In [1]:
import os
import time
import torch
import torch.nn as nn
import torch.nn.functional as F
import random
import argparse
import pandas as pd
import seaborn as sns
import matplotlib.pyplot as plt
from torchvision import transforms
from torch.utils.data import Dataset, DataLoader
from collections import defaultdict
from sklearn.manifold import TSNE
from sklearn.decomposition import PCA
from scipy.stats import entropy

import numpy as np

from load_moonboard import load_moonboard


### Parsing all arguments

In [2]:
parser = argparse.ArgumentParser()
parser.add_argument("--seed", type=int, default=0)
parser.add_argument("--epochs", type=int, default=20)
parser.add_argument("--lambd", type=int, default=0.5)
parser.add_argument("--batch_size", type=int, default=64)
parser.add_argument("--learning_rate", type=float, default=0.001)
parser.add_argument("--encoder_layer_sizes", type=list, default=[784, 256])
parser.add_argument("--decoder_layer_sizes", type=list, default=[256, 784])
parser.add_argument("--latent_size", type=int, default=30)
parser.add_argument("--print_every", type=int, default=10)
parser.add_argument("--fig_root", type=str, default='figs')
parser.add_argument("--representation", type=str, default=None)
parser.add_argument("--loss", type=str, default="MSE")
parser.add_argument("--test_loss", type=str, default="MSE")
parser.add_argument("--net", type=str, default="ResNet")
parser.add_argument("--conditional", action='store_true')
parser.add_argument("--variational", action='store_true')



args = parser.parse_args([])
args.conditional = True
args.lambd = 0.9
args.epochs = 70

### Setting up model and dataset

In [3]:
if args.net == "ResNet":
    from models_resnet import VAE
elif args.net == "conv":
    from models_conv import VAE
else:
    from models import VAE

torch.manual_seed(args.seed)
torch.set_default_tensor_type('torch.cuda.FloatTensor')
if torch.cuda.is_available():
    torch.cuda.manual_seed(args.seed)

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')


ts = time.time()


"""Load the dataset"""
class MoonBoardDataset(Dataset):
    """Face Landmarks dataset."""

    def __init__(self, train = True):
        self.train = train

        (self.x_train, self.y_train), (self.x_test, self.y_test) = load_moonboard()
        self.x_train = self.x_train.transpose(0,1,3,2 ).astype(float)
        self.x_test = self.x_test.transpose(0,1,3,2 ).astype(float)
        self.y_train = self.y_train.reshape(-1,1).astype(int)
        self.y_test = self.y_test.reshape(-1,1).astype(int)

    def __len__(self):
        if self.train:
            return len(self.x_train)
        else:
            return len(self.x_test)

    def __getitem__(self, idx):
        if self.train:
            return self.x_train[idx], self.y_train[idx]
        else:
            return self.x_test[idx], self.y_test[idx]

dataset = MoonBoardDataset(train = True)
dataset_test = MoonBoardDataset(train = False)

data_loader = torch.utils.data.DataLoader(dataset,
batch_size=128, shuffle=True)

data_loader_test = torch.utils.data.DataLoader(MoonBoardDataset(train = False),
batch_size=32, shuffle=True)

### Define loss

In [4]:
"""Define the loss"""
def loss_fn(recon_x, x, mean, log_var):
    """
        recon_x : reconstructed x after being through VAE or CVAE
        x : original x
        mean : center of the gaussian in average
        log_var : related to the standard deviation 
    """
    if args.loss == "BCE":
        recon_loss = torch.nn.BCELoss()(
        recon_x.view(-1, 3*18*11), x.view(-1, 3*18*11))
    elif args.loss == "MSE":
        recon_loss = torch.nn.MSELoss()(recon_x.view(-1,3*18*11), x.view(-1, 3*18*11))
    elif args.loss == "L1":
        recon_loss = torch.nn.L1Loss()(recon_x.view(-1,3*18*11), x.view(-1, 3*18*11))
    KLD = -0.5 * torch.sum(1 + log_var - mean.pow(2) - log_var.exp())/x.size(0)

    return (args.lambd * recon_loss + (1 - args.lambd) * KLD) 

def test_loss_fn(recon_x, x, mean, log_var):
    if args.loss == "BCE":
        recon_loss = torch.nn.BCELoss()(
        recon_x.view(-1, 3*18*11), x.view(-1, 3*18*11))
    elif args.loss == "MSE":
        recon_loss = torch.nn.MSELoss()(recon_x.view(-1, 3*18*11), x.view(-1, 3*18*11))
    elif args.loss == "L1":
        recon_loss = torch.nn.L1Loss()(recon_x.view(-1, 3*18*11), x.view(-1, 3*18*11))
    KLD = -0.5 * torch.sum(1 + log_var - mean.pow(2) - log_var.exp())/x.size(0)

    return (args.lambd * recon_loss + (1 - args.lambd) * KLD)  





### Define model

In [5]:
vae = VAE(
    encoder_layer_sizes=args.encoder_layer_sizes,
    latent_size=args.latent_size,
    decoder_layer_sizes=args.decoder_layer_sizes,
    conditional=args.conditional, variational = args.variational,
    num_labels=13 if args.conditional else 0).to(device)

"""Define the optimizer"""
optimizer = torch.optim.Adam(vae.parameters(), lr=args.learning_rate)


"""Prints the number of parameters in the model"""
print(sum(p.numel() for p in vae.parameters()))

logs = defaultdict(list)


1163896


In [6]:
import sys
print(sys.version)

3.6.8 |Anaconda, Inc.| (default, Feb 11 2019, 15:03:47) [MSC v.1915 64 bit (AMD64)]


In [11]:
class ResNet(nn.Module):
    def __init__(self, block, num_blocks, num_classes=13):
        super(ResNet, self).__init__()
        self.in_planes = 64

        self.conv1 = nn.Conv2d(3, 64, kernel_size=3, stride=1, padding=1, bias=False)
        self.bn1 = nn.BatchNorm2d(64)
        self.layer1 = self._make_layer(block, 64, num_blocks[0], stride=1)
        self.layer2 = self._make_layer(block, 128, num_blocks[1], stride=2)
        self.layer3 = self._make_layer(block, 256, num_blocks[2], stride=2)
        self.layer4 = self._make_layer(block, 512, num_blocks[3], stride=2)
        self.linear = nn.Linear(512*block.expansion, num_classes)

    def _make_layer(self, block, planes, num_blocks, stride):
        strides = [stride] + [1]*(num_blocks-1)
        layers = []
        for stride in strides:
            layers.append(block(self.in_planes, planes, stride))
            self.in_planes = planes * block.expansion
        return nn.Sequential(*layers)

    def forward(self, x):
        out = F.relu(self.bn1(self.conv1(x)))
        out = self.layer1(out)
        out = self.layer2(out)
        out = self.layer3(out)
        out = self.layer4(out)
        out = F.avg_pool2d(out, 4)
        out = out.view(out.size(0), -1)
        out = self.linear(out)
        return out
    
class BasicBlock(nn.Module):
    expansion = 1

    def __init__(self, in_planes, planes, stride=1):
        super(BasicBlock, self).__init__()
        self.conv1 = nn.Conv2d(in_planes, planes, kernel_size=3, stride=stride, padding=1, bias=False)
        self.bn1 = nn.BatchNorm2d(planes)
        self.conv2 = nn.Conv2d(planes, planes, kernel_size=3, stride=1, padding=1, bias=False)
        self.bn2 = nn.BatchNorm2d(planes)

        self.shortcut = nn.Sequential()
        if stride != 1 or in_planes != self.expansion*planes:
            self.shortcut = nn.Sequential(
                nn.Conv2d(in_planes, self.expansion*planes, kernel_size=1, stride=stride, bias=False),
                nn.BatchNorm2d(self.expansion*planes)
            )

    def forward(self, x):
        out = F.relu(self.bn1(self.conv1(x)))
        out = self.bn2(self.conv2(out))
        out += self.shortcut(x)
        out = F.relu(out)
        return out
    
inception_model = torch.load("inception_model.trch").to(device)

### Inception score

In [12]:
def inception_score(imgs, inception_model, splits = 1):

    N = len(imgs)    
    split_scores = []
    
    preds = nn.Softmax(dim = 1)(inception_model(imgs))
    

    for k in range(splits):
        part = preds[k * (N // splits): (k+1) * (N // splits), :]
        py = torch.mean(part, 0)
        scores = []
        for i in range(part.shape[0]):
            pyx = part[i, :]
            scores.append(entropy(pyx.cpu().detach().numpy(), py.cpu().detach().numpy()))
        split_scores.append(np.exp(np.mean(scores)))

    return np.mean(split_scores), np.std(split_scores)

In [13]:
print(device)

cuda


In [14]:
for epoch in range(args.epochs):

    tracker_epoch = defaultdict(lambda: defaultdict(dict))

    """Do training"""

    for iteration, (x, y) in enumerate(data_loader):
        """Send data to GPU"""
        x, y = x.type(torch.FloatTensor).to(device), y.type(torch.FloatTensor).to(device)
        

        """CVAE or VAE generates data"""
        if args.conditional:
            recon_x, mean, log_var, z = vae(x, y)
        elif args.variational:
            recon_x, mean, log_var, z = vae(x)
        else:
            recon_x, z = vae(x)

        for i, yi in enumerate(y):
            id = len(tracker_epoch)
            for j in range(args.latent_size):
                tracker_epoch[id][str(j)] = z[i, j].item()
            tracker_epoch[id]['label'] = yi.item()


        """Compute loss"""
        if args.variational or args.conditional:                
            loss = loss_fn(recon_x, x, mean, log_var)
            """Compute KL divergence and binary crossentropy"""
            diverge = -0.5 * torch.sum(1 + log_var - mean.pow(2) - log_var.exp()) / x.size(0)
            if args.loss == "MSE":
                recon_loss = torch.nn.MSELoss()(recon_x.view(-1, 3*18*11), x.view(-1, 3*18*11))
            elif args.loss == "BCE":
                recon_loss = torch.nn.functional.binary_cross_entropy(recon_x.view(-1, 3*18*11), x.view(-1, 3*18*11)) 
            elif args.loss == "L1":
                recon_loss = torch.nn.L1Loss()(recon_x.view(-1, 3*18*11), x.view(-1, 3*18*11))
            logs['KL divergence'].append(-diverge.item())
            logs['Reconstruction Loss'].append(recon_loss.item())

        else:
            if args.loss == "MSE":
                loss = torch.nn.MSELoss()(recon_x.view(-1, 3*18*11), x.view(-1, 3*18*11))
            elif args.loss == "BCE":
                loss = torch.nn.functional.binary_cross_entropy(recon_x.view(-1, 3,18,11), x.view(-1, 3,18,11)) 
            elif args.loss == "L1":
                loss = torch.nn.L1Loss()(recon_x.view(-1, 3*18*11), x.view(-1, 3*18*11))
            #elif args.loss == "Hinge":
            #    loss = torch.nn.MarginRankingLoss()(recon_x.view(-1, 3*18*11), x.view(-1, 3*18*11))


        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        logs['loss'].append(loss.item()/x.size(0))


        if iteration % args.print_every == 0 or iteration == len(data_loader)-1:
            if args.variational or args.conditional:
                #print(" "*50, end = "\r")
                print("Epoch {:02d}/{:02d} Batch {:04d}/{:d}, Loss {:f} KL {:f} Recon loss {:f}".format(
                    epoch+1, args.epochs, iteration, len(data_loader)-1, loss.item(), diverge.item(), recon_loss.item()))#, end="\r")

            else:
                #print(" "*50, end="\r")
                print("Epoch {:02d}/{:02d} Batch {:04d}/{:d}, Loss {:f} ".format(
                    epoch+1, args.epochs, iteration, len(data_loader)-1, loss.item()))#, end="\r")


            """Create images from only latent variable"""
            if args.conditional:
                c = torch.arange(0, 10).long().unsqueeze(1)
                x = vae.inference(n=c.size(0), c=c)
            else:
                x = vae.inference(n=10)


            plt.figure()
            plt.figure(figsize=(5, 10))
            for p in range(10):
                plt.subplot(5, 2, p+1)
                if args.conditional:
                    plt.text(
                        0, 0, "c={:d}".format(c[p].item()), color='black',
                        backgroundcolor='white', fontsize=8)
                plt.imshow(x[p].view( 3,18, 11).data.cpu().numpy().transpose(1,2,0))
                plt.axis('off')

            if not os.path.exists(os.path.join(args.fig_root, str(ts))):
                if not(os.path.exists(os.path.join(args.fig_root))):
                    os.mkdir(os.path.join(args.fig_root))
                os.mkdir(os.path.join(args.fig_root, str(ts)))

            plt.savefig(
                os.path.join(args.fig_root, str(ts),
                             "E{:d}I{:d}.png".format(epoch, iteration)),
                dpi=300)
            plt.clf()
            plt.close('all')

            """Reconstruction of already existing images"""



            rnd_id = random.sample(range(1, len(dataset_test)), 5)
            x = [torch.from_numpy(np.array(dataset_test[i][0])).type(torch.FloatTensor) for i in rnd_id]
            x = torch.stack(x)
            x = x.to(device)

            c = [torch.from_numpy(np.array(dataset_test[i][1])).type(torch.FloatTensor) for i in rnd_id]
            c = torch.stack(c).view(5,1)
            c = c.to(device)
            



            if not args.conditional:
                x = torch.cat((x,vae(x, testing = True)[0].view(5,3,18, 11)),0)
            else:
                x = torch.cat((x,vae(x,c, testing = True)[0].view(5,3,18, 11)),0)




            plt.figure()
            plt.figure(figsize=(5, 10))
            for p in range(5):
                plt.subplot(5, 2, 2*p+1)
                if args.conditional:
                    plt.text(
                        0, 0, "c={:f}".format(c[p].item()), color='black',
                        backgroundcolor='white', fontsize=8)
                plt.imshow(x[p].view(3,18,11).data.cpu().numpy().transpose(1,2,0))
                plt.axis('off')

                plt.subplot(5, 2, 2*p+2)
                if args.conditional:
                    plt.text(
                        0, 0, "c={:f}".format(c[p].item()), color='black',
                        backgroundcolor='white', fontsize=8)
                plt.imshow(x[p+5].view(3,18,11).data.cpu().numpy().transpose(1,2,0))
                plt.axis('off')

            if not os.path.exists(os.path.join(args.fig_root, str(ts))):
                if not(os.path.exists(os.path.join(args.fig_root))):
                    os.mkdir(os.path.join(args.fig_root))
                os.mkdir(os.path.join(args.fig_root, str(ts)))

            plt.savefig(
                os.path.join(args.fig_root, str(ts),
                             "Reconstruction E{:d}I{:d}.png".format(epoch, iteration)),
                dpi=300)
            plt.clf()
            plt.close('all')
    print("")

    """Test model"""

    test_loss = 0
    with torch.no_grad():
        bs = 0
        for iteration, (x, y) in enumerate(data_loader_test):
            #Send data to GPU
            x, y = x.type(torch.FloatTensor).to(device), y.type(torch.FloatTensor).to(device)

            #CVAE or VAE generates data
            if args.conditional:
                recon_x, mean, log_var, z = vae(x, y, testing = True)
            elif args.variational:
                recon_x, mean, log_var, z = vae(x, testing = True)
            else:
                recon_x, z = vae(x, testing = True)

            if not args.conditional and not args.variational:

                if args.test_loss == "MSE":
                    loss_res = torch.nn.MSELoss()(recon_x.view(-1, 3*18*11), x.view(-1, 3*18*11))
                elif args.test_loss == "BCE":
                    loss_res = torch.nn.functional.binary_cross_entropy(recon_x.view(-1, 3*18*11), x.view(-1, 3*18*11)) 
                elif args.test_loss == "L1":
                    loss_res = torch.nn.L1Loss()(recon_x.view(-1, 3*18*11), x.view(-1, 3*18*11)) 
            else:
                loss_res = test_loss_fn(recon_x, x, mean, log_var) 
            test_loss += loss_res.item() * x.size(0)

    test_loss /= len(data_loader_test.dataset)
    logs['test loss'].append(test_loss)

    print("Test loss : ", test_loss)
    
    
    if args.conditional:
        c = torch.randint(0,18,[1000]).long().unsqueeze(1)
        imgs = vae.inference(n=c.size(0), c=c)
    else:
        imgs = vae.inference(n=1000)
    print("Inception score ", inception_score(imgs, inception_model))
    


    """Print the points of the latent space"""
    df = pd.DataFrame.from_dict(tracker_epoch, orient='index')

    if args.representation == None:
        g = sns.lmplot(
        x='0', y='1', hue='label', data=df.groupby('label').head(100),
        fit_reg=False, legend=True)

    elif args.representation == "PCA":
        pca = PCA(n_components=2)
        feat_cols = [str(i) for i in range(args.latent_size)]
        pca_result = pca.fit_transform(df[feat_cols].values)
        df['pca-one'] = pca_result[:,0]
        df['pca-two'] = pca_result[:,1]
        print('Explained variation per principal component: {}'.format(pca.explained_variance_ratio_))
        g = sns.lmplot(
        x='pca-one', y='pca-two', hue='label', data=df.groupby('label').head(100),
        fit_reg=False, legend=True)

    elif args.representation == "TSNE":     
        time_start = time.time()  
        feat_cols = [str(i) for i in range(args.latent_size)] 
        tsne = TSNE(n_components=2, verbose=0, perplexity=40, n_iter=50)
        tsne_results = tsne.fit_transform(df[feat_cols].head(500).values)
        df_tsne = df.head(500).copy()
        df_tsne['x-tsne'] = tsne_results[:,0]
        df_tsne['y-tsne'] = tsne_results[:,1]

        print('t-SNE done! Time elapsed: {} seconds'.format(time.time()-time_start))

        g = sns.lmplot(
        x='x-tsne', y='y-tsne', hue='label', data=df_tsne.groupby('label').head(500),
        fit_reg=False, legend=True)

    g.savefig(os.path.join(
        args.fig_root, str(ts), "E{:d}-Dist.png".format(epoch)),
        dpi=300)


"""Print curves for loss, KL divergence and BCE"""
plt.clf()
plt.plot(logs['loss'], label='Loss')
plt.savefig(os.path.join(args.fig_root, str(ts),"loss_summary.png"),dpi=300)

plt.clf()
plt.plot(logs['test loss'], label='Test Loss')
plt.savefig(os.path.join(args.fig_root, str(ts),"test_loss_summary.png"),dpi=300)

if args.variational or args.conditional:
    plt.clf()
    plt.plot(logs['KL divergence'], label='KL divergence')
    plt.savefig(os.path.join(args.fig_root, str(ts),"KL_summary.png"),dpi=300)

    plt.clf()
    plt.plot(logs['Reconstruction Loss'], label='Reconstruction Loss')
    plt.savefig(os.path.join(args.fig_root, str(ts),"reconstruction_summary.png"),dpi=300)

Epoch 01/70 Batch 0000/32, Loss 0.023888 KL 0.051113 Recon loss 0.020863
Epoch 01/70 Batch 0010/32, Loss 0.018473 KL 0.032722 Recon loss 0.016890
Epoch 01/70 Batch 0020/32, Loss 0.016298 KL 0.023750 Recon loss 0.015470
Epoch 01/70 Batch 0030/32, Loss 0.015209 KL 0.019516 Recon loss 0.014730
Epoch 01/70 Batch 0032/32, Loss 0.015397 KL 0.020832 Recon loss 0.014793

Test loss :  0.01522843289889561
Inception score  (1.1047145, 0.0)
Epoch 02/70 Batch 0000/32, Loss 0.015432 KL 0.021637 Recon loss 0.014742


KeyboardInterrupt: 