# GAN training
We have embeddings of over 1000 summaries of popular movies and a few hundred completely unrelated images. We'll train a GAN to generate images from movie embeddings, without any constraint on the relationship between the movie and the generated image.

In [None]:
import pandas as pd
import numpy as np
import torch
import torchvision
%matplotlib inline
import matplotlib.pyplot as plt
from tqdm import tqdm_notebook as tqdm
from skimage.io import imsave

In [None]:
root_dir = './'

In [None]:
movies = pd.read_csv(root_dir + 'data/movies/processed.csv')
movie_vectors = np.array([np.fromstring(vector, sep=' ') for vector in movies['vector']])
movie_vectors_std = movie_vectors.std()
print(f'movie_vectors: shape {movie_vectors.shape}, std {movie_vectors_std}')

In [None]:
class PenabrancaDataset(torch.utils.data.Dataset):
    def __init__(self, data_path, noise_distribution=torch.distributions.Normal(0, .02)):
        self.images = torch.tensor(np.load(data_path)).float()
        self.noise_distribution = noise_distribution
    
    def __len__(self):
        return len(self.images)
    
    def __getitem__(self, idx):
        result = self.images[idx]
        if np.random.rand() < .5:
            result = result.flip(-2)
        if np.random.rand() < .5:
            result = result.flip(-1)
        if np.random.rand() < .5:
            result = result.permute(1, 0)
        result += self.noise_distribution.sample(result.shape)
        return result
    
    @property
    def image_size(self):
        return self.images.shape[-1]

In [None]:
image_dataset = PenabrancaDataset(root_dir + 'data/penabranca/processed.npy')

In [None]:
plt.title('test image')
plt.imshow(image_dataset[0], cmap='gray')
plt.show()

In [None]:
image_loader = torch.utils.data.DataLoader(image_dataset, batch_size=16, shuffle=True)

This implementation is based on [this](https://github.com/eriklindernoren/PyTorch-GAN/blob/master/implementations/dcgan/dcgan.py) DCGAN code.

In [None]:
class Generator(torch.nn.Module):
    def __init__(self):
        super(Generator, self).__init__()

        self.start_size = image_dataset.image_size // (2 * 2)
        self.dense = torch.nn.Sequential(torch.nn.Linear(movie_vectors.shape[-1], 128 * self.start_size ** 2))
        
        def conv_block(in_filters, out_filters, upscaling):
            return [
                torch.nn.Upsample(scale_factor=upscaling),
                torch.nn.Conv2d(in_filters, out_filters, 3, stride=1, padding=1),
                torch.nn.BatchNorm2d(out_filters, 0.8),
                torch.nn.LeakyReLU(0.2, inplace=True),
            ]

        self.conv_blocks = torch.nn.Sequential(
            torch.nn.BatchNorm2d(128),
            *conv_block(128, 128, upscaling=2),
            *conv_block(128, 64, upscaling=2),
            torch.nn.Conv2d(64, 1, 3, stride=1, padding=1),
            torch.nn.Sigmoid(),
        )

        
    def forward(self, vector):
        if torch.cuda.is_available():
            vector = vector.cuda()
        start = self.dense(vector)
        start = start.view(start.shape[0], -1, self.start_size, self.start_size)
        img = self.conv_blocks(start)
        img = img.squeeze(1) # skimage grayscale format (no channel dimension)
        return img

In [None]:
class Discriminator(torch.nn.Module):
    def __init__(self):
        super(Discriminator, self).__init__()

        def conv_block(in_filters, out_filters, batch_norm=True):
            block = [
                torch.nn.Conv2d(in_filters, out_filters, 3, 2, 1),
                torch.nn.LeakyReLU(0.2, inplace=True),
                torch.nn.Dropout2d(0.25),
            ]
            if batch_norm:
                block.append(torch.nn.BatchNorm2d(out_filters, 0.8))
            return block

        self.conv_blocks = torch.nn.Sequential(
            *conv_block(1, 16, batch_norm=False),
            *conv_block(16, 32),
            *conv_block(32, 64),
            *conv_block(64, 128),
        )

        downsampled_size = image_dataset.image_size // 2 ** 4
        self.adv_layer = torch.nn.Sequential(
            torch.nn.Linear(128 * downsampled_size ** 2, 1),
            torch.nn.Sigmoid(),
        )

        
    def forward(self, img):
        if torch.cuda.is_available():
            img = img.cuda()
        img = img.unsqueeze(1) # add channel dimension
        out = self.conv_blocks(img)
        out = out.view(out.shape[0], -1)
        realness = self.adv_layer(out)
        realness = realness.squeeze(1)
        return realness

In [None]:
generator = Generator()
discriminator = Discriminator()

In [None]:
def init_weights(m):
    class_name = m.__class__.__name__
    if 'Conv' in class_name:
        torch.nn.init.normal_(m.weight.data, 0.0, 0.02)
    elif 'BatchNorm2d' in class_name:
        torch.nn.init.normal_(m.weight.data, 1.0, 0.02)
        torch.nn.init.constant_(m.bias.data, 0.0)

In [None]:
generator.apply(init_weights)

In [None]:
discriminator.apply(init_weights)

In [None]:
criterion = torch.nn.BCELoss()

In [None]:
if torch.cuda.is_available():
    print('switching to cuda')
    generator.cuda()
    discriminator.cuda()
    criterion.cuda()
else:
    print('cuda unavailable')

In [None]:
def train(generator, discriminator, optimizer_params=None, num_epochs=64, epochs_per_preview=4):
    if optimizer_params is None:
        optimizer_params = {'lr': 2e-4, 'betas': (.5, .999)}
    generator_optimizer = torch.optim.Adam(generator.parameters(), **optimizer_params)
    discriminator_optimizer = torch.optim.Adam(discriminator.parameters(), **optimizer_params)
    
    generator_losses = []
    discriminator_losses = []
    
    for epoch in tqdm(range(num_epochs)):
        for image_batch in image_loader:
            vectors = torch.normal(mean=0, std=movie_vectors_std, size=(len(image_batch), movie_vectors.shape[-1]))
            fakes = generator(vectors)
            fake_realness = discriminator(fakes)
            real_realness = discriminator(image_batch)

            zero_realness = torch.zeros_like(fake_realness)
            full_realness = torch.ones_like(fake_realness)

            generator_optimizer.zero_grad()
            generator_loss = criterion(fake_realness, full_realness)
            generator_losses.append(generator_loss.item())
            generator_loss.backward(retain_graph=True)
            generator_optimizer.step()

            discriminator_optimizer.zero_grad()
            discriminator_loss = (criterion(fake_realness, zero_realness) + criterion(real_realness, full_realness)) / 2
            discriminator_losses.append(discriminator_loss.item())
            discriminator_loss.backward()
            discriminator_optimizer.step()

        if epoch % epochs_per_preview == 0:
            plt.title(f'image generated in epoch {epoch}')
            plt.imshow(fakes[0].detach().cpu().numpy(), cmap='gray')
            plt.show()
    
    return np.array(generator_losses), np.array(discriminator_losses)

In [None]:
generator_loss, discriminator_loss = train(generator, discriminator, num_epochs=1024, epochs_per_preview=16)

In [None]:
plt.plot(generator_loss, label='generator loss')
plt.plot(discriminator_loss, label='discriminator loss')
plt.legend()
plt.xlabel('batch number')
plt.show)()

## Training's done, let's save the results!

In [None]:
generator.eval()

for idx, vector in tqdm(enumerate(movie_vectors)):
    generated = generator(torch.tensor(vector).float().unsqueeze(0))[0]
    imsave(root_dir + f'generated/{idx}.png', generated.detach().cpu().numpy())