<a href="https://colab.research.google.com/github/mijanr/GANs/blob/master/cGAN_MNIST.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [None]:
#necessary imports
import torch
import torch.nn as nn
import torchvision
from torchvision import transforms, datasets
import torchvision.transforms as transforms 

import matplotlib.pyplot as plt
import seaborn as sns
import numpy as np

from tqdm import tqdm, trange

In [None]:
#set a fixed image size
sns.set(rc={'image.cmap':'gray', 'figure.figsize':(12,10)})
sns.set_style("darkgrid")

In [None]:
#load the dataset
dataset = datasets.MNIST(root='data/', train=True, transform=transforms.ToTensor(), download=True)

In [None]:
#dataloader
dataloader = torch.utils.data.DataLoader(dataset, batch_size=100, shuffle=True)

In [None]:
#plot a grid of images
def plot_grid(images):
    grid = torchvision.utils.make_grid(images)
    plt.imshow(np.transpose(grid, (1, 2, 0)))
    plt.show()

In [None]:
#plot a grid of images
plot_grid(next(iter(dataloader))[0])

In [None]:
#Let's create a Conditional GAN
#Generator
class Generator(nn.Module):
    def __init__(self):
        super(Generator, self).__init__()
        self.embedding = nn.Embedding(10, 10)
        self.sequential = nn.Sequential(
            nn.Linear(110, 256),
            nn.LeakyReLU(0.2),
            nn.Linear(256, 512),
            nn.LeakyReLU(0.2),
            nn.Linear(512, 1024),
            nn.LeakyReLU(0.2),
            nn.Linear(1024, 784),
            nn.Tanh()
        )
    def forward(self, z, labels):
        c = self.embedding(labels)
        x = torch.cat([z, c], 1)
        return self.sequential(x)#.view(-1, 1, 28, 28)
#Discriminator
class Discriminator(nn.Module):
    def __init__(self):
        super(Discriminator, self).__init__()
        self.embedding = nn.Embedding(10, 10)
        self.sequential = nn.Sequential(
            nn.Linear(794, 1024),
            nn.LeakyReLU(0.2),
            nn.Dropout(0.3),
            nn.Linear(1024, 512),
            nn.LeakyReLU(0.2),
            nn.Dropout(0.3),
            nn.Linear(512, 256),
            nn.LeakyReLU(0.2),
            nn.Dropout(0.3),
            nn.Linear(256, 1),
            nn.Sigmoid()
        )
    def forward(self, x, labels):
        c = self.embedding(labels)
        x = torch.cat([x, c], 1)
        return self.sequential(x)

In [None]:
gen = Generator()
disc = Discriminator()

In [None]:
X = torch.randn(10, 100)
print(X.shape)
labels = torch.randint(0, 10, (10,))
out = gen(X, labels)
print(out.shape)

In [None]:
disc(out, labels)

In [None]:
#optimizer
opt_gen = torch.optim.Adam(gen.parameters(), lr=0.0002)
opt_disc = torch.optim.Adam(disc.parameters(), lr=0.0002)
#loss function
loss = nn.BCELoss()

In [None]:
for real, labels in dataloader:
    print(real.shape)
    print(labels.shape)
    break

In [None]:
#training
epochs = 10
for epoch in trange(epochs):
    for real, labels in tqdm(dataloader):
        #real = real.view(-1, 784)
        real = real.view(-1, 784)
        #training discriminator
        opt_disc.zero_grad()
        #real
        D_real = disc(real, labels)
        loss_real = loss(D_real, torch.ones_like(D_real))
        #fake
        z = torch.randn(100, 100)
        fake_labels = torch.randint(0, 10, (100,))
        fake = gen(z, fake_labels)
        D_fake = disc(fake.detach(), fake_labels)
        loss_fake = loss(D_fake, torch.zeros_like(D_fake))
        #total loss
        loss_disc = loss_real + loss_fake
        loss_disc.backward()
        opt_disc.step()
        #training generator
        opt_gen.zero_grad()
        D_fake = disc(fake, fake_labels)
        loss_gen = loss(D_fake, torch.ones_like(D_fake))
        loss_gen.backward()
        opt_gen.step()
    print(f"Epoch {epoch+1}/{epochs}, Loss D: {loss_disc.item():.4f}, Loss G: {loss_gen.item():.4f}")
    #plot fake images after every 10 epochs
    if (epoch+1)%10 == 0:
        plot_grid(fake.view(-1, 1, 28, 28).detach())
    