In [17]:
import os
import sys 
from sklearn.decomposition import PCA
from sklearn.utils.extmath import randomized_svd
import re

PROJECT_ROOT_DIR = "latent-communication"

current_dir = os.getcwd()

# Find the project root by walking up the directory tree
while current_dir:
    if os.path.basename(current_dir) == PROJECT_ROOT_DIR:
        break  # Found the project root!
    current_dir = os.path.dirname(current_dir)
else:
    raise FileNotFoundError(f"Project root '{PROJECT_ROOT_DIR}' not found in the directory tree.")

os.chdir(current_dir)
# Add the project root and any necessary subdirectories to sys.path
sys.path.insert(0, current_dir) 
sys.path.insert(0, os.path.join(current_dir, "utils"))  # Add the utils directory if needed

print(os.getcwd())

from pathlib import Path
import torch.nn as nn

import itertools
import torch
import numpy as np
from tqdm import tqdm
from torchvision.datasets import MNIST
from torchvision.transforms import transforms
from sklearn.decomposition import PCA
import matplotlib.pyplot as plt
import seaborn as sns
import pandas as pd
import logging
from optimization.fit_mapping import create_mapping
from utils.sampler import *
from optimization.optimizer import AffineFitting, ScalingFitting
from skimage.metrics import structural_similarity as ssim
from skimage.metrics import peak_signal_noise_ratio as psnr

from utils.dataloaders.full_dataloaders import DataLoaderMNIST, DataLoaderFashionMNIST, DataLoaderCIFAR10, DataLoaderCIFAR100
from utils.visualization import (
    visualize_mapping_error,
    visualize_latent_space,
    plot_latent_space,
    highlight_cluster,
)
from utils.model import load_model, get_transformations

DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")

class Config:
    def __init__(self, **entries):
        self.__dict__.update(entries)


/Users/mariotuci/Documents/latent-communication


In [18]:
def clear_memory():
    torch.cuda.empty_cache()

def define_dataloader(file, file2, use_test_set=False):
    if file.strip("_")[0] != file2.strip("_")[0]:
        logging.error("The datasets are different")
    # Define the dataloaders
    name_dataset, name_model, size_of_the_latent, seed = file.strip(".pth").split("_")
    augumentation = get_transformations(name_model)
    if name_dataset.lower() == "mnist":
        dataloader = DataLoaderMNIST(transformation=augumentation, batch_size=64, seed=int(seed))
    if name_dataset.lower() == "fmnist":
        dataloader = DataLoaderFashionMNIST(transformation=augumentation,batch_size=64, seed=int(seed))
    if name_dataset.lower() == "cifar10":
        dataloader = DataLoaderCIFAR10(transformation=augumentation,batch_size=64, seed=int(seed))
    if name_dataset.lower() == "cifar100":
        dataloader = DataLoaderCIFAR100(transformation=augumentation,batch_size=64, seed=int(seed))
    if use_test_set:
        full_dataset_images, full_dataset_labels = dataloader.get_full_test_dataset()
    else:
        full_dataset_images, full_dataset_labels = dataloader.get_full_train_dataset()
    return full_dataset_images, full_dataset_labels, len(np.unique(full_dataset_labels.numpy()))

def load_mapping(path,mapping):
    if mapping == 'Linear':
        from optimization.optimizer import LinearFitting
        mapping = LinearFitting.from_file(path)
    elif mapping == 'Affine':
        from optimization.optimizer import AffineFitting
        mapping = AffineFitting.from_file(path)
    elif mapping == 'NeuralNetwork':
        from optimization.optimizer import NeuralNetworkFitting
        mapping = NeuralNetworkFitting.from_file(path)
    elif mapping == 'Decouple':
        from optimization.optimizer import DecoupleFitting
        mapping = DecoupleFitting.from_file(path)
    else:
        raise ValueError("Invalid experiment name")
    return mapping

def get_same_mapping(files):
    # Extract the seed (last number)
    final_dict = {}
    for file in files:
        seed = re.findall(r'\d+', files)[-1]

        # Define the prefix (string before the last number and any extensions)
        prefix = re.sub(r'_[\d\.]+\.npy$', '', file)

        # Create a dictionary with prefix as key and seed as value
        information_dict = {
            'seed': int(seed),
            'mse': 0,
        }
        # Append the dictionary to the list

        final_dict[prefix] = information_dict
  
    return final_dict

In [21]:
##############################################
#Specify here which files you want to use
folder1 = "models/checkpoints/VAE/FMNIST"
folder2 = "models/checkpoints/VAE/FMNIST"

dataset="FMNIST"
number_samples = [50]
mapping_list = ["Linear"]
lamda_list = [0]
sampling_strategy = "equally"
filter = "_16_" #write here if you want that the processed files contain this string (example "_50_" to only process the files with latent size 50)
###############################################
## this autiomatically creates all the possible setups with the paramenters and the files you speicified, and sets up the correct dataset
files1 = [f for f in os.listdir(folder1) if f.endswith(".pth") and filter in f]
files2 = [f for f in os.listdir(folder2) if f.endswith(".pth") and filter in f]


list_of_files = [(f1, f2) for f1, f2 in itertools.product(files1, files2) if f1 != f2]




combinations_parameters = list(itertools.product(number_samples, mapping_list, lamda_list))
pbar = tqdm(list(itertools.product(list_of_files, combinations_parameters)))
images, labels, n_classes = define_dataloader(files1[0], files2[0], use_test_set=True)
images = images.type(torch.float32)
labels = labels.type(torch.float32)
criterion = nn.MSELoss()
# Result of all the combinations 
results = []
criterion = nn.MSELoss()
# Get for each class the corresponding indices 
class_indices = {i: np.where(labels.numpy() == i)[0] for i in range(n_classes)}
# Get the corresponding class for each index
indices_class = {i: labels.numpy()[i] for i in range(len(labels))}

result_data = pd.DataFrame(columns=['reconstruction_error', 'reconstruction_error_model1', 'reconstruction_error_model2', 'name'])
result_data_list = []

for (file1, file2), (num_samples, mapping_name, lamda) in pbar:
  name_dataset1, name_model1, size_of_the_latent1, seed1 = file1.strip(".pth").split("_")
  name_dataset2, name_model2, size_of_the_latent2, seed2 = file2.strip(".pth").split("_")
  #Load model
  model1 = load_model(model_name=name_model1, name_dataset=name_dataset1, latent_size=int(size_of_the_latent1), seed=int(seed1), model_path = folder1 +"/"+ file1)
  model2 = load_model(model_name=name_model2, name_dataset=name_dataset2, latent_size=int(size_of_the_latent2), seed=int(seed2), model_path = folder1 +"/"+ file2)

  #Load mapping
  for seed in range(1,20):

    mapping_path = f'test/uncertain//{name_model2}/{file1.strip(".pth")}>{file2.strip(".pth")}>{mapping_name}_{num_samples}_{lamda}_{sampling_strategy}_{seed}'
    mapping = load_mapping(mapping_path,mapping_name)
    #Calculate latent spaces
    latent_left = model1.get_latent_space(images).detach().cpu().numpy()
    latent_right = model2.get_latent_space(images).detach().cpu().numpy()
    transformed_latent_space = mapping.transform(latent_left)

    #Model1 output
    decoded_left = model1.decode(torch.tensor(latent_left, dtype=torch.float32).to(images.device)).detach().cpu().numpy()
    # Model2 output
    decoded_right = model2.decode(torch.tensor(latent_right, dtype=torch.float32).to(images.device)).detach().cpu().numpy()
    #Get stitched output
    decoded_transformed = model2.decode(torch.tensor(transformed_latent_space, dtype=torch.float32).to(images.device)).detach().cpu().numpy()


    mse_loss = criterion(torch.tensor(decoded_transformed),images).item()
    # MSE loss for the model1 
    mse_loss_model1 = criterion(torch.tensor(decoded_left),images).item()
    # MSE loss for the model2
    mse_loss_model2 = criterion(torch.tensor(decoded_right),images).item()

    name = f'{name_model2}/{file1.strip(".pth")}>{file2.strip(".pth")}>{mapping_name}_{num_samples}_{lamda}_{sampling_strategy}'

    # Add to pandas dataframe

    result_data_list.append({
      'reconstruction_error': mse_loss,
      'reconstruction_error_model1': mse_loss_model1,
      'reconstruction_error_model2': mse_loss_model2,
      'name': name
    })

    result_data = pd.DataFrame(result_data_list)
  


 


  0%|          | 0/6 [00:00<?, ?it/s]

  decoded_transformed = model2.decode(torch.tensor(transformed_latent_space, dtype=torch.float32).to(images.device)).detach().cpu().numpy()
100%|██████████| 6/6 [02:43<00:00, 27.29s/it]


In [24]:
# Group by the name and get the mean and standard deviation
result_data_grouped = result_data.groupby('name').agg(['mean', 'std'])

# Show the data Frame
result_data_grouped 

Unnamed: 0_level_0,reconstruction_error,reconstruction_error,reconstruction_error_model1,reconstruction_error_model1,reconstruction_error_model2,reconstruction_error_model2
Unnamed: 0_level_1,mean,std,mean,std,mean,std
name,Unnamed: 1_level_2,Unnamed: 2_level_2,Unnamed: 3_level_2,Unnamed: 4_level_2,Unnamed: 5_level_2,Unnamed: 6_level_2
VAE/FMNIST_VAE_16_1>FMNIST_VAE_16_2>Linear_50_0_convex_hull,0.136787,0.009693,0.079943,3.1e-05,0.0836,2.9e-05
VAE/FMNIST_VAE_16_1>FMNIST_VAE_16_3>Linear_50_0_convex_hull,0.156299,0.013303,0.079934,5.3e-05,0.081642,3.5e-05
VAE/FMNIST_VAE_16_2>FMNIST_VAE_16_1>Linear_50_0_convex_hull,0.144384,0.012714,0.083585,3.6e-05,0.079927,4e-05
VAE/FMNIST_VAE_16_2>FMNIST_VAE_16_3>Linear_50_0_convex_hull,0.120135,0.009818,0.083587,3e-05,0.081628,3.7e-05
VAE/FMNIST_VAE_16_3>FMNIST_VAE_16_1>Linear_50_0_convex_hull,0.164794,0.022424,0.081626,4.2e-05,0.079917,3.4e-05
VAE/FMNIST_VAE_16_3>FMNIST_VAE_16_2>Linear_50_0_convex_hull,0.133023,0.015326,0.081645,4.5e-05,0.083603,3.9e-05
