# Imports

In [None]:
%load_ext autoreload
%autoreload 2

In [None]:
import argparse
import os
import random
import torch
import torch.nn as nn
import torch.nn.parallel
import torch.backends.cudnn as cudnn
import torch.optim as optim
import torch.utils.data
import torchvision.utils as vutils
import torchvision.transforms as transforms

from datasets.noisy_mnist import NoisyMNIST
from models.generator import Generator
from models.discriminator import Discriminator
from models import utils

import matplotlib.pyplot as plt


# Global variables

In [None]:
SEED = 0
DATAROOT = "data"
BATCH_SIZE = 128
WORKERS = 1
NZ = 100
NGF = 64
NDF = 64
NITER = 5
LR = 0.0002
BETA1 = 0.5
DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")

random.seed(SEED)
torch.manual_seed(SEED);

# Data

In [None]:
dataset = NoisyMNIST(dataset_size=60000, noise_level=0.0, root=DATAROOT)
dataloader = torch.utils.data.DataLoader(
    dataset,
    batch_size=BATCH_SIZE,
    shuffle=True,
    num_workers=0
)

In [None]:
def plot_batch(images, labels):
    images, labels = images[:16], labels[:16]
    
    fig = plt.figure(figsize=(16, 16))

    for i in range(16):
        ax = plt.subplot(4, 4, i + 1)
        plt.tight_layout()
        ax.set_title(f"Label: {labels[i]}")
        ax.axis('off')
        plt.imshow(images[i].squeeze())

    plt.subplots_adjust(wspace=0.1, hspace=0.1)
    plt.show()

plot_batch(*next(iter(dataloader)))
    

# Models

In [None]:
netG = Generator(NZ, NGF, dataset.num_channels, dataset.output_shape).to(DEVICE)
netG.apply(utils.weights_init)
netG

In [None]:
netD = Discriminator(dataset.num_channels, NDF).to(DEVICE)
netD.apply(utils.weights_init)
netD

# Train

In [None]:
criterion = nn.BCELoss()

real_label, fake_label = 1, 0
fixed_noise = torch.randn(BATCH_SIZE, NZ, 1, 1, device=DEVICE)

In [None]:
optimizerD = optim.Adam(netD.parameters(), lr=LR, betas=(BETA1, 0.999))
optimizerG = optim.Adam(netG.parameters(), lr=LR, betas=(BETA1, 0.999))


In [None]:
for epoch in range(NITER):
    for i, data in enumerate(dataloader, 0):
        ############################
        # (1) Update D network: maximize log(D(x)) + log(1 - D(G(z)))
        ###########################
        # train with real
        netD.zero_grad()
        real_cpu = data[0].to(DEVICE)
        batch_size = real_cpu.size(0)
        label = torch.full(
            (batch_size,), real_label, dtype=real_cpu.dtype, device=DEVICE
        )

        output = netD(real_cpu)
        errD_real = criterion(output.squeeze(), label.squeeze())
        errD_real.backward()
        D_x = output.mean().item()

        # train with fake
        noise = torch.randn(batch_size, NZ, 1, 1, device=DEVICE)
        fake = netG(noise)
        label.fill_(fake_label)
        output = netD(fake.detach())
        errD_fake = criterion(output.squeeze(), label.squeeze())
        errD_fake.backward()
        D_G_z1 = output.mean().item()
        errD = errD_real + errD_fake
        optimizerD.step()

        ############################
        # (2) Update G network: maximize log(D(G(z)))
        ###########################
        netG.zero_grad()
        label.fill_(real_label)  # fake labels are real for generator cost
        output = netD(fake)
        errG = criterion(output.squeeze(), label.squeeze())
        errG.backward()
        D_G_z2 = output.mean().item()
        optimizerG.step()

        if i % 100 == 0:
            print(
                "[%d/%d][%d/%d] Loss_D: %.4f Loss_G: %.4f D(x): %.4f D(G(z)): %.4f / %.4f"
                % (
                    epoch,
                    NITER,
                    i,
                    len(dataloader),
                    errD.item(),
                    errG.item(),
                    D_x,
                    D_G_z1,
                    D_G_z2,
                )
            )

# Results

In [None]:
fake = netG(fixed_noise)
plot_batch(fake.detach() * 0.5 + 0.5, label.detach())