In [1]:
from PIL import Image
from PIL import ImageFile
import torch
import torchvision.transforms as transforms
import glob
import os
import platform
import numpy as np
from __future__ import print_function

import argparse
import os
import random
import torch
import torch.nn as nn
import torch.nn.parallel
import torch.backends.cudnn as cudnn
import torch.optim as optim
import torch.utils.data
import torchvision.datasets as dset
import torchvision.transforms as transforms
import torchvision.utils as vutils
from torch.utils.tensorboard import SummaryWriter

import matplotlib.pyplot as plt
import matplotlib.animation as animation
from IPython.display import HTML

from ImageDataset import ImageDataset
from cGAN_net import Generator, Discriminator, weights_init

%matplotlib inline
ImageFile.LOAD_TRUNCATED_IMAGES = True

In [3]:
ngpu = 1
workers = 2
device = torch.device("cuda:0" if (torch.cuda.is_available() and ngpu > 0) else "cpu")
learning_rate = 1e-4
beta1 = 0.5
batch_size = 64
image_size = 64
img_ch = 3
z_dim = 100
num_epochs = 200
features_g = 64
features_d = 64

In [4]:
dataset = ImageDataset('DATA')
num_classes = dataset.num_classes
embed_size = 100

In [8]:
# Create the generator
netG = Generator(ngpu, features_g, z_dim, img_ch, num_classes, embed_size).to(device)
if (device.type == 'cuda') and (ngpu > 1):
    netG = nn.DataParallel(netG, list(range(ngpu)))
netG.apply(weights_init)

# Create the Discriminator
netD = Discriminator(ngpu, features_d, img_ch, num_classes, image_size, embed_size).to(device)
if (device.type == 'cuda') and (ngpu > 1):
    netD = nn.DataParallel(netD, list(range(ngpu)))
#  to mean=0, stdev=0.2.
netD.apply(weights_init)
criterion = nn.BCELoss()
# Create batch of latent vectors that we will use to visualize
#  the progression of the generator


# Setup Adam optimizers for both G and D
optimizerD = optim.Adam(netD.parameters(), lr=learning_rate, betas=(beta1, 0.999))
optimizerG = optim.Adam(netG.parameters(), lr=learning_rate, betas=(beta1, 0.999))

In [None]:
from torch.autograd import Variable
# learning_range = [0.0008, 0.0004, 0.0002, 0.00008, 0.00004, 0.00002]
learning_range = [0.0001]#, 0.0004, 0.0002, 0.00008, 0.00004, 0.00002]
beta_range = [0.2]# 0.5, 0.7, 0.9]
batch_range = [32]# 64]
os.makedirs('model/Generator_cGAN', exist_ok=True)
os.makedirs('model/Discriminator_cGAN', exist_ok=True)

for lr in learning_range:
    for beta1 in beta_range:
        for batch_size in batch_range:
            sufix = f'lr={lr}_beta={beta1}_batch={batch_size}'   
            fixed_noise = Variable(torch.randn(batch_size, nz)).to(device)
            dataloader = torch.utils.data.DataLoader(dataset,
                                                     batch_size=batch_size,
                                                     shuffle=True,
                                                     num_workers=workers)
            # Apply the weights_init function to randomly initialize all weights
            #  to mean=0, stdev=0.2.
            netD.apply(weights_init)
            netG.apply(weights_init)
            optimizerD = optim.Adam(netD.parameters(), lr=lr, betas=(beta1, 0.999))
            optimizerG = optim.Adam(netG.parameters(), lr=lr, betas=(beta1, 0.999))
            
            # Number of training epochs
            num_epochs = 20
            writer = SummaryWriter(log_dir=f'runs-D_{optimizerD.__class__.__name__}-G_{optimizerG.__class__.__name__}/'
                                + f'lr={lr}_beta={beta1}_batch={batch_size}',
                                comment=f'batch={batch_size}')
            img_list = []
            iters = 0

            print("Starting Training Loop...")
            # For each epoch
            for epoch in range(num_epochs):
                # For each batch in the dataloader
                for i, (data, label) in enumerate(dataloader, 0):

                    ############################
                    # (1) Update D network: maximize log(D(x)) + log(1 - D(G(z)))
                    ###########################
                    ## Train with all-real batch
                    netD.zero_grad()
                    # Format batch
                    real_cpu = data.to(device)
                    # Forward pass real batch through D
                    output = netD(data.to(device), label.to(device)).view(-1)
                    # Calculate loss on all-real batch
                    errD_real = criterion(output, Variable(torch.ones_like(output)).to(device))
                    # Calculate gradients for D in backward pass
                    errD_real.backward()
                    D_x = output.mean().item()

                    ## Train with all-fake batch
                    # Generate batch of latent vectors
                    noise = Variable(torch.randn(batch_size, nz)).to(device)
                    fake_labels = Variable(torch.LongTensor(np.random.randint(0, dataset.num_classes, batch_size))).to(device)
                    # Generate fake image batch with G
                    fake = netG(noise, fake_labels)
                    # label.fill_(fake_label)
                    # Classify all fake batch with D
                    
                    output = netD(fake.detach(), fake_labels).view(-1)
                    # Calculate D's loss on the all-fake batch
                    errD_fake = criterion(output, Variable(torch.zeros_like(output)).to(device))
                    # Calculate the gradients for this batch
                    errD_fake.backward()
                    D_G_z1 = output.mean().item()
                    # Add the gradients from the all-real and all-fake batches
                    errD = errD_real + errD_fake
                    # Update D
                    optimizerD.step()

                    ############################
                    # (2) Update G network: maximize log(D(G(z)))
                    ###########################
                    netG.zero_grad()
                    # label.fill_(real_label)  # fake labels are real for generator cost
                    # Since we just updated D, perform another forward pass of all-fake batch through D
                    output = netD(fake.detach(), fake_labels).view(-1)
                    # Calculate G's loss based on this output
                    errG = criterion(output, Variable(torch.ones_like(output)).to(device))
                    # Calculate gradients for G
                    errG.backward()
                    D_G_z2 = output.mean().item()
                    # Update G
                    optimizerG.step()

                    # Output training stats
                    if i % 10 == 0:
                        print('[%d/%d][%d/%d]\tLoss_D: %.4f\tLoss_G: %.4f\tD(x): %.4f\tD(G(z)): %.4f / %.4f'
                            % (epoch, num_epochs, i, len(dataloader),
                                errD.item(), errG.item(), D_x, D_G_z1, D_G_z2))

                    # Check how the generator is doing by saving G's output on fixed_noise
                    if (iters % 200 == 0) or ((epoch == num_epochs-1) and (i == len(dataloader)-1)):
                        with torch.no_grad():
                            fake_labels = Variable(torch.LongTensor(np.random.randint(0, dataset.num_classes, batch_size))).to(device)
                            fake = netG(fixed_noise, fake_labels).detach().cpu()
                        img_list.append(vutils.make_grid(fake, padding=2, normalize=True))

                    iters += 1
                writer.add_scalar('Loss_Generator', errG.item(), epoch)
                writer.add_scalar('Loss_Discriminator', errD.item(), epoch)
                writer.add_scalar('D(x)', D_x, epoch)
                writer.add_scalar('D(G(z))', D_G_z2, epoch)

            with torch.no_grad():
                fake_labels = Variable(torch.LongTensor(np.random.randint(0, dataset.num_classes, batch_size))).to(device)
                fake = netG(fixed_noise, fake_labels).detach().cpu()
            writer.add_image('images', vutils.make_grid(fake, padding=2, normalize=True), 0)
            torch.save(netD, 'model/Discriminator_cGAN/' + sufix)
            torch.save(netG, 'model/Generator_cGAN/' + sufix)
            
            