In [None]:
from google.colab import drive
drive.mount('/content/drive')

In [None]:
from __future__ import division
from __future__ import print_function
from __future__ import absolute_import
import os
import sys
import shutil
import numpy as np
from itertools import cycle,islice
import torch
import torch.nn as nn
from torch.autograd import Variable
from torch.nn import functional as F
from torch.nn.parameter import Parameter
from torch.utils.data import Dataset, IterableDataset, DataLoader
import torch.optim as optim
import torch.nn.functional as F
from torchvision import transforms
from itertools import combinations
from tqdm import tqdm
import json
# from paths import *

project_dir = "/content/drive/MyDrive/ift6269/project"
os.chdir(project_dir)
data_dir = os.path.join(project_dir, "data")
rec_data_dir = os.path.join(data_dir, "recordings")
mnist_data_dir = os.path.join(data_dir, "mnist")
speech_data_dir = os.path.join(data_dir, "speech")
lookup_embd_dir = os.path.join(data_dir, "lookup_embd")
OUT_DIR = os.path.join(project_dir, "modelnew")

from entity import person
from retabulate import tabulate,init_history,mean_history
import matplotlib.pyplot as plt
import pickle as pkl

In [None]:
N_LATENTS=512
CUDA=torch.cuda.is_available()
EPOCHS=500 
# OUT_DIR="/content/drive/MyDrive/project/modelnew/"
lambda_image=1
lambda_speech=100
learning_rate= 1e-3
annealing_epochs=200
log_interval=100
batch_size= 128
N_mini_batches=48000//batch_size
N_mini_batches_val=188
N_mini_batches_test=317

In [None]:
p = person("Farsi", 'Kannada',dtype=np.float32)

In [None]:
class MyIterableDataset(IterableDataset):
    def __init__(self,p, mode):
        self.data=p
        self.mode=mode
    def parse_data(self):
        if self.mode=="train":
            for sample in self.data._sample(self.data.mnist_1_X_train, 
                                            self.data.mnist_2_X_train, 
                                            self.data.speech_X_train, 
                                            self.data.mnist_1_y_train, 
                                            self.data.mnist_2_y_train, 
                                            self.data.speech_y_train):
                yield sample
        elif self.mode=="val":
            for sample in self.data._sample(self.data.mnist_1_X_valid, 
                                            self.data.mnist_2_X_valid, 
                                            self.data.speech_X_valid, 
                                            self.data.mnist_1_y_valid, 
                                            self.data.mnist_2_y_valid, 
                                            self.data.speech_y_valid):
                yield sample
        else:
            for sample in self.data._sample(self.data.mnist_1_X_test, 
                                            self.data.mnist_2_X_test, 
                                            self.data.speech_X_test, 
                                            self.data.mnist_1_y_test, 
                                            self.data.mnist_2_y_test, 
                                            self.data.speech_y_test):
                yield sample
    # def get_stream(self):
    #     return cycle(self.parse_data())
    def __iter__(self):
        return self.parse_data()

In [None]:
iterable_dataset=MyIterableDataset(p, mode="train")
train_loader=DataLoader(iterable_dataset,batch_size=batch_size)
iterable_dataset_val=MyIterableDataset(p, mode="val")
val_loader=DataLoader(iterable_dataset_val,1)
iterable_dataset_test=MyIterableDataset(p, mode="test")
test_loader=DataLoader(iterable_dataset_test,1)

In [None]:

for batch_idx,((mnist1, mnist2, speech), label_y) in enumerate(test_loader):
    # print(mnist2)
    break
# print(batch_idx)

# Model

In [None]:
class MVAE(nn.Module):
    """Multimodal Variational Autoencoder.
    @param n_latents: integer
                      number of latent dimensions
    """
    def __init__(self, n_latents):
        super(MVAE, self).__init__()
        self.image1_encoder = ImageEncoder(n_latents)
        # self.image1_encoder=self.image1_encoder.apply(self.weights_init)
        self.image1_decoder = ImageDecoder(n_latents)
        self.image2_encoder = ImageEncoder(n_latents)
        # self.image2_encoder=self.image2_encoder.apply(self.weights_init)
        self.image2_decoder = ImageDecoder(n_latents)
        self.speech_encoder  = SpeechEncoder(n_latents)
        # self.speech_encoder=self.speech_encoder.apply(self.weights_init)
        self.speech_decoder  = SpeechDecoder(n_latents)
        self.label_decoder = LabelDecoder(n_latents)

        self.experts       = ProductOfExperts()
        self.n_latents     = n_latents
    def weights_init(self,m):
        if isinstance(m, nn.Linear):
            torch.nn.init.zeros_(m.weight)
            torch.nn.init.zeros_(m.bias)
    
    def reparametrize(self, mu, logvar):
        if self.training:
            std = logvar.mul(0.5).exp_()
            eps = Variable(std.data.new(std.size()).normal_())
            return eps.mul(std).add_(mu)
        else:
          return mu

    def forward(self, image1=None,image2=None, speech=None):
        mu, logvar = self.infer(image1,image2, speech)
        # reparametrization trick to sample
        z          = self.reparametrize(mu, logvar)
        # reconstruct inputs based on that gaussian
        img1_recon  = self.image1_decoder(z)
        img2_recon  = self.image2_decoder(z)
        sp_recon  = self.speech_decoder(z)
        label_recon=self.label_decoder(z)
        return img1_recon,img2_recon, sp_recon,label_recon, mu, logvar

    def infer(self, image1=None,image2=None, speech=None):

        batch_size = image1.size(0) if image1 is not None else image2.size(0) if image2 is not None else speech.size(0)
        use_cuda   = next(self.parameters()).is_cuda  # check if CUDA
        # initialize the universal prior expert
        mu, logvar = prior_expert((1, batch_size, self.n_latents), 
                                  use_cuda=use_cuda)
        if image1 is not None:
            img1_mu, img1_logvar = self.image1_encoder(image1)
            mu     = torch.cat((mu, img1_mu.unsqueeze(0)), dim=0)
            logvar = torch.cat((logvar, img1_logvar.unsqueeze(0)), dim=0)
        if image2 is not None:
            img2_mu, img2_logvar = self.image2_encoder(image2)
            mu     = torch.cat((mu, img2_mu.unsqueeze(0)), dim=0)
            logvar = torch.cat((logvar, img2_logvar.unsqueeze(0)), dim=0)

        if speech is not None:
            sp_mu, sp_logvar = self.speech_encoder(speech)
            mu     = torch.cat((mu, sp_mu.unsqueeze(0)), dim=0)
            logvar = torch.cat((logvar, sp_logvar.unsqueeze(0)), dim=0)

        # product of experts to combine gaussians
        mu, logvar = self.experts(mu, logvar)
        return mu, logvar

In [None]:
class ImageEncoder(nn.Module):
    """Parametrizes q(z|x).
    @param n_latents: integer
                      number of latent dimensions
    """
    def __init__(self, n_latents):
        super(ImageEncoder, self).__init__()
        self.fc1   = nn.Linear(784, 2048)
        self.fc1_2   = nn.Linear(2048, 1024)
        self.fc2   = nn.Linear(1024, 512)
        self.fc31  = nn.Linear(512, n_latents)
        self.fc32  = nn.Linear(512, n_latents).apply(self.weights_init)
        self.swish = Swish()
    def weights_init(self,m):
        torch.nn.init.zeros_(m.weight)
        torch.nn.init.zeros_(m.bias)
    def forward(self, x):
        xx=x.view(-1, 784)
        h= self.fc1(xx.float())
        h = self.swish(h)
        h = self.swish(self.fc1_2(h))
        h = self.swish(self.fc2(h))
        return self.fc31(h), self.fc32(h)

In [None]:
class ImageDecoder(nn.Module):
    """Parametrizes p(x|z).
    @param n_latents: integer
                      number of latent dimensions
    """
    def __init__(self, n_latents):
        super(ImageDecoder, self).__init__()
        self.fc1   = nn.Linear(n_latents, 512)
        self.fc2   = nn.Linear(512, 512)
        self.fc3   = nn.Linear(512, 512)
        self.fc3_4   = nn.Linear(512, 512)
        self.fc4   = nn.Linear(512, 784)
        self.swish = Swish()

    def forward(self, z):
        h = self.swish(self.fc1(z))
        h = self.swish(self.fc2(h))
        h = self.swish(self.fc3(h))
        h = self.swish(self.fc3_4(h))
        return self.fc4(h)

In [None]:
class SpeechEncoder(nn.Module):
    """Parametrizes q(z|y).
    @param n_latents: integer
                      number of latent dimensions
    """
    def __init__(self, n_latents):
        super(SpeechEncoder, self).__init__()
        self.fc1   = nn.Linear(13, 512)
        self.fc2   = nn.Linear(512, 512)
        self.fc31  = nn.Linear(512, n_latents)
        self.fc32  = nn.Linear(512, n_latents).apply(self.weights_init)
        self.swish = Swish()
    def weights_init(self,m):
        torch.nn.init.zeros_(m.weight)
        torch.nn.init.zeros_(m.bias)
    def forward(self, x):
        h = self.swish(self.fc1(x.float()))
        h = self.swish(self.fc2(h))
        return self.fc31(h), self.fc32(h)

In [None]:
#TODO
class SpeechDecoder(nn.Module):
    """Parametrizes p(y|z).
    @param n_latents: integer
                      number of latent dimensions
    """
    def __init__(self, n_latents):
        super(SpeechDecoder, self).__init__()
        self.fc1   = nn.Linear(n_latents, 512)
        self.fc2   = nn.Linear(512, 512)
        self.fc3   = nn.Linear(512, 512)
        self.fc4   = nn.Linear(512, 13)
        self.swish = Swish()

    def forward(self, z):
        h = self.swish(self.fc1(z))
        h = self.swish(self.fc2(h))
        h = self.swish(self.fc3(h))
        return self.fc4(h)  # NOTE: no softmax here. See train.py

In [None]:
class LabelDecoder(nn.Module):
    """Parametrizes p(y|z).
    @param n_latents: integer
                      number of latent dimensions
    """
    def __init__(self, n_latents):
        super(LabelDecoder, self).__init__()
        self.fc1   = nn.Linear(n_latents, 512)
        self.fc2   = nn.Linear(512, 512)
        self.fc3   = nn.Linear(512, 512)
        self.fc4   = nn.Linear(512, 10)
        self.swish = Swish()

    def forward(self, z):
        h = self.swish(self.fc1(z))
        h = self.swish(self.fc2(h))
        h = self.swish(self.fc3(h))
        return self.fc4(h)

In [None]:
class ProductOfExperts(nn.Module):
    """Return parameters for product of independent experts.
    See https://arxiv.org/pdf/1410.7827.pdf for equations.
    @param mu: M x D for M experts
    @param logvar: M x D for M experts
    """
    def forward(self, mu, logvar, eps=1e-8):
        var       = torch.exp(logvar) + eps
        # precision of i-th Gaussian expert at point x
        T         = 1. / (var + eps)
        pd_mu     = torch.sum(mu * T, dim=0) / torch.sum(T, dim=0)
        pd_var    = 1. / torch.sum(T, dim=0)
        pd_logvar = torch.log(pd_var + eps)
        return pd_mu, pd_logvar

In [None]:
class Swish(nn.Module):
    """https://arxiv.org/abs/1710.05941"""
    def forward(self, x):
        # return x * F.sigmoid(x)
        return x *( torch.tanh(F.softplus(x)))

In [None]:
class Mish(nn.Module):
    def __init__(self):
        super().__init__()

    def forward(self, x):
        #inlining this saves 1 second per epoch (V100 GPU) vs having a temp x and then returning x(!)
        return x *( torch.tanh(F.softplus(x)))

In [None]:
def prior_expert(size, use_cuda=False):
    """Universal prior expert. Here we use a spherical
    Gaussian: N(0, 1).
    @param size: integer
                 dimensionality of Gaussian
    @param use_cuda: boolean [default: False]
                     cast CUDA on variables
    """
    mu     = Variable(torch.zeros(size))
    logvar = Variable(torch.zeros(size))
    if use_cuda:
        mu, logvar = mu.cuda(), logvar.cuda()
    return mu, logvar

In [None]:
def save_checkpoint(state, is_best, folder='./', filename='checkpoint.pth.tar'):
    if not os.path.isdir(folder):
        os.mkdir(folder)
    torch.save(state, os.path.join(folder, filename))
    if is_best:
        shutil.copyfile(os.path.join(folder, filename),
                        os.path.join(folder, 'model_best.pth.tar'))


def load_checkpoint(file_path, use_cuda=False):
    checkpoint = torch.load(file_path) if use_cuda else \
        torch.load(file_path, map_location=lambda storage, location: storage)
    model = MVAE(checkpoint['n_latents'])
    model.load_state_dict(checkpoint['state_dict'])
    return model

# Load Model

In [None]:
model=load_checkpoint(OUT_DIR+'/checkpoint.pth.tar',use_cuda=True)
device = torch.device("cuda")
model.to(device)

In [None]:
def get_mu_logvar(digit=7):
  model.eval()
  val_loss = 0
  count=0
  # for simplicitly, here i'm only going to track the joint loss. 
  # pbar = tqdm(total=len(test_loader))
  for batch_idx, ((image1,image2,speech),y) in enumerate(val_loader):
      imorig1=image1
      imorig2=image2
      label=np.argmax(y).cpu().detach().numpy()
      if label == digit: 
        # print(y) 
        break
  if CUDA:
      image1     = image1.cuda().float()
      image2     = image2.cuda().float()
      speech     = speech.cuda().float()
      y     = y.cuda().float()
  image1         = Variable(image1)
  image2         = Variable(image2)
  speech         = Variable(speech)
  y         = Variable(y)
  batch_size = image1.size(0)

  recon_image1, recon_image2, recon_speech, recon_label, mu, logvar = model(image1=image1, image2=image2, speech=speech)

  return mu, logvar

In [None]:
def show_image1(model, Tens):
  image1_recons=model.image1_decoder(Tens)
  image1_recons=image1_recons.view(-1,28,28)
  w=10
  h=10
  fig=plt.figure(figsize=(10, 10))
  columns = 10
  rows = 10
  for i in range(0, columns*rows):
      img = image1_recons[i,:,:].cpu().detach().numpy()
      fig.add_subplot(rows, columns, i+1)
      plt.axis('off')
      # fig.set_aspect('equal')
      plt.subplots_adjust(wspace=0, hspace=0)
      plt.imshow(img,interpolation='nearest')
  # plt.show()
  # plt.savefig(OUT_DIR+"/Fa_7(5,3,7,0).png",bbox_inches='tight', pad_inches = 0)

In [None]:
def show_image2(model, Tens):
  image2_recons=model.image2_decoder(Tens)
  image2_recons=image2_recons.view(-1,28,28)
  fig=plt.figure(figsize=(10, 10))
  columns = 10
  rows = 10
  for i in range(0, columns*rows):
      img = image2_recons[i,:,:].cpu().detach().numpy()
      fig.add_subplot(rows, columns, i+1)
      plt.axis('off')
      # fig.set_aspect('equal')
      plt.subplots_adjust(wspace=0, hspace=0)
      plt.imshow(img,interpolation='nearest')
  # plt.show()
  # plt.savefig(OUT_DIR+"/Ka_7(5,3,7,0).png",bbox_inches='tight', pad_inches = 0)

In [None]:
from copy import deepcopy
def discover_disentanglement(z, idx, change, num=100):
  # z= model.reparametrize(mu,logvar).clone().detach()
  Z = torch.randn(num,512).to(device)
  sub = [-change*i for i in range(1,num//2+1)][::-1]
  add = [change*i for i in range(1,num//2+1)]
  # print(sub+add)
  for i, ch in enumerate(sub+add):
    Z[i] = deepcopy(z)
    Z[i][idx] += ch 
  return Z


In [None]:
denorm  = lambda x: x*0.5+0.5
def plot_line(img_tensor,save=False,name='default'):
    _,axs = plt.subplots(1,11,figsize=(10,5))
    img_tensor = denorm(img_tensor.cpu())
    for i,ax in enumerate(axs.flatten()):
        img = img_tensor[i].detach().numpy()
        ax.imshow(img)
        ax.axis('off')
    if save: plt.savefig(name)
    plt.show()

In [None]:
def get_z(digit, model):
  mu, logvar = get_mu_logvar(digit=digit)
  return model.reparametrize(mu,logvar)

# Disentanglement

In [None]:
z = get_z(digit=7, model=model)
tens=discover_disentanglement(z.clone().detach(), 201,30)
show_image1(model, tens)
show_image2(model, tens)

# Transition / Arithematics in latent space

In [None]:
z0 = get_z(digit=5, model=model)
z1 = get_z(digit=0, model=model)
alpha = torch.arange(0,1.1,0.1).to(device).view(-1,1)
z_dash_a = alpha*z1 + (1-alpha)*z0
image_tensor_1 = model.image1_decoder(z_dash_a).view(-1,28,28)
image_tensor_2 = model.image2_decoder(z_dash_a).view(-1,28,28)
plot_line(image_tensor_1,False,'line1')
plot_line(image_tensor_2,False,'line2')