In [None]:
# ***************** Import the necessary libraries *****************

import torch
import torchvision
import torchvision.transforms as transforms
from torch import nn
import torch.optim as optim
from torchvision.utils import save_image
from torch.utils.data import ConcatDataset
from torchvision import transforms

from PIL import Image
import pandas as pd
import numpy as np
import sys
import os
import random
import matplotlib.pyplot as plt

# Mount the google drive
from google.colab import drive
drive.mount('/content/drive', force_remount=True)

sys.path.insert(0, 'drive/MyDrive/CI642_Coursework')

from Generator_class import Generator
from Discriminator_class import Discriminator
from CustomDataset_class import CustomDataset
from Utils import plot_images, initialise_weights, test_implementation, calculate_CID_indices

In [None]:
# ***************** Initialise the hyperparameters *****************

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
learning_rate = 1e-3
batch_size = 128
num_classes = 5
img_size = 64
img_channels = 3
z_dim = 100
epochs = 50
discriminator_num_filters = 128
generator_num_filters = 128

data_folder = "drive/MyDrive/CI642_Coursework/Dataset"

In [None]:
# ***************** Pre-processing and loading the dataset *****************

# Resize the images, convert them to tensors, and normalise the image channels to have 0.5 mean and standard deviation.
transform = transforms.Compose(
    [
      transforms.Resize((img_size, img_size)),
      transforms.ToTensor(),
      transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
    ])


training_data = []

classes = {}

# Disable the PIL image limit
Image.MAX_IMAGE_PIXELS = None

# Every folder in data_folder contains images of a specific class and is named with the format ClassIndex_ClassName
for img_class in os.listdir(data_folder):

  classes[int(img_class.split("_")[0])] = img_class.split("_")[1]

  # Get full paths of all images in this folder
  imgs_paths = [os.path.join(data_folder, img_class, img) for img in os.listdir(os.path.join(data_folder, img_class))]

  for img in imgs_paths:
    image = Image.open(img)
    # Delete images that do not have exactly 3 channels
    if len(list(np.array(image).shape)) != 3:
      print(img, np.array(image).shape)
      os.remove(img)

  print(f"************ There are {len(os.listdir(os.path.join(data_folder, img_class)))} {img_class.split('_')[1]} images")

  # Put these images into a CustomDataset object
  dataset = CustomDataset(imgs_paths, int(img_class.split("_")[0]), img_class.split("_")[1], transform=transform)

  training_data.append(dataset)

# Create a training dataloader to iterate over the training data
train_dataloader = torch.utils.data.DataLoader(ConcatDataset(training_data), batch_size=batch_size, shuffle=True)

In [None]:
# ***************** Show 10 sample images from the training dataset *****************

dataiter = iter(train_dataloader)
images, label_indices, label_names = next(dataiter)
plot_images(images[:10], label_names[:10])

In [None]:
# ***************** Test the architecture implementation before starting the training *****************

test_implementation(num_samples = 5, img_channels = 3, img_size = img_size, num_classes = 5, noise_dim = z_dim, labels = torch.LongTensor(np.arange(5)), disc_num_feature_maps = discriminator_num_filters, gen_num_feature_maps = generator_num_filters)

In [None]:
# ***************** Initialise the generator and discriminator objects, and their optimisers *****************

generator = Generator(z_dim, img_channels, generator_num_filters, num_classes).to(device)
initialise_weights(generator)
print(generator)

discriminator = Discriminator(img_channels, discriminator_num_filters, num_classes, img_size).to(device)
initialise_weights(discriminator)
print(discriminator)

generator_optimiser = optim.Adam(generator.parameters(), lr=learning_rate, betas=(0.5, 0.999))
discriminator_optimiser = optim.Adam(discriminator.parameters(), lr=learning_rate, betas=(0.5, 0.999))

# The training loss function (Binary Cross Entropy)
criterion = nn.BCELoss()

In [None]:
# ***************** Train the conditional DCGAN *****************

# Put the generator and discriminator in training modes
generator.train()
discriminator.train()

generator_losses = []
discriminator_losses = []


for epoch in range(epochs):

  total_disc_loss = 0
  total_gen_loss = 0

  for batch_index, (real_imgs, label_indices, label_names) in enumerate(train_dataloader):

    real_imgs = real_imgs.to(device)
    label_indices = label_indices.to(device)

    # Generate a batch of images from the random noise
    noise = torch.randn(len(real_imgs), z_dim, 1, 1).to(device)
    fake_imgs = generator(noise, label_indices)

    # ******* Train the discriminator

    # Make the discriminator trainable
    for param in discriminator.parameters():
      param.requires_grad = True

    discriminator_optimiser.zero_grad()

    # Set the ground truth for the real images to 1
    # Calculate the loss of the discriminator on real images
    real_imgs_output = discriminator(real_imgs, label_indices).reshape(-1)
    real_imgs_loss = criterion(real_imgs_output, torch.full_like(real_imgs_output, 0.9)) # Apply label smoothing

    # Set the ground truth for the fake images to 0
    # Calculate the loss of the discriminator on fake images
    fake_imgs_output = discriminator(fake_imgs, label_indices).reshape(-1)
    fake_imgs_loss = criterion(fake_imgs_output, torch.zeros_like(fake_imgs_output))

    # Measure the total discriminator loss
    discriminator_loss = (real_imgs_loss + fake_imgs_loss) / 2
    total_disc_loss += discriminator_loss.item()

    # Perform the backward propagation and update the model parameters
    discriminator_loss.backward(retain_graph=True)
    discriminator_optimiser.step()


    # ******* Train the generator
    generator_optimiser.zero_grad()

    # Make the discriminator untrainable
    for param in discriminator.parameters():
      param.requires_grad = False

    # Set the ground truth for the fake images to 1
    # Calculate the loss of the discriminator on fake images
    fake_imgs_output = discriminator(fake_imgs, label_indices).reshape(-1)
    generator_loss = criterion(fake_imgs_output, torch.full_like(fake_imgs_output, 0.9)) # Apply label smoothing

    total_gen_loss += generator_loss.item()

    # Perform backpropagation and update parameters according to the gradient from the discriminator
    generator_loss.backward(retain_graph=True)
    generator_optimiser.step()


  generator_losses.append(total_gen_loss)
  discriminator_losses.append(total_disc_loss)


  # After every 5 epochs, validate the generator.
  if epoch % 5 == 0:
    print(
        f"\n************************************************************************* \
        \nEpoch {epoch} | Discriminator Loss: {total_disc_loss:.4f}, generator loss: {total_gen_loss:.4f}"
    )

    with torch.no_grad():
      generator.eval()

      # Create random noise
      noise = torch.randn(num_classes, z_dim, 1, 1).to(device)
      labels = torch.LongTensor(np.arange(num_classes)).to(device)

      # Generate an image for each class
      generated_imgs = generator(noise, labels).squeeze(1).data.cpu()

      # Show the generated images and their labels
      plot_images(generated_imgs, list(map(lambda x: classes[x], labels.cpu().numpy())))

    # Put the generator back in the training mode
    generator.train()

# ***************** Save the models *****************
torch.save(generator.state_dict(), f"drive/MyDrive/CI642_Coursework/generator_model.pth")
torch.save(discriminator.state_dict(), f"drive/MyDrive/CI642_Coursework/discriminator_model.pth")

In [None]:
# ***************** Plot the training loss curve *****************

plt.figure(1)
plt.plot(range(len(discriminator_losses)), discriminator_losses, 'y', label="Discriminator loss")
plt.plot(range(len(generator_losses)), generator_losses, 'c', label="Generator loss")
plt.xlabel("Number of iteration")
plt.ylabel("Loss")
plt.title("Conditional DCGAN Training Loss Curve")
plt.legend(loc="upper right")
plt.show()

In [None]:
# ***************** Save the models *****************

torch.save(generator.state_dict(), "drive/MyDrive/CI642_Coursework/generator_model.pth")
torch.save(discriminator.state_dict(), "drive/MyDrive/CI642_Coursework/discriminator_model.pth")

In [None]:
# ***************** Load the generator parameters and generate some images *****************

generator = Generator(z_dim, img_channels, generator_num_filters, num_classes).to(device)
generator.load_state_dict(torch.load("drive/MyDrive/CI642_Coursework/generator_model.pth"))

# Create the noise vector and set the generated image labels to all class labels
noise = torch.randn(num_classes, z_dim, 1, 1).to(device)
labels = torch.LongTensor(np.arange(num_classes)).to(device)

# Generate an image for each class
generated_imgs = generator(noise, labels)
generated_imgs = [(x.squeeze(1).data.cpu())  for x in generated_imgs]

# Plot the generated images and their labels
plot_images(generated_imgs, list(map(lambda x: classes[x], labels.cpu().numpy())))

In [None]:
# ***************** Evaluation *****************

# Load the training images
training_images = {}

for img_class in os.listdir(data_folder):
  label = int(img_class.split("_")[0])
  # Get full paths of all images in this folder
  imgs_paths = [os.path.join(data_folder, img_class, img) for img in os.listdir(os.path.join(data_folder, img_class))]

  training_images[label] = list(map(lambda img_path: np.array(Image.open(img_path).resize((64, 64))), imgs_paths))


# Calculate the CID indices
CID_indices = calculate_CID_indices(z_dim, generator, training_images, num_fake_imgs = 1000)
print(CID_indices)


# Plot the CID indices
x_axis = [i for i in CID_indices]
creativity_indices = [CID_indices[label][0] for label in CID_indices]
inheritance_indices = [CID_indices[label][1] for label in CID_indices]
diversity_indices = [CID_indices[label][2] for label in CID_indices]

plt.figure(2)
plt.plot(x_axis, creativity_indices, 'y', label="Creativity Index")
plt.plot(x_axis, inheritance_indices, 'c', label="Inhertance Index")
plt.plot(x_axis, [(i)/(max(diversity_indices)) for i in diversity_indices], 'm', label="Diversity Index")
plt.xticks(ticks = x_axis, labels = list(map(lambda i: classes[i], CID_indices)))
plt.xlabel("Image Class")
plt.title("CID Indices Measured With 1000 Generated Images per Class")
plt.legend(loc="center left", bbox_to_anchor=(1, 0.5))
plt.show()