In [None]:
import numpy as np 
import torch
from torch import nn, optim
import torchvision
from torch.utils.data import DataLoader
import torchvision.transforms as transforms
import torchvision.models as models
import matplotlib.pyplot as plt

In [None]:
class Descriminator(nn.Module):
    def __init__(self, in_features):
        super().__init__()
        self.disc = nn.Sequential(
            nn.Linear(in_features, 256),
            nn.LeakyReLU(0.1),
            nn.Linear(256, 128),
            nn.LeakyReLU(0.1),
            nn.Linear(128, 1),
            nn.Sigmoid()
        )
    def forward(self, x):
        return self.disc(x)

In [None]:
class Generator(nn.Module):
    def __init__(self, z_dim, img_dim):
        super().__init__()
        self.gen = nn.Sequential(
            nn.Linear(z_dim, 256),
            nn.LeakyReLU(0.1),
            nn.Linear(256, 512),
            nn.LeakyReLU(0.1),
            nn.Linear(512, img_dim),
            nn.Tanh()
        )
    def forward(self, x):
        return self.gen(x)

In [None]:
DEVICE  = 'cuda' if torch.cuda.is_available() else 'cpu'
LR      = 3e-4
Z_DIM   = 64
IMG_DIM = 28*28*1
BS      = 64
EPOCHS  = 100

In [None]:
%%time
disc        = Descriminator(IMG_DIM).to(DEVICE)
gen         = Generator(Z_DIM, IMG_DIM).to(DEVICE) 
fixed_noice = torch.randn((BS, Z_DIM)).to(DEVICE) 
transforms  = transforms.Compose([transforms.RandomHorizontalFlip(), transforms.ToTensor(),transforms.Normalize((0.5,),(0.5,))])
dataset     = torchvision.datasets.MNIST(root = 'dataset/', transform = transforms, download = True)
loader      = DataLoader(dataset, batch_size=BS, shuffle=True)
opt_disc    = optim.Adam(disc.parameters(), lr=LR)
opt_gen     = optim.Adam(gen.parameters(), lr=LR)
criterion   = nn.BCELoss()

In [None]:
real, y = next(iter(loader))
_, ax = plt.subplots(5,5, figsize=(10,10))
plt.suptitle('Some real samples', fontsize=19, fontweight='bold')

ind = 0 
for k in range(5):
    for kk in range(5):
        ind += 1
        ax[k][kk].imshow(real[ind][0])
        

In [None]:
def train_fn(loader, device):
    for real,_ in loader:
        real       = real.view(-1, 784).to(device)
        batch_size = real.shape[0]
        ######## train Discriminator: max log(D(real)) + log(1-D(G(z)))
        noice     = torch.randn(batch_size, Z_DIM).to(device)
        fake      = gen(noice)
        disc_real = disc(real).view(-1) #shape [64,1] -> [64] 
        lossD_real= criterion(disc_real, torch.ones_like(disc_real))
        disc_fake = disc(fake).view(-1)
        lossD_fake= criterion(disc_fake, torch.zeros_like(disc_fake))
        lossD     = (lossD_real + lossD_fake)/2
        disc.zero_grad()
        lossD.backward(retain_graph=True) # Add retain_graph=True, To save fake in memory 
        opt_disc.step()
        
        ######## train Generator: min log(1-D(G(z))) <==> max log(D(G(z)))
        output = disc(fake).view(-1)
        lossG  = criterion(output, torch.ones_like(output))
        gen.zero_grad()
        lossG.backward()
        opt_gen.step()
        

In [None]:
# Define empty lists to store losses
gen_losses  = []
disc_losses = []
for epoch in range(EPOCHS):
    for batch_idx, (real,_) in enumerate(loader):
        real       = real.view(-1, 784).to(DEVICE)
        batch_size = real.shape[0]
        ######## train Discriminator: max log(D(real)) + log(1-D(G(z)))
        noice     = torch.randn(batch_size, Z_DIM).to(DEVICE)
        fake      = gen(noice)
        disc_real = disc(real).view(-1) #shape [64,1] -> [64] 
        lossD_real= criterion(disc_real, torch.ones_like(disc_real))
        disc_fake = disc(fake).view(-1)
        lossD_fake= criterion(disc_fake, torch.zeros_like(disc_fake))
        lossD     = (lossD_real + lossD_fake)/2
        disc.zero_grad()
        lossD.backward(retain_graph=True) # Add retain_graph=True, To save fake in memory 
        opt_disc.step()
        
        ######## train Generator: min log(1-D(G(z))) <==> max log(D(G(z)))
        output = disc(fake).view(-1)
        lossG  = criterion(output, torch.ones_like(output))
        gen.zero_grad()
        lossG.backward()
        opt_gen.step()
        
        # Calculate losses and append to lists
        gen_losses.append(lossG.item())
        disc_losses.append(lossD.item())
        
            
    print(
                f"Epoch [{epoch}/{EPOCHS}] \
                      Loss D: {lossD:.4f}, loss G: {lossG:.4f}"
    )
    
    if epoch %10 == 0:
        with torch.no_grad():
            fake = gen(fixed_noice).reshape(-1, 1, 28, 28).cpu()
        _, ax = plt.subplots(5,5, figsize=(10,10))
        plt.suptitle('Results of epoch '+str(epoch), fontsize=19, fontweight='bold')

        ind = 0 
        for k in range(5):
            for kk in range(5):
                ind += 1
                ax[k][kk].imshow(fake[ind][0])

In [None]:
# Plot losses
plt.plot(gen_losses, label='Generator')
plt.plot(disc_losses, label='Discriminator')
plt.legend()
plt.xlabel('Iteration')
plt.ylabel('Loss')
plt.show()