In [None]:
import os
import cv2
import time
import random
import numpy as np
import matplotlib.pyplot as plt
from PIL import Image
from glob import glob

import torch
import torchvision
import torch.nn as nn
import torch.optim as optim
import torch.backends.cudnn as cudnn
import torchvision.transforms.functional as TF
import torchvision.transforms as transformtransforms
from tqdm import tqdm
from torchsummary import summary
from torchvision import transforms
from torchvision.transforms import ToPILImage
from torch.utils.data import Dataset, DataLoader

NUM_EPOCHS = 20
BATCH_SIZE = 50
WORKERS = 0

RANDOM_SEED = 2099

IMG_SIZE = 64
LEARNING_RATE = 0.0002
NOISE_DIM = 100

start_time = time.time()

os.environ["KMP_DUPLICATE_LIB_OK"] = "TRUE"
TORCH_CUDA_ARCH_LIST = "8.6"

print('torch.version: ', torch.__version__)
print('torch.version.cuda: ', torch.version.cuda)
print('torch.cuda.is_available: ', torch.cuda.is_available())
print('torch.cuda.device_count: ', torch.cuda.device_count())
current_device = torch.cuda.current_device()
torch.cuda.device(current_device)
print('torch.cuda.get_device_name: ', torch.cuda.get_device_name(current_device))
compute_device = torch.device("cuda")

transform_pipeline = transforms.Compose([transforms.Resize(IMG_SIZE),
                                         transforms.CenterCrop(IMG_SIZE),
                                         transforms.ToTensor(),
                                         transforms.Normalize((0.5, 0.5, 0.5), 
                                                              (0.5, 0.5, 0.5)),])

data_set = torchvision.datasets.CIFAR10(root="./data/", 
                                        download=True,
                                        transform=transform_pipeline)

data_loader = DataLoader(data_set, 
                         batch_size=BATCH_SIZE,
                         shuffle=True, 
                         num_workers=WORKERS)

print(data_set.classes)
print(data_set.data.shape)

class ImageGenerator(nn.Module):
    def __init__(self, input_noise=NOISE_DIM):
        super(ImageGenerator, self).__init__()
        self.model = nn.Sequential(
            nn.ConvTranspose2d(input_noise, 512, kernel_size=4, stride=1, padding=0, bias=False),
            nn.BatchNorm2d(512),
            nn.ReLU(True),
            nn.ConvTranspose2d(512, 256, kernel_size=4, stride=2, padding=1, bias=False),
            nn.BatchNorm2d(256),
            nn.ReLU(True),
            nn.ConvTranspose2d(256, 128, kernel_size=4, stride=2, padding=1, bias=False),
            nn.BatchNorm2d(128),
            nn.ReLU(True),
            nn.ConvTranspose2d(128, 64, kernel_size=4, stride=2, padding=1, bias=False),
            nn.BatchNorm2d(64),
            nn.ReLU(True),
            nn.ConvTranspose2d(64, 3, kernel_size=4, stride=2, padding=1, bias=False),
            nn.Tanh()
        )

    def forward(self, x):
        return self.model(x)

    def init_weights(self, mean_w=0, std_w=0.02, mean_b=1, std_b=0.02):
        for module in self.modules():
            if isinstance(module, nn.ConvTranspose2d) or isinstance(module, nn.BatchNorm2d):
                nn.init.normal_(module.weight.data, mean_w, std_w)
                nn.init.constant_(module.bias.data, 0)

generator = ImageGenerator().to(compute_device)
summary(generator, input_size=(100, 1, 1))

class ImageDiscriminator(nn.Module):
    def __init__(self):
        super(ImageDiscriminator, self).__init__()
        self.model = nn.Sequential(
            nn.Conv2d(3, 64, kernel_size=4, stride=2, padding=1, bias=False),
            nn.LeakyReLU(0.2, inplace=True),
            nn.Conv2d(64, 128, kernel_size=4, stride=2, padding=1, bias=False),
            nn.BatchNorm2d(128),
            nn.LeakyReLU(0.2, inplace=True),
            nn.Conv2d(128, 256, kernel_size=4, stride=2, padding=1, bias=False),
            nn.BatchNorm2d(256),
            nn.LeakyReLU(0.2, inplace=True),
            nn.Conv2d(256, 512, kernel_size=4, stride=2, padding=1, bias=False),
            nn.BatchNorm2d(512),
            nn.LeakyReLU(0.2, inplace=True),
            nn.Conv2d(512, 1, kernel_size=4, stride=1, padding=0, bias=False),
            nn.Sigmoid()
        )

    def forward(self, x):
        return self.model(x)

    def init_weights(self, mean_w=0, std_w=0.02, mean_b=1, std_b=0.02):
        for module in self.modules():
            if isinstance(module, nn.Conv2d) or isinstance(module, nn.BatchNorm2d):
                nn.init.normal_(module.weight.data, mean_w, std_w)
                nn.init.constant_(module.bias.data, 0)

discriminator = ImageDiscriminator().to(compute_device)
summary(discriminator, input_size=(3, 64, 64))

loss_function = nn.BCELoss()
optim_gen = optim.Adam(generator.parameters(), lr=LEARNING_RATE, betas=(0.5, 0.999))
optim_disc = optim.Adam(discriminator.parameters(), lr=LEARNING_RATE, betas=(0.5, 0.999))

noise_input = torch.randn(BATCH_SIZE, NOISE_DIM, 1, 1)
noise_input = noise_input.to(compute_device)

real_mark = 1.0
fake_mark = 0.0

print('start training...')
for epoch in range(NUM_EPOCHS):
    for i, batch_data in enumerate(data_loader):
        discriminator.zero_grad()
        real_images = batch_data[0].to(compute_device)
        real_labels = torch.full((BATCH_SIZE, 1, 1, 1), real_mark, device=compute_device)
        real_output = discriminator(real_images)
        loss_real = loss_function(real_output, real_labels)
        loss_real.backward()
        real_score = real_output.mean().item()

        noise = torch.randn(BATCH_SIZE, NOISE_DIM, 1, 1, device=compute_device)
        fake_images = generator(noise)
        fake_labels = torch.full((BATCH_SIZE, 1, 1, 1), fake_mark, device=compute_device)
        fake_output = discriminator(fake_images.detach())
        loss_fake = loss_function(fake_output, fake_labels)
        loss_fake.backward()
        fake_score = fake_output.mean().item()
        disc_loss = loss_real + loss_fake
        optim_disc.step()

        generator.zero_grad()
        fake_labels.fill_(real_mark)
        fake_output = discriminator(fake_images)
        gen_loss = loss_function(fake_output, fake_labels)
        gen_loss.backward()
        optim_gen.step()

        if i % (len(data_loader)//50) == 0:
            discriminator.eval()
            generator.eval()
            noise = torch.randn(10, NOISE_DIM, 1, 1, device=compute_device)
            test_images = generator(noise)
            test_images_grid = create_image_grid(test_images)
            plt.figure(figsize=(10, 10))
            plt.imshow(test_images_grid)
            plt.axis('off')
            plt.title(f'Epoch: {epoch} Iteration: {i}', fontsize=16)
            plt.show()
            print(f'[{epoch}/{NUM_EPOCHS}][{i}/{len(data_loader)}] Loss_D: {disc_loss:.4f} Loss_G: {gen_loss:.4f} D(x): {real_score:.4f} D(G(z)): {fake_score:.4f}')
            discriminator.train()
            generator.train()

end_time = time.time()
print('Done! Total time cost: ', end_time - start_time)

def test_generator(num_images=10):
    discriminator.eval()
    generator.eval()
    top_images = []
    index = 0
    while index < num_images:
        test_noise = torch.randn(1, NOISE_DIM, 1, 1, device=compute_device)
        test_image = generator(test_noise)
        image_score = discriminator(test_image).detach().cpu().numpy()[0][0][0][0]
        if image_score > 0.9:
            print(image_score)
            formatted_image = test_image.detach().permute(0, 2, 3, 1).cpu().numpy()[0]
            top_images.append(normalize_image(formatted_image))
            index += 1
    return top_images

top_test_images = test_generator(10)

def display_images(images):
    n = len(images)
    col = 5
    row = n // col
    combined_images = np.array([])
    for i in range(row):
        row_images = np.array([])
        for j in range(col):
            image = images[i * col + j]
            if j == 0:
                row_images = image
            else:
                row_images = np.hstack((row_images, image))
        if i == 0:
            combined_images = row_images
        else:
            combined_images = np.vstack((combined_images, row_images))
    return combined_images

final_image_display = display_images(top_test_images)

plt.figure(figsize=(10, 10))
plt.imshow(final_image_display)
plt.axis('off')
plt.title(f'DCGAN Epochs:{NUM_EPOCHS} Batch:{BATCH_SIZE} D/G: 1/1', fontsize=20)
plt.savefig('DCGAN_image_wall_plot.png')
plt.show()

cv2.imwrite('DCGAN_image_wall.png', cv2.cvtColor(final_image_display, cv2.COLOR_BGR2RGB))
