In [None]:
# importing modules


import pickle
import torch
import torchvision
import torchvision.transforms as transforms
import torch.nn as nn
import torch.optim as optim
import scipy.io
import numpy as np
import matplotlib.pyplot as plt
from torch.utils.data import Dataset
from torchvision.utils import make_grid, save_image

In [None]:
# setting hyperparameters


img_size = 64
batch_size = 100
num_threads = 8
lat_dim = 10

beta = 4

device = torch.device("cuda:1" if torch.cuda.is_available() else "cpu")

In [None]:
# contructing dataset from the generated sample data from the 'SamplesGenerator.ipynb'


class SampledData(Dataset):
    def __init__(self, mode, file='data/dsprites/base.pt'):
        samples = torch.load(file)
        img = samples[:, :-1]
        label = samples[:, -1]
        self.mode = mode
        if self.mode == 'train':
            self.x = img[:-batch_size*100]
            self.y = label[:-batch_size*100].type(dtype=torch.int64)
        elif self.mode == 'test':
            self.x = img[-batch_size*100:]
            self.y = label[-batch_size*100:].type(dtype=torch.int64)
                    
    def __getitem__(self, idx):
        return self.x[idx], self.y[idx]
    
    def __len__(self):
        return len(self.x)

In [None]:
# data loader for generated samples

def samples_loader():
    train_data = SampledData('train')
    test_data = SampledData('test')
    train_loader = torch.utils.data.DataLoader(train_data, batch_size=batch_size, shuffle=False, num_workers=num_threads)
    test_loader = torch.utils.data.DataLoader(test_data, batch_size=batch_size, shuffle=False, num_workers=num_threads)
    return train_loader, test_loader

In [None]:
# util function mostly for visualization purpose


def param_fix(layer):
    for param in layer.parameters():
        param.requires_grad = False

def plot_loss_curve(generated, train_list, test_list):
    fig = plt.figure()       
    plt.plot(train_list, 'r-', label='train loss')
    plt.plot(test_list, 'r--', label='test loss')
    plt.legend()
    plt.savefig('plots/'+generated+'/loss_curves.png')
    plt.close()
    
def plot_accur_curve(accur_list):
    fig = plt.figure()       
    plt.plot(accur_list, 'b-', label='accuracy')
    plt.legend()
    plt.savefig('plots/accur/accuracys.png')
    plt.close()
    

In [None]:
# Contructing BetaVAE model which have been appeared in 'BetaVAE(shapes).ipynb'

class BetaVAE(nn.Module):
    def __init__(self, supervised, generated='Bernoulli', mode='learn'):
        super(BetaVAE, self).__init__()
        self.supervised = supervised
        self.generated = generated
        self.mode = mode
        self.sigmoid = nn.Sigmoid()
        self.encoder = nn.Sequential(nn.Linear(4096, 1200),
                                     nn.ReLU(),
                                     nn.Linear(1200, 1200),
                                     nn.ReLU(),
                                     nn.Linear(1200, lat_dim*2))
                                    
        self.decoder = nn.Sequential(nn.Linear(lat_dim, 1200),
                                     nn.Tanh(),
                                     nn.Linear(1200, 1200),
                                     nn.Tanh(),
                                     nn.Linear(1200, 1200),
                                     nn.Tanh(),
                                     nn.Linear(1200, 4096))
        
        self.layer = nn.Sequential(nn.Linear(lat_dim, 5),
                                   nn.Softmax())
        
    def reparametrize(self, z_mu, z_log_var):
        std = torch.exp(0.5 * z_log_var)
        eps = torch.randn(std.size()).to(device)
        return z_mu + std * eps

    def encoderNet(self, x):
        code = self.encoder(x)
        z_mu = code[:, :lat_dim]
        z_log_var = code[:, lat_dim:]
        z = self.reparametrize(z_mu, z_log_var)
        self.kl = -0.5 * ((1 + z_log_var) - z_mu * z_mu - torch.exp(z_log_var)).mean(dim=0).sum()
        return z

    def decoderNet(self, z):
        h = self.decoder(z)
        x_ = self.sigmoid(h)
        return x_
        
    def forward(self, x):
        if self.mode == 'learn':
            self.z = self.encoderNet(x)
            if not self.supervised:
                self.x_ = self.decoderNet(self.z)
                self.recon = -(x * torch.log(self.x_ + 1e-10) + (1 - x) * torch.log(1 - self.x_ + 1e-10)).mean(dim=0).sum()
                return self.x_, self.recon, self.kl
            else:
                self.factors = self.layer(self.z)
                return self.factors
        elif self.mode == 'generate':
            self.x_ = self.decoderNet(x)
            return self.x_

In [None]:
# main function for training linear classifier with samples data

def Shapes_evaluate(dist='Bernoulli'):
    train_loader, test_loader = samples_loader()
    model = BetaVAE(supervised=True, generated=dist).to(device)
    #model.load_state_dict(torch.load('models/'+dist+'/BetaVAE.pt')) 
    model.load_state_dict(torch.load('models/'+dist+'/Classifiers.pt')) 
    param_fix(model.encoder)
    param_fix(model.decoder)
    #print(model.supervised)

    model.mode = 'learn'
    optimizer = optim.Adagrad(model.parameters(), lr=lr)
    criterion = nn.CrossEntropyLoss()

    #for name, param in model.named_parameters():
    #    if param.requires_grad:
    #        print(name)
    
    train_list, test_list, accur_list = [], [], []
    
    with open('accur/accur_list.txt', 'rb') as f:
        accur_list = pickle.load(f)
    
    for epoch in range(n_epoch):        
        train_loss, test_loss = 0.0, 0.0
        cnt = 0
        for x, y in train_loader:
            cnt += 1
            inputs, labels = x.to(device), y.to(device)
            half = int(0.5*batch_size)
            batch1 = inputs[:half]
            batch2 = inputs[-half:]
            
            z1 = model.encoderNet(batch1)
            z2 = model.encoderNet(batch2)
            z_diff = torch.abs(z1-z2)
            z_mean = z_diff.mean(dim=0).view(1, -1)
            label = labels[0].view(-1)
            
            factor = model.layer(z_mean)
            loss = criterion(factor, label)
            
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()
            
            train_loss += loss
        
        train_loss = train_loss / cnt
        train_list.append(train_loss)
        
        with torch.no_grad():
            cnt = 0
            accur = 0
            for x, y in test_loader:
                cnt += 1
                inputs, labels = x.to(device), y.to(device)
                half = int(0.5*batch_size)
                batch1 = inputs[:half]
                batch2 = inputs[-half:]
            
                z1 = model.encoderNet(batch1)
                z2 = model.encoderNet(batch2)
                z_diff = torch.abs(z1-z2)
                z_mean = z_diff.mean(dim=0).view(1, -1)
                label = labels[0].view(-1)
            
                factor = model.layer(z_mean)
                loss = criterion(factor, label)
            
                test_loss += loss
                accur += sum(torch.argmax(factor, dim=1) == label)     
    
            test_loss = test_loss / cnt
            test_list.append(test_loss)
                    
            accuracy = accur.item() / cnt * 100
            accur_list.append(accuracy)
            
        torch.save(model.state_dict(), 'models/'+dist+'/Classifiers.pt')
        
        with open('accur/accur_list.txt', 'wb') as f:
            pickle.dump(accur_list, f)
    
        if (epoch+1) % 1 == 0:
            print('[Epoch %d] train_loss: %.3f, test_loss: %.3f' 
                  % (epoch+1, train_loss, test_loss))

    print('accuracy is ', max(accur_list),'%')
    plot_loss_curve(model.generated, train_list, test_list)
    plot_accur_curve(accur_list)

In [None]:
# 2nd experiment : running code for training linear classifier so that calculate the disentanglement metric via accuracy


if __name__ == '__main__': 
    print(device)
    lr = 1e-2
    n_epoch = 100
    Shapes_evaluate()