<a href="https://colab.research.google.com/github/kunalburgul/MLDS_Problems/blob/main/GANS/GANs_on_MNIST.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [None]:
import torch
import torch.nn as nn
import torch.optim as optim
import torchvision
import torchvision.datasets as datasets
from torch.utils.data import DataLoader
import torchvision.transforms as transforms
from torch.utils.tensorboard import SummaryWriter

In [None]:
class Discriminator(nn.Module):
    def __init__(self, img_dim):
        super().__init__()
        self.disc = nn.Sequential(
            nn.Linear(img_dim, 128),
            nn.LeakyReLU(0.1),
            nn.Linear(128, 1),
            nn.Sigmoid(),
        )

    def forward(self, x):
        return self.disc(x)

In [None]:
class Generator(nn.Module):
    def __init__(self, z_dim, img_dim):
        super().__init__()
        self.gen = nn.Sequential(
            nn.Linear(z_dim, 256),
            nn.LeakyReLU(0.1),
            nn.Linear(256, img_dim), # 28 x 28 x 1 
            nn.Tanh(),
        )

    def forward(self, x):
        return self.gen(x)


In [None]:
# Hyperparameters 
device = "cuda" if torch.cuda.is_available() else "cpu"
lr = 3e-4
z_dim = 64  # 128 , 256
image_dim = 28 * 28 * 1 # 784
batch_size = 32
num_epochs = 50 

In [None]:
disc = Discriminator(image_dim).to(device)
gen = Generator(z_dim, image_dim).to(device)

In [None]:
fixed_noise = torch.randn((batch_size, z_dim)).to(device)
transforms = transforms.Compose([transforms.ToTensor(),
                                 transforms.Normalize((0.1307,), (0.3081,))])

In [None]:
dataset = datasets.MNIST(root="/dataset/", transform= transforms, download= True)
loader = DataLoader(dataset, 
                    batch_size=batch_size,
                    shuffle=True)

Downloading http://yann.lecun.com/exdb/mnist/train-images-idx3-ubyte.gz
Using downloaded and verified file: /dataset/MNIST/raw/train-images-idx3-ubyte.gz
Extracting /dataset/MNIST/raw/train-images-idx3-ubyte.gz to /dataset/MNIST/raw

Downloading http://yann.lecun.com/exdb/mnist/train-labels-idx1-ubyte.gz
Using downloaded and verified file: /dataset/MNIST/raw/train-labels-idx1-ubyte.gz
Extracting /dataset/MNIST/raw/train-labels-idx1-ubyte.gz to /dataset/MNIST/raw

Downloading http://yann.lecun.com/exdb/mnist/t10k-images-idx3-ubyte.gz
Downloading http://yann.lecun.com/exdb/mnist/t10k-images-idx3-ubyte.gz to /dataset/MNIST/raw/t10k-images-idx3-ubyte.gz
Failed to download (trying next):
HTTP Error 503: Service Unavailable

Downloading https://ossci-datasets.s3.amazonaws.com/mnist/t10k-images-idx3-ubyte.gz
Downloading https://ossci-datasets.s3.amazonaws.com/mnist/t10k-images-idx3-ubyte.gz to /dataset/MNIST/raw/t10k-images-idx3-ubyte.gz


HBox(children=(FloatProgress(value=0.0, max=1648877.0), HTML(value='')))


Extracting /dataset/MNIST/raw/t10k-images-idx3-ubyte.gz to /dataset/MNIST/raw

Downloading http://yann.lecun.com/exdb/mnist/t10k-labels-idx1-ubyte.gz
Downloading http://yann.lecun.com/exdb/mnist/t10k-labels-idx1-ubyte.gz to /dataset/MNIST/raw/t10k-labels-idx1-ubyte.gz


HBox(children=(FloatProgress(value=0.0, max=4542.0), HTML(value='')))


Extracting /dataset/MNIST/raw/t10k-labels-idx1-ubyte.gz to /dataset/MNIST/raw



  return torch.from_numpy(parsed.astype(m[2], copy=False)).view(*s)


In [None]:
opt_disc = optim.Adam(disc.parameters(), lr=lr)
opt_gen = optim.Adam(gen.parameters(), lr=lr)
criterion = nn.BCELoss()

In [None]:
writer_fake = SummaryWriter(f"runs/GAN_MNIST/fake")
writer_real = SummaryWriter(f"runs/GAN_MNIST/real")
step = 0

In [None]:
for epoch in range(num_epochs):
    for batch_idx, (real, label_) in enumerate(loader):
        real = real.view(-1, 784).to(device)
        batch_size = real.shape[0]

        ### Train Discriminator: max log(D(real)) + log(1 - D(G(z)))
        noise = torchl.randn(batch_size, z_dim).to(device)
        fake = gen(noise)
        disc_real = disc(real).view (-1)
        lossD_real = criterion(disc_real, torch.ones_like(disc_real))
        disc_fake = disc(fake).view(-1)
        lossD_fake = criterion(disc_fake, torch.zeroslike(disc_fake))
        lossD = (lossD_real +  lossD_fake) / 2
        disc.zero_grad()
        lossD.backward(retrain_graph = True)
        opt_disc.step()

    
        ### Train the Genererator: min log(1 - D(G(z)))
        output = disc(fake).view(-1)
        lossG = criterion(output, torch.ones_like(output))
        gen.zero_grad()
        lossG.backward()
        opt_gen.step()

        if batch_idx == 0:
            print(
                f"Epoch [{epoch}/{num_epochs}] \"
                f"Loss D:{lossD:.4f}, Loss G:{lossG:.4f}")
            
            with torch.no_grad()
            fake = gen 

        