## Import

In [33]:
import torch 
import torch.nn as nn 
import torch.optim as optim
import torchvision
import torchvision.datasets as datasets 
from torch.utils.data import DataLoader 
import torchvision.transforms as transforms 
from torch.utils.tensorboard import SummaryWriter 

## Main Classes

In [17]:
# new class that inherits from torch.nn.Module, which is the base class for all PyTorch neural networks.
class Discriminator(nn.Module):
    # in_features, which is the size of the input vector
    def __init__(self,in_features):
        super().__init__()
        self.disc = nn.Sequential(
                 nn.Linear(in_features,128),
                 nn.LeakyReLU(0.01),
                 nn.Linear(128,1),
            # binary classifier that takes an image as input and outputs a probability (0 = fake, 1 = real).
                 nn.Sigmoid(),
        )
    def forward(self,x):
        return self.disc(x)

In [20]:
class Generator(nn.Module):
    # noise_dim: This is the size of the input noise vector
    # image_dim: This is the size of the output image
    def __init__(self,noise_dim,image_dim):
        super().__init__()
        self.gen = nn.Sequential(
            nn.Linear(noise_dim,256),
            nn.LeakyReLU(0.01),
            nn.Linear(256,image_dim),
            nn.Tanh(),
        )
    def forward ( self,x):
        return self.gen(x)        

## Hyperparameters & Initializations

In [37]:
device = "cuda" if torch.cuda.is_available() else "cpu" 
lr = 3e-4
noise_dim = 64 
image_dim = 28*28*1 
batch_size = 32 
num_epochs = 50 

In [39]:
disc = Discriminator(image_dim).to(device)
gen = Generator(noise_dim,image_dim).to(device)
fixed_noise = torch.randn((batch_size,noise_dim)).to(device)

# prepare your image dataset (e.g., MNIST or CIFAR) for training.
transforms = transforms.Compose(
    [   # Converts a PIL image or NumPy array to a PyTorch tensor.
        transforms.ToTensor(),
      transforms.Normalize((0.5,),(0.5,)), ]
)

In [43]:
dataset = datasets.MNIST(root="dataset/" , transform=transforms , download = True)
loader = DataLoader(dataset, batch_size=batch_size, shuffle=True)
opt_disc = optim.Adam(disc.parameters(),lr=lr)
opt_gen = optim.Adam(gen.parameters(),lr=lr)
# The binary cross-entropy loss function
criterion = nn.BCELoss()
# Initializes TensorBoard writer for logging fake and real image statistics during training.
writer_fake = SummaryWriter(f"logs/fake")
writer_real = SummaryWriter(f"logs/real")
step = 0 

## Train Loop 

In [54]:
for epoch in range(num_epochs):
    for batch_idx, (real,_) in enumerate(loader):
        real = real.view(-1,784).to(device)
        batch_size = real.shape[0]
        
        # Train Discriminator: max log(D(x)) + log(1 - D(G(z)))
        noise = torch.randn(batch_size,noise_dim).to(device)
        fake = gen(noise)
        disc_real = disc(real).view(-1)
        lossD_real = criterion(disc_real , torch.zeros_like(disc_real))
        disc_fake = disc(fake).view(-1)
        lossD_fake = criterion(disc_fake, torch.zeros_like(disc_fake))
        lossD = (lossD_real + lossD_fake)/2
        disc.zero_grad()
        lossD.backward(retain_graph =True)
        opt_disc.step()

        # Train Generator: min log(1 - D(G(z))) <-> max log(D(G(z))
        output = disc(fake).view(-1)
        lossG = criterion(output,torch.ones_like(output))
        gen.zero_grad()
        lossG.backward()
        opt_gen.step()


        if batch_idx == 0:
            print(
                f"Epoch [{epoch}/{num_epochs}] Batch {batch_idx}/{len(loader)} \
                      Loss D: {lossD:.4f}, loss G: {lossG:.4f}"
            )

            with torch.no_grad():
                fake = gen(fixed_noise).reshape(-1, 1, 28, 28)
                data = real.reshape(-1, 1, 28, 28)
                img_grid_fake = torchvision.utils.make_grid(fake, normalize=True)
                img_grid_real = torchvision.utils.make_grid(data, normalize=True)

                writer_fake.add_image(
                    "Mnist Fake Images", img_grid_fake, global_step=step
                )
                writer_real.add_image(
                    "Mnist Real Images", img_grid_real, global_step=step
                )
                step += 1
        

Epoch [0/50] Batch 0/1875                       Loss D: 0.6053, loss G: 0.7135
Epoch [1/50] Batch 0/1875                       Loss D: 0.0038, loss G: 5.2382
Epoch [2/50] Batch 0/1875                       Loss D: 0.0004, loss G: 7.9914
Epoch [3/50] Batch 0/1875                       Loss D: 0.0002, loss G: 8.8522
Epoch [4/50] Batch 0/1875                       Loss D: 0.0001, loss G: 9.8630
Epoch [5/50] Batch 0/1875                       Loss D: 0.0000, loss G: 11.2355
Epoch [6/50] Batch 0/1875                       Loss D: 0.0000, loss G: 12.2036
Epoch [7/50] Batch 0/1875                       Loss D: 0.0000, loss G: 13.4619
Epoch [8/50] Batch 0/1875                       Loss D: 0.0000, loss G: 14.7443
Epoch [9/50] Batch 0/1875                       Loss D: 0.0000, loss G: 16.7441
Epoch [10/50] Batch 0/1875                       Loss D: 0.0000, loss G: 17.5778
Epoch [11/50] Batch 0/1875                       Loss D: 0.0000, loss G: 19.1621
Epoch [12/50] Batch 0/1875                 