In [None]:
# https://arxiv.org/pdf/1406.2661.pd

In [11]:
import torch
import torch.nn as nn 
from torch.utils.data import DataLoader
from torchvision import datasets
import torchvision.transforms as transforms
from torchvision.utils import save_image

In [12]:
latent_dim = 100
device = 'mps'

class Generator(nn.Module):
    def __init__(self):
        super(Generator, self).__init__()
        
        def block(input_dim, output_dim, normalize=True):
            layers = [nn.Linear(input_dim, output_dim)]
            if normalize:
                layers.append(nn.BatchNorm1d(output_dim, 0.8))
            layers.append(nn.LeakyReLU(0.2, inplace=True))
            return layers
        
        self.model = nn.Sequential(
            *block(latent_dim, 128, normalize=False),
            *block(128, 256),
            *block(256, 512),
            *block(512, 1024),
            nn.Linear(1024, 1*28*28),
            nn.Tanh()
        )
        
    def forward(self, z):
        img = self.model(z)
        img = img.view(img.size(0), 1, 28, 28)
        return img

In [13]:
class Discriminator(nn.Module):
    def __init__(self):
        super(Discriminator, self).__init__()
        
        self.model = nn.Sequential(
            nn.Linear(1*28*28, 512),
            nn.LeakyReLU(0.2, inplace=True),
            nn.Linear(512, 256),
            nn.LeakyReLU(0.2, inplace=True),
            nn.Linear(256, 1),
            nn.Sigmoid()
        )
        
    def forward(self, img):
        flattened = img.view(img.size(0), -1)
        output = self.model(flattened)
        
        return output

In [14]:
transforms_train = transforms.Compose([
    transforms.Resize(28),
    transforms.ToTensor(),
    transforms.Normalize([.5],[.5])
])

train_dataset = datasets.MNIST(root="./data", train=True, download=True, transform=transforms_train)
dataloader = DataLoader(train_dataset, batch_size=128, shuffle=True, num_workers=4)

In [15]:
generator = Generator().to(device)
discriminator = Discriminator().to(device)

adversarial_loss = nn.BCELoss().to(device)

lr = 0.0002

optimizer_G = torch.optim.Adam(generator.parameters(), lr=lr, betas=(0.5, 0.999))
optimizer_D = torch.optim.Adam(discriminator.parameters(), lr=lr, betas=(0.5, 0.999))

In [16]:
import time

n_epochs = 200
sample_interval = 2000
start_time = time.time()

for epoch in range(n_epochs):
    for i, (imgs,_) in enumerate(dataloader):
        real = torch.FloatTensor(imgs.size(0), 1).fill_(1.0).to(device)
        fake = torch.FloatTensor(imgs.size(0), 1).fill_(0.0).to(device)
        
        real_imgs = imgs.to(device)
        
        optimizer_G.zero_grad()
        
        z = torch.normal(mean=0, std=1, size=(imgs.shape[0], latent_dim)).to(device)
        
        generated_imgs = generator(z)
        
        g_loss = adversarial_loss(discriminator(generated_imgs), real)
        
        g_loss.backward()
        optimizer_G.step()
        
        optimizer_D.zero_grad()
        
        real_loss = adversarial_loss(discriminator(real_imgs), real)
        fake_loss = adversarial_loss(discriminator(generated_imgs.detach()), fake)
        d_loss = (real_loss + fake_loss) / 2
        
        d_loss.backward()
        optimizer_D.step()
        
        done = epoch * len(dataloader) + i
        if done % sample_interval == 0:
            save_image(generated_imgs[:25], f'{done}.png',nrow=5,normalize=True)
            
        print(f"[Epoch {epoch}/{n_epochs}] [D loss: {d_loss.item():.6f}] [G loss: {g_loss.item():.6f}] [Elapsed time: {time.time()-start_time:2f}s")

[Epoch 0/200] [D loss: 0.714689] [G loss: 0.658624] [Elapsed time: 5.613758s
[Epoch 0/200] [D loss: 0.630700] [G loss: 0.656052] [Elapsed time: 5.652444s
[Epoch 0/200] [D loss: 0.563831] [G loss: 0.653501] [Elapsed time: 5.673531s
[Epoch 0/200] [D loss: 0.509699] [G loss: 0.650731] [Elapsed time: 5.693347s
[Epoch 0/200] [D loss: 0.465850] [G loss: 0.647541] [Elapsed time: 5.714069s
[Epoch 0/200] [D loss: 0.436302] [G loss: 0.643486] [Elapsed time: 5.732920s
[Epoch 0/200] [D loss: 0.417482] [G loss: 0.638582] [Elapsed time: 5.752312s
[Epoch 0/200] [D loss: 0.406150] [G loss: 0.633012] [Elapsed time: 5.770996s
[Epoch 0/200] [D loss: 0.402728] [G loss: 0.626107] [Elapsed time: 5.790784s
[Epoch 0/200] [D loss: 0.403103] [G loss: 0.615706] [Elapsed time: 5.808821s
[Epoch 0/200] [D loss: 0.406123] [G loss: 0.605439] [Elapsed time: 5.828607s
[Epoch 0/200] [D loss: 0.411598] [G loss: 0.595287] [Elapsed time: 5.847169s
[Epoch 0/200] [D loss: 0.417796] [G loss: 0.582895] [Elapsed time: 5.868842s