## Importing libraries

In [2]:
import sys, os
import numpy as np
import matplotlib.pylab as plt
import pickle as pkl

import torch
from torch.utils.data import DataLoader, Subset
from sklearn.model_selection import train_test_split

In [3]:
ROOT_PATH = '/'.join(os.getcwd().split('/')[:-1])
methods_dir = f'{ROOT_PATH}/Methods/VanHateren_Gamma/'

In [5]:
sys.path.append(methods_dir)
from methods import CsikorDataset, Laplace_FC_VAE, Gamma_free_Laplace_FC_VAE, set_seed, train_and_val, z_train_and_val, replace_point_by_underscore

## Create Dataset and DataLoaders

In [7]:
natural40_dir = f'{ROOT_PATH}/Datasets/VanHateren/'

train_labels = pkl.load(open(natural40_dir+'train_labels.pkl','rb'))
len_trainset = len(train_labels)

train_images = np.load(natural40_dir+'train_images/train_images.npy').astype(np.float32)

n_data = len(train_images)
perc_val = 0.2

# generating subset based on indices
set_seed(seed=2, seed_torch=True)

import_idxs = True

if import_idxs:
    train_idxs = np.load(natural40_dir+'train_idxs.npy')
    val_idxs = np.load(natural40_dir+'val_idxs.npy')

else:
    train_idxs, val_idxs = train_test_split(
        range(n_data),
        test_size=perc_val
    )
    np.save(natural40_dir+'train_idxs.npy',train_idxs)
    np.save(natural40_dir+'val_idxs.npy',val_idxs)

Random seed 2 has been set.


In [8]:
rescale_0_1 = True
remove_mean = True

if rescale_0_1:
    compute_pix_mean_std = False
    if compute_pix_mean_std:   
        nat_train_pixs_mean = train_images[train_idxs].mean()
        nat_train_pixs_std = train_images[train_idxs].std()
        np.save(natural40_dir+'nat_train_pixs_mean.npy',nat_train_pixs_mean)
        np.save(natural40_dir+'nat_train_pixs_std.npy',nat_train_pixs_std)
    else:
        nat_train_pixs_mean = np.load(natural40_dir+'nat_train_pixs_mean.npy')
        nat_train_pixs_std = np.load(natural40_dir+'nat_train_pixs_std.npy')    
    print('Will rescale images to 0-1')
    train_images_ = (train_images - nat_train_pixs_mean)/(6*nat_train_pixs_std)+1/2
        
    if remove_mean:
        print('Will remove mean from images')
        train_images__ = train_images_ - np.mean(train_images_,axis=(1),keepdims=True)

        nat_dataset = CsikorDataset(train_labels,train_images__)
    else:
        nat_dataset = CsikorDataset(train_labels,train_images_)
else:
    nat_dataset = CsikorDataset(train_labels,train_images)


Will rescale images to 0-1
Will remove mean from images


In [9]:
batch_size = 128

train_dataset = Subset(nat_dataset,train_idxs)
val_dataset = Subset(nat_dataset,val_idxs)

n_train = len(train_dataset)
n_val = len(val_dataset)

train_dataloader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True)
validation_dataloader = DataLoader(val_dataset, batch_size=batch_size, shuffle=False)

print(f'length train, validation and test sets: {n_train}, {n_val}')

length train, validation and test sets: 512000, 128000


## Set your models' hyperparameters (shown here manuscript values)

In [10]:
# defining model parameters
latent_dim_vae = 1800
latent_dim_eavae = 1799
imsize = 1600 #Fixed
variational_beta_y_vae = .015
variational_beta_y_eavae = .015
variational_beta_z_eavae = .03
k_param = 2
theta_param = float(1/np.sqrt(2))
learning_rate = 3e-5
weight_decay = 1e-5
nepochs = 10000

## Initialize VAE and EA-VAE models

In [11]:
# checking if torch.cuda is available
use_gpu = True
device = torch.device("cuda:0" if use_gpu and torch.cuda.is_available() else "cpu")

uncomp_model_vae = Laplace_FC_VAE(n_train,n_val,latent_dim_vae,imsize,learning_rate,weight_decay, device=device)
uncomp_model_eavae = Gamma_free_Laplace_FC_VAE(n_train,n_val,latent_dim_eavae,imsize,learning_rate,weight_decay,k_param,theta_param, device=device)

model_vae = uncomp_model_vae
model_vae = model_vae.to(device)

model_eavae = uncomp_model_eavae
model_eavae = model_eavae.to(device)

## Directory where models will be saved

In [None]:
# generating a directory to save the results
out_location_vae = f'{ROOT_PATH}/VanHateren_Gamma-Laplace/Model_checkpoints/personal/latent_dim_{latent_dim_vae}/VAE/beta_y_'+replace_point_by_underscore(str(variational_beta_y_vae))+'/lr_'+replace_point_by_underscore(str(learning_rate))+'/'

if not os.path.exists(out_location_vae):
   os.makedirs(out_location_vae)
   print("The directory was generated: ", out_location_vae)
   
out_location_eavae = f'{ROOT_PATH}/VanHateren_Gamma-Laplace/Model_checkpoints/personal/latent_dim_{latent_dim_eavae}/EA-VAE/beta_y_'+replace_point_by_underscore(str(variational_beta_y_eavae))+'/beta_z_'+replace_point_by_underscore(str(variational_beta_z_eavae))+'/lr_'+replace_point_by_underscore(str(learning_rate))+'/'

if not os.path.exists(out_location_eavae):
   os.makedirs(out_location_eavae)
   print("The directory was generated: ", out_location_eavae)

## Train VAE (around 75 hr to complete)

In [None]:
#TRAIN VAE
train_and_val(model_vae,train_dataloader,validation_dataloader, nepochs, variational_beta_y_vae, saving_path = out_location_vae)

## Train EA-VAE (around 75 hr to complete)

In [None]:
#TRAIN EA-VAE
z_train_and_val(model_eavae, train_dataloader, validation_dataloader, nepochs, variational_beta_y_eavae, variational_beta_z_eavae, saving_path = out_location_eavae)