In [None]:
from __future__ import print_function
import os
import random
import torch
import torch.nn as nn
import torch.nn.parallel
import torch.backends.cudnn as cudnn
import torch.optim as optim
import torch.utils.data
import torchvision.datasets as dset
import torchvision.transforms as transforms
import torchvision.utils as vutils
import numpy as np


In [None]:
import torchvision.datasets as dset

DATA_PATH = '~/Data/mnist'

transform = transforms.Compose([#transforms.Grayscale(1),
    transforms.Resize(64),
    transforms.ToTensor()
])  

dataset = dset.MNIST(root=DATA_PATH, download=True,
                     transform=transform)
dataloader = torch.utils.data.DataLoader(dataset, batch_size=128,
                                         shuffle=True, num_workers=2)


In [None]:
from models import *

print("Random Seed: ", 1623)
random.seed(1623)
torch.manual_seed(1623)

device = torch.device("cuda" if False else "cpu")
ngpu = 1
nz = 100
ngf = 64
ndf = 64
nc = 1

netG = Generator(ngpu).to(device)
netG.apply(weights_init)
print(netG)

netD = Discriminator(ngpu).to(device)
netD.apply(weights_init)
print(netD)


In [None]:
criterion = nn.BCELoss()

fixed_noise = torch.randn(64, nz, 1, 1, device=device)
real_label = 1.0
fake_label = 0.0
lr = 0.001
# setup optimizer
optimizerD = optim.Adam(netD.parameters(), lr=lr, betas=(0.5, 0.999))
optimizerG = optim.Adam(netG.parameters(), lr=lr, betas=(0.5, 0.999))

scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizerG, patience = 5, factor = 0.1)


In [None]:
import torchvision.utils as utils
import matplotlib.pyplot as plt
def imshow(img):
    img = (img+1)/2    
    img = img.squeeze()
    np_img = img.numpy()
    plt.imshow(np_img, cmap='gray')
    plt.show()

def imshow_grid(img):
    img = utils.make_grid(img.cpu().detach())
    img = (img+1)/2
    npimg = img.numpy()
    plt.imshow(np.transpose(npimg, (1,2,0)))
    plt.show()
    
example_mini_batch_img, example_mini_batch_label = next(iter(dataloader))
imshow_grid(example_mini_batch_img[0:16])

In [None]:
from tqdm import tqdm

for epoch in tqdm(range(100)):
    for i, data in enumerate(dataloader, 0):

        netD.zero_grad()
        real_cpu = data[0].to(device)
        batch_size = real_cpu.size(0)
        label = torch.full((batch_size,), real_label, device=device)

        output = netD(real_cpu)
        errD_real = criterion(output, label)
        errD_real.backward()
        D_x = output.mean().item()

        noise = torch.randn(batch_size, nz, 1, 1, device=device)
        fake = netG(noise)
        label.fill_(fake_label)
        output = netD(fake.detach())
        errD_fake = criterion(output, label)
        errD_fake.backward()
        D_G_z1 = output.mean().item()
        errD = errD_real + errD_fake
        optimizerD.step()

        netG.zero_grad()
        label.fill_(real_label)  # fake labels are real for generator cost
        output = netD(fake)
        errG = criterion(output, label)
        errG.backward()
        D_G_z2 = output.mean().item()
        optimizerG.step()
    
    scheduler.step()
    print('[%d/%d] Loss_D: %.4f Loss_G: %.4f D(x): %.4f D(G(z)): %.4f / %.4f'
                  % (epoch, 50,
                     errD.item(), errG.item(), D_x, D_G_z1, D_G_z2))
    # do checkpointing
    imshow_grid(netG(fixed_noise)[0:16])

    torch.save(netG.state_dict(), 'best_gan_weights/mnist_gen_epoch_%d.pth' % (epoch))
    torch.save(netD.state_dict(), 'best_gan_weights/mnist_dis_epoch_%d.pth' % (epoch))