Parameter
====================

In [2]:
batch_size = 64
latent_dim = 100
label_classes = 10
img_size = 32
channels = 1

Packages
=================

In [3]:
import torch
from torch import nn, optim
import matplotlib.pyplot as plt
from collections import OrderedDict
from torch.autograd.variable import Variable
import seaborn as sns
import numpy as np

Dataset
============


In [4]:
from tensorflow.keras.datasets import mnist
(X_train_numpy, Y_train_numpy), (X_test_numpy, Y_test_numpy) = mnist.load_data()
print(f"x_train's shape is {X_train_numpy.shape}\n x_test's shape is {X_test_numpy.shape}\n \
y_train's shape is {Y_train_numpy.shape}\n y_test's shape is {Y_test_numpy.shape}")

Downloading data from https://storage.googleapis.com/tensorflow/tf-keras-datasets/mnist.npz
x_train's shape is (60000, 28, 28)
 x_test's shape is (10000, 28, 28)
 y_train's shape is (60000,)
 y_test's shape is (10000,)


In [5]:
X_train = torch.div(torch.from_numpy(X_train_numpy).type(torch.FloatTensor), 127.5)[0:10000].cuda() - 1

Generator
=======================

In [6]:
class Generator(torch.nn.Module):
    def __init__(self):
        super(Generator, self).__init__()

        self.label_emb = nn.Embedding(label_classes, latent_dim) # a dictionary whose size = (n_classes * latent_dim)
        # self.noise = torch.FloatTensor(np.random.normal(0, 1, (label_classes, latent_dim)))

        self.linear = nn.Sequential(OrderedDict([
            'fc1', nn.Linear(latent_dim, 128 * (img_size // 4) ** 2)
        ]))

        self.conv = nn.Sequential(OrderedDict([
            'bn1', nn.BatchNorm2d(128),
            'us1', nn.Upsample(scale_factor=2),
            'conv1', nn.Conv2d(128, 128, 3, stride=1, padding=1),
            'bn2', nn.BatchNorm2d(128, 0.8),
            'lr1', nn.LeakyReLU(0.2, inplace=True),
            'us2', nn.Upsample(scale_factor=2),
            'conv2', nn.Conv2d(128, 64, 3, stride=1, padding=1),
            'bn3', nn.BatchNorm2d(64, 0.8),
            'lr2', nn.LeakyReLU(0.2, inplace=True),
            'conv3', nn.Conv2d(64, opt.channels, 3, stride=1, padding=1),
            'th1', nn.Tanh(),
        ]))

    def forward(self, x):
        x = torch.mul(self.label_emb(x), 
            torch.cuda.FloatTensor(
                np.random.normal(0, 1, (x.size()[0], latent_dim)))) # add noise, size = (batch, latent_dim)
        x = self.linear(x)
        x = x.view(x.size()[0], 128, img_size // 4, img_size // 4)
        x = self.conv(x)
        return x

In [50]:
generator = Generator().cuda()
gen_labels = Variable(torch.cuda.LongTensor(np.random.randint(0, n_classes, batch_size)))
generator(gen_labels).size()

torch.Size([64, 100])

=========================

In [8]:
z = Variable(torch.cuda.FloatTensor(np.random.normal(0, 1, (64, 100))))
z.size()

torch.Size([64, 100])

In [11]:
gen_labels = Variable(torch.cuda.LongTensor(np.random.randint(0, n_classes, batch_size)))
gen_labels

tensor([3, 7, 5, 5, 4, 6, 5, 8, 5, 8, 8, 9, 6, 3, 7, 9, 9, 3, 5, 0, 4, 9, 3, 8,
        9, 2, 5, 7, 1, 9, 1, 3, 5, 1, 3, 1, 5, 5, 3, 8, 2, 9, 6, 7, 6, 9, 4, 2,
        1, 8, 2, 4, 4, 9, 6, 9, 2, 1, 8, 5, 8, 7, 4, 3], device='cuda:0')

In [16]:
label_emb = nn.Embedding(n_classes, latent_dim).cuda()

In [25]:
label_emb(gen_labels).size()

torch.Size([64, 100])

In [26]:
input = torch.mul(label_emb(gen_labels), z)
input.size()

torch.Size([64, 100])

In [24]:
label_emb(torch.LongTensor([3, 7, 5, 7, 4, 3]).cuda())

tensor([[ 2.7215e-02,  7.8519e-01, -8.8001e-01,  2.1255e-01, -2.4295e-01,
         -8.5597e-01, -5.4428e-01,  7.2043e-01, -2.3116e-01,  8.9193e-01,
         -1.9086e-01,  2.4456e-01,  2.9121e-01,  6.8968e-02,  3.4306e-01,
          1.3845e+00, -1.6111e+00, -1.1383e+00, -1.2362e+00,  5.9456e-02,
          4.4886e-01, -1.3046e+00,  1.8611e-01,  1.1382e+00, -1.1620e+00,
         -7.1407e-01, -6.5871e-01,  2.5875e-01,  6.2464e-01, -3.5525e-02,
         -6.2551e-01,  3.2627e-01,  1.4171e-01,  1.1860e-02,  1.2237e+00,
         -8.1045e-01, -8.0309e-01, -1.2119e+00,  8.5478e-01, -1.0928e+00,
          1.4775e+00, -2.8349e-01,  7.8462e-02, -1.0319e+00, -8.6210e-01,
         -7.1445e-02, -5.9968e-01, -5.7613e-01,  1.5089e-01,  4.9831e-01,
         -1.2920e+00,  1.4540e+00, -1.9555e+00,  2.5548e-01, -1.4829e+00,
         -3.2472e-02,  4.3744e-01,  1.5100e+00, -3.6293e-01,  2.1843e+00,
          2.5614e-01,  1.1126e+00, -4.4586e-01,  1.0466e+00, -1.0846e+00,
         -6.5934e-01,  5.5246e-01, -1.