In [None]:
%matplotlib inline

import argparse
parser = argparse.ArgumentParser('pcgan')

# GAN TYPE:
# 'SCGAN' for Standard Conditional GAN 
# 'PCGAN' for Partially Conditioned GAN
parser.add_argument('--gan_type', type = str, default='PCGAN')

# proportion of conditioning entries for training whith PCGAN
parser.add_argument('--prob_train', type = float, default = 0.85)

# proportion of conditioning entries for generating test images
parser.add_argument('--prob_test', type = float, default = 0.7)


parser.add_argument('--batch_size', type=int, default=128) # batch size
parser.add_argument('--nepoch', type=int, default=12) # number of training epochs
parser.add_argument('--nz', type=int, default=10) # latent space dimension
parser.add_argument('--lr', type=float, default=0.0001) # learning rate
parser.add_argument('--input_size', type = int, default = 784) # MNIST im. size


config, _ = parser.parse_known_args()

In [None]:
import numpy as np
import torch
from torch import nn, optim

import torch.nn.functional as F
from torch.autograd import Variable
from torch.utils.data import DataLoader

import torchvision
from torchvision.utils import save_image
from torchvision import datasets, transforms

import matplotlib.pyplot as plt

In [None]:
%mkdir results
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")

mkdir: cannot create directory ‘results’: File exists


In [None]:
# Define Training Dataset
train_dataset = datasets.MNIST(root='data',
    train=True,
    download=True,
    transform=transforms.ToTensor())
train_loader = DataLoader(train_dataset, shuffle=True,
    batch_size=config.batch_size)

Downloading http://yann.lecun.com/exdb/mnist/train-images-idx3-ubyte.gz to data/MNIST/raw/train-images-idx3-ubyte.gz


HBox(children=(FloatProgress(value=1.0, bar_style='info', max=1.0), HTML(value='')))

Extracting data/MNIST/raw/train-images-idx3-ubyte.gz to data/MNIST/raw
Downloading http://yann.lecun.com/exdb/mnist/train-labels-idx1-ubyte.gz to data/MNIST/raw/train-labels-idx1-ubyte.gz


HBox(children=(FloatProgress(value=1.0, bar_style='info', max=1.0), HTML(value='')))

Extracting data/MNIST/raw/train-labels-idx1-ubyte.gz to data/MNIST/raw
Downloading http://yann.lecun.com/exdb/mnist/t10k-images-idx3-ubyte.gz to data/MNIST/raw/t10k-images-idx3-ubyte.gz


HBox(children=(FloatProgress(value=1.0, bar_style='info', max=1.0), HTML(value='')))

Extracting data/MNIST/raw/t10k-images-idx3-ubyte.gz to data/MNIST/raw
Downloading http://yann.lecun.com/exdb/mnist/t10k-labels-idx1-ubyte.gz to data/MNIST/raw/t10k-labels-idx1-ubyte.gz


HBox(children=(FloatProgress(value=1.0, bar_style='info', max=1.0), HTML(value='')))

Extracting data/MNIST/raw/t10k-labels-idx1-ubyte.gz to data/MNIST/raw
Processing...
Done!




In [None]:
# BUILDING FEATURES
NUM_PARTS = 3
NUM_PARTNS = 10
NUM_CLASES = 10
num_labels = NUM_PARTNS*NUM_PARTS
A = np.zeros([NUM_PARTNS,NUM_CLASES])
for k in range(NUM_PARTNS):
    A[k,:] = np.random.permutation(NUM_CLASES)
set_labels = torch.zeros([NUM_CLASES,num_labels])
for k in range(NUM_CLASES):
    for j in range(NUM_PARTNS):
        if k in A[j,0:3]:
            set_labels[k,3*j] = 1
        elif k in A[j,3:6]:
            set_labels[k,3*j+1] = 1     
        else:
            set_labels[k,3*j+2] = 1
    
# Function for building features
def make_features(batch_size, num_labels,y,one_prob):
    one_hot_features = torch.zeros([batch_size,num_labels])
    for k in range(batch_size):
        one_hot_features[k,:] = set_labels[y[k],:]

    A = np.random.binomial(1,one_prob,one_hot_features.shape)
    A = torch.from_numpy(A)
    one_hot_features = one_hot_features*A
    one_hot_features = one_hot_features.to(device)

    return one_hot_features 






In [None]:
# define DISCRIMINATOR
class ModelD(nn.Module):
    def __init__(self,gan_type):
        super(ModelD, self).__init__()
        self.gan_type = gan_type
        self.conv1 = nn.Conv2d(1, 32, 5, 1, 2)
        self.bn1 = nn.BatchNorm2d(32)
        self.conv2 = nn.Conv2d(32, 64, 5, 1, 2)
        self.bn2 = nn.BatchNorm2d(64)
        self.fc1 = nn.Linear(64*28*28+300, 1024)
        self.fc2 = nn.Linear(1024, 1)
        self.fc3 = nn.Linear(num_labels,300)

    def forward(self, x, labels):
        batch_size = x.size(0)
        x = x.view(batch_size, 1, 28,28)
        x = self.conv1(x)
        x = self.bn1(x)
        x = F.relu(x)
        x = self.conv2(x)
        x = self.bn2(x)
        x = F.relu(x)
        x = x.view(batch_size, 64*28*28)
        y_ = labels
        if self.gan_type == 'SCGAN':
            y_ = self.fc3(labels)
            y_ = F.relu(y_)
        x = torch.cat([x, y_], 1)
        x = self.fc1(x)
        x = F.relu(x)
        x = self.fc2(x)
        return torch.sigmoid(x)

In [None]:
# define GENERATOR
class ModelG(nn.Module):
    def __init__(self):
        super(ModelG, self).__init__()
        self.fc2 = nn.Linear(num_labels, 300)
        self.fc = nn.Linear(config.nz+300, 64*28*28)
        self.bn1 = nn.BatchNorm2d(64)
        self.deconv1 = nn.ConvTranspose2d(64, 32, 5, 1, 2)
        self.bn2 = nn.BatchNorm2d(32)
        self.deconv2 = nn.ConvTranspose2d(32, 1, 5, 1, 2)

    def forward(self, x, labels):
        batch_size = x.size(0)
        y_ = self.fc2(labels)
        y_ = F.relu(y_)
        x = torch.cat([x, y_], 1)
        x = self.fc(x)
        x = x.view(batch_size, 64, 28, 28)
        x = self.bn1(x) 
        x = F.relu(x)
        x = self.deconv1(x)
        x = self.bn2(x)
        x = F.relu(x)
        x = self.deconv2(x)
        x = torch.sigmoid(x)
        return x

    def FE(self,labels):
        y_ = self.fc2(labels)
        return F.relu(y_)

In [None]:
# Define some vectors for generating samples
fixed_z = torch.empty(50,config.nz, device=device)
fixed_y = torch.empty(50,num_labels, device=device)
for k in range(5):
    aux_z = torch.randn(1,config.nz)
    for j in range(10):
        fixed_z[10*k+j] = aux_z
        fixed_y[10*k+j] = label = set_labels[j,:]*torch.from_numpy(np.random.binomial(1,config.prob_test,[1,num_labels])).float()

In [None]:
print('_________________________') 
print('Training ',config.gan_type)
print('...') 

if config.gan_type == 'SCGAN':
    ONE_PROB = 1
elif config.gan_type == 'PCGAN':
    ONE_PROB = config.prob_train    

# Define model
model_d = ModelD(config.gan_type).to(device)
model_g = ModelG().to(device)

criterion = nn.BCELoss()
input = torch.FloatTensor(config.batch_size, config.input_size).to(device)
noise = torch.FloatTensor(config.batch_size, config.nz).to(device)

label = torch.FloatTensor(config.batch_size).to(device) 
one_hot_features = torch.FloatTensor(config.batch_size, 10).to(device)

real_label = 1
fake_label = 0

# Choosing optimizer
optim_d = optim.SGD(model_d.parameters(), lr = config.lr)
optim_g = optim.SGD(model_g.parameters(), lr = config.lr)

# TRAIN 
for epoch_idx in range(config.nepoch):

    model_d.train()
    model_g.train()       
    d_loss = 0.0
    g_loss = 0.0

    for batch_idx, (train_x, train_y) in enumerate(train_loader):
        batch_size = train_x.size(0)
        # x from matrix to vector
        train_x = train_x.view(-1, config.input_size).to(device)
        train_y = train_y.to(device)

        ####### TRAIN DISCRIMINATOR #######
        input.resize_as_(train_x).copy_(train_x)  # load true images      
        label.resize_(batch_size).fill_(real_label) # vetor of 1's (i.e. label = true)
        one_hot_features = make_features(batch_size, num_labels, train_y,ONE_PROB)
        
        inputv = Variable(input)
        labelv = Variable(label)

        if config.gan_type == 'PCGAN':
            layered_one_hot = model_g.FE(one_hot_features)
            layered_onehotv = Variable(layered_one_hot)
            output = model_d(inputv, layered_onehotv)
        elif config.gan_type == 'SCGAN':
            onehotv = Variable(one_hot_features)
            output = model_d(inputv, onehotv) # predictions of the discriminator when inputs are true images
        
        optim_d.zero_grad()
        errD_real = criterion(output, labelv) # error (BCE) when we feed the discriminator true images
        errD_real.backward() # backpropagate error
        realD_mean = output.data.cpu().mean() # true positive rate TP/(TP+FN)
        
        rand_y = torch.from_numpy(
            np.random.randint(0, NUM_CLASES, size=(batch_size,1))).cuda() # random features (as digit labels)
        one_hot_features = make_features(batch_size, num_labels,rand_y,ONE_PROB)

        noise.resize_(batch_size, config.nz).normal_(0,1) # noise input z
        label.resize_(batch_size).fill_(fake_label) # vector of 0's (i.e. label = generated (fake) image)
        noisev = Variable(noise)
        labelv = Variable(label)
        onehotv = Variable(one_hot_features)

        g_out = model_g(noisev, onehotv)

        one_hot_features = make_features(batch_size, num_labels,rand_y,ONE_PROB)
        onehotv = Variable(one_hot_features)
        
        if config.gan_type == 'PCGAN':
            layered_one_hot = model_g.FE(one_hot_features)
            layered_onehotv = Variable(layered_one_hot)
            output = model_d(g_out, layered_onehotv)
        elif config.gan_type == 'SCGAN':
            output = model_d(g_out, onehotv) # discriminator output when fed generated (fake) images
        
        errD_fake = criterion(output, labelv) # error (BCE) when we feed the discriminator generated (fake) images
        fakeD_mean = output.data.cpu().mean() #true negative rate TN/(TN+FP)
        errD = errD_real + errD_fake
        errD_fake.backward() # backpropagate error
        optim_d.step() # perform learning step

        ####### TRAIN GENERATOR #######
        noise.normal_(0,1) # noise input z
        rand_y = torch.from_numpy(
            np.random.randint(0, NUM_CLASES, size=(batch_size,1))).cuda() # random features (as digit labels)
        one_hot_features = make_features(batch_size, num_labels, rand_y,ONE_PROB)

        label.resize_(batch_size).fill_(real_label) # vetor of 1's (i.e. label = true)

        onehotv = Variable(one_hot_features)
        noisev = Variable(noise)
        labelv = Variable(label)

        g_out = model_g(noisev, onehotv) # generator output when fed random noise and random features

        if config.gan_type == 'PCGAN':
            layered_one_hot = model_g.FE(one_hot_features) 
            layered_onehotv = Variable(layered_one_hot)
            output = model_d(g_out, layered_onehotv) # discriminator output when fed generated (fake) images
        elif config.gan_type == 'SCGAN':
            output = model_d(g_out, onehotv) # discriminator output when fed generated (fake) images

        errG = criterion(output, labelv) # 1-discriminator error = 1-generator success = generator error
        optim_g.zero_grad()
        errG.backward() # backpropagate error
        optim_g.step() # perform learning step
        
        # add batch losses to epoch losses
        d_loss += errD.data
        g_loss += errG.data

    print('Epoch {} - D loss = {:.4f}, G loss = {:.4f}'.format(epoch_idx,
        d_loss, g_loss))
    
    with torch.no_grad():
        fake = model_g(fixed_z,fixed_y)
        save_image(1-fake.detach(),'results/'+config.gan_type+'samples_e_%03d.png' % (epoch_idx),normalize=True,nrow = 5)

_________________________
Training  PCGAN
...


  return F.binary_cross_entropy(input, target, weight=self.weight, reduction=self.reduction)


KeyboardInterrupt: ignored

In [None]:
import matplotlib.image as mpimg
plt.figure(figsize=(15,7))
plt.imshow(mpimg.imread('results/'+config.gan_type+'samples_e_%03d.png' % (config.nepoch-1)),cmap='gray',aspect = 'equal')
plt.title(config.gan_type+' w/ %0.2f entries' % (config.prob_test))
plt.axis('off')