In [2]:
%matplotlib inline

In [3]:
from __future__ import print_function
#%matplotlib inline
import argparse
import os
import random
import torch
import torch.nn as nn
import torch.nn.parallel
import torch.backends.cudnn as cudnn
import torch.optim as optim
import torch.utils.data
import torchvision.datasets as dset
import torchvision.transforms as transforms
import torchvision.utils as vutils
import numpy as np
import matplotlib.pyplot as plt
import matplotlib.animation as animation
from IPython.display import HTML
import torch.nn.functional as F
from skimage import io, transform

import copy
from torch.autograd import grad
import torchvision
import pandas as pd

from torch.utils.tensorboard import SummaryWriter
writer = SummaryWriter('runs/cPGGANs')

# import torchvision.transforms.functional as TF

from my_utils import GeneratorC, DiscriminatorC, save_model, save_model2

In [5]:
def get_data_loader(x: torch.Tensor, y: torch.Tensor, batch_size=5, numThreads = 1) -> torch.utils.data.DataLoader:
    """Fetches a DataLoader, which is built into PyTorch, and provides a
    convenient (and efficient) method for sampling.

    :param x: (torch.Tensor) inputs
    :param y: (torch.Tensor) labels
    :param batch_size: (int)
    """
    dataset = torch.utils.data.TensorDataset(x, y)
    data_loader = torch.utils.data.DataLoader(
        dataset, num_workers=numThreads, shuffle=True, batch_size=batch_size)

    return data_loader

def unroll(data):
    gp_lambda = 10
    fake_label = 0
        ############################
    # (1) Update D network: maximize log(D(x)) + log(1 - D(G(z)))
    ###########################
    ## Train with all-real batch
    netDK.zero_grad()
    # Format batch
    real_cpu = data[0].to(device)
    y_onehot = turn_label_to_one_hot(data[1].to(device))
    b_size = real_cpu.size(0)
    label = torch.full((b_size,), fake_label, device=device)
    # Forward pass real batch through D
    output = netDK(real_cpu, y_onehot, flag = flag, alpha=alpha).view(-1)
    # Calculate loss on all-real batch
#     errD_real = criterion(output, label)
    errD_real = - output.mean()
    # Calculate gradients for D in backward pass
    errD_real.backward()
    D_x = output.mean().item()


    ## Train with all-fake batch
    # Generate batch of latent vectors
    z_fake, y_onehot_fake = gen_fake(b_size)
#     noise = torch.randn(b_size, nz, 1, 1, device=device)
    # Generate fake image batch with G
    fake = netG(z_fake, flag = flag, alpha=alpha)
    label.fill_(fake_label)
    # Classify all fake batch with D
    output = netDK(fake.detach(), y_onehot_fake, flag = flag, alpha=alpha).view(-1)
    # Calculate D's loss on the all-fake batch
    errD_fake = output.mean() #+ 10.0 * get_gp(real_cpu, fake, netDK, y_onehot_fake, flag = flag, alpha=alpha )
    # Calculate the gradients for this batch
    errD_fake.backward()
    D_G_z1 = output.mean().item()
    # Add the gradients from the all-real and all-fake batches
    errD = errD_real + errD_fake
    # Update D
    optimizerDK.step()

def copy_unroll(data, K=1):
    optimizerDK.load_state_dict(optimizerD.state_dict())
#     optimizerDK = copy.deepcopy(optimizerD)
    netDK.load_state_dict(netD.state_dict())
    for i in range(K):
        unroll(data)  
        
def get_gp(x, fake_x, nn, y_onehot_fake, flag , alpha):
    batch = x.shape[0]
    alpha = torch.rand(batch, 1, 1, 1).to(device)

    x_hat = alpha * x.detach() + (1 - alpha) * fake_x.detach()
    x_hat.requires_grad_(True)

    pred_hat = nn(x_hat, y_onehot_fake, flag = flag, alpha=alpha)
    gradients = grad(outputs=pred_hat, inputs=x_hat,
                     grad_outputs=torch.ones(pred_hat.size()).to(device),
                     create_graph=True, retain_graph=True, only_inputs=True)[0]

    grad_norm = gradients.view(batch, -1).norm(2, dim=1)
    return grad_norm.sub(750).pow(2).mean()/(750**2)

In [22]:
# Number of workers for dataloader
workers = 8

# Batch size during training
batch_size = 16

# Spatial size of training images. All images will be resized to this
#   size using a transformer.
image_size = 64
cropSec = 64

# Number of channels in the training images. For color images this is 3
nc = 3

# Size of z latent vector (i.e. size of generator input)
nz = 512

# Learning rate for optimizers
lr = 0.0001

# Beta1 hyperparam for Adam optimizers
beta1 = 0.0

# Number of GPUs available. Use 0 for CPU mode.
ngpu = 1
nb_digits = 5

In [1]:
#LOAD DATA
dataloader = get_data_loader(x_train, y_train, batch_size, numThreads=workers)

# Decide which device we want to run on
device = torch.device("cuda:0" if (torch.cuda.is_available() and ngpu > 0) else "cpu")

# Plot some training images
real_batch = next(iter(dataloader))
plt.figure(figsize=(8,8))
plt.axis("off")
plt.title("Training Images")
plt.imshow(np.transpose(vutils.make_grid(real_batch[0].to(device)[:64], padding=2, normalize=True).cpu(),(1,2,0)))


In [2]:
a_step = dataloader.dataset.tensors[0].shape[0] // batch_size
num_epochs = 800000//dataloader.dataset.tensors[0].shape[0]  

In [28]:
def turn_label_to_one_hot(y):
    y_onehot = torch.FloatTensor(y.size()[0], nb_digits).to(device)
    y_onehot.zero_()
    y_onehot.scatter_(1, y.view(-1,1),1)
    return y_onehot

def gen_fake(b_size):
    z = torch.randn(b_size, nz, 1, 1, device=device)
    y = torch.LongTensor(b_size,1).random_().to(device) % nb_digits
    y_onehot = torch.FloatTensor(b_size, nb_digits).to(device)
    y_onehot.zero_()
    y_onehot.scatter_(1, y, 1)
    return torch.cat((z,y_onehot.view(-1,nb_digits,1,1)),dim=1), y_onehot.scatter_(1, y, 1)

In [None]:
# Create the generator
netG = GeneratorC(ngpu, nb_digits).to(device)

# Handle multi-gpu if desired
if (device.type == 'cuda') and (ngpu > 1):
    netG = nn.DataParallel(netG, list(range(ngpu)))

In [None]:
# Create the Discriminator
netD = DiscriminatorC(ngpu, nb_digits).to(device)
netDK = copy.deepcopy(netD)

# Handle multi-gpu if desired
if (device.type == 'cuda') and (ngpu > 1):
    netD = nn.DataParallel(netD, list(range(ngpu)))
    netDK = copy.deepcopy(netD)

In [None]:
# Initialize BCELoss function
# criterion = nn.BCELoss()
# criterion = nn.MSELoss()

# Establish convention for real and fake labels during training
real_label = 1
fake_label = 0

# Setup Adam optimizers for both G and D
optimizerD = optim.Adam(netD.parameters(), lr=lr, betas=(beta1, 0.999))
optimizerDK = optim.Adam(netDK.parameters(), lr=lr, betas=(beta1, 0.999))
optimizerG = optim.Adam(netG.parameters(), lr=lr, betas=(beta1, 0.999))

In [None]:
####################
PATH = "/PATH/TO/SAVED/MODEL/Gen-32-alph-1-epoch-99-complete.pt"
checkpoint = torch.load(PATH)
netG.load_state_dict(checkpoint['model_state_dict'])
optimizerG.load_state_dict(checkpoint['optimizer_state_dict'])
epoch_s = checkpoint['epoch']
loss = checkpoint['loss']

netG.eval()

PATH = "/PATH/TO/SAVED/MODEL/Dis-32-alph-1-epoch-99-complete.pt"
checkpoint = torch.load(PATH)
netD.load_state_dict(checkpoint['model_state_dict'])
optimizerD.load_state_dict(checkpoint['optimizer_state_dict'])
optimizerDK.load_state_dict(checkpoint['optimizer_state_dict'])
# optimizerDK = copy.deepcopy(optimizerD)
epoch_s1 = checkpoint['epoch']
# loss = checkpoint['loss']
# netDK = copy.deepcopy(netD)
netDK.load_state_dict(netD.state_dict())

netD.eval()
print("Done!")

In [29]:

fixed_noise, _ = gen_fake(64) #torch.randn(64, nz+nb_digits, 1, 1, device=device)
real_batch = next(iter(dataloader))
b_size = real_batch[0].size(0)
noise, _ = gen_fake(batch_size) # torch.rand(b_size, nz+nb_digits, 1, 1, device=device)

# flags = [[4, "stable"], [8, "transition"], [8, "stable"], [16, "transition"], [16, "stable"], [32, "transition"], [32, "stable"]]

In [30]:
#GEN AND DISC CHECK
flag = [64, "transition"]
alpha = 0.1
for i, data in enumerate(dataloader, 0):
    if i==0:
        break
print(data[0].to(device).shape)
print(netD(data[0].to(device),turn_label_to_one_hot(data[1].to(device)), flag = flag, alpha=alpha).shape)
z_fake, y_onehot_fake = gen_fake(batch_size)
print("y_onehot_fake: ", y_onehot_fake.shape, " y_onehot_real: ", turn_label_to_one_hot(data[1].to(device)).shape)
print(netG(z_fake, flag = flag, alpha=alpha).shape)
netD(netG(z_fake, flag = flag, alpha=alpha), y_onehot_fake, flag = flag, alpha=alpha).shape

torch.Size([16, 3, 64, 64])
torch.Size([16, 1])
y_onehot_fake:  torch.Size([16, 5])  y_onehot_real:  torch.Size([16, 5])
torch.Size([16, 3, 64, 64])


torch.Size([16, 1])

In [None]:
# Training Loop

# Lists to keep track of progress
img_list = []
G_losses = []
D_losses = []
iters = 0

flag = [64, "transition"]
# alpha = (epoch_s+1)/num_epochs
alpha = 0

print("Starting Training Loop...")
# For each epoch
# for epoch in range(epoch_s + 1, num_epochs):
for epoch in range(num_epochs):
    # For each batch in the dataloader
    for i, data in enumerate(dataloader, 0):
        if flag[1] == "stable":
            pass
        else:
            alpha = min(1, alpha + 1 / (a_step * num_epochs) )
        
        ############################
        # (1) Update D network: maximize log(D(x)) + log(1 - D(G(z)))
        ####copyright####################
        ## Train with all-real batch
        netD.zero_grad()
        # Format batch
        real_cpu = data[0].to(device)
        y_onehot = turn_label_to_one_hot(data[1].to(device))
        b_size = real_cpu.size(0)
        label = torch.full((b_size,), real_label, device=device)
        # Forward pass real batch through D
        output = netD(real_cpu, y_onehot, flag = flag, alpha=alpha).view(-1)
        # Calculate loss on all-real batch
#         errD_real = criterion(output, label)
        errD_real = - output.mean()
        # Calculate gradients for D in backward pass
        errD_real.backward()
        D_x = output.mean().item()
        

        ## Train with all-fake batch
        # Generate batch of latent vectors
#         noise = torch.randn(b_size, nz, 1, 1, device=device)
        z_fake, y_onehot_fake = gen_fake(b_size)
        # Generate fake image batch with G
        fake = netG(z_fake, flag = flag, alpha=alpha)
        label.fill_(fake_label)
        # Classify all fake batch with D
        output = netD(fake.detach(), y_onehot_fake, flag = flag, alpha=alpha).view(-1)
        # Calculate D's loss on the all-fake batch
#         errD_fake = - criterion(output, label)
        errD_fake = output.mean() #+ 10.0 * get_gp(real_cpu, fake, netD, y_onehot_fake, flag = flag, alpha=alpha )
        # Calculate the gradients for this batch
        errD_fake.backward()
        D_G_z1 = output.mean().item()
        # Add the gradients from the all-real and all-fake batches
        errD = errD_real + errD_fake
        # Update D
        optimizerD.step()
        
#         if errD_fake.item() + errD_real.item() > 5:
#             print("D fake error is ", errD_fake.item(), "and  D real error is ", errD_real.item(), " in i = ", i, ", epoch = ", epoch )
            
        copy_unroll(data, K=5)

        ############################
        # (2) Update G network: maximize log(D(G(z)))
        ###########################
        netG.zero_grad()
        label.fill_(real_label)  # fake labels are real for generator cost
        # Since we just updated D, perform another forward pass of all-fake batch through D
#         output = netD(fake, flag = flag, alpha=alpha).view(-1)
        output = netDK(fake, y_onehot_fake, flag = flag, alpha=alpha).view(-1)
        # Calculate G's loss based on this output
#         errG = criterion(output, label)
        errG = - output.mean()
        # Calculate gradients for G
        errG.backward()
        D_G_z2 = output.mean().item()
        # Update G
        optimizerG.step()
        
#         if errG.item() > 5:
#             print("G error is ", errG.item(), " in i = ", i, ", epoch = ", epoch )
        
        # Output training stats
        if i % 50 == 0:
            print('[%d/%d][%d/%d]\tLoss_D: %.4f\tLoss_G: %.4f\tD(x): %.4f\tD(G(z)): %.4f / %.4f \t alpha: %.4f'
                  % (epoch, num_epochs, i, len(dataloader),
                     errD.item(), errG.item(), D_x, D_G_z1, D_G_z2, alpha))
            writer.add_scalar('cPGGANS Discriminator loss {}-{}'.format(flag[0], flag[1]),
                            errD.item(),
                            epoch * len(dataloader) + i)
            writer.add_scalar('cPGGANS Generator loss {}-{}'.format(flag[0], flag[1]),
                            errG.item(),
                            epoch * len(dataloader) + i)
            with torch.no_grad():
                fake = netG(fixed_noise, flag = flag, alpha=alpha).detach().cpu()
                fake_2 = (fake - fake.min())/(fake.max() - fake.min())
                real_cpu_2 = (real_cpu - real_cpu.min())/(real_cpu.max() - real_cpu.min())
            img_list.append(vutils.make_grid(fake, padding=2, normalize=True))
            writer.add_image('cPGGANS real{}-{}'.format(flag[0], flag[1]),
                torchvision.utils.make_grid(real_cpu_2),
                global_step=epoch * len(dataloader) + i)
            writer.add_image('cPGGANS fake{}-{}'.format(flag[0], flag[1]),
                torchvision.utils.make_grid(fake_2),
                global_step=epoch * len(dataloader) + i)
           
        
        # Save Losses for plotting later
        G_losses.append(errG.item())
        D_losses.append(errD.item())
        
        # Check how the generator is doing by saving G's output on fixed_noise
        if (iters % 500 == 0) or ((epoch == num_epochs-1) and (i == len(dataloader)-1)):
            with torch.no_grad():
                fake = netG(fixed_noise, flag = flag, alpha=alpha).detach().cpu()
            img_list.append(vutils.make_grid(fake, padding=2, normalize=True))

            
        iters += 1
    if epoch % 25 == 24:
        if flag[1] == "stable":
#             save_model(netG, 'Gen', flag[0], alpha, epoch, 0)
#             save_model(netD, 'Dis', flag[0], alpha, epoch, 0)

            save_model2(netG,'Gen', flag[0], optimizerG, errG, epoch, 0, SAVE_DIR = "/home/jovyan/GANs/trained_models_r/")
            save_model2(netD,'Dis', flag[0], optimizerD, errD, epoch, 0, SAVE_DIR = "/home/jovyan/GANs/trained_models_r/")
        else:
#             save_model(netG, 'Gen', flag[0], alpha, epoch, 1)
#             save_model(netD, 'Dis', flag[0], alpha, epoch, 1)

            save_model2(netG,'Gen', flag[0], optimizerG, errG, epoch, 1, SAVE_DIR = "/home/jovyan/GANs/trained_models_r/")
            save_model2(netD,'Dis', flag[0], optimizerD, errD, epoch, 1, SAVE_DIR = "/home/jovyan/GANs/trained_models_r/")
    
        
#         if i > nsamples:
#             break

Starting Training Loop...
[0/131][0/381]	Loss_D: -236224432.0000	Loss_G: 90336208.0000	D(x): 145625120.0000	D(G(z)): -90599312.0000 / -90336208.0000 	 alpha: 0.0000
[0/131][50/381]	Loss_D: -52932.0000	Loss_G: -44678200.0000	D(x): 44795244.0000	D(G(z)): 44742312.0000 / 44678200.0000 	 alpha: 0.0010
[0/131][100/381]	Loss_D: 20004.0000	Loss_G: -10873172.0000	D(x): 11060832.0000	D(G(z)): 11080836.0000 / 10873172.0000 	 alpha: 0.0020
[0/131][150/381]	Loss_D: -101572656.0000	Loss_G: -10973203.0000	D(x): 112346416.0000	D(G(z)): 10773759.0000 / 10973203.0000 	 alpha: 0.0030
[0/131][200/381]	Loss_D: -34010900.0000	Loss_G: 22974744.0000	D(x): 11009969.0000	D(G(z)): -23000932.0000 / -22974744.0000 	 alpha: 0.0040
[0/131][250/381]	Loss_D: -101674064.0000	Loss_G: 56691408.0000	D(x): 44975460.0000	D(G(z)): -56698608.0000 / -56691408.0000 	 alpha: 0.0050
[0/131][300/381]	Loss_D: -33928396.0000	Loss_G: 25595376.0000	D(x): 8912856.0000	D(G(z)): -25015540.0000 / -25595376.0000 	 alpha: 0.0060
[0/131][35

In [3]:
plt.figure(figsize=(10,5))
plt.title("Generator and Discriminator Loss During Training")
plt.plot(G_losses,label="G")
plt.plot(D_losses,label="D")
plt.xlabel("iterations")
plt.ylabel("Loss")
plt.legend()
plt.show()

In [4]:
# Grab a batch of real images from the dataloader
real_batch = next(iter(dataloader))

# Plot the real images
plt.figure(figsize=(15,15))
plt.subplot(1,2,1)
plt.axis("off")
plt.title("Real Images")
plt.imshow(np.transpose(vutils.make_grid(real_batch[0].to(device)[:64], padding=5, normalize=True).cpu(),(1,2,0)))

# Plot the fake images from the last epoch
plt.subplot(1,2,2)
plt.axis("off")
plt.title("Fake Images")
plt.imshow(np.transpose(img_list[-1],(1,2,0)))
plt.show()

In [5]:
#SAVE SNAPSHOTS

flag = [32, "transition"]
alpha = 1
idx = 4
s=64
y_onehot = turn_label_to_one_hot(torch.LongTensor([idx]).expand(s).to(device))
y_onehot.shape
a_im = netG(torch.cat((torch.randn(s, nz, 1, 1, device=device),y_onehot.view(-1,5,1,1)),dim=1), flag = flag, alpha=alpha)
print(a_im.shape)
a_grid = vutils.make_grid(a_im, normalize=True)#[0,:,:]
a_grid.shape
plt.imshow(a_grid.detach().cpu().numpy().transpose(1,2,0))
# plt.show()
plt.savefig("/PATH/TO/Figs/cPGGANs-32-transition-49-RP=4.jpg",bbox_inches='tight', pad_inches=0,dpi=32*8)