In [16]:
import torch
import torch.nn as nn
import torch.optim as optim
import torchvision
import torchvision.transforms as transforms
import torchvision.datasets as datasets
from torch.utils.data import DataLoader
from torch.utils.tensorboard import SummaryWriter

In [17]:
class Discriminator(nn.Module):
    def __init__(self,img_dim):
        super(Discriminator,self).__init__()
        self.disc = nn.Sequential(
            nn.Linear(img_dim,128),
            nn.LeakyReLU(0.1),
            nn.Linear(128,1),
            nn.Sigmoid()
        )
    
    def forward(self,x):
        return self.disc(x)
    

class Generator(nn.Module):
    def __init__(self,z_dim, img_dim):
        super(Generator,self).__init__()
        self.gen = nn.Sequential(
            nn.Linear(z_dim,256),
            nn.LeakyReLU(0.1),
            nn.Linear(256,img_dim),
            nn.Tanh())
        
    def forward(self,x):
        return self.gen(x)
    



In [18]:
# Hyperparameters etc.
device = 'mps' if torch.backends.mps.is_available() else 'cpu'
lr = 3e-4
z_dim = 64 # 128, 256 as alternatives
img_dim = 28 * 28 * 1
batch_size = 32
epochs = 50

disc = Discriminator(img_dim).to(device)
gen = Generator(z_dim, img_dim).to(device)
fixed_noise = torch.randn(batch_size, z_dim).to(device)

transforms = transforms.Compose(
    [transforms.ToTensor(), transforms.Normalize((0.5,), (0.5,))]
)

dataset = datasets.MNIST(root='../data', download=True, transform=transforms)
loader = DataLoader(dataset, batch_size=batch_size,shuffle = True)

opt_disc = optim.Adam(disc.parameters(), lr=lr)
opt_gen = optim.Adam(gen.parameters(), lr=lr)

criterion = nn.BCELoss()

writer_fake = SummaryWriter(f'../data/runs/GAN_MNIST/fake')
writer_real = SummaryWriter(f'../data/runs/GAN_MNIST/real')
step = 0

In [20]:
# Training loop

for epoch in range(epochs):
    for idx, (real_img, _) in enumerate(loader):
        real_img = real_img.view(-1,784).to(device)
        batch_size = real_img.shape[0]

        ### TRAIN DISC ### max log(D(real)) + log(1 - FD)

        noise = torch.randn(batch_size, z_dim).to(device)
        fake = gen(noise)
        disc_real = disc(real_img).view(-1)
        lossD_real = criterion(disc_real, torch.ones_like(disc_real))

        disc_fake = disc(fake).view(-1)
        lossD_fake = criterion(disc_fake, torch.zeros_like(disc_fake))
        lossD = (lossD_real + lossD_fake) / 2
        disc.zero_grad()
        lossD.backward(retain_graph=True)
        opt_disc.step()

        ### TRAIN GEN ### min(log(1- D(G(z)))) <-> max(log(D(G(z))))
        output = disc(fake).view(-1)
        lossG = criterion(output, torch.ones_like(output))
        gen.zero_grad()
        lossG.backward()
        opt_gen.step()

        if idx == 0:
            print(
                f'Epoch [{epoch / epochs}] \  '
                f'Loss D: {lossD.item():.4f}, Loss G: {lossG.item():.4f}'
            )

            with torch.no_grad():
                fake = gen(fixed_noise).reshape(-1, 1,28,28)
                data = real_img.reshape(-1, 1, 28, 28)

                img_grid_fake = torchvision.utils.make_grid(fake, normalize=True)
                img_grid_real = torchvision.utils.make_grid(data, normalize=True)

                writer_fake.add_image(
                    'MNIST Fake Images', img_grid_fake, global_step=step
                )

                writer_real.add_image(
                    'MNIST Real Images', img_grid_real, global_step=step
                )

                step += 1





Epoch [0.0] \  Loss D: 0.6425, Loss G: 0.9537
Epoch [0.02] \  Loss D: 0.4302, Loss G: 1.1987
Epoch [0.04] \  Loss D: 0.6934, Loss G: 0.7924
Epoch [0.06] \  Loss D: 0.7322, Loss G: 0.8545
Epoch [0.08] \  Loss D: 0.5555, Loss G: 1.3533
Epoch [0.1] \  Loss D: 0.5912, Loss G: 0.9121
Epoch [0.12] \  Loss D: 0.8059, Loss G: 0.9252
Epoch [0.14] \  Loss D: 0.5508, Loss G: 1.1309
Epoch [0.16] \  Loss D: 0.6394, Loss G: 0.9582
Epoch [0.18] \  Loss D: 0.4820, Loss G: 1.6010
Epoch [0.2] \  Loss D: 0.4869, Loss G: 1.3634
Epoch [0.22] \  Loss D: 0.5367, Loss G: 1.0627
Epoch [0.24] \  Loss D: 0.7572, Loss G: 1.0065
Epoch [0.26] \  Loss D: 0.5100, Loss G: 1.1262
Epoch [0.28] \  Loss D: 0.6252, Loss G: 1.1670
Epoch [0.3] \  Loss D: 0.5999, Loss G: 1.5315
Epoch [0.32] \  Loss D: 0.3958, Loss G: 1.2876
Epoch [0.34] \  Loss D: 0.6455, Loss G: 1.1663
Epoch [0.36] \  Loss D: 0.6972, Loss G: 1.0367
Epoch [0.38] \  Loss D: 0.5647, Loss G: 1.1603
Epoch [0.4] \  Loss D: 0.6285, Loss G: 0.9057
Epoch [0.42] \  Lo