# Simple GAN

[Original video](https://youtu.be/OljTVUVzPpM)

In [None]:
import os
import torch
import torch.nn as nn
import torch.optim as optim
import torchvision
import torchvision.datasets as datasets
import torchvision.transforms as transforms

from torch.utils.data import DataLoader
from torch.utils.tensorboard import SummaryWriter

In [None]:
# A, B, C, D = 3, 3, 2, 2
# c = torch.ones(A, B) * 2
# v = torch.randn(A, B, C, D)

# d = c[:, :, None, None] * v
# print(c[:, :, None, None])
# # print(c[:, :, None])
# print(c[:, :, None, None].shape)
# print((d[0, 0] == v[0, 0]* 2).all())

In [None]:
# Run TensorBoard

# To fix the error, because PyTorch and TensorFlow are installed both:
# AttributeError: module 'tensorflow._api.v2.io.gfile' has no attribute 'get_filesystem'
import tensorflow as tf
import tensorboard as tb
tf.io.gfile = tb.compat.tensorflow_stub.io.gfile

if os.path.exists('runs'):
    !rm -rf runs

# Load the TensorBoard notebook extension
%load_ext tensorboard

# Start TensorBoard before training to monitor it in progress
%tensorboard --logdir runs

# Reload TensorBoard
%reload_ext tensorboard

In [None]:
class Discriminator(nn.Module):
    def __init__(self, img_dim):
        super().__init__()
        self.disc = nn.Sequential(
            nn.Linear(img_dim, 128),
            nn.LeakyReLU(0.01),  # a better choice for GANs than ReLU
            nn.Linear(128, 1),  # fake (0) or real (1) image
            nn.Sigmoid(),  # between (0,1)
        )

    def forward(self, x):
        return self.disc(x)


class CustomTanh(nn.Module):
    # init method takes the parameter
    def __init__(self, multiplier):
        super().__init__()
        self.multiplier = multiplier

    # forward calls it
    def forward(self, x):
        x = self.multiplier * x
        return torch.tanh(x)


class Generator(nn.Module):
    def __init__(self, z_dim, img_dim):
        # z_dim - dimension of the latent noise as input
        # img_dim == 784 --> (28x28x1)
        super().__init__()
        self.gen = nn.Sequential(
            nn.Linear(z_dim, 256),
            nn.LeakyReLU(0.01),
            nn.Linear(256, img_dim),
            
            # normalize inputs to (-1, 1) so make outputs (-1, 1)
            nn.Tanh(),
            
            # normalize inputs to (-2.83, 2.83) so make outputs (-2.83, 2.83)
            # CustomTanh(2.83),
        )

    def forward(self, x):
        return self.gen(x)


# Hyperparameters. GANs are sensitive to hyperparameters.
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
lr = 3e-4
z_dim = 64
img_dim = 28*28*1  # 784
batch_size = 32
num_epochs = 50

disc = Discriminator(img_dim).to(device)
gen = Generator(z_dim, img_dim).to(device)
fixed_noise = torch.randn(batch_size, z_dim).to(device)
transform = transforms.Compose([
    transforms.ToTensor(),

    # convert [0, 1] to [-1, 1]
    transforms.Normalize((0.5,), (0.5,)),

    # convert [0, 1] to ≈ (-0.42, 2.82)
    # transforms.Normalize((0.1307,), (0.3081,)),
])


# If error:
# HTTPError: HTTP Error 503: Service Unavailable
# Use this instead
data_dir = '/content/dataset/MNIST/raw/'
if os.path.exists(data_dir):
    !rm -rf $data_dir

!mkdir $data_dir
!wget --directory-prefix=$data_dir https://github.com/golbin/TensorFlow-MNIST/raw/master/mnist/data/t10k-images-idx3-ubyte.gz
!wget --directory-prefix=$data_dir https://github.com/golbin/TensorFlow-MNIST/raw/master/mnist/data/t10k-labels-idx1-ubyte.gz
!wget --directory-prefix=$data_dir https://github.com/golbin/TensorFlow-MNIST/raw/master/mnist/data/train-images-idx3-ubyte.gz
!wget --directory-prefix=$data_dir https://github.com/golbin/TensorFlow-MNIST/raw/master/mnist/data/train-labels-idx1-ubyte.gz

# For the error: HTTPError: HTTP Error 403: Forbidden
# StackOverflow: https://stackoverflow.com/a/66461122/7550928
from six.moves import urllib    
opener = urllib.request.build_opener()
opener.addheaders = [('User-agent', 'Mozilla/5.0')]
urllib.request.install_opener(opener)

dataset = datasets.MNIST(root='dataset/', transform=transform, download=True)

loader = DataLoader(dataset, batch_size=batch_size, shuffle=True)
optim_disc = optim.Adam(disc.parameters(), lr=lr)
optim_gen = optim.Adam(gen.parameters(), lr=lr)
criterion = nn.BCELoss()  # the same loss like in paper GANs
writer_fake = SummaryWriter(f'runs/GAN_MNIST/fake')
writer_real = SummaryWriter(f'runs/GAN_MNIST/real')
step = 0

for epoch in range(num_epochs):
    for idx, (real, _) in enumerate(loader):
        real = real.view(-1, 784).to(device)  # reshape

        # Train Discriminator: max log(D(real)) + log(1-D(G(noise)))
        noise = torch.randn(batch_size, z_dim).to(device)

        # fake.detach() means to save it after loss_disc.backwards()
        fake = gen(noise)
        
        disc_real = disc(real).view(-1)  # flatten everything
        # BCELoss == (-1)*(y*log(x)+(1-y)*log(1-x))
        # loss_real == log(disc_real)*1
        # because (1-1)*log(1-disc_real) == 0
        # max log(D(real)) is the same as min (-1)*log(D(real))
        loss_real = criterion(disc_real, torch.ones_like(disc_real))

        disc_fake = disc(fake.detach()).view(-1)
        # BCELoss == (-1)*(y*log(x)+(1-y)*log(1-x))
        # loss_fake == 1*log(1-disc_fake)
        # because 0*log(disc_fake) == 0
        # max log(1-D(fake)) is the same as min (-1)*log(1-D(fake))
        loss_fake = criterion(disc_fake, torch.zeros_like(disc_fake))

        loss_disc = (loss_real + loss_fake) / 2.0

        disc.zero_grad()
        loss_disc.backward()  # or loss_disc.backward(retain_graph=True)
        optim_disc.step()

        # Train Generator: min log(1-D(G(noise))) - it leads to saturating
        # weak gradients, so change loss to:
        # max log(D(G(noise))) or the same as min (-1)*log(D(G(noise)))
        output = disc(fake).view(-1)
        loss_gen = criterion(output, torch.ones_like(output))

        gen.zero_grad()
        loss_gen.backward()
        optim_gen.step()

        if idx == 0:
            print(f'Epoch [{epoch}/{num_epochs}] '
                  f'Loss D {loss_disc:.4f}, '
                  f'Loss G {loss_gen:.4f}')
            
            with torch.no_grad():
                fake = gen(fixed_noise).reshape(-1, 1, 28, 28)
                data = real.reshape(-1, 1, 28, 28)
                img_grid_fake = torchvision.utils.make_grid(fake, normalize=True)
                img_grid_real = torchvision.utils.make_grid(data, normalize=True)
                writer_fake.add_image('MNIST fake images', img_grid_fake, global_step=step)
                writer_real.add_image('MNIST real images', img_grid_real, global_step=step)
                step += 1

--2021-03-13 12:43:38--  https://github.com/golbin/TensorFlow-MNIST/raw/master/mnist/data/t10k-images-idx3-ubyte.gz
Resolving github.com (github.com)... 192.30.255.113
Connecting to github.com (github.com)|192.30.255.113|:443... connected.
HTTP request sent, awaiting response... 302 Found
Location: https://raw.githubusercontent.com/golbin/TensorFlow-MNIST/master/mnist/data/t10k-images-idx3-ubyte.gz [following]
--2021-03-13 12:43:38--  https://raw.githubusercontent.com/golbin/TensorFlow-MNIST/master/mnist/data/t10k-images-idx3-ubyte.gz
Resolving raw.githubusercontent.com (raw.githubusercontent.com)... 185.199.110.133, 185.199.111.133, 185.199.108.133, ...
Connecting to raw.githubusercontent.com (raw.githubusercontent.com)|185.199.110.133|:443... connected.
HTTP request sent, awaiting response... 200 OK
Length: 1648877 (1.6M) [application/octet-stream]
Saving to: ‘/content/dataset/MNIST/raw/t10k-images-idx3-ubyte.gz’


2021-03-13 12:43:38 (16.4 MB/s) - ‘/content/dataset/MNIST/raw/t10k-im

In [None]:
def save_checkpoint(state, filename):
    print("=> Saving checkpoint")
    torch.save(state, filename)

disc_checkpoint = {
    'state_dict': disc.state_dict(),
    'optimizer': optim_disc.state_dict(),
    'step': step,
}
save_checkpoint(disc_checkpoint, 'discriminator_checkpoint.pth.tar')

gen_checkpoint = {
    'state_dict': gen.state_dict(),
    'optimizer': optim_gen.state_dict(),
    'step': step,
}
save_checkpoint(gen_checkpoint, 'generator_checkpoint.pth.tar')

=> Saving checkpoint
=> Saving checkpoint
