# Basic GAN implementation

### Import relevant modules

In [1]:
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader
from torch.utils.tensorboard import SummaryWriter
import torchvision.transforms as transforms
import torchvision.datasets as datasets
from torchvision.utils import make_grid

### Check device

In [2]:
device = "cuda" if torch.cuda.is_available() else "cpu"
if device == "cuda":
  print(f"{device} - {torch.cuda.get_device_name()}")
else:
  print(f"{device}")

'cuda'

## 0. Prepare Data

### Image transformations

In [3]:
# Image transformations. They can be chained together using Compose.
transforms = transforms.Compose(
    # Normalize to Mean Standard Deviation
    [transforms.ToTensor(), transforms.Normalize((0.1307,), (0.3081,))]
)

### Define and get dataset

In [4]:
# Latent space (noise) dimension
z_dim = 64 # 128, 256
# Image dimensions (flatened)
img_dim = 28 * 28 * 1 # 784

# Defne dataset
dataset = datasets.MNIST(root="dataset/", transform=transforms, download=True)

## 1. Create Model (Arquitecture)
#### Create Discriminator and Generator classes

In [5]:
class Discriminator(nn.Module):
    """
    A three hidden-layer discriminative neural network
    """
    def __init__(self, img_dim):
        super().__init__()

        self.hidden0 == nn.Sequential(
            nn.Linear(img_dim, 1024),
            # Like RELU but it has a small slope for negative values instead of a flat slope. (In GANs is often a better choice than ReLU)
            nn.LeakyReLU(0.2),
            # To prevent overfitting (see README)
            nn.Dropout(0.3)
        )
        self.hidden1 == nn.Sequential(
            nn.Linear(1024, 512),
            nn.LeakyReLU(0.2),
            nn.Dropout(0.3)
        )
        self.hidden2 == nn.Sequential(
            nn.Linear(512, 256),
            nn.LeakyReLU(0.2),
            nn.Dropout(0.3)
        )
        self.out == nn.Sequential(
            nn.Linear(256, 1),
            nn.Sigmoid()
        )

    def forward(self, x):
        x = self.hidden0(x)
        x = self.hidden1(x)
        x = self.hidden2(x)
        x = self.out(x)
        return x

class Generator(nn.Module):
    # z_dim is the dimension of the latent noise that the generator takes as input
    def __init__(self, z_dim, img_dim):
        super().__init__()
        self.gen = nn.Sequential(
            # Input (latent noise)
            nn.Linear(z_dim, 256),
            nn.LeakyReLU(0.1),
            nn.Linear(256, img_dim), # 28x28x1 --> 784
            # Tanh to ensure the pixel output values are between -1 and 1
            nn.Tanh()
        )

    def forward(self, x):
        return self.gen(x)

###  Initialize a Discriminator and Generator objects

In [6]:
# Create a Disc object
dis = Discriminator(img_dim).to(device)
# Create a Gen object
gen = Generator(z_dim, img_dim).to(device)

## 2. Loss and optimizers

### Hyperparameters

In [7]:
lr = 3e-4 # Learning rate
batch_size = 32
num_epochs = 50

### Loss function and optimizers

In [8]:
# Loss function
criterion = nn.BCELoss()

# Optimizers
# Here we tell which parameters (tensors) of the model we should update (dis.parameters(), gen.parameters())
opt_dis = optim.Adam(dis.parameters(), lr=lr)
opt_gen = optim.Adam(gen.parameters(), lr=lr)


### Tensorboard settings

In [9]:
fixed_noise = torch.randn((batch_size, z_dim)).to(device)
writer_fake = SummaryWriter(f"runs/GAN_MNIST/fake")
writer_real = SummaryWriter(f"runs/GAN_MNIST/real")
step = 0

##  3. Training loop

In [12]:
# Set data
loader = DataLoader(dataset, batch_size=batch_size, shuffle=True)

for epoch in range(num_epochs):
    for batch_idx, (real, _) in enumerate(loader):
        
        # Real image flatened and sent to device
        real = real.view(-1, 784).to(device)
        batch_size = real.shape[0]

        ### Train Discriminator: max [ log(D(real)) + log(1 - D(G(z)) ]
        # Reset gradients to zero
        dis.zero_grad() # Clear out the gradients of all variables
        noise = torch.randn(batch_size, z_dim).to(device)
        ## Forward pass
        fake = gen(noise) # Generate fake image
        disc_real = dis(real).view(-1) # Real image Discriminator output
        disc_fake = dis(fake).view(-1) # Fake (Generated) image Discriminator output.
        ## Loss
        loss_dis_real = criterion(disc_real, torch.ones_like(disc_real)) # (input: disc_real, target: real image, ones with size disc_real)
        loss_dis_fake = criterion(disc_fake, torch.zeros_like(disc_fake)) # (input: disc_fake, target: fake image, zeros with size disc_fake)
        ## Backward pass 
        # Train discriminator on real data
        loss_dis = (loss_dis_real + loss_dis_fake) / 2
        loss_dis.backward() # Compute local gradients and compute dLoss/dWeights using the chain rule (backward pass) 
        # Train discriminator on fake data
        # loss_dis_fake.backward() # Compute local gradients and compute dLoss/dWeights using the chain rule (backward pass) 
        # Update discriminator weights
        opt_dis.step()


        ### Train Generator: min [ log(1 - D(G(z))) ] <-> max [ log(D(G(z))) ]
        # Reset gradients to zero
        gen.zero_grad() # Clear out the gradients of all variables
        ## Forward pass
        output = dis(gen(noise)).view(-1) # Fake (Generated) image Discriminator output.
        ## Loss
        loss_gen = criterion(output, torch.ones_like(output)) # (input: fake image disc output, target: fame image disc output being ones)
        ## Backward pass
        loss_gen.backward()
        # Update discriminator weights
        opt_gen.step()


        # On each epoch at the first mini-batch:
        if batch_idx == 0:
            # Print epochs and losses
            print( 
                f"Epoch [{epoch}/{num_epochs} - "
                f"Loss D fake: {loss_dis_fake:.4f}, Loss D real: {loss_dis_fake:.4f}, Loss G: {loss_gen:.4f}]"
            )
            # Get images for each epoch
            with torch.no_grad(): # Context-manager that disable gradient calculation. It will reduce memory consumption.
                fake = gen(fixed_noise).reshape(-1, 1, 28, 28) # Fake image with right dimensions
                data = real.reshape(-1, 1, 28, 28) # Real image with right dimensions
                img_grid_fake = make_grid(fake, normalize=True)
                img_grid_real = make_grid(data, normalize=True)

                writer_fake.add_image(
                    "MNIST Fake Images", img_grid_fake, global_step=step
                )

                writer_real.add_image(
                    "MNIST Real Images", img_grid_real, global_step=step
                )

                step += 1

        

Epoch [0/50 - Loss D fake: 5.4481, Loss D real: 5.4481, Loss G: 0.0110]
Epoch [1/50 - Loss D fake: 0.1658, Loss D real: 0.1658, Loss G: 2.5002]
Epoch [2/50 - Loss D fake: 0.0705, Loss D real: 0.0705, Loss G: 3.9593]
Epoch [3/50 - Loss D fake: 0.0392, Loss D real: 0.0392, Loss G: 4.6524]
Epoch [4/50 - Loss D fake: 0.0241, Loss D real: 0.0241, Loss G: 4.8136]
Epoch [5/50 - Loss D fake: 0.0070, Loss D real: 0.0070, Loss G: 6.5268]
Epoch [6/50 - Loss D fake: 0.0318, Loss D real: 0.0318, Loss G: 4.5064]
Epoch [7/50 - Loss D fake: 0.0142, Loss D real: 0.0142, Loss G: 5.1170]
Epoch [8/50 - Loss D fake: 0.0124, Loss D real: 0.0124, Loss G: 5.3214]
Epoch [9/50 - Loss D fake: 0.0243, Loss D real: 0.0243, Loss G: 5.2867]
Epoch [10/50 - Loss D fake: 0.0135, Loss D real: 0.0135, Loss G: 5.0862]
Epoch [11/50 - Loss D fake: 0.0215, Loss D real: 0.0215, Loss G: 5.7599]
Epoch [12/50 - Loss D fake: 0.0638, Loss D real: 0.0638, Loss G: 5.3474]
Epoch [13/50 - Loss D fake: 0.0091, Loss D real: 0.0091, Loss