In [1]:
import torch
from torch import nn
from tqdm.auto import tqdm
import torchvision
from torchvision import transforms
from torchvision.datasets import MNIST
from torchvision.utils import make_grid
from torch.utils.data import DataLoader, Dataset
import matplotlib.pyplot as plt
torch.manual_seed(0) # Set for testing purposes, please do not change!


def show_tensor_images(image_tensor, num_images=1, size=(1, 256, 256)):
    '''
    Function for visualizing images: Given a tensor of images, number of images, and
    size per image, plots and prints the images in an uniform grid.
    '''
    image_tensor = (image_tensor + 1) / 2
    image_unflat = image_tensor.detach().cpu()
    image_grid = make_grid(image_unflat[:num_images], nrow=5)
    plt.imshow(image_grid.permute(1, 2, 0).squeeze())
    plt.show()

In [2]:
class Discriminator(nn.Module):
    '''
    Discriminator Class
    Values:
        im_chan: the number of channels of the output image, a scalar
              (MNIST is black-and-white, so 1 channel is your default)
    hidden_dim: the inner dimension, a scalar
    '''
    def __init__(self, im_chan=3, hidden_dim=16):
        super(Discriminator, self).__init__()
        self.disc = nn.Sequential(
            self.make_disc_block(im_chan, hidden_dim, stride=4),
            self.make_disc_block(hidden_dim, hidden_dim * 2, stride=3),
            self.make_disc_block(hidden_dim * 2, hidden_dim * 2, stride=3),
            self.make_disc_block(hidden_dim * 2, 1, kernel_size=5, final_layer=True),
        )

    def make_disc_block(self, input_channels, output_channels, kernel_size=4, stride=2, final_layer=False):
        '''
        Function to return a sequence of operations corresponding to a discriminator block of DCGAN, 
        corresponding to a convolution, a batchnorm (except for in the last layer), and an activation.
        Parameters:
            input_channels: how many channels the input feature representation has
            output_channels: how many channels the output feature representation should have
            kernel_size: the size of each convolutional filter, equivalent to (kernel_size, kernel_size)
            stride: the stride of the convolution
            final_layer: a boolean, true if it is the final layer and false otherwise 
                      (affects activation and batchnorm)
        '''
        #     Steps:
        #       1) Add a convolutional layer using the given parameters.
        #       2) Do a batchnorm, except for the last layer.
        #       3) Follow each batchnorm with a LeakyReLU activation with slope 0.2.
        
        # Build the neural block
        if not final_layer:
            return nn.Sequential(
                #### START CODE HERE #### #
                nn.Conv2d(input_channels, output_channels, kernel_size, stride),
                nn.BatchNorm2d(output_channels),
                nn.LeakyReLU(0.2, inplace=True)
                #### END CODE HERE ####
            )
        else: # Final Layer
            return nn.Sequential(
                #### START CODE HERE #### #
                nn.Conv2d(input_channels, output_channels, kernel_size, stride)
                #### END CODE HERE ####
            )

    def forward(self, image):
        '''
        Function for completing a forward pass of the discriminator: Given an image tensor, 
        returns a 1-dimension tensor representing fake/real.
        Parameters:
            image: a flattened image tensor with dimension (im_dim)
        '''
        disc_pred = self.disc(image)
        return disc_pred.view(len(disc_pred), -1)

In [18]:
criterion = nn.BCEWithLogitsLoss()
z_dim = 64
display_step = 100
batch_size = 20
# A learning rate of 0.0002 works well on DCGAN
lr = 0.0002

# These parameters control the optimizer's momentum, which you can read more about here:
# https://distill.pub/2017/momentum/ but you don’t need to worry about it for this course!
beta_1 = 0.5 
beta_2 = 0.999
device = 'cuda'

# You can tranform the image values to be between -1 and 1 (the range of the tanh activation)
transform = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize((0.5,), (0.5,)),
])

dataset = torchvision.datasets.ImageFolder(root="disc-data/vangogh", 
                                            transform=transform)

dataloader = DataLoader(
    dataset,
    batch_size=batch_size,
    shuffle=True)

In [19]:

disc = Discriminator().to(device) 
disc_opt = torch.optim.Adam(disc.parameters(), lr=lr, betas=(beta_1, beta_2))

# You initialize the weights to the normal distribution
# with mean 0 and standard deviation 0.02
def weights_init(m):
    if isinstance(m, nn.Conv2d) or isinstance(m, nn.ConvTranspose2d):
        torch.nn.init.normal_(m.weight, 0.0, 0.02)
    if isinstance(m, nn.BatchNorm2d):
        torch.nn.init.normal_(m.weight, 0.0, 0.02)
        torch.nn.init.constant_(m.bias, 0)

disc = disc.apply(weights_init)

In [20]:
def plot_loss(cur_step, disc_steps, disc_losses, t='van-gogh'):
    plt.figure()
    plt.plot(disc_steps, disc_losses)
    if t == 'van-gogh':
        plt.title('Discriminator Losses for Van Gogh paintings')
    else:
        plt.title('Discriminator Losses for Monet paintings')
        
    plt.xlabel('Steps')
    plt.ylabel('Losses')
    plt.savefig('disc-results/plots/' + t + '-' + str(cur_step) + '.pdf')

In [22]:

n_epochs = 100
cur_step = 0
mean_discriminator_loss = 0
disc_losses = []
disc_steps = []

for epoch in range(n_epochs):
    print('\n EPOCH ' + str(epoch) + '\n')
    # Dataloader returns the batches
    for img, label in tqdm(dataloader):
        cur_batch_size = len(img)
        img = img.to(device)

        ## Update discriminator ##
        disc_opt.zero_grad()
        
        disc_pred = disc(img)
        disc_pred = disc_pred[:,0]
        label = label.type_as(disc_pred)
        disc_loss = criterion(disc_pred, label)
        disc_loss.backward(retain_graph=True)
        # Update optimizer
        disc_opt.step()
        
        # Keep track of the average discriminator loss
        mean_discriminator_loss += disc_loss.item() / display_step

        #disc_losses.append(disc_loss)
        #disc_steps.append(cur_step)
        
        ## Visualization code ##
        if cur_step % display_step == 0 and cur_step > 0:
            print("Step " + str(cur_step) + ': ' + str(mean_discriminator_loss))
            #plot_loss(cur_step, disc_steps, disc_losses)
            mean_discriminator_loss = 0
            torch.save(disc.state_dict(), 'disc-results/models/vangogh_disc_latest.pth')
            
        cur_step += 1


 EPOCH 0



HBox(children=(FloatProgress(value=0.0, max=58.0), HTML(value='')))



 EPOCH 1



HBox(children=(FloatProgress(value=0.0, max=58.0), HTML(value='')))

Step 100: 0.5356849241256715


 EPOCH 2



HBox(children=(FloatProgress(value=0.0, max=58.0), HTML(value='')))



 EPOCH 3



HBox(children=(FloatProgress(value=0.0, max=58.0), HTML(value='')))

Step 200: 0.46639860272407535


 EPOCH 4



HBox(children=(FloatProgress(value=0.0, max=58.0), HTML(value='')))



 EPOCH 5



HBox(children=(FloatProgress(value=0.0, max=58.0), HTML(value='')))

Step 300: 0.39868572711944594


 EPOCH 6



HBox(children=(FloatProgress(value=0.0, max=58.0), HTML(value='')))

Step 400: 0.38325640961527835


 EPOCH 7



HBox(children=(FloatProgress(value=0.0, max=58.0), HTML(value='')))



 EPOCH 8



HBox(children=(FloatProgress(value=0.0, max=58.0), HTML(value='')))

Step 500: 0.3157583811879159


 EPOCH 9



HBox(children=(FloatProgress(value=0.0, max=58.0), HTML(value='')))



 EPOCH 10



HBox(children=(FloatProgress(value=0.0, max=58.0), HTML(value='')))

Step 600: 0.25914408288896096


 EPOCH 11



HBox(children=(FloatProgress(value=0.0, max=58.0), HTML(value='')))



 EPOCH 12



HBox(children=(FloatProgress(value=0.0, max=58.0), HTML(value='')))

Step 700: 0.21446415565907956


 EPOCH 13



HBox(children=(FloatProgress(value=0.0, max=58.0), HTML(value='')))

Step 800: 0.16893841236829754


 EPOCH 14



HBox(children=(FloatProgress(value=0.0, max=58.0), HTML(value='')))



 EPOCH 15



HBox(children=(FloatProgress(value=0.0, max=58.0), HTML(value='')))

Step 900: 0.15296912323683498


 EPOCH 16



HBox(children=(FloatProgress(value=0.0, max=58.0), HTML(value='')))



 EPOCH 17



HBox(children=(FloatProgress(value=0.0, max=58.0), HTML(value='')))

Step 1000: 0.1080858474969864


 EPOCH 18



HBox(children=(FloatProgress(value=0.0, max=58.0), HTML(value='')))

Step 1100: 0.10407098136842251


 EPOCH 19



HBox(children=(FloatProgress(value=0.0, max=58.0), HTML(value='')))



 EPOCH 20



HBox(children=(FloatProgress(value=0.0, max=58.0), HTML(value='')))

Step 1200: 0.08146379454061387


 EPOCH 21



HBox(children=(FloatProgress(value=0.0, max=58.0), HTML(value='')))



 EPOCH 22



HBox(children=(FloatProgress(value=0.0, max=58.0), HTML(value='')))

Step 1300: 0.0782480700686574


 EPOCH 23



HBox(children=(FloatProgress(value=0.0, max=58.0), HTML(value='')))



 EPOCH 24



HBox(children=(FloatProgress(value=0.0, max=58.0), HTML(value='')))

Step 1400: 0.06467971766367556


 EPOCH 25



HBox(children=(FloatProgress(value=0.0, max=58.0), HTML(value='')))

Step 1500: 0.04968314323574302


 EPOCH 26



HBox(children=(FloatProgress(value=0.0, max=58.0), HTML(value='')))



 EPOCH 27



HBox(children=(FloatProgress(value=0.0, max=58.0), HTML(value='')))

Step 1600: 0.04154657822102308


 EPOCH 28



HBox(children=(FloatProgress(value=0.0, max=58.0), HTML(value='')))



 EPOCH 29



HBox(children=(FloatProgress(value=0.0, max=58.0), HTML(value='')))

Step 1700: 0.04673801412805915


 EPOCH 30



HBox(children=(FloatProgress(value=0.0, max=58.0), HTML(value='')))



 EPOCH 31



HBox(children=(FloatProgress(value=0.0, max=58.0), HTML(value='')))

Step 1800: 0.04560056244488804


 EPOCH 32



HBox(children=(FloatProgress(value=0.0, max=58.0), HTML(value='')))

Step 1900: 0.033763945777900525


 EPOCH 33



HBox(children=(FloatProgress(value=0.0, max=58.0), HTML(value='')))



 EPOCH 34



HBox(children=(FloatProgress(value=0.0, max=58.0), HTML(value='')))

Step 2000: 0.022132974190171784


 EPOCH 35



HBox(children=(FloatProgress(value=0.0, max=58.0), HTML(value='')))



 EPOCH 36



HBox(children=(FloatProgress(value=0.0, max=58.0), HTML(value='')))

Step 2100: 0.012808867964195088


 EPOCH 37



HBox(children=(FloatProgress(value=0.0, max=58.0), HTML(value='')))

Step 2200: 0.033785658457782135


 EPOCH 38



HBox(children=(FloatProgress(value=0.0, max=58.0), HTML(value='')))



 EPOCH 39



HBox(children=(FloatProgress(value=0.0, max=58.0), HTML(value='')))

Step 2300: 0.017074689650908106


 EPOCH 40



HBox(children=(FloatProgress(value=0.0, max=58.0), HTML(value='')))



 EPOCH 41



HBox(children=(FloatProgress(value=0.0, max=58.0), HTML(value='')))

Step 2400: 0.017111248844303192


 EPOCH 42



HBox(children=(FloatProgress(value=0.0, max=58.0), HTML(value='')))



 EPOCH 43



HBox(children=(FloatProgress(value=0.0, max=58.0), HTML(value='')))

Step 2500: 0.01243323339615017


 EPOCH 44



HBox(children=(FloatProgress(value=0.0, max=58.0), HTML(value='')))

Step 2600: 0.02169345992268063


 EPOCH 45



HBox(children=(FloatProgress(value=0.0, max=58.0), HTML(value='')))



 EPOCH 46



HBox(children=(FloatProgress(value=0.0, max=58.0), HTML(value='')))

Step 2700: 0.013066048522596253


 EPOCH 47



HBox(children=(FloatProgress(value=0.0, max=58.0), HTML(value='')))



 EPOCH 48



HBox(children=(FloatProgress(value=0.0, max=58.0), HTML(value='')))

Step 2800: 0.040710506888572126


 EPOCH 49



HBox(children=(FloatProgress(value=0.0, max=58.0), HTML(value='')))



 EPOCH 50



HBox(children=(FloatProgress(value=0.0, max=58.0), HTML(value='')))

Step 2900: 0.031438387540983964


 EPOCH 51



HBox(children=(FloatProgress(value=0.0, max=58.0), HTML(value='')))

Step 3000: 0.010894819701788941


 EPOCH 52



HBox(children=(FloatProgress(value=0.0, max=58.0), HTML(value='')))



 EPOCH 53



HBox(children=(FloatProgress(value=0.0, max=58.0), HTML(value='')))

Step 3100: 0.010297106912184968


 EPOCH 54



HBox(children=(FloatProgress(value=0.0, max=58.0), HTML(value='')))



 EPOCH 55



HBox(children=(FloatProgress(value=0.0, max=58.0), HTML(value='')))

Step 3200: 0.013631057112943379


 EPOCH 56



HBox(children=(FloatProgress(value=0.0, max=58.0), HTML(value='')))

Step 3300: 0.005186137089622206


 EPOCH 57



HBox(children=(FloatProgress(value=0.0, max=58.0), HTML(value='')))



 EPOCH 58



HBox(children=(FloatProgress(value=0.0, max=58.0), HTML(value='')))

Step 3400: 0.0035784930436057035


 EPOCH 59



HBox(children=(FloatProgress(value=0.0, max=58.0), HTML(value='')))

KeyboardInterrupt: 