In [18]:
## Importing libraries

## Importing libraries
import sys, os
import numpy as np
import pandas as pd
import pickle
import pickle as pkl
import scipy
from scipy.stats import norm

import torch
from torch.utils.data import DataLoader, Subset
import torch.utils.data as data
import torchvision.transforms as transforms
from torchvision.datasets import MNIST
from torchvision.datasets import FashionMNIST

from sklearn.neural_network import MLPClassifier

from auxiliary_functions import *

In [19]:
ROOT_PATH = '/'.join(os.getcwd().split('/')[:-1])
methods_dir = f'{ROOT_PATH}/Methods/MNIST_domains/'

In [20]:
sys.path.append(methods_dir)
from methods import VariationalAutoencoder,z_VariationalAutoencoder, set_seed, replace_point_by_underscore,CsikorDataset, MNISTShuffDataset
import medmnist
from medmnist import INFO, Evaluator

In [21]:

## Van Hateren Dataset
natural40_dir = f'{ROOT_PATH}/Datasets/VanHateren/'

nat_test_labels = pkl.load(open(natural40_dir+'test_labels.pkl','rb'))
nat_test_images = torch.tensor(np.load(natural40_dir+'test_images/test_images.npy').astype(np.float32).reshape(64000,1,40,40))
nat_n_test = len(nat_test_labels)

nat_train_pixs_mean = np.load(natural40_dir+'nat_train_pixs_mean.npy')
nat_train_pixs_std = np.load(natural40_dir+'nat_train_pixs_std.npy')
 
def rescale_nat2num(img):
    old_mean = nat_train_pixs_mean
    old_std = nat_train_pixs_std
    
    img = (1/(6*old_std)*(img - old_mean) + 0.5).clamp(0,1)
    return img

data_transform_natural_rescaled = transforms.Compose([
    transforms.Resize((32, 32)),
    transforms.Lambda(lambda x: rescale_nat2num(x)) 
]) 
batch_size = 128
nat_test_rescaled_dataset = CsikorDataset(nat_test_labels,nat_test_images,transform = data_transform_natural_rescaled)
nat_test_rescaled_dataloader = DataLoader(nat_test_rescaled_dataset, batch_size=batch_size, shuffle=False)

## MNIST Dataset
data_dir = f'{ROOT_PATH}/Datasets/Mnist/'

img_transform = transforms.Compose([
transforms.Resize((32, 32)),
transforms.ToTensor(),  
])

dataset = MNIST(root=data_dir+'MNIST', download=True, train=True, transform=img_transform)
num_train_idxs = np.load(data_dir+'train_idxs.npy')
num_val_idxs = np.load(data_dir+'val_idxs.npy')
num_train_dataset = Subset(dataset,num_train_idxs)
num_val_dataset = Subset(dataset,num_val_idxs)
num_test_dataset = MNIST(root=data_dir+'MNIST', download=True, train=False, transform=img_transform)

num_n_val = len(num_val_dataset)
num_n_train = len(dataset)-num_n_val

num_train_dataloader = DataLoader(num_train_dataset, batch_size=batch_size, shuffle=False)
num_validation_dataloader = DataLoader(num_val_dataset, batch_size=batch_size, shuffle=False)
num_test_dataloader = DataLoader(num_test_dataset, batch_size=batch_size, shuffle=False)

## FashionMNIST Dataset
data_dir = f'{ROOT_PATH}/Datasets/FashionMnist/'

fash_test_dataset = FashionMNIST(root=data_dir+'FashionMNIST', download=True, train=False, transform=img_transform)
fash_test_dataloader = DataLoader(fash_test_dataset, batch_size=batch_size, shuffle=False)

#ChestMNIST Dataset
data_flag = 'chestmnist'
download = True

BATCH_SIZE = 128

info = INFO[data_flag]
task = info['task']
n_channels = info['n_channels']
n_classes = len(info['label'])

DataClass = getattr(medmnist, info['python_class'])

med_test_dataset = DataClass(split='test', transform=img_transform, download=download)
med_n_test = len(med_test_dataset)
med_test_dataloader = data.DataLoader(dataset=med_test_dataset, batch_size=BATCH_SIZE, shuffle=False)

rnd_pix_shuff_imgs = torch.zeros(10000,1,32,32)
# shuffle pixels across first axis
for k in range(10000):
    rnd_pix_shuff_imgs[k] = num_test_dataset[k][0]
for i in range(32):
    for j in range(32):
        torch.manual_seed(32*i+j)
        rnd_pix_shuff_imgs[:,0,i,j] = rnd_pix_shuff_imgs[:,0,i,j][torch.randperm(10000)]
        
batch_size = 128

num_test_ShuffPix_dataset = MNISTShuffDataset(num_test_dataset.targets,rnd_pix_shuff_imgs)
num_test_ShuffPix_dataloader = DataLoader(num_test_ShuffPix_dataset, batch_size=batch_size, shuffle=False)

Using downloaded and verified file: /home/jcatoni/.medmnist/chestmnist.npz


In [22]:
use_manuscript_training = True

input_dim = 32*32
capacity = 64

if use_manuscript_training:
    latent_dim_vae = 5
    latent_dim_eavae = 4
    variational_beta_y_vae = 4
    variational_beta_y_eavae = 4
    variational_beta_z_eavae = 1
    learning_rate = 1e-3
    in_location_VAE = f'{ROOT_PATH}/MNIST/Model_checkpoints/manuscript/VAE/'
    in_location_EAVAE = f'{ROOT_PATH}/MNIST/Model_checkpoints/manuscript/EA-VAE/'
else:
    latent_dim_vae = 5 #complete
    latent_dim_eavae = 4 #complete
    variational_beta_y_vae = 4 #complete
    variational_beta_y_eavae = 4 #complete
    variational_beta_z_eavae = 1 #complete
    learning_rate = 1e-3 #complete
    in_location_VAE = f'{ROOT_PATH}/MNIST/Model_checkpoints/personal/latent_dim_{latent_dim_vae}/VAE/beta_y_'+replace_point_by_underscore(str(variational_beta_y_vae))+'/lr_'+replace_point_by_underscore(str(learning_rate))+'/'
    in_location_EAVAE = f'{ROOT_PATH}/MNIST/Model_checkpoints/personal/latent_dim_{latent_dim_eavae}/EA-VAE/beta_y_'+replace_point_by_underscore(str(variational_beta_y_eavae))+'/beta_z_'+replace_point_by_underscore(str(variational_beta_z_eavae))+'/lr_'+replace_point_by_underscore(str(learning_rate))+'/'


In [24]:
if use_manuscript_training:
   out_location = f'{ROOT_PATH}/Plot_results/MNIST/data/manuscript/prueba/'
else:
   out_location = f'{ROOT_PATH}/Plot_results/MNIST/data/personal/'

if not os.path.exists(out_location):
   os.makedirs(out_location)
   print("Se creó el directorio: ", out_location)

## loading models
# checking if torch.cuda is available
use_gpu = True
device = torch.device("cuda:0" if use_gpu and torch.cuda.is_available() else "cpu")

model_vae = VariationalAutoencoder(num_n_train,num_n_val,input_dim,latent_dim_vae,capacity,learning_rate,rec_loss_method='BCE', device=device)
model_eavae = z_VariationalAutoencoder(num_n_train,num_n_val,input_dim,latent_dim_eavae,capacity,learning_rate,rec_loss_method='BCE', device=device)

state_dict_vae = torch.load(in_location_VAE+'bestnet.pth')
state_dict_eavae = torch.load(in_location_EAVAE+'bestnet.pth')

model_vae.load_state_dict(state_dict_vae)
model_eavae.load_state_dict(state_dict_eavae)

model_vae = model_vae.to(device)
model_eavae = model_eavae.to(device)

model_vae.eval()
model_eavae.eval()

#cost_vae = torch.load(in_location_VAE+'finalloss.pth')
#cost_eavae = torch.load(in_location_EAVAE+'finalloss.pth')

  state_dict_vae = torch.load(in_location_VAE+'bestnet.pth')
  state_dict_eavae = torch.load(in_location_EAVAE+'bestnet.pth')


z_VariationalAutoencoder(
  (encoder): z_VEncoder(
    (conv1): Conv2d(1, 64, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1))
    (conv2): Conv2d(64, 128, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1))
    (conv3): Conv2d(128, 256, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1))
    (fc_mu): Linear(in_features=4096, out_features=4, bias=True)
    (fc_logvar): Linear(in_features=4096, out_features=4, bias=True)
    (fc_intz): Linear(in_features=16384, out_features=20, bias=True)
    (fc_zmu): Linear(in_features=20, out_features=1, bias=True)
    (fc_zlogvar): Linear(in_features=20, out_features=1, bias=True)
  )
  (decoder): z_VDecoder_BCE(
    (fc): Linear(in_features=4, out_features=4096, bias=True)
    (conv3): ConvTranspose2d(256, 128, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1))
    (conv2): ConvTranspose2d(128, 64, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1))
    (conv1): ConvTranspose2d(64, 1, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1))
  )
)

In [25]:
num_avimg = torch.zeros(1,32,32)
for k, (x, y) in enumerate(num_test_dataloader):
    num_avimg += x.sum(dim=0)
num_avimg = num_avimg/len(num_test_dataloader.dataset)

num_test_mu_vae, num_test_var_vae, num_test_zmu_vae, num_test_zvar_vae = data_posteriors(model_vae, num_test_dataloader, latent_dim_vae, False)
num_test_mu_eavae, num_test_var_eavae, num_test_zmu_eavae, num_test_zvar_eavae = data_posteriors(model_eavae, num_test_dataloader, latent_dim_eavae, True)

num_test_av_post_mu_vae, num_test_av_post_var_vae = average_posterior(num_test_mu_vae, num_test_var_vae)
num_test_av_post_mu_eavae, num_test_av_post_var_eavae = average_posterior(num_test_mu_eavae, num_test_var_eavae)

num_uncertain_mu_vae, num_uncertain_var_vae, num_uncertain_zmu_vae, num_uncertain_zvar_vae = individual_posterior(model_vae, num_avimg, False)
num_uncertain_mu_eavae, num_uncertain_var_eavae, num_uncertain_zmu_eavae, num_uncertain_zvar_eavae = individual_posterior(model_eavae, num_avimg, True)

med_test_mu_vae, med_test_var_vae, med_test_zmu_vae, med_test_zvar_vae = data_posteriors(model_vae, med_test_dataloader, latent_dim_vae, False)
med_test_mu_eavae, med_test_var_eavae, med_test_zmu_eavae, med_test_zvar_eavae = data_posteriors(model_eavae, med_test_dataloader, latent_dim_eavae, True)

nat_test_mu_vae, nat_test_var_vae, nat_test_zmu_vae, nat_test_zvar_vae = data_posteriors(model_vae, nat_test_rescaled_dataloader, latent_dim_vae, False)
nat_test_mu_eavae, nat_test_var_eavae, nat_test_zmu_eavae, nat_test_zvar_eavae = data_posteriors(model_eavae, nat_test_rescaled_dataloader, latent_dim_eavae, True)

fash_test_mu_vae, fash_test_var_vae, fash_test_zmu_vae, fash_test_zvar_vae = data_posteriors(model_vae, fash_test_dataloader, latent_dim_vae, False)
fash_test_mu_eavae, fash_test_var_eavae, fash_test_zmu_eavae, fash_test_zvar_eavae = data_posteriors(model_eavae, fash_test_dataloader, latent_dim_eavae, True)

num_test_ShuffPix_mu_vae, num_test_ShuffPix_var_vae, num_test_ShuffPix_zmu_vae, num_test_ShuffPix_zvar_vae = data_posteriors(model_vae, num_test_ShuffPix_dataloader, latent_dim_vae, False)
num_test_ShuffPix_mu_eavae, num_test_ShuffPix_var_eavae, num_test_ShuffPix_zmu_eavae, num_test_ShuffPix_zvar_eavae = data_posteriors(model_eavae, num_test_ShuffPix_dataloader, latent_dim_eavae, True)

num_uncertainty_vae_df = pd.DataFrame({'uncertainty':np.sqrt(num_test_var_vae).mean(axis=1), 'model':'VAE', 'data-set':'MNIST'})
med_uncertainty_vae_df = pd.DataFrame({'uncertainty':np.sqrt(med_test_var_vae).mean(axis=1), 'model':'VAE', 'data-set':'ChestMNIST'})
nat_uncertainty_vae_df = pd.DataFrame({'uncertainty':np.sqrt(nat_test_var_vae).mean(axis=1), 'model':'VAE', 'data-set':'van Hateren'})
fash_uncertainty_vae_df = pd.DataFrame({'uncertainty':np.sqrt(fash_test_var_vae).mean(axis=1), 'model':'VAE', 'data-set':'FashionMNIST'})
num_ShuffPix_uncertainty_vae_df = pd.DataFrame({'uncertainty':np.sqrt(num_test_ShuffPix_var_vae).mean(axis=1), 'model':'VAE', 'data-set':'MNIST permuted pixels'})

num_uncertainty_eavae_df = pd.DataFrame({'uncertainty':np.sqrt(num_test_var_eavae).mean(axis=1), 'model':'EA-VAE', 'data-set':'MNIST'})
med_uncertainty_eavae_df = pd.DataFrame({'uncertainty':np.sqrt(med_test_var_eavae).mean(axis=1), 'model':'EA-VAE', 'data-set':'ChestMNIST'})
nat_uncertainty_eavae_df = pd.DataFrame({'uncertainty':np.sqrt(nat_test_var_eavae).mean(axis=1), 'model':'EA-VAE', 'data-set':'van Hateren'})
fash_uncertainty_eavae_df = pd.DataFrame({'uncertainty':np.sqrt(fash_test_var_eavae).mean(axis=1), 'model':'EA-VAE', 'data-set':'FashionMNIST'})
num_ShuffPix_uncertainty_eavae_df = pd.DataFrame({'uncertainty':np.sqrt(num_test_ShuffPix_var_eavae).mean(axis=1), 'model':'EA-VAE', 'data-set':'MNIST permuted pixels'})

fash_zmu_vae_df = pd.DataFrame({'z_mu':fash_test_zmu_vae, 'model':'VAE', 'data-set':'FashionMNIST'})
med_zmu_vae_df = pd.DataFrame({'z_mu':med_test_zmu_vae, 'model':'VAE', 'data-set':'ChestMNIST'})
num_zmu_vae_df = pd.DataFrame({'z_mu':num_test_zmu_vae, 'model':'VAE', 'data-set':'MNIST'})
nat_zmu_vae_df = pd.DataFrame({'z_mu':nat_test_zmu_vae, 'model':'VAE', 'data-set':'van Hateren'})
num_ShuffPix_zmu_vae_df = pd.DataFrame({'z_mu':num_test_ShuffPix_zmu_vae, 'model':'VAE', 'data-set':'MNIST permuted pixels'})

fash_zmu_eavae_df = pd.DataFrame({'z_mu':np.exp(fash_test_zmu_eavae + fash_test_zvar_eavae/2), 'model':'EA-VAE', 'data-set':'FashionMNIST'})
med_zmu_eavae_df = pd.DataFrame({'z_mu':np.exp(med_test_zmu_eavae + med_test_zvar_eavae/2), 'model':'EA-VAE', 'data-set':'ChestMNIST'})
num_zmu_eavae_df = pd.DataFrame({'z_mu':np.exp(num_test_zmu_eavae + num_test_zvar_eavae/2), 'model':'EA-VAE', 'data-set':'MNIST'})
nat_zmu_eavae_df = pd.DataFrame({'z_mu':np.exp(nat_test_zmu_eavae + nat_test_zvar_eavae/2), 'model':'EA-VAE', 'data-set':'van Hateren'})
num_ShuffPix_zmu_eavae_df = pd.DataFrame({'z_mu':np.exp(num_test_ShuffPix_zmu_eavae+num_test_ShuffPix_zvar_eavae/2), 'model':'EA-VAE', 'data-set':'MNIST permuted pixels'})

uncertainty_df = pd.concat([num_uncertainty_vae_df,   num_ShuffPix_uncertainty_vae_df,   fash_uncertainty_vae_df,   med_uncertainty_vae_df,   nat_uncertainty_vae_df,
                            num_uncertainty_eavae_df, num_ShuffPix_uncertainty_eavae_df, fash_uncertainty_eavae_df, med_uncertainty_eavae_df, nat_uncertainty_eavae_df])

zmu_eavae_df = pd.concat([num_zmu_eavae_df,num_ShuffPix_zmu_eavae_df, fash_zmu_eavae_df, med_zmu_eavae_df, nat_zmu_eavae_df])

uncertainty_vae_eavae_df = pd.concat([num_uncertainty_vae_df, num_uncertainty_eavae_df])


In [10]:
model_vae.eval()
model_eavae.eval()

x_ori = []
x_vae = []
x_eavae = []
for x, y in num_test_dataloader:
    with torch.no_grad():
        x_hat_eavae, _, _, _, _ = model_eavae(x.to(model_eavae.device), only_mu=True, only_zmu=True)
        x_hat_vae, _, _ = model_vae(x.to(model_vae.device), only_mu=True)
    for i,k in enumerate(range(3,6)):            
        x_ori.append(x.detach().cpu().numpy()[k])
        x_vae.append(x_hat_vae.detach().cpu().numpy()[k])
        x_eavae.append(x_hat_eavae.detach().cpu().numpy()[k])
    break # just one batch

x_med_ori = []
x_med_vae = []
x_med_eavae = []
for x, y in med_test_dataloader:
    with torch.no_grad():
        x_hat_eavae, _, _, _, _ = model_eavae(x.to(model_eavae.device), only_mu=True, only_zmu=True)
        x_hat_vae, _, _ = model_vae(x.to(model_vae.device), only_mu=True)
    for i,k in enumerate(range(3,6)):
        x_med_ori.append(x.detach().cpu().numpy()[k])
        x_med_vae.append(x_hat_vae.detach().cpu().numpy()[k])
        x_med_eavae.append(x_hat_eavae.detach().cpu().numpy()[k])
    break # just one batch

x_fash_ori = []
x_fash_vae = []
x_fash_eavae = []
for x, y in fash_test_dataloader:
    with torch.no_grad():
        x_hat_eavae, _, _, _, _ = model_eavae(x.to(model_eavae.device), only_mu=True, only_zmu=True)
        x_hat_vae, _, _ = model_vae(x.to(model_vae.device), only_mu=True)
    for i,k in enumerate(range(3,6)):
        x_fash_ori.append(x.detach().cpu().numpy()[k])
        x_fash_vae.append(x_hat_vae.detach().cpu().numpy()[k])
        x_fash_eavae.append(x_hat_eavae.detach().cpu().numpy()[k])
    break # just one batch

x_nat_ori = []
x_nat_vae = []
x_nat_eavae = []
for x, y in nat_test_rescaled_dataloader:
    with torch.no_grad():
        x_hat_eavae, _, _, _, _ = model_eavae(x.to(model_eavae.device), only_mu=True, only_zmu=True)
        x_hat_vae, _, _ = model_vae(x.to(model_vae.device), only_mu=True)
    for i,k in enumerate(range(5,8)):
        x_nat_ori.append(x.detach().cpu().numpy()[k])
        x_nat_vae.append(x_hat_vae.detach().cpu().numpy()[k])
        x_nat_eavae.append(x_hat_eavae.detach().cpu().numpy()[k])
    break # just one batch

x_shuff_ori = []
x_shuff_vae = []
x_shuff_eavae = []
for x, y in num_test_ShuffPix_dataloader:
    with torch.no_grad():
        x_hat_eavae, _, _, _, _ = model_eavae(x.to(model_eavae.device), only_mu=True, only_zmu=True)
        x_hat_vae, _, _ = model_vae(x.to(model_vae.device), only_mu=True)
    for i,k in enumerate(range(5,8)):
        x_shuff_ori.append(x.detach().cpu().numpy()[k])
        x_shuff_vae.append(x_hat_vae.detach().cpu().numpy()[k])
        x_shuff_eavae.append(x_hat_eavae.detach().cpu().numpy()[k])
    break # just one batch

In [11]:
with open(out_location+'x_ori.pkl', 'wb') as f:
    pickle.dump(x_ori, f)
with open(out_location+'x_vae.pkl', 'wb') as f:
    pickle.dump(x_vae, f)
with open(out_location+'x_eavae.pkl', 'wb') as f:
    pickle.dump(x_eavae, f)
    
with open(out_location+'x_med_ori.pkl', 'wb') as f:
    pickle.dump(x_med_ori, f)
with open(out_location+'x_med_vae.pkl', 'wb') as f:
    pickle.dump(x_med_vae, f)
with open(out_location+'x_med_eavae.pkl', 'wb') as f:
    pickle.dump(x_med_eavae, f)
    
with open(out_location+'x_fash_ori.pkl', 'wb') as f:
    pickle.dump(x_fash_ori, f)
with open(out_location+'x_fash_vae.pkl', 'wb') as f:
    pickle.dump(x_fash_vae, f)
with open(out_location+'x_fash_eavae.pkl', 'wb') as f:
    pickle.dump(x_fash_eavae, f)
    
with open(out_location+'x_nat_ori.pkl', 'wb') as f:
    pickle.dump(x_nat_ori, f)
with open(out_location+'x_nat_vae.pkl', 'wb') as f:
    pickle.dump(x_nat_vae, f)
with open(out_location+'x_nat_eavae.pkl', 'wb') as f:
    pickle.dump(x_nat_eavae, f)
    
with open(out_location+'x_shuff_ori.pkl', 'wb') as f:
    pickle.dump(x_shuff_ori, f)
with open(out_location+'x_shuff_vae.pkl', 'wb') as f:
    pickle.dump(x_shuff_vae, f)
with open(out_location+'x_shuff_eavae.pkl', 'wb') as f:
    pickle.dump(x_shuff_eavae, f)

In [12]:
uncertainty_df.to_pickle(out_location+'uncertainty_df')
zmu_eavae_df.to_pickle(out_location+'zmu_eavae_df')
uncertainty_vae_eavae_df.to_pickle(out_location+'uncertainty_vae_eavae_df')
np.save(out_location+'num_uncertain_var_vae.npy',num_uncertain_var_vae)
np.save(out_location+'num_uncertain_var_eavae.npy',num_uncertain_var_eavae)
np.save(out_location+'num_avimg.npy',num_avimg)

In [13]:
test_indexes_per_category = [np.where(num_test_dataset.targets==j)[0] for j in range(10)]

test_category_zmu_eavae = [num_test_zmu_eavae[test_indexes_per_category[j]].mean() for j in range(10)]
test_category_mu_eavae = [num_test_mu_eavae[test_indexes_per_category[j]].mean(axis=0) for j in range(10)]
test_category_cov_eavae = [np.diag(num_test_var_eavae[test_indexes_per_category[j]].mean(axis=0))+np.cov(num_test_mu_eavae[test_indexes_per_category[j]].T) for j in range(10)]

test_category_mu_vae = [num_test_mu_vae[test_indexes_per_category[j]].mean(axis=0) for j in range(10)]
test_category_cov_vae = [np.diag(num_test_var_vae[test_indexes_per_category[j]].mean(axis=0))+np.cov(num_test_mu_vae[test_indexes_per_category[j]].T) for j in range(10)]


In [14]:
#save test_category_mu_vae,test_category_cov_vae,test_indexes_per_category,num_test_mu_vae,num_test_var_vae to pickle
with open(out_location+'test_category_mu_vae.pkl', 'wb') as f:
    pickle.dump(test_category_mu_vae, f)
with open(out_location+'test_category_cov_vae.pkl', 'wb') as f:
    pickle.dump(test_category_cov_vae, f)
with open(out_location+'test_indexes_per_category.pkl', 'wb') as f:
    pickle.dump(test_indexes_per_category, f)
with open(out_location+'num_test_mu_vae.pkl', 'wb') as f:
    pickle.dump(num_test_mu_vae, f)
with open(out_location+'num_test_var_vae.pkl', 'wb') as f:
    pickle.dump(num_test_var_vae, f)
#same with eavae
with open(out_location+'test_category_mu_eavae.pkl', 'wb') as f:
    pickle.dump(test_category_mu_eavae, f)
with open(out_location+'test_category_cov_eavae.pkl', 'wb') as f:
    pickle.dump(test_category_cov_eavae, f)
with open(out_location+'num_test_mu_eavae.pkl', 'wb') as f:
    pickle.dump(num_test_mu_eavae, f)
with open(out_location+'num_test_var_eavae.pkl', 'wb') as f:
    pickle.dump(num_test_var_eavae, f)

In [15]:
model_eavae.eval()
model_vae.eval()

category_recons_vae,category_recons_eavae=[],[]
with torch.no_grad():
    
    for i in range(10):
        category_recons_eavae.append(model_eavae.decoder(torch.tensor(test_category_mu_eavae[i]).unsqueeze(0).float().to(model_eavae.device),torch.tensor(test_category_zmu_eavae[i]).unsqueeze(0).float().to(model_eavae.device)).detach().cpu().numpy())        
    for category_mu in test_category_mu_vae:
        category_recons_vae.append(model_vae.decoder(torch.tensor(category_mu).unsqueeze(0).float().to(model_vae.device)).detach().cpu().numpy())
        
vae_fit = np.zeros((10,10,3))
eavae_fit = np.zeros((10,10,3))

int_cat_unc_vaes = []
int_cat_unc_eavaes = []
int_cat_zmu_eavaes = []
interpolation_img_in_vaes = []
interpolation_img_in_eavaes = []

l=0
for i in range(10):
    for j in range(i+1,10):
        interpolation_img_in_eavae, interpolation_img_out_eavae, interpolation_var_eavae, interpolation_mu_eavae, interpolation_zmu_eavae, interpolation_zvar_eavae = eavae_categoric_interpolation_line_pixel(model_eavae, category_recons_eavae, i, j, 50)
        interpolation_img_in_vae,interpolation_img_out_vae , interpolation_var_vae, interpolation_mu_vae = vae_categoric_interpolation_line_pixel(model_vae, category_recons_vae, i, j, 50)
        
        lam_int = np.linspace(0, 1, 51)
        int_cat_unc_eavae = np.sqrt(interpolation_var_eavae).mean(axis=1)
        int_cat_zmu_eavae = np.exp(interpolation_zmu_eavae[:]+interpolation_zvar_eavae[:]/2)
        eavae_fit[i,j] = np.polyfit(lam_int, int_cat_unc_eavae, 2)
        int_cat_unc_vae = np.sqrt(interpolation_var_vae).mean(axis=1)
        vae_fit[i,j] = np.polyfit(lam_int, int_cat_unc_vae, 2)

        int_cat_unc_vaes.append(int_cat_unc_vae)
        int_cat_unc_eavaes.append(int_cat_unc_eavae)
        int_cat_zmu_eavaes.append(int_cat_zmu_eavae)
        interpolation_img_in_vaes.append(interpolation_img_in_vae)
        interpolation_img_in_eavaes.append(interpolation_img_in_eavae)
        
        l+=1

In [16]:
vae_in = []
eavae_in = []
for thr in np.linspace(0,.5,51):
    vae_in.append(((.5-thr<((-vae_fit[:,:,1]/(2*vae_fit[:,:,0])).flatten())) * (((-vae_fit[:,:,1]/(2*vae_fit[:,:,0])).flatten())<.5+thr) * ((vae_fit[:,:,0]).flatten()<0)).sum()/45)
    eavae_in.append(((.5-thr<((-eavae_fit[:,:,1]/(2*eavae_fit[:,:,0])).flatten())) * (((-eavae_fit[:,:,1]/(2*eavae_fit[:,:,0])).flatten())<.5+thr) * ((eavae_fit[:,:,0]).flatten()<0)).sum()/45)

  vae_in.append(((.5-thr<((-vae_fit[:,:,1]/(2*vae_fit[:,:,0])).flatten())) * (((-vae_fit[:,:,1]/(2*vae_fit[:,:,0])).flatten())<.5+thr) * ((vae_fit[:,:,0]).flatten()<0)).sum()/45)
  eavae_in.append(((.5-thr<((-eavae_fit[:,:,1]/(2*eavae_fit[:,:,0])).flatten())) * (((-eavae_fit[:,:,1]/(2*eavae_fit[:,:,0])).flatten())<.5+thr) * ((eavae_fit[:,:,0]).flatten()<0)).sum()/45)


In [17]:
np.save(out_location+'int_cat_unc_vaes.npy', np.array(int_cat_unc_vaes, dtype=float), allow_pickle=True)
np.save(out_location+'int_cat_unc_eavaes.npy', np.array(int_cat_unc_eavaes, dtype=float), allow_pickle=True)
np.save(out_location+'int_cat_zmu_eavaes.npy', np.array(int_cat_zmu_eavaes, dtype=float), allow_pickle=True)

np.save(out_location+'interpolation_img_in_vaes.npy', np.array(interpolation_img_in_vaes, dtype=float), allow_pickle=True)
np.save(out_location+'interpolation_img_in_eavaes.npy', np.array(interpolation_img_in_eavaes, dtype=float), allow_pickle=True)

np.save(out_location+'vae_fit.npy',vae_fit)
np.save(out_location+'eavae_fit.npy',eavae_fit)

Classifying

In [48]:
num_train_mu_vae, num_train_var_vae, num_train_zmu_vae, num_train_zvar_vae = data_posteriors(model_vae, num_train_dataloader, latent_dim_vae, False)
num_train_mu_eavae, num_train_var_eavae, num_train_zmu_eavae, num_train_zvar_eavae = data_posteriors(model_eavae, num_train_dataloader, latent_dim_eavae, True)

num_val_mu_vae, num_val_var_vae, num_val_zmu_vae, num_val_zvar_vae = data_posteriors(model_vae, num_validation_dataloader, latent_dim_vae, False)
num_val_mu_eavae, num_val_var_eavae, num_val_zmu_eavae, num_val_zvar_eavae = data_posteriors(model_eavae, num_validation_dataloader, latent_dim_eavae, True)

num_test_mu_vae, num_test_var_vae, num_test_zmu_vae, num_test_zvar_vae = data_posteriors(model_vae, num_test_dataloader, latent_dim_vae, False)
num_test_mu_eavae, num_test_var_eavae, num_test_zmu_eavae, num_test_zvar_eavae = data_posteriors(model_eavae, num_test_dataloader, latent_dim_eavae, True)

In [49]:
W_train_vae = np.hstack((num_train_mu_vae,num_train_var_vae))
W_val_vae = np.hstack((num_val_mu_vae,num_val_var_vae))
W_test_vae = np.hstack((num_test_mu_vae,num_test_var_vae))

y_train_vae = np.array(dataset.targets[num_train_dataset.indices].clone().detach())
y_val_vae = np.array(dataset.targets[num_val_dataset.indices].clone().detach())
y_test_vae = np.array(num_test_dataset.targets.clone().detach())


W_train_eavae = np.hstack((num_train_mu_eavae,num_train_var_eavae))
W_val_eavae = np.hstack((num_val_mu_eavae,num_val_var_eavae))
W_test_eavae = np.hstack((num_test_mu_eavae,num_test_var_eavae))

y_train_eavae = np.array(dataset.targets[num_train_dataset.indices].clone().detach())
y_val_eavae = np.array(dataset.targets[num_val_dataset.indices].clone().detach())
y_test_eavae = np.array(num_test_dataset.targets.clone().detach())

In [None]:
#RUNS FOR ABOUT 1.5 HOURS, RESULTS CAN BE DIRECTLY IMPORTED BELOW

clf_ws_vae_ = MLPClassifier(hidden_layer_sizes=(50, 50))  # Initialize the classifier
score_train_vae_ = []
score_val_vae_ = []
best_val_score_vae = 0  # Start with the lowest possible score
best_clf_vae = None  # Placeholder for the best classifier

for k in range(1000):
    # Generate synthetic training and validation data
    X_s_train_vae = norm.rvs(size=(len(W_train_vae), latent_dim_vae)) * np.sqrt(W_train_vae[:, latent_dim_vae:]) + W_train_vae[:, 0:latent_dim_vae]
    X_s_val_vae = norm.rvs(size=(len(W_val_vae), latent_dim_vae)) * np.sqrt(W_val_vae[:, latent_dim_vae:]) + W_val_vae[:, 0:latent_dim_vae]
    # Update the model
    clf_ws_vae_.partial_fit(X_s_train_vae, y_train_vae, classes=np.arange(10))
    # Evaluate scores
    train_score = clf_ws_vae_.score(X_s_train_vae, y_train_vae)
    val_score = clf_ws_vae_.score(X_s_val_vae, y_val_vae)
    # Save scores
    score_train_vae_.append(train_score)
    score_val_vae_.append(val_score)
    # Check if current model is the best so far
    if val_score > best_val_score_vae:
        best_val_score_vae = val_score
        best_clf_vae = clf_ws_vae_  # Save the current model
    
    
clf_ws_eavae_ = MLPClassifier(hidden_layer_sizes=(50, 50))  # Initialize the classifier
score_train_eavae_ = []
score_val_eavae_ = []
best_val_score_eavae = 0  # Start with the lowest possible score
best_clf_eavae = None  # Placeholder for the best classifier

for k in range(1000):
    # Generate synthetic training and validation data
    X_s_train_eavae = norm.rvs(size=(len(W_train_eavae), latent_dim_eavae)) * np.sqrt(W_train_eavae[:, latent_dim_eavae:]) + W_train_eavae[:, 0:latent_dim_eavae]
    X_s_val_eavae = norm.rvs(size=(len(W_val_eavae), latent_dim_eavae)) * np.sqrt(W_val_eavae[:, latent_dim_eavae:]) + W_val_eavae[:, 0:latent_dim_eavae]
    # Update the model
    clf_ws_eavae_.partial_fit(X_s_train_eavae, y_train_eavae, classes=np.arange(10))
    # Evaluate scores
    train_score = clf_ws_eavae_.score(X_s_train_eavae, y_train_eavae)
    val_score = clf_ws_eavae_.score(X_s_val_eavae, y_val_eavae)
    # Save scores
    score_train_eavae_.append(train_score)
    score_val_eavae_.append(val_score)
    # Check if current model is the best so far
    if val_score > best_val_score_eavae:
        best_val_score_eavae = val_score
        best_clf_eavae = clf_ws_eavae_  # Save the current model

#save
with open(out_location+'clf_ws_vae.pkl','wb') as f:
    pickle.dump(best_clf_vae,f)
with open(out_location+'clf_ws_eavae.pkl','wb') as f:
    pickle.dump(best_clf_eavae,f)

In [54]:
# IF USING MANUSCRIPT DATA, LOAD CLASSIFIERS
with open(out_location+'clf_ws_vae.pkl', 'rb') as f:
    clf_ws_vae = pkl.load(f)
with open(out_location+'clf_ws_eavae.pkl', 'rb') as f:
    clf_ws_eavae = pkl.load(f)

In [56]:
med_mat_vae = np.zeros((len(med_test_mu_vae),10))
med_mat_eavae = np.zeros((len(med_test_mu_eavae),10))

Nsamp=1000
for j in range(Nsamp):
    np.random.seed(j)
    med_mat_vae+=clf_ws_vae.predict_proba(norm.rvs(size=(len(med_test_mu_vae),latent_dim_vae))*np.sqrt(med_test_var_vae)+med_test_mu_vae)/Nsamp
    np.random.seed(j)
    med_mat_eavae+=clf_ws_eavae.predict_proba(norm.rvs(size=(len(med_test_mu_eavae),latent_dim_eavae))*np.sqrt(med_test_var_eavae)+med_test_mu_eavae)/Nsamp
shuff_mat_vae = np.zeros((len(num_test_ShuffPix_mu_vae),10))
shuff_mat_eavae = np.zeros((len(num_test_ShuffPix_mu_eavae),10))

Nsamp=1000
for j in range(Nsamp):
    np.random.seed(j)
    shuff_mat_vae+=clf_ws_vae.predict_proba(norm.rvs(size=(len(num_test_ShuffPix_mu_vae),latent_dim_vae))*np.sqrt(num_test_ShuffPix_var_vae)+num_test_ShuffPix_mu_vae)/Nsamp
    np.random.seed(j)
    shuff_mat_eavae+=clf_ws_eavae.predict_proba(norm.rvs(size=(len(num_test_ShuffPix_mu_eavae),latent_dim_eavae))*np.sqrt(num_test_ShuffPix_var_eavae)+num_test_ShuffPix_mu_eavae)/Nsamp

In [60]:
np.save(out_location+'med_mat_vae.npy',med_mat_vae)
np.save(out_location+'med_mat_eavae.npy',med_mat_eavae)
np.save(out_location+'shuff_mat_vae.npy',shuff_mat_vae)
np.save(out_location+'shuff_mat_eavae.npy',shuff_mat_eavae)

In [None]:
entropies_all_vae = []
entropies_all_eavae = []

for l in range(10):
    for i in range(10):
        if i>l:
            predicts_vae = []
            predicts_eavae = []
            entropies_vae = []
            entropies_eavae = []
            for k in np.linspace(0,1,11):
                model_vae.eval()
                ins_vae, outs_vae, mus_vae, vars_vae = vae_fusion_numbers_pixel(model_vae, num_val_dataset, cat1=l, cat2=i, N=50, lam = k)
                matt_vae = np.zeros((len(mus_vae),10))
                for j in range(1000):
                    X_s_vae = norm.rvs(size=(len(mus_vae),latent_dim_vae))*np.sqrt(vars_vae)+mus_vae
                    matt_vae+=clf_ws_vae.predict_proba(X_s_vae)/1000
                
                predicts_vae.append(matt_vae.mean(axis=0))
                entropies_vae.append(scipy.stats.entropy(matt_vae,axis=1).mean())
                
                model_eavae.eval()
                ins_eavae, outs_eavae, mus_eavae, vars_eavae, zmus_eavae, zvars_eavae = eavae_fusion_numbers_pixel(model_eavae, num_val_dataset, cat1=l, cat2=i, N=50, lam = k)
                matt_eavae = np.zeros((len(mus_eavae),10))
                for j in range(1000):        
                    X_s_eavae = norm.rvs(size=(len(mus_eavae),latent_dim_eavae))*np.sqrt(vars_eavae)+mus_eavae
                    matt_eavae+=clf_ws_eavae.predict_proba(X_s_eavae)/1000
                predicts_eavae.append(matt_eavae.mean(axis=0))
                
                entropies_eavae.append(scipy.stats.entropy(matt_eavae,axis=1).mean())
            entropies_all_vae.append(np.array(entropies_vae))
            entropies_all_eavae.append(np.array(entropies_eavae))
            if l==2 and i==5:
                np.save(out_location+f'predicts_vae_{l}_{i}.npy', np.array(predicts_vae), allow_pickle=True)
                np.save(out_location+f'predicts_eavae_{l}_{i}.npy', np.array(predicts_eavae), allow_pickle=True)

np.save(out_location+f'entropies_all_vae.npy', np.array(entropies_all_vae), allow_pickle=True)
np.save(out_location+f'entropies_all_eavae.npy', np.array(entropies_all_eavae), allow_pickle=True) 