In [1]:
import torch
from torch import nn
from tqdm.auto import tqdm
import torchvision
from torchvision import transforms
from torchvision.datasets import MNIST
from torchvision.utils import make_grid
from torch.utils.data import DataLoader, Dataset
import matplotlib.pyplot as plt
torch.manual_seed(0) # Set for testing purposes, please do not change!


def show_tensor_images(image_tensor, num_images=1, size=(1, 256, 256)):
    '''
    Function for visualizing images: Given a tensor of images, number of images, and
    size per image, plots and prints the images in an uniform grid.
    '''
    image_tensor = (image_tensor + 1) / 2
    image_unflat = image_tensor.detach().cpu()
    image_grid = make_grid(image_unflat[:num_images], nrow=5)
    plt.imshow(image_grid.permute(1, 2, 0).squeeze())
    plt.show()

In [None]:
class Discriminator(nn.Module):
    '''
    Discriminator Class
    Values:
        im_chan: the number of channels of the output image, a scalar
              (MNIST is black-and-white, so 1 channel is your default)
    hidden_dim: the inner dimension, a scalar
    '''
    def __init__(self, im_chan=3, hidden_dim=16):
        super(Discriminator, self).__init__()
        self.disc = nn.Sequential(
            self.make_disc_block(im_chan, hidden_dim, stride=4),
            self.make_disc_block(hidden_dim, hidden_dim * 2, stride=3),
            self.make_disc_block(hidden_dim * 2, hidden_dim * 2, stride=3),
            self.make_disc_block(hidden_dim * 2, 1, kernel_size=5, final_layer=True),
        )

    def make_disc_block(self, input_channels, output_channels, kernel_size=4, stride=2, final_layer=False):
        '''
        Function to return a sequence of operations corresponding to a discriminator block of DCGAN, 
        corresponding to a convolution, a batchnorm (except for in the last layer), and an activation.
        Parameters:
            input_channels: how many channels the input feature representation has
            output_channels: how many channels the output feature representation should have
            kernel_size: the size of each convolutional filter, equivalent to (kernel_size, kernel_size)
            stride: the stride of the convolution
            final_layer: a boolean, true if it is the final layer and false otherwise 
                      (affects activation and batchnorm)
        '''
        #     Steps:
        #       1) Add a convolutional layer using the given parameters.
        #       2) Do a batchnorm, except for the last layer.
        #       3) Follow each batchnorm with a LeakyReLU activation with slope 0.2.
        
        # Build the neural block
        if not final_layer:
            return nn.Sequential(
                #### START CODE HERE #### #
                nn.Conv2d(input_channels, output_channels, kernel_size, stride),
                nn.BatchNorm2d(output_channels),
                nn.LeakyReLU(0.2, inplace=True)
                #### END CODE HERE ####
            )
        else: # Final Layer
            return nn.Sequential(
                #### START CODE HERE #### #
                nn.Conv2d(input_channels, output_channels, kernel_size, stride)
                #### END CODE HERE ####
            )

    def forward(self, image):
        '''
        Function for completing a forward pass of the discriminator: Given an image tensor, 
        returns a 1-dimension tensor representing fake/real.
        Parameters:
            image: a flattened image tensor with dimension (im_dim)
        '''
        disc_pred = self.disc(image)
        return disc_pred.view(len(disc_pred), -1)

In [None]:
criterion = nn.BCEWithLogitsLoss()
z_dim = 64
display_step = 117
batch_size = 30
# A learning rate of 0.0002 works well on DCGAN
lr = 0.0002

# These parameters control the optimizer's momentum, which you can read more about here:
# https://distill.pub/2017/momentum/ but you don’t need to worry about it for this course!
beta_1 = 0.5 
beta_2 = 0.999
device = 'cuda'

# You can tranform the image values to be between -1 and 1 (the range of the tanh activation)
transform = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize((0.5,), (0.5,)),
])

dataset = torchvision.datasets.ImageFolder(root="disc-data/ukiyoe", 
                                            transform=transform)

dataloader = DataLoader(
    dataset,
    batch_size=batch_size,
    shuffle=True)

In [None]:

disc = Discriminator().to(device) 
disc_opt = torch.optim.Adam(disc.parameters(), lr=lr, betas=(beta_1, beta_2))

# You initialize the weights to the normal distribution
# with mean 0 and standard deviation 0.02
def weights_init(m):
    if isinstance(m, nn.Conv2d) or isinstance(m, nn.ConvTranspose2d):
        torch.nn.init.normal_(m.weight, 0.0, 0.02)
    if isinstance(m, nn.BatchNorm2d):
        torch.nn.init.normal_(m.weight, 0.0, 0.02)
        torch.nn.init.constant_(m.bias, 0)

disc = disc.apply(weights_init)

In [None]:
def plot_loss(disc_steps, disc_losses, t='van-gogh'):
    plt.figure()
    plt.plot(disc_steps, disc_losses)
    if t == 'van-gogh':
        plt.title('Discriminator loss for Van Gogh paintings')
    elif t == 'ukiyoe':
        plt.title('Discriminator loss for Ukiyoe paintings')
    else:
        plt.title('Discriminator loss for Monet paintings')
        
    plt.xlabel('Steps')
    plt.ylabel('Mean loss')
    plt.savefig('disc-results/plots/' + t + '.pdf')

In [None]:

n_epochs = 100
cur_step = 0
mean_discriminator_loss = 0
disc_losses = []
disc_steps = []

for epoch in range(n_epochs):
    # Dataloader returns the batches
    for img, label in tqdm(dataloader):
        cur_batch_size = len(img)
        img = img.to(device)

        ## Update discriminator ##
        disc_opt.zero_grad()
        
        disc_pred = disc(img)
        disc_pred = disc_pred[:,0]
        label = label.type_as(disc_pred)
        disc_loss = criterion(disc_pred, label)
        disc_loss.backward(retain_graph=True)
        # Update optimizer
        disc_opt.step()
        
        # Keep track of the average discriminator loss
        mean_discriminator_loss += disc_loss.item() / display_step
        
        ## Visualization code ##
        if cur_step % display_step == 0 and cur_step > 0:
            print("Step " + str(cur_step) + ': ' + str(mean_discriminator_loss))
            disc_losses.append(mean_discriminator_loss)
            disc_steps.append(cur_step)
            plot_loss(disc_steps, disc_losses, t='ukiyoe')
            mean_discriminator_loss = 0
            torch.save(disc.state_dict(), 'disc-results/models/ukiyoe_disc_latest.pth')
            
        cur_step += 1