<a href="https://colab.research.google.com/github/mchivuku/csb659-project/blob/master/InfoGAN.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

# InfoGAN

Notebook contains implementation for InfoGAN - https://arxiv.org/abs/1606.03657

Code adapted from: https://github.com/eriklindernoren/Pytorch-GAN/

## Connect to google drive

In [0]:
%%capture
!pip install tqdm six

In [2]:
from google.colab import drive

drive.mount("/content/drive")


Go to this URL in a browser: https://accounts.google.com/o/oauth2/auth?client_id=947318989803-6bn6qk8qdgf4n4g3pfee6491hc0brc4i.apps.googleusercontent.com&redirect_uri=urn%3Aietf%3Awg%3Aoauth%3A2.0%3Aoob&scope=email%20https%3A%2F%2Fwww.googleapis.com%2Fauth%2Fdocs.test%20https%3A%2F%2Fwww.googleapis.com%2Fauth%2Fdrive%20https%3A%2F%2Fwww.googleapis.com%2Fauth%2Fdrive.photos.readonly%20https%3A%2F%2Fwww.googleapis.com%2Fauth%2Fpeopleapi.readonly&response_type=code

Enter your authorization code:
··········
Mounted at /content/drive


In [3]:
%cd /content/drive/My\ Drive/Masters-DS/CSCI-B659/project/examples/infoGAN


/content/drive/My Drive/Masters-DS/CSCI-B659/project/examples/infoGAN


In [0]:
import os

os.makedirs("infoGAN")

In [12]:
%cd infoGAN

/content/drive/My Drive/Masters-DS/CSCI-B659/project/examples/infoGAN


In [0]:
os.makedirs("./results/static", exist_ok=True)
os.makedirs('./results/varying_c1/', exist_ok=True)
os.makedirs('./results/varying_c2/', exist_ok=True)

## Torch Imports

In [0]:
from __future__ import print_function
import torch
import torch.utils.data
from torch import nn, optim
from torch.nn import functional as F
from torchvision import datasets, transforms
from torchvision.utils import save_image

import torch.nn as nn

import itertools
from torch.autograd import Variable



In [5]:
torch.manual_seed(5)

print('Torch', torch.__version__, 'CUDA', torch.version.cuda)
print('Device:', torch.device('cuda:0'))
print(torch.cuda.is_available())

is_cuda = torch.cuda.is_available()
device = torch.device ( "cuda:0" if torch.cuda.is_available () else "cpu" )

Torch 1.0.1.post2 CUDA 10.0.130
Device: cuda:0
True


In [8]:
class Params:
    n_epochs = 200
    batch_size = 128
    lr = 0.0002
    b1 = 0.5 #adam, decay of first order momentum of gradient
    b2 = 0.999 #decay of first order momentum of gradient
    n_cpu = 8
    latent_dim = 62
    code_dim = 2
    n_classes = 10
    img_size = 32
    channels = 1
    sample_interval = 400 # interval  between image sampling
    
    
    
Params.n_epochs

200

In [0]:
## function to initialize weights
def weights_init_normal(m):
  classname= m.__class__.__name__
  if classname.find("Conv")!=-1:
    torch.nn.init.normal_(m.weight.data,0.0,0.02)
  elif classname.find("BatchNorm")!=-1:
    torch.nn.init.normal_(m.weight.data,1.0,0.02)
    torch.nn.init.constant_(m.bias.data,0.0)
    
    
"""
One hot encoded vector for category
"""
def to_categorical(y, num_cols):
  """
  """
  y_cat = np.zeros((y.shape[0],num_cols))
  y_cat[range(y.shape[0]),y] =1.
  
  return Variable(FloatTensor(y_cat))
  

In [0]:
## Models
class Generator(nn.Module):
  def __init__(self):
    super(Generator,self).__init__()
    input_dim = Params.latent_dim + Params.n_classes + Params.code_dim
    
    self.init_size = Params.img_size // 4 #initial size before upsampling
    self.l1 = nn.Sequential(nn.Linear(input_dim, 128*self.init_size **2))
    
    self.conv_blocks = nn.Sequential(
            nn.BatchNorm2d(128),
            nn.Upsample(scale_factor=2),
            nn.Conv2d(128, 128, 3, stride=1, padding=1),
            nn.BatchNorm2d(128, 0.8),
            nn.LeakyReLU(0.2, inplace=True),
            nn.Upsample(scale_factor=2),
            nn.Conv2d(128, 64, 3, stride=1, padding=1),
            nn.BatchNorm2d(64, 0.8),
            nn.LeakyReLU(0.2, inplace=True),
            nn.Conv2d(64, Params.channels, 3, stride=1, padding=1),
            nn.Tanh()
        )
    
  def forward(self,noise, labels, code):
    gen_input = torch.cat((noise, labels, code),-1) ## include to contains noise and code
    out = self.l1(gen_input)
    
    out = out.view(out.shape[0],128,self.init_size, self.init_size)
    img = self.conv_blocks(out)
    return img
  

In [0]:
class Discriminator(nn.Module):
  def __init__(self):
    super(Discriminator,self).__init__()
    
    def discriminator_block(in_filters, out_filters, bn=True):
      """Returns layers of each discriminator block"""
      block = [   nn.Conv2d(in_filters, out_filters, 3, 2, 1),
                nn.LeakyReLU(0.2, inplace=True),
                nn.Dropout2d(0.25)]
      if bn:
        block.append(nn.BatchNorm2d(out_filters, 0.8))
      return block
          
    self.conv_blocks = nn.Sequential(*discriminator_block(Params.channels,16, bn=False),
                                    *discriminator_block(16,32),
                                     *discriminator_block(32,64),
                                     *discriminator_block(64,128)
                                        
                                    )    
    
    ## height and width of downsampled image
    ds_size = Params.img_size // 2**4
    self.adv_layer = nn.Sequential(nn.Linear(128*ds_size**2, 1))
    self.aux_layer = nn.Sequential(
            nn.Linear(128*ds_size**2, Params.n_classes),
            nn.Softmax()
        )
    
    self.latent_layer = nn.Sequential(nn.Linear(128*ds_size**2, Params.code_dim))
    
  def forward(self, img):
      out = self.conv_blocks(img)
      out = out.view(out.shape[0], -1)
      validity = self.adv_layer(out)
      label = self.aux_layer(out)
      latent_code = self.latent_layer(out)

      return validity, label, latent_code
      
    
    
    
    
    
    

In [25]:
## Loss Functions

adversarial_loss = torch.nn.MSELoss()
categorical_loss = torch.nn.CrossEntropyLoss() # cross entropy loss - discrete loss
continuous_loss = torch.nn.MSELoss()


## Loss weights
lambda_cat = 1
lambda_con = 0.1

## Initialize generator and discriminator netowrk
generator  = Generator()
discriminator = Discriminator()


generator = generator.cuda()
discriminator = discriminator.cuda()

print(generator)
print()
print(discriminator)

Generator(
  (l1): Sequential(
    (0): Linear(in_features=74, out_features=8192, bias=True)
  )
  (conv_blocks): Sequential(
    (0): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (1): Upsample(scale_factor=2, mode=nearest)
    (2): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (3): BatchNorm2d(128, eps=0.8, momentum=0.1, affine=True, track_running_stats=True)
    (4): LeakyReLU(negative_slope=0.2, inplace)
    (5): Upsample(scale_factor=2, mode=nearest)
    (6): Conv2d(128, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (7): BatchNorm2d(64, eps=0.8, momentum=0.1, affine=True, track_running_stats=True)
    (8): LeakyReLU(negative_slope=0.2, inplace)
    (9): Conv2d(64, 1, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (10): Tanh()
  )
)

Discriminator(
  (conv_blocks): Sequential(
    (0): Conv2d(1, 16, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1))
    (1): LeakyReLU(negative_slope=0.2, inplace)


In [26]:
## Initialize weights
# Initialize weights
generator.apply(weights_init_normal)
discriminator.apply(weights_init_normal)


Discriminator(
  (conv_blocks): Sequential(
    (0): Conv2d(1, 16, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1))
    (1): LeakyReLU(negative_slope=0.2, inplace)
    (2): Dropout2d(p=0.25)
    (3): Conv2d(16, 32, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1))
    (4): LeakyReLU(negative_slope=0.2, inplace)
    (5): Dropout2d(p=0.25)
    (6): BatchNorm2d(32, eps=0.8, momentum=0.1, affine=True, track_running_stats=True)
    (7): Conv2d(32, 64, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1))
    (8): LeakyReLU(negative_slope=0.2, inplace)
    (9): Dropout2d(p=0.25)
    (10): BatchNorm2d(64, eps=0.8, momentum=0.1, affine=True, track_running_stats=True)
    (11): Conv2d(64, 128, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1))
    (12): LeakyReLU(negative_slope=0.2, inplace)
    (13): Dropout2d(p=0.25)
    (14): BatchNorm2d(128, eps=0.8, momentum=0.1, affine=True, track_running_stats=True)
  )
  (adv_layer): Sequential(
    (0): Linear(in_features=512, out_features=1, bias=True

In [0]:
## Data
dataloader = torch.utils.data.DataLoader(
    datasets.MNIST('../vae/mnist/data', train=True, download=True,
                   transform=transforms.Compose([
                        transforms.Resize(Params.img_size),
                        transforms.ToTensor(),
                        transforms.Normalize((0.5, ), (0.5,))
                   ])),
    batch_size=Params.batch_size, shuffle=True)

In [0]:
## Define Optimizers
# Optimizers
optimizer_G = torch.optim.Adam(generator.parameters(), lr=Params.lr, betas=(Params.b1, Params.b2))
optimizer_D = torch.optim.Adam(discriminator.parameters(), lr=Params.lr, betas=(Params.b1, Params.b2))
optimizer_info = torch.optim.Adam(itertools.chain(generator.parameters(), discriminator.parameters()),
                                    lr=Params.lr, betas=(Params.b1, Params.b2))

FloatTensor = torch.cuda.FloatTensor  
LongTensor = torch.cuda.LongTensor  


In [0]:
## Static generator inputs for sampling
# Static generator inputs for sampling
import numpy as np

static_z = Variable(FloatTensor(np.zeros((Params.n_classes**2, Params.latent_dim))))
static_label = to_categorical(np.array([num for _ in range(Params.n_classes) for num in range(Params.n_classes)]),
                                 Params.n_classes)
static_code = Variable(FloatTensor(np.zeros((Params.n_classes**2, Params.code_dim))))

In [0]:
def sample_image(n_row, batches_done):
    """Saves a grid of generated digits ranging from 0 to n_classes"""
    # Static sample
    z = Variable(FloatTensor(np.random.normal(0, 1, (n_row**2, Params.latent_dim))))
    static_sample = generator(z, static_label, static_code)
    save_image(static_sample.data, './results/static/%d.png' % batches_done, nrow=n_row, normalize=True)

    # Get varied c1 and c2
    zeros = np.zeros((n_row**2, 1))
    c_varied = np.repeat(np.linspace(-1, 1, n_row)[:, np.newaxis], n_row, 0)
    c1 = Variable(FloatTensor(np.concatenate((c_varied, zeros), -1)))
    c2 = Variable(FloatTensor(np.concatenate((zeros, c_varied), -1)))
    sample1 = generator(static_z, static_label, c1)
    sample2 = generator(static_z, static_label, c2)
    save_image(sample1.data, './results/varying_c1/%d.png' % batches_done, nrow=n_row, normalize=True)
    save_image(sample2.data, './results/varying_c2/%d.png' % batches_done, nrow=n_row, normalize=True)


## Training

In [0]:
for epoch in range(Params.n_epochs):
  for i, (images,labels) in enumerate(dataloader):
    batch_size = images.size(0)
    
    ## real and fake ground truth
    real = Variable(FloatTensor(batch_size,1).fill_(1.0),requires_grad=False)
    fake = Variable(FloatTensor(batch_size,1).fill_(0.0),requires_grad=False)
    
    ## configure input
    real_imgs = Variable(images.type(FloatTensor))
    labels = Variable(labels.type(FloatTensor))
    
    ### Train Generator
    optimizer_G.zero_grad()
    
    ## sample noise and labels as generator input
    z = Variable(FloatTensor(np.random.normal(0,1,(batch_size,Params.latent_dim))))
    label_input = to_categorical(np.random.randint(0,Params.n_classes, batch_size),Params.n_classes)
    
    code_input = Variable(FloatTensor(np.random.uniform(-1,1,(batch_size,Params.code_dim))))
    
    # Generate a batch of images - labels, code and z
    gen_imgs = generator(z, label_input, code_input)
    
    
    # discriminator
    validity, _, _ = discriminator(gen_imgs)
    
    g_loss = adversarial_loss(validity,real)
    
    g_loss.backward()
    optimizer_G.step()
    
    
    ##
    # Train Discriminator
    
    optimizer_D.zero_grad()
    
    ## Loss on real images
    real_pred, _, _ = discriminator(real_imgs)
    d_real_loss = adversarial_loss(real_pred, real)
    
    ## Loss on fake images
    fake_pred, _, _ = discriminator(gen_imgs.detach())
    d_fake_loss = adversarial_loss(fake_pred, fake)
    
    ## Total discriminator loss
    d_loss = (d_real_loss + d_fake_loss)/2
    
    d_loss.backward()
    optimizer_D.step()
    
    
    
    ### MutualInformation Loss
    ####
    
    optimizer_info.zero_grad()
    
    ## sample labels
    sampled_labels = np.random.randint(0,Params.n_classes, batch_size)
    
    # ground truth labels
    gt_labels = Variable(LongTensor(sampled_labels),requires_grad = False)
    
    
    # Sample noise, labels and code as generator input
    z = Variable(FloatTensor(np.random.normal(0, 1, (batch_size, Params.latent_dim))))
    label_input = to_categorical(sampled_labels, Params.n_classes)
    code_input = Variable(FloatTensor(np.random.normal(-1, 1, (batch_size, Params.code_dim))))

    gen_imgs = generator(z, label_input, code_input)
    _, pred_label, pred_code = discriminator(gen_imgs)
    
    info_loss = lambda_cat * categorical_loss(pred_label, gt_labels) + \
                lambda_con * continuous_loss(pred_code, code_input)
    
    info_loss.backward()
    
    optimizer_info.step()
    
    #--------------
    # Log Progress
    #--------------

    print ("[Epoch %d/%d] [Batch %d/%d] [D loss: %f] [G loss: %f] [info loss: %f]" % (epoch, Params.n_epochs, i, len(dataloader),
                                                            d_loss.item(), g_loss.item(), info_loss.item()))
    
    
    batches_done = epoch * len(dataloader) + i
    if batches_done % Params.sample_interval == 0:
            sample_image(n_row=10, batches_done=batches_done)

    
    

  input = module(input)


[Epoch 0/200] [Batch 0/469] [D loss: 0.488959] [G loss: 0.978922] [info loss: 2.485431]
[Epoch 0/200] [Batch 1/469] [D loss: 0.486341] [G loss: 0.974166] [info loss: 2.495313]
[Epoch 0/200] [Batch 2/469] [D loss: 0.483580] [G loss: 0.968782] [info loss: 2.466264]
[Epoch 0/200] [Batch 3/469] [D loss: 0.479737] [G loss: 0.962913] [info loss: 2.481699]
[Epoch 0/200] [Batch 4/469] [D loss: 0.475852] [G loss: 0.956066] [info loss: 2.484247]
[Epoch 0/200] [Batch 5/469] [D loss: 0.470917] [G loss: 0.946706] [info loss: 2.498102]
[Epoch 0/200] [Batch 6/469] [D loss: 0.464695] [G loss: 0.936057] [info loss: 2.487242]
[Epoch 0/200] [Batch 7/469] [D loss: 0.457368] [G loss: 0.923186] [info loss: 2.488725]
[Epoch 0/200] [Batch 8/469] [D loss: 0.447540] [G loss: 0.905998] [info loss: 2.494567]
[Epoch 0/200] [Batch 9/469] [D loss: 0.436703] [G loss: 0.880101] [info loss: 2.495739]
[Epoch 0/200] [Batch 10/469] [D loss: 0.418810] [G loss: 0.845121] [info loss: 2.496710]
[Epoch 0/200] [Batch 11/469] [D