# NIPS Paper Implementation Challenge 
## PyTorch Code Implementation for Paper Structured Generative Adversarial Networks


In [3]:
### IMPORTS ###
import torch
import torch.nn as nn
import torch.optim as optim
import numpy as np
import time
import utils
# initialize logger
import logging.config
import yaml
with open('./log_config.yaml') as file:
    Dict = yaml.load(file)    # load config file
    logging.config.dictConfig(Dict)    # import config
    
logger = logging.getLogger(__name__)
logger.info('PyTorch version: ' + str(torch.__version__))

# import SGAN utils
from layers import rampup, rampdown
from zca import ZCA
from models import Generator, InferenceNet, ClassifierNet, DConvNet1, DConvNet2
from trainGAN import pretrain_classifier, train_classifier, train_gan, eval_classifier


   INFO [16:56:14] __main__: PyTorch version: 0.2.0_3


### Global Parameter Setting

In [5]:
### GLOBAL PARAMS ###
BATCH_SIZE = 200
BATCH_SIZE_EVAL = 200
NUM_CLASSES = 10
NUM_LABELLED = 4000
SSL_SEED = 1
NP_SEED = 1234
CUDA = torch.cuda.is_available()
logger.info('Cuda = ' + str(CUDA))

# data dependent
IN_CHANNELS = 3

# evaluation
VIS_EPOCH = 1
EVAL_EPOCH = 1

# C
SCALED_UNSUP_WEIGHT_MAX = 100.0

# G
N_Z = 100

# optimization
B1 = 0.5  # moment1 in Adam
LR = 3e-4
LR_CLA = 3e-3
NUM_EPOCHS = 1000
NUM_EPOCHS_PRE = 20
ANNEAL_EPOCH = 200
ANNEAL_EVERY_EPOCH = 1
ANNEAL_FACTOR = 0.995
ANNEAL_FACTOR_CLA = 0.99

path_out = "./results"


   INFO [17:05:11] __main__: Cuda = False


### Data Preprocessing

In [4]:
### DATA ###
logger.info('Loading data...')
train_x, train_y = utils.load('./cifar10/', 'train')
eval_x, eval_y = utils.load('./cifar10/', 'test')

train_y = np.int32(train_y)
eval_y = np.int32(eval_y)
x_unlabelled = train_x.copy()

rng_data = np.random.RandomState(SSL_SEED)
inds = rng_data.permutation(train_x.shape[0])
train_x = train_x[inds]
train_y = train_y[inds]
x_labelled = []
y_labelled = []

for j in range(NUM_CLASSES):
    x_labelled.append(train_x[train_y == j][:int(NUM_LABELLED / NUM_CLASSES)])
    y_labelled.append(train_y[train_y == j][:int(NUM_LABELLED / NUM_CLASSES)])

x_labelled = np.concatenate(x_labelled, axis=0)
y_labelled = np.concatenate(y_labelled, axis=0)
del train_x

num_batches_l = int(x_labelled.shape[0] // BATCH_SIZE)
num_batches_u = int(x_unlabelled.shape[0] // BATCH_SIZE)
num_batches_e = int(eval_x.shape[0] // BATCH_SIZE_EVAL)
rng = np.random.RandomState(NP_SEED)

   INFO [16:56:47] __main__: Cuda = False
   INFO [16:56:47] __main__: Loading data...
>> Downloading cifar-10-python.tar.gz 100.0%
Successfully downloaded cifar-10-python.tar.gz 170498071 bytes.


### Model Structures
#### Generator

In [6]:
### MODEL STRUCTURES ###

# generator y2x: p_g(x, y) = p(y) p_g(x | y) where x = G(z, y), z follows p_g(z)
class Generator(nn.Module):
    def __init__(self, input_size, num_classes, dense_neurons, weight_init=True):
        super(Generator, self).__init__()

        self.logger = logging.getLogger(__name__)  # initialize logger

        self.num_classes = num_classes

        self.Dense = nn.Linear(input_size, dense_neurons)
        self.Relu = nn.ReLU()
        self.Tanh = nn.Tanh()

        self.Deconv2D_0 = nn.ConvTranspose2d(in_channels=522, out_channels=256,
                                             kernel_size=5, stride=2, padding=2,
                                             output_padding=1, bias=False)
        self.Deconv2D_1 = nn.ConvTranspose2d(in_channels=266, out_channels=128,
                                             kernel_size=5, stride=2, padding=2,
                                             output_padding=1, bias=False)
        self.Deconv2D_2 = wn(nn.ConvTranspose2d(in_channels=138, out_channels=3,
                                                kernel_size=5, stride=2, padding=2,
                                                output_padding=1, bias=False))

        self.BatchNorm1D = nn.BatchNorm1d(dense_neurons)

        self.BatchNorm2D_0 = nn.BatchNorm2d(256)
        self.BatchNorm2D_1 = nn.BatchNorm2d(128)

        if weight_init:
            # initialize weights for all conv and lin layers
            self.apply(init_weights)
            # log network structure
            self.logger.debug(self)

    def forward(self, z, y):
        x = mlp_concat(z, y, self.num_classes)

        x = self.Dense(x)
        x = self.Relu(x)
        x = self.BatchNorm1D(x)

        x = x.resize(z.size(0), 512, 4, 4)
        x = conv_concat(x, y, self.num_classes)

        x = self.Deconv2D_0(x)                    # output shape (256,8,8) = 8192 * 2
        x = self.Relu(x)
        x = self.BatchNorm2D_0(x)

        x = conv_concat(x, y, self.num_classes)

        x = self.Deconv2D_1(x)                    # output shape (128,16,16) = 8192 * 2 * 2
        x = self.Relu(x)
        x = self.BatchNorm2D_1(x)

        x = conv_concat(x, y, self.num_classes)
        x = self.Deconv2D_2(x)                    # output shape (3, 32, 32) = 3072
        x = self.Tanh(x)

        return x


#### Discriminator 1

In [8]:
# discriminator xy2p: test a pair of input comes from p(x, y) instead of p_c or p_g
class DConvNet1(nn.Module):
    '''
    1st convolutional discriminative net (discriminator xy2p)
    --> does a pair of input come from p(x, y) instead of p_c or p_g ?
    '''

    def __init__(self, channel_in, num_classes, p_dropout=0.2, weight_init=True):
        super(DConvNet1, self).__init__()

        self.logger = logging.getLogger(__name__)  # initialize logger

        self.num_classes = num_classes

        # general reusable layers:
        self.LReLU = nn.LeakyReLU(negative_slope=0.2)  # leaky ReLU activation function
        self.sgmd = nn.Sigmoid()  # sigmoid activation function
        self.drop = nn.Dropout(p=p_dropout)  # dropout layer

        # input -->
        # drop
        # ConvConcat

        self.conv1 = wn(nn.Conv2d(in_channels=channel_in + num_classes, out_channels=32,
                                  kernel_size=(3, 3), stride=(1, 1), padding=1, bias=False))
        # LReLU
        # ConvConcat

        self.conv2 = wn(nn.Conv2d(in_channels=32 + num_classes, out_channels=32,
                                  kernel_size=(3, 3), stride=2, padding=1, bias=False))
        # LReLU
        # drop
        # ConvConcat

        self.conv3 = wn(nn.Conv2d(in_channels=32 + num_classes, out_channels=64,
                                  kernel_size=(3, 3), stride=(1, 1), padding=1, bias=False))
        # LReLU
        # ConvConcat

        self.conv4 = wn(nn.Conv2d(in_channels=64 + num_classes, out_channels=64,
                                  kernel_size=(3, 3), stride=2, padding=1, bias=False))
        # LReLU
        # drop
        # ConvConcat

        self.conv5 = wn(nn.Conv2d(in_channels=64 + num_classes, out_channels=128,
                                  kernel_size=(3, 3), stride=(1, 1), padding=0, bias=False))
        # LReLU
        # ConvConcat

        self.conv6 = wn(nn.Conv2d(in_channels=128 + num_classes, out_channels=128,
                                  kernel_size=(3, 3), stride=(1, 1), padding=0, bias=False))
        # LReLU

        self.globalPool = nn.AdaptiveAvgPool2d(output_size=4)

        # MLPConcat

        self.lin = nn.Linear(in_features=128 * 4 * 4 + num_classes,
                             out_features=1)
        # smg

        if weight_init:
            # initialize weights for all conv and lin layers
            self.apply(init_weights)
            # log network structure
            self.logger.debug(self)

    def forward(self, x, y):
        # x: (bs, channel_in, dim_input)
        # y: (bs, 1)

        x0 = self.drop(x)
        x0 = conv_concat(x0, y, self.num_classes)

        x1 = self.LReLU(self.conv1(x0))
        x1 = conv_concat(x1, y, self.num_classes)

        x2 = self.LReLU(self.conv2(x1))
        x2 = self.drop(x2)
        x2 = conv_concat(x2, y, self.num_classes)

        x3 = self.LReLU(self.conv3(x2))
        x3 = conv_concat(x3, y, self.num_classes)

        x4 = self.LReLU(self.conv4(x3))
        x4 = self.drop(x4)
        x4 = conv_concat(x4, y, self.num_classes)

        x5 = self.LReLU(self.conv5(x4))
        x5 = conv_concat(x5, y, self.num_classes)

        x6 = self.LReLU(self.conv6(x5))

        x_pool = self.globalPool(x6)

        x_pool = x_pool.view(-1, 128 * 4 * 4)
        x_out = mlp_concat(x_pool, y, self.num_classes)

        out = self.sgmd(self.lin(x_out))

        return out


#### Discriminator 2

In [None]:
# discriminator xz
class DConvNet2(nn.Module):
    '''
    2nd convolutional discriminative net (discriminator xz)
    '''

    def __init__(self, n_z, channel_in, num_classes, weight_init=True):
        super(DConvNet2, self).__init__()

        self.logger = logging.getLogger(__name__)  # initialize logger

        self.num_classes = num_classes

        # general reusable layers:
        self.LReLU = nn.LeakyReLU(negative_slope=0.2)  # leaky ReLU activation function
        self.sgmd = nn.Sigmoid()  # sigmoid activation function

        # z input -->
        self.lin_z0 = nn.Linear(in_features=n_z,
                                out_features=512)
        # LReLU

        self.lin_z1 = nn.Linear(in_features=512,
                                out_features=512)
        # LReLU

        # -------------------------------------

        # x input -->
        self.conv_x0 = nn.Conv2d(in_channels=channel_in, out_channels=128,
                                 kernel_size=(5, 5), stride=2, padding=2, bias=False)
        # LReLU

        self.conv_x1 = nn.Conv2d(in_channels=128, out_channels=256,
                                 kernel_size=(5, 5), stride=2, padding=2, bias=False)
        # LReLU
        self.bn1 = nn.BatchNorm2d(num_features=256)

        self.conv_x2 = nn.Conv2d(in_channels=256, out_channels=512,
                                 kernel_size=(5, 5), stride=2, padding=2, bias=False)
        # LReLU
        self.bn2 = nn.BatchNorm2d(num_features=512)

        # -------------------------------------

        # concat x & z -->
        self.lin_f0 = nn.Linear(in_features=8704,
                                out_features=1024)
        # LReLU

        self.lin_f1 = nn.Linear(in_features=1024,
                                out_features=1)
        # smg

        if weight_init:
            # initialize weights for all conv and lin layers
            self.apply(init_weights)
            # log network structure
            self.logger.debug(self)

    def forward(self, z, x):
        # x: (bs, channel_in, dim_input)
        # z: (bs, n_z)

        z0 = self.LReLU(self.lin_z0(z))
        z_out = self.LReLU(self.lin_z1(z0))

        x0 = self.LReLU(self.conv_x0(x))
        x1 = self.LReLU(self.conv_x1(x0))
        x1 = self.bn1(x1)
        x_out = self.LReLU(self.conv_x2(x1))
        x_out = self.bn2(x_out)

        dims = x_out.size()
        fusion = torch.cat([x_out.view(dims[0], -1).squeeze(-1).squeeze(-1), z_out], dim=1)

        f_out = self.LReLU(self.lin_f0(fusion))
        out = self.sgmd(self.lin_f1(f_out))

        return out

#### Classifier

In [9]:
# classifier module
class ClassifierNet(nn.Module):
    def __init__(self, in_channels, weight_init=True):
        super(ClassifierNet, self).__init__()

        self.logger = logging.getLogger(__name__)  # initialize logger

        self.gaussian = Gaussian_NoiseLayer()

        self.conv1a = nn.Conv2d(in_channels=in_channels, out_channels=128, kernel_size=3,
                                stride=1, padding=1)
        self.convWN1 = MeanOnlyBatchNorm([1, 128, 32, 32])
        self.conv1b = nn.Conv2d(in_channels=128, out_channels=128, kernel_size=3,
                                stride=1, padding=1)
        self.conv_relu = nn.LeakyReLU(negative_slope=0.1)
        self.conv1c = nn.Conv2d(in_channels=128, out_channels=128, kernel_size=3,
                                stride=1, padding=1)
        self.conv_maxpool1 = nn.MaxPool2d(kernel_size=2)
        self.dropout1 = nn.Dropout2d(p=0.5)
        self.conv2a = nn.Conv2d(in_channels=128, out_channels=256, kernel_size=3,
                                stride=1, padding=1)
        self.convWN2 = MeanOnlyBatchNorm([1, 256, 16, 16])
        self.conv2b = nn.Conv2d(in_channels=256, out_channels=256, kernel_size=3,
                                stride=1, padding=1)
        self.conv2c = nn.Conv2d(in_channels=256, out_channels=256, kernel_size=3,
                                stride=1, padding=1)
        self.conv_maxpool2 = nn.MaxPool2d(kernel_size=2)
        self.dropout2 = nn.Dropout2d(p=0.5)
        self.conv3a = nn.Conv2d(in_channels=256, out_channels=512, kernel_size=3,
                                stride=1, padding=0)  # output[6,6]
        self.convWN3a = MeanOnlyBatchNorm([1, 512, 6, 6])
        self.conv3b = nn.Conv2d(in_channels=512, out_channels=256, kernel_size=1,
                                stride=1, padding=0)
        self.convWN3b = MeanOnlyBatchNorm([1, 256, 6, 6])
        self.conv3c = nn.Conv2d(in_channels=256, out_channels=128, kernel_size=1,
                                stride=1, padding=0)
        self.convWN3c = MeanOnlyBatchNorm([1, 128, 6, 6])

        self.conv_globalpool = nn.AdaptiveAvgPool2d(6)

        self.dense = nn.Linear(in_features=128 * 6 * 6, out_features=10)
        self.smx = nn.Softmax()
        #self.WNfinal = MeanOnlyBatchNorm([1, 128, 6, 6])

        if weight_init:
            # initialize weights for all conv and lin layers
            self.apply(init_weights)
            # log network structure
            self.logger.debug(self)

    def forward(self, x, cuda):
        x = self.gaussian(x, cuda=cuda)
        x = self.convWN1(self.conv_relu(self.conv1a(x)))
        x = self.convWN1(self.conv_relu(self.conv1b(x)))
        x = self.convWN1(self.conv_relu(self.conv1c(x)))
        x = self.conv_maxpool1(x)
        x = self.dropout1(x)
        x = self.convWN2(self.conv_relu(self.conv2a(x)))
        x = self.convWN2(self.conv_relu(self.conv2b(x)))
        x = self.convWN2(self.conv_relu(self.conv2c(x)))
        x = self.conv_maxpool2(x)
        x = self.dropout2(x)
        x = self.convWN3a(self.conv_relu(self.conv3a(x)))
        x = self.convWN3b(self.conv_relu(self.conv3b(x)))
        x = self.convWN3c(self.conv_relu(self.conv3c(x)))
        x = self.conv_globalpool(x)
        x = x.view(-1, 128 * 6 * 6)
        #x = self.WNfinal(self.smx(self.dense(x)))
        x = self.smx(self.dense(x))
        return x

#### Inference Network

In [10]:
class InferenceNet(nn.Module):
    def __init__(self, in_channels, n_z, weight_init=True):
        super(InferenceNet, self).__init__()

        self.logger = logging.getLogger(__name__)  # initialize logger

        self.inf02 = nn.Conv2d(in_channels=in_channels, out_channels=64, kernel_size=4,
                               stride=2, padding=1)
        self.inf03 = nn.BatchNorm2d(64)
        self.inf11 = nn.Conv2d(in_channels=64, out_channels=128, kernel_size=4,
                               stride=2, padding=1)
        self.inf12 = nn.BatchNorm2d(128)
        self.inf21 = nn.Conv2d(in_channels=128, out_channels=256, kernel_size=4,
                               stride=2, padding=1)
        self.inf22 = nn.BatchNorm2d(256)
        self.inf31 = nn.Conv2d(in_channels=256, out_channels=512, kernel_size=4,
                               stride=2, padding=1)
        self.inf32 = nn.BatchNorm2d(512)
        self.inf4 = nn.Linear(in_features=512*2*2, out_features=n_z)

        if weight_init:
            # initialize weights for all conv and lin layers
            self.apply(init_weights)
            # log network structure
            self.logger.debug(self)

    def forward(self, x):
        x = F.leaky_relu(self.inf03(self.inf02(x)))
        x = F.leaky_relu(self.inf12(self.inf11(x)))
        x = F.leaky_relu(self.inf22(self.inf21(x)))
        x = F.leaky_relu(self.inf32(self.inf31(x)))
        x = x.view(-1, 512*2*2)
        x = self.inf4(x)

        return x


In [None]:
### INITS ###

# GENRATOR
generator = Generator(input_size=110, num_classes=NUM_CLASSES, dense_neurons=(4 * 4 * 512))

# INFERENCE
inference = InferenceNet(in_channels=IN_CHANNELS, n_z=N_Z)

# CLASSIFIER
classifier = ClassifierNet(in_channels=IN_CHANNELS)

# DISCRIMINATOR
discriminator1 = DConvNet1(channel_in=IN_CHANNELS, num_classes=NUM_CLASSES)
discriminator2 = DConvNet2(n_z=N_Z, channel_in=IN_CHANNELS, num_classes=NUM_CLASSES)



In [None]:
# put on GPU
if CUDA:
    generator.cuda()
    inference.cuda()
    classifier.cuda()
    discriminator1.cuda()
    discriminator2.cuda()


In [None]:
# ZCA
whitener = ZCA(x=x_unlabelled)

# LOSS FUNCTIONS
if CUDA:
    losses = {
        'bce': nn.BCELoss().cuda(),
        'mse': nn.MSELoss().cuda(),
        'ce': nn.CrossEntropyLoss().cuda()
    }
else:
    losses = {
        'bce': nn.BCELoss(),
        'mse': nn.MSELoss(),
        'ce': nn.CrossEntropyLoss()
    }


In [None]:
### PRETRAIN CLASSIFIER ###

logger.info('Start pretraining...')
for epoch in range(1, 1+NUM_EPOCHS_PRE):

    # pretrain classifier net
    classifier = pretrain_classifier(x_labelled, x_unlabelled, y_labelled, eval_x, eval_y, num_batches_l,
                                     BATCH_SIZE, num_batches_u, classifier, whitener, losses, rng, CUDA)

    # evaluate
    accurracy = eval_classifier(num_batches_e, eval_x, eval_y, BATCH_SIZE_EVAL, whitener, classifier, CUDA)

    logger.info(str(epoch) + ':Pretrain error_rate: ' + str(1 - accurracy))



In [None]:

### GAN TRAINING ###

# assign start values
lr_cla = LR_CLA
lr = LR
start_full = time.time()

logger.info("Start GAN training...")
for epoch in range(1, 1+NUM_EPOCHS):

    # OPTIMIZERS
    optimizers = {
        'dis': optim.Adam(list(discriminator1.parameters()) + list(discriminator2.parameters()), betas=(B1, 0.999), lr=lr),
        'gen': optim.Adam(generator.parameters(), betas=(B1, 0.999), lr=lr),
        'inf': optim.Adam(inference.parameters(), betas=(B1, 0.999), lr=lr)
    }

    # randomly permute data and labels each epoch
    p_l = rng.permutation(x_labelled.shape[0])
    x_labelled = x_labelled[p_l]
    y_labelled = y_labelled[p_l]

    # permuted slicer objects
    p_u = rng.permutation(x_unlabelled.shape[0]).astype('int32')
    p_u_d = rng.permutation(x_unlabelled.shape[0]).astype('int32')
    p_u_i = rng.permutation(x_unlabelled.shape[0]).astype('int32')

    # set epoch dependent values
    if epoch < (NUM_EPOCHS/2):
        if epoch % 50 == 1:
            batch_l = 200 - (epoch // 50 + 1) * 16
            batch_c = (epoch // 50 + 1) * 16
            batch_g = 1
    elif epoch < NUM_EPOCHS and epoch % 100 == 0:
        batch_l = 50
        batch_c = 140 - 10 * (epoch-500)/100
        batch_g = 10 + 10 * (epoch-500)/100

    # if current epoch is an evaluation epoch, train classifier and report results
    if epoch % EVAL_EPOCH == 0:

        logger.info('Train classifier...')

        rampup_value = rampup(epoch-1)
        rampdown_value = rampdown(epoch-1)
        b1_c = rampdown_value * 0.9 + (1.0 - rampdown_value) * 0.5
        unsup_weight = rampup_value * SCALED_UNSUP_WEIGHT_MAX if epoch > 1 else 0.0
        w_g = np.float32(min(float(epoch) / 300.0, 1.0))

        size_l = 100
        size_g = 100
        size_u = 100

        cla_losses = train_classifier(x_labelled=x_labelled,
                                      y_labelled=y_labelled,
                                      x_unlabelled=x_unlabelled,
                                      num_batches_u=num_batches_u,
                                      eval_epoch=EVAL_EPOCH,
                                      size_l=size_l,
                                      size_u=size_u,
                                      size_g=size_g,
                                      n_z=N_Z,
                                      whitener=whitener,
                                      classifier=classifier,
                                      p_u=p_u,
                                      unsup_weight=unsup_weight,
                                      losses=losses,
                                      generator=generator,
                                      w_g=w_g,
                                      cla_lr=lr_cla,
                                      rng=rng,
                                      b1_c=b1_c,
                                      cuda=CUDA)

        # evaluate & report
        accurracy = eval_classifier(num_batches_e, eval_x, eval_y, BATCH_SIZE_EVAL, whitener, classifier, CUDA)

        logger.info('Evaluation error_rate: %.5f\n' % (1 - accurracy))

    logger.info('Train generator, inference and discriminator model...')
    # train GAN model
    for i in range(num_batches_u):
        gan_losses = train_gan(discriminator1=discriminator1,
                               discriminator2=discriminator2,
                               generator=generator,
                               inferentor=inference,
                               classifier=classifier,
                               whitener=whitener,
                               x_labelled=x_labelled,
                               x_unlabelled=x_unlabelled,
                               y_labelled=y_labelled,
                               p_u_d=p_u_d,
                               p_u_i=p_u_i,
                               num_classes=NUM_CLASSES,
                               batch_size=BATCH_SIZE,
                               num_batches_u=num_batches_u,
                               batch_c=batch_c,
                               batch_l=batch_l,
                               batch_g=batch_g,
                               n_z=N_Z,
                               optimizers=optimizers,
                               losses=losses,
                               rng=rng,
                               cuda=CUDA)

    # anneal the learning rates
    if (epoch >= ANNEAL_EPOCH) and (epoch % ANNEAL_EVERY_EPOCH == 0):
        lr = lr * ANNEAL_FACTOR
        lr_cla *= ANNEAL_FACTOR_CLA

    # report and log training info
    t = time.time() - start_full
    line = "*Epoch=%d Time=%.2f LR=%.5f\n" % (epoch, t, lr) + "DisLosses: " + str(gan_losses['dis']) + "\nGenLosses: " + \
           str(gan_losses['gen']) + "\nInfLosses: " + str(gan_losses['inf']) + "\nClaLosses: " + str(cla_losses)
    logger.info(line)

