In [47]:
import numpy as np
import torch
import torch
import torch.nn as nn
import torch.optim as optim
import torchvision
import torchvision.datasets as datasets
from torch.utils.data import DataLoader
import torchvision.transforms as transforms
from torch.utils.tensorboard import SummaryWriter  # to print to tensorboard
from torchvision.utils import save_image

In [40]:
class Discriminator(nn.Module):
    def __init__(self, img_features):
        super().__init__()
        self.disc = nn.Sequential(nn.Linear(img_features,128),
                                 nn.LeakyReLU(0.01),
                                 nn.Linear(128,1),
                                 nn.Sigmoid())
    def forward(self,x):
        return self.disc(x)

In [41]:
class Generator(nn.Module):
    def __init__(self, img_features, z_features):
        super().__init__()
        self.gen = nn.Sequential(nn.Linear(z_features, 256),
                                nn.LeakyReLU(0.1),
                                nn.Linear(256,img_features),
                                nn.Tanh())  # inputs and outputs are in the range [-1,1]
    def forward(self,x):
        return self.gen(x)

In [42]:
device = "cuda" if torch.cuda.is_available() else "cpu"
lr = 3e-4
z_dim = 64
image_dim = 28 * 28 * 1  # 784
batch_size = 32
num_epochs = 50
disc = Discriminator(image_dim).to(device)
gen = Generator(image_dim,z_dim).to(device)
fixed_noise = torch.randn((batch_size, z_dim)).to(device)

In [43]:
transforms = transforms.Compose(
    [
        transforms.ToTensor(),
        transforms.Normalize((0.5,), (0.5,)),
    ]
)
dataset = datasets.MNIST(root="dataset/", transform=transforms, download=True)
loader = DataLoader(dataset, batch_size=batch_size, shuffle=True)
opt_disc = optim.Adam(disc.parameters(), lr=lr)
opt_gen = optim.Adam(gen.parameters(), lr=lr)
criterion = nn.BCELoss()
writer_fake = SummaryWriter(f"logs/fake")
writer_real = SummaryWriter(f"logs/real")
step = 0

In [44]:
x,y = next(iter(loader))

In [45]:
x.min()

tensor(-1.)

In [46]:
x.max()

tensor(1.)

In [49]:
for epoch in range(num_epochs):
    for batch_idx, (real,y) in enumerate(loader): # note that the label is not used - unsupervised!!
        real = real.view(-1,784).to(device)
        batch_size = real.shape[0]
        
        # Train Discriminator max(log(D(x)) + log(1-D(G(z))))
        noise = torch.randn(batch_size, z_dim).to(device)
        fake = gen(noise)
        # will be using BCEloss terms to get this. in fake pass ones to use only ylog(x) term and zeros for real to use 1-y(log(1-x)) term
        disc_real = disc(real).view(-1)
        disc_loss_real = criterion(disc_real, torch.ones_like(disc_real,device=device))
        disc_fake = disc(fake).view(-1)
        disc_loss_fake = criterion(disc_fake, torch.zeros_like(disc_fake,device=device))
        disc_loss = 0.5*(disc_loss_real + disc_loss_fake)
        opt_disc.zero_grad()
        disc_loss.backward(retain_graph=True)
        opt_disc.step()
        
        # Train Generator min(log(1-D(G(z)))) <-> max(log(D(G(z)))) due to saturating gradients in the first option
        output = disc(fake).view(-1)
        loss_gen = criterion(output, torch.ones_like(output, device=device))
        opt_gen.zero_grad()
        loss_gen.backward()
        opt_gen.step()
        
        if batch_idx == 0:
            print(
                f"Epoch [{epoch}/{num_epochs}] Batch {batch_idx}/{len(loader)} \
                      Loss D: {disc_loss:.4f}, loss G: {loss_gen:.4f}"
            )

            with torch.no_grad():
                fake = gen(fixed_noise).reshape(-1, 1, 28, 28)
                data = real.reshape(-1, 1, 28, 28)
                img_grid_fake = torchvision.utils.make_grid(fake, normalize=True)
                img_grid_real = torchvision.utils.make_grid(data, normalize=True)

#                 writer_fake.add_image(
#                     "Mnist Fake Images", img_grid_fake, global_step=step
#                 )
#                 writer_real.add_image(
#                     "Mnist Real Images", img_grid_real, global_step=step
#                 )
                save_image(fake[0], f"generated/generated_ex{y[0]}.png")
                step += 1

Epoch [0/50] Batch 0/1875                       Loss D: 0.6591, loss G: 0.7523
Epoch [1/50] Batch 0/1875                       Loss D: 0.6397, loss G: 0.8196
Epoch [2/50] Batch 0/1875                       Loss D: 0.7097, loss G: 0.7650
Epoch [3/50] Batch 0/1875                       Loss D: 0.6835, loss G: 0.7706
Epoch [4/50] Batch 0/1875                       Loss D: 0.6479, loss G: 0.8503
Epoch [5/50] Batch 0/1875                       Loss D: 0.6843, loss G: 0.9461
Epoch [6/50] Batch 0/1875                       Loss D: 0.6475, loss G: 0.8681
Epoch [7/50] Batch 0/1875                       Loss D: 0.7365, loss G: 0.9128
Epoch [8/50] Batch 0/1875                       Loss D: 0.6958, loss G: 0.8614
Epoch [9/50] Batch 0/1875                       Loss D: 0.6042, loss G: 0.9767
Epoch [10/50] Batch 0/1875                       Loss D: 0.7599, loss G: 0.7792
Epoch [11/50] Batch 0/1875                       Loss D: 0.7263, loss G: 0.8289
Epoch [12/50] Batch 0/1875                       L