In [1]:
import torch 
from torch import nn,optim
from torch.utils import data 
from torchvision import datasets,transforms
from torch.utils.data import DataLoader 
from torch.utils.tensorboard import SummaryWriter
import matplotlib.pyplot as plt
from torchvision.utils import make_grid

In [2]:
class Discriminator(nn.Module):

    def __init__(self, img_dim):
        super(Discriminator,self).__init__()

        self.disc = nn.Sequential(
            nn.Linear(img_dim,128),
            nn.LeakyReLU(0.1),
            nn.Linear(128,1),
            nn.Sigmoid() # Ensure output is 0 or 1 (fake/real)
        )

    def forward(self,x):
        return self.disc(x)


class Generator(nn.Module):
    # z_dim is noise dimension, img_dim is output img dim
    def __init__(self, z_dim, img_dim):
        super(Generator,self).__init__()

        self.gen = nn.Sequential(
            nn.Linear(z_dim, 256),
            nn.LeakyReLU(0.1),
            nn.Linear(256,img_dim),
            nn.Tanh()
        )

    def forward(self,x):
        return self.gen(x)



In [3]:
dev = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')
lr = 0.003
z_dim = 64
image_dim = 28 * 28 * 1
batch_size = 32

In [4]:
from google.colab import drive
drive.mount('/content/drive')

Mounted at /content/drive


In [7]:
disc = Discriminator(image_dim).to(dev)
gen = Generator(z_dim, image_dim).to(dev)

t = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize((0.5,),(0.5,))
])

dataset = datasets.MNIST(root='./data', download=True, transform=t)
loader = DataLoader(dataset=dataset, batch_size=batch_size, shuffle=True)

opt_disc = optim.Adam(disc.parameters(), lr=lr)
opt_gen = optim.Adam(gen.parameters(), lr=lr)

criterion = nn.BCELoss()

writer_fake = SummaryWriter(f'runs/GAN_MNIST/fake')
writer_real = SummaryWriter(f'runs/GAN_MNIST/real')
step = 1

Downloading http://yann.lecun.com/exdb/mnist/train-images-idx3-ubyte.gz to ./data/MNIST/raw/train-images-idx3-ubyte.gz


HBox(children=(FloatProgress(value=1.0, bar_style='info', max=1.0), HTML(value='')))

Extracting ./data/MNIST/raw/train-images-idx3-ubyte.gz to ./data/MNIST/raw
Downloading http://yann.lecun.com/exdb/mnist/train-labels-idx1-ubyte.gz to ./data/MNIST/raw/train-labels-idx1-ubyte.gz


HBox(children=(FloatProgress(value=1.0, bar_style='info', max=1.0), HTML(value='')))

Extracting ./data/MNIST/raw/train-labels-idx1-ubyte.gz to ./data/MNIST/raw
Downloading http://yann.lecun.com/exdb/mnist/t10k-images-idx3-ubyte.gz to ./data/MNIST/raw/t10k-images-idx3-ubyte.gz


HBox(children=(FloatProgress(value=1.0, bar_style='info', max=1.0), HTML(value='')))

Extracting ./data/MNIST/raw/t10k-images-idx3-ubyte.gz to ./data/MNIST/raw
Downloading http://yann.lecun.com/exdb/mnist/t10k-labels-idx1-ubyte.gz to ./data/MNIST/raw/t10k-labels-idx1-ubyte.gz


HBox(children=(FloatProgress(value=1.0, bar_style='info', max=1.0), HTML(value='')))

Extracting ./data/MNIST/raw/t10k-labels-idx1-ubyte.gz to ./data/MNIST/raw
Processing...
Done!




  return torch.from_numpy(parsed.astype(m[2], copy=False)).view(*s)






In [8]:
fixed_noise = torch.randn((batch_size, z_dim)).to(dev)
epochs = 1000
for epoch in range(epochs):
    for batch_idx, (real, _) in enumerate(loader):

        real = real.view(-1,784).to(dev)
        batch_size = real.shape[0]

        # Train discriminator
        noise = torch.randn((batch_size, z_dim)).to(dev)
        fake = gen(noise)

        disc_real = disc(real).view(-1)
        lossD_real = criterion(disc_real, torch.ones_like(disc_real))

        disc_fake = disc(fake).view(-1)
        lossD_fake = criterion(disc_fake, torch.zeros_like(disc_fake))

        lossD = (lossD_fake + lossD_real) / 2
        disc.zero_grad()
        lossD.backward(retain_graph=True)
        opt_disc.step()

        # Train generator
        output = disc(fake).view(-1)
        lossG = criterion(output, torch.ones_like(output))
        gen.zero_grad()
        lossG.backward()
        opt_gen.step()

        if batch_idx == 0 : 
            print(f'Epoch : [{epoch}] Loss D : {lossD:.4f}, Loss G {lossG:.4f}')

            with torch.no_grad():
                fake = gen(fixed_noise).reshape(-1,1,28,28)
                real = real.reshape(-1,1,28,28)

                fake_grid = make_grid(fake, normalize = True)
                real_grid = make_grid(real, normalize = True)

                writer_fake.add_image(
                    'Mnist fake images', fake_grid, global_step=step
                )

                writer_real.add_image(
                    'Mnist real images', real_grid, global_step=step
                )

                torch.save(disc.state_dict(), f'/content/drive/MyDrive/GAN_models/disc{step}.pt')
                torch.save(gen.state_dict(), f'/content/drive/MyDrive/GAN_models/gen{step}.pt')

                step += 1

Epoch : [0] Loss D : 0.6536, Loss G 0.8027
Epoch : [1] Loss D : 0.1131, Loss G 7.5832
Epoch : [2] Loss D : 0.1510, Loss G 9.2025
Epoch : [3] Loss D : 0.0210, Loss G 4.7537
Epoch : [4] Loss D : 0.1537, Loss G 4.7853
Epoch : [5] Loss D : 0.0383, Loss G 7.6136
Epoch : [6] Loss D : 0.0469, Loss G 5.3292
Epoch : [7] Loss D : 0.0379, Loss G 6.8874
Epoch : [8] Loss D : 0.0485, Loss G 7.1203
Epoch : [9] Loss D : 0.0365, Loss G 9.5649
Epoch : [10] Loss D : 0.1157, Loss G 10.3714
Epoch : [11] Loss D : 0.1534, Loss G 8.0941
Epoch : [12] Loss D : 0.5535, Loss G 4.7307
Epoch : [13] Loss D : 0.6157, Loss G 2.0698
Epoch : [14] Loss D : 0.5069, Loss G 1.6221
Epoch : [15] Loss D : 0.5459, Loss G 2.3597
Epoch : [16] Loss D : 0.4360, Loss G 1.9356
Epoch : [17] Loss D : 0.6701, Loss G 2.4467
Epoch : [18] Loss D : 0.4633, Loss G 2.0354
Epoch : [19] Loss D : 0.6449, Loss G 1.2468
Epoch : [20] Loss D : 0.7444, Loss G 1.8514
Epoch : [21] Loss D : 0.4908, Loss G 2.3424
Epoch : [22] Loss D : 0.6121, Loss G 3.15

KeyboardInterrupt: ignored