**Main imports**

In [None]:
import math
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
import torchvision
import matplotlib.pyplot as plt
from IPython import display as disp
from torch.nn.utils import spectral_norm    
import os

device = torch.device('cuda') if torch.cuda.is_available() else torch.device('cpu')
print(device.type)

**Import dataset**

In [None]:
# helper function to make getting another batch of data easier
def cycle(iterable):
    while True:
        for x in iterable:
            yield x

class_names = ['apple','aquarium_fish','baby','bear','beaver','bed','bee','beetle','bicycle','bottle','bowl','boy','bridge','bus','butterfly','camel','can','castle','caterpillar','cattle','chair','chimpanzee','clock','cloud','cockroach','couch','crab','crocodile','cup','dinosaur','dolphin','elephant','flatfish','forest','fox','girl','hamster','house','kangaroo','computer_keyboard','lamp','lawn_mower','leopard','lion','lizard','lobster','man','maple_tree','motorcycle','mountain','mouse','mushroom','oak_tree','orange','orchid','otter','palm_tree','pear','pickup_truck','pine_tree','plain','plate','poppy','porcupine','possum','rabbit','raccoon','ray','road','rocket','rose','sea','seal','shark','shrew','skunk','skyscraper','snail','snake','spider','squirrel','streetcar','sunflower','sweet_pepper','table','tank','telephone','television','tiger','tractor','train','trout','tulip','turtle','wardrobe','whale','willow_tree','wolf','woman','worm',]

train_loader = torch.utils.data.DataLoader(
    torchvision.datasets.CIFAR100('data', train=True, download=True, transform=torchvision.transforms.Compose([
        torchvision.transforms.ToTensor(),
        torchvision.transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
    ])),
    batch_size=64, drop_last=True)

test_loader = torch.utils.data.DataLoader(
    torchvision.datasets.CIFAR100('data', train=False, download=True, transform=torchvision.transforms.Compose([
        torchvision.transforms.Resize([32,32]),
        torchvision.transforms.ToTensor(),
        torchvision.transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)),
    ])),
    batch_size=64, drop_last=True)

train_iterator = iter(cycle(train_loader))
test_iterator = iter(cycle(test_loader))

print(f'> Size of training dataset {len(train_loader.dataset)}')
print(f'> Size of test dataset {len(test_loader.dataset)}')
print("Number of classes: ", len(class_names))

**View some of the test dataset**

In [None]:
# let's view some of the training data
plt.rcParams['figure.dpi'] = 100
x, t = next(train_iterator)

# Ensure the tensor is correctly moved to the GPU
x = x.to(device)
t = t.to(device)

# Plot the images
plt.imshow(torchvision.utils.make_grid(x).cpu().numpy().transpose(1, 2, 0))
plt.show()

### Generator, Discriminator, and Training Loop

In [62]:
# hyperparameters
batch_size = train_loader.batch_size
# channels = 3
# latent_dim = 256
# image_size = 32

num_classes = 100

check_interval = 5

num_of_epochs = 50000

Repo: https://github.com/atapour/dl-pytorch/blob/main/Conditional_GAN_Example/Conditional_GAN_Example.ipynb

In [None]:
# A few parameters:
n_channels = 3
img_width = 32

# define the generator
class Generator(nn.Module):
    def __init__(self, latent_size=100, label_size=100):
        super(Generator, self).__init__()
        self.layer = nn.Sequential(
            nn.Linear(latent_size+label_size, 128),
            nn.LeakyReLU(0.2),
            nn.Linear(128, 256),
            nn.BatchNorm1d(256),
            nn.LeakyReLU(0.2),
            nn.Linear(256, 512),
            nn.BatchNorm1d(512),
            nn.LeakyReLU(0.2),
            nn.Linear(512, 1024),
            nn.BatchNorm1d(1024),
            nn.LeakyReLU(0.2),
            nn.Linear(1024, n_channels*img_width*img_width),
            nn.Tanh()
        )
        
    def forward(self, x, c):
        x, c = x.view(x.size(0), -1), c.view(c.size(0), -1).float()
        x = torch.cat((x, c), 1) # [input, label] concatenated
        x = self.layer(x)
        return x.view(x.size(0), n_channels, img_width, img_width)

# define the discriminator
class Discriminator(nn.Module):
    def __init__(self, label_size=100):
        super(Discriminator, self).__init__()
        self.layer = nn.Sequential(
            nn.Linear(n_channels*img_width*img_width+label_size, 512),
            nn.LeakyReLU(0.2),
            nn.Linear(512, 256),
            nn.LeakyReLU(0.2),
            nn.Linear(256, 128),
            nn.LeakyReLU(0.2),
            nn.Linear(128,64),
            nn.LeakyReLU(0.2),
            nn.Linear(64, 32),
            nn.LeakyReLU(0.2),
            nn.Linear(32, 1),
            nn.LeakyReLU(0.2),
            nn.Sigmoid(),   
        )
    
    def forward(self, x, c):        
        x, c = x.view(x.size(0), -1), c.view(c.size(0), -1).float()
        x = torch.cat((x, c), 1) # [input, label] concatenated
        return self.layer(x)
        
G = Generator().to(device)
D = Discriminator().to(device)

print(f'Generator has {len(torch.nn.utils.parameters_to_vector(G.parameters()))} parameters.')
print(f'Discriminator has {len(torch.nn.utils.parameters_to_vector(D.parameters()))} parameters')

# initialise the optimiser
optimiser_G = torch.optim.Adam(G.parameters(), lr=0.0002, betas=(0.5, 0.999))
optimiser_D = torch.optim.Adam(D.parameters(), lr=0.0004, betas=(0.5, 0.999))
print('Optimisers have been created!')

criterion = nn.BCELoss()
print('Loss function is Binary Cross Entropy!')

epoch = 0
# training loop
while (epoch<50000):
    
    # arrays for metrics
    logs = {}
    gen_loss_arr = np.zeros(0)
    dis_loss_arr = np.zeros(0)

    # iterate over the train dateset
    for i, batch in enumerate(train_loader):

        x, t = batch
        x, t = x.to(device), t.to(device)

        # convert target labels "t" to a one-hot vector, e.g. 3 becomes [0,0,0,1,0,0,0,...]
        y = torch.zeros(x.size(0), 100).long().to(device).scatter(1, t.view(x.size(0),1), 1)

        # train discriminator
        z = torch.randn(x.size(0), 100).to(device)
        l_r = criterion(D(x,y),       torch.ones([64,1]).to(device)) # real -> 1
        l_f = criterion(D(G(z,y),y), torch.zeros([64,1]).to(device)) # fake -> 0
        loss_d = (l_r + l_f)/2.0
        optimiser_D.zero_grad()
        loss_d.backward()
        optimiser_D.step()
        
        # train generator
        z = torch.randn(x.size(0), 100).to(device)
        loss_g = criterion(D(G(z,y),y), torch.ones([64,1]).to(device)) # fake -> 1
        optimiser_G.zero_grad()
        loss_g.backward()
        optimiser_G.step()

        gen_loss_arr = np.append(gen_loss_arr, loss_g.item())
        dis_loss_arr = np.append(dis_loss_arr, loss_d.item())

        print(f'Epoch [{epoch+1}/20000], Discriminator Loss: {dis_loss_arr.mean():.4f}, Generator Loss: {gen_loss_arr.mean():.4f}')

    # epoch = epoch+1
    if (epoch + 1) % check_interval == 0:
        G.eval()
        with torch.no_grad():
            z = torch.randn(batch_size, 100).to(device)
            fake_labels = torch.randint(0, num_classes, (batch_size,)).to(device)
            y_fake = torch.zeros(batch_size, 100).to(device).scatter(1, fake_labels.view(-1,1), 1)
            # z = z.unsqueeze(2).unsqueeze(3)
            fake_images = G(z, y_fake)
            fake_images = (fake_images + 1) / 2
            print(f'Epoch [{epoch+1}/{num_of_epochs}], Loss D: {dis_loss_arr.mean():.4f}, Loss G: {gen_loss_arr.mean():.4f}')
            plt.figure(figsize=(10, 10))
            plt.axis('off')
            plt.title(f'Generated Images at Epoch {epoch+1}')
            plt.imshow(np.transpose(torchvision.utils.make_grid(fake_images[:batch_size], padding=2, normalize=True).cpu(), (1, 2, 0)))
            plt.show()
        G.train()
    
    epoch = epoch + 1

**Latent interpolations**

In [None]:
# now show some interpolations (note you do not have to do linear interpolations as shown here, you can do non-linear or gradient-based interpolation if you wish)
col_size = int(np.sqrt(batch_size))

z0 = z[0:col_size].repeat(col_size,1) # z for top row
z1 = z[batch_size-col_size:].repeat(col_size,1) # z for bottom row

t = torch.linspace(0,1,col_size).unsqueeze(1).repeat(1,col_size).view(batch_size,1).to(device)

lerp_z = (1-t)*z0 + t*z1 # linearly interpolate between two points in the latent space
lerp_g = generator.sample(lerp_z) # sample the model at the resulting interpolated latents

plt.rcParams['figure.dpi'] = 100
plt.grid(False)
plt.imshow(torchvision.utils.make_grid(lerp_g).cpu().numpy().transpose(1, 2, 0), cmap=plt.cm.binary)
plt.show()

**FID scores**

Evaluate the FID from 10k of your model samples (do not sample more than this) and compare it against the 10k test images. Calculating FID is somewhat involved, so we use a library for it. It can take a few minutes to evaluate. Lower FID scores are better.

In [7]:
%%capture
!pip install clean-fid
import os
from cleanfid import fid
from torchvision.utils import save_image

In [9]:
# define directories
real_images_dir = 'real_images'
generated_images_dir = 'generated_images'
num_samples = 10000 # do not change

# create/clean the directories
def setup_directory(directory):
    if os.path.exists(directory):
        !rm -r {directory} # remove any existing (old) data
    os.makedirs(directory)

# setup_directory(real_images_dir)
# setup_directory(generated_images_dir)

# generate and save 10k model samples
num_generated = 0
while num_generated < num_samples:

    # sample from your model, you can modify this
    z = torch.randn(batch_size, latent_dim).to(device)
    samples_batch = N.sample(z).cpu().detach()

    for image in samples_batch:
        if num_generated >= num_samples:
            break
        save_image(image, os.path.join(generated_images_dir, f"gen_img_{num_generated}.png"))
        num_generated += 1

# save 10k images from the CIFAR-100 test dataset
num_saved_real = 0
while num_saved_real < num_samples:
    real_samples_batch, _ = next(test_iterator)
    for image in real_samples_batch:
        if num_saved_real >= num_samples:
            break
        save_image(image, os.path.join(real_images_dir, f"real_img_{num_saved_real}.png"))
        num_saved_real += 1

In [None]:
# compute FID
score = fid.compute_fid(real_images_dir, generated_images_dir, mode="clean")
print(f"FID score: {score}")