In [1]:

%reload_ext autoreload
%autoreload 2
%matplotlib inline

In [None]:
import warnings
warnings.filterwarnings('ignore')

In [2]:
import torch
import torch.nn.functional as nn
import torch.autograd as autograd
import torch.optim as optim
import numpy as np
import matplotlib.pyplot as plt
import matplotlib.gridspec as gridspec
import os
from torch.autograd import Variable
from tensorflow.examples.tutorials.mnist import input_data


  from ._conv import register_converters as _register_converters


In [3]:
mnist = input_data.read_data_sets('../../MNIST_data', one_hot=True)
mb_size = 32 
Z_dim = 16  # Random noise input for generator
X_dim = mnist.train.images.shape[1]
y_dim = mnist.train.labels.shape[1]
h_dim = 128   # Hidden layer dimension
cnt = 0
lr = 1e-3


Instructions for updating:
Please use alternatives such as official/mnist/dataset.py from tensorflow/models.
Instructions for updating:
Please write your own downloading logic.
Instructions for updating:
Please use tf.data to implement this functionality.
Extracting ../../MNIST_data\train-images-idx3-ubyte.gz
Instructions for updating:
Please use tf.data to implement this functionality.
Extracting ../../MNIST_data\train-labels-idx1-ubyte.gz
Instructions for updating:
Please use tf.one_hot on tensors.
Extracting ../../MNIST_data\t10k-images-idx3-ubyte.gz
Extracting ../../MNIST_data\t10k-labels-idx1-ubyte.gz
Instructions for updating:
Please use alternatives such as official/mnist/dataset.py from tensorflow/models.


In [4]:
def xavier_init(size):
    in_dim = size[0]
    xavier_stddev = 1. / np.sqrt(in_dim / 2.)
    return Variable(torch.randn(*size) * xavier_stddev, requires_grad=True)


 # ##  ==================== GENERATOR ========================


In [5]:
""" ==================== GENERATOR ======================== """

Wzh = xavier_init(size=[Z_dim + 10, h_dim])
bzh = Variable(torch.zeros(h_dim), requires_grad=True)

Whx = xavier_init(size=[h_dim, X_dim])
bhx = Variable(torch.zeros(X_dim), requires_grad=True)


def G(z, c):
    inputs = torch.cat([z, c], 1)
    h = nn.relu(inputs @ Wzh + bzh.repeat(inputs.size(0), 1))
    X = nn.sigmoid(h @ Whx + bhx.repeat(h.size(0), 1))
    return X


#  ==================== DISCRIMINATOR ========================


In [6]:

Wxh = xavier_init(size=[X_dim, h_dim])
bxh = Variable(torch.zeros(h_dim), requires_grad=True)

Why = xavier_init(size=[h_dim, 1])
bhy = Variable(torch.zeros(1), requires_grad=True)


def D(X):
    h = nn.relu(X @ Wxh + bxh.repeat(X.size(0), 1))
    y = nn.sigmoid(h @ Why + bhy.repeat(h.size(0), 1))
    return y


### ====================== Q(c|X) ========================== 


In [7]:

Wqxh = xavier_init(size=[X_dim, h_dim])
bqxh = Variable(torch.zeros(h_dim), requires_grad=True)

Whc = xavier_init(size=[h_dim, 10])
bhc = Variable(torch.zeros(10), requires_grad=True)


def Q(X):
    h = nn.relu(X @ Wqxh + bqxh.repeat(X.size(0), 1))
    c = nn.softmax(h @ Whc + bhc.repeat(h.size(0), 1))
    return c


G_params = [Wzh, bzh, Whx, bhx]
D_params = [Wxh, bxh, Why, bhy]
Q_params = [Wqxh, bqxh, Whc, bhc]
params = G_params + D_params + Q_params


In [8]:
""" ===================== TRAINING ======================== """


def reset_grad():
    for p in params:
        if p.grad is not None:
            data = p.grad.data
            p.grad = Variable(data.new().resize_as_(data).zero_())


G_solver = optim.Adam(G_params, lr=1e-3)
D_solver = optim.Adam(D_params, lr=1e-3)
Q_solver = optim.Adam(G_params + Q_params, lr=1e-3)


### Generate a categorical distribution, with equal probability for each of the ten elements.
# Remember that we start with a random categorical distribution [0.2 0.2 0.2 0.2 0.2 0.2 0.2 0.2 0.2 0.2 ]
# At the end, we expect InfoGAN to embed latent contents into categorical distribution that is , categorical representation
# for each digit (Like one-hot encoding)

def sample_c(size):
    c = np.random.multinomial(1, 10*[0.1], size=size)
    c = Variable(torch.from_numpy(c.astype('float32')))
    return c
### We start with sample_c



## 100000

for it in range(100000):
    
    
    # Sample data
    
    X, _ = mnist.train.next_batch(mb_size)   ## Get a batch of size 32
    X = Variable(torch.from_numpy(X))        ## Convert to torch variable for updating it.    

    z = Variable(torch.randn(mb_size, Z_dim)) ## z (32,16) ==> 16-size random vector for each example
    c = sample_c(mb_size)                     ## Create uniform categorical distribution for the batch. (32,10) [[0.2,..0.2],[0.2,..,0.2]]  

    

    # Dicriminator forward-loss-backward-update (Only Train Disc but not generator and Q_c)

# ================================== START ==========================================

    G_sample = G(z, c)                       ## Generate images given random noise z and prior c.
    D_real = D(X)                            ## Make predictions on the actual images.Ideally D_real shud classify it as real.
    D_fake = D(G_sample)                     ## Make predictions on the fake (generated) images.

    D_loss = -torch.mean(torch.log(D_real + 1e-8) + torch.log(1 - D_fake + 1e-8))
    
    ## Classify real images as True (1) and fake images as False (0)

    D_loss.backward()   ## Backpropagate (Calculate gradients)
    D_solver.step()     ## Update weights

    # Housekeeping - reset gradient
    reset_grad()
# ================================= STOP =============================================



    # Generator forward-loss-backward-update

# ================================== START ==========================================


    G_sample = G(z, c)                    ## Generate images given random noise z and prior c.
    D_fake = D(G_sample)                  ## Make predictions on the fake (generated) images.

    G_loss = -torch.mean(torch.log(D_fake + 1e-8))
    
    ## Classify fake images as True (1)

    G_loss.backward()
    G_solver.step()

    # Housekeeping - reset gradient
    reset_grad()

    
# ================================== STOP ==========================================


    # Q forward-loss-backward-update
    
# ================================== START ==========================================

    G_sample = G(z, c)                    ## Generate images given random noise z and prior c.
    Q_c_given_x = Q(G_sample)             ## Generate Q_C (New C)

    crossent_loss = torch.mean(-torch.sum(c * torch.log(Q_c_given_x + 1e-8), dim=1))
    mi_loss = crossent_loss

    mi_loss.backward()
    Q_solver.step()

    # Housekeeping - reset gradient
    reset_grad()

# ================================== STOP ==========================================

    # Print and plot every now and then
    if it % 1000 == 0:
        idx = np.random.randint(0, 10)
        c = np.zeros([mb_size, 10])
        c[range(mb_size), idx] = 1
        c = Variable(torch.from_numpy(c.astype('float32')))
        samples = G(z, c).data.numpy()[:16]

        print('Iter-{}; D_loss: {}; G_loss: {}; Idx: {}'
              .format(it, D_loss.data.numpy(), G_loss.data.numpy(), idx))

        fig = plt.figure(figsize=(4, 4))
        gs = gridspec.GridSpec(4, 4)
        gs.update(wspace=0.05, hspace=0.05)

        for i, sample in enumerate(samples):
            ax = plt.subplot(gs[i])
            plt.axis('off')
            ax.set_xticklabels([])
            ax.set_yticklabels([])
            ax.set_aspect('equal')
            plt.imshow(sample.reshape(28, 28), cmap='Greys_r')

        if not os.path.exists('out/'):
            os.makedirs('out/')

        plt.savefig('out/{}.png'
                    .format(str(cnt).zfill(3)+"label_"+str(idx)), bbox_inches='tight')
        cnt += 1
        plt.close(fig)



Iter-0; D_loss: 1.4468754529953003; G_loss: 2.042823314666748; Idx: 8
Iter-1000; D_loss: 0.15183107554912567; G_loss: 4.062690258026123; Idx: 8
Iter-2000; D_loss: 0.12419113516807556; G_loss: 3.6387805938720703; Idx: 5
Iter-3000; D_loss: 0.15582631528377533; G_loss: 4.37735652923584; Idx: 6
Iter-4000; D_loss: 0.14815331995487213; G_loss: 3.676138401031494; Idx: 3
Iter-5000; D_loss: 0.508362352848053; G_loss: 4.521950721740723; Idx: 9
Iter-6000; D_loss: 0.5160889029502869; G_loss: 2.9797861576080322; Idx: 5
Iter-7000; D_loss: 0.4400549530982971; G_loss: 2.8511452674865723; Idx: 7
Iter-8000; D_loss: 0.8536098003387451; G_loss: 2.386115312576294; Idx: 5
Iter-9000; D_loss: 1.0276927947998047; G_loss: 2.3005568981170654; Idx: 8
Iter-10000; D_loss: 0.6924546360969543; G_loss: 1.9490838050842285; Idx: 2
Iter-11000; D_loss: 0.5206872224807739; G_loss: 2.314621925354004; Idx: 0
Iter-12000; D_loss: 0.5898966789245605; G_loss: 2.147003173828125; Idx: 2
Iter-13000; D_loss: 0.6004413962364197; G_lo

KeyboardInterrupt: 