In [31]:
import torch
import os
import numpy as np
from torch.utils.data import DataLoader
import matplotlib
import matplotlib.pyplot as plt
import import_ipynb
import gibbs_sampler_poise
import kl_divergence_calculator
import data_preprocessing
from torchvision.utils import save_image
import torch.nn as nn
import torch.optim as optim
from torchvision import datasets
import torchvision.transforms as transforms
from torch.nn import functional as F  #for the activation function
from torchviz import make_dot
from torch.utils.tensorboard import SummaryWriter
from torchvision import datasets, transforms
import torchvision
import umap
import random
import shutil
from numpy import prod

In [2]:
_device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
# learning parameters
latent_dim1 = 32
latent_dim2 = 16
batch_size = 10
dim_MNIST   = 784
lr = 1e-4
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
tx = transforms.ToTensor()
MNIST_TRAINING_PATH = "/home/achint/Practice_code/VAE/MNIST/MNIST/processed/training.pt"
SVHN_TRAINING_PATH  = "/home/achint/Practice_code/VAE/SVHN/train_32x32.mat"
MNIST_TEST_PATH     = "/home/achint/Practice_code/VAE/MNIST/MNIST/processed/test.pt"
SVHN_TEST_PATH  = "/home/achint/Practice_code/VAE/SVHN/test_32x32.mat"
SUMMARY_WRITER_PATH = "/home/achint/Practice_code/logs"
RECONSTRUCTION_PATH = "/home/achint/Practice_code/Updated_POISE_VAE/MNIST_SVHN/reconstructions/"
PATH = "/home/achint/Practice_code/Updated_POISE_VAE/MNIST_SVHN/mnist_svhn_parameters.txt"

In [3]:
# Remove the logs directory and the reconstruction directory 
if os.path.exists(RECONSTRUCTION_PATH):
    shutil.rmtree(RECONSTRUCTION_PATH)
    os.makedirs(RECONSTRUCTION_PATH)

if os.path.exists(SUMMARY_WRITER_PATH):
    shutil.rmtree(SUMMARY_WRITER_PATH)
    os.makedirs(SUMMARY_WRITER_PATH)

In [75]:
## Importing MNIST and SVHN datasets
joint_dataset_train=data_preprocessing.JointDataset(mnist_pt_path=MNIST_TRAINING_PATH,
                             svhn_mat_path=SVHN_TRAINING_PATH)
joint_dataset_test = data_preprocessing.JointDataset(mnist_pt_path=MNIST_TEST_PATH,
                             svhn_mat_path=SVHN_TEST_PATH)

joint_dataset_train_loader = DataLoader(
    joint_dataset_train,
    batch_size=batch_size,
    shuffle=True,
    drop_last=True
)
joint_dataset_test_loader = DataLoader(
    joint_dataset_test,
    batch_size=batch_size,
    shuffle=False,
    drop_last=True
)

In [76]:
def _latent_dims_type_setter(lds):
    ret, ret_flatten = [], []
    for ld in lds:
        if hasattr(ld, '__iter__'): # Iterable
            ld_tuple = tuple([i for i in ld])
            if not all(map(lambda i: isinstance(i, int), ld_tuple)):
                raise ValueError('`latent_dim` must be either iterable of ints or int.')
            ret.append(ld_tuple)
            ret_flatten.append(int(prod(ld_tuple)))
        elif isinstance(ld, int):
            ret.append((ld, ))
            ret_flatten.append(ld)
        else:
            raise ValueError('`latent_dim` must be either iterable of ints or int.')
    return ret, ret_flatten

In [77]:
class POISEVAE(nn.Module):
    __version__ = 2.0
    
    def __init__(self, encoders, decoders, batch_size, loss, latent_dims=None,
                 device=_device):
        """
        encoders: list of nn.Module
            Each encoder must have an attribute `latent_dim` specifying the dimension of the
            latent space to which it encodes. An alternative way to avoid adding this attribute
            is to specify the `latent_dims` parameter (see below). 
            Note that each `latent_dim` must be unsqueezed, e.g. (10, ) is not the same as (10, 1).
            
        decoders: list of nn.Module
            The number and indices of decoders must match those of encoders.
            
        batch_size: int
        
        loss: str
            Can either be 'MSE' for MSE loss or 'BCE' for BCE loss. The users should properly 
            restrict the range of the output of their decoders for the loss chosen.
        
        latent_dims: iterable, optional; default None
            The dimensions of the latent spaces to which the encoders encode. The indices of the 
            entries must match those of encoders. An alternative way to specify the dimensions is
            to add the attribute `latent_dim` to each encoder (see above).
            Note that each entry must be unsqueezed, e.g. (10, ) is not the same as (10, 1).
        
        device: torch.device, optional
        """
        super(POISEVAE,self).__init__()

        if len(encoders) != len(decoders):
            raise ValueError('The number of encoders must match that of decoders.')
        
        if len(encoders) > 2:
            raise NotImplementedError('> 3 latent spaces not yet supported.')
        
        # Type check
        if not all(map(lambda x: isinstance(x, nn.Module), (*encoders, *decoders))):
            raise TypeError('`encoders` and `decoders` must be lists of `nn.Module` class.')

        # Get the latent dimensions
        if latent_dims is not None:
            if not hasattr(latent_dims, '__iter__'): # Iterable
                raise TypeError('`latent_dims` must be iterable.')
            self.latent_dims = latent_dims
        else:
            self.latent_dims = tuple(map(lambda l: l.latent_dim, encoders))
        self.latent_dims, self.latent_dims_flatten = _latent_dims_type_setter(self.latent_dims)

        if batch_size <= 0:
            raise ValueError('Invalid batch size')
        self.batch_size = batch_size
        
        if loss not in ['MSE', 'BCE']: 
            raise NotImplementedError('Not yet supported for other loss functions')
        self.loss = loss
        
        self.encoders = nn.ModuleList(encoders)
        self.decoders = nn.ModuleList(decoders)
        
        self.device = device

        self.gibbs = gibbs_sampler(self.latent_dims_flatten, batch_size)
        self.kl_div = kl_divergence(self.latent_dims_flatten, batch_size)

        self.register_parameter(name='g11', 
                                param=nn.Parameter(torch.randn(*self.latent_dims_flatten, 
                                                               device=self.device)))
        self.register_parameter(name='g22', 
                                param=nn.Parameter(torch.randn(*self.latent_dims_flatten, 
                                                               device=self.device)))
        self.flag_initialize = 1

    def _decoder_helper(self):
        """
        Reshape samples drawn from each latent space, and decode with considering the loss function
        """
        ret = []
        for decoder, z, ld in zip(self.decoders, self.z_gibbs_posteriors, self.latent_dims):
            z = z.view(self.batch_size, *ld) # Match the shape to the output
            x_ = decoder(z)
            ret.append(x_)
        return ret

    def forward(self, x):
        """
        Return
        ------
        results: dict
            z: list of torch.Tensor
                Samples from the posterior distributions in the corresponding latent spaces
            x_rec: list of torch.Tensor
                Reconstructed samples
            mu: list of torch.Tensor
                Posterior distribution means
            var: list of torch.Tensor
                Posterior distribution variances
            total_loss: torch.Tensor
            rec_losses: list of torch.tensor
                Reconstruction loss for each dataset
            KL_loss: torch.Tensor
        """
        mu, var = [], []
        for i, xi in enumerate(x):
            _mu, _log_var = self.encoders[i].forward(xi)
            mu.append(_mu.view(self.batch_size, -1))
            var.append(-torch.exp(_log_var.view(self.batch_size, -1)))

        g22 = -torch.exp(self.g22)

        # Initializing gibbs sample
        if self.flag_initialize == 1:
            z_priors = self.gibbs.sample(self.g11, g22, n_iterations=5000)
            z_posteriors = self.gibbs.sample(self.g11, g22, lambda1s=mu, lambda2s=var,
                                             n_iterations=5000)

            self.z_priors = z_priors
            self.z_posteriors = z_posteriors
            self.flag_initialize = 0

        z_priors = list(map(lambda z: z.detach(), self.z_priors))
        z_posteriors = list(map(lambda z: z.detach(), self.z_posteriors))

        # If lambda not provided, treat as zeros to save memory and computation
        self.z_gibbs_priors = self.gibbs.sample(self.g11, g22, z=z_priors, n_iterations=5)
        self.z_gibbs_posteriors = self.gibbs.sample(self.g11, g22, lambda1s=mu, lambda2s=var,
                                                    z=z_posteriors, n_iterations=5)

        self.z_priors = list(map(lambda z: z.detach(), self.z_gibbs_priors))
        self.z_posteriors = list(map(lambda z: z.detach(), self.z_gibbs_posteriors))

        G = torch.block_diag(self.g11, self.g22)

        x_ = self._decoder_helper() # Decoding

        # self.z2_gibbs_posterior = self.z2_gibbs_posterior.squeeze()
        for i in range(len(self.z_gibbs_posteriors)):
            self.z_gibbs_posteriors[i] = self.z_gibbs_posteriors[i].squeeze()

        # KL loss
        kls = self.kl_div.calc(G, self.z_gibbs_posteriors, self.z_gibbs_priors, mu,var)
        KL_loss  = sum(kls)

        # Reconstruction loss
        rec_loss_func = nn.MSELoss(reduction='sum') if self.loss == 'MSE' else \
                        nn.BCELoss(reduction='sum')
        recs = list(map(lambda x: rec_loss_func(x[0], x[1]), zip(x_, x)))
        rec_loss = sum(recs)
        
        # Total loss
        total_loss = KL_loss + rec_loss

        results = {
            'z': self.z_posteriors, 'x_rec': x_, 'mu': mu, 'var': var, 
            'total_loss': total_loss, 'rec_losses': recs, 'KL_loss': KL_loss
        }

        return results

In [78]:
class Encoder1(nn.Module):
    def __init__(self):
        super(Encoder1, self).__init__()
        self.latent_dim = 32
        self.dim_MNIST   = 784

        ## Encoder set1(MNIST)
        self.set1_enc1 = nn.Linear(in_features = self.dim_MNIST,out_features = 512)
        self.set1_enc2 = nn.Linear(in_features = 512,out_features = 128)
        self.set1_enc3 = nn.Linear(in_features = 128,out_features = 2*self.latent_dim)

    def forward(self, x):
        # Modality 1 (MNIST)
        x       = F.relu(self.set1_enc1(x))
        x       = F.relu(self.set1_enc2(x))  
        x       = self.set1_enc3(x).view(-1,2,self.latent_dim)  # ->[128,2,32]
        mu      = x[:,0,:] # ->[128,32]
        log_var = x[:,1,:] # ->[128,32]
        var     = -torch.exp(log_var)           #lambdap_2<0
        return mu, log_var
    
class Encoder2(nn.Module):
    # 64*64 -> 40*40 -> 16*16 -> 4*4
    def __init__(self):
        super(Encoder2, self).__init__()
        self.latent_dim = 16     
        # input size: 3 x 32 x 32
        self.set2_enc1 = nn.Conv2d(in_channels=3, out_channels=2*self.latent_dim, kernel_size=4, stride=2, padding=1)
        # size: 32 x 16 x 16
        self.set2_enc2 = nn.Conv2d(in_channels=2*self.latent_dim, out_channels=2*self.latent_dim, kernel_size=4, stride=2, padding=1)
        # size: 32 x 8 x 8
        self.set2_enc3 = nn.Conv2d(in_channels=2*self.latent_dim, out_channels=self.latent_dim, kernel_size=4, stride=2, padding=1)
        # size: 16 x 4 x 4   
        self.SVHNc1 = nn.Conv2d(self.latent_dim, self.latent_dim, 4, 1, 0)
        # size: 16 x 1 x 1
        self.SVHNc2 = nn.Conv2d(self.latent_dim, self.latent_dim, 4, 1, 0)
        # size: 16 x 1 x 1
    def forward(self, x):
        # Modality 2 (SVHN)
        x = x.view(-1,3, 32,32) 
        x = F.relu(self.set2_enc1(x))
        x = F.relu(self.set2_enc2(x))
        x = F.relu(self.set2_enc3(x))
        # get 'mu' and 'log_var' for SVHN
        mu = (self.SVHNc1(x).squeeze(3)).squeeze(2)
        log_var = (self.SVHNc2(x).squeeze(3)).squeeze(2)
        return mu, log_var
    
class Decoder1(nn.Module):
    def __init__(self):
        super(Decoder1, self).__init__()  
        self.latent_dim = 32
        self.dim_MNIST   = 784
        ## Decoder set1(MNIST)
        self.set1_dec1 = nn.Linear(in_features = self.latent_dim,out_features = 128)
        self.set1_dec2 = nn.Linear(in_features = 128,out_features = 512)
        self.set1_dec3 = nn.Linear(in_features = 512,out_features = self.dim_MNIST)
    def forward(self,x):
        x = F.relu(self.set1_dec1(x))
        x = self.set1_dec2(x) 
        x = self.set1_dec3(x)
        return x
        
        
class Decoder2(nn.Module):
    def __init__(self):
        super(Decoder2, self).__init__()  
        self.latent_dim = 16    
        ## Decoder set2(SVHN)
        # input size: 16x1x1
        self.set2_dec0 = nn.ConvTranspose2d(in_channels=self.latent_dim,out_channels=self.latent_dim, kernel_size=4, stride=1, padding=0)
        # input size: 16x4x4
        self.set2_dec1 = nn.ConvTranspose2d(in_channels=self.latent_dim,out_channels=2*self.latent_dim, kernel_size=3, stride=1, padding=1)
        # size: 32 x 4 x 4
        self.set2_dec2 = nn.ConvTranspose2d(in_channels=2*self.latent_dim,out_channels=2*self.latent_dim, kernel_size=5, stride=1, padding=0)
        # size: 32 x 8 x 8
        self.set2_dec3 = nn.ConvTranspose2d(in_channels=2*self.latent_dim,out_channels=2*self.latent_dim, kernel_size=4, stride=2, padding=1)
        # size: 32 x 16 x 16
        self.set2_dec4 = nn.ConvTranspose2d(in_channels=2*self.latent_dim,out_channels=3, kernel_size=4, stride=2, padding=1)
        # size: 3 x 32 x 32
    def forward(self,x):
        x = F.relu(self.set2_dec0(x))
        x = F.relu(self.set2_dec1(x))
        x = F.relu(self.set2_dec2(x))
        x = F.relu(self.set2_dec3(x))
        x = self.set2_dec4(x).view(-1,3072)
        return x
    
enc1 = Encoder1()
enc2 = Encoder2().to(_device)
dec1 = Decoder1().to(_device)
dec2 = Decoder2().to(_device)


In [86]:
state = torch.load(PATH)
model = POISEVAE([enc1, enc2], [dec1, dec2], batch_size, loss='MSE', latent_dims=[32, (16, 1, 1)]).to(device)
optimizer = optim.Adam(model.parameters(),lr=lr)
model.load_state_dict(state['state_dict'])
optimizer.load_state_dict(state['optimizer'])
for name, para in model.named_parameters():
    print(name)

g11
g22
encoders.0.set1_enc1.weight
encoders.0.set1_enc1.bias
encoders.0.set1_enc2.weight
encoders.0.set1_enc2.bias
encoders.0.set1_enc3.weight
encoders.0.set1_enc3.bias
encoders.1.set2_enc1.weight
encoders.1.set2_enc1.bias
encoders.1.set2_enc2.weight
encoders.1.set2_enc2.bias
encoders.1.set2_enc3.weight
encoders.1.set2_enc3.bias
encoders.1.SVHNc1.weight
encoders.1.SVHNc1.bias
encoders.1.SVHNc2.weight
encoders.1.SVHNc2.bias
decoders.0.set1_dec1.weight
decoders.0.set1_dec1.bias
decoders.0.set1_dec2.weight
decoders.0.set1_dec2.bias
decoders.0.set1_dec3.weight
decoders.0.set1_dec3.bias
decoders.1.set2_dec0.weight
decoders.1.set2_dec0.bias
decoders.1.set2_dec1.weight
decoders.1.set2_dec1.bias
decoders.1.set2_dec2.weight
decoders.1.set2_dec2.bias
decoders.1.set2_dec3.weight
decoders.1.set2_dec3.bias
decoders.1.set2_dec4.weight
decoders.1.set2_dec4.bias


In [87]:
def train(model,joint_dataloader,epoch):
    model.train()
    running_loss = 0.0
    running_mse1 = 0.0
    running_mse2 = 0.0
    running_kld  = 0.0
    running_loss = 0.0
    for i,joint_data in enumerate(joint_dataloader):
        data1    = joint_data[0]
        data1    = data1.float()
        data2   = joint_data[1]
        data2   = data2.float()
        data1    = data1.to(device)
        data2   = data2.to(device)
        data1    = data1.view(data1.size(0), -1)
        data2   = data2.view(data2.size(0), -1)
        optimizer.zero_grad()
        
        results = model([data1,data2])
        z_posterior, mu, var = results['z'], results['mu'], results['var']
        reconstruction = results['x_rec']
        total_loss, rec_loss, KLD = results['total_loss'], results['rec_losses'], results['KL_loss']

        running_mse1 += rec_loss[0].item()
        running_mse2 += rec_loss[1].item()
        running_kld  += KLD.item()
        running_loss += total_loss.item()          #.item converts tensor with one element to number
        total_loss.backward()                      #.backward
        optimizer.step()                     #.step one learning step
    train_loss = running_loss/(len(joint_dataloader.dataset))
    mse1_loss = running_mse1 / (len(joint_dataloader.dataset))
    mse2_loss = running_mse2 / (len(joint_dataloader.dataset))
    kld_loss = running_kld / (len(joint_dataloader.dataset))
#     for name, param in model.named_parameters():
#         writer.add_histogram(name, param.clone().cpu().data.numpy(), epoch)
#     writer.add_scalar("training/loss", train_loss, epoch)
#     writer.add_scalar("training/MSE1", mse1_loss, epoch)
#     writer.add_scalar("training/MSE2", mse2_loss, epoch)
#     writer.add_scalar("training/KLD", kld_loss, epoch)    
    return train_loss
    
def test(model,joint_dataloader,epoch):
    latent_repMNIST= []
    latent_repSVHN= []
    label_mnist= []
    label_svhn= []
    model.eval()
    running_loss = 0.0
    running_mse1 = 0.0
    running_mse2 = 0.0
    running_kld  = 0.0
    running_loss = 0.0
    with torch.no_grad():
        for i,joint_data in enumerate(joint_dataloader):
            data1   = joint_data[0]
            data1   = data1.float()

            data2  =joint_data[1]
            data2 = data2.float()

            label1  =joint_data[2]
            label2  =joint_data[3]
            
            data1 = data1.to(device)
            data2 = data2.to(device)
            data1 = data1.view(data1.size(0), -1)
            data2 = data2.view(data2.size(0), -1)            
            z_posterior,reconstruction,mu,var,total_loss, rec_loss,  KLD       = model([data1,data2])

            running_loss += total_loss.item()
            running_mse1 += rec_loss[0].item()
            running_mse2 += rec_loss[0].item()
            running_kld  += KLD.item()    
            
            latent_repMNIST.append(z_posterior[0])
            latent_repSVHN.append(z_posterior[1])
            label_mnist.append(label1)
            label_svhn.append(label2)

            #save the last batch input and output of every epoch
            if i == int(len(joint_dataloader.dataset)/joint_dataloader.batch_size) - 1:
                num_rows = 8
                both = torch.cat((data1.view(batch_size, 1, 28, 28)[:8], 
                                  reconstruction[0].view(batch_size, 1, 28, 28)[:8]))
                bothp = torch.cat((data2.view(batch_size, 3, 32, 32)[:8], 
                                  reconstruction[1].view(batch_size, 3, 32, 32)[:8]))
                save_image(both.cpu(), os.path.join(RECONSTRUCTION_PATH, f"1_outputMNIST_{epoch}.png"), nrow=num_rows)
                save_image(bothp.cpu(), os.path.join(RECONSTRUCTION_PATH, f"1_outputSVHN_{epoch}.png"), nrow=num_rows)
    test_loss = running_loss/(len(joint_dataloader.dataset))
    mse1_loss = running_mse1 / (len(joint_dataloader.dataset))
    mse2_loss = running_mse2 / (len(joint_dataloader.dataset))
    kld_loss = running_kld / (len(joint_dataloader.dataset))
#     writer.add_scalar("validation/loss", test_loss, epoch)
#     writer.add_scalar("validation/MSE1", mse1_loss, epoch)
#     writer.add_scalar("validation/MSE2", mse2_loss, epoch)
#     writer.add_scalar("validation/KLD", kld_loss, epoch)
    latent_repMNIST = torch.vstack(latent_repMNIST).cpu().numpy()
    latent_repSVHN  = torch.vstack(latent_repSVHN).cpu().numpy()
    label_mnist     = torch.hstack(label_mnist).cpu().numpy()
    label_svhn      = torch.hstack(label_svhn).cpu().numpy()
    return test_loss,latent_repMNIST,latent_repSVHN,label_mnist,label_svhn

In [None]:
train_loss = []
test_loss = []
epochs = 5
writer=SummaryWriter(SUMMARY_WRITER_PATH)
for epoch in range(epochs):
    print(f"Epoch {epoch+1} of {epochs}")
    train_epoch_loss = train(model,joint_dataset_train_loader,epoch)
    test_epoch_loss,latent_repMNIST,latent_repSVHN,label_mnist,label_svhn = test(model,joint_dataset_test_loader,epoch)
    train_loss.append(train_epoch_loss)
    test_loss.append(test_epoch_loss)     
    print(f"Train Loss: {train_epoch_loss:.4f}")
    print(f"Test Loss: {test_epoch_loss:.4f}")

Epoch 1 of 5
Train Loss: 63.0024
Test Loss: 62.5679
Epoch 2 of 5
Train Loss: 60.6707
Test Loss: 60.3820
Epoch 3 of 5
Train Loss: 58.1589
Test Loss: 58.7737
Epoch 4 of 5
Train Loss: 56.7420
Test Loss: 57.0718
Epoch 5 of 5


In [85]:
state = {
    'epoch': epoch,
    'state_dict': model.state_dict(),
    'optimizer': optimizer.state_dict()}
torch.save(state, PATH)