In [1]:
from PIL import Image
from PIL import ImageFile
import torch
import torchvision.transforms as transforms
import glob
import os
import platform
import numpy as np
from __future__ import print_function

import argparse
import os
import random
import torch
import torch.nn as nn
import torch.nn.parallel
import torch.backends.cudnn as cudnn
import torch.optim as optim
import torch.utils.data
import torchvision.datasets as dset
import torchvision.transforms as transforms
import torchvision.utils as vutils
from torch.utils.tensorboard import SummaryWriter

import matplotlib.pyplot as plt
import matplotlib.animation as animation
from IPython.display import HTML

from ImageDataset import ImageDataset
from WGAN_net import Generator, Discriminator, weights_init
from WGAN_net import gradient_penalty

%matplotlib inline
ImageFile.LOAD_TRUNCATED_IMAGES = True

In [2]:
dataset = ImageDataset('DATA')

In [3]:
# Decide which device we want to run on
ngpu = 1
device = torch.device("cuda:0" if (torch.cuda.is_available() and ngpu > 0) else "cpu")
learning_rate = 1e-4
batch_size = 64
image_size = 64
img_ch = 3
workers = 2
z_dim = 100
num_epochs = 200
features_g = 64
features_d = 64
critic_iterations = 5
lambda_pen = 10

In [4]:
dataloader = torch.utils.data.DataLoader(dataset, batch_size=batch_size,
                                         shuffle=True, num_workers=workers)

In [5]:
# Create the Generator and Discriminator
netG = Generator(z_dim=z_dim, img_ch=img_ch, features_g=features_g).to(device)
netD = Discriminator(img_ch=img_ch, features_d=features_d).to(device)

# Handle multi-gpu if desired
if (device.type == 'cuda') and (ngpu > 1):
    netD = nn.DataParallel(netD, list(range(ngpu)))
    netG = nn.DataParallel(netG, list(range(ngpu)))


netD.apply(weights_init)
netG.apply(weights_init)
# Initialize BCELoss function
# Create batch of latent vectors that we will use to visualize
#  the progression of the generator


# Setup Adam optimizers for both G and D
optimizerD = optim.Adam(netD.parameters(), lr=learning_rate, betas=(0.0, 0.9))
optimizerG = optim.Adam(netG.parameters(), lr=learning_rate, betas=(0.0, 0.9))

In [None]:
fixed_noise = torch.randn(32, z_dim, 1, 1).to(device)

# Number of training epochs
writer_real = SummaryWriter(log_dir=f'logs2/real')
writer_fake = SummaryWriter(log_dir=f'logs2/fake')
writer_lossD = SummaryWriter(log_dir=f'logs2/lossD')
writer_lossG = SummaryWriter(log_dir=f'logs2/lossG')
writer_penalty = SummaryWriter(log_dir=f'logs2/penalty')
iters = 0

print("Starting Training Loop...")
# For each epoch
for epoch in range(num_epochs):
    # For each batch in the dataloader
    for i, (real, _) in enumerate(dataloader):
        real = real.to(device)
        for _ in range(critic_iterations):
            noise = torch.randn(real.shape[0], z_dim, 1, 1).to(device)
            fake = netG(noise)
            disc_real = netD(real).reshape(-1)
            disc_fake = netD(fake).reshape(-1)
            gp = gradient_penalty(netD, real, fake, device=device)
            errD_real = torch.mean(disc_real)
            errD_fake = torch.mean(disc_fake)
            loss_disc = (-(errD_real - errD_fake)\
                         + lambda_pen * gp)
            netD.zero_grad()
            loss_disc.backward(retain_graph=True)
            optimizerD.step()

        output = netD(fake).view(-1)
        loss_gen = -torch.mean(output)
        netG.zero_grad()
        loss_gen.backward()
        optimizerG.step()

        # Output training stats
        if i % 20 == 0:
            print(
                f'Epoch [{epoch}/{num_epochs}] Batch {i}/{len(dataloader)} '
                + f'Loss D: {loss_disc:.4f}, loss G: {loss_gen:.4f}'
                )
            with torch.no_grad():
                fake = netG(fixed_noise)
                img_grid_real = torchvision.utils.make_grid(
                    real[:32], normalize=True
                )
                img_grid_fake = torchvision.utils.make_grid(
                    fake[:32], normalize=True
                )

                writer_real.add_image('Real', img_grid_real, global_step=iters)
                writer_real.add_scalar('Real', errD_real, global_step=iters)
                writer_fake.add_image('D(x)', img_grid_fake, global_step=iters)
                writer_fake.add_scalar('D(G(z))', errD_fake, global_step=iters)
                writer_lossD.add_scalar('Loss_Discriminator', loss_disc.item(), global_step=iters)
                writer_lossG.add_scalar('Loss_Generator', loss_gen.item(), global_step=iters)
                writer_penalty.add_scalar('Gradient_Penalty', gp.item(), global_step=iters)
            iters += 1

Starting Training Loop...
Epoch [0/200] Batch 0/43 Loss D: -21.8954, loss G: 29.1958
Epoch [0/200] Batch 20/43 Loss D: -21.3246, loss G: 42.1368
Epoch [0/200] Batch 40/43 Loss D: -26.9997, loss G: 45.7437
Epoch [1/200] Batch 0/43 Loss D: -31.9731, loss G: 44.3318
Epoch [1/200] Batch 20/43 Loss D: -26.5479, loss G: 47.3191
Epoch [1/200] Batch 40/43 Loss D: -24.7024, loss G: 47.7742
Epoch [2/200] Batch 0/43 Loss D: -31.0710, loss G: 40.8739
Epoch [2/200] Batch 20/43 Loss D: -27.5465, loss G: 49.6100
Epoch [2/200] Batch 40/43 Loss D: -32.2657, loss G: 44.7909
Epoch [3/200] Batch 0/43 Loss D: -27.3051, loss G: 52.6935
Epoch [3/200] Batch 20/43 Loss D: -27.7235, loss G: 60.3729
Epoch [3/200] Batch 40/43 Loss D: -30.3863, loss G: 49.7284
Epoch [4/200] Batch 0/43 Loss D: -24.2956, loss G: 54.2500
Epoch [4/200] Batch 20/43 Loss D: -27.5710, loss G: 56.5612
Epoch [4/200] Batch 40/43 Loss D: -32.0501, loss G: 44.1032
Epoch [5/200] Batch 0/43 Loss D: -25.2204, loss G: 54.3083
Epoch [5/200] Batch 