In [1]:
from __future__ import print_function
#%matplotlib inline
import argparse
import os
import random
import torch
import torch.nn as nn
import torch.nn.parallel
import torch.backends.cudnn as cudnn
import torch.optim as optim
import torch.utils.data
import torchvision.datasets as dset
import torchvision.transforms as transforms
import torchvision.utils as vutils
import numpy as np
import matplotlib.pyplot as plt
import matplotlib.animation as animation
from IPython.display import HTML

In [2]:
# Root directory for dataset
dataroot = "/Volumes/gordonssd/cifar10"
# Directory to output images and model checkpoints
outf = "/Users/soymilk/Desktop/icr-gan/output"
# Number of workers for dataloader
workers = 2
# Batch size during training
batch_size = 64
# Spatial size of training images. All images will be resized to this
#   size using a transformer.
image_size = 32
# Number of channels in the training images. For color images this is 3
nc = 3
# Size of z latent vector (i.e. size of generator input)
nz = 100
# Size of feature maps in generator
ngf = 64
# Size of feature maps in discriminator
ndf = 64
# Number of training epochs
num_epochs = 5
# Learning rate for optimizers
lr = 0.0002
# Beta1 hyperparam for Adam optimizers
beta1 = 0.5
# Number of GPUs available. Use 0 for CPU mode.
ngpu = 0
# Consistency regularization for real and fake images (bCR)
lambda_real = 10 
lambda_fake = 10
# Consistency regularization for discriminator and generator (zCR)
lambda_dis = 5
lambda_gen = 0.5
# Latent transform noise
sigma_noise = 0.03

<h3>Load CIFAR10</h3>

In [3]:
dataset = dset.CIFAR10(root=dataroot, download=True,
                       transform=transforms.Compose([
                           transforms.Resize(image_size),
                           transforms.ToTensor(),
                           transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)),
                       ]))
nc=3

dataloader = torch.utils.data.DataLoader(dataset, batch_size=batch_size,
                                         shuffle=True, num_workers=workers)

Files already downloaded and verified


<h3>Initialise pre-trained DCGAN</h3>

In [4]:
# load the models
from dcgan import Discriminator, Generator

netD = Discriminator(ngpu)
netG = Generator(ngpu)

# load weights
netD.load_state_dict(torch.load('weights/netD_epoch_4.pth', map_location=torch.device('cpu')))
netG.load_state_dict(torch.load('weights/netG_epoch_4.pth', map_location=torch.device('cpu')))
if torch.cuda.is_available():
    netD = netD.cuda()
    netG = netG.cuda()

In [5]:
# Decide which device we want to run on
device = torch.device("cuda:0" if (torch.cuda.is_available() and ngpu > 0) else "cpu")

<h3>Loss function and Optimizers</h3>

In [6]:
criterion = nn.BCELoss()
l2loss = nn.MSELoss()

fixed_noise = torch.randn(batch_size, nz, 1, 1, device=device)
real_label = 1
fake_label = 0

# setup optimizer
optimizerD = optim.Adam(netD.parameters(), lr=lr, betas=(beta1, 0.999))
optimizerG = optim.Adam(netG.parameters(), lr=lr, betas=(beta1, 0.999))

<h3>Augmentations</h3>

In [7]:
transform = transforms.RandomRotation(25)

<h2>Train Balanced Consistency Regularization DCGAN (bCR-DCGAN)</h2>

In [None]:
for epoch in range(num_epochs):
    for i, data in enumerate(dataloader, 0):
        ############################
        # (1) Update D network: maximize log(D(x)) + log(1 - D(G(z)))
        ###########################
        # Train with all-real batch
        netD.zero_grad()
        x = data[0].to(device)
        
        # bCR: Augment real images
        T_x = transform(x)
        
        batch_size = x.size(0)
        label = torch.full((batch_size,), real_label, dtype=torch.float, device=device)

        D_x = netD(x)
        
        # bCR: Forward pass augmented real batch through D
        D_T_x = netD(T_x) 
        
        errD_real = criterion(D_x, label)

        # bCR: Calculate L_real: |D(x) − D(T(x))|^2 
        L_real = l2loss(D_x, D_T_x)
        
        (errD_real + lambda_real*L_real).backward()
        
        D_x = D_x.mean().item()
        
        # train with fake
        z = torch.randn(batch_size, nz, 1, 1, device=device)
        G_z = netG(z)
        
        # bCR: Augment generated images
        T_G_z = transform(G_z.detach())
        
        label.fill_(fake_label)
        D_G_z = netD(G_z.detach())
        
        # bCR: Forward pass augmented fake batch through D
        D_T_G_z = netD(T_G_z)
        
        errD_fake = criterion(D_G_z, label)
        
        # bCR: Calculate L_fake: |D(G(z)) − D(T(G(z)))|^2
        L_fake = l2loss(D_G_z, D_T_G_z)
        
        (errD_fake + lambda_fake*L_fake).backward()
        
        D_G_z1 = D_G_z.mean().item()
        L_D = errD_real + errD_fake
        
        optimizerD.step()

        ############################
        # (2) Update G network: maximize log(D(G(z)))
        ###########################
        netG.zero_grad()
        label.fill_(real_label)  # fake labels are real for generator cost
        D_G_z = netD(G_z)
        errG = criterion(D_G_z, label)
        errG.backward()
        D_G_z2 = D_G_z.mean().item()
        optimizerG.step()

        print('[%d/%d][%d/%d] Loss_D: %.4f Loss_real: %.4f Loss_fake: %.4f Loss_G: %.4f D(x): %.4f D(G(z)): %.4f / %.4f'
              % (epoch, num_epochs, i, len(dataloader),
                 L_D.item(), L_real.item(), L_fake.item(), errG.item(), D_x, D_G_z1, D_G_z2))
        if i % 100 == 0:
            vutils.save_image(x,
                    '%s/real_samples.png' % outf,
                    normalize=True)
            fake = netG(fixed_noise)
            vutils.save_image(fake.detach(),
                    '%s/fake_samples_epoch_%03d.png' % (outf, epoch),
                    normalize=True)

    # do checkpointing
    torch.save(netG.state_dict(), '%s/netG_epoch_%d.pth' % (outf, epoch))
    torch.save(netD.state_dict(), '%s/netD_epoch_%d.pth' % (outf, epoch))

<h2>Train Latent Consistency Regularization DCGAN (zCR-DCGAN)</h2>

In [None]:
for epoch in range(num_epochs):
    for i, data in enumerate(dataloader, 0):
        ############################
        # (1) Update D network: maximize log(D(x)) + log(1 - D(G(z)))
        ###########################
        # Train with all-real batch
        netD.zero_grad()
        x = data[0].to(device)
        
        batch_size = x.size(0)
        label = torch.full((batch_size,), real_label, dtype=torch.float, device=device)
        
        D_x = netD(x)
        errD_real = criterion(D_x, label)
        errD_real.backward()
        D_x = D_x.mean().item()
        
        # train with fake
        z = torch.randn(batch_size, nz, 1, 1, device=device)
        
        # zCR: transform z
        T_z = z + torch.normal(0, sigma_noise, z.shape)
        
        G_z = netG(z)
        
        # zCR: Forward pass T_z through G
        G_T_z = netG(T_z)
        
        label.fill_(fake_label)
        D_G_z = netD(G_z.detach())
        errD_fake = criterion(D_G_z, label)
        
        # zCR: Forward pass G_T_z through G
        D_G_T_z = netD(G_T_z.detach())
        # zCR: Calculate L_dis |D(G(z)) − D(G(T(z))|^2
        L_dis = l2loss(D_G_z, D_G_T_z)
        
        (errD_fake + lambda_dis*L_dis).backward() 
        
        D_G_z1 = D_G_z.mean().item()
        L_D = errD_real + errD_fake
        
        optimizerD.step()

        ############################
        # (2) Update G network: maximize log(D(G(z)))
        ###########################
        netG.zero_grad()
        label.fill_(real_label)  # fake labels are real for generator cost
        D_G_z = netD(G_z)
        errG = criterion(D_G_z, label)
        
        D_G_T_z = netD(G_T_z)
        # zCR_TODO: Calculate L_gen −|G(z) − G(T(z)|^2 
        L_gen = -l2loss(G_z, G_T_z)
        
        (errG + lambda_gen*L_gen).backward()
        
        D_G_z2 = D_G_z.mean().item()
        optimizerG.step()
        
        print('[%d/%d][%d/%d] Loss_D: %.4f Loss_dis: %.4f Loss_G: %.4f Loss_gen: %.4f D(x): %.4f D(G(z)): %.4f / %.4f'
                      % (epoch, num_epochs, i, len(dataloader),
                         L_D.item(), L_dis.item(), errG.item(), L_gen.item(), D_x, D_G_z1, D_G_z2))
        
        if i % 100 == 0:
            vutils.save_image(x,
                    '%s/real_samples.png' % outf,
                    normalize=True)
            fake = netG(fixed_noise)
            vutils.save_image(fake.detach(),
                    '%s/fake_samples_epoch_%03d.png' % (outf, epoch),
                    normalize=True)

    # do checkpointing
    torch.save(netG.state_dict(), '%s/netG_epoch_%d.pth' % (outf, epoch))
    torch.save(netD.state_dict(), '%s/netD_epoch_%d.pth' % (outf, epoch))

<h2>Train Improved Consistency Regularization DCGAN (ICR-DCGAN) bCR+zCR</h2>

In [None]:
for epoch in range(num_epochs):
    for i, data in enumerate(dataloader, 0):
        ############################
        # (1) Update D network: maximize log(D(x)) + log(1 - D(G(z)))
        ###########################
        # Train with all-real batch
        netD.zero_grad()
        x = data[0].to(device)
        
        # bCR: Augment real images
        T_x = transform(x)
        
        batch_size = x.size(0)
        label = torch.full((batch_size,), real_label, dtype=torch.float, device=device)
        
        D_x = netD(x)
        
        # bCR: Forward pass augmented real batch through D
        D_T_x = netD(T_x) 
        
        errD_real = criterion(D_x, label)
        # bCR: Calculate L_real: |D(x) − D(T(x))|^2 
        L_real = l2loss(D_x, D_T_x)
        
        (errD_real + lambda_real*L_real).backward()
        
        D_x = D_x.mean().item()
        
        # train with fake
        z = torch.randn(batch_size, nz, 1, 1, device=device)
        
        # zCR: transform z
        T_z = z + torch.normal(0,1,z.shape)
        
        G_z = netG(z)
        
        # bCR: Augment generated images
        T_G_z = transform(G_z.detach())
        
        # zCR: Forward pass T_z through G
        G_T_z = netG(T_z)
        
        label.fill_(fake_label)
        D_G_z = netD(G_z.detach())
        
        # bCR: Forward pass augmented fake batch through D
        D_T_G_z = netD(T_G_z)
        
        errD_fake = criterion(D_G_z, label)
        
        # bCR: Calculate L_fake: |D(G(z)) − D(T(G(z)))|^2
        L_fake = l2loss(D_G_z, D_T_G_z)
        
        # zCR: Forward pass G_T_z through G
        D_G_T_z = netD(G_T_z.detach())
        # zCR: Calculate L_dis |D(G(z)) − D(G(T(z))|^2
        L_dis = l2loss(D_G_z, D_G_T_z)
        
        (errD_fake + lambda_dis*L_dis + lambda_fake*L_fake).backward() 
        
        D_G_z1 = D_G_z.mean().item()
        L_D = errD_real + errD_fake
        
        optimizerD.step()

        ############################
        # (2) Update G network: maximize log(D(G(z)))
        ###########################
        netG.zero_grad()
        label.fill_(real_label)  # fake labels are real for generator cost
        D_G_z = netD(G_z)
        errG = criterion(D_G_z, label)
        
        D_G_T_z = netD(G_T_z)
        # zCR_TODO: Calculate L_gen −|G(z) − G(T(z)|^2 
        L_gen = -l2loss(G_z, G_T_z)
        
        (errG + lambda_gen*L_gen).backward()
        
        D_G_z2 = D_G_z.mean().item()
        optimizerG.step()

        
        if i % 100 == 0:
            
            print('[%d/%d][%d/%d] Loss_D: %.4f Loss_real: %.4f Loss_fake: %.4f Loss_dis: %.4f Loss_G: %.4f Loss_gen: %.4f D(x): %.4f D(G(z)): %.4f / %.4f'
              % (epoch, num_epochs, i, len(dataloader),
                 L_D.item(), L_real.item(), L_fake.item(), L_dis.item(), errG.item(), L_gen.item(), D_x, D_G_z1, D_G_z2))
            
            vutils.save_image(x,
                    '%s/real_samples.png' % outf,
                    normalize=True)
            fake = netG(fixed_noise)
            vutils.save_image(fake.detach(),
                    '%s/fake_samples_epoch_%03d.png' % (outf, epoch),
                    normalize=True)

    # do checkpointing
    torch.save(netG.state_dict(), '%s/netG_epoch_%d.pth' % (outf, epoch))
    torch.save(netD.state_dict(), '%s/netD_epoch_%d.pth' % (outf, epoch))