# Task 1 - Generative Models

## Introduction

This is a research environment for Task 1 - Generative Models. The implementations of the networks are adapted from William Beng's implementation of the BiGAN network.
[[Paper]](https://arxiv.org/abs/1605.09782) [[Original implementation]](https://github.com/WilliBee/bigan_SRL)

Additionally, this relies on Cyrille Rossant's post 'An illustrated introduction to the t-SNE algorithm' for t-SNE visualisation utilities. [[Post]](https://www.oreilly.com/learning/an-illustrated-introduction-to-the-t-sne-algorithm)

Triplet data processing and triplet loss calculations based on Adam Bielski's implementation of triplet networks. [[Implementation]](https://github.com/adambielski/siamese-triplet/blob/master/losses.py)

The BiCoGAN architecture comes from a paper by Ayush Jaiswal, Wael AbdAlmageed, Yue Wu, Premkumar Natarajan. [[Paper]](https://arxiv.org/abs/1711.07461)

The overriding principle has been to prototype rapidly and get research results as quickly as possible. This means that there are some obvious inefficiencies in the code, e.g. network error printing could be rewritten as a seperate function and reused as needed.

Keeping different architectures in separate classes is a conscious design decision.

Thanks to the Colab environment, further architectures and variations of the mentioned ones were tested. In many cases, the results were either similar to the ones achieved on the BiCoGAN architecture or were not substantially different from the baseline method (BiGAN) and were not reported.

Let's get the ball rolling.

## Preliminaries

In [0]:
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.autograd import Variable
import torch.optim as optim
from torch.utils.data import Dataset
from torchvision import datasets, transforms
import matplotlib.pyplot as plt
import matplotlib.gridspec as gridspec
import matplotlib.patheffects as PathEffects
import seaborn as sns
import numpy as np
import os
from itertools import *
import math
from PIL import Image
from sklearn.manifold import TSNE
# Faster implementation of t-SNE:
# from MulticoreTSNE import MulticoreTSNE as TSNE
from sklearn.neighbors import KNeighborsClassifier
from sklearn import metrics


In [0]:
#from google.colab import drive
#drive.mount('/content/gdrive')


## Shared Utilities

In [0]:
# print out net and number of parameters

def print_network(net):
    num_params = 0
    for param in net.parameters():
        num_params += param.numel()
    print(net)
    print('Total number of parameters: %d' % num_params)
    
            

In [0]:
# weight initialization

def initialize_weights(net):
    for m in net.modules():
        if isinstance(m, nn.Linear):
            m.weight.data.normal_(0, 0.02)
            m.bias.data.zero_()

            

In [0]:
# handles log cases with x = 0

def log(x):
      return torch.log(x + 1e-8)

    

In [0]:
# scatter plot for t-SNE

def scatter(x, colors):
    # We choose a color palette with seaborn.
    palette = np.array(sns.color_palette("hls", 10))

    # We create a scatter plot.
    f = plt.figure(figsize=(8, 8))
    ax = plt.subplot(aspect='equal')
    sc = ax.scatter(x[:,0], x[:,1], lw=0, s=40,
                    c=palette[colors.astype(np.int)])
    plt.xlim(-25, 25)
    plt.ylim(-25, 25)
    ax.axis('off')
    ax.axis('tight')

    # We add the labels for each digit.
    txts = []
    for i in range(10):
        # Position of each label.
        xtext, ytext = np.median(x[colors == i, :], axis=0)
        txt = ax.text(xtext, ytext, str(i), fontsize=24)
        txt.set_path_effects([
            PathEffects.Stroke(linewidth=5, foreground="w"),
            PathEffects.Normal()])
        txts.append(txt)

    return f, ax, sc, txts

In [0]:
#!rm -rf sample_data
#!zip -r /content/file.zip /content


## Datasets

In [0]:
# classic MNIST

class Mnist:
    def __init__(self, batch_size):
        MNIST_MEAN = 0.1307
        MNIST_STD = 0.3081

        dataset_transform = transforms.Compose([
                       transforms.ToTensor(),
                       # transforms.Normalize((MNIST_MEAN,), (MNIST_STD,))
                   ])

        train_dataset = datasets.MNIST('/home/zelazny/Downloads/tooploox/data/', train=True, download=True, transform=dataset_transform)
        test_dataset = datasets.MNIST('/home/zelazny/Downloads/tooploox/data/', train=False, download=True, transform=dataset_transform)

        self.train_loader  = torch.utils.data.DataLoader(train_dataset, batch_size=batch_size, shuffle=True)
        self.test_loader = torch.utils.data.DataLoader(test_dataset, batch_size=batch_size, shuffle=True)

        

In [0]:
# Dataset to transform MNIST into triplets

class TripletMNIST(Dataset):
    """
    Train: For each sample (anchor) randomly chooses a positive and negative samples
    Test: Creates fixed triplets for testing
    """

    def __init__(self, mnist_dataset):
        self.mnist_dataset = mnist_dataset
        self.train = self.mnist_dataset.train
        self.transform = self.mnist_dataset.transform

        if self.train:
            self.targets= self.mnist_dataset.targets
            self.data = self.mnist_dataset.data
            self.labels_set = set(self.targets.numpy())
            self.label_to_indices = {label: np.where(self.targets.numpy() == label)[0]
                                     for label in self.labels_set}

        else:
            self.targets = self.mnist_dataset.targets
            self.data = self.mnist_dataset.data
            # generate fixed triplets for testing
            self.labels_set = set(self.targets.numpy())
            self.label_to_indices = {label: np.where(self.targets.numpy() == label)[0]
                                     for label in self.labels_set}

            random_state = np.random.RandomState(29)

            triplets = [[i,
                         random_state.choice(self.label_to_indices[self.targets[i].item()]),
                         random_state.choice(self.label_to_indices[
                                                 np.random.choice(
                                                     list(self.labels_set - set([self.targets[i].item()]))
                                                 )
                                             ])
                         ]
                        for i in range(len(self.data))]
            self.test_triplets = triplets

    def __getitem__(self, index):
        if self.train:
            img1, label1 = self.data[index], self.targets[index].item()
            positive_index = index
            while positive_index == index:
                positive_index = np.random.choice(self.label_to_indices[label1])
            negative_label = np.random.choice(list(self.labels_set - set([label1])))
            negative_index = np.random.choice(self.label_to_indices[negative_label])
            img2 = self.data[positive_index]
            img3 = self.data[negative_index]
        else:
            img1 = self.data[self.test_triplets[index][0]]
            img2 = self.data[self.test_triplets[index][1]]
            img3 = self.data[self.test_triplets[index][2]]
            label1 = self.targets[self.test_triplets[index][0]]

        img1 = Image.fromarray(img1.numpy(), mode='L')
        img2 = Image.fromarray(img2.numpy(), mode='L')
        img3 = Image.fromarray(img3.numpy(), mode='L')
        if self.transform is not None:
            img1 = self.transform(img1)
            img2 = self.transform(img2)
            img3 = self.transform(img3)
        #return (img1, img2, img3), []
        return (img1, img2, img3), label1

    def __len__(self):
        return len(self.mnist_dataset)

    

In [0]:
# MNIST transformed into triplets

class TripletMnist:
    def __init__(self, batch_size):
        MNIST_MEAN = 0.1307
        MNIST_STD = 0.3081

        dataset_transform = transforms.Compose([
                       transforms.ToTensor(),
                       # transforms.Normalize((MNIST_MEAN,), (MNIST_STD,))
                   ])

        train_dataset = datasets.MNIST('/home/zelazny/Downloads/tooploox/data/', train=True, download=True, transform=dataset_transform)
        test_dataset = datasets.MNIST('/home/zelazny/Downloads/tooploox/data/', train=False, download=True, transform=dataset_transform)
        
        triplet_train_dataset = TripletMNIST(train_dataset)
        triplet_test_dataset = TripletMNIST(test_dataset)

        self.train_loader  = torch.utils.data.DataLoader(triplet_train_dataset, batch_size=batch_size, shuffle=True)
        self.test_loader = torch.utils.data.DataLoader(triplet_test_dataset, batch_size=batch_size, shuffle=False)

        

## BiGAN Architecture & Utilities

In [0]:
# fully-connected generator


class Generator_FC(nn.Module):
    """
    MLP with two hidden layers of dimension h_dim.
    Input is a vector from representation space of dimension z_dim.
    Output is a vector from image space of dimension X_dim.
    """
    def __init__(self, z_dim, h_dim, X_dim):
        super(Generator_FC, self).__init__()

        self.z_dim = z_dim
        self.h_dim = h_dim
        self.X_dim = X_dim

        self.fc = torch.nn.Sequential(
            torch.nn.Linear(z_dim, h_dim),
            torch.nn.ReLU(),
            torch.nn.Linear(h_dim, h_dim),
            torch.nn.BatchNorm1d(h_dim),
            torch.nn.ReLU(),
            torch.nn.Linear(h_dim, X_dim),
            torch.nn.Sigmoid()
            )

        initialize_weights(self)

    def forward(self, input):
        x = self.fc(input)
        return x

    

The Discriminator is implemented as outlined in the BiGan paper:

*The BiGAN
discriminator D(x, z) takes data x as its initial input, and at each linear layer thereafter, the latent
representation z is transformed using a learned linear transformation to the hidden layer dimension
and added to the non-linearity input.*

Such a formulation is not standard. Alternative implementations based on concatenations of X and z are possible and perform similarly.

This Discriminator is also used in the basic variants of the BiCoGAN and TriBiGAN architectures.

In [0]:
# fully-connected Discriminator


class Discriminator_FC(nn.Module):
    """
    MLP with two hidden layers of dimension h_dim.
    Initial input is vector X from data space (or sampled from representation space).
    Vector z from latent space (representation space) is transformed linearly
    and added to activation input in hidden layers.
    For example, if X comes from the dataset, corresponding
    z is Encoder(X), and if z is sampled from representation space, X is Generator(z).
    """
    def __init__(self, z_dim, h_dim, X_dim):
        super(Discriminator_FC, self).__init__()

        self.z_dim = z_dim
        self.h_dim = h_dim
        self.X_dim = X_dim

        self.fc1 = torch.nn.Linear(X_dim, h_dim)
        
        self.z1 = torch.nn.Linear(z_dim, h_dim)   
        
        self.leaky1 = nn.LeakyReLU(0.2)
        
        self.fc2 = torch.nn.Linear(h_dim, h_dim)
        
        self.z2 = torch.nn.Linear(z_dim, h_dim)
        
        self.leaky2 = nn.LeakyReLU(0.2)
        
        self.fc3 = torch.nn.Sequential(
            torch.nn.Linear(h_dim, 1),
            torch.nn.Sigmoid()
            )

        initialize_weights(self)

    def forward(self, input_x, input_z):
        x = self.fc1(input_x)
        x = torch.add(x, self.z1(input_z))
        x = self.leaky1(x)
        x = self.fc2(x)
        x = torch.add(x, self.z2(input_z))
        x = self.leaky2(x)
        return self.fc3(x)



In [0]:
# fully-connected Encoder

class Encoder_FC(nn.Module):
    """
    MLP with with two hidden layers of dimension h_dim.
    Input is vector X from image space if dimension X_dim.
    Output is vector z from representation space of dimension z_dim.
    """
    def __init__(self, z_dim, h_dim, X_dim):
        super(Encoder_FC, self).__init__()

        self.z_dim = z_dim
        self.h_dim = h_dim
        self.X_dim = X_dim
        
        self.fc = torch.nn.Sequential(
            torch.nn.Linear(X_dim, h_dim),
            nn.LeakyReLU(0.2),
            torch.nn.Linear(h_dim, h_dim),
            torch.nn.BatchNorm1d(h_dim),
            nn.LeakyReLU(0.2),
            torch.nn.Linear(h_dim, z_dim),
            )
        
        initialize_weights(self)

    def forward(self, input):
        x = self.fc(input)
        return x


In [0]:
# save plots of loss functions after training


def save_plot_losses(train_D_loss, train_G_loss, eval_D_loss, eval_G_loss, model_used, z_dim, epochs, lr, batch_size):

    x = np.arange(1, len(train_D_loss) + 1)

    plt.figure(figsize=(8, 6))
    plt.plot(x, train_G_loss, label="Train G loss", linewidth=2)
    plt.plot(x, eval_G_loss, label="Eval G loss", linewidth=2)


    plt.axes().set_xlabel('Epoch')
    plt.axes().set_ylabel('Loss')
    plt.legend(loc='upper right')
    plt.suptitle("Evolution of the Train and Eval losses of G")
    params = "Network type: " + model_used + ", Dimension of latent space: " + str(z_dim) + ", epochs: " + str(epochs) + ", learning rate: " + str(lr) + ", batch size:" + str(batch_size)
    plt.title(params, fontsize=8)
    # plt.show()
    plt.savefig("plot_G_losses.eps", format='eps', dpi=1000)



    plt.figure(figsize=(8, 6))
    plt.plot(x, train_D_loss, label="Train D loss", linewidth=2)
    plt.plot(x, eval_D_loss, label="Eval D loss", linewidth=2)

    plt.axes().set_xlabel('Epoch')
    plt.axes().set_ylabel('Loss')
    plt.legend(loc='upper right')

    plt.suptitle("Evolution of the Train and Eval losses of D")

    params = "Network type: " + model_used + ", Dimension of latent space: " + str(z_dim) + ", epochs: " + str(epochs) + ", learning rate: " + str(lr) + ", batch size:" + str(batch_size)
    plt.title(params, fontsize=8)
    # plt.show()
    plt.savefig("plot_D_losses.eps", format='eps', dpi=1000)
    plt.close()


def save_plot_pixel_norm(mean_pixel_norm, model_used, z_dim, epochs, lr, batch_size):
    x = np.arange(1, len(mean_pixel_norm) + 1)

    plt.figure(figsize=(8, 6))
    plt.plot(x, mean_pixel_norm, label="Reconstruction error", linewidth=2)

    plt.axes().set_xlabel('Epoch')
    plt.axes().set_ylabel('Norm')
    plt.legend(loc='upper right')
    plt.suptitle("Evolution of the reconstruction error between X and G(E(X))")
    params = "Network type: " + model_used + ", Dimension of latent space: " + str(z_dim) + ", epochs: " + str(epochs) + ", learning rate: " + str(lr) + ", batch size:" + str(batch_size)
    plt.title(params, fontsize=8)
    # plt.show()
    plt.savefig("pix2pix_norm.eps", format='eps', dpi=1000)
    plt.close()

def save_plot_z_norm(mean_z_norm, model_used, z_dim, epochs, lr, batch_size):
    x = np.arange(1, len(mean_z_norm) + 1)

    plt.figure(figsize=(8, 6))
    plt.plot(x, mean_z_norm, label="Reconstruction error", linewidth=2)

    plt.axes().set_xlabel('Epoch')
    plt.axes().set_ylabel('Norm')
    plt.legend(loc='upper right')
    plt.suptitle("Evolution of the reconstruction error between z and E(G(z))")
    params = "Network type: " + model_used + ", Dimension of latent space: " + str(z_dim) + ", epochs: " + str(epochs) + ", learning rate: " + str(lr) + ", batch size:" + str(batch_size)
    plt.title(params, fontsize=8)
    # plt.show()
    plt.savefig("z_norm.eps", format='eps', dpi=1000)
    plt.close()


In [0]:
# BiGAN net

class BIGAN(object):
    """
    Class implementing a BiGAN network that trains from an observations dataset.
    """

    def __init__(self, kwargs):
        
        self.epoch = kwargs['epoch']
        self.batch_size = kwargs['batch_size']
        self.save_dir = kwargs['save_dir']
        self.result_dir = kwargs['result_dir']
        self.log_dir = kwargs['log_dir']
        self.gpu_mode = kwargs['gpu_mode']
        self.learning_rate = kwargs['lr']
        self.lr_decay = kwargs['lr_decay']
        self.beta1 = kwargs['beta1']
        self.beta2 = kwargs['beta2']
        self.decay = kwargs['decay']
        self.network_type = kwargs['network_type']
        self.dataset = kwargs['dataset']
        self.dataset_path = kwargs['dataset_path']

        # BIGAN parameters
        self.z_dim = kwargs['z_dim']    #dimension of feature space
        self.h_dim = kwargs['h_dim']    #dimension of the hidden layer

        if kwargs['dataset'] == 'mnist':
            self.X_dim = 28*28    #dimension of data
            self.num_channels = 1

        if kwargs['network_type'] == 'FC':
            # networks init
            self.G = Generator_FC(self.z_dim, self.h_dim, self.X_dim)
            self.D = Discriminator_FC(self.z_dim, self.h_dim, self.X_dim)
            self.E = Encoder_FC(self.z_dim, self.h_dim, self.X_dim)
        else:
            raise Exception("[!] There is no option for " + kwargs['network_type'])

        if self.gpu_mode:
            self.G.cuda()
            self.D.cuda()
            self.E.cuda()

        
        self.G_solver = optim.Adam(chain(self.E.parameters(), self.G.parameters()), lr=self.learning_rate, betas=[self.beta1,self.beta2], weight_decay=self.decay)
        self.D_solver = optim.Adam(self.D.parameters(), lr=self.learning_rate, betas=[self.beta1,self.beta2], weight_decay=self.decay)
        
        # exponential learning rate decay starting halfway through learning
        self.G_scheduler = optim.lr_scheduler.MultiStepLR(self.G_solver, milestones=list(range(self.epoch//2, self.epoch)), gamma=self.lr_decay)
        self.D_scheduler = optim.lr_scheduler.MultiStepLR(self.D_solver, milestones=list(range(self.epoch//2, self.epoch)), gamma=self.lr_decay)



        print('---------- Networks architecture -------------')
        print_network(self.G)
        print_network(self.E)
        print_network(self.D)
        print('-----------------------------------------------')



    def D_(self, X, z):
        return self.D(X, z)

    def reset_grad(self):
        self.E.zero_grad()
        self.G.zero_grad()
        self.D.zero_grad()


    def train(self):
        if self.dataset == 'mnist':
            dataset = Mnist(self.batch_size)


        self.train_hist = {}
        self.train_hist['D_loss'] = []
        self.train_hist['G_loss'] = []

        self.eval_hist = {}
        self.eval_hist['D_loss'] = []
        self.eval_hist['G_loss'] = []
        self.eval_hist['pixel_norm'] = []
        self.eval_hist['z_norm'] = []


        for epoch in range(self.epoch):
            print("epoch ",str(epoch))

            self.D.train()
            self.E.train()
            self.G.train()

            train_loss_G = 0
            train_loss_D = 0
            
            # learning rate schedule
            self.G_scheduler.step()
            self.D_scheduler.step()


            for batch_id, (data, target) in enumerate(dataset.train_loader):

                if self.gpu_mode:
                    # sample z
                    z = Variable((1 - (-1))*torch.rand(self.batch_size, self.z_dim) + (-1)).cuda() # uniform on [-1,1]
                    # X is a real image from the dataset
                    X = data
                    X = Variable(X).cuda()
                else:
                    z = Variable((1 - (-1))*torch.rand(self.batch_size, self.z_dim) + (-1)) # uniform on [-1,1]
                    X = data
                    X = Variable(X)

                # sometimes batchsize of X is not equal to actual batch_size
                if X.size(0) == self.batch_size:

                    if self.network_type == 'FC':
                        X = X.view(self.batch_size, -1)
                        z_hat = self.E(X)
                        X_hat = self.G(z)

                        D_enc = self.D_(X, z_hat)
                        D_gen = self.D_(X_hat, z)


                    D_loss = -torch.mean(log(D_enc) + log(1 - D_gen))
                    G_loss = -torch.mean(log(D_gen) + log(1 - D_enc))

                    D_loss.backward(retain_graph=True)
                    self.D_solver.step()
                    self.reset_grad()

                    G_loss.backward()
                    self.G_solver.step()
                    self.reset_grad()

                    train_loss_G += G_loss.data
                    train_loss_D += D_loss.data

                    if batch_id % 1000 == 0:
                        # Print and plot every now and then
                        samples = X_hat.data.cpu().numpy()

                        fig = plt.figure(figsize=(8, 4))
                        gs = gridspec.GridSpec(4, 8)
                        gs.update(wspace=0.05, hspace=0.05)

                        for i, sample in enumerate(samples):
                            if i<32:
                                ax = plt.subplot(gs[i])
                                plt.axis('off')
                                ax.set_xticklabels([])
                                ax.set_yticklabels([])
                                ax.set_aspect('equal')

                                if self.network_type == 'FC':
                                    if self.dataset == 'mnist':
                                        sample = sample.reshape(28, 28)
                                        plt.imshow(sample, cmap='Greys_r')


                        if not os.path.exists(self.result_dir + '/train/'):
                            os.makedirs(self.result_dir + '/train/')

                        filename = "epoch_" + str(epoch) + "_batchid_" + str(batch_id)
                        plt.savefig(self.result_dir + '/train/{}.png'.format(filename, bbox_inches='tight'))
                        plt.close()

            print("Train loss G:", train_loss_G / len(dataset.train_loader))
            print("Train loss D:", train_loss_D / len(dataset.train_loader))

            self.train_hist['D_loss'].append(train_loss_D / len(dataset.train_loader))
            self.train_hist['G_loss'].append(train_loss_G / len(dataset.train_loader))


            self.D.eval()
            self.E.eval()
            self.G.eval()
            test_loss_G = 0
            test_loss_D = 0

            mean_pixel_norm = 0
            mean_z_norm = 0
            norm_counter = 1

            for batch_id, (data, target) in enumerate(dataset.test_loader):
                # Sample data
                z = Variable((1 - (-1))*torch.rand(self.batch_size, self.z_dim) + (-1)) # uniform on [-1,1]
                X_data = Variable(data)

                if self.gpu_mode:
                    z = z.cuda()
                    X_data = X_data.cuda()

                if X_data.size(0) == self.batch_size:
                    X = X_data
                    if self.network_type == 'FC':
                        X = X.view(self.batch_size, -1)
                        z_hat = self.E(X)
                        X_hat = self.G(z)

                        D_enc = self.D_(X, z_hat)
                        D_gen = self.D_(X_hat, z)

                    D_loss = -torch.mean(log(D_enc) + log(1 - D_gen))
                    G_loss = -torch.mean(log(D_gen) + log(1 - D_enc))

                    test_loss_G += G_loss.data
                    test_loss_D += D_loss.data

                    pixel_norm = X -  self.G(z_hat)
                    pixel_norm = pixel_norm.norm().data / float(self.X_dim)
                    mean_pixel_norm += pixel_norm


                    z_norm = z - self.E(X_hat)
                    z_norm = z_norm.norm().data / float(self.z_dim)
                    mean_z_norm += z_norm

                    norm_counter += 1


            print("Eval loss G:", test_loss_G / norm_counter)
            print("Eval loss D:", test_loss_D / norm_counter)

            self.eval_hist['D_loss'].append(test_loss_D / norm_counter)
            self.eval_hist['G_loss'].append(test_loss_G / norm_counter)

            print("Pixel norm:", mean_pixel_norm / norm_counter)
            self.eval_hist['pixel_norm'].append( mean_pixel_norm / norm_counter )

            with open('pixel_error_BIGAN.txt', 'a') as f:
                f.writelines(str(mean_pixel_norm / norm_counter) + '\n')

            print("z norm:", mean_z_norm / norm_counter)
            self.eval_hist['z_norm'].append( mean_z_norm / norm_counter )

            with open('z_error_BIGAN.txt', 'a') as f:
                f.writelines(str(mean_z_norm / norm_counter) + '\n')

            ##### At the end of the epoch, save X and its reconstruction G(E(X))
            samples = X.data.cpu().numpy()

            fig = plt.figure(figsize=(10, 2))
            gs = gridspec.GridSpec(2, 10)
            gs.update(wspace=0.05, hspace=0.05)

            for i, sample in enumerate(samples):
                if i<10:
                    ax = plt.subplot(gs[i])
                    plt.axis('off')
                    ax.set_xticklabels([])
                    ax.set_yticklabels([])
                    ax.set_aspect('equal')
                    if self.network_type == 'FC':
                        if self.dataset == 'mnist':
                            sample = sample.reshape(28, 28)
                            plt.imshow(sample, cmap='Greys_r')
                        
                        
            X_hat = self.G(self.E(X).view(self.batch_size, self.z_dim))
            samples = X_hat.data.cpu().numpy()


            for i, sample in enumerate(samples):
                if i<10:
                    ax = plt.subplot(gs[10+i])
                    plt.axis('off')
                    ax.set_xticklabels([])
                    ax.set_yticklabels([])
                    ax.set_aspect('equal')
                    if self.network_type == 'FC':
                        if self.dataset == 'mnist':
                            sample = sample.reshape(28, 28)
                            plt.imshow(sample, cmap='Greys_r')
                        

            if not os.path.exists(self.result_dir + '/recons/'):
                os.makedirs(self.result_dir + '/recons/')

            filename = "epoch_" + str(epoch)
            plt.savefig(self.result_dir + '/recons/{}.png'.format(filename), bbox_inches='tight')
            plt.close()


        save_plot_losses(self.train_hist['D_loss'], self.train_hist['G_loss'], self.eval_hist['D_loss'], self.eval_hist['G_loss'], self.network_type, self.z_dim, self.epoch, self.learning_rate, self.batch_size)
        save_plot_pixel_norm(self.eval_hist['pixel_norm'], self.network_type, self.z_dim, self.epoch, self.learning_rate, self.batch_size)
        save_plot_z_norm(self.eval_hist['z_norm'], self.network_type, self.z_dim, self.epoch, self.learning_rate, self.batch_size)

    def save_model(self):
        if not os.path.exists(self.save_dir):
            os.makedirs(self.save_dir)
        
        torch.save(self.G.state_dict(), self.save_dir + "/G.pt")
        torch.save(self.E.state_dict(), self.save_dir + "/E.pt")
        torch.save(self.D.state_dict(), self.save_dir + "/D.pt")

    def load_model(self, kwargs):
        if kwargs['network_type'] == 'FC':
            # networks init
            self.G = Generator_FC(self.z_dim, self.h_dim, self.X_dim)
            self.D = Discriminator_FC(self.z_dim, self.h_dim, self.X_dim)
            self.E = Encoder_FC(self.z_dim, self.h_dim, self.X_dim)

        self.G.load_state_dict(torch.load("/models/G.pt"))
        self.E.load_state_dict(torch.load("/models/E.pt"))
        self.D.load_state_dict(torch.load("/models/D.pt"))

        if self.gpu_mode:
            self.G.cuda()
            self.D.cuda()
            self.E.cuda()



## BiGAN Training

In [0]:
"""config"""

kwargs = {
    'dataset': 'mnist',
    'dataset_path': '/home/zelazny/Downloads/tooploox/data/',
    'gpu_mode': True,
    'save_dir': 'models',
    'result_dir': 'results',
    'log_dir': 'logs',
    'epoch': 400,
    'batch_size': 128,
    'lr': 1e-4, # initial learning rate
    'lr_decay': (0.01)**(1/200), # exponential decay to 1e-6 over 200 epochs
    'beta1': 0.5, # Adam1
    'beta2': 0.999, # Adam2
    'decay': 2.5*1e-5, # weight decay
    'network_type': 'FC',
    'z_dim': 50, # latent space dimension
    'h_dim': 1024 # number of hidden units
}

"""check arguments"""

def check_kwargs(kwargs):
    # save_dir
    if not os.path.exists(kwargs['save_dir']):
        os.makedirs(kwargs['save_dir'])

    # result_dir
    if not os.path.exists(kwargs['result_dir']):
        os.makedirs(kwargs['result_dir'])

    # log_dir
    if not os.path.exists(kwargs['log_dir']):
        os.makedirs(kwargs['log_dir'])

    # epoch
    try:
        assert kwargs['epoch'] >= 1
    except:
        print('number of epochs must be larger than or equal to one')

    # batch_size
    try:
        assert kwargs['batch_size'] >= 1
    except:
        print('batch size must be larger than or equal to one')

    return kwargs


In [0]:
"""main"""
def main():
    # check arguments
    if kwargs is None:
        exit()
    else:
        check_kwargs(kwargs)

    bigan = BIGAN(kwargs)

    # wipe old files
    with open('pixel_error_BIGAN.txt', 'w') as f:
        f.writelines('')
    with open('z_error_BIGAN.txt', 'w') as f:
        f.writelines('')

    bigan.train()
    print(" [*] Training finished!")

    bigan.save_model()
    

if __name__ == '__main__':
    main()

    

## BiGAN Evaluation

In [0]:
# load trained net

bigan = BIGAN(kwargs)
bigan.load_model(kwargs)

dataset = Mnist(128)



In [0]:
# generate encodings for dataset

def generate_encodings(dataset, train_test):
    if train_test == 'train':
        loader = dataset.train_loader
    elif train_test == 'test':
        loader = dataset.test_loader
    
    encodings = []
    labels = []
    
    for batch_id, (data, target) in enumerate(loader):
    
        X_data = Variable(data)
    
        if bigan.gpu_mode:
            X_data = X_data.cuda()
        
        if X_data.size(0) == bigan.batch_size:
            X = X_data
            X = X.view(bigan.batch_size, -1)
            z_hat = bigan.E(X)
            encodings.append(z_hat)
            labels.append(target)


    encodings = torch.cat(encodings).data.cpu().numpy()
    labels = torch.cat(labels).data.cpu().numpy()

    return encodings, labels



In [0]:
# generate encodings for test set

encodings_train, labels_train = generate_encodings(dataset, 'train')
encodings_test, labels_test = generate_encodings(dataset, 'test')



In [0]:
# generate and visualize t-SNE encodings for test set

tsne_encodings_test = TSNE().fit_transform(encodings_test)
scatter(tsne_encodings_test, labels_test)


In [0]:
# kNN classification accuracy on test set

def knn_results(n):
    knn = KNeighborsClassifier(n_neighbors=n)
    knn.fit(encodings_train, labels_train)
    labels_hat = knn.predict(encodings_test)

    print('%sNN classification accuracy (%%)' %n, round(metrics.accuracy_score(labels_test, labels_hat)*100, 2))

    
for i in range(1, 11):
    knn_results(i)


## BiCoGan Architecture and Utilities

In [0]:
# BiCoGAN Generator

class Generator_BiCoGAN(nn.Module):
    """
    MLP incorporating label data. Input: two vectors, one from representation
    space of dimension z_dim, the second one a one-hot encoding of the label,
    of dimension y_dim. Intermediate layers are of dimension h_dim. Output is
    a vector from image space of dimension X_dim.
    """
    def __init__(self, z_dim, h_dim, X_dim, y_dim):
        super(Generator_BiCoGAN, self).__init__()

        self.z_dim = z_dim
        self.h_dim = h_dim
        self.X_dim = X_dim
        self.y_dim = y_dim
        
        self.fc_y = torch.nn.Linear(y_dim, h_dim)
        self.fc_z = torch.nn.Linear(z_dim, h_dim)

        self.fc = torch.nn.Sequential(
            torch.nn.ReLU(),
            torch.nn.Linear(h_dim, h_dim),
            torch.nn.BatchNorm1d(h_dim),
            torch.nn.ReLU(),
            torch.nn.Linear(h_dim, X_dim),
            torch.nn.Sigmoid()
            )
        
        initialize_weights(self)

    def forward(self, input_z, input_y):
        y = self.fc_y(input_y)
        z = self.fc_z(input_z)
        x = y*z # element-wise multiplication
        x = self.fc(x)
        return x

    

In [0]:
# BiCoGAN Encoder

class Encoder_BiCoGAN(nn.Module):
    """
    MLP incorporating label data. Input: vector from data space
    of dimension X_dim. Output: a vector from representation space of dimension
    z_dim and a second one which is a prediction of the label,of dimension y_dim.
    Intermediate layers are of dimension h_dim.
    """
    def __init__(self, z_dim, h_dim, X_dim, y_dim):
        super(Encoder_BiCoGAN, self).__init__()

        self.z_dim = z_dim
        self.h_dim = h_dim
        self.X_dim = X_dim
        self.y_dim = y_dim

        
        self.fc = torch.nn.Sequential(
            torch.nn.Linear(X_dim, h_dim),
            nn.LeakyReLU(0.2),
            torch.nn.Linear(h_dim, h_dim),
            torch.nn.BatchNorm1d(h_dim),
            nn.LeakyReLU(0.2)
            )
        
        self.z_out = torch.nn.Linear(h_dim, z_dim)
        
        self.y_out = torch.nn.Sequential(
            torch.nn.Linear(h_dim, y_dim),
            torch.nn.Softmax()
            )        

        initialize_weights(self)

    def forward(self, input_x):
        x = self.fc(input_x)
        y = self.y_out(x)
        x = self.z_out(x)
        return [x, y]

    

In [0]:
# save plots of loss functions after training

def save_plot_losses_BiCoGAN(train_D_loss, train_G_loss, train_E_loss, eval_D_loss, eval_G_loss, eval_E_loss, model_used, z_dim, epochs, lr, batch_size):

    x = np.arange(1, len(train_D_loss) + 1)

    plt.figure(figsize=(8, 6))
    plt.plot(x, train_G_loss, label="Train G loss", linewidth=2)
    plt.plot(x, eval_G_loss, label="Eval G loss", linewidth=2)


    plt.axes().set_xlabel('Epoch')
    plt.axes().set_ylabel('Loss')
    plt.legend(loc='upper right')
    plt.suptitle("Evolution of the Train and Eval losses of G")
    params = "Network type: " + model_used + ", Dimension of latent space: " + str(z_dim) + ", epochs: " + str(epochs) + ", learning rate: " + str(lr) + ", batch size:" + str(batch_size)
    plt.title(params, fontsize=8)
    # plt.show()
    plt.savefig("plot_G_losses.eps", format='eps', dpi=1000)



    plt.figure(figsize=(8, 6))
    plt.plot(x, train_D_loss, label="Train D loss", linewidth=2)
    plt.plot(x, eval_D_loss, label="Eval D loss", linewidth=2)

    plt.axes().set_xlabel('Epoch')
    plt.axes().set_ylabel('Loss')
    plt.legend(loc='upper right')

    plt.suptitle("Evolution of the Train and Eval losses of D")

    params = "Network type: " + model_used + ", Dimension of latent space: " + str(z_dim) + ", epochs: " + str(epochs) + ", learning rate: " + str(lr) + ", batch size:" + str(batch_size)
    plt.title(params, fontsize=8)
    # plt.show()
    plt.savefig("plot_D_losses.eps", format='eps', dpi=1000)
    
    
    
    plt.figure(figsize=(8, 6))
    plt.plot(x, train_E_loss, label="Train E loss", linewidth=2)
    plt.plot(x, eval_E_loss, label="Eval E loss", linewidth=2)

    plt.axes().set_xlabel('Epoch')
    plt.axes().set_ylabel('Loss')
    plt.legend(loc='upper right')

    plt.suptitle("Evolution of the Train and Eval losses of E")

    params = "Network type: " + model_used + ", Dimension of latent space: " + str(z_dim) + ", epochs: " + str(epochs) + ", learning rate: " + str(lr) + ", batch size:" + str(batch_size)
    plt.title(params, fontsize=8)
    # plt.show()
    plt.savefig("plot_E_losses.eps", format='eps', dpi=1000)
    plt.close()


def save_plot_pixel_norm(mean_pixel_norm, model_used, z_dim, epochs, lr, batch_size):
    x = np.arange(1, len(mean_pixel_norm) + 1)

    plt.figure(figsize=(8, 6))
    plt.plot(x, mean_pixel_norm, label="Reconstruction error", linewidth=2)

    plt.axes().set_xlabel('Epoch')
    plt.axes().set_ylabel('Norm')
    plt.legend(loc='upper right')
    plt.suptitle("Evolution of the reconstruction error between X and G(E(X))")
    params = "Network type: " + model_used + ", Dimension of latent space: " + str(z_dim) + ", epochs: " + str(epochs) + ", learning rate: " + str(lr) + ", batch size:" + str(batch_size)
    plt.title(params, fontsize=8)
    # plt.show()
    plt.savefig("pix2pix_norm.eps", format='eps', dpi=1000)
    plt.close()

def save_plot_z_norm(mean_z_norm, model_used, z_dim, epochs, lr, batch_size):
    x = np.arange(1, len(mean_z_norm) + 1)

    plt.figure(figsize=(8, 6))
    plt.plot(x, mean_z_norm, label="Reconstruction error", linewidth=2)

    plt.axes().set_xlabel('Epoch')
    plt.axes().set_ylabel('Norm')
    plt.legend(loc='upper right')
    plt.suptitle("Evolution of the reconstruction error between z and E(G(z))")
    params = "Network type: " + model_used + ", Dimension of latent space: " + str(z_dim) + ", epochs: " + str(epochs) + ", learning rate: " + str(lr) + ", batch size:" + str(batch_size)
    plt.title(params, fontsize=8)
    # plt.show()
    plt.savefig("z_norm.eps", format='eps', dpi=1000)
    plt.close()



In [0]:
# BiCoGAN net

class BiCoGAN(object):
    """
    Class implementing a BiCoGAN network that trains from an observations dataset.
    """

    def __init__(self, kwargs):
        
        self.epoch = kwargs['epoch']
        self.batch_size = kwargs['batch_size']
        self.save_dir = kwargs['save_dir']
        self.result_dir = kwargs['result_dir']
        self.log_dir = kwargs['log_dir']
        self.gpu_mode = kwargs['gpu_mode']
        self.alpha = kwargs['alpha']
        self.phi = kwargs['phi']
        self.rho = kwargs['rho']
        self.learning_rate = kwargs['lr']
        self.lr_decay = kwargs['lr_decay']
        self.beta1 = kwargs['beta1']
        self.beta2 = kwargs['beta2']
        self.decay = kwargs['decay']
        self.network_type = kwargs['network_type']
        self.dataset = kwargs['dataset']
        self.dataset_path = kwargs['dataset_path']

        # BIGAN parameters
        self.z_dim = kwargs['z_dim']    #dimension of feature space
        self.h_dim = kwargs['h_dim']    #dimension of the hidden layer
        self.y_dim = kwargs['y_dim']    #number of label classes

        if kwargs['dataset'] == 'mnist':
            self.X_dim = 28*28    #dimension of data
            self.num_channels = 1

        if kwargs['network_type'] == 'FC':
            # networks init
            self.G = Generator_BiCoGAN(self.z_dim, self.h_dim, self.X_dim, self.y_dim)
            self.D = Discriminator_FC(self.z_dim, self.h_dim, self.X_dim)
            self.E = Encoder_BiCoGAN(self.z_dim, self.h_dim, self.X_dim, self.y_dim)
        else:
            raise Exception("[!] There is no option for " + kwargs['network_type'])

        if self.gpu_mode:
            self.G.cuda()
            self.D.cuda()
            self.E.cuda()

        
        self.G_solver = optim.Adam(chain(self.E.parameters(), self.G.parameters()), lr=self.learning_rate, betas=[self.beta1,self.beta2], weight_decay=self.decay)
        self.D_solver = optim.Adam(self.D.parameters(), lr=self.learning_rate, betas=[self.beta1,self.beta2], weight_decay=self.decay)
        
        # exponential learning rate decay starting halfway through learning
        self.G_scheduler = optim.lr_scheduler.MultiStepLR(self.G_solver, milestones=list(range(self.epoch//2, self.epoch)), gamma=self.lr_decay)
        self.D_scheduler = optim.lr_scheduler.MultiStepLR(self.D_solver, milestones=list(range(self.epoch//2, self.epoch)), gamma=self.lr_decay)



        print('---------- Networks architecture -------------')
        print_network(self.G)
        print_network(self.E)
        print_network(self.D)
        print('-----------------------------------------------')



    def D_(self, X, z):
        return self.D(X, z)

    def reset_grad(self):
        self.E.zero_grad()
        self.G.zero_grad()
        self.D.zero_grad()


    def train(self):
        if self.dataset == 'mnist':
            dataset = Mnist(self.batch_size)


        self.train_hist = {}
        self.train_hist['D_loss'] = []
        self.train_hist['G_loss'] = []
        self.train_hist['E_loss'] = []

        self.eval_hist = {}
        self.eval_hist['D_loss'] = []
        self.eval_hist['G_loss'] = []
        self.eval_hist['E_loss'] = []
        self.eval_hist['pixel_norm'] = []
        self.eval_hist['z_norm'] = []

                    
        for epoch in range(self.epoch):
            print("epoch ", str(epoch))

            self.D.train()
            self.E.train()
            self.G.train()

            train_loss_G = 0
            train_loss_D = 0
            train_loss_E = 0
            
            # learning rate schedule
            self.G_scheduler.step()
            self.D_scheduler.step()
            
            self.efl_gamma = min(self.alpha*math.exp(self.rho*epoch), self.phi)


            for batch_id, (data, target) in enumerate(dataset.train_loader):
                # sometimes batchsize of X is not equal to actual batch_size
                if data.size(0) == self.batch_size:

                    if self.gpu_mode:
                        # sample z
                        z = Variable((1 - (-1))*torch.rand(self.batch_size, self.z_dim) + (-1)).cuda() # uniform on [-1,1]
                        # X is a real image from the dataset
                        X = data
                        X = Variable(X).cuda()
                        y_onehot = torch.FloatTensor(self.batch_size, self.y_dim).cuda() # dummy for one-hot encodings
                        y = Variable(target).cuda()                    
                        target = target.cuda()
                    else:
                        z = Variable((1 - (-1))*torch.rand(self.batch_size, self.z_dim) + (-1)) # uniform on [-1,1]
                        X = data
                        X = Variable(X)
                        y_onehot = torch.FloatTensor(self.batch_size, self.y_dim) # dummy for one-hot encodings
                        y = Variable(target)


                    if self.network_type == 'FC':
                        X = X.view(self.batch_size, -1)
                        y = y.view(self.batch_size, -1)                        
                        y_onehot.zero_()
                        y_onehot.scatter_(1, y, 1)                        
                        z_hat = self.E(X)[0]
                        y_hat = self.E(X)[1]
                        X_hat = self.G(z, y_onehot)

                        D_enc = self.D_(X, z_hat)
                        D_gen = self.D_(X_hat, z)


                    D_loss = -torch.mean(log(D_enc) + log(1 - D_gen))
                    G_loss = -torch.mean(log(D_gen) + log(1 - D_enc))
                    
                    EFL_loss = torch.nn.CrossEntropyLoss()
                    E_loss = torch.mean(self.efl_gamma*EFL_loss(y_hat, target))
                    
                    G_EFL_loss = G_loss + E_loss

                    D_loss.backward(retain_graph=True)
                    self.D_solver.step()
                    self.reset_grad()

                    G_EFL_loss.backward()
                    self.G_solver.step()
                    self.reset_grad()
                    

                    train_loss_G += G_loss.data
                    train_loss_D += D_loss.data
                    train_loss_E += E_loss.data

                    if batch_id % 1000 == 0:
                        # Print and plot every now and then
                        samples = X_hat.data.cpu().numpy()

                        fig = plt.figure(figsize=(8, 4))
                        gs = gridspec.GridSpec(4, 8)
                        gs.update(wspace=0.05, hspace=0.05)

                        for i, sample in enumerate(samples):
                            if i<32:
                                ax = plt.subplot(gs[i])
                                plt.axis('off')
                                ax.set_xticklabels([])
                                ax.set_yticklabels([])
                                ax.set_aspect('equal')

                                if self.network_type == 'FC':
                                    if self.dataset == 'mnist':
                                        sample = sample.reshape(28, 28)
                                        plt.imshow(sample, cmap='Greys_r')


                        if not os.path.exists(self.result_dir + '/train/'):
                            os.makedirs(self.result_dir + '/train/')

                        filename = "epoch_" + str(epoch) + "_batchid_" + str(batch_id)
                        plt.savefig(self.result_dir + '/train/{}.png'.format(filename, bbox_inches='tight'))
                        plt.close()

            print("Train loss G:", train_loss_G / len(dataset.train_loader))
            print("Train loss D:", train_loss_D / len(dataset.train_loader))
            print("Train loss E:", train_loss_E / len(dataset.train_loader))

            self.train_hist['D_loss'].append(train_loss_D / len(dataset.train_loader))
            self.train_hist['G_loss'].append(train_loss_G / len(dataset.train_loader))
            self.train_hist['E_loss'].append(train_loss_E / len(dataset.train_loader))


            self.D.eval()
            self.E.eval()
            self.G.eval()
            test_loss_G = 0
            test_loss_D = 0
            test_loss_E = 0

            mean_pixel_norm = 0
            mean_z_norm = 0
            norm_counter = 1

            for batch_id, (data, target) in enumerate(dataset.test_loader):
                if data.size(0) == self.batch_size:
                    # Sample data
                    z = Variable((1 - (-1))*torch.rand(self.batch_size, self.z_dim) + (-1)) # uniform on [-1,1]
                    X_data = Variable(data)
                    y_onehot = torch.FloatTensor(self.batch_size, self.y_dim) # dummy for one-hot encodings
                    y = Variable(target)

                    if self.gpu_mode:
                        z = z.cuda()
                        X_data = X_data.cuda()
                        y_onehot = y_onehot.cuda()
                        y = y.cuda()
                        target = target.cuda()

                    
                    X = X_data
                    if self.network_type == 'FC':                        
                        X = X.view(self.batch_size, -1)
                        y = y.view(self.batch_size, -1)
                        y_onehot.zero_()
                        y_onehot.scatter_(1, y, 1) # one-hot encoding of lables
                        z_hat =  self.E(X)[0].view(self.batch_size, -1)
                        y_hat = self.E(X)[1]
                        X_hat = self.G(z, y_onehot)

                        D_enc = self.D_(X, z_hat)
                        D_gen = self.D_(X_hat, z)

                    D_loss = -torch.mean(log(D_enc) + log(1 - D_gen))
                    G_loss = -torch.mean(log(D_gen) + log(1 - D_enc))                    
                    
                    EFL_loss = torch.nn.CrossEntropyLoss()
                    E_loss = torch.mean(self.efl_gamma*EFL_loss(y_hat, target))
                    
                    # not needed for test data
                    #G_EFL_loss = G_loss + E_loss
                    
                    test_loss_G += G_loss.data
                    test_loss_D += D_loss.data
                    test_loss_E += E_loss.data

                    pixel_norm = X -  self.G(z_hat, y_onehot)
                    pixel_norm = pixel_norm.norm().data / float(self.X_dim)
                    mean_pixel_norm += pixel_norm


                    z_norm = z - self.E(X_hat)[0]
                    z_norm = z_norm.norm().data / float(self.z_dim)
                    mean_z_norm += z_norm

                    norm_counter += 1


            print("Eval loss G:", test_loss_G / norm_counter)
            print("Eval loss D:", test_loss_D / norm_counter)
            print("Eval loss E:", test_loss_E / norm_counter)

            self.eval_hist['D_loss'].append(test_loss_D / norm_counter)
            self.eval_hist['G_loss'].append(test_loss_G / norm_counter)
            self.eval_hist['E_loss'].append(test_loss_E / norm_counter)

            print("Pixel norm:", mean_pixel_norm / norm_counter)
            self.eval_hist['pixel_norm'].append( mean_pixel_norm / norm_counter )

            with open('pixel_error_BIGAN.txt', 'a') as f:
                f.writelines(str(mean_pixel_norm / norm_counter) + '\n')

            print("z norm:", mean_z_norm / norm_counter)
            self.eval_hist['z_norm'].append( mean_z_norm / norm_counter )

            with open('z_error_BIGAN.txt', 'a') as f:
                f.writelines(str(mean_z_norm / norm_counter) + '\n')

            ##### At the end of the epoch, save X and its reconstruction G(E(X))
            samples = X.data.cpu().numpy()

            fig = plt.figure(figsize=(10, 2))
            gs = gridspec.GridSpec(2, 10)
            gs.update(wspace=0.05, hspace=0.05)

            for i, sample in enumerate(samples):
                if i<10:
                    ax = plt.subplot(gs[i])
                    plt.axis('off')
                    ax.set_xticklabels([])
                    ax.set_yticklabels([])
                    ax.set_aspect('equal')
                    if self.network_type == 'FC':
                        if self.dataset == 'mnist':
                            sample = sample.reshape(28, 28)
                            plt.imshow(sample, cmap='Greys_r')
                        
            # alterative logic: keep Encoder output as is                        
            #X_hat = self.G(self.E(X)[0].view(self.batch_size, self.z_dim), self.E(X)[1].view(self.batch_size, self.y_dim))
            
            # transform Encoder output to one-hot encoding
            y_hat = torch.argmax(self.E(X)[1].view(self.batch_size, self.y_dim), dim=1)
            y_hat = y_hat.view(self.batch_size, -1)
            y_hat_onehot = torch.FloatTensor(self.batch_size, self.y_dim) # dummy for one-hot encodings
            
            if self.gpu_mode:
                y_hat_onehot = y_hat_onehot.cuda()                           
            
            y_hat_onehot.zero_()
            y_hat_onehot.scatter_(1, y_hat, 1)
            
            X_hat = self.G(self.E(X)[0].view(self.batch_size, self.z_dim), y_hat_onehot)
            samples = X_hat.data.cpu().numpy()


            for i, sample in enumerate(samples):
                if i<10:
                    ax = plt.subplot(gs[10+i])
                    plt.axis('off')
                    ax.set_xticklabels([])
                    ax.set_yticklabels([])
                    ax.set_aspect('equal')
                    if self.network_type == 'FC':
                        if self.dataset == 'mnist':
                            sample = sample.reshape(28, 28)
                            plt.imshow(sample, cmap='Greys_r')
                        

            if not os.path.exists(self.result_dir + '/recons/'):
                os.makedirs(self.result_dir + '/recons/')

            filename = "epoch_" + str(epoch)
            plt.savefig(self.result_dir + '/recons/{}.png'.format(filename), bbox_inches='tight')
            plt.close()


        save_plot_losses_BiCoGAN(self.train_hist['D_loss'], self.train_hist['G_loss'], self.train_hist['E_loss'], self.eval_hist['D_loss'], self.eval_hist['G_loss'], self.eval_hist['E_loss'], self.network_type, self.z_dim, self.epoch, self.learning_rate, self.batch_size)
        save_plot_pixel_norm(self.eval_hist['pixel_norm'], self.network_type, self.z_dim, self.epoch, self.learning_rate, self.batch_size)
        save_plot_z_norm(self.eval_hist['z_norm'], self.network_type, self.z_dim, self.epoch, self.learning_rate, self.batch_size)

    def save_model(self):
        if not os.path.exists(self.save_dir):
            os.makedirs(self.save_dir)
        
        torch.save(self.G.state_dict(), self.save_dir + "/G.pt")
        torch.save(self.E.state_dict(), self.save_dir + "/E.pt")
        torch.save(self.D.state_dict(), self.save_dir + "/D.pt")

    def load_model(self, kwargs):
        if kwargs['network_type'] == 'FC':
            # networks init
            self.G = Generator_BiCoGAN(self.z_dim, self.h_dim, self.X_dim, self.y_dim)
            self.D = Discriminator_FC(self.z_dim, self.h_dim, self.X_dim)
            self.E = Encoder_BiCoGAN(self.z_dim, self.h_dim, self.X_dim, self.y_dim)
        
        self.G.load_state_dict(torch.load("/models/G.pt"))
        self.E.load_state_dict(torch.load("/models/E.pt"))
        self.D.load_state_dict(torch.load("/models/D.pt"))
        

        if self.gpu_mode:
            self.G.cuda()
            self.D.cuda()
            self.E.cuda()



## BiCoGAN Training

In [0]:
"""config"""

kwargs_bicogan = {
    'dataset': 'mnist',
    'dataset_path': '/home/zelazny/Downloads/tooploox/data/',
    'gpu_mode': True,
    'save_dir': 'models',
    'result_dir': 'results',
    'log_dir': 'logs',
    'epoch': 400,
    'batch_size': 128,
    'alpha' : 5,
    'phi' : 10,
    'rho' : 0.25,
    'lr': 1e-4,
    'lr_decay': (0.01)**(1/200), # exponential decay to 1e-6 over 200 epochs
    'beta1': 0.5,
    'beta2': 0.999,
    # 'slope': 1e-2,
    'decay': 2.5*1e-5,
    # 'dropout': 0.2,
    'network_type': 'FC',
    'z_dim': 50,
    'h_dim': 1024,
    'y_dim': 10
}

"""check arguments"""

def check_kwargs(kwargs_bicogan):
    # save_dir
    if not os.path.exists(kwargs_bicogan['save_dir']):
        os.makedirs(kwargs_bicogan['save_dir'])

    # result_dir
    if not os.path.exists(kwargs_bicogan['result_dir']):
        os.makedirs(kwargs_bicogan['result_dir'])

    # log_dir
    if not os.path.exists(kwargs_bicogan['log_dir']):
        os.makedirs(kwargs_bicogan['log_dir'])

    # epoch
    try:
        assert kwargs_bicogan['epoch'] >= 1
    except:
        print('number of epochs must be larger than or equal to one')

    # batch_size
    try:
        assert kwargs_bicogan['batch_size'] >= 1
    except:
        print('batch size must be larger than or equal to one')

    return kwargs_bicogan


In [0]:
#import warnings

"""main"""
def main():
    #warnings.simplefilter('error', UserWarning)
    # check arguments
    if kwargs_bicogan is None:
        exit()
    else:
        check_kwargs(kwargs_bicogan)

    bicogan = BiCoGAN(kwargs_bicogan)

    # wipe old files
    with open('pixel_error_BIGAN.txt', 'w') as f:
        f.writelines('')
    with open('z_error_BIGAN.txt', 'w') as f:
        f.writelines('')

    bicogan.train()
    print(" [*] Training finished!")

    bicogan.save_model()
    

if __name__ == '__main__':
    main()


## BiCoGAN Evaluation

In [0]:
# load trained net

bicogan = BiCoGAN(kwargs_bicogan)
bicogan.load_model(kwargs_bicogan)

dataset = Mnist(128)


In [0]:
# generate encodings for dataset

def generate_encodings(dataset, train_test):
    if train_test == 'train':
        loader = dataset.train_loader
    elif train_test == 'test':
        loader = dataset.test_loader
    
    encodings = []
    labels = []
    
    for batch_id, (data, target) in enumerate(loader):
    
        X_data = Variable(data)
    
        if bicogan.gpu_mode:
            X_data = X_data.cuda()
        
        if X_data.size(0) == bicogan.batch_size:
            X = X_data
            X = X.view(bicogan.batch_size, -1)
            z_hat = bicogan.E(X)[0]
            encodings.append(z_hat)
            labels.append(target)


    encodings = torch.cat(encodings).data.cpu().numpy()
    labels = torch.cat(labels).data.cpu().numpy()

    return encodings, labels



In [0]:
# generate encodings for test set

encodings_train, labels_train = generate_encodings(dataset, 'train')
encodings_test, labels_test = generate_encodings(dataset, 'test')


In [0]:
# generate and visualize t-SNE encodings for test set

tsne_encodings_test = TSNE().fit_transform(encodings_test)
scatter(tsne_encodings_test, labels_test)


In [0]:
# kNN classification accuracy on test set

def knn_results(n):
    knn = KNeighborsClassifier(n_neighbors=n)
    knn.fit(encodings_train, labels_train)
    labels_hat = knn.predict(encodings_test)

    print('%sNN classification accuracy (%%)' %n, round(metrics.accuracy_score(labels_test, labels_hat)*100, 2))

    
for i in range(1, 11):
    knn_results(i)


## TriBiGAN Architecture and Utilities

In [0]:
# triplet loss

class TripletLoss(nn.Module):
    """
    Triplet loss
    Takes embeddings of an anchor sample, a positive sample and a negative sample
    """

    def __init__(self, margin):
        super(TripletLoss, self).__init__()
        self.margin = margin

    def forward(self, anchor, positive, negative, size_average=True):
        distance_positive = (anchor - positive).pow(2).sum(1)  # .pow(.5)
        distance_negative = (anchor - negative).pow(2).sum(1)  # .pow(.5)
        losses = F.relu(distance_positive - distance_negative + self.margin)
        return losses.mean() if size_average else losses.sum()
   


In [0]:
# save plots of loss functions after training

def save_plot_losses_triplet(train_D_loss, train_G_loss, train_triplet_loss, eval_D_loss, eval_G_loss, eval_triplet_loss, model_used, z_dim, epochs, lr, batch_size):

    x = np.arange(1, len(train_D_loss) + 1)

    plt.figure(figsize=(8, 6))
    plt.plot(x, train_G_loss, label="Train G loss", linewidth=2)
    plt.plot(x, eval_G_loss, label="Eval G loss", linewidth=2)


    plt.axes().set_xlabel('Epoch')
    plt.axes().set_ylabel('Loss')
    plt.legend(loc='upper right')
    plt.suptitle("Evolution of the Train and Eval losses of G")
    params = "Network type: " + model_used + ", Dimension of latent space: " + str(z_dim) + ", epochs: " + str(epochs) + ", learning rate: " + str(lr) + ", batch size:" + str(batch_size)
    plt.title(params, fontsize=8)
    # plt.show()
    plt.savefig("plot_G_losses.eps", format='eps', dpi=1000)



    plt.figure(figsize=(8, 6))
    plt.plot(x, train_D_loss, label="Train D loss", linewidth=2)
    plt.plot(x, eval_D_loss, label="Eval D loss", linewidth=2)

    plt.axes().set_xlabel('Epoch')
    plt.axes().set_ylabel('Loss')
    plt.legend(loc='upper right')

    plt.suptitle("Evolution of the Train and Eval losses of D")

    params = "Network type: " + model_used + ", Dimension of latent space: " + str(z_dim) + ", epochs: " + str(epochs) + ", learning rate: " + str(lr) + ", batch size:" + str(batch_size)
    plt.title(params, fontsize=8)
    # plt.show()
    plt.savefig("plot_D_losses.eps", format='eps', dpi=1000)
    
    
    
    plt.figure(figsize=(8, 6))
    plt.plot(x, train_triplet_loss, label="Train triplet loss", linewidth=2)
    plt.plot(x, eval_triplet_loss, label="Eval triplet loss", linewidth=2)

    plt.axes().set_xlabel('Epoch')
    plt.axes().set_ylabel('Loss')
    plt.legend(loc='upper right')

    plt.suptitle("Evolution of the Train and Eval triplet losses")

    params = "Network type: " + model_used + ", Dimension of latent space: " + str(z_dim) + ", epochs: " + str(epochs) + ", learning rate: " + str(lr) + ", batch size:" + str(batch_size)
    plt.title(params, fontsize=8)
    # plt.show()
    plt.savefig("plot_triplet_losses.eps", format='eps', dpi=1000)
    plt.close()


def save_plot_pixel_norm(mean_pixel_norm, model_used, z_dim, epochs, lr, batch_size):
    x = np.arange(1, len(mean_pixel_norm) + 1)

    plt.figure(figsize=(8, 6))
    plt.plot(x, mean_pixel_norm, label="Reconstruction error", linewidth=2)

    plt.axes().set_xlabel('Epoch')
    plt.axes().set_ylabel('Norm')
    plt.legend(loc='upper right')
    plt.suptitle("Evolution of the reconstruction error between X and G(E(X))")
    params = "Network type: " + model_used + ", Dimension of latent space: " + str(z_dim) + ", epochs: " + str(epochs) + ", learning rate: " + str(lr) + ", batch size:" + str(batch_size)
    plt.title(params, fontsize=8)
    # plt.show()
    plt.savefig("pix2pix_norm.eps", format='eps', dpi=1000)
    plt.close()

def save_plot_z_norm(mean_z_norm, model_used, z_dim, epochs, lr, batch_size):
    x = np.arange(1, len(mean_z_norm) + 1)

    plt.figure(figsize=(8, 6))
    plt.plot(x, mean_z_norm, label="Reconstruction error", linewidth=2)

    plt.axes().set_xlabel('Epoch')
    plt.axes().set_ylabel('Norm')
    plt.legend(loc='upper right')
    plt.suptitle("Evolution of the reconstruction error between z and E(G(z))")
    params = "Network type: " + model_used + ", Dimension of latent space: " + str(z_dim) + ", epochs: " + str(epochs) + ", learning rate: " + str(lr) + ", batch size:" + str(batch_size)
    plt.title(params, fontsize=8)
    # plt.show()
    plt.savefig("z_norm.eps", format='eps', dpi=1000)
    plt.close()


In [0]:
# TriBiGAN net

class TripletBIGAN(object):
    """
    Class implementing a BIGAN network with an additional triplet loss.
    """

    def __init__(self, kwargs):
        
        self.epoch = kwargs['epoch']
        self.batch_size = kwargs['batch_size']
        self.save_dir = kwargs['save_dir']
        self.result_dir = kwargs['result_dir']
        self.log_dir = kwargs['log_dir']
        self.gpu_mode = kwargs['gpu_mode']
        self.triplet_margin = kwargs['triplet_margin']
        self.learning_rate = kwargs['lr']
        self.lr_decay = kwargs['lr_decay']
        self.beta1 = kwargs['beta1']
        self.beta2 = kwargs['beta2']
        self.decay = kwargs['decay']      
        self.network_type = kwargs['network_type']
        self.dataset = kwargs['dataset']
        self.dataset_path = kwargs['dataset_path']

        # BIGAN parameters
        self.z_dim = kwargs['z_dim']    #dimension of feature space
        self.h_dim = kwargs['h_dim']    #dimension of the hidden layer

        if kwargs['dataset'] == 'mnist':
            self.X_dim = 28*28    #dimension of data
            self.num_channels = 1

        if kwargs['network_type'] == 'FC':
            # networks init
            self.G = Generator_FC(self.z_dim, self.h_dim, self.X_dim)
            self.D = Discriminator_FC(self.z_dim, self.h_dim, self.X_dim)
            self.E = Encoder_FC(self.z_dim, self.h_dim, self.X_dim)
        else:
            raise Exception("[!] There is no option for " + kwargs['network_type'])

        if self.gpu_mode:
            self.G.cuda()
            self.D.cuda()
            self.E.cuda()

        
        self.G_solver = optim.Adam(chain(self.E.parameters(), self.G.parameters()), lr=self.learning_rate, betas=[self.beta1,self.beta2], weight_decay=self.decay)
        self.D_solver = optim.Adam(self.D.parameters(), lr=self.learning_rate, betas=[self.beta1,self.beta2], weight_decay=self.decay)
        
        # exponential learning rate decay starting halfway through learning
        self.G_scheduler = optim.lr_scheduler.MultiStepLR(self.G_solver, milestones=list(range(self.epoch//2, self.epoch)), gamma=self.lr_decay)
        self.D_scheduler = optim.lr_scheduler.MultiStepLR(self.D_solver, milestones=list(range(self.epoch//2, self.epoch)), gamma=self.lr_decay)



        print('---------- Networks architecture -------------')
        print_network(self.G)
        print_network(self.E)
        print_network(self.D)
        print('-----------------------------------------------')



    def D_(self, X, z):
        return self.D(X, z)

    def reset_grad(self):
        self.E.zero_grad()
        self.G.zero_grad()
        self.D.zero_grad()


    def train(self):
        if self.dataset == 'mnist':
            dataset = TripletMnist(self.batch_size)


        self.train_hist = {}
        self.train_hist['D_loss'] = []
        self.train_hist['G_loss'] = []
        self.train_hist['triplet_loss'] = []

        self.eval_hist = {}
        self.eval_hist['D_loss'] = []
        self.eval_hist['G_loss'] = []
        self.eval_hist['triplet_loss'] = []
        self.eval_hist['pixel_norm'] = []
        self.eval_hist['z_norm'] = []


        for epoch in range(self.epoch):
            print("epoch ",str(epoch))

            self.D.train()
            self.E.train()
            self.G.train()

            train_loss_G = 0
            train_loss_D = 0
            train_triplet_loss = 0
            
            # learning rate schedule
            self.G_scheduler.step()
            self.D_scheduler.step()


            for batch_id, (data, target) in enumerate(dataset.train_loader):
                
                # sometimes batchsize of X is not equal to actual batch_size
                if data[0].size(0) == self.batch_size:
                    # sample z
                    z = Variable((1 - (-1))*torch.rand(self.batch_size, self.z_dim) + (-1))# uniform on [-1,1]
                    # (X = anchor) is a real image from the dataset                                                       
                    anchor = Variable(data[0])
                    positive = Variable(data[1])
                    negative = Variable(data[2])

                    if self.gpu_mode:
                        z = z.cuda()
                        anchor = anchor.cuda()
                        positive = positive.cuda()
                        negative = negative.cuda()

                    if self.network_type == 'FC':                        
                        anchor = anchor.view(self.batch_size, -1)
                        positive = positive.view(self.batch_size, -1)
                        negative = negative.view(self.batch_size, -1)
                        z_hat = self.E(anchor)
                        X_hat = self.G(z)

                        D_enc = self.D_(anchor, z_hat)
                        D_gen = self.D_(X_hat, z)                      


                    D_loss = -torch.mean(log(D_enc) + log(1 - D_gen))
                    G_loss = -torch.mean(log(D_gen) + log(1 - D_enc))                    
                    
                    triplet_fn = TripletLoss(self.triplet_margin)
                    triplet_loss = triplet_fn(self.E(anchor), self.E(positive), self.E(negative))
                    
                    G_triplet_loss = G_loss + triplet_loss

                    D_loss.backward(retain_graph=True)
                    self.D_solver.step()
                    self.reset_grad()

                    G_triplet_loss.backward()
                    self.G_solver.step()
                    self.reset_grad()

                    train_loss_G += G_loss.data
                    train_loss_D += D_loss.data
                    train_triplet_loss += triplet_loss.data

                    if batch_id % 1000 == 0:
                        # Print and plot every now and then
                        samples = X_hat.data.cpu().numpy()

                        fig = plt.figure(figsize=(8, 4))
                        gs = gridspec.GridSpec(4, 8)
                        gs.update(wspace=0.05, hspace=0.05)

                        for i, sample in enumerate(samples):
                            if i<32:
                                ax = plt.subplot(gs[i])
                                plt.axis('off')
                                ax.set_xticklabels([])
                                ax.set_yticklabels([])
                                ax.set_aspect('equal')

                                if self.network_type == 'FC':
                                    if self.dataset == 'mnist':
                                        sample = sample.reshape(28, 28)
                                        plt.imshow(sample, cmap='Greys_r')


                        if not os.path.exists(self.result_dir + '/train/'):
                            os.makedirs(self.result_dir + '/train/')

                        filename = "epoch_" + str(epoch) + "_batchid_" + str(batch_id)
                        plt.savefig(self.result_dir + '/train/{}.png'.format(filename, bbox_inches='tight'))
                        plt.close()

            print("Train loss G:", train_loss_G / len(dataset.train_loader))
            print("Train loss D:", train_loss_D / len(dataset.train_loader))
            print("Train triplet loss:", train_triplet_loss / len(dataset.train_loader))

            self.train_hist['D_loss'].append(train_loss_D / len(dataset.train_loader))
            self.train_hist['G_loss'].append(train_loss_G / len(dataset.train_loader))
            self.train_hist['triplet_loss'].append(train_triplet_loss / len(dataset.train_loader))


            self.D.eval()
            self.E.eval()
            self.G.eval()
            test_loss_G = 0
            test_loss_D = 0
            test_triplet_loss = 0

            mean_pixel_norm = 0
            mean_z_norm = 0
            norm_counter = 1

            for batch_id, (data, target) in enumerate(dataset.test_loader):

                if data[0].size(0) == self.batch_size:
                    
                    # Sample data
                    z = Variable((1 - (-1))*torch.rand(self.batch_size, self.z_dim) + (-1)) # uniform on [-1,1]                    
                    anchor = Variable(data[0])
                    positive = Variable(data[1])
                    negative = Variable(data[2])

                    if self.gpu_mode:
                        z = z.cuda()
                        anchor = anchor.cuda()
                        positive = positive.cuda()
                        negative = negative.cuda()                
                    
                    if self.network_type == 'FC':
                        anchor = anchor.view(self.batch_size, -1)
                        positive = positive.view(self.batch_size, -1)
                        negative = negative.view(self.batch_size, -1)
                        z_hat = self.E(anchor)
                        X_hat = self.G(z)

                        D_enc = self.D_(anchor, z_hat)
                        D_gen = self.D_(X_hat, z)
                                                

                    D_loss = -torch.mean(log(D_enc) + log(1 - D_gen))
                    G_loss = -torch.mean(log(D_gen) + log(1 - D_enc))
                    triplet_fn = TripletLoss(self.triplet_margin)                  
                    triplet_loss = triplet_fn(self.E(anchor), self.E(positive), self.E(negative))

                    test_loss_G += G_loss.data
                    test_loss_D += D_loss.data
                    test_triplet_loss += triplet_loss.data

                    pixel_norm = anchor -  self.G(z_hat)
                    pixel_norm = pixel_norm.norm().data / float(self.X_dim)
                    mean_pixel_norm += pixel_norm


                    z_norm = z - self.E(X_hat)
                    z_norm = z_norm.norm().data / float(self.z_dim)
                    mean_z_norm += z_norm

                    norm_counter += 1


            print("Eval loss G:", test_loss_G / norm_counter)
            print("Eval loss D:", test_loss_D / norm_counter)
            print("Eval triplet loss:", test_triplet_loss / norm_counter)

            self.eval_hist['D_loss'].append(test_loss_D / norm_counter)
            self.eval_hist['G_loss'].append(test_loss_G / norm_counter)
            self.eval_hist['triplet_loss'].append(test_triplet_loss / norm_counter)

            print("Pixel norm:", mean_pixel_norm / norm_counter)
            self.eval_hist['pixel_norm'].append( mean_pixel_norm / norm_counter )

            with open('pixel_error_BIGAN.txt', 'a') as f:
                f.writelines(str(mean_pixel_norm / norm_counter) + '\n')

            print("z norm:", mean_z_norm / norm_counter)
            self.eval_hist['z_norm'].append( mean_z_norm / norm_counter )

            with open('z_error_BIGAN.txt', 'a') as f:
                f.writelines(str(mean_z_norm / norm_counter) + '\n')

            ##### At the end of the epoch, save X and its reconstruction G(E(X))
            samples = anchor.data.cpu().numpy()

            fig = plt.figure(figsize=(10, 2))
            gs = gridspec.GridSpec(2, 10)
            gs.update(wspace=0.05, hspace=0.05)

            for i, sample in enumerate(samples):
                if i<10:
                    ax = plt.subplot(gs[i])
                    plt.axis('off')
                    ax.set_xticklabels([])
                    ax.set_yticklabels([])
                    ax.set_aspect('equal')
                    if self.network_type == 'FC':
                        if self.dataset == 'mnist':
                            sample = sample.reshape(28, 28)
                            plt.imshow(sample, cmap='Greys_r')
                        
                        
            X_hat = self.G(self.E(anchor).view(self.batch_size, self.z_dim))
            samples = X_hat.data.cpu().numpy()


            for i, sample in enumerate(samples):
                if i<10:
                    ax = plt.subplot(gs[10+i])
                    plt.axis('off')
                    ax.set_xticklabels([])
                    ax.set_yticklabels([])
                    ax.set_aspect('equal')
                    if self.network_type == 'FC':
                        if self.dataset == 'mnist':
                            sample = sample.reshape(28, 28)
                            plt.imshow(sample, cmap='Greys_r')
                        

            if not os.path.exists(self.result_dir + '/recons/'):
                os.makedirs(self.result_dir + '/recons/')

            filename = "epoch_" + str(epoch)
            plt.savefig(self.result_dir + '/recons/{}.png'.format(filename), bbox_inches='tight')
            plt.close()


        save_plot_losses_triplet(self.train_hist['D_loss'], self.train_hist['G_loss'], self.train_hist['triplet_loss'], self.eval_hist['D_loss'], self.eval_hist['G_loss'], self.eval_hist['triplet_loss'], self.network_type, self.z_dim, self.epoch, self.learning_rate, self.batch_size)
        save_plot_pixel_norm(self.eval_hist['pixel_norm'], self.network_type, self.z_dim, self.epoch, self.learning_rate, self.batch_size)
        save_plot_z_norm(self.eval_hist['z_norm'], self.network_type, self.z_dim, self.epoch, self.learning_rate, self.batch_size)

    def save_model(self):
        if not os.path.exists(self.save_dir):
            os.makedirs(self.save_dir)
        
        torch.save(self.G.state_dict(), self.save_dir + "/G.pt")
        torch.save(self.E.state_dict(), self.save_dir + "/E.pt")
        torch.save(self.D.state_dict(), self.save_dir + "/D.pt")

    def load_model(self, kwargs):
        if kwargs['network_type'] == 'FC':
            # networks init
            self.G = Generator_FC(self.z_dim, self.h_dim, self.X_dim)
            self.D = Discriminator_FC(self.z_dim, self.h_dim, self.X_dim)
            self.E = Encoder_FC(self.z_dim, self.h_dim, self.X_dim)
        
        self.G.load_state_dict(torch.load("/models/G.pt"))
        self.E.load_state_dict(torch.load("/models/E.pt"))
        self.D.load_state_dict(torch.load("/models/D.pt"))
        

        if self.gpu_mode:
            self.G.cuda()
            self.D.cuda()
            self.E.cuda()



## TriBiGAN Training

In [0]:
"""config"""

kwargs_triplet = {
    'dataset': 'mnist',
    'dataset_path': '/home/zelazny/Downloads/tooploox/data/',
    'gpu_mode': True,
    'save_dir': 'models',
    'result_dir': 'results',
    'log_dir': 'logs',
    'epoch': 400,
    'batch_size': 128,
    'triplet_margin': 0.2,
    'lr': 1e-4,
    'lr_decay': (0.01)**(1/200), # exponential decay to 1e-6 over 200 epochs
    'beta1': 0.5,
    'beta2': 0.999,
    # 'slope': 1e-2,
    'decay': 2.5*1e-5,
    # 'dropout': 0.2,
    'network_type': 'FC',
    'z_dim': 50,
    'h_dim': 1024
}

"""check arguments"""

def check_kwargs(kwargs_triplet):
    # save_dir
    if not os.path.exists(kwargs_triplet['save_dir']):
        os.makedirs(kwargs_triplet['save_dir'])

    # result_dir
    if not os.path.exists(kwargs_triplet['result_dir']):
        os.makedirs(kwargs_triplet['result_dir'])

    # log_dir
    if not os.path.exists(kwargs_triplet['log_dir']):
        os.makedirs(kwargs_triplet['log_dir'])

    # epoch
    try:
        assert kwargs_triplet['epoch'] >= 1
    except:
        print('number of epochs must be larger than or equal to one')

    # batch_size
    try:
        assert kwargs_triplet['batch_size'] >= 1
    except:
        print('batch size must be larger than or equal to one')

    return kwargs_triplet


In [0]:
#import warnings

"""main"""
def main():
    #warnings.simplefilter('error', UserWarning)
    # check arguments
    if kwargs_triplet is None:
        exit()
    else:
        check_kwargs(kwargs_triplet)

    triplet_bigan = TripletBIGAN(kwargs_triplet)

    # wipe old files
    with open('pixel_error_BIGAN.txt', 'w') as f:
        f.writelines('')
    with open('z_error_BIGAN.txt', 'w') as f:
        f.writelines('')

    triplet_bigan.train()
    print(" [*] Training finished!")

    triplet_bigan.save_model()
    

if __name__ == '__main__':
    main()


## TriBiGAN Evaluation

In [0]:
# load trained net

triplet_bigan = TripletBIGAN(kwargs_triplet)
triplet_bigan.load_model(kwargs_triplet)

dataset = TripletMnist(128)


In [0]:
# generate encodings for dataset

def generate_encodings(dataset, train_test):
    if train_test == 'train':
        loader = dataset.train_loader
    elif train_test == 'test':
        loader = dataset.test_loader
    
    encodings = []
    labels = []
    
    for batch_id, (data, target) in enumerate(loader):

        X_data = Variable(data[0])

        if triplet_bigan.gpu_mode:
            X_data = X_data.cuda()
        
        if X_data.size(0) == triplet_bigan.batch_size:
            X = X_data
            X = X.view(triplet_bigan.batch_size, -1)
            z_hat = triplet_bigan.E(X)
            encodings.append(z_hat)
            labels.append(target)


    encodings = torch.cat(encodings).data.cpu().numpy()
    labels = torch.cat(labels).data.cpu().numpy()

    return encodings, labels



In [0]:
# generate encodings for test set

encodings_train, labels_train = generate_encodings(dataset, 'train')
encodings_test, labels_test = generate_encodings(dataset, 'test')


In [0]:
# generate and visualize t-SNE encodings for test set

tsne_encodings_test = TSNE().fit_transform(encodings_test)
scatter(tsne_encodings_test, labels_test)


In [0]:
# kNN classification accuracy on test set

def knn_results(n):
    knn = KNeighborsClassifier(n_neighbors=n)
    knn.fit(encodings_train, labels_train)
    labels_hat = knn.predict(encodings_test)

    print('%sNN classification accuracy (%%)' %n, round(metrics.accuracy_score(labels_test, labels_hat)*100, 2))

    
for i in range(1, 11):
    knn_results(i)


## Conclusion

Now we can compare the three architectures.

And we are done! Hope you have enjoyed it!