In [1]:
import os
import torch
import torchvision
import dataloaders
import random
import torch.nn as nn
import torch.optim as optim
import torch.utils.data
import torch.nn.functional as F
import torchvision.datasets as dset
import torchvision.transforms as transforms
import torchvision.utils as vutils
import numpy as np
import matplotlib.pyplot as plt
from datetime import datetime
from tqdm import tqdm
from PIL import Image
import glob
from torchsummary import summary

KeyboardInterrupt: 

In [None]:
seed = 0
random.seed(seed)
torch.manual_seed(seed)

DATA_PATH = "data/tmdb-64"
TABLE_PATH = "data/tmdb-movies-220921-clean.pkl"
workers = torch.get_num_threads()
batch_size = 128
image_size = 64
number_of_channels = 3
noise_vector_size = 100
label_vector_size = 1
generator_feature_maps = 128
discriminator_feature_maps = 64
num_epochs = 15
learning_rate = 0.0002
beta1 = 0.5
beta2 = 0.999
ngpu = 1
number_of_samples = 64

In [None]:
output_directory = os.path.join(os.getcwd(), 'cdcgan-output', datetime.now().strftime('%Y-%m-%d_%H-%M-%S'))
os.makedirs(output_directory)

In [None]:
dataset = dataloaders.PosterDataset(table_path=TABLE_PATH, img_root_path=DATA_PATH,
                                    img_transform=transforms.Compose([
                                    transforms.ToTensor(),
                                    transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))]),
                                    img_in_ram=False,
                                    genre=None, genre_logic='and', og_lang=None, year=None, runtime=(40,np.inf),
                                    max_num=None, sort='popularity')
dataloader = torch.utils.data.DataLoader(dataset, batch_size=batch_size,
                                         shuffle=False, num_workers=workers, pin_memory=True)
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
real_batch = next(iter(dataloader))

if label_vector_size > 1:
    sample_labels = torch.stack(real_batch[23], dim=1).type(torch.FloatTensor).to(device)[:number_of_samples]
    sample_labels_generator = sample_labels[:, :, None, None].expand(number_of_samples, label_vector_size, 3, 1)
else:
    sample_labels = real_batch[8].type(torch.FloatTensor).to(device)[:number_of_samples]
    sample_labels_generator = sample_labels[:, None, None, None].expand(number_of_samples, label_vector_size, 3, 1)


plt.figure(figsize=(8,8))
plt.axis("off")
plt.title("Training Images")
plt.imshow(np.transpose(vutils.make_grid(real_batch[0].to(device)[:number_of_samples], padding=2, normalize=True).cpu(),(1,2,0)))
print('The dataset has ' + str(len(dataset)) + ' entries.')

In [None]:
def init_weights(m):
    classname = m.__class__.__name__
    if classname.find('Conv') != -1:
        nn.init.normal_(m.weight.data, 0.0, 0.02)
    elif classname.find('BatchNorm') != -1:
        nn.init.normal_(m.weight.data, 1.0, 0.02)
        nn.init.constant_(m.bias.data, 0)

In [None]:
class Generator(nn.Module):
    def __init__(self):
        super(Generator, self).__init__()
        self.deconv1_image = nn.ConvTranspose2d(noise_vector_size, generator_feature_maps * 4, kernel_size=(4, 4), stride=(1, 1), bias=False)
        self.deconv1_image_bn = nn.BatchNorm2d(generator_feature_maps * 4)
        self.deconv1_label = nn.ConvTranspose2d(label_vector_size, generator_feature_maps * 4, kernel_size=(4, 4), stride=(1, 1), bias=False)
        self.deconv1_label_bn = nn.BatchNorm2d(generator_feature_maps * 4)

        self.deconv2 = nn.ConvTranspose2d(generator_feature_maps * 8, generator_feature_maps * 4, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1), bias=False)
        self.deconv2_bn = nn.BatchNorm2d(generator_feature_maps * 4)

        self.deconv3 =  nn.ConvTranspose2d(generator_feature_maps * 4, generator_feature_maps * 2, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1), bias=False)
        self.deconv3_bn = nn.BatchNorm2d(generator_feature_maps * 2)

        self.deconv4 =  nn.ConvTranspose2d(generator_feature_maps * 2, generator_feature_maps, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1), bias=False)
        self.deconv4_bn = nn.BatchNorm2d(generator_feature_maps)

        self.deconv5 =  nn.ConvTranspose2d(generator_feature_maps, number_of_channels, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1), bias=False)

    def forward(self, input, label):
        x = F.relu(self.deconv1_image_bn(self.deconv1_image(input)))
        y = F.relu(self.deconv1_label_bn(self.deconv1_label(label)))
        x = torch.cat([x, y], 1)
        x = F.relu(self.deconv2_bn(self.deconv2(x)))
        x = F.relu(self.deconv3_bn(self.deconv3(x)))
        x = F.relu(self.deconv4_bn(self.deconv4(x)))
        return torch.tanh(self.deconv5(x))

In [None]:
generator_network = Generator().to(device)
generator_network.apply(init_weights)
print(generator_network)

In [None]:
summary(generator_network, input_size=[(noise_vector_size,1,3),(label_vector_size,1,3)])

In [None]:
class Discriminator(nn.Module):
    def __init__(self, ngpu):
        super(Discriminator, self).__init__()
        self.conv1_1 = nn.Conv2d(3, discriminator_feature_maps, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1), bias=False)
        self.conv1_2 = nn.Conv2d(label_vector_size, discriminator_feature_maps, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1), bias=False)

        self.conv2 = nn.Conv2d(discriminator_feature_maps * 2, discriminator_feature_maps * 4, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1), bias=False)
        self.conv2_bn = nn.BatchNorm2d(discriminator_feature_maps * 4)

        self.conv3 = nn.Conv2d(discriminator_feature_maps * 4, discriminator_feature_maps * 8, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1), bias=False)
        self.conv3_bn = nn.BatchNorm2d(discriminator_feature_maps * 8)

        self.conv4 = nn.Conv2d(discriminator_feature_maps * 8, discriminator_feature_maps * 16, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1), bias=False)
        self.conv4_bn = nn.BatchNorm2d(discriminator_feature_maps * 16)
        self.conv5 = nn.Conv2d(discriminator_feature_maps * 16, 1, kernel_size=(4, 4), stride=(1, 1), bias=False)
        self.conv6 = nn.Conv2d(1, 1, kernel_size=(3, 1), stride=(1, 1), bias=False)

    def forward(self, input, label):
        x = F.leaky_relu(self.conv1_1(input), 0.2)
        y = F.leaky_relu(self.conv1_2(label), 0.2)
        x = torch.cat([x, y], 1)
        x = F.leaky_relu(self.conv2_bn(self.conv2(x)), 0.2)
        x = F.leaky_relu(self.conv3_bn(self.conv3(x)), 0.2)
        x = F.leaky_relu(self.conv4(x))
        x = F.leaky_relu(self.conv5(x))
        x = torch.sigmoid(self.conv6(x))
        return x

In [None]:
discriminator_network = Discriminator(ngpu).to(device)
discriminator_network.apply(init_weights)
print(discriminator_network)

In [None]:
summary(discriminator_network, input_size=[(3,96,64),(label_vector_size,96,64)])

In [None]:
def generate_and_save_images(generator_network, epoch, noise, img_list, sample_labels):
  with torch.no_grad():
    images = generator_network(noise, sample_labels).detach().cpu()
  fig = plt.figure(figsize=(8, 8))
  plt.axis("off")
  img_list.append(vutils.make_grid(images, padding=2, normalize=True))
  plt.imshow(np.transpose(img_list[-1],(1,2,0)))
  plt.show()
  fig.savefig(output_directory + '/image_at_epoch_{:04d}.png'.format(epoch))

def save_checkpoint(generator_network, optimizer_generator, discriminator_network, optimizer_discriminator, epoch):
  torch.save({
            'epoch': epoch,
            'generator_model_state_dict': generator_network.state_dict(),
            'generator_optimizer_state_dict': optimizer_generator.state_dict(),
            'discriminator_model_state_dict': discriminator_network.state_dict(),
            'discriminator_optimizer_state_dict': optimizer_discriminator.state_dict(),
            }, output_directory + '/gan_at_epoch_{:04d}.pt'.format(epoch))

In [None]:
loss_function = nn.BCELoss()
samples_noise = torch.randn(number_of_samples, noise_vector_size, 3, 1, device=device)

optimizer_discriminator = optim.Adam(discriminator_network.parameters(), lr=learning_rate, betas=(beta1, beta2))
optimizer_generator = optim.Adam(generator_network.parameters(), lr=learning_rate, betas=(beta1, beta2))

real_label = 1.
fake_label = 0.

In [None]:
img_list = []
generator_losses = []
discriminator_losses = []
current_iteration = 0
discriminator_steps = 2

for epoch in range(num_epochs):
    for i, data in enumerate(pbar := tqdm(dataloader)):

        # Train Discriminator
        for step in range(discriminator_steps):
            discriminator_network.zero_grad()
            real = data[0].to(device)
            if label_vector_size > 1:
                input_labels = torch.stack(data[23], dim=1).type(torch.FloatTensor).to(device)
                input_labels_generator = input_labels[:, :, None, None].expand(real.size(0), label_vector_size, 3, 1)
                input_labels_discriminator = input_labels[:, :, None, None].expand(real.size(0), label_vector_size, 96, 64)
            else:
                input_labels = data[23].type(torch.FloatTensor).to(device)
                input_labels_generator = input_labels[:, None, None, None].expand(real.size(0), label_vector_size, 3, 1)
                input_labels_discriminator = input_labels[:, None, None, None].expand(real.size(0), label_vector_size, 96, 64)


            label = torch.full((real.size(0),), real_label, dtype=torch.float, device=device)
            output = discriminator_network(real, input_labels_discriminator).view(-1)
            discriminator_error_on_real = loss_function(output, label)
            discriminator_error_on_real.backward()
            D_x = output.mean().item()

            noise = torch.randn(real.size(0), noise_vector_size, 3, 1, device=device)
            fake = generator_network(noise, input_labels_generator)
            label.fill_(fake_label)
            output = discriminator_network(fake.detach(), input_labels_discriminator).view(-1)
            discriminator_error_on_fake = loss_function(output, label)
            discriminator_error_on_fake.backward()
            D_G_z1 = output.mean().item()
            discriminator_error = discriminator_error_on_real + discriminator_error_on_fake
            optimizer_discriminator.step()

        # Train Generator
        for step in range(1):
            generator_network.zero_grad()
            label.fill_(real_label)
            output = discriminator_network(fake, input_labels_discriminator).view(-1)
            generator_error = loss_function(output, label)
            generator_error.backward()
            D_G_z2 = output.mean().item()
            optimizer_generator.step()

        pbar.set_description('[%d/%d]\tLoss_D: %.4f\tLoss_G: %.4f\tD(x): %.4f\tD(G(z)): %.4f / %.4f' % (epoch, num_epochs-1, discriminator_error.item(), generator_error.item(), D_x, D_G_z1, D_G_z2))
        generator_losses.append(generator_error.item())
        discriminator_losses.append(discriminator_error.item())

        current_iteration += 1

    generate_and_save_images(generator_network, epoch, samples_noise, img_list, sample_labels_generator)
    save_checkpoint(generator_network, optimizer_generator, discriminator_network, optimizer_discriminator, epoch)

In [None]:
fig = plt.figure(figsize=(10,5))
plt.title("Generator and Discriminator Loss During Training")
plt.plot(generator_losses,label="Generator")
plt.plot(discriminator_losses,label="Discriminator")
plt.xlabel("Total Batch Iterations")
plt.ylabel("Loss")
plt.legend()
plt.show()
fig.savefig(f"{output_directory}/training_loss.png")

In [None]:
sample_images = [Image.open(image) for image in glob.glob(f"{output_directory}/image*.png")]
thumbnail = sample_images[0]
thumbnail.save(f"{output_directory}/dcgan.gif", format="GIF", append_images=sample_images, save_all=True, duration=1000, loop=0)

In [None]:
real = next(iter(dataloader))

fig = plt.figure(figsize=(15,15))
plt.subplot(1,2,1)
plt.axis("off")
plt.title("Real Images")
plt.imshow(np.transpose(vutils.make_grid(real_batch[0].to(device)[:number_of_samples], padding=5, normalize=True).cpu(),(1,2,0)))

plt.subplot(1,2,2)
plt.axis("off")
plt.title("Fake Images")
plt.imshow(np.transpose(img_list[-1],(1,2,0)))
plt.show()

fig.savefig(f"{output_directory}/real_vs_fake.png")