# PyTorch for GANs Workshop Part 1

This notebook implements a simple GAN training pipeline which can easily be adapted for your projects. Components of a typical GAN training pipeline include:

*   Custom Dataset Class + Dataloader
*   Generator
*   Discriminator
*   Optimizer + Scheduler
*   Loss Functions
*   Training Loop
*   Loading and Saving Checkpoints

This notebook implements DCGAN for simplicity but feel free to swap out the archtecture for newer and better models.

In [30]:
# TOTAL_CLASSES = 1103
TOTAL_CLASSES = 1103 #reducing the number of labels to 398 (only culture tags)

DATA_PATH_SMALL = "drive/MyDrive/CS236G/data-sample/"
DATA_PATH_BIG = "drive/MyDrive/CS236G/data-big/"
device = 'cuda'
z_dim = 64 #noise vector dimension

image_resize = 10
postcard_shape = (3, image_resize, image_resize) 
n_classes = TOTAL_CLASSES
batch_size =128

criterion = nn.BCEWithLogitsLoss()
n_epochs = 10
display_step = 20
batch_size = 128
lr = 0.0002
device = 'cuda'


In [None]:
from google.colab import drive
drive.mount('/content/drive')

Mounted at /content/drive


In [3]:
import torch
from torch import nn
from tqdm.auto import tqdm
from torchvision import transforms
from torchvision.datasets import MNIST
from torchvision.utils import make_grid
from torch.utils.data import DataLoader
import matplotlib.pyplot as plt
import torch.utils.data as data
import os
torch.manual_seed(0) # Set for our testing purposes, please do not change!
from PIL import Image
import csv



def show_tensor_images(image_tensor, num_images=25, size=(1, 28, 28), nrow=5, show=True):
    '''
    Function for visualizing images: Given a tensor of images, number of images, and
    size per image, plots and prints the images in an uniform grid.
    '''
    image_tensor = (image_tensor + 1) / 2
    image_unflat = image_tensor.detach().cpu()
    image_grid = make_grid(image_unflat[:num_images], nrow=nrow)
    plt.imshow(image_grid.permute(1, 2, 0).squeeze())
    if show:
        plt.show()

# Generator

> Indented block



In [4]:
class MyConvTranspose2d(nn.Module):
    def __init__(self, conv, output_size):
        super(MyConvTranspose2d, self).__init__()
        self.output_size = output_size
        self.conv = conv
        
    def forward(self, x):
        x = self.conv(x, output_size=self.output_size)
        return x

class Generator(nn.Module):
    '''
    Generator Class
    Values:
        input_dim: the dimension of the input vector, a scalar
        im_chan: the number of channels in the images, fitted for the dataset used, a scalar
              (MNIST is black-and-white, so 1 channel is your default)
        hidden_dim: the inner dimension, a scalar
    '''
    def __init__(self, input_dim=10, im_chan=3, hidden_dim=64): #changing im_chan
        super(Generator, self).__init__()
        self.input_dim = input_dim
        # Build the neural network
        self.gen = nn.Sequential(
            self.make_gen_block(input_dim, hidden_dim * 4),
            # self.make_gen_block(hidden_dim * 4, hidden_dim * 2, kernel_size=6, stride=1, final_layer=True)
            self.make_gen_block(hidden_dim * 4, im_chan, kernel_size=6, stride=1, final_layer=True)
            # self.make_gen_block(hidden_dim * 2, hidden_dim),
            # self.make_gen_block(hidden_dim, im_chan, kernel_size=4, stride=4, final_layer=True)
        ) 

    def make_gen_block(self, input_channels, output_channels, kernel_size=5, stride=5, final_layer=False):
        '''
        Function to return a sequence of operations corresponding to a generator block of DCGAN;
        a transposed convolution, a batchnorm (except in the final layer), and an activation.
        Parameters:
            input_channels: how many channels the input feature representation has
            output_channels: how many channels the output feature representation should have
            kernel_size: the size of each convolutional filter, equivalent to (kernel_size, kernel_size)
            stride: the stride of the convolution
            final_layer: a boolean, true if it is the final layer and false otherwise 
                      (affects activation and batchnorm)
        '''
        if not final_layer:
            return nn.Sequential(
                nn.ConvTranspose2d(input_channels, output_channels, kernel_size, stride),
                nn.BatchNorm2d(output_channels),
                nn.ReLU(inplace=True),
            )
        else:
            return nn.Sequential(
                nn.ConvTranspose2d(input_channels, output_channels, kernel_size, stride),
                nn.Tanh(),
            )

    def forward(self, noise):
        '''
        Function for completing a forward pass of the generator: Given a noise tensor, 
        returns generated images.
        Parameters:
            noise: a noise tensor with dimensions (n_samples, input_dim)
        '''
        x = noise.view(len(noise), self.input_dim, 1, 1)
        return self.gen(x)

def get_noise(n_samples, input_dim, device='cpu'):
    '''
    Function for creating noise vectors: Given the dimensions (n_samples, input_dim)
    creates a tensor of that shape filled with random numbers from the normal distribution.
    Parameters:
        n_samples: the number of samples to generate, a scalar
        input_dim: the dimension of the input vector, a scalar
        device: the device type
    '''
    return torch.randn(n_samples, input_dim, device=device)

In [None]:
temp_noise = torch.rand((1,1167)).to(device)
fake_ = gen(temp_noise)

NameError: ignored

In [None]:
fake_.shape

NameError: ignored

# Discriminator

In [5]:
class Discriminator(nn.Module):
    '''
    Discriminator Class
    Values:
      im_chan: the number of channels in the images, fitted for the dataset used, a scalar
            (MNIST is black-and-white, so 1 channel is your default)
      hidden_dim: the inner dimension, a scalar
    '''
    def __init__(self, im_chan=3, hidden_dim=64): #changing im_chan
        super(Discriminator, self).__init__()
        self.disc = nn.Sequential(
            self.make_disc_block(im_chan, hidden_dim),
            self.make_disc_block(hidden_dim, hidden_dim * 2),
            self.make_disc_block(hidden_dim * 2, 1, final_layer=True),
        )

    def make_disc_block(self, input_channels, output_channels, kernel_size=2, stride=2, final_layer=False):
        '''
        Function to return a sequence of operations corresponding to a discriminator block of the DCGAN; 
        a convolution, a batchnorm (except in the final layer), and an activation (except in the final layer).
        Parameters:
            input_channels: how many channels the input feature representation has
            output_channels: how many channels the output feature representation should have
            kernel_size: the size of each convolutional filter, equivalent to (kernel_size, kernel_size)
            stride: the stride of the convolution
            final_layer: a boolean, true if it is the final layer and false otherwise 
                      (affects activation and batchnorm)
        '''
        if not final_layer:
            return nn.Sequential(
                nn.Conv2d(input_channels, output_channels, kernel_size, stride),
                nn.BatchNorm2d(output_channels),
                nn.LeakyReLU(0.2, inplace=True),
            )
        else:
            return nn.Sequential(
                nn.Conv2d(input_channels, output_channels, kernel_size, stride),
            )

    def forward(self, image):
        '''
        Function for completing a forward pass of the discriminator: Given an image tensor, 
        returns a 1-dimension tensor representing fake/real.
        Parameters:
            image: a flattened image tensor with dimension (im_chan)
        '''
        disc_pred = self.disc(image)
        return disc_pred.view(len(disc_pred), -1)

In [None]:
disc = Discriminator(im_chan=discriminator_im_chan).to(device)
temp_ = torch.rand((20, 3312, 10, 10)).to(device)
output = disc(temp_)

#Class Input : We will replace this with description embeddings

In conditional GANs, the input vector for the generator will also need to include the class information. The class is represented using a one-hot encoded vector where its length is the number of classes and each index represents a class. The vector is all 0's and a 1 on the chosen class.

In [6]:
import torch.nn.functional as F
def get_one_hot_labels(labels, n_classes):
    '''
    Function for creating one-hot vectors for the labels, returns a tensor of shape (?, num_classes).
    Parameters:
        labels: tensor of labels from the dataloader, size (?)
        n_classes: the total number of classes in the dataset, an integer scalar
    '''
    #### START CODE HERE ####
    return nn.functional.one_hot(labels, n_classes) 
    #### END CODE HERE ####

# Combine vectors

In [7]:
def combine_vectors(x, y):
    '''
    Function for combining two vectors with shapes (n_samples, ?) and (n_samples, ?).
    Parameters:
      x: (n_samples, ?) the first vector. 
        In this assignment, this will be the noise vector of shape (n_samples, z_dim), 
        but you shouldn't need to know the second dimension's size.
      y: (n_samples, ?) the second vector.
        Once again, in this assignment this will be the one-hot class vector 
        with the shape (n_samples, n_classes), but you shouldn't assume this in your code.
    '''
    # Note: Make sure this function outputs a float no matter what inputs it receives
    #### START CODE HERE ####
#     print(x.shape, y.shape)
    combined = torch.cat((x.float(),y.float()),dim=1)
    #### END CODE HERE ####
    return combined

# Dataset transformations and Hyperparams
We will replace n_classes with embedding size => n_dim

In [16]:


class Dataset(data.Dataset):
    """Dataset class for dsprites"""

    def __init__(self, data_root, normalize=True, rotate=False):
        
        
        #get the data labels
        self.img_to_labels = {}
        with open(DATA_PATH_BIG+'train.csv') as csv_file:
            csv_reader = csv.reader(csv_file, delimiter=',')
            next(csv_reader) # skips the first row containing the row names
            for row in tqdm(csv_reader):
              classes = row[1].split()
              classes = [int(e) for e in classes]
              self.img_to_labels[str(row[0])] = classes

        # Recursively exract paths to all .png files in subdirectories
        self.file_paths = []
        self.file_names = []
        for path, subdirs, files in tqdm(os.walk(data_root)):
            for name in files:
                plain_name = name.split('.')[0]

                if name.endswith(".png") and plain_name in self.img_to_labels:
                    
                    # print('here')
                    # Way 1
                    # condition = True
                    # for class_ in self.img_to_labels[plain_name]:
                    #     if class_ >= TOTAL_CLASSES:
                    #         condition = False
                    #         break

                    # Way 2
                    condition = True
                    labels = self.img_to_labels[plain_name]
                    filtered_labels = [e for e in labels if e < TOTAL_CLASSES]
                    if len(filtered_labels) == 0:
                        condition = False

                    if condition:
                      self.file_paths.append(os.path.join(path, name))
                      name = plain_name
                      self.file_names.append(name)
            # break
        self.transform = self._set_transforms(normalize, rotate)


    def _set_transforms(self, normalize, rotate):
        """Decide transformations to data to be applied"""

        transform = transforms.Compose([
                    transforms.Resize((image_resize,image_resize)),
                    transforms.ToTensor(),
                    transforms.Normalize((0.5,), (0.5,)),
                    ])
        return transform

    def __len__(self):
        """Required: specify dataset length for dataloader"""
        return len(self.file_paths)

    def __getitem__(self, index):
        """Required: specify what each iteration in dataloader yields"""
        img = Image.open(self.file_paths[index])
        img = self.transform(img)
        one_hot = torch.zeros(n_classes, dtype=int)
        classes = self.img_to_labels[self.file_names[index]]

        for class_ in classes:
          one_hot[class_]=1
        
        return img,one_hot


# dataset = Dataset(data_root=DATA_PATH_SMALL+'train/')

#loading the bigger dataset
dataset = Dataset(data_root=DATA_PATH_BIG+'train/')

dataloader = data.DataLoader(dataset,
                             batch_size=batch_size,
                             num_workers=4,
                             shuffle=True)

HBox(children=(FloatProgress(value=1.0, bar_style='info', max=1.0), HTML(value='')))




HBox(children=(FloatProgress(value=1.0, bar_style='info', max=1.0), HTML(value='')))

[1;30;43mStreaming output truncated to the last 5000 lines.[0m
here
here
here
here
here
here
here
here
here
here
here
here
here
here
here
here
here
here
here
here
here
here
here
here
here
here
here
here
here
here
here
here
here
here
here
here
here
here
here
here
here
here
here
here
here
here
here
here
here
here
here
here
here
here
here
here
here
here
here
here
here
here
here
here
here
here
here
here
here
here
here
here
here
here
here
here
here
here
here
here
here
here
here
here
here
here
here
here
here
here
here
here
here
here
here
here
here
here
here
here
here
here
here
here
here
here
here
here
here
here
here
here
here
here
here
here
here
here
here
here
here
here
here
here
here
here
here
here
here
here
here
here
here
here
here
here
here
here
here
here
here
here
here
here
here
here
here
here
here
here
here
here
here
here
here
here
here
here
here
here
here
here
here
here
here
here
here
here
here
here
here
here
here
here
here
here
here
here
here
here
here
here
here
here
here
here
here


In [None]:
#testing if dataloader works
for real,labels in tqdm(dataloader):
  cur_batch_size = len(real)
  real = real.to(device='cpu')
  print(real.shape) #currently have 101 images
  print('hi')
  break

In [22]:
labels.shape

torch.Size([128, 1103])

# Initializing input dimensions and weights

In [23]:
def get_input_dimensions(z_dim, postcard_shape, n_classes):
    '''
    Function for getting the size of the conditional input dimensions 
    from z_dim, the image shape, and number of classes.
    Parameters:
        z_dim: the dimension of the noise vector, a scalar
        postcard_shape: the shape of each postcard image as (C, W, H), which is (3, 200, 200)
        n_classes: the total number of classes in the dataset, an integer scalar
                (10 for MNIST)
    Returns: 
        generator_input_dim: the input dimensionality of the conditional generator, 
                          which takes the noise and class vectors
        discriminator_im_chan: the number of input channels to the discriminator
                            (e.g. C x 200 x 200 for postcard)
    '''
    #### START CODE HERE ####
    generator_input_dim = z_dim+n_classes
    discriminator_im_chan = postcard_shape[0]*(n_classes+1)
    # discriminator_im_chan = postcard_shape[0]+n_classes

    #### END CODE HERE ####
    return generator_input_dim, discriminator_im_chan

In [28]:
generator_input_dim, discriminator_im_chan = get_input_dimensions(z_dim, postcard_shape, n_classes)

gen = Generator(input_dim=generator_input_dim).to(device)
gen_opt = torch.optim.Adam(gen.parameters(), lr=lr)
disc = Discriminator(im_chan=discriminator_im_chan).to(device)
disc_opt = torch.optim.Adam(disc.parameters(), lr=lr)

def weights_init(m):
    if isinstance(m, nn.Conv2d) or isinstance(m, nn.ConvTranspose2d):
        torch.nn.init.normal_(m.weight, 0.0, 0.02)
    if isinstance(m, nn.BatchNorm2d):
        torch.nn.init.normal_(m.weight, 0.0, 0.02)
        torch.nn.init.constant_(m.bias, 0)
gen = gen.apply(weights_init)
disc = disc.apply(weights_init)

In [29]:
discriminator_im_chan

3312

In [None]:

cur_step = 0
generator_losses = []
discriminator_losses = []

#UNIT TEST NOTE: Initializations needed for grading
noise_and_labels = False
fake = False

fake_image_and_labels = False
real_image_and_labels = False
disc_fake_pred = False
disc_real_pred = False


for epoch in range(n_epochs):
    # Dataloader returns the batches and the labels
    for real, labels in tqdm(dataloader):

        cur_batch_size = len(real)
        # Flatten the batch of real images from the dataset
        real = real.to(device)
        # one_hot_labels = get_one_hot_labels(labels.to(device), n_classes)
        one_hot_labels = labels.to(device)
        image_one_hot_labels = one_hot_labels[:, :, None, None]
        image_one_hot_labels = image_one_hot_labels.repeat(1, postcard_shape[0], postcard_shape[1], postcard_shape[2])
        # image_one_hot_labels = image_one_hot_labels.repeat(1, postcard_shape[0], 28, 28) #ToDo 

        ### Update discriminator ###
        # Zero out the discriminator gradients
        disc_opt.zero_grad()
        # Get noise corresponding to the current batch_size 
        fake_noise = get_noise(cur_batch_size, z_dim, device=device)
        
        # Now you can get the images from the generator
        # Steps: 1) Combine the noise vectors and the one-hot labels for the generator
        #        2) Generate the conditioned fake images
       
        #### START CODE HERE ####
        noise_and_labels = combine_vectors(fake_noise, one_hot_labels)
        fake = gen(noise_and_labels)

        #### END CODE HERE ####
        
        # Make sure that enough images were generated
        assert len(fake) == len(real)
        # Check that correct tensors were combined
        assert tuple(noise_and_labels.shape) == (cur_batch_size, fake_noise.shape[1] + one_hot_labels.shape[1])
        # It comes from the correct generator
        assert tuple(fake.shape) == (len(real), 3, image_resize, image_resize) #ToDo: check the fake image size
        # Now you can get the predictions from the discriminator
        # Steps: 1) Create the input for the discriminator
        #           a) Combine the fake images with image_one_hot_labels, 
        #              remember to detach the generator (.detach()) so you do not backpropagate through it
        #           b) Combine the real images with image_one_hot_labels
        #        2) Get the discriminator's prediction on the fakes as disc_fake_pred
        #        3) Get the discriminator's prediction on the reals as disc_real_pred
        
        #### START CODE HERE ####
        fake_image_and_labels = combine_vectors(fake, image_one_hot_labels)
        real_image_and_labels = combine_vectors(real, image_one_hot_labels)

        disc_fake_pred = disc(fake_image_and_labels.detach())
        disc_real_pred = disc(real_image_and_labels)
        #### END CODE HERE ####
        
        # Make sure shapes are correct 
        assert tuple(fake_image_and_labels.shape) == (len(real), fake.detach().shape[1] + image_one_hot_labels.shape[1], image_resize ,image_resize)
        assert tuple(real_image_and_labels.shape) == (len(real), real.shape[1] + image_one_hot_labels.shape[1],image_resize , image_resize)
        # Make sure that enough predictions were made
        assert len(disc_real_pred) == len(real)
        # Make sure that the inputs are different
        assert torch.any(fake_image_and_labels != real_image_and_labels)
        # Shapes must match
        assert tuple(fake_image_and_labels.shape) == tuple(real_image_and_labels.shape)
        assert tuple(disc_fake_pred.shape) == tuple(disc_real_pred.shape)
        
        
        disc_fake_loss = criterion(disc_fake_pred, torch.zeros_like(disc_fake_pred))
        disc_real_loss = criterion(disc_real_pred, torch.ones_like(disc_real_pred))
        disc_loss = (disc_fake_loss + disc_real_loss) / 2
        disc_loss.backward(retain_graph=True)
        disc_opt.step() 

        # Keep track of the average discriminator loss
        discriminator_losses += [disc_loss.item()]

        ### Update generator ###
        # Zero out the generator gradients
        gen_opt.zero_grad()

        fake_image_and_labels = combine_vectors(fake, image_one_hot_labels)
        # This will error if you didn't concatenate your labels to your image correctly
        disc_fake_pred = disc(fake_image_and_labels)
        gen_loss = criterion(disc_fake_pred, torch.ones_like(disc_fake_pred))
        gen_loss.backward()
        gen_opt.step()

        # Keep track of the generator losses
        generator_losses += [gen_loss.item()]
        #

        if cur_step % display_step == 0 and cur_step > 0:
            gen_mean = sum(generator_losses[-display_step:]) / display_step
            disc_mean = sum(discriminator_losses[-display_step:]) / display_step
            print(f"Step {cur_step}: Generator loss: {gen_mean}, discriminator loss: {disc_mean}")
            show_tensor_images(fake)
            show_tensor_images(real)
            step_bins = 20
            x_axis = sorted([i * step_bins for i in range(len(generator_losses) // step_bins)] * step_bins)
            num_examples = (len(generator_losses) // step_bins) * step_bins
            plt.plot(
                range(num_examples // step_bins), 
                torch.Tensor(generator_losses[:num_examples]).view(-1, step_bins).mean(1),
                label="Generator Loss"
            )
            plt.plot(
                range(num_examples // step_bins), 
                torch.Tensor(discriminator_losses[:num_examples]).view(-1, step_bins).mean(1),
                label="Discriminator Loss"
            )
            plt.legend()
            plt.show()
        elif cur_step == 0:
            print("Congratulations! If you've gotten here, it's working. Please let this train until you're happy with how the generated numbers look, and then go on to the exploration!")
        cur_step += 1

Output hidden; open in https://colab.research.google.com to view.

In [None]:
fake_image_and_labels.shape

torch.Size([20, 3312, 10, 10])