In [75]:
import os 
import numpy as np
import math
import matplotlib.pyplot as plt
from torchvision.utils import save_image
import torchvision.transforms as transforms
from torch.utils.data import DataLoader
from torchvision import datasets
from torch.autograd import Variable

import torch.nn as nn
import torch.nn.functional as F
import torch
import time


In [76]:
img_shape = (1,28,28)

In [77]:
class Generator(nn.Module):
    def __init__(self):
        super(Generator, self).__init__()
        def block(input,output,normalize = True):
            layers = [nn.Linear(input,output)]
            if normalize:
                layers.append(nn.BatchNorm1d(output,.8))
            layers.append(nn.LeakyReLU(.2,inplace = True))
            return layers
        
        self.model = nn.Sequential(
            *block(100,128,normalize = False),
            *block(128,256),
            *block(256,512),
            *block(512,1024),
            nn.Linear(1024,int(np.prod(img_shape))),
            nn.Tanh()
        )
    
    def forward(self,z):
        img = self.model(z)
        img = img.view(img.size(0),*img_shape)
        return img

In [78]:
class Discriminator(nn.Module):
    def __init__(self):
        super(Discriminator, self).__init__()

        def block(input,output):
            layers = [nn.Linear(input,output)]
            layers.append(nn.LeakyReLU(.2,inplace = True))
            return layers
        
        self.model = nn.Sequential(
            *block(int(np.prod(img_shape)),512),
            *block(512,256),
            *block(256,1),
            nn.Sigmoid(),
        )
    
    def forward(self,img):
        flatten = img.view(img.size(0),-1)
        checker = self.model(flatten)
        return checker

In [79]:
cuda = True if torch.cuda.is_available() else False
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
loss = torch.nn.BCELoss()
generator = Generator()
discriminator = Discriminator()

if cuda:
    generator.cuda()
    discriminator.cuda()
    loss.cuda()

In [80]:
os.makedirs("images",exist_ok = True)
dataloader = torch.utils.data.DataLoader(
    datasets.MNIST(
        "images",
        train = True,
        download = True,
        transform = transforms.Compose(
            [transforms.ToTensor(),
             transforms.Normalize([.5],[.5])]
        ),
    ),
    batch_size = 128,
    shuffle = True,
    )

In [81]:
optimizer_G = torch.optim.AdamW(generator.parameters(),lr = .0002,betas = (.5,.999))
optimizer_D = torch.optim.AdamW(discriminator.parameters(),lr = .0002,betas = (.5,.999))

In [None]:
epoch = 1000
for epochs in range(epoch):
    for i, (imgs,_) in enumerate(dataloader):

        # Ground truths for adversarial
        valid = torch.ones((imgs.size(0),1), device=device, requires_grad = False)
        fake = torch.zeros((imgs.size(0),1), device=device, requires_grad = False)

        # Input images into tensor
        real_imgs = imgs.to(device)

        # Generator Training, feed rand noise
        optimizer_G.zero_grad() 
        z = torch.randn((imgs.shape[0],100), device=device)

        # Generate images passing in the noise z
        gen_imgs = generator(z)

        # Generators loss, based on the discriminators output
        g_loss = loss(discriminator(gen_imgs),valid)

        g_loss.backward()
        optimizer_G.step()

        # Train Discriminator
        optimizer_D.zero_grad()

        #  Discriminator loss, 1/2 correct + incorrect score
        real_loss = loss(discriminator(real_imgs),valid)
        fake_loss = loss(discriminator(gen_imgs.detach()),fake)
        d_loss = (real_loss + fake_loss)/2

        d_loss.backward()
        optimizer_D.step()

        real_preds = discriminator(real_imgs)
        real_acc = torch.mean(((real_preds > 0.5) == valid).float())

        fake_preds = discriminator(gen_imgs.detach())
        fake_acc = torch.mean(((fake_preds < 0.5) == fake).float())

        d_acc = (real_acc + fake_acc) / 2

        print(f"[Epoch {epochs}/{epoch}] [Batch {i}/{len(dataloader)}] "
              f"[D loss: {d_loss.item():.4f}] [G loss: {g_loss.item():.4f}] "
              f"[D accuracy: {d_acc.item() * 100:.2f}%]")


        
        batches_done = epoch * len(dataloader) + i
        if batches_done % 600 == 0:
            filename = f"images/epoch_{epochs}_batch_{i}.png"
            print(f"Saving image to {filename}")
            save_image(gen_imgs.data[:25], filename, nrow=5, normalize=True)

