In [15]:
import os
import numpy as np

import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import DataLoader
from torch.autograd import Variable

import torchvision.transforms as transforms
from torchvision.utils import save_image
from torchvision import datasets

from einops.layers.torch import Rearrange

In [2]:
num_eps=10
bsize=32
lrate=0.001
lat_dimension=64
image_sz=64
chnls=1
logging_intv=200

In [43]:
class GANGenerator(nn.Module):
    def __init__(self):
        super().__init__()
        self.inp_sz = image_sz // 4
        self.lin = nn.Linear(lat_dimension, 128 * self.inp_sz ** 2) # output 16x16x128 flattened
        self.bn1 = nn.BatchNorm2d(128)
        self.up1 = nn.Upsample(scale_factor=2)
        self.cn1 = nn.Conv2d(128, 128, 3, stride=1, padding=1)
        self.bn2 = nn.BatchNorm2d(128, 0.8)
        self.rl1 = nn.LeakyReLU(0.2, inplace=True)
        self.up2 = nn.Upsample(scale_factor=2)
        self.cn2 = nn.Conv2d(128, 64, 3, stride=1, padding=1)
        self.bn3 = nn.BatchNorm2d(64, 0.8)
        self.rl2 = nn.LeakyReLU(0.2, inplace=True)
        self.cn3 = nn.Conv2d(64, chnls, 3, stride=1, padding=1)
        self.act = nn.Tanh() # data is normalized to mean of mean=0.5 and std=0.5 [-1, 1] range so tanh can be used
    
    def forward(self, x):
        x = self.lin(x)
        x = rearrange(x, 'b (c h w) -> b c h w', h=self.inp_sz, w=self.inp_sz)
        x = self.bn1(x)
        x = self.up1(x)
        x = self.cn1(x)
        x = self.bn2(x)
        x = self.rl1(x)
        x = self.up2(x)
        x = self.cn2(x)
        x = self.bn3(x)
        x = self.rl2(x)
        x = self.cn3(x)
        out = self.act(x)
        return out

In [44]:
class GANDiscriminator(nn.Module):
    def __init__(self):
        super().__init__()
        def disc_module(ip_chnls, op_chnls, bnorm=True):
            mod = [nn.Conv2d(ip_chnls, op_chnls, 3, 2, 1),
                   nn.LeakyReLU(0.2, inplace=True),
                   nn.Dropout2d(0.25)]
            if bnorm:
                mod += [nn.BatchNorm2d(op_chnls, 0.8)]
            return mod
        self.disc_model = nn.Sequential(
            *disc_module(chnls, 16, bnorm=False),
            *disc_module(16,32),
            *disc_module(32,64),
            *disc_module(64,128),
        )

        ds_size = image_sz // 2 ** 4
        self.adverse_lyr = nn.Sequential(nn.Linear(128 * ds_size ** 2, 1), nn.Sigmoid())

    def forward(self, x):
        x = self.disc_model(x)
        x = rearrange(x, 'b c h w -> b (c h w)', b=x.shape[0])
        out = self.adverse_lyr(x)
        return out

In [45]:
gen = GANGenerator()
disc = GANDiscriminator()

adv_loss_func = torch.nn.BCELoss()

In [16]:
dloader = torch.utils.data.DataLoader(
    datasets.MNIST(
        "./data/mnist/",
        download=True,
        transform=transforms.Compose(
            [transforms.Resize((image_sz, image_sz)),
             transforms.ToTensor(), transforms.Normalize([0.5], [0.5])]
        ),
    ),
    batch_size=bsize,
    shuffle=True,
)

opt_gen = torch.optim.Adam(gen.parameters(), lr=lrate)
opt_disc = torch.optim.Adam(disc.parameters(), lr=lrate)

Downloading http://yann.lecun.com/exdb/mnist/train-images-idx3-ubyte.gz
Failed to download (trying next):
<urlopen error [SSL: CERTIFICATE_VERIFY_FAILED] certificate verify failed: certificate has expired (_ssl.c:1006)>

Downloading https://ossci-datasets.s3.amazonaws.com/mnist/train-images-idx3-ubyte.gz
Downloading https://ossci-datasets.s3.amazonaws.com/mnist/train-images-idx3-ubyte.gz to ./data/mnist/MNIST/raw/train-images-idx3-ubyte.gz


100%|██████████████████████████████████████████████████████████| 9912422/9912422 [00:00<00:00, 14444160.77it/s]


Extracting ./data/mnist/MNIST/raw/train-images-idx3-ubyte.gz to ./data/mnist/MNIST/raw

Downloading http://yann.lecun.com/exdb/mnist/train-labels-idx1-ubyte.gz
Failed to download (trying next):
<urlopen error [SSL: CERTIFICATE_VERIFY_FAILED] certificate verify failed: certificate has expired (_ssl.c:1006)>

Downloading https://ossci-datasets.s3.amazonaws.com/mnist/train-labels-idx1-ubyte.gz
Downloading https://ossci-datasets.s3.amazonaws.com/mnist/train-labels-idx1-ubyte.gz to ./data/mnist/MNIST/raw/train-labels-idx1-ubyte.gz


100%|████████████████████████████████████████████████████████████████| 28881/28881 [00:00<00:00, 463764.28it/s]


Extracting ./data/mnist/MNIST/raw/train-labels-idx1-ubyte.gz to ./data/mnist/MNIST/raw

Downloading http://yann.lecun.com/exdb/mnist/t10k-images-idx3-ubyte.gz
Failed to download (trying next):
<urlopen error [SSL: CERTIFICATE_VERIFY_FAILED] certificate verify failed: certificate has expired (_ssl.c:1006)>

Downloading https://ossci-datasets.s3.amazonaws.com/mnist/t10k-images-idx3-ubyte.gz
Downloading https://ossci-datasets.s3.amazonaws.com/mnist/t10k-images-idx3-ubyte.gz to ./data/mnist/MNIST/raw/t10k-images-idx3-ubyte.gz


100%|███████████████████████████████████████████████████████████| 1648877/1648877 [00:00<00:00, 4276642.64it/s]


Extracting ./data/mnist/MNIST/raw/t10k-images-idx3-ubyte.gz to ./data/mnist/MNIST/raw

Downloading http://yann.lecun.com/exdb/mnist/t10k-labels-idx1-ubyte.gz
Failed to download (trying next):
<urlopen error [SSL: CERTIFICATE_VERIFY_FAILED] certificate verify failed: certificate has expired (_ssl.c:1006)>

Downloading https://ossci-datasets.s3.amazonaws.com/mnist/t10k-labels-idx1-ubyte.gz
Downloading https://ossci-datasets.s3.amazonaws.com/mnist/t10k-labels-idx1-ubyte.gz to ./data/mnist/MNIST/raw/t10k-labels-idx1-ubyte.gz


100%|█████████████████████████████████████████████████████████████████| 4542/4542 [00:00<00:00, 1055080.24it/s]

Extracting ./data/mnist/MNIST/raw/t10k-labels-idx1-ubyte.gz to ./data/mnist/MNIST/raw






In [46]:
os.makedirs("./images_mnist", exist_ok=True)

for ep in range(num_eps):
    for idx, (images, _) in enumerate(dloader):
        good_img = Variable(torch.FloatTensor(images.shape[0], 1).fill_(1.0), requires_grad=False)
        bad_img = Variable(torch.FloatTensor(images.shape[0], 1).fill_(0.0), requires_grad=False)

        actual_images = Variable(images.type(torch.FloatTensor))

        opt_gen.zero_grad()

        noise = Variable(torch.FloatTensor(np.random.normal(0, 1, (images.shape[0], lat_dimension))))
        gen_images = gen(noise)

        generator_loss = adv_loss_func(disc(gen_images), good_img)
        generator_loss.backward()
        opt_gen.step()

        opt_disc.zero_grad()

        actual_image_loss = adv_loss_func(disc(actual_images), good_img) # loss on identifying real images
        fake_image_loss = adv_loss_func(disc(gen_images.detach()), bad_img) # loss on identifying fake images
        discriminator_loss = (actual_image_loss + fake_image_loss) / 2

        discriminator_loss.backward()
        opt_disc.step()

        batches_completed = ep * len(dloader) + idx
        if batches_completed % logging_intv == 0:
            print(f'epoch number {ep} | batch number {idx} | generator loss = {generator_loss.item()} | discriminator loss = {discriminator_loss.item()}')
            save_image(gen_images.data[:25], f'images_mnist/{batches_completed}.png')

epoch number 0 | batch number 0 | generator loss = 0.6900531053543091 | discriminator loss = 0.6957554817199707
epoch number 0 | batch number 200 | generator loss = 0.6902207732200623 | discriminator loss = 0.6935006380081177
epoch number 0 | batch number 400 | generator loss = 0.6897798776626587 | discriminator loss = 0.6943982839584351
epoch number 0 | batch number 600 | generator loss = 0.6902846097946167 | discriminator loss = 0.6958463191986084
epoch number 0 | batch number 800 | generator loss = 0.6900146007537842 | discriminator loss = 0.6948859691619873
epoch number 0 | batch number 1000 | generator loss = 0.6898703575134277 | discriminator loss = 0.6949820518493652
epoch number 0 | batch number 1200 | generator loss = 0.6901360750198364 | discriminator loss = 0.6944220066070557
epoch number 0 | batch number 1400 | generator loss = 0.6895678639411926 | discriminator loss = 0.6951930522918701
epoch number 0 | batch number 1600 | generator loss = 0.6904693245887756 | discriminato

KeyboardInterrupt: 