Code taken from [this link](https://github.com/wiseodd/generative-models/blob/master/GAN/conditional_gan/cgan_pytorch.py)

In [10]:
import torch
import torch.nn.functional as nn
import torch.autograd as autograd
import torch.optim as optim
import numpy as np
import matplotlib.pyplot as plt
import matplotlib.gridspec as gridspec
import os
from torch.autograd import Variable
import torchvision.datasets as dsets
import torchvision.transforms as transforms

In [69]:
train_data = dsets.MNIST(root="./datasets/", train=True, transform=transforms.ToTensor())

In [70]:
all_data = train_data.train_data.float().view(-1, 784)
print(all_data.size())

torch.Size([60000, 784])


In [72]:
labels = train_data.train_labels
print("labels size: %d" %(labels.size()))

labels size: 60000


In [73]:
#mnist = input_data.read_data_sets('../../MNIST_data', one_hot=True)
mb_size = 64
Z_dim = 100

In [117]:
#X_dim = mnist.train.images.shape[1]
X_dim = all_data.size(1)
#y_dim = mnist.train.labels.shape[1]
y_dim = labels.size(0)

h_dim = 128
cnt = 0
lr = 1e-3
epochs = 1
num_classes = 10

In [75]:
def xavier_init(size):
    in_dim = size[0]
    xavier_stddev = 1. / np.sqrt(in_dim / 2.)
    return Variable(torch.randn(*size) * xavier_stddev, requires_grad=True)

In [142]:
""" ==================== GENERATOR ======================== """

Wzh = xavier_init(size=[Z_dim + y_dim, h_dim])
bzh = Variable(torch.zeros(h_dim), requires_grad=True)

Whx = xavier_init(size=[h_dim, X_dim])
bhx = Variable(torch.zeros(X_dim), requires_grad=True)


def G(z, c):
    print(z.size(), c.size())
    inputs = torch.cat([z, c], 1)
    h = nn.relu(inputs @ Wzh + bzh.repeat(inputs.size(0), 1))
    X = nn.sigmoid(h @ Whx + bhx.repeat(h.size(0), 1))
    return X

In [143]:
""" ==================== DISCRIMINATOR ======================== """

Wxh = xavier_init(size=[X_dim + y_dim, h_dim])
bxh = Variable(torch.zeros(h_dim), requires_grad=True)

Why = xavier_init(size=[h_dim, 1])
bhy = Variable(torch.zeros(1), requires_grad=True)


def D(X, c):
    print(X.size(), c.size())
    inputs = torch.cat([X, c], 1)
    h = nn.relu(inputs @ Wxh + bxh.repeat(inputs.size(0), 1))
    y = nn.sigmoid(h @ Why + bhy.repeat(h.size(0), 1))
    return y

In [144]:
G_params = [Wzh, bzh, Whx, bhx]
D_params = [Wxh, bxh, Why, bhy]
params = G_params + D_params

In [145]:
train_loader = torch.utils.data.DataLoader(dataset=train_data, shuffle=True, 
                                           batch_size=mb_size, drop_last=True)

In [146]:
(img, label) = iter(train_loader).next()

In [147]:
label.type(), label.size()

('torch.LongTensor', torch.Size([64]))

In [148]:
blah = label
blahz = Variable(torch.randn(mb_size, Z_dim))

In [149]:
eye = torch.eye(4)
blahl = torch.LongTensor([0, 2, 3])
blahl = eye[blahl]
blahl


 1  0  0  0
 0  0  1  0
 0  0  0  1
[torch.FloatTensor of size 3x4]

In [150]:
def one_hot_vector(label, num_classes):
    eye = torch.eye(num_classes)
    return eye[label]
    

In [152]:
""" ===================== TRAINING ======================== """


def reset_grad():
    for p in params:
        p.grad.data.zero_()


G_solver = optim.Adam(G_params, lr=1e-3)
D_solver = optim.Adam(D_params, lr=1e-3)

ones_label = Variable(torch.ones(mb_size))
zeros_label = Variable(torch.zeros(mb_size))


for epoch in range(epochs):
    for i, (img, label) in enumerate(train_loader):
        # Sample data
        z = Variable(torch.randn(mb_size, Z_dim))
        #X, c = mnist.train.next_batch(mb_size)
        X = img.view(mb_size, -1) # <batch_size, 784>
        c = one_hot_vector(label, num_classes)
        X = Variable(X)
        c = Variable(c)
        
        # Dicriminator forward-loss-backward-update
        G_sample = G(z, c)
        D_real = D(X, c)
        D_fake = D(G_sample, c)
        
        D_loss_real = nn.binary_cross_entropy(D_real, ones_label)
        D_loss_fake = nn.binary_cross_entropy(D_fake, zeros_label)
        D_loss = D_loss_real + D_loss_fake
        
        D_loss.backward()
        D_solver.step()
        
        # Housekeeping - reset gradient
        reset_grad()

        # Generator forward-loss-backward-update
        z = Variable(torch.randn(mb_size, Z_dim))
        G_sample = G(z, c)
        D_fake = D(G_sample, c)

        G_loss = nn.binary_cross_entropy(D_fake, ones_label)

        G_loss.backward()
        G_solver.step()

        # Housekeeping - reset gradient
        reset_grad()

        # Print and plot every now and then
        if it % 1000 == 0:
            print('Iter-{}; D_loss: {}; G_loss: {}'.format(it, D_loss.data.numpy(), G_loss.data.numpy()))

            c = np.zeros(shape=[mb_size, y_dim], dtype='float32')
            c[:, np.random.randint(0, 10)] = 1.
            c = Variable(torch.from_numpy(c))
            samples = G(z, c).data.numpy()[:16]

            fig = plt.figure(figsize=(4, 4))
            gs = gridspec.GridSpec(4, 4)
            gs.update(wspace=0.05, hspace=0.05)

            for i, sample in enumerate(samples):
                ax = plt.subplot(gs[i])
                plt.axis('off')
                ax.set_xticklabels([])
                ax.set_yticklabels([])
                ax.set_aspect('equal')
                plt.imshow(sample.reshape(28, 28), cmap='Greys_r')

            if not os.path.exists('out/'):
                os.makedirs('out/')

            plt.savefig('out/{}.png'.format(str(cnt).zfill(3)), bbox_inches='tight')
            cnt += 1
    plt.close(fig)

torch.Size([64, 100]) torch.Size([64, 10])


RuntimeError: size mismatch, m1: [64 x 110], m2: [60100 x 128] at d:\downloads\pytorch-master-1\torch\lib\th\generic/THTensorMath.c:1238