In [11]:
from __future__ import print_function
#%matplotlib inline
import argparse
import os
import random
import torch
import torch.nn as nn
import torch.nn.parallel
import torch.backends.cudnn as cudnn
import torch.optim as optim
import torch.utils.data
import torchvision.datasets as dset
import torchvision.transforms as transforms
import torchvision.utils as vutils
import numpy as np
import matplotlib.pyplot as plt
import matplotlib.animation as animation
from IPython.display import HTML

# Set random seed for reproducibility
manualSeed = 999
#manualSeed = random.randint(1, 10000) # use if you want new results
print("Random Seed: ", manualSeed)
random.seed(manualSeed)
torch.manual_seed(manualSeed)


Random Seed:  999


<torch._C.Generator at 0x7f5f00ab7d80>

In [12]:
dataroot = "/media/fico/Data/Celeba/CelebAMask-HQ"

workers = 4

batch_size = 128

image_size = 64

nz = 100

nf = 64

num_epochs = 50

lr = 0.0001

betas = (0, .9)

lambda_gp = 10

d_ratio = 5

device = torch.device("cuda:0" if (torch.cuda.is_available()) else "cpu")


In [13]:
# We can use an image folder dataset the way we have it setup.
# Create the dataset
dataset = dset.ImageFolder(root=dataroot,
                           transform=transforms.Compose([
                               transforms.Resize(image_size),
                               transforms.CenterCrop(image_size),
                               transforms.ToTensor(),
                               transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)),
                           ]))
# Create the dataloader
dataloader = torch.utils.data.DataLoader(dataset, batch_size=batch_size,
                                         shuffle=True, num_workers=workers)

print(len(dataloader))
# Decide which device we want to run on



235


In [14]:
# custom weights initialization called on netG and netD
def weights_init(m):
    classname = m.__class__.__name__
    if classname.find('Conv') != -1:
        nn.init.normal_(m.weight.data, 0.0, 0.02)
    elif classname.find('BatchNorm') != -1:
        nn.init.normal_(m.weight.data, 1.0, 0.02)
        nn.init.constant_(m.bias.data, 0)


In [15]:
# Generator Code

class Generator(nn.Module):
    def __init__(self):
        super(Generator, self).__init__()
        self.block1 = nn.Sequential(
            nn.ConvTranspose2d( nz, nf * 16, 4, 1, 0, bias=False),
            nn.BatchNorm2d(nf * 16),
            nn.ReLU(True),)
        self.block2 = nn.Sequential(
            nn.ConvTranspose2d(nf * 16, nf * 8, 4, 2, 1, bias=False),
            nn.BatchNorm2d(nf * 8),
            nn.ReLU(True),)
        self.block3 = nn.Sequential(
            nn.ConvTranspose2d( nf * 8, nf * 4, 4, 2, 1, bias=False),
            nn.BatchNorm2d(nf * 4),
            nn.ReLU(True),)
        self.block4 = nn.Sequential(
            nn.ConvTranspose2d( nf * 4, nf * 2, 4, 2, 1, bias=False),
            nn.BatchNorm2d(nf*2),
            nn.ReLU(True),)
        self.block5 = nn.Sequential(
            nn.ConvTranspose2d( nf *2, 3, 4, 2, 1, bias=False),
            nn.Tanh())
        
        

    def forward(self, input):
        output = self.block1(input)
        output = self.block2(output)
        output = self.block3(output)
        output = self.block4(output)
        output = self.block5(output)
        return output


In [16]:
class Discriminator(nn.Module):
    def __init__(self):
        super(Discriminator, self).__init__()
        self.main = nn.Sequential(
            nn.Conv2d(nc, nf, 4, 2, 1, bias=False),
            nn.LeakyReLU(0.2, inplace=True),
            nn.Conv2d(nf, nf * 2, 4, 2, 1, bias=False),
            nn.InstanceNorm2d(nf * 2),
            nn.LeakyReLU(0.2, inplace=True),
            nn.Conv2d(nf * 2, nf * 4, 4, 2, 1, bias=False),
            nn.InstanceNorm2d(nf * 4),
            nn.LeakyReLU(0.2, inplace=True),
            nn.Conv2d(nf * 4, nf * 8, 4, 2, 1, bias=False),
            nn.InstanceNorm2d(nf * 8),
            nn.LeakyReLU(0.2, inplace=True),
            nn.Conv2d(nf * 8, 1, 4, 1, 0, bias=False),
        )
    def forward(self, input):
        return self.main(input)


In [17]:
netD = Discriminator().to(device)
netG = Generator().to(device)
netG.apply(weights_init)
netD.apply(weights_init)




Discriminator(
  (main): Sequential(
    (0): Conv2d(3, 64, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1), bias=False)
    (1): LeakyReLU(negative_slope=0.2, inplace=True)
    (2): Conv2d(64, 128, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1), bias=False)
    (3): InstanceNorm2d(128, eps=1e-05, momentum=0.1, affine=False, track_running_stats=False)
    (4): LeakyReLU(negative_slope=0.2, inplace=True)
    (5): Conv2d(128, 256, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1), bias=False)
    (6): InstanceNorm2d(256, eps=1e-05, momentum=0.1, affine=False, track_running_stats=False)
    (7): LeakyReLU(negative_slope=0.2, inplace=True)
    (8): Conv2d(256, 512, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1), bias=False)
    (9): InstanceNorm2d(512, eps=1e-05, momentum=0.1, affine=False, track_running_stats=False)
    (10): LeakyReLU(negative_slope=0.2, inplace=True)
    (11): Conv2d(512, 1, kernel_size=(4, 4), stride=(1, 1), bias=False)
  )
)

In [18]:
fixed_noise = torch.randn(64, nz, 1, 1, device=device)

epo = 0
# Setup Adam optimizers for both G and D
optimizerD = optim.Adam(netD.parameters(), lr=lr, betas=betas)
optimizerG = optim.Adam(netG.parameters(), lr=lr, betas=betas)

In [22]:
path = "./Saved_Models/WGAN_CheckPoint_Epoch:_" + str(19)
check = torch.load(path)
netD.load_state_dict(check['netD_state_dict'])
netG.load_state_dict(check['netG_state_dict'])
optimizerD.load_state_dict(check['optimizerD_state_dict'])
optimizerG.load_state_dict(check['optimizerG_state_dict'])
epo = check['epoch']
epo += 1


In [23]:
print("Train Start")
# For each epoch

for epoch in range(epo, num_epochs):
    count = 0
    itera = 0
    # For each batch in the dataloader
    for i, data in enumerate(dataloader, 0):
        
        netD.zero_grad()
        
        real_imgs = data[0].to(device)

        
        mini_batch = real_imgs.size(0)

        loss_real = netD(real_imgs)
        loss_real = -torch.mean(loss_real)
        loss_real.backward()

        noise = torch.randn(mini_batch, nz, 1, 1, device=device)

        fake_imgs = netG(noise)

        loss_fake = netD(fake_imgs.detach())
        loss_fake = torch.mean(loss_fake)
        loss_fake.backward()

        alpha = torch.randn(mini_batch, 1, 1, 1, device = device)

        interp = alpha * real_imgs + ((1-alpha) * fake_imgs.detach())
        interp.requires_grad_()

        model_interp = netD(interp)

        grads = torch.autograd.grad(outputs=model_interp, inputs=interp,
                                  grad_outputs=torch.ones(model_interp.size()).to(device),
                                  create_graph=True, retain_graph=True, only_inputs=True)[0]

        grads = torch.square(grads)
        grads = torch.sum(grads, dim = [1,2,3])
        grads = torch.sqrt(grads)
        grads = grads - 1
        grads = torch.square(grads)
        grad_pen = torch.mean(grads * lambda_gp)

        grad_pen.backward()


        d_loss = loss_fake - loss_real + grad_pen

        optimizerD.step()
#         print(output_real)

#             print(itera)
        if i % d_ratio == 0:

            netG.zero_grad()

            g_loss = netD(fake_imgs)

            g_loss = -torch.mean(g_loss)
            g_loss.backward()
            optimizerG.step()

            if i % 10 == 0:

                print("\nEpoch:", epoch, "Iteration:", count)
                print("WGAN GP Loss:", d_loss.item(), "G Loss:", g_loss.item())


            # Check how the generator is doing by saving G's output on fixed_noise
        if (i % 10 == 0):
            with torch.no_grad():
                    guess = netG(fixed_noise)
                    guess = guess.cpu()
                    old_min = -1
                    old_max = 1
                    old_range = old_max - old_min
                    new_range = 1 - 0
                    guess = (((guess - old_min)*new_range)/ old_range) + 0
                    guess = guess.permute(0,2,3,1)

                    fig = plt.figure(figsize=(4,4))
                    for i in range(16):
                        plt.subplot(4, 4, i+1)
                        plt.imshow(guess[i, :, :])

                        plt.axis('off')
                    path = "./Training_Imgs/Epoch: " + str(epoch) + " training_step: " + str(count) + ".png"
                    plt.savefig(path, dpi=300)
                    plt.close('all')
        count += mini_batch

    path = "./Saved_Models/WGAN_CheckPoint_Epoch:_" + str(epoch)
    torch.save({
    'netD_state_dict': netD.state_dict(),
    'netG_state_dict': netG.state_dict(),
    'optimizerD_state_dict': optimizerD.state_dict(),
    'optimizerG_state_dict': optimizerG.state_dict(),
    'epoch': epoch
    }, path)


Train Start

Epoch: 20 Iteration: 0
WGAN GP Loss: -47.113319396972656 G Loss: 33.14257049560547

Epoch: 20 Iteration: 1280
WGAN GP Loss: -46.35020446777344 G Loss: 31.81186294555664

Epoch: 20 Iteration: 2560
WGAN GP Loss: -46.919700622558594 G Loss: 34.638763427734375

Epoch: 20 Iteration: 3840
WGAN GP Loss: -45.11269760131836 G Loss: 31.40787696838379

Epoch: 20 Iteration: 5120
WGAN GP Loss: -47.773681640625 G Loss: 32.78889846801758

Epoch: 20 Iteration: 6400
WGAN GP Loss: -52.816436767578125 G Loss: 34.75325012207031

Epoch: 20 Iteration: 7680
WGAN GP Loss: -50.3585090637207 G Loss: 34.274566650390625

Epoch: 20 Iteration: 8960
WGAN GP Loss: -49.76044464111328 G Loss: 34.696250915527344

Epoch: 20 Iteration: 10240
WGAN GP Loss: -50.89025115966797 G Loss: 33.42718505859375

Epoch: 20 Iteration: 11520
WGAN GP Loss: -45.4417724609375 G Loss: 33.327205657958984

Epoch: 20 Iteration: 12800
WGAN GP Loss: -49.05784606933594 G Loss: 31.94373321533203

Epoch: 20 Iteration: 14080
WGAN GP Los


Epoch: 23 Iteration: 28160
WGAN GP Loss: -46.143798828125 G Loss: 31.045108795166016

Epoch: 23 Iteration: 29440
WGAN GP Loss: -49.747249603271484 G Loss: 30.485034942626953

Epoch: 24 Iteration: 0
WGAN GP Loss: -48.998470306396484 G Loss: 32.019493103027344

Epoch: 24 Iteration: 1280
WGAN GP Loss: -48.301326751708984 G Loss: 31.479228973388672

Epoch: 24 Iteration: 2560
WGAN GP Loss: -39.9539680480957 G Loss: 28.777835845947266

Epoch: 24 Iteration: 3840
WGAN GP Loss: -45.33827209472656 G Loss: 31.178813934326172

Epoch: 24 Iteration: 5120
WGAN GP Loss: -48.408546447753906 G Loss: 31.112382888793945

Epoch: 24 Iteration: 6400
WGAN GP Loss: -49.90498352050781 G Loss: 30.955108642578125

Epoch: 24 Iteration: 7680
WGAN GP Loss: -49.46299362182617 G Loss: 33.647090911865234

Epoch: 24 Iteration: 8960
WGAN GP Loss: -56.04025650024414 G Loss: 34.290016174316406

Epoch: 24 Iteration: 10240
WGAN GP Loss: -37.65788269042969 G Loss: 29.58294677734375

Epoch: 24 Iteration: 11520
WGAN GP Loss: -


Epoch: 27 Iteration: 26880
WGAN GP Loss: -45.458740234375 G Loss: 31.090065002441406

Epoch: 27 Iteration: 28160
WGAN GP Loss: -51.060508728027344 G Loss: 31.530933380126953

Epoch: 27 Iteration: 29440
WGAN GP Loss: -50.762359619140625 G Loss: 31.35879135131836

Epoch: 28 Iteration: 0
WGAN GP Loss: -46.355247497558594 G Loss: 30.70623207092285

Epoch: 28 Iteration: 1280
WGAN GP Loss: -51.753700256347656 G Loss: 32.61414337158203

Epoch: 28 Iteration: 2560
WGAN GP Loss: -51.11876678466797 G Loss: 32.1957893371582

Epoch: 28 Iteration: 3840
WGAN GP Loss: -45.93735122680664 G Loss: 30.888996124267578

Epoch: 28 Iteration: 5120
WGAN GP Loss: -49.08203125 G Loss: 30.480613708496094

Epoch: 28 Iteration: 6400
WGAN GP Loss: -51.029022216796875 G Loss: 33.43902587890625

Epoch: 28 Iteration: 7680
WGAN GP Loss: -49.607723236083984 G Loss: 32.908546447753906

Epoch: 28 Iteration: 8960
WGAN GP Loss: -44.238189697265625 G Loss: 28.953598022460938

Epoch: 28 Iteration: 10240
WGAN GP Loss: -50.2875


Epoch: 31 Iteration: 25600
WGAN GP Loss: -47.39619445800781 G Loss: 31.235864639282227

Epoch: 31 Iteration: 26880
WGAN GP Loss: -51.98931121826172 G Loss: 31.70515251159668

Epoch: 31 Iteration: 28160
WGAN GP Loss: -50.23931121826172 G Loss: 32.18425750732422

Epoch: 31 Iteration: 29440
WGAN GP Loss: -45.153099060058594 G Loss: 30.859525680541992

Epoch: 32 Iteration: 0
WGAN GP Loss: -40.92937469482422 G Loss: 28.45077896118164

Epoch: 32 Iteration: 1280
WGAN GP Loss: -42.98257064819336 G Loss: 30.40206527709961

Epoch: 32 Iteration: 2560
WGAN GP Loss: -52.77229309082031 G Loss: 32.931549072265625

Epoch: 32 Iteration: 3840
WGAN GP Loss: -44.19158935546875 G Loss: 29.376750946044922

Epoch: 32 Iteration: 5120
WGAN GP Loss: -48.148193359375 G Loss: 29.79088020324707

Epoch: 32 Iteration: 6400
WGAN GP Loss: -51.046546936035156 G Loss: 32.48418426513672

Epoch: 32 Iteration: 7680
WGAN GP Loss: -49.992454528808594 G Loss: 31.72456169128418

Epoch: 32 Iteration: 8960
WGAN GP Loss: -48.376


KeyboardInterrupt

