## Importing libraries

In [1]:
import sys, os
from pathlib import Path

import numpy as np

import torch
from torch.utils.data import DataLoader, Subset
import torchvision.transforms as transforms
from torchvision.datasets import MNIST

from sklearn.model_selection import train_test_split

from auxiliary_functions import remove_mean_transform, scale_0_1_transform

In [2]:
ROOT_PATH = '/'.join(os.getcwd().split('/')[:-1])
data_dir = f'{ROOT_PATH}/Datasets/Mnist/'
methods_dir = f'{ROOT_PATH}/Methods/MNIST_domains/'

In [None]:
sys.path.append(methods_dir)
from methods import VariationalAutoencoder,z_VariationalAutoencoder, set_seed, train_and_val, z_train_and_val, replace_point_by_underscore

## Create Dataset and DataLoaders

In [4]:
remove_mean = False
scale_0_1 = False

if remove_mean:
    img_transform = transforms.Compose([
        transforms.Resize((32, 32)),
        transforms.ToTensor(),
        transforms.Lambda(lambda x: remove_mean_transform(x)) 
    ])
elif scale_0_1:
    img_transform = transforms.Compose([
        transforms.Resize((32, 32)),
        transforms.ToTensor(),
        transforms.Lambda(lambda x: scale_0_1_transform(x)) 
    ])
else:
    img_transform = transforms.Compose([
    transforms.Resize((32, 32)),
    transforms.ToTensor(),  
    ])
    
dataset = MNIST(root=data_dir+'MNIST', download=True, train=True, transform=img_transform)
#test_dataset = MNIST(root=data_dir+'MNIST', download=True, train=False, transform=img_transform)

In [5]:
batch_size = 128

n_data = len(dataset)
perc_val = 0.2

# generating subset based on indices
set_seed(seed=2, seed_torch=True)

import_idxs = True

if import_idxs:
    train_idxs = np.load(data_dir+'train_idxs.npy')
    val_idxs = np.load(data_dir+'val_idxs.npy')

else:
    train_idxs, val_idxs, _, _ = train_test_split(
        range(n_data),
        dataset.targets,
        stratify=dataset.targets,
        test_size=perc_val
    )
    np.save(data_dir+'train_idxs.npy',train_idxs)
    np.save(data_dir+'val_idxs.npy',val_idxs)

train_dataset = Subset(dataset,train_idxs)
val_dataset = Subset(dataset,val_idxs)

n_train = len(train_dataset)
n_val = len(val_dataset)
#n_test = len(test_dataset)

train_dataloader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True)
validation_dataloader = DataLoader(val_dataset, batch_size=batch_size, shuffle=False)
#test_dataloader = DataLoader(test_dataset, batch_size=batch_size, shuffle=False)

print(f'length train, validation and test sets: {n_train}, {n_val}')#, {n_test}')

Random seed 2 has been set.
length train, validation and test sets: 48000, 12000


## Set your models' hyperparameters (shown here manuscript values)

In [6]:
input_dim = 32*32

latent_dim_vae = 5
latent_dim_eavae = 4

variational_beta_y_vae = 4
variational_beta_y_eavae = 4
variational_beta_z_eavae = 1

capacity = 64
learning_rate = 1e-3
nepochs = 500

## Initialize VAE and EA-VAE models

In [7]:
use_gpu = True
device = torch.device("cuda:0" if use_gpu and torch.cuda.is_available() else "cpu")
uncomp_model_eavae = z_VariationalAutoencoder(n_train,n_val,input_dim,latent_dim_eavae,capacity,learning_rate,rec_loss_method='BCE', device=device)
uncomp_model_vae = VariationalAutoencoder(n_train,n_val,input_dim,latent_dim_vae,capacity,learning_rate,rec_loss_method='BCE', device=device)
'''
if hasattr(torch, 'compile'):
    print('torch.compile is available, will compile the model')
    model = torch.compile(uncomp_model)
else:
    print('torch.compile is not available')
    model = uncomp_model
'''
model_vae = uncomp_model_vae
model_vae = model_vae.to(device)

model_eavae = uncomp_model_eavae
model_eavae = model_eavae.to(device)

## Directory where models will be saved

In [None]:
# generating a directory to save the results
out_location_vae = f'{ROOT_PATH}/MNIST/Model_checkpoints/personal/latent_dim_{latent_dim_vae}/VAE/beta_y_'+replace_point_by_underscore(str(variational_beta_y_vae))+'/lr_'+replace_point_by_underscore(str(learning_rate))+'/'

if not os.path.exists(out_location_vae):
   os.makedirs(out_location_vae)
   print("The directory was generated: ", out_location_vae)
   
out_location_eavae = f'{ROOT_PATH}/MNIST/Model_checkpoints/personal/latent_dim_{latent_dim_eavae}/EA-VAE/beta_y_'+replace_point_by_underscore(str(variational_beta_y_eavae))+'/beta_z_'+replace_point_by_underscore(str(variational_beta_z_eavae))+'/lr_'+replace_point_by_underscore(str(learning_rate))+'/'

if not os.path.exists(out_location_eavae):
   os.makedirs(out_location_eavae)
   print("The directory was generated: ", out_location_eavae)

## Train VAE (around 1 hr to complete)

In [None]:
# TRAIN VAE
train_and_val(model_vae,train_dataloader,validation_dataloader, nepochs, variational_beta_y_vae, saving_path = out_location_vae)

## Train EA-VAE (around 1 hr to complete)

In [None]:
# TRAIN EA-VAE
z_train_and_val(model_eavae, train_dataloader, validation_dataloader, nepochs, variational_beta_y_eavae, variational_beta_z_eavae, saving_path= out_location_eavae)