In [1]:
import os
import numpy as np
import math
import random 
import torchvision.transforms as tfs
from torchvision.utils import save_image
import torch.optim as optim
from torch.utils.data import DataLoader
import torchvision.datasets as dset
from torch.autograd import Variable
import torch.backends.cudnn as cudnn
import itertools
import torch.nn as nn
import torch.nn.functional as F
import torch
import PIL.ImageOps

### Load pottery data and augmented methods as described in official example InfoGAN 

In [2]:
manualSeed = random.randint(1, 10000)
print("Random Seed: ", manualSeed)
random.seed(manualSeed)
torch.manual_seed(manualSeed)

#cudnn.benchmark = True

transformations = [
                   tfs.Resize((112, 112)),
                   tfs.Grayscale(1),
                   tfs.Lambda(lambda x: PIL.ImageOps.invert(x)),
                   tfs.ToTensor()]

dataset = dset.ImageFolder('data/png_clasificados/',
                                     transform=tfs.Compose(transformations))

dataloader = DataLoader(dataset, batch_size=64, shuffle=True)

Random Seed:  5409


### Parameter definitions for training

In [3]:
device = torch.device("cuda")

In [4]:
def weights_init_normal(m):
    classname = m.__class__.__name__
    if classname.find('Conv') != -1:
        torch.nn.init.normal_(m.weight.data, 0.0, 0.02)
    elif classname.find('BatchNorm') != -1:
        torch.nn.init.normal_(m.weight.data, 1.0, 0.02)
        torch.nn.init.constant_(m.bias.data, 0.0)

def to_categorical(y, num_columns=11):
    """Returns one-hot encoded Variable"""
    y_cat = np.zeros((y.shape[0], num_columns))
    y_cat[range(y.shape[0]), y] = 1.

    return Variable(FloatTensor(y_cat))

### Create $G(x)$ and $D(x)$ with weights initialization, define criterion and optimizers

In [13]:
from models.infogenerator import Generator
from models.infodiscriminator import Discriminator

# Loss functions
adversarial_loss = torch.nn.MSELoss()
categorical_loss = torch.nn.CrossEntropyLoss()
continuous_loss = torch.nn.MSELoss()

# Loss weights
lambda_cat = 1
lambda_con = 0.1

# Initialize generator and discriminator
netG = Generator()
netD = Discriminator()


netG.cuda()
netD.cuda()
adversarial_loss.cuda()
categorical_loss.cuda()
continuous_loss.cuda()

# Initialize weights
netG.apply(weights_init_normal)
netD.apply(weights_init_normal)

optimizerD = optim.Adam(netD.parameters(), lr=0.0002, betas=(0.5, 0.999))
optimizerG = optim.Adam(netG.parameters(), lr=0.0002, betas=(0.5, 0.999))
optimizer_info = torch.optim.Adam(itertools.chain(netG.parameters(), netD.parameters()),
                                    lr=0.0002, betas=(0.5, 0.999))

FloatTensor = torch.cuda.FloatTensor
LongTensor = torch.cuda.LongTensor
#FloatTensor = torch.FloatTensor
#LongTensor = torch.LongTensor

In [14]:
static_z = Variable(FloatTensor(np.zeros((11**2, 62))))
static_label = to_categorical(np.array([num for _ in range(11) for num in range(11)]),
                                num_columns=11)

static_code = Variable(FloatTensor(np.zeros((11**2, 2))))

In [15]:
def sample_image(n_row, batches_done):
    """Saves a grid of generated digits ranging from 0 to n_classes"""
    # Static sample
    z = Variable(FloatTensor(np.random.normal(0, 1, (n_row**2, 62))))
    static_sample = netG(z, static_label, static_code)
    save_image(static_sample.data, 'out/static_%d.png' % batches_done, nrow=n_row, normalize=True)

    # Get varied c1 and c2
    zeros = np.zeros((n_row**2, 1))
    c_varied = np.repeat(np.linspace(-1, 1, n_row)[:, np.newaxis], n_row, 0)
    c1 = Variable(FloatTensor(np.concatenate((c_varied, zeros), -1)))
    c2 = Variable(FloatTensor(np.concatenate((zeros, c_varied), -1)))
    sample1 = netG(static_z, static_label, c1)
    sample2 = netG(static_z, static_label, c2)
    save_image(sample1.data, 'out/c1_%d.png' % batches_done, nrow=n_row, normalize=True)
    save_image(sample2.data, 'out/c2_%d.png' % batches_done, nrow=n_row, normalize=True)

### Train $G(x)$ , $D(x)$ and calculate information loss 

In [None]:
for epoch in range(400):
    for i, (imgs, labels) in enumerate(dataloader):

        batch_size = imgs.shape[0]

        # Adversarial ground truths
        valid = Variable(FloatTensor(batch_size, 1).fill_(1.0), requires_grad=False)
        fake = Variable(FloatTensor(batch_size, 1).fill_(0.0), requires_grad=False)

        # Configure input
        real_imgs = Variable(imgs.type(FloatTensor))
        labels = to_categorical(labels.numpy(), num_columns=11)

        # -----------------
        #  Train Generator
        # -----------------

        optimizerG.zero_grad()

        # Sample noise and labels as generator input
        z = Variable(FloatTensor(np.random.normal(0, 1, (batch_size, 62))))
        label_input = to_categorical(np.random.randint(0, 11, batch_size), num_columns=11)
        code_input = Variable(FloatTensor(np.random.uniform(-1, 1, (batch_size, 2))))

        # Generate a batch of images
        gen_imgs = netG(z, label_input, code_input)

        # Loss measures generator's ability to fool the discriminator
        validity, _, _ = netD(gen_imgs)
        g_loss = adversarial_loss(validity, valid)

        g_loss.backward()
        optimizerG.step()

        # ---------------------
        #  Train Discriminator
        # ---------------------

        optimizerD.zero_grad()

        # Loss for real images
        real_pred, _, _ = netD(real_imgs)
        d_real_loss = adversarial_loss(real_pred, valid)

        # Loss for fake images
        fake_pred, _, _ = netD(gen_imgs.detach())
        d_fake_loss = adversarial_loss(fake_pred, fake)

        # Total discriminator loss
        d_loss = (d_real_loss + d_fake_loss) / 2

        d_loss.backward()
        optimizerD.step()

        #------------------
        # Information Loss
        #------------------

        optimizer_info.zero_grad()

        # Sample labels
        sampled_labels = np.random.randint(0, 11, batch_size)

        # Ground truth labels
        gt_labels = Variable(LongTensor(sampled_labels), requires_grad=False)


        # Sample noise, labels and code as generator input
        z = Variable(FloatTensor(np.random.normal(0, 1, (batch_size, 62))))
        label_input = to_categorical(sampled_labels, num_columns=11)
        code_input = Variable(FloatTensor(np.random.normal(-1, 1, (batch_size, 2))))

        gen_imgs = netG(z, label_input, code_input)
        _, pred_label, pred_code = netD(gen_imgs)

        info_loss = lambda_cat * categorical_loss(pred_label, gt_labels) + \
                    lambda_con * continuous_loss(pred_code, code_input)

        info_loss.backward()
        optimizer_info.step()

        #--------------
        # Log Progress
        #--------------

        print ("[Epoch %d/%d] [Batch %d/%d] [D loss: %f] [G loss: %f] [info loss: %f]" % (epoch, 400, i, len(dataloader),
                                                            d_loss.item(), g_loss.item(), info_loss.item()))
        batches_done = epoch * len(dataloader) + i
        if batches_done % 100 == 0:
            sample_image(n_row=11, batches_done=batches_done)

[Epoch 0/400] [Batch 0/21] [D loss: 0.493480] [G loss: 0.986736] [info loss: 2.620036]
[Epoch 0/400] [Batch 1/21] [D loss: 0.491779] [G loss: 0.983794] [info loss: 2.573261]
[Epoch 0/400] [Batch 2/21] [D loss: 0.489931] [G loss: 0.980199] [info loss: 2.617831]
[Epoch 0/400] [Batch 3/21] [D loss: 0.487247] [G loss: 0.974698] [info loss: 2.595105]
[Epoch 0/400] [Batch 4/21] [D loss: 0.483560] [G loss: 0.967247] [info loss: 2.593553]
[Epoch 0/400] [Batch 5/21] [D loss: 0.478536] [G loss: 0.957331] [info loss: 2.581253]
[Epoch 0/400] [Batch 6/21] [D loss: 0.472338] [G loss: 0.944221] [info loss: 2.582580]
[Epoch 0/400] [Batch 7/21] [D loss: 0.464474] [G loss: 0.928286] [info loss: 2.593290]
[Epoch 0/400] [Batch 8/21] [D loss: 0.453671] [G loss: 0.906881] [info loss: 2.629110]
[Epoch 0/400] [Batch 9/21] [D loss: 0.442241] [G loss: 0.881457] [info loss: 2.624499]
[Epoch 0/400] [Batch 10/21] [D loss: 0.426958] [G loss: 0.849875] [info loss: 2.590659]
[Epoch 0/400] [Batch 11/21] [D loss: 0.411

### Funcs to get generated samples

In [11]:
def sample_image_class(n_row, batches_done=-1, class_n=7):
    """Saves a grid of generated digits ranging from 0 to n_classes"""
    # Sample noise
    z = Variable(FloatTensor(np.random.normal(0, 1, (n_row**2, 100))))
    # Get labels ranging from 0 to n_classes for n rows
    labels = np.array([num for _ in range(n_row) for num in [class_n] * n_row])
    print(labels)
    labels = Variable(LongTensor(labels))
    gen_imgs = netG(z, labels)
    save_image(gen_imgs.data, 'out/class_%d_%d.png' % (batches_done, class_n) , nrow=n_row, normalize=True)

In [None]:
for j in range(100):
        z = Variable(FloatTensor(np.random.normal(0, 1, (10**2, 62))))
        fake = netG(z, static_label, static_code)
        # save_image(static_sample.data, 'out/static_%d.png' % batches_done, nrow=n_row, normalize=True)
        for i in range(100):
            for class_ in range(10):
                if class_ == 4:
                    save_image(fake.detach()[i], 'output/fake_infogan_{}{}_class_{}.png'.format(j, i, class_),
                              nrow=1, normalize=True)
                    

In [None]:
sample_image_class(10, class_n=0)