In [1]:
# -*- coding: utf-8 -*-
# @Author: aaronlai

import torch.nn as nn
import torch.nn.functional as F


class InfoGAN_Discriminator(nn.Module):

    def __init__(self, n_layer=3, n_conti=2, n_discrete=1,
                 num_category=10, use_gpu=False, featmap_dim=256,
                 n_channel=1):
        """
        InfoGAN Discriminator, have additional outputs for latent codes.
        Architecture brought from DCGAN.
        """
        super(InfoGAN_Discriminator, self).__init__()
        self.n_layer = n_layer
        self.n_conti = n_conti
        self.n_discrete = n_discrete
        self.num_category = num_category

        # Discriminator
        self.featmap_dim = featmap_dim

        convs = []
        BNs = []
        for layer in range(self.n_layer):
            if layer == (self.n_layer - 1):
                n_conv_in = n_channel
            else:
                n_conv_in = int(featmap_dim / (2**(layer + 1)))
            n_conv_out = int(featmap_dim / (2**layer))

            _conv = nn.Conv2d(n_conv_in, n_conv_out, kernel_size=5,
                              stride=2, padding=2)
            if use_gpu:
                _conv = _conv.cuda()
            convs.append(_conv)

            if layer != (self.n_layer - 1):
                _BN = nn.BatchNorm2d(n_conv_out)
                if use_gpu:
                    _BN = _BN.cuda()
                BNs.append(_BN)

        # output layer - prob(real) and auxiliary distributions Q(c_j|x)
        n_hidden = featmap_dim * 4 * 4
        n_output = 1 + n_conti + n_discrete * num_category
        self.fc = nn.Linear(n_hidden, n_output)

        # register all nn modules
        self.convs = nn.ModuleList(convs)
        self.BNs = nn.ModuleList(BNs)

    def forward(self, x):
        """
        Output the probability of being in real dataset
        plus the conditional distributions of latent codes.
        """
        for layer in range(self.n_layer):
            conv_layer = self.convs[self.n_layer - layer - 1]

            if layer == 0:
                x = F.leaky_relu(conv_layer(x), negative_slope=0.2)
            else:
                BN_layer = self.BNs[self.n_layer - layer - 1]
                x = F.leaky_relu(BN_layer(conv_layer(x)), negative_slope=0.2)

        x = x.view(-1, self.featmap_dim * 4 * 4)

        # output layer
        x = self.fc(x)
        x[:, 0] = F.sigmoid(x[:, 0].clone())
        for j in range(self.n_discrete):
            start = 1 + self.n_conti + j * self.num_category
            end = start + self.num_category
            x[:, start:end] = F.softmax(x[:, start:end].clone())

        return x


class InfoGAN_Generator(nn.Module):

    def __init__(self, noise_dim=10, n_layer=3, n_conti=2, n_discrete=1,
                 num_category=10, use_gpu=False, featmap_dim=256, n_channel=1):
        """
        InfoGAN Generator, have an additional input branch for latent codes.
        Architecture brought from DCGAN.
        """
        super(InfoGAN_Generator, self).__init__()
        self.n_layer = n_layer
        self.n_conti = n_conti
        self.n_discrete = n_discrete
        self.num_category = num_category

        # calculate input dimension
        n_input = noise_dim + n_conti + n_discrete * num_category

        # Generator
        self.featmap_dim = featmap_dim
        self.fc_in = nn.Linear(n_input, featmap_dim * 4 * 4)

        convs = []
        BNs = []
        for layer in range(self.n_layer):
            if layer == 0:
                n_conv_out = n_channel
            else:
                n_conv_out = featmap_dim / (2 ** (self.n_layer - layer))

            n_conv_in = featmap_dim / (2 ** (self.n_layer - layer - 1))
            n_width = 5 if layer == (self.n_layer - 1) else 6

            _conv = nn.ConvTranspose2d(n_conv_in, n_conv_out, n_width,
                                       stride=2, padding=2)

            if use_gpu:
                _conv = _conv.cuda()
            convs.append(_conv)

            if layer != 0:
                _BN = nn.BatchNorm2d(n_conv_out)
                if use_gpu:
                    _BN = _BN.cuda()
                BNs.append(_BN)

        # register all nn modules
        self.convs = nn.ModuleList(convs)
        self.BNs = nn.ModuleList(BNs)

    def forward(self, x):
        """
        Input the random noise plus latent codes to generate fake images.
        """
        x = self.fc_in(x)
        x = x.view(-1, self.featmap_dim, 4, 4)

        for layer in range(self.n_layer):
            conv_layer = self.convs[self.n_layer - layer - 1]
            if layer == (self.n_layer - 1):
                x = F.tanh(conv_layer(x))
            else:
                BN_layer = self.BNs[self.n_layer - layer - 2]
                x = F.relu(BN_layer(conv_layer(x)))

        return x


In [2]:
# -*- coding: utf-8 -*-
# @Author: aaronlai

import torch
import torch.optim as optim
import torchvision
import torchvision.transforms as transforms
import numpy as np

from torch.autograd import Variable


In [3]:
def load_dataset(batch_size=10, download=True):
    """
    The output of torchvision datasets are PILImage images of range [0, 1].
    Transform them to Tensors of normalized range [-1, 1]
    """
    transform = transforms.Compose([transforms.ToTensor(),
                                    transforms.Normalize((0.5, 0.5, 0.5),
                                                         (0.5, 0.5, 0.5))])
    trainset = torchvision.datasets.MNIST(root='../data', train=True,
                                          download=download,
                                          transform=transform)
    trainloader = torch.utils.data.DataLoader(trainset, batch_size=batch_size,
                                              shuffle=True, num_workers=2)

    testset = torchvision.datasets.MNIST(root='../data', train=False,
                                         download=download,
                                         transform=transform)
    testloader = torch.utils.data.DataLoader(testset, batch_size=batch_size,
                                             shuffle=False, num_workers=2)

    return trainloader, testloader


def gen_noise(n_instance, n_dim=2):
    """generate n-dim uniform random noise"""
    return torch.Tensor(np.random.uniform(low=-1.0, high=1.0,
                                          size=(n_instance, n_dim)))


def gen_conti_codes(n_instance, n_conti, mean=0, std=1):
    """generate gaussian continuous codes with specified mean and std"""
    codes = np.random.randn(n_instance, n_conti) * std + mean
    return torch.Tensor(codes)


def gen_discrete_code(n_instance, n_discrete, num_category=10):
    """generate discrete codes with n categories"""
    codes = []
    for i in range(n_discrete):
        code = np.zeros((n_instance, num_category))
        random_cate = np.random.randint(0, num_category, n_instance)
        code[range(n_instance), random_cate] = 1
        codes.append(code)

    codes = np.concatenate(codes, 1)
    return torch.Tensor(codes)


In [5]:
def train_InfoGAN(InfoGAN_Dis, InfoGAN_Gen, D_criterion, G_criterion,
                  D_optimizer, G_optimizer, info_reg_discrete, info_reg_conti,
                  n_conti, n_discrete, mean, std, num_category, trainloader,
                  n_epoch, batch_size, noise_dim,
                  n_update_dis=1, n_update_gen=1, use_gpu=False,
                  print_every=50, update_max=None):
    """train InfoGAN and print out the losses for D and G"""

    for epoch in range(n_epoch):

        D_running_loss = 0.0
        G_running_loss = 0.0

        for i, data in enumerate(trainloader, 0):
            # get the inputs from true distribution
            true_inputs, lab = data
            true_inputs = Variable(true_inputs)
            if use_gpu:
                true_inputs = true_inputs.cuda()

            # get inputs (noises and codes) for Generator
            noises = Variable(gen_noise(batch_size, n_dim=noise_dim))
            conti_codes = Variable(gen_conti_codes(batch_size, n_conti,
                                                   mean, std))
            discr_codes = Variable(gen_discrete_code(batch_size, n_discrete,
                                                     num_category))
            print (noises, conti_codes, discr_codes)
            if use_gpu:
                noises = noises.cuda()
                conti_codes = conti_codes.cuda()
                discr_codes = discr_codes.cuda()
                

            # generate fake images
            gen_inputs = torch.cat((noises, conti_codes, discr_codes), 1)
            fake_inputs = InfoGAN_Gen(gen_inputs)
            
            vutils.save_image(fake_inputs.data, 'fake_inputs.png',nrow=10)
            
            inputs = torch.cat([true_inputs, fake_inputs])

            # make a minibatch of labels
            labels = np.zeros(2 * batch_size)
            labels[:batch_size] = 1
            labels = torch.from_numpy(labels.astype(np.float32))
            if use_gpu:
                labels = labels.cuda()
            labels = Variable(labels)

            # Discriminator
            D_optimizer.zero_grad()
            outputs = InfoGAN_Dis(inputs)

            # calculate mutual information lower bound L(G, Q)
            for j in range(n_discrete):
                shift = (j * num_category)
                start = 1 + n_conti + shift
                end = start + num_category
                Q_cx_discr = outputs[batch_size:, start:end]
                codes = discr_codes[:, shift:(shift+num_category)]
                condi_entro = -torch.mean(torch.sum(Q_cx_discr * codes, 1))

                if j == 0:
                    L_discrete = -condi_entro
                else:
                    L_discrete -= condi_entro
            L_discrete /= n_discrete

            Q_cx_conti = outputs[batch_size:, 1:(1 + n_conti)]
            L_conti = torch.mean(-(((Q_cx_conti - mean) / std) ** 2))

            # Update Discriminator
            D_loss = D_criterion(outputs[:, 0], labels)
            if n_discrete > 0:
                D_loss = D_loss - info_reg_discrete * L_discrete

            if n_conti > 0:
                D_loss = D_loss - info_reg_conti * L_conti

            if i % n_update_dis == 0:
                D_loss.backward(retain_variables=True)
                D_optimizer.step()

            # Update Generator
            if i % n_update_gen == 0:
                G_optimizer.zero_grad()
                G_loss = G_criterion(outputs[batch_size:, 0],
                                     labels[:batch_size])

                if n_discrete > 0:
                    G_loss = G_loss - info_reg_discrete * L_discrete

                if n_conti > 0:
                    G_loss = G_loss - info_reg_conti * L_conti

                G_loss.backward()
                G_optimizer.step()

            # print statistics
            D_running_loss += D_loss.data[0]
            G_running_loss += G_loss.data[0]
            if i % print_every == (print_every - 1):
                print('[%d, %5d] D loss: %.3f ; G loss: %.3f' %
                      (epoch+1, i+1, D_running_loss / print_every,
                       G_running_loss / print_every))
                D_running_loss = 0.0
                G_running_loss = 0.0
                
                _c3_fix = Variable(torch.randn(1,2).cuda().uniform_(-1,1))
                for q in range(0,10):
                    _c1 = np.zeros((1,10),dtype = np.float32)
                    _c1[0,q] = 1
                    _c1 = Variable(torch.Tensor(_c1).cuda())
                    noise = torch.cat([_c1,_c2_fix,_c3_fix,z_fix],1)
                    G_sample = netG(noise)
                    storage[k*10+q] = G_sample.data.cpu().numpy()


            if update_max and i > update_max:
                break

    print('Finished Training')


def run_InfoGAN(info_reg_discrete=1.0, info_reg_conti=0.5, noise_dim=10,
                n_conti=2, n_discrete=1, mean=0.0, std=0.5, num_category=10,
                n_layer=3, n_channel=1, D_featmap_dim=256, G_featmap_dim=1024,
                n_epoch=2, batch_size=50, use_gpu=False, dis_lr=1e-4,
                gen_lr=1e-3, n_update_dis=1, n_update_gen=1, update_max=None):
    # loading data
    trainloader, testloader = load_dataset(batch_size=batch_size)

    # initialize models
    InfoGAN_Dis = InfoGAN_Discriminator(n_layer, n_conti, n_discrete,
                                        num_category, use_gpu, D_featmap_dim,
                                        n_channel)

    InfoGAN_Gen = InfoGAN_Generator(noise_dim, n_layer, n_conti, n_discrete,
                                    num_category, use_gpu, G_featmap_dim,
                                    n_channel)

    if use_gpu:
        InfoGAN_Dis = InfoGAN_Dis.cuda()
        InfoGAN_Gen = InfoGAN_Gen.cuda()

    # assign loss function and optimizer (Adam) to D and G
    D_criterion = torch.nn.BCELoss()
    D_optimizer = optim.Adam(InfoGAN_Dis.parameters(), lr=dis_lr,
                             betas=(0.5, 0.999))

    G_criterion = torch.nn.BCELoss()
    G_optimizer = optim.Adam(InfoGAN_Gen.parameters(), lr=gen_lr,
                             betas=(0.5, 0.999))

    train_InfoGAN(InfoGAN_Dis, InfoGAN_Gen, D_criterion, G_criterion,
                  D_optimizer, G_optimizer, info_reg_discrete, info_reg_conti,
                  n_conti, n_discrete, mean, std, num_category, trainloader,
                  n_epoch, batch_size, noise_dim,
                  n_update_dis, n_update_gen, use_gpu, update_max=update_max)


if __name__ == '__main__':
    run_InfoGAN(n_conti=2, n_discrete=1, D_featmap_dim=64, G_featmap_dim=128,
                n_epoch=1, batch_size=10, update_max=200)


(Variable containing:
-0.5619 -0.5759 -0.0501  0.7748  0.8173 -0.2253 -0.2041  0.1070  0.9425  0.2985
-0.7721 -0.5022 -0.7415  0.4669  0.2111 -0.4062 -0.4161 -0.0211  0.4009  0.6022
 0.8112 -0.1573 -0.4208 -0.8612 -0.6203  0.4121  0.4028 -0.0320 -0.8657 -0.5539
-0.2803  0.0650  0.6164  0.3515 -0.1804 -0.6027 -0.2142 -0.4378  0.0115  0.4682
 0.8992 -0.0747 -0.3857 -0.2657 -0.0587  0.9938  0.4232  0.5902 -0.3354  0.1744
 0.3869  0.7138 -0.4790  0.7496  0.3506  0.6166  0.4613  0.7808 -0.2872  0.0924
 0.1025  0.1222  0.4651 -0.2540 -0.0238  0.7983 -0.0792 -0.4399  0.0748 -0.2969
 0.7205 -0.2320 -0.0464 -0.1224 -0.9998 -0.3553  0.2423 -0.5924  0.0665  0.4739
-0.9500  0.8510  0.7272 -0.7940  0.6544  0.4640 -0.6919 -0.1258 -0.3671  0.6257
 0.1278 -0.7626  0.6199 -0.0241 -0.2509 -0.7392 -0.2772  0.8582  0.1457  0.2488
[torch.FloatTensor of size 10x10]
, Variable containing:
 0.8891  0.2308
-0.0382 -0.0731
-0.7821  0.3871
-0.1851 -1.0973
-0.4095  0.3323
-0.4840 -0.5791
-1.2121  0.2350
-0.6005  

Process Process-3:
Process Process-4:
Traceback (most recent call last):
Traceback (most recent call last):
  File "/usr/lib/python2.7/multiprocessing/process.py", line 258, in _bootstrap
  File "/usr/lib/python2.7/multiprocessing/process.py", line 258, in _bootstrap
    self.run()
    self.run()
  File "/usr/lib/python2.7/multiprocessing/process.py", line 114, in run
    self._target(*self._args, **self._kwargs)
  File "/usr/lib/python2.7/multiprocessing/process.py", line 114, in run
  File "/usr/local/lib/python2.7/dist-packages/torch/utils/data/dataloader.py", line 35, in _worker_loop
    self._target(*self._args, **self._kwargs)
    r = index_queue.get()
  File "/usr/lib/python2.7/multiprocessing/queues.py", line 378, in get
  File "/usr/local/lib/python2.7/dist-packages/torch/utils/data/dataloader.py", line 35, in _worker_loop
    return recv()
  File "/usr/local/lib/python2.7/dist-packages/torch/multiprocessing/queue.py", line 21, in recv
    r = index_queue.get()
    buf = self.

(Variable containing:
 0.1499 -0.4778  0.9060 -0.4313 -0.9350 -0.3977  0.3440  0.5113 -0.6693 -0.0150
 0.0748 -0.9997 -0.2833  0.0848 -0.8306 -0.6828  0.2618  0.3055 -0.0795  0.8585
 0.8174 -0.2803 -0.8764 -0.6650  0.9967 -0.1577  0.6745  0.8956 -0.7304 -0.9721
 0.7480 -0.5349  0.8766  0.4356  0.4382 -0.6811  0.0195 -0.5709 -0.3006  0.5314
-0.1940  0.2335  0.2326  0.6668  0.3854  0.9003  0.2779  0.6885  0.4033 -0.0541
-0.4044 -0.8425 -0.4944  0.2804  0.7920  0.4238 -0.2951 -0.8358 -0.8103 -0.2369
-0.2043  0.7068  0.0413 -0.3494 -0.0759 -0.9432 -0.6595  0.1829 -0.1114  0.9322
-0.4580  0.5508 -0.3535 -0.6602 -0.4037 -0.6364  0.0264  0.7172 -0.8352  0.1016
 0.9250 -0.1507  0.7303  0.4110  0.5430 -0.8716 -0.7607  0.2870 -0.9677  0.7428
-0.2300  0.7895  0.6782  0.8749  0.9842 -0.5989  0.5146  0.3075  0.9180 -0.7334
[torch.FloatTensor of size 10x10]
, Variable containing:
 0.4841 -0.8371
-0.2425  0.8418
 0.3397  0.0184
 1.0742 -0.1434
-0.7744  0.9786
-0.0042 -0.9155
 0.6209 -0.0045
-0.6693 -

KeyboardInterrupt: 