In [None]:
import os
import cv2
import time
import random
import numpy as np
import matplotlib.pyplot as plt
from PIL import Image
from glob import glob
import torch
import torchvision
import torch.nn as nn
import torch.optim as optim
import torch.backends.cudnn as cudnn
import torchvision.transforms.functional as TF
import torchvision.transforms as transformtransforms
from tqdm import tqdm
from torchsummary import summary
from torchvision import transforms
from torchvision.transforms import ToPILImage
from torch.utils.data import DataLoader

os.environ["KMP_DUPLICATE_LIB_OK"] = "TRUE"
TORCH_CUDA_ARCH_LIST = "8.6"

EPOCH_COUNT = 20
BATCH_SIZE = 20
WORKERS = 0
SEED_VAL = 2099
IMG_SIZE = 64
LEARNING_RATE = 0.0002
NOISE_DIM = 10

time_start_all = time.time()

print('torch.version: ', torch.__version__)
print('torch.version.cuda: ', torch.version.cuda)
print('torch.cuda.is_available: ', torch.cuda.is_available())
print('torch.cuda.device_count: ', torch.cuda.device_count())
print('torch.cuda.current_device: ', torch.cuda.current_device())
current_device = torch.cuda.current_device()
torch.cuda.device(current_device)
print('torch.cuda.get_device_name: ', torch.cuda.get_device_name(current_device))
computing_device = torch.device("cuda")

transform = transforms.Compose([transforms.Resize(IMG_SIZE),
                                transforms.CenterCrop(IMG_SIZE),
                                transforms.ToTensor(),
                                transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))])

dataset = torchvision.datasets.CIFAR10(root="./data/", download=True, transform=transform)
dataloader = DataLoader(dataset, batch_size=BATCH_SIZE, shuffle=True, num_workers=WORKERS)

print(dataset.classes)
print(dataset.data.shape)

NUM_CLASSES = len(dataset.classes)

class Generator(nn.Module):
    def __init__(self, noise=NOISE_DIM):
        super(Generator, self).__init__()
        self.main = nn.Sequential(
            nn.ConvTranspose2d(noise, 512, kernel_size=4, stride=1, padding=0, bias=False),
            nn.BatchNorm2d(512),
            nn.ReLU(True),
            nn.ConvTranspose2d(512, 256, kernel_size=4, stride=2, padding=1, bias=False),
            nn.BatchNorm2d(256),
            nn.ReLU(True),
            nn.ConvTranspose2d(256, 128, kernel_size=4, stride=2, padding=1, bias=False),
            nn.BatchNorm2d(128),
            nn.ReLU(True),
            nn.ConvTranspose2d(128, 64, kernel_size=4, stride=2, padding=1, bias=False),
            nn.BatchNorm2d(64),
            nn.ReLU(True),
            nn.ConvTranspose2d(64, 3, kernel_size=4, stride=2, padding=1, bias=False),
            nn.Tanh()
        )

    def forward(self, input):
        x = self.main(input)
        return x

    def initialize_weights(self, w_mean=0, w_std=0.02, b_mean=1, b_std=0.02):
        for m in self.modules():
            classname = m.__class__.__name__
            if classname.find('Conv') != -1:
                nn.init.normal_(m.weight.data, w_mean, w_std)
            elif classname.find('BatchNorm') != -1:
                nn.init.normal_(m.weight.data, b_mean, b_std)
                nn.init.constant_(m.bias.data, 0)

computing_device = torch.device("cuda")
gen = Generator().to(computing_device)
summary(gen, input_size=(10, 1, 1))

class Discriminator(nn.Module):
    def __init__(self, num_class=NUM_CLASSES):
        super(Discriminator, self).__init__()
        self.main = nn.Sequential(
            nn.Conv2d(3, 64, kernel_size=4, stride=2, padding=1, bias=False),
            nn.LeakyReLU(0.2, inplace=True),
            nn.Conv2d(64, 128, kernel_size=4, stride=2, padding=1, bias=False),
            nn.BatchNorm2d(128),
            nn.LeakyReLU(0.2, inplace=True),
            nn.Conv2d(128, 256, kernel_size=4, stride=2, padding=1, bias=False),
            nn.BatchNorm2d(256),
            nn.LeakyReLU(0.2, inplace=True),
            nn.Conv2d(256, 512, kernel_size=4, stride=2, padding=1, bias=False),
            nn.BatchNorm2d(512),
            nn.LeakyReLU(0.2, inplace=True),
            nn.Conv2d(512, 64, kernel_size=4, stride=1, padding=0, bias=False),
        )
        self.fc_discrimination = nn.Linear(64, 1)
        self.fc_classification = nn.Linear(64, num_class)
        self.softmax = nn.Softmax()
        self.sigmoid = nn.Sigmoid()
        self.num_class = num_class

    def forward(self, input):
        x = self.main(input)
        x = x.view(-1, 64)
        x_discrimination = self.fc_discrimination(x)
        x_classification = self.fc_classification(x)
        score_discrimination = self.sigmoid(x_discrimination)
        score_classification = self.softmax(x_classification)
        return score_discrimination, score_classification

    def initialize_weights(self, w_mean=0, w_std=0.02, b_mean=1, b_std=0.02):
        for m in self.modules():
            classname = m.__class__.__name__
            if classname.find('Conv') != -1:
                nn.init.normal_(m.weight.data, w_mean, w_std)
            elif classname.find('BatchNorm') != -1:
                nn.init.normal_(m.weight.data, b_mean, b_std)
                nn.init.constant_(m.bias.data, 0)

computing_device = torch.device("cuda")
disc = Discriminator().to(computing_device)
summary(disc, input_size=(3, 64, 64))

def normalize(data):
    data = ((data - np.min(data)) / (np.max(data) - np.min(data)) * 255).astype(np.uint8)
    return data

def create_image_wall(images):
    image = images[random.randint(0, 9)].detach().permute(1, 2, 0).cpu().numpy()
    rows_1 = np.array([])
    for i in range(0, 5):
        row_i = images[i].detach().permute(1, 2, 0).cpu().numpy()
        if i == 0:
            rows_1 = row_i
        else:
            rows_1 = np.hstack((rows_1, row_i))
    rows_2 = np.array([])
    for i in range(5, 10):
        row_i = images[i].detach().permute(1, 2, 0).cpu().numpy()
        if i == 5:
            rows_2 = row_i
        else:
            rows_2 = np.hstack((rows_2, row_i))
    rows = np.vstack((rows_1, rows_2))
    rows = normalize(rows)
    image = normalize(image)
    return rows, image

ds_loss = nn.BCELoss()
class_loss = nn.NLLLoss()
beta1 = 0.5
beta2 = 0.999
optimizer_disc = optim.Adam(disc.parameters(), lr=LEARNING_RATE, betas=(beta1, beta2))
optimizer_gen = optim.Adam(gen.parameters(), lr=LEARNING_RATE, betas=(beta1, beta2))

disc_label = torch.Tensor(BATCH_SIZE)
class_label = torch.Tensor(BATCH_SIZE)

disc_training_period = 1
gen_training_period = 1
disc_loss_history = []
gen_loss_history = []
images_history = []
frames_history = []

print('start training...')
print('Epochs: %d | Batch Size: %d | Learning Rate: %.4f' % (EPOCH_COUNT, BATCH_SIZE, LEARNING_RATE))
gen.train(True)
disc.train(True)
time_start_training = time.time()

for epoch in range(EPOCH_COUNT):
    for i, data in enumerate(dataloader):
        for disc_iter in range(disc_training_period):
            optimizer_disc.zero_grad()

            real_data, labels = data
            real_data = real_data.to(computing_device)
            labels = labels.to(computing_device)

            class_label = labels
            disc_label = torch.full((BATCH_SIZE, 1), 1., device=computing_device)

            disc_score, disc_class = disc(real_data)
            loss_disc_score = ds_loss(disc_score, disc_label)
            loss_disc_class = class_loss(disc_class, class_label)

            disc_loss_real = loss_disc_score + loss_disc_class

            noise = torch.randn(BATCH_SIZE, NOISE_DIM, 1, 1, device=computing_device)
            fake_data = gen(noise)
            class_label = labels
            disc_label = torch.full((BATCH_SIZE, 1), 0., device=computing_device)

            disc_score, disc_class = disc(fake_data)
            loss_disc_score = ds_loss(disc_score, disc_label)
            loss_disc_class = class_loss(disc_class, class_label)
            disc_loss_fake = loss_disc_score + loss_disc_class

            disc_loss = disc_loss_real + disc_loss_fake

            disc_loss.backward()
            optimizer_disc.step()

        for gen_iter in range(gen_training_period):
            optimizer_gen.zero_grad()

            noise = torch.randn(BATCH_SIZE, NOISE_DIM, 1, 1, device=computing_device)
            fake_data = gen(noise)
            class_label = labels
            disc_label = torch.full((BATCH_SIZE, 1), 1., device=computing_device)

            gen_score, gen_class = disc(fake_data)
            loss_gen_score = ds_loss(gen_score, disc_label)
            loss_gen_class = class_loss(gen_class, class_label)

            gen_loss = loss_gen_score + loss_gen_class

            gen_loss.backward()
            optimizer_gen.step()

        disc_loss_history.append(disc_loss.item())
        gen_loss_history.append(gen_loss.item())

        if i % (len(dataloader) // 25) == 0:
            disc.eval()
            gen.eval()

            noise = torch.randn(10, NOISE_DIM, 1, 1, device=computing_device)
            fake_data = gen(noise)
            fake_images, gen_image = create_image_wall(fake_data)
            real_images, _ = create_image_wall(real_data)

            frames = np.vstack((real_images, fake_images))

            plt.figure(figsize=(10, 10))
            plt.imshow(frames)
            plt.axis('off')
            plt.title('Epoch: %d niter: %d' % (epoch, i), fontsize=16)
            plt.show()

            frames_history.append(gen_image)
            images_history.append(fake_images)

            time_end_training = time.time()
            training_duration = time_end_training - time_start_training
            time_start_training = time.time()

            print('[%d/%d][%d/%d] | Loss_D: %.4f | Loss_G: %.4f | Time: %.4f' % (
                epoch + 1, EPOCH_COUNT, i, len(dataloader), disc_loss.item(), gen_loss.item(), training_duration))

            disc.train()
            gen.train()

    plt.figure(figsize=(20, 10))
    plt.plot(gen_loss_history, label='Gen')
    plt.plot(disc_loss_history, label='Disc')
    plt.xlabel('iter', fontsize=20)
    plt.ylabel('loss', fontsize=20)
    plt.title('Training', fontsize=20)
    plt.legend(fontsize=20)
    plt.savefig('ACGAN_loss.png')
    plt.show()
    torch.save(gen.state_dict(), './weights/ACGAN_Gen_test.pth')
    torch.save(disc.state_dict(), './weights/ACGAN_Disc_test.pth')


def display_images(frame_list):
    count = len(frame_list)

    new_frame_list = []
    interval = count // EPOCH_COUNT
    for index in range(0, count, interval):
        selected_index = random.randint(index, index + interval - 1)
        new_frame_list.append(frame_list[selected_index])

    count = len(new_frame_list)
    columns = 5
    rows = count // columns

    final_frames = np.array([])
    for row_index in range(rows):
        frame_row = np.array([])
        for col_index in range(columns):
            frame_i = new_frame_list[row_index * columns + col_index]
            if col_index == 0:
                frame_row = frame_i
            else:
                frame_row = np.hstack((frame_row, frame_i))

        if row_index == 0:
            final_frames = frame_row
        else:
            final_frames = np.vstack((final_frames, frame_row))

    return final_frames

final_frames = display_images(frames_history)

plt.figure(figsize=(10, 10))
plt.imshow(final_frames)
plt.axis('off')
plt.title('ACGAN Epochs:%d Batch:%d D/G: %d/%d' % (epoch + 1, BATCH_SIZE, disc_training_period, gen_training_period), fontsize=20)
plt.legend(fontsize=20)
plt.savefig('image_wall_plot.png')
plt.show()

cv2.imwrite('image_wall.png', cv2.cvtColor(final_frames, cv2.COLOR_BGR2RGB))

def test_model(image_count):
    disc.eval()
    gen.eval()

    best_images = []
    score_list = []
    index_count = 0

    while index_count < image_count:
        noise = torch.randn(1, NOISE_DIM, 1, 1, device=computing_device)
        test_image = gen(noise)
        score, classes = disc(test_image)
        score = score.detach().cpu().numpy()[0][0]
        classes = classes.detach().cpu().numpy()[0]
        image = test_image.detach().permute(0, 2, 3, 1).cpu().numpy()[0]
        score_list.append(score)
        if score < 0.001:
            print(score)
            image = test_image.detach().permute(0, 2, 3, 1).cpu().numpy()[0]
            best_images.append(normalize(image))
            index_count += 1

    return best_images, score_list

best_images, score_list = test_model(10)
plt.plot(score_list)

def show_final_images(image_list):
    count = len(image_list)

    columns = 5
    rows = count // columns

    final_frames = np.array([])
    for row_index in range(rows):
        frame_row = np.array([])
        for col_index in range(columns):
            frame_i = image_list[row_index * columns + col_index]
            if col_index == 0:
                frame_row = frame_i
            else:
                frame_row = np.hstack((frame_row, frame_i))

        if row_index == 0:
            final_frames = frame_row
        else:
            final_frames = np.vstack((final_frames, frame_row))

    return final_frames

final_frames = show_final_images(best_images)
plt.figure()
plt.imshow(final_frames)
plt.axis('off')
plt.title('ACGAN Best Epochs:%d Batch:%d D/G: %d/%d' % (epoch + 1, BATCH_SIZE, 1, 1), fontsize=20)
plt.legend()
plt.savefig('ACGAN_image_wall_plot.png')
plt.show()

cv2.imwrite('ACGAN_image_wall.png', cv2.cvtColor(final_frames, cv2.COLOR_BGR2RGB))
