In [1]:
# -*- coding: utf-8 -*-
# @Author: aaronlai

import torch.nn as nn
import torch.nn.functional as F


class InfoGAN_Discriminator(nn.Module):

    def __init__(self, n_layer=3, n_conti=2, n_discrete=1,
                 num_category=10, use_gpu=False, featmap_dim=256,
                 n_channel=1):
        """
        InfoGAN Discriminator, have additional outputs for latent codes.
        Architecture brought from DCGAN.
        """
        super(InfoGAN_Discriminator, self).__init__()
        self.n_layer = n_layer
        self.n_conti = n_conti
        self.n_discrete = n_discrete
        self.num_category = num_category

        # Discriminator
        self.featmap_dim = featmap_dim

        convs = []
        BNs = []
        for layer in range(self.n_layer):
            if layer == (self.n_layer - 1):
                n_conv_in = n_channel
            else:
                n_conv_in = int(featmap_dim / (2**(layer + 1)))
            n_conv_out = int(featmap_dim / (2**layer))

            _conv = nn.Conv2d(n_conv_in, n_conv_out, kernel_size=5,
                              stride=2, padding=2)
            if use_gpu:
                _conv = _conv.cuda()
            convs.append(_conv)

            if layer != (self.n_layer - 1):
                _BN = nn.BatchNorm2d(n_conv_out)
                if use_gpu:
                    _BN = _BN.cuda()
                BNs.append(_BN)

        # output layer - prob(real) and auxiliary distributions Q(c_j|x)
        n_hidden = featmap_dim * 4 * 4
        n_output = 1 + n_conti + n_discrete * num_category
        self.fc = nn.Linear(n_hidden, n_output)

        # register all nn modules
        self.convs = nn.ModuleList(convs)
        self.BNs = nn.ModuleList(BNs)

    def forward(self, x):
        """
        Output the probability of being in real dataset
        plus the conditional distributions of latent codes.
        """
        for layer in range(self.n_layer):
            conv_layer = self.convs[self.n_layer - layer - 1]

            if layer == 0:
                x = F.leaky_relu(conv_layer(x), negative_slope=0.2)
            else:
                BN_layer = self.BNs[self.n_layer - layer - 1]
                x = F.leaky_relu(BN_layer(conv_layer(x)), negative_slope=0.2)

        x = x.view(-1, self.featmap_dim * 4 * 4)

        # output layer
        x = self.fc(x)
        x[:, 0] = F.sigmoid(x[:, 0].clone())
        for j in range(self.n_discrete):
            start = 1 + self.n_conti + j * self.num_category
            end = start + self.num_category
            x[:, start:end] = F.softmax(x[:, start:end].clone())

        return x


class InfoGAN_Generator(nn.Module):

    def __init__(self, noise_dim=10, n_layer=3, n_conti=2, n_discrete=1,
                 num_category=10, use_gpu=False, featmap_dim=256, n_channel=1):
        """
        InfoGAN Generator, have an additional input branch for latent codes.
        Architecture brought from DCGAN.
        """
        super(InfoGAN_Generator, self).__init__()
        self.n_layer = n_layer
        self.n_conti = n_conti
        self.n_discrete = n_discrete
        self.num_category = num_category

        # calculate input dimension
        n_input = noise_dim + n_conti + n_discrete * num_category

        # Generator
        self.featmap_dim = featmap_dim
        self.fc_in = nn.Linear(n_input, featmap_dim * 4 * 4)

        convs = []
        BNs = []
        for layer in range(self.n_layer):
            if layer == 0:
                n_conv_out = n_channel
            else:
                n_conv_out = featmap_dim / (2 ** (self.n_layer - layer))

            n_conv_in = featmap_dim / (2 ** (self.n_layer - layer - 1))
            n_width = 5 if layer == (self.n_layer - 1) else 6

            _conv = nn.ConvTranspose2d(n_conv_in, n_conv_out, n_width,
                                       stride=2, padding=2)

            if use_gpu:
                _conv = _conv.cuda()
            convs.append(_conv)

            if layer != 0:
                _BN = nn.BatchNorm2d(n_conv_out)
                if use_gpu:
                    _BN = _BN.cuda()
                BNs.append(_BN)

        # register all nn modules
        self.convs = nn.ModuleList(convs)
        self.BNs = nn.ModuleList(BNs)

    def forward(self, x):
        """
        Input the random noise plus latent codes to generate fake images.
        """
        x = self.fc_in(x)
        x = x.view(-1, self.featmap_dim, 4, 4)

        for layer in range(self.n_layer):
            conv_layer = self.convs[self.n_layer - layer - 1]
            if layer == (self.n_layer - 1):
                x = F.tanh(conv_layer(x))
            else:
                BN_layer = self.BNs[self.n_layer - layer - 2]
                x = F.relu(BN_layer(conv_layer(x)))

        return x


In [None]:
# -*- coding: utf-8 -*-
# @Author: aaronlai

import torch
import torch.optim as optim
import torchvision
import torchvision.transforms as transforms
import numpy as np
import torchvision.utils as vutils
from torch.autograd import Variable


def load_dataset(batch_size=10, download=True):
    """
    The output of torchvision datasets are PILImage images of range [0, 1].
    Transform them to Tensors of normalized range [-1, 1]
    """
    transform = transforms.Compose([transforms.ToTensor(),
                                    transforms.Normalize((0.5, 0.5, 0.5),
                                                         (0.5, 0.5, 0.5))])
    trainset = torchvision.datasets.MNIST(root='../data', train=True,
                                          download=download,
                                          transform=transform)
    trainloader = torch.utils.data.DataLoader(trainset, batch_size=batch_size,
                                              shuffle=True, num_workers=2)

    testset = torchvision.datasets.MNIST(root='../data', train=False,
                                         download=download,
                                         transform=transform)
    testloader = torch.utils.data.DataLoader(testset, batch_size=batch_size,
                                             shuffle=False, num_workers=2)

    return trainloader, testloader


def gen_noise(n_instance, n_dim=2):
    """generate n-dim uniform random noise"""
    return torch.Tensor(np.random.uniform(low=-1.0, high=1.0,
                                          size=(n_instance, n_dim)))


def gen_conti_codes(n_instance, n_conti, mean=0, std=1):
    """generate gaussian continuous codes with specified mean and std"""
    codes = np.random.randn(n_instance, n_conti) * std + mean
    return torch.Tensor(codes)


def gen_discrete_code(n_instance, n_discrete, num_category=10):
    """generate discrete codes with n categories"""
    codes = []
    for i in range(n_discrete):
        code = np.zeros((n_instance, num_category))
        random_cate = np.random.randint(0, num_category, n_instance)
        code[range(n_instance), random_cate] = 1
        codes.append(code)

    codes = np.concatenate(codes, 1)
    return torch.Tensor(codes)


def train_InfoGAN(InfoGAN_Dis, InfoGAN_Gen, D_criterion, G_criterion,
                  D_optimizer, G_optimizer, info_reg_discrete, info_reg_conti,
                  n_conti, n_discrete, mean, std, num_category, trainloader,
                  n_epoch, batch_size, noise_dim,
                  n_update_dis=1, n_update_gen=1, use_gpu=False,
                  print_every=50, update_max=None):
    """train InfoGAN and print out the losses for D and G"""

    for epoch in range(n_epoch):

        D_running_loss = 0.0
        G_running_loss = 0.0

        for i, data in enumerate(trainloader, 0):
            # get the inputs from true distribution
            true_inputs, lab = data
            true_inputs = Variable(true_inputs)
            if use_gpu:
                true_inputs = true_inputs.cuda()

            # get inputs (noises and codes) for Generator
            noises = Variable(gen_noise(batch_size, n_dim=noise_dim))
            conti_codes = Variable(gen_conti_codes(batch_size, n_conti,
                                                   mean, std))
            discr_codes = Variable(gen_discrete_code(batch_size, n_discrete,
                                                     num_category))
            if use_gpu:
                noises = noises.cuda()
                conti_codes = conti_codes.cuda()
                discr_codes = discr_codes.cuda()

            # generate fake images
            gen_inputs = torch.cat((noises, conti_codes, discr_codes), 1)
            fake_inputs = InfoGAN_Gen(gen_inputs)
            vutils.save_image(fake_inputs.data, 'fake_samples.png')
            
            inputs = torch.cat([true_inputs, fake_inputs])

            # make a minibatch of labels
            labels = np.zeros(2 * batch_size)
            labels[:batch_size] = 1
            labels = torch.from_numpy(labels.astype(np.float32))
            if use_gpu:
                labels = labels.cuda()
            labels = Variable(labels)

            # Discriminator
            D_optimizer.zero_grad()
            outputs = InfoGAN_Dis(inputs)

            # calculate mutual information lower bound L(G, Q)
            for j in range(n_discrete):
                shift = (j * num_category)
                start = 1 + n_conti + shift
                end = start + num_category
                Q_cx_discr = outputs[batch_size:, start:end]
                codes = discr_codes[:, shift:(shift+num_category)]
                condi_entro = -torch.mean(torch.sum(Q_cx_discr * codes, 1))

                if j == 0:
                    L_discrete = -condi_entro
                else:
                    L_discrete -= condi_entro
            L_discrete /= n_discrete

            Q_cx_conti = outputs[batch_size:, 1:(1 + n_conti)]
            L_conti = torch.mean(-(((Q_cx_conti - mean) / std) ** 2))

            # Update Discriminator
            D_loss = D_criterion(outputs[:, 0], labels)
            if n_discrete > 0:
                D_loss = D_loss - info_reg_discrete * L_discrete

            if n_conti > 0:
                D_loss = D_loss - info_reg_conti * L_conti

            if i % n_update_dis == 0:
                D_loss.backward(retain_variables=True)
                D_optimizer.step()

            # Update Generator
            if i % n_update_gen == 0:
                G_optimizer.zero_grad()
                G_loss = G_criterion(outputs[batch_size:, 0],
                                     labels[:batch_size])

                if n_discrete > 0:
                    G_loss = G_loss - info_reg_discrete * L_discrete

                if n_conti > 0:
                    G_loss = G_loss - info_reg_conti * L_conti

                G_loss.backward()
                G_optimizer.step()

            # print statistics
            D_running_loss += D_loss.data[0]
            G_running_loss += G_loss.data[0]
            if i % print_every == (print_every - 1):
                print('[%d, %5d] D loss: %.3f ; G loss: %.3f' %
                      (epoch+1, i+1, D_running_loss / print_every,
                       G_running_loss / print_every))
                D_running_loss = 0.0
                G_running_loss = 0.0

            if update_max and i > update_max:
                break

    print('Finished Training')


def run_InfoGAN(info_reg_discrete=1.0, info_reg_conti=0.5, noise_dim=10,
                n_conti=2, n_discrete=1, mean=0.0, std=0.5, num_category=10,
                n_layer=3, n_channel=1, D_featmap_dim=256, G_featmap_dim=1024,
                n_epoch=2, batch_size=50, use_gpu=False, dis_lr=1e-4,
                gen_lr=1e-3, n_update_dis=1, n_update_gen=1, update_max=None):
    # loading data
    trainloader, testloader = load_dataset(batch_size=batch_size)

    # initialize models
    InfoGAN_Dis = InfoGAN_Discriminator(n_layer, n_conti, n_discrete,
                                        num_category, use_gpu, D_featmap_dim,
                                        n_channel)

    InfoGAN_Gen = InfoGAN_Generator(noise_dim, n_layer, n_conti, n_discrete,
                                    num_category, use_gpu, G_featmap_dim,
                                    n_channel)

    if use_gpu:
        InfoGAN_Dis = InfoGAN_Dis.cuda()
        InfoGAN_Gen = InfoGAN_Gen.cuda()

    # assign loss function and optimizer (Adam) to D and G
    D_criterion = torch.nn.BCELoss()
    D_optimizer = optim.Adam(InfoGAN_Dis.parameters(), lr=dis_lr,
                             betas=(0.5, 0.999))

    G_criterion = torch.nn.BCELoss()
    G_optimizer = optim.Adam(InfoGAN_Gen.parameters(), lr=gen_lr,
                             betas=(0.5, 0.999))

    train_InfoGAN(InfoGAN_Dis, InfoGAN_Gen, D_criterion, G_criterion,
                  D_optimizer, G_optimizer, info_reg_discrete, info_reg_conti,
                  n_conti, n_discrete, mean, std, num_category, trainloader,
                  n_epoch, batch_size, noise_dim,
                  n_update_dis, n_update_gen, use_gpu, update_max=update_max)


if __name__ == '__main__':
    run_InfoGAN(n_conti=2, n_discrete=1, D_featmap_dim=64, G_featmap_dim=128,
                n_epoch=25, batch_size=10, update_max=20000)


[1,    50] D loss: 0.633 ; G loss: 0.344
[1,   100] D loss: 0.532 ; G loss: 0.480
[1,   150] D loss: 0.575 ; G loss: 0.587
[1,   200] D loss: 0.571 ; G loss: 0.631
[1,   250] D loss: 0.586 ; G loss: 0.607
[1,   300] D loss: 0.581 ; G loss: 0.598
[1,   350] D loss: 0.575 ; G loss: 0.674
[1,   400] D loss: 0.601 ; G loss: 0.565
[1,   450] D loss: 0.584 ; G loss: 0.586
[1,   500] D loss: 0.522 ; G loss: 0.703
[1,   550] D loss: 0.574 ; G loss: 0.639
[1,   600] D loss: 0.495 ; G loss: 0.711
[1,   650] D loss: 0.565 ; G loss: 0.586
[1,   700] D loss: 0.523 ; G loss: 0.724
[1,   750] D loss: 0.550 ; G loss: 0.657
[1,   800] D loss: 0.546 ; G loss: 0.707
[1,   850] D loss: 0.508 ; G loss: 0.670
[1,   900] D loss: 0.537 ; G loss: 0.642
[1,   950] D loss: 0.547 ; G loss: 0.668
[1,  1000] D loss: 0.404 ; G loss: 0.909
[1,  1050] D loss: 0.364 ; G loss: 0.889
[1,  1100] D loss: 0.490 ; G loss: 0.796
[1,  1150] D loss: 0.380 ; G loss: 0.878
[1,  1200] D loss: 0.439 ; G loss: 0.790
[1,  1250] D los

In [6]:
train, test = load_dataset()

In [9]:
train[0]

TypeError: 'DataLoader' object does not support indexing