# NIPS Paper Implementation Challenge 
## PyTorch Code Implementation for Paper Structured Generative Adversarial Networks


In [1]:
### IMPORTS ###
import torch
import torch.nn as nn
import torch.optim as optim
import numpy as np
import time
import utils
# initialize logger
import logging.config
import yaml
with open('./log_config.yaml') as file:
    Dict = yaml.load(file)    # load config file
    logging.config.dictConfig(Dict)    # import config
    
logger = logging.getLogger(__name__)
logger.info('PyTorch version: ' + str(torch.__version__))

# import SGAN utils
from layers import rampup, rampdown
from zca import ZCA
from models import Generator, InferenceNet, ClassifierNet, DConvNet1, DConvNet2
from trainGAN import pretrain_classifier, train_classifier, train_gan, eval_classifier

from sklearn.metrics import accuracy_score
from torch.autograd import Variable
import logging
import torch.nn.functional as F
from torch.nn.utils import weight_norm as wn

from layers import conv_concat, mlp_concat, init_weights, Gaussian_NoiseLayer, MeanOnlyBatchNorm

   INFO [12:46:29] __main__: PyTorch version: 0.2.0_3


### Global Parameter Setting

In [2]:
### GLOBAL PARAMS ###
BATCH_SIZE = 200
BATCH_SIZE_EVAL = 200
NUM_CLASSES = 10
NUM_LABELLED = 4000
SSL_SEED = 1
NP_SEED = 1234
CUDA = torch.cuda.is_available()
logger.info('Cuda = ' + str(CUDA))

# data dependent
IN_CHANNELS = 3

# evaluation
VIS_EPOCH = 1
EVAL_EPOCH = 1

# C
SCALED_UNSUP_WEIGHT_MAX = 100.0

# G
N_Z = 100

# optimization
B1 = 0.5  # beta1 in Adam
LR = 3e-4
LR_CLA = 3e-3
NUM_EPOCHS = 1000
NUM_EPOCHS_PRE = 20
ANNEAL_EPOCH = 200
ANNEAL_EVERY_EPOCH = 1
ANNEAL_FACTOR = 0.995
ANNEAL_FACTOR_CLA = 0.99

path_out = "./results"


   INFO [12:46:29] __main__: Cuda = False


### Data Preprocessing
- Download the cifar-10 dataset from 'http://www.cs.toronto.edu/~kriz/cifar-10-python.tar.gz'
- Split the dataset into labelled training and test dataset, with length of 50000 and 10000 resprectively
- Create an unlabelled dataset by copying the training dataset created above
- Shuffle the labelled training dataset
- Create a much smaller length of labelled training dataset with length of 4000 for our semi-supervised classification setting.
- Calculate the number of minibatches for labelled training and test, and unlabelled training datasets respectively

In [3]:
### DATA ###
logger.info('Loading data...')
train_x, train_y = utils.load('./cifar10/', 'train')
eval_x, eval_y = utils.load('./cifar10/', 'test')

train_y = np.int32(train_y)
eval_y = np.int32(eval_y)
x_unlabelled = train_x.copy()

rng_data = np.random.RandomState(SSL_SEED)
inds = rng_data.permutation(train_x.shape[0])
train_x = train_x[inds]
train_y = train_y[inds]
x_labelled = []
y_labelled = []

for j in range(NUM_CLASSES):
    x_labelled.append(train_x[train_y == j][:int(NUM_LABELLED / NUM_CLASSES)])
    y_labelled.append(train_y[train_y == j][:int(NUM_LABELLED / NUM_CLASSES)])

x_labelled = np.concatenate(x_labelled, axis=0)
y_labelled = np.concatenate(y_labelled, axis=0)
del train_x

num_batches_l = int(x_labelled.shape[0] // BATCH_SIZE)
num_batches_u = int(x_unlabelled.shape[0] // BATCH_SIZE)
num_batches_e = int(eval_x.shape[0] // BATCH_SIZE_EVAL)
rng = np.random.RandomState(NP_SEED)

   INFO [12:46:31] __main__: Loading data...


### Model Structures
The model structures of structured GAN consist of one generator network, two discriminator networks, and two inference networks which are named as classifier and inferentor respectively. 

#### Generator Network $p_g (x \mid y, z)$
The generator parametrizes the sampling process $x \sim p_g(x \mid y, z) = G(y, z)$ by taking some hidden structures $y$ of $x$ and other factors $z$ as inputs, and outputing generated samples $x$. 
![title](generator_paper.png)

The generator architecture consists of following layers, activation functions, and parameters.
![title](generator.png)

In [6]:
### MODEL STRUCTURES ###

class Generator(nn.Module):
    def __init__(self, input_size, num_classes, dense_neurons, weight_init=True):
        super(Generator, self).__init__()

        self.logger = logging.getLogger(__name__)  # initialize logger

        self.num_classes = num_classes

        self.Dense = nn.Linear(input_size, dense_neurons)
        self.Relu = nn.ReLU()
        self.Tanh = nn.Tanh()

        self.Deconv2D_0 = nn.ConvTranspose2d(in_channels=522, out_channels=256,
                                             kernel_size=5, stride=2, padding=2,
                                             output_padding=1, bias=False)
        self.Deconv2D_1 = nn.ConvTranspose2d(in_channels=266, out_channels=128,
                                             kernel_size=5, stride=2, padding=2,
                                             output_padding=1, bias=False)
        self.Deconv2D_2 = wn(nn.ConvTranspose2d(in_channels=138, out_channels=3,
                                                kernel_size=5, stride=2, padding=2,
                                                output_padding=1, bias=False))

        self.BatchNorm1D = nn.BatchNorm1d(dense_neurons)

        self.BatchNorm2D_0 = nn.BatchNorm2d(256)
        self.BatchNorm2D_1 = nn.BatchNorm2d(128)

        if weight_init:
            # initialize weights for all conv and lin layers
            self.apply(init_weights)
            # log network structure
            self.logger.debug(self)

    def forward(self, z, y):
        x = mlp_concat(z, y, self.num_classes)

        x = self.Dense(x)
        x = self.Relu(x)
        x = self.BatchNorm1D(x)

        x = x.resize(z.size(0), 512, 4, 4)
        x = conv_concat(x, y, self.num_classes)

        x = self.Deconv2D_0(x)                    # output shape (256,8,8) = 8192 * 2
        x = self.Relu(x)
        x = self.BatchNorm2D_0(x)

        x = conv_concat(x, y, self.num_classes)

        x = self.Deconv2D_1(x)                    # output shape (128,16,16) = 8192 * 2 * 2
        x = self.Relu(x)
        x = self.BatchNorm2D_1(x)

        x = conv_concat(x, y, self.num_classes)
        x = self.Deconv2D_2(x)                    # output shape (3, 32, 32) = 3072
        x = self.Tanh(x)

        return x


#### Discriminator Network 1

Discriminator 1 is trained to distinguish generated pairs $(x, y) \sim p_g(x,y)$ by using generator introduced above from those come from real pairs $p(x,y)$. It takes $x$ and $y$ as inputs and outputs the classification decisions whether the joint pairs are drawn from $p_g(x,y)$ or $p(x,y)$.

![title](discriminator1.png)

In [None]:
# discriminator xy2p: test a pair of input comes from p(x, y) instead of p_c or p_g
class DConvNet1(nn.Module):
    '''
    1st convolutional discriminative net (discriminator xy2p)
    --> does a pair of input come from p(x, y) instead of p_c or p_g ?
    '''

    def __init__(self, channel_in, num_classes, p_dropout=0.2, weight_init=True):
        super(DConvNet1, self).__init__()

        self.logger = logging.getLogger(__name__)  # initialize logger

        self.num_classes = num_classes

        # general reusable layers:
        self.LReLU = nn.LeakyReLU(negative_slope=0.2)  # leaky ReLU activation function
        self.sgmd = nn.Sigmoid()  # sigmoid activation function
        self.drop = nn.Dropout(p=p_dropout)  # dropout layer

        # input -->
        # drop
        # ConvConcat

        self.conv1 = wn(nn.Conv2d(in_channels=channel_in + num_classes, out_channels=32,
                                  kernel_size=(3, 3), stride=(1, 1), padding=1, bias=False))
        # LReLU
        # ConvConcat

        self.conv2 = wn(nn.Conv2d(in_channels=32 + num_classes, out_channels=32,
                                  kernel_size=(3, 3), stride=2, padding=1, bias=False))
        # LReLU
        # drop
        # ConvConcat

        self.conv3 = wn(nn.Conv2d(in_channels=32 + num_classes, out_channels=64,
                                  kernel_size=(3, 3), stride=(1, 1), padding=1, bias=False))
        # LReLU
        # ConvConcat

        self.conv4 = wn(nn.Conv2d(in_channels=64 + num_classes, out_channels=64,
                                  kernel_size=(3, 3), stride=2, padding=1, bias=False))
        # LReLU
        # drop
        # ConvConcat

        self.conv5 = wn(nn.Conv2d(in_channels=64 + num_classes, out_channels=128,
                                  kernel_size=(3, 3), stride=(1, 1), padding=0, bias=False))
        # LReLU
        # ConvConcat

        self.conv6 = wn(nn.Conv2d(in_channels=128 + num_classes, out_channels=128,
                                  kernel_size=(3, 3), stride=(1, 1), padding=0, bias=False))
        # LReLU

        self.globalPool = nn.AdaptiveAvgPool2d(output_size=4)

        # MLPConcat

        self.lin = nn.Linear(in_features=128 * 4 * 4 + num_classes,
                             out_features=1)
        # smg

        if weight_init:
            # initialize weights for all conv and lin layers
            self.apply(init_weights)
            # log network structure
            self.logger.debug(self)

    def forward(self, x, y):
        # x: (bs, channel_in, dim_input)
        # y: (bs, 1)

        x0 = self.drop(x)
        x0 = conv_concat(x0, y, self.num_classes)

        x1 = self.LReLU(self.conv1(x0))
        x1 = conv_concat(x1, y, self.num_classes)

        x2 = self.LReLU(self.conv2(x1))
        x2 = self.drop(x2)
        x2 = conv_concat(x2, y, self.num_classes)

        x3 = self.LReLU(self.conv3(x2))
        x3 = conv_concat(x3, y, self.num_classes)

        x4 = self.LReLU(self.conv4(x3))
        x4 = self.drop(x4)
        x4 = conv_concat(x4, y, self.num_classes)

        x5 = self.LReLU(self.conv5(x4))
        x5 = conv_concat(x5, y, self.num_classes)

        x6 = self.LReLU(self.conv6(x5))

        x_pool = self.globalPool(x6)

        x_pool = x_pool.view(-1, 128 * 4 * 4)
        x_out = mlp_concat(x_pool, y, self.num_classes)

        out = self.sgmd(self.lin(x_out))

        return out


#### Discriminator Network 2

Discriminator 2 is trained to distinguish generated pairs $(x, z) \sim p_g(x,z)$ by using generator introduced above from those come from inferenced pairs $p_i (x,z)$ by using inferentor introduced below. It takes $x$ and $z$ as inputs and outputs the classification decisions whether the joint pairs are drawn from $p_g(x,z)$ or $p_i(x,z)$.
![title](discriminator2.png)

In [None]:
# discriminator xz
class DConvNet2(nn.Module):
    '''
    2nd convolutional discriminative net (discriminator xz)
    '''

    def __init__(self, n_z, channel_in, num_classes, weight_init=True):
        super(DConvNet2, self).__init__()

        self.logger = logging.getLogger(__name__)  # initialize logger

        self.num_classes = num_classes

        # general reusable layers:
        self.LReLU = nn.LeakyReLU(negative_slope=0.2)  # leaky ReLU activation function
        self.sgmd = nn.Sigmoid()  # sigmoid activation function

        # z input -->
        self.lin_z0 = nn.Linear(in_features=n_z,
                                out_features=512)
        # LReLU

        self.lin_z1 = nn.Linear(in_features=512,
                                out_features=512)
        # LReLU

        # -------------------------------------

        # x input -->
        self.conv_x0 = nn.Conv2d(in_channels=channel_in, out_channels=128,
                                 kernel_size=(5, 5), stride=2, padding=2, bias=False)
        # LReLU

        self.conv_x1 = nn.Conv2d(in_channels=128, out_channels=256,
                                 kernel_size=(5, 5), stride=2, padding=2, bias=False)
        # LReLU
        self.bn1 = nn.BatchNorm2d(num_features=256)

        self.conv_x2 = nn.Conv2d(in_channels=256, out_channels=512,
                                 kernel_size=(5, 5), stride=2, padding=2, bias=False)
        # LReLU
        self.bn2 = nn.BatchNorm2d(num_features=512)

        # -------------------------------------

        # concat x & z -->
        self.lin_f0 = nn.Linear(in_features=8704,
                                out_features=1024)
        # LReLU

        self.lin_f1 = nn.Linear(in_features=1024,
                                out_features=1)
        # smg

        if weight_init:
            # initialize weights for all conv and lin layers
            self.apply(init_weights)
            # log network structure
            self.logger.debug(self)

    def forward(self, z, x):
        # x: (bs, channel_in, dim_input)
        # z: (bs, n_z)

        z0 = self.LReLU(self.lin_z0(z))
        z_out = self.LReLU(self.lin_z1(z0))

        x0 = self.LReLU(self.conv_x0(x))
        x1 = self.LReLU(self.conv_x1(x0))
        x1 = self.bn1(x1)
        x_out = self.LReLU(self.conv_x2(x1))
        x_out = self.bn2(x_out)

        dims = x_out.size()
        fusion = torch.cat([x_out.view(dims[0], -1).squeeze(-1).squeeze(-1), z_out], dim=1)

        f_out = self.LReLU(self.lin_f0(fusion))
        out = self.sgmd(self.lin_f1(f_out))

        return out

#### Inference Networks
Two inference networks define two distributions $p_i(z\mid x)$ and $p_c(y\mid x)$ that are used to approximate the true posterior $p(z\mid x)$ and $p(y \mid x)$ using two different adversarial games, which are named InferenceNet and ClassifierNet accordingly. 
![title](inferencenets.png)

In [None]:
class InferenceNet(nn.Module):
    def __init__(self, in_channels, n_z, weight_init=True):
        super(InferenceNet, self).__init__()

        self.logger = logging.getLogger(__name__)  # initialize logger

        self.inf02 = nn.Conv2d(in_channels=in_channels, out_channels=64, kernel_size=4,
                               stride=2, padding=1)
        self.inf03 = nn.BatchNorm2d(64)
        self.inf11 = nn.Conv2d(in_channels=64, out_channels=128, kernel_size=4,
                               stride=2, padding=1)
        self.inf12 = nn.BatchNorm2d(128)
        self.inf21 = nn.Conv2d(in_channels=128, out_channels=256, kernel_size=4,
                               stride=2, padding=1)
        self.inf22 = nn.BatchNorm2d(256)
        self.inf31 = nn.Conv2d(in_channels=256, out_channels=512, kernel_size=4,
                               stride=2, padding=1)
        self.inf32 = nn.BatchNorm2d(512)
        self.inf4 = nn.Linear(in_features=512*2*2, out_features=n_z)

        if weight_init:
            # initialize weights for all conv and lin layers
            self.apply(init_weights)
            # log network structure
            self.logger.debug(self)

    def forward(self, x):
        x = F.leaky_relu(self.inf03(self.inf02(x)))
        x = F.leaky_relu(self.inf12(self.inf11(x)))
        x = F.leaky_relu(self.inf22(self.inf21(x)))
        x = F.leaky_relu(self.inf32(self.inf31(x)))
        x = x.view(-1, 512*2*2)
        x = self.inf4(x)

        return x


#### Classifier
The classifier is a inference network $C: x \rightarrow y$ which approximates the posterior $p(y\mid x)$ as $y \sim p_c(y \mid x) =C(x)$. In the case of using cifar-10 dataset, it is a 10-way classifier. 

In [None]:
# classifier module
class ClassifierNet(nn.Module):
    def __init__(self, in_channels, weight_init=True):
        super(ClassifierNet, self).__init__()

        self.logger = logging.getLogger(__name__)  # initialize logger

        self.gaussian = Gaussian_NoiseLayer()

        self.conv1a = nn.Conv2d(in_channels=in_channels, out_channels=128, kernel_size=3,
                                stride=1, padding=1)
        self.convWN1 = MeanOnlyBatchNorm([1, 128, 32, 32])
        self.conv1b = nn.Conv2d(in_channels=128, out_channels=128, kernel_size=3,
                                stride=1, padding=1)
        self.conv_relu = nn.LeakyReLU(negative_slope=0.1)
        self.conv1c = nn.Conv2d(in_channels=128, out_channels=128, kernel_size=3,
                                stride=1, padding=1)
        self.conv_maxpool1 = nn.MaxPool2d(kernel_size=2)
        self.dropout1 = nn.Dropout2d(p=0.5)
        self.conv2a = nn.Conv2d(in_channels=128, out_channels=256, kernel_size=3,
                                stride=1, padding=1)
        self.convWN2 = MeanOnlyBatchNorm([1, 256, 16, 16])
        self.conv2b = nn.Conv2d(in_channels=256, out_channels=256, kernel_size=3,
                                stride=1, padding=1)
        self.conv2c = nn.Conv2d(in_channels=256, out_channels=256, kernel_size=3,
                                stride=1, padding=1)
        self.conv_maxpool2 = nn.MaxPool2d(kernel_size=2)
        self.dropout2 = nn.Dropout2d(p=0.5)
        self.conv3a = nn.Conv2d(in_channels=256, out_channels=512, kernel_size=3,
                                stride=1, padding=0)  # output[6,6]
        self.convWN3a = MeanOnlyBatchNorm([1, 512, 6, 6])
        self.conv3b = nn.Conv2d(in_channels=512, out_channels=256, kernel_size=1,
                                stride=1, padding=0)
        self.convWN3b = MeanOnlyBatchNorm([1, 256, 6, 6])
        self.conv3c = nn.Conv2d(in_channels=256, out_channels=128, kernel_size=1,
                                stride=1, padding=0)
        self.convWN3c = MeanOnlyBatchNorm([1, 128, 6, 6])

        self.conv_globalpool = nn.AdaptiveAvgPool2d(6)

        self.dense = nn.Linear(in_features=128 * 6 * 6, out_features=10)
        self.smx = nn.Softmax()
        #self.WNfinal = MeanOnlyBatchNorm([1, 128, 6, 6])

        if weight_init:
            # initialize weights for all conv and lin layers
            self.apply(init_weights)
            # log network structure
            self.logger.debug(self)

    def forward(self, x, cuda):
        x = self.gaussian(x, cuda=cuda)
        x = self.convWN1(self.conv_relu(self.conv1a(x)))
        x = self.convWN1(self.conv_relu(self.conv1b(x)))
        x = self.convWN1(self.conv_relu(self.conv1c(x)))
        x = self.conv_maxpool1(x)
        x = self.dropout1(x)
        x = self.convWN2(self.conv_relu(self.conv2a(x)))
        x = self.convWN2(self.conv_relu(self.conv2b(x)))
        x = self.convWN2(self.conv_relu(self.conv2c(x)))
        x = self.conv_maxpool2(x)
        x = self.dropout2(x)
        x = self.convWN3a(self.conv_relu(self.conv3a(x)))
        x = self.convWN3b(self.conv_relu(self.conv3b(x)))
        x = self.convWN3c(self.conv_relu(self.conv3c(x)))
        x = self.conv_globalpool(x)
        x = x.view(-1, 128 * 6 * 6)
        #x = self.WNfinal(self.smx(self.dense(x)))
        x = self.smx(self.dense(x))
        return x

### Training Process
The training process of structured GAN consists of pretraining of classifier, and training classifier, discriminator, inferentor, and generator in succession. It envolves training two adversarial games $\mathcal{L}_{xy}$ and $\mathcal{L}_{xz}$, and two collorative games $\mathcal{R}_y$ and $\mathcal{R}_z$. In following subsections, training details for different networks, such loss functions and optimizers used, are described. 

In [None]:
def train_gan(discriminator1, discriminator2, generator, inferentor, classifier, whitener,
              x_labelled, x_unlabelled, y_labelled, p_u_d, p_u_i,
              num_classes, batch_size, num_batches_u,
              batch_c, batch_l, batch_g,
              n_z, optimizers, losses, rng, cuda=False):

    '''

    Args:
        discriminator1(DConvNet1): Discriminator instance xy
        discriminator2(DConvNet2): Discriminator instance xz
        generator(Generator): Generator instance
        inferentor(InferenceNet): Inference Net instance
        classifier(ClassifierNet): Classifier Net instance
        whitener(ZCA): ZCA instance
        x_labelled: batch of labelled input data
        x_unlabelled: batch of unlabelled input data
        y_labelled: batch of corresponding labels
        p_u_d: data slice object (idx)
        p_u_i: data slice object (idx)
        num_classes(int): number of target classes
        batch_size(int): size of mini-batch
        num_batches_u:
        batch_c:
        batch_l:
        batch_g:
        n_z:
        optimizers(dict): dictionary containing optimizer instances for all respective nets (dis, gen, inf)
        losses(dict): dictionary containing respective loss instances (BCE, MSE, CE)
        b1: beta1 in Adam
        cuda(bool): cuda flag

    Returns:

    '''

    for i in range(num_batches_u):
            i_l = i % (x_labelled.shape[0] // batch_l)

            from_u_i = i*batch_size  # unlabelled inferentor slice
            to_u_i = (i+1)*batch_size
            from_u_d = i*batch_c    # unlabelled discriminator slice
            to_u_d = (i+1) * batch_c
            from_l = i_l*batch_l    # labelled
            to_l = (i_l+1)*batch_l

            # create samples and labels
            sample_y = torch.from_numpy(np.int32(np.repeat(np.arange(num_classes), int(batch_size/num_classes))))
            y_real = torch.from_numpy(np.int32(np.random.randint(10, size=batch_g)))
            z_real = torch.from_numpy(np.random.uniform(size=(batch_g, n_z)).astype(np.float32))
            z_rand = torch.rand((batch_size*n_z)).view(batch_size, n_z)

            sample_y, y_real, z_real, z_rand = Variable(sample_y), Variable(y_real), Variable(z_real), Variable(z_rand)
            if cuda:
                sample_y, y_real, z_real, z_rand = sample_y.cuda(), y_real.cuda(), z_real.cuda(), z_rand.cuda()

            dis_losses = train_discriminator(discriminator1=discriminator1,
                                             discriminator2=discriminator2,
                                             generator=generator,
                                             inferentor=inferentor,
                                             classificator=classifier,
                                             whitener=whitener,
                                             x_labelled=x_labelled[from_l:to_l],  # sym_x_l
                                             x_unlabelled=x_unlabelled,
                                             y_labelled=y_labelled[from_l:to_l],  # sym_y
                                             slice_x_dis=p_u_d[from_u_d:to_u_d],  # slice_x_u_d
                                             y_real=y_real,  # sym_y_m
                                             z_real=z_real,  # sym_z_m
                                             slice_x_inf=p_u_i[from_u_i:to_u_i],  # slice_x_u_i
                                             sample_y=sample_y,  # sym_y_g
                                             z_rand=z_rand,
                                             batch_size=batch_size,
                                             optimizer=optimizers['dis'],
                                             loss=losses['bce'],
                                             cuda=cuda)

            inf_losses = train_inferentor(x_unlabelled=x_unlabelled,
                                          sample_y=sample_y,
                                          generator=generator,
                                          z_rand=z_rand,
                                          discriminator2=discriminator2,
                                          inferentor=inferentor,
                                          mse=losses['mse'],
                                          bce=losses['bce'],
                                          slice_x_u_i=p_u_i[from_u_i:to_u_i],
                                          optimizer=optimizers['inf'],
                                          cuda=cuda)

            gen_losses = train_generator(whitener=whitener,
                                         optimizer=optimizers['gen'],
                                         BCE_loss=losses['bce'],
                                         MSE_loss=losses['mse'],
                                         cross_entropy_loss=losses['ce'],
                                         discriminator1=discriminator1,
                                         discriminator2=discriminator2,
                                         inferentor=inferentor,
                                         generator=generator,
                                         classifier=classifier,
                                         sample_y=sample_y,
                                         z_rand=z_rand,
                                         cuda=cuda)


            gan_loss = {
                'dis': dis_losses,
                'inf': inf_losses,
                'gen': gen_losses
            }

            return gan_loss




#### Pretraining
Pretraining aims to obtain a relative good prior weights for later training phase. It is achieved by minimizing the reconstruction error of $y$ in terms of $C$, on both labeled data $X_l$ and generated data:  $$\min_{C,G} \mathcal{R}_y = - \mathbb{E}_{(x,y)\sim p(x,y)}[\log p_c(y\mid x)]$$

Adam optimizer is used for training.

In [None]:
def pretrain_classifier(x_labelled, x_unlabelled, y_labelled, eval_x, eval_y, num_batches_l,
                        batch_size, num_batches_u, classifier, whitener, losses, rng, cuda):

    # randomly permute data and labels
    permutation_labelled = rng.permutation(x_labelled.shape[0])
    x_labelled = x_labelled[permutation_labelled]
    y_labelled = y_labelled[permutation_labelled]
    permutation_unlabelled = rng.permutation(x_unlabelled.shape[0]).astype('int32')

    x_labelled = Variable(torch.from_numpy(x_labelled))
    y_labelled = Variable(torch.from_numpy(y_labelled))

    eval_x = Variable(torch.from_numpy(eval_x))
    eval_y = Variable(torch.from_numpy(eval_y))
    x_unlabelled = Variable(torch.from_numpy(x_unlabelled))

    if cuda:
        x_labelled, y_labelled, eval_x, eval_y, x_unlabelled = \
            x_labelled.cuda(), y_labelled.cuda(), eval_x.cuda(), eval_y.cuda(), x_unlabelled.cuda()

    for i in range(num_batches_u):
        i_c = i % num_batches_l
        x_l = x_labelled[i_c * batch_size:(i_c + 1) * batch_size]
        x_l_zca = whitener.apply(x_l)
        y = y_labelled[i_c * batch_size:(i_c + 1) * batch_size]
        y = y.type(torch.LongTensor)
        if cuda:
            y = y.cuda()

        # classify input
        cla_out_y_l = classifier(x_l_zca, cuda=cuda)
        # calculate loss
        cla_cost_l = losses['ce'](cla_out_y_l, y)

        batch_slicer = torch.from_numpy(permutation_unlabelled[i * batch_size:(i + 1) * batch_size]).type(torch.LongTensor)
        if cuda:
            batch_slicer = batch_slicer.cuda()
        x_u_rep = x_unlabelled[batch_slicer]
        x_u_rep_zca = whitener.apply(x_u_rep)
        # classify input
        cla_out_y_rep = classifier(x_u_rep_zca, cuda=cuda)
        target = cla_out_y_rep.detach()
        del cla_out_y_rep

        x_u = x_unlabelled[batch_slicer]
        x_u_zca = whitener.apply(x_u)
        # classify input
        cla_out_y = classifier(x_u_zca, cuda=cuda)

        # calculate loss
        cla_cost_u = 100 * losses['mse'](cla_out_y, target)

        # sum losses
        pretrain_cost = cla_cost_l + cla_cost_u

        # run optimization and update weights
        cla_optimizer = optim.Adam(classifier.parameters(), betas=(0.9, 0.999), lr=3e-3)  # they implement robust adam
        pretrain_cost.backward()
        cla_optimizer.step()

    return classifier

#### Training Generator

In [None]:
def train_generator(whitener, optimizer, BCE_loss, MSE_loss, cross_entropy_loss,
                    discriminator1, discriminator2, inferentor, generator, classifier, sample_y, z_rand, cuda):
    '''
    Args:
        whitener(ZCA):      ZCA instance
        optimizer:          optimizer  for generator
        BCE_loss:           binary cross entropy loss
        MSE_loss:           mean squared error loss
        cross_entropy_loss: cross entropy loss
        discriminator1(DConvNet1): Discriminator instance xy
        discriminator2(DConvNet2): Discriminator instance xz
        inferentor:          Inference net
        generator:          Generator net
        classifier:         Classificaiton net
        sample_y:           sampled labels
        z_rand:             random z sample


    Returns:

    '''
    # compute loss
    gen_out_x = generator(z_rand, sample_y)
    inf_z_g = inferentor(gen_out_x)
    gen_out_x_zca = whitener.apply(gen_out_x)
    cla_out_y_g = classifier(gen_out_x_zca, cuda=cuda)
    rz = MSE_loss(inf_z_g, z_rand)
    sample_y = sample_y.long()
    ry = cross_entropy_loss(cla_out_y_g, sample_y)
    dis_out_p_g = discriminator1(x=gen_out_x, y=sample_y)
    disxz_out_p_g = discriminator2(z=z_rand, x=gen_out_x)

    target1, target2 = Variable(torch.ones(dis_out_p_g.size())), Variable(torch.ones(disxz_out_p_g.size()))
    if cuda:
        target1, target2 = target1.cuda(), target2.cuda()

    gen_cost_p_g_1 = BCE_loss(dis_out_p_g, target1)
    gen_cost_p_g_2 = BCE_loss(disxz_out_p_g, target2)

    generator_cost = gen_cost_p_g_1 + gen_cost_p_g_2 + rz + ry

    # optimization routines and weight updates
    optimizer.zero_grad()
    generator_cost.backward()
    optimizer.step()

    return generator_cost.cpu().data.numpy().mean()


#### Training Discriminator

In [None]:
def train_discriminator(discriminator1, discriminator2, generator, inferentor, classificator, whitener,
                        x_labelled, x_unlabelled, y_labelled,
                        slice_x_dis, y_real, z_real, slice_x_inf, sample_y, z_rand,
                        batch_size, optimizer, loss, cuda):
    '''

    Args:
        discriminator1(DConvNet1): Discriminator instance xy
        discriminator2(DConvNet2): Discriminator instance xz
        generator(Generator): Generator instance
        inferentor(InferenceNet): Inference Net instance
        classificator(ClassifierNet): Classifier Net instance
        whitener(ZCA): ZCA instance
        x_labelled: batch of labelled input data
        x_unlabelled: batch of unlabelled input data
        y_labelled: batch of corresponding labels
        slice_x_dis: indexes to select unlabelled data for discriminator
        y_real: class labels
        z_real: generator_x_m noise input
        slice_x_inf: indexes to select unlabelled data for inference net
        sample_y: sampled labels
        z_rand: generator_x noise input
        batch_size(int): size of mini-batch
        d1_optimizer(torch.optim): optimizer instance for discriminator1
        d2_optimizer(torch.optim): optimizer instance for discriminator2
        loss(torch.nn.Loss): loss instance for discriminators (BCE)
        cuda(bool): cuda flag (GPU)

    Returns: list(discriminator1 loss, discriminator2 loss)

    '''

    '''
    Parameter Translation: Theano original --> PyTorch
    input:
        x_labelled[from_l:to_l],  # sym_x_l
        y_labelled[from_l:to_l],  # sym_y
        p_u_d[from_u_d:to_u_d] --> slice_x_dis,  # slice_x_u_d
        y_real,  # sym_y_m
        z_real,  # sym_z_m
        p_u_i[from_u_i:to_u_i] --> slice_x_inf,  # slice_x_u_i
        sample_y,  # sym_y_g
    '''


    # get respective data slices for batch
    unlabel_dis = x_unlabelled[slice_x_dis]  # original: sym_x_u_d
    unlabel_dis_zca = whitener.apply(unlabel_dis)  # original: sym_x_u_d_zca
    unlabel_inf = x_unlabelled[slice_x_inf]  # original: sym_x_u_i

    # convert data ndarrays to pytorch tensor variables
    x_labelled = Variable(torch.from_numpy(x_labelled))
    y_labelled = Variable(torch.from_numpy(y_labelled))
    unlabel_dis = Variable(torch.from_numpy(unlabel_dis))
    unlabel_dis_zca = Variable(torch.from_numpy(unlabel_dis_zca))
    unlabel_inf = Variable(torch.from_numpy(unlabel_inf))

    if cuda:
        x_labelled, y_labelled = x_labelled.cuda(), y_labelled.cuda()
        unlabel_dis, unlabel_dis_zca, unlabel_inf = unlabel_dis.cuda(), unlabel_dis_zca.cuda(), unlabel_inf.cuda()

    # generate samples
    gen_out_x = generator(z=z_rand, y=sample_y)
    gen_out_x_m = generator(z=z_real, y=y_real)

    # compute inference
    inf_z = inferentor(unlabel_inf)

    # classify
    cla_out = classificator(unlabel_dis_zca, cuda=cuda)
    cla_out_val, cla_out_idx = cla_out.max(dim=1)

    # concatenate inputs
    x_in = torch.cat([x_labelled, unlabel_dis, gen_out_x_m], dim=0)[:batch_size]

    y_labelled = y_labelled.long()
    y_real = y_real.long()
    y_in = torch.cat([y_labelled, cla_out_idx, y_real], dim=0)[:batch_size]

    # calculate probabilities by discriminators
    dis1_out_p = discriminator1(x=x_in, y=y_in)
    dis1_out_pg = discriminator1(x=gen_out_x, y=sample_y)

    dis2_out_p = discriminator2(z=inf_z, x=unlabel_inf)
    dis2_out_pg = discriminator2(z=z_rand, x=gen_out_x)

    # create discriminator labels
    p_label_d1 = Variable(torch.ones(dis1_out_p.size()))
    pg_label_d1 = Variable(torch.zeros(dis1_out_pg.size()))
    p_label_d2 = Variable(torch.ones(dis2_out_p.size()))
    pg_label_d2 = Variable(torch.zeros(dis2_out_pg.size()))

    if cuda:
        p_label_d1, pg_label_d1, \
        p_label_d2, pg_label_d2 = p_label_d1.cuda(), pg_label_d1.cuda(), \
                                  p_label_d2.cuda(), pg_label_d2.cuda()

    # compute loss
    dis1_cost_p = loss(dis1_out_p, p_label_d1)
    dis1_cost_pg = loss(dis1_out_pg, pg_label_d1)
    dis2_cost_p = loss(dis2_out_p, p_label_d2)
    dis2_cost_pg = loss(dis2_out_pg, pg_label_d2)

    # sum individual losses
    dis1_cost = dis1_cost_p + dis1_cost_pg  # for report
    dis2_cost = dis2_cost_p + dis2_cost_pg  # for report
    total_cost = dis1_cost + dis2_cost

    # optimization routines and weight updates
    optimizer.zero_grad()
    total_cost.backward()
    optimizer.step()

    return total_cost.cpu().data.numpy().mean()


#### Training Classifier

In [None]:
def train_classifier(x_labelled, y_labelled, x_unlabelled, num_batches_u, eval_epoch,
                     size_l, size_u, size_g, n_z, whitener, classifier, p_u,
                     unsup_weight, losses, generator, w_g, cla_lr, rng, b1_c, cuda):
    '''

    Args:
        x_labelled: batch of labelled input data
        y_labelled: batch of labels
        x_unlabelled: unlabelled data
        num_batches_u:
        eval_epoch:
        size_l:
        size_u:
        size_g:
        n_z:
        whitener:
        classifier: Classifier Net instance
        p_u:
        unsup_weight:
        losses(dict): dictionary containing respective loss instances (BCE, MSE, CE)
        generator:
        w_g:
        cla_lr:
        rng:
        b1_c:

    Returns:

    '''

    running_cla_cost = 0.0
    epochs = num_batches_u * eval_epoch

    for i in range(epochs):

        i_l = i % (x_labelled.shape[0] // size_l)
        i_u = i % (x_unlabelled.shape[0] // size_u)

        y_real = np.int32(np.random.randint(10, size=size_g))
        z_real = np.random.uniform(size=(size_g, n_z)).astype(np.float32)

        x_l = x_labelled[i_l * size_l:(i_l + 1) * size_l]
        y = y_labelled[i_l * size_l:(i_l + 1) * size_l]
        x_l_zca = whitener.apply(x_l)

        slice_x_u_c = p_u[i_u*size_u:(i_u+1)*size_u]
        x_u_rep = x_unlabelled[slice_x_u_c]  # copy x_u_zca? double assigned variable??
        x_u = x_unlabelled[slice_x_u_c]
        x_u_rep_zca = whitener.apply(x_u_rep)
        x_u_zca = whitener.apply(x_u)

        # convert to torch tensor variable
        y_real = Variable(torch.from_numpy(y_real))
        z_real = Variable(torch.from_numpy(z_real))
        y = Variable(torch.from_numpy(y).type(torch.LongTensor))
        x_l_zca = Variable(torch.from_numpy(x_l_zca))
        x_u_rep_zca = Variable(torch.from_numpy(x_u_rep_zca))
        x_u_zca = Variable(torch.from_numpy(x_u_zca))
        if cuda:
            y_real, z_real, y = y_real.cuda(), z_real.cuda(), y.cuda()
            x_l_zca, x_u_rep_zca, x_u_zca = x_l_zca.cuda(), x_u_rep_zca.cuda(), x_u_zca.cuda()

        # classify input
        cla_out_y_l = classifier(x_l_zca, cuda=cuda)
        # calculate loss
        cla_cost_l = losses['ce'](cla_out_y_l, y)  # size_average in pytorch is by default

        # classify input for target
        cla_out_y_rep = classifier(x_u_rep_zca, cuda=cuda)
        target = cla_out_y_rep.detach()
        del cla_out_y_rep

        # classify input
        cla_out_y = classifier(x_u_zca, cuda=cuda)
        # calculate loss
        cla_cost_u = unsup_weight * losses['mse'](cla_out_y, target)

        y_m = y_real.type(torch.LongTensor)
        z_m = z_real
        if cuda:
            y_m, z_m = y_m.cuda(), z_m.cuda()
        gen_out_x_m = generator(z=z_m, y=y_m)
        gen_out_x_m_zca = whitener.apply(gen_out_x_m)

        # classify input
        cla_out_y_m = classifier(gen_out_x_m_zca, cuda=cuda)
        # calculate loss
        cla_cost_g = losses['ce'](cla_out_y_m, y_m) * float(w_g)

        # sum individual losses for backward
        cla_cost = cla_cost_l + cla_cost_u + cla_cost_g

        cla_optimizer = optim.Adam(classifier.parameters(), betas=(b1_c, 0.999), lr=cla_lr)
        # zero the parameter gradients, optimize and update parameters
        cla_optimizer.zero_grad()
        cla_cost.backward()
        cla_optimizer.step()

        # update batch permutations
        if i_l == ((x_labelled.shape[0] // size_l) - 1):
            p_l = rng.permutation(x_labelled.shape[0])
            x_labelled = x_labelled[p_l]
            y_labelled = y_labelled[p_l]
        if i_u == (num_batches_u - 1):
            p_u = rng.permutation(x_unlabelled.shape[0]).astype('int32')

        running_cla_cost += cla_cost.cpu().data.numpy().mean()

    return running_cla_cost/epochs


#### Training Inferentor

In [None]:
def train_inferentor(x_unlabelled, sample_y, generator, z_rand, discriminator2, inferentor,
                     mse, bce, slice_x_u_i, optimizer, cuda):

    x_u_i = x_unlabelled[slice_x_u_i]
    x_u_i = Variable(torch.from_numpy(x_u_i))

    if cuda:
        x_u_i = x_u_i.cuda()

    y_g = sample_y
    gen_out_x = generator(z_rand, y_g)
    inf_z = inferentor(x_u_i)
    inf_z_g = inferentor(gen_out_x)
    disxz_out_p = discriminator2(z=inf_z, x=x_u_i)
    target = inf_z_g.detach()
    rz = mse(z_rand, target)

    target = Variable(torch.zeros(disxz_out_p.size()))
    if cuda:
        target = target.cuda()
    inf_cost_p_i = bce(disxz_out_p, target)
    inf_cost = inf_cost_p_i + rz

    optimizer.zero_grad()
    inf_cost.backward()
    optimizer.step()
    return inf_cost.cpu().data.numpy().mean()

#### Accuracy Evaluation

In [None]:
def eval_classifier(num_batches_e, eval_x, eval_y, batch_size, whitener, classifier, cuda):

    accurracy = []
    for i in range(num_batches_e):
        x_eval = eval_x[i*batch_size:(i+1)*batch_size]
        y_eval = eval_y[i*batch_size:(i+1)*batch_size]
        x_eval_zca = whitener.apply(x_eval)
        x_eval_zca = Variable(torch.from_numpy(x_eval_zca))
        if cuda:
            x_eval_zca = x_eval_zca.cuda()

        cla_out_y_eval = classifier(x_eval_zca, cuda=cuda)

        pred = cla_out_y_eval.cpu().data.numpy()
        pred = np.argmax(pred, axis=1)

        accurracy_batch = accuracy_score(y_eval, pred)
        accurracy.append(accurracy_batch)

    return np.mean(accurracy)


In [None]:
### INITS ###

# GENRATOR
generator = Generator(input_size=110, num_classes=NUM_CLASSES, dense_neurons=(4 * 4 * 512))

# INFERENCE
inference = InferenceNet(in_channels=IN_CHANNELS, n_z=N_Z)

# CLASSIFIER
classifier = ClassifierNet(in_channels=IN_CHANNELS)

# DISCRIMINATOR
discriminator1 = DConvNet1(channel_in=IN_CHANNELS, num_classes=NUM_CLASSES)
discriminator2 = DConvNet2(n_z=N_Z, channel_in=IN_CHANNELS, num_classes=NUM_CLASSES)



In [None]:
# put on GPU
if CUDA:
    generator.cuda()
    inference.cuda()
    classifier.cuda()
    discriminator1.cuda()
    discriminator2.cuda()


In [None]:
# ZCA
whitener = ZCA(x=x_unlabelled)

# LOSS FUNCTIONS
if CUDA:
    losses = {
        'bce': nn.BCELoss().cuda(),
        'mse': nn.MSELoss().cuda(),
        'ce': nn.CrossEntropyLoss().cuda()
    }
else:
    losses = {
        'bce': nn.BCELoss(),
        'mse': nn.MSELoss(),
        'ce': nn.CrossEntropyLoss()
    }


In [None]:
### PRETRAIN CLASSIFIER ###

logger.info('Start pretraining...')
for epoch in range(1, 1+NUM_EPOCHS_PRE):

    # pretrain classifier net
    classifier = pretrain_classifier(x_labelled, x_unlabelled, y_labelled, eval_x, eval_y, num_batches_l,
                                     BATCH_SIZE, num_batches_u, classifier, whitener, losses, rng, CUDA)

    # evaluate
    accurracy = eval_classifier(num_batches_e, eval_x, eval_y, BATCH_SIZE_EVAL, whitener, classifier, CUDA)

    logger.info(str(epoch) + ':Pretrain error_rate: ' + str(1 - accurracy))



In [None]:

### GAN TRAINING ###

# assign start values
lr_cla = LR_CLA
lr = LR
start_full = time.time()

logger.info("Start GAN training...")
for epoch in range(1, 1+NUM_EPOCHS):

    # OPTIMIZERS
    optimizers = {
        'dis': optim.Adam(list(discriminator1.parameters()) + list(discriminator2.parameters()), betas=(B1, 0.999), lr=lr),
        'gen': optim.Adam(generator.parameters(), betas=(B1, 0.999), lr=lr),
        'inf': optim.Adam(inference.parameters(), betas=(B1, 0.999), lr=lr)
    }

    # randomly permute data and labels each epoch
    p_l = rng.permutation(x_labelled.shape[0])
    x_labelled = x_labelled[p_l]
    y_labelled = y_labelled[p_l]

    # permuted slicer objects
    p_u = rng.permutation(x_unlabelled.shape[0]).astype('int32')
    p_u_d = rng.permutation(x_unlabelled.shape[0]).astype('int32')
    p_u_i = rng.permutation(x_unlabelled.shape[0]).astype('int32')

    # set epoch dependent values
    if epoch < (NUM_EPOCHS/2):
        if epoch % 50 == 1:
            batch_l = 200 - (epoch // 50 + 1) * 16
            batch_c = (epoch // 50 + 1) * 16
            batch_g = 1
    elif epoch < NUM_EPOCHS and epoch % 100 == 0:
        batch_l = 50
        batch_c = 140 - 10 * (epoch-500)/100
        batch_g = 10 + 10 * (epoch-500)/100

    # if current epoch is an evaluation epoch, train classifier and report results
    if epoch % EVAL_EPOCH == 0:

        logger.info('Train classifier...')

        rampup_value = rampup(epoch-1)
        rampdown_value = rampdown(epoch-1)
        b1_c = rampdown_value * 0.9 + (1.0 - rampdown_value) * 0.5
        unsup_weight = rampup_value * SCALED_UNSUP_WEIGHT_MAX if epoch > 1 else 0.0
        w_g = np.float32(min(float(epoch) / 300.0, 1.0))

        size_l = 100
        size_g = 100
        size_u = 100

        cla_losses = train_classifier(x_labelled=x_labelled,
                                      y_labelled=y_labelled,
                                      x_unlabelled=x_unlabelled,
                                      num_batches_u=num_batches_u,
                                      eval_epoch=EVAL_EPOCH,
                                      size_l=size_l,
                                      size_u=size_u,
                                      size_g=size_g,
                                      n_z=N_Z,
                                      whitener=whitener,
                                      classifier=classifier,
                                      p_u=p_u,
                                      unsup_weight=unsup_weight,
                                      losses=losses,
                                      generator=generator,
                                      w_g=w_g,
                                      cla_lr=lr_cla,
                                      rng=rng,
                                      b1_c=b1_c,
                                      cuda=CUDA)

        # evaluate & report
        accurracy = eval_classifier(num_batches_e, eval_x, eval_y, BATCH_SIZE_EVAL, whitener, classifier, CUDA)

        logger.info('Evaluation error_rate: %.5f\n' % (1 - accurracy))

    logger.info('Train generator, inference and discriminator model...')
    # train GAN model
    for i in range(num_batches_u):
        gan_losses = train_gan(discriminator1=discriminator1,
                               discriminator2=discriminator2,
                               generator=generator,
                               inferentor=inference,
                               classifier=classifier,
                               whitener=whitener,
                               x_labelled=x_labelled,
                               x_unlabelled=x_unlabelled,
                               y_labelled=y_labelled,
                               p_u_d=p_u_d,
                               p_u_i=p_u_i,
                               num_classes=NUM_CLASSES,
                               batch_size=BATCH_SIZE,
                               num_batches_u=num_batches_u,
                               batch_c=batch_c,
                               batch_l=batch_l,
                               batch_g=batch_g,
                               n_z=N_Z,
                               optimizers=optimizers,
                               losses=losses,
                               rng=rng,
                               cuda=CUDA)

    # anneal the learning rates
    if (epoch >= ANNEAL_EPOCH) and (epoch % ANNEAL_EVERY_EPOCH == 0):
        lr = lr * ANNEAL_FACTOR
        lr_cla *= ANNEAL_FACTOR_CLA

    # report and log training info
    t = time.time() - start_full
    line = "*Epoch=%d Time=%.2f LR=%.5f\n" % (epoch, t, lr) + "DisLosses: " + str(gan_losses['dis']) + "\nGenLosses: " + \
           str(gan_losses['gen']) + "\nInfLosses: " + str(gan_losses['inf']) + "\nClaLosses: " + str(cla_losses)
    logger.info(line)

