In [1]:
import argparse
import os
import numpy as np
import math

import torchvision.transforms as transforms
from torchvision.utils import save_image

from torch.utils.data import DataLoader
from torchvision import datasets
from torch.autograd import Variable

import torch.nn as nn
import torch.nn.functional as F
import torch

os.makedirs("images", exist_ok=True)

# parser = argparse.ArgumentParser()
# parser.add_argument("--n_epochs", type=int, default=200, help="number of epochs of training")
# parser.add_argument("--batch_size", type=int, default=64, help="size of the batches")
# parser.add_argument("--lr", type=float, default=0.0002, help="adam: learning rate")
# parser.add_argument("--b1", type=float, default=0.5, help="adam: decay of first order momentum of gradient")
# parser.add_argument("--b2", type=float, default=0.999, help="adam: decay of first order momentum of gradient")
# parser.add_argument("--n_cpu", type=int, default=8, help="number of cpu threads to use during batch generation")
# parser.add_argument("--latent_dim", type=int, default=100, help="dimensionality of the latent space")
# parser.add_argument("--n_classes", type=int, default=10, help="number of classes for dataset")
# parser.add_argument("--img_size", type=int, default=32, help="size of each image dimension")
# parser.add_argument("--channels", type=int, default=1, help="number of image channels")
# parser.add_argument("--sample_interval", type=int, default=400, help="interval between image sampling")
# opt = parser.parse_args()
# print(opt)

class Arguments():
  def __init__(self):
    self.n_epochs = 200
    self.batch_size = 64
    self.lr = 0.0002
    self.b1 = 0.5
    self.b2 = 0.999
    self.n_cpu = 8
    self.latent_dim = 100
    self.n_classes = 10
    self.img_size = 32
    self.channels = 1
    self.sample_interval = 400

opt = Arguments()

img_shape = (opt.channels, opt.img_size, opt.img_size)

cuda = True if torch.cuda.is_available() else False


class Generator(nn.Module):
    def __init__(self):
        super(Generator, self).__init__()

        self.label_emb = nn.Embedding(opt.n_classes, opt.n_classes)

        def block(in_feat, out_feat, normalize=True):
            layers = [nn.Linear(in_feat, out_feat)]
            if normalize:
                layers.append(nn.BatchNorm1d(out_feat, 0.8))
            layers.append(nn.LeakyReLU(0.2, inplace=True))
            return layers

        self.model = nn.Sequential(
            *block(opt.latent_dim + opt.n_classes, 128, normalize=False),
            *block(128, 256),
            *block(256, 512),
            *block(512, 1024),
            nn.Linear(1024, int(np.prod(img_shape))),
            nn.Tanh()
        )

    def forward(self, noise, labels):
        # Concatenate label embedding and image to produce input
        gen_input = torch.cat((self.label_emb(labels), noise), -1)
        img = self.model(gen_input)
        img = img.view(img.size(0), *img_shape)
        return img


class Discriminator(nn.Module):
    def __init__(self):
        super(Discriminator, self).__init__()

        self.label_embedding = nn.Embedding(opt.n_classes, opt.n_classes)

        self.model = nn.Sequential(
            nn.Linear(opt.n_classes + int(np.prod(img_shape)), 512),
            nn.LeakyReLU(0.2, inplace=True),
            nn.Linear(512, 512),
            nn.Dropout(0.4),
            nn.LeakyReLU(0.2, inplace=True),
            nn.Linear(512, 512),
            nn.Dropout(0.4),
            nn.LeakyReLU(0.2, inplace=True),
            nn.Linear(512, 1),
        )

    def forward(self, img, labels):
        # Concatenate label embedding and image to produce input
        d_in = torch.cat((img.view(img.size(0), -1), self.label_embedding(labels)), -1)
        validity = self.model(d_in)
        return validity


# Loss functions
adversarial_loss = torch.nn.MSELoss()

# Initialize generator and discriminator
generator = Generator()
discriminator = Discriminator()

if cuda:
    generator.cuda()
    discriminator.cuda()
    adversarial_loss.cuda()

# Configure data loader
os.makedirs("../../data/mnist", exist_ok=True)
dataloader = torch.utils.data.DataLoader(
    datasets.MNIST(
        "../../data/mnist",
        train=True,
        download=True,
        transform=transforms.Compose(
            [transforms.Resize(opt.img_size), transforms.ToTensor(), transforms.Normalize([0.5], [0.5])]
        ),
    ),
    batch_size=opt.batch_size,
    shuffle=True,
)

# Optimizers
optimizer_G = torch.optim.Adam(generator.parameters(), lr=opt.lr, betas=(opt.b1, opt.b2))
optimizer_D = torch.optim.Adam(discriminator.parameters(), lr=opt.lr, betas=(opt.b1, opt.b2))

FloatTensor = torch.cuda.FloatTensor if cuda else torch.FloatTensor
LongTensor = torch.cuda.LongTensor if cuda else torch.LongTensor


def sample_image(n_row, batches_done):
    """Saves a grid of generated digits ranging from 0 to n_classes"""
    # Sample noise
    z = Variable(FloatTensor(np.random.normal(0, 1, (n_row ** 2, opt.latent_dim))))
    # Get labels ranging from 0 to n_classes for n rows
    labels = np.array([num for _ in range(n_row) for num in range(n_row)])
    labels = Variable(LongTensor(labels))
    gen_imgs = generator(z, labels)
    save_image(gen_imgs.data, "images/original/%d.png" % batches_done, nrow=n_row, normalize=True)


# ----------
#  Training
# ----------

for epoch in range(opt.n_epochs):
    for i, (imgs, labels) in enumerate(dataloader):

        batch_size = imgs.shape[0]

        # Adversarial ground truths
        valid = Variable(FloatTensor(batch_size, 1).fill_(1.0), requires_grad=False)
        fake = Variable(FloatTensor(batch_size, 1).fill_(0.0), requires_grad=False)

        # Configure input
        real_imgs = Variable(imgs.type(FloatTensor))
        labels = Variable(labels.type(LongTensor))

        # -----------------
        #  Train Generator
        # -----------------

        optimizer_G.zero_grad()

        # Sample noise and labels as generator input
        z = Variable(FloatTensor(np.random.normal(0, 1, (batch_size, opt.latent_dim))))
        gen_labels = Variable(LongTensor(np.random.randint(0, opt.n_classes, batch_size)))

        # Generate a batch of images
        gen_imgs = generator(z, gen_labels)

        # Loss measures generator's ability to fool the discriminator
        validity = discriminator(gen_imgs, gen_labels)
        g_loss = adversarial_loss(validity, valid)

        g_loss.backward()
        optimizer_G.step()

        # ---------------------
        #  Train Discriminator
        # ---------------------

        optimizer_D.zero_grad()

        # Loss for real images
        validity_real = discriminator(real_imgs, labels)
        d_real_loss = adversarial_loss(validity_real, valid)

        # Loss for fake images
        validity_fake = discriminator(gen_imgs.detach(), gen_labels)
        d_fake_loss = adversarial_loss(validity_fake, fake)

        # Total discriminator loss
        d_loss = (d_real_loss + d_fake_loss) / 2

        d_loss.backward()
        optimizer_D.step()
        if i%300==0:
            print(
                "[Epoch %d/%d] [Batch %d/%d] [D loss: %f] [G loss: %f]"
                % (epoch, opt.n_epochs, i, len(dataloader), d_loss.item(), g_loss.item())
            )

        batches_done = epoch * len(dataloader) + i
        if batches_done % opt.sample_interval == 0:
            sample_image(n_row=10, batches_done=batches_done)

[Epoch 0/200] [Batch 0/938] [D loss: 0.563877] [G loss: 1.049378]
[Epoch 0/200] [Batch 300/938] [D loss: 0.099793] [G loss: 0.505118]
[Epoch 0/200] [Batch 600/938] [D loss: 0.406350] [G loss: 1.504236]
[Epoch 0/200] [Batch 900/938] [D loss: 0.354785] [G loss: 1.628829]
[Epoch 1/200] [Batch 0/938] [D loss: 0.096207] [G loss: 0.617486]
[Epoch 1/200] [Batch 300/938] [D loss: 0.081843] [G loss: 0.748146]
[Epoch 1/200] [Batch 600/938] [D loss: 0.109547] [G loss: 0.579923]
[Epoch 1/200] [Batch 900/938] [D loss: 0.115510] [G loss: 0.650354]
[Epoch 2/200] [Batch 0/938] [D loss: 0.101971] [G loss: 0.581561]
[Epoch 2/200] [Batch 300/938] [D loss: 0.099785] [G loss: 0.528440]
[Epoch 2/200] [Batch 600/938] [D loss: 0.113738] [G loss: 0.386501]
[Epoch 2/200] [Batch 900/938] [D loss: 0.097983] [G loss: 0.556454]
[Epoch 3/200] [Batch 0/938] [D loss: 0.135327] [G loss: 0.586107]
[Epoch 3/200] [Batch 300/938] [D loss: 0.137688] [G loss: 0.447845]
[Epoch 3/200] [Batch 600/938] [D loss: 0.121836] [G loss

[Epoch 30/200] [Batch 0/938] [D loss: 0.212085] [G loss: 0.327978]
[Epoch 30/200] [Batch 300/938] [D loss: 0.192269] [G loss: 0.510927]
[Epoch 30/200] [Batch 600/938] [D loss: 0.205106] [G loss: 0.380007]
[Epoch 30/200] [Batch 900/938] [D loss: 0.199391] [G loss: 0.465925]
[Epoch 31/200] [Batch 0/938] [D loss: 0.213559] [G loss: 0.269654]
[Epoch 31/200] [Batch 300/938] [D loss: 0.207497] [G loss: 0.539645]
[Epoch 31/200] [Batch 600/938] [D loss: 0.163928] [G loss: 0.496454]
[Epoch 31/200] [Batch 900/938] [D loss: 0.182382] [G loss: 0.417145]
[Epoch 32/200] [Batch 0/938] [D loss: 0.172102] [G loss: 0.544492]
[Epoch 32/200] [Batch 300/938] [D loss: 0.199699] [G loss: 0.334973]
[Epoch 32/200] [Batch 600/938] [D loss: 0.192352] [G loss: 0.423734]
[Epoch 32/200] [Batch 900/938] [D loss: 0.194896] [G loss: 0.288208]
[Epoch 33/200] [Batch 0/938] [D loss: 0.222022] [G loss: 0.795579]
[Epoch 33/200] [Batch 300/938] [D loss: 0.214806] [G loss: 0.582232]
[Epoch 33/200] [Batch 600/938] [D loss: 0.

[Epoch 59/200] [Batch 900/938] [D loss: 0.192334] [G loss: 0.307849]
[Epoch 60/200] [Batch 0/938] [D loss: 0.131668] [G loss: 0.630270]
[Epoch 60/200] [Batch 300/938] [D loss: 0.139205] [G loss: 0.658366]
[Epoch 60/200] [Batch 600/938] [D loss: 0.128060] [G loss: 0.555829]
[Epoch 60/200] [Batch 900/938] [D loss: 0.148948] [G loss: 0.449543]
[Epoch 61/200] [Batch 0/938] [D loss: 0.136589] [G loss: 0.476546]
[Epoch 61/200] [Batch 300/938] [D loss: 0.151210] [G loss: 0.693677]
[Epoch 61/200] [Batch 600/938] [D loss: 0.144048] [G loss: 0.674751]
[Epoch 61/200] [Batch 900/938] [D loss: 0.144444] [G loss: 0.422457]
[Epoch 62/200] [Batch 0/938] [D loss: 0.143307] [G loss: 0.614684]
[Epoch 62/200] [Batch 300/938] [D loss: 0.168336] [G loss: 0.340708]
[Epoch 62/200] [Batch 600/938] [D loss: 0.178037] [G loss: 0.785117]
[Epoch 62/200] [Batch 900/938] [D loss: 0.154148] [G loss: 0.415526]
[Epoch 63/200] [Batch 0/938] [D loss: 0.165001] [G loss: 0.859073]
[Epoch 63/200] [Batch 300/938] [D loss: 0.

[Epoch 89/200] [Batch 600/938] [D loss: 0.140505] [G loss: 0.870988]
[Epoch 89/200] [Batch 900/938] [D loss: 0.192502] [G loss: 0.438837]
[Epoch 90/200] [Batch 0/938] [D loss: 0.113548] [G loss: 0.610395]
[Epoch 90/200] [Batch 300/938] [D loss: 0.125120] [G loss: 0.873740]
[Epoch 90/200] [Batch 600/938] [D loss: 0.114452] [G loss: 0.737281]
[Epoch 90/200] [Batch 900/938] [D loss: 0.148528] [G loss: 0.348770]
[Epoch 91/200] [Batch 0/938] [D loss: 0.124586] [G loss: 0.431220]
[Epoch 91/200] [Batch 300/938] [D loss: 0.244614] [G loss: 0.226086]
[Epoch 91/200] [Batch 600/938] [D loss: 0.094371] [G loss: 0.762733]
[Epoch 91/200] [Batch 900/938] [D loss: 0.142735] [G loss: 0.545407]
[Epoch 92/200] [Batch 0/938] [D loss: 0.137409] [G loss: 0.439121]
[Epoch 92/200] [Batch 300/938] [D loss: 0.173981] [G loss: 0.290557]
[Epoch 92/200] [Batch 600/938] [D loss: 0.123945] [G loss: 0.691155]
[Epoch 92/200] [Batch 900/938] [D loss: 0.135517] [G loss: 0.977987]
[Epoch 93/200] [Batch 0/938] [D loss: 0.

[Epoch 119/200] [Batch 0/938] [D loss: 0.076513] [G loss: 0.685787]
[Epoch 119/200] [Batch 300/938] [D loss: 0.098736] [G loss: 0.767387]
[Epoch 119/200] [Batch 600/938] [D loss: 0.093218] [G loss: 0.622034]
[Epoch 119/200] [Batch 900/938] [D loss: 0.140459] [G loss: 0.399323]
[Epoch 120/200] [Batch 0/938] [D loss: 0.114701] [G loss: 0.437412]
[Epoch 120/200] [Batch 300/938] [D loss: 0.135472] [G loss: 0.409887]
[Epoch 120/200] [Batch 600/938] [D loss: 0.288013] [G loss: 1.047294]
[Epoch 120/200] [Batch 900/938] [D loss: 0.168577] [G loss: 0.893852]
[Epoch 121/200] [Batch 0/938] [D loss: 0.192391] [G loss: 0.251037]
[Epoch 121/200] [Batch 300/938] [D loss: 0.076166] [G loss: 0.698154]
[Epoch 121/200] [Batch 600/938] [D loss: 0.127043] [G loss: 0.396835]
[Epoch 121/200] [Batch 900/938] [D loss: 0.097042] [G loss: 0.566063]
[Epoch 122/200] [Batch 0/938] [D loss: 0.296715] [G loss: 0.136358]
[Epoch 122/200] [Batch 300/938] [D loss: 0.083972] [G loss: 0.702947]
[Epoch 122/200] [Batch 600/9

[Epoch 148/200] [Batch 300/938] [D loss: 0.149274] [G loss: 0.416853]
[Epoch 148/200] [Batch 600/938] [D loss: 0.296557] [G loss: 1.053385]
[Epoch 148/200] [Batch 900/938] [D loss: 0.084732] [G loss: 0.664017]
[Epoch 149/200] [Batch 0/938] [D loss: 0.042029] [G loss: 0.905223]
[Epoch 149/200] [Batch 300/938] [D loss: 0.076429] [G loss: 0.634982]
[Epoch 149/200] [Batch 600/938] [D loss: 0.077851] [G loss: 0.610547]
[Epoch 149/200] [Batch 900/938] [D loss: 0.153270] [G loss: 0.361222]
[Epoch 150/200] [Batch 0/938] [D loss: 0.084640] [G loss: 0.625465]
[Epoch 150/200] [Batch 300/938] [D loss: 0.096516] [G loss: 0.950441]
[Epoch 150/200] [Batch 600/938] [D loss: 0.082571] [G loss: 0.843550]
[Epoch 150/200] [Batch 900/938] [D loss: 0.074898] [G loss: 0.651730]
[Epoch 151/200] [Batch 0/938] [D loss: 0.068270] [G loss: 0.744203]
[Epoch 151/200] [Batch 300/938] [D loss: 0.090057] [G loss: 0.770030]
[Epoch 151/200] [Batch 600/938] [D loss: 0.155975] [G loss: 0.649675]
[Epoch 151/200] [Batch 900

[Epoch 177/200] [Batch 600/938] [D loss: 0.094875] [G loss: 0.540559]
[Epoch 177/200] [Batch 900/938] [D loss: 0.048919] [G loss: 0.825742]
[Epoch 178/200] [Batch 0/938] [D loss: 0.138401] [G loss: 0.389016]
[Epoch 178/200] [Batch 300/938] [D loss: 0.067839] [G loss: 0.795975]
[Epoch 178/200] [Batch 600/938] [D loss: 0.049651] [G loss: 0.957341]
[Epoch 178/200] [Batch 900/938] [D loss: 0.069743] [G loss: 0.646859]
[Epoch 179/200] [Batch 0/938] [D loss: 0.165743] [G loss: 0.613261]
[Epoch 179/200] [Batch 300/938] [D loss: 0.082292] [G loss: 0.645154]
[Epoch 179/200] [Batch 600/938] [D loss: 0.068295] [G loss: 0.780937]
[Epoch 179/200] [Batch 900/938] [D loss: 0.057608] [G loss: 0.857681]
[Epoch 180/200] [Batch 0/938] [D loss: 0.067419] [G loss: 0.806958]
[Epoch 180/200] [Batch 300/938] [D loss: 0.095359] [G loss: 0.535465]
[Epoch 180/200] [Batch 600/938] [D loss: 0.057337] [G loss: 0.829018]
[Epoch 180/200] [Batch 900/938] [D loss: 0.234255] [G loss: 0.256645]
[Epoch 181/200] [Batch 0/9