In [1]:
from pathlib import Path
import os
import itertools

os.chdir('/Users/federicoferoggio/Documents/vs_code/latent-communication')
import torch
import numpy as np
from tqdm import tqdm
from torchvision.datasets import MNIST
from torchvision.transforms import transforms
from sklearn.decomposition import PCA
import matplotlib.pyplot as plt
import seaborn as sns
import pandas as pd

from skimage.metrics import structural_similarity as ssim
from skimage.metrics import peak_signal_noise_ratio as psnr

from models.definitions.PCKTAE import PocketAutoencoder
from utils.dataloaders.full_dataloaders import DataLoaderMNIST, DataLoaderFashionMNIST, DataLoaderCIFAR10, DataLoaderCIFAR100
from utils.visualization import (
    visualize_mapping_error,
    visualize_latent_space_pca,
    plot_latent_space,
    highlight_cluster,
)
from utils.sampler import *
from optimization.fit_mapping import create_mapping
from utils.metrics import calculate_MSE_ssim_psnr
from utils.model import load_model, get_transformations

DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")

class Config:
    def __init__(self, **entries):
        self.__dict__.update(entries)

In [2]:
# Function to clear GPU memory
def clear_memory():
    torch.cuda.empty_cache()

def define_dataloader(file, file2, use_test_set=False):
    if file.strip("_")[0] != file2.strip("_")[0]:
        logging.error("The datasets are different")
    # Define the dataloaders
    name_dataset, name_model, size_of_the_latent, seed = file.strip(".pth").split("_")
    augumentation = get_transformations(name_model)
    if name_dataset.lower() == "mnist":
        dataloader = DataLoaderMNIST(transformation=augumentation, batch_size=64, seed=int(seed))
    if name_dataset.lower() == "fmnist":
        dataloader = DataLoaderFashionMNIST(transformation=augumentation,batch_size=64, seed=int(seed))
    if name_dataset.lower() == "cifar10":
        dataloader = DataLoaderCIFAR10(transformation=augumentation,batch_size=64, seed=int(seed))
    if name_dataset.lower() == "cifar100":
        dataloader = DataLoaderCIFAR100(transformation=augumentation,batch_size=64, seed=int(seed))
    if use_test_set:
        full_dataset_images, full_dataset_labels = dataloader.get_full_test_dataset()
    else:
        full_dataset_images, full_dataset_labels = dataloader.get_full_train_dataset()
    return full_dataset_images, full_dataset_labels, len(np.unique(full_dataset_labels.numpy()))


def calculate_and_save_mapping(model1, model2, sampling_strategy, sampled_images, parameters, file1, file2, transformations_database, num_samples, lamda, DEVICE):

    name_dataset1, name_model1, size_of_the_latent1, seed1 = file1.strip(".pth").split("_")
    name_dataset2, name_model2, size_of_the_latent2, seed2 = file2.strip(".pth").split("_")

    # Set the model to evaluation and sends them to the DEVICE 
    model1.to(torch.float32).to(DEVICE).eval()
    model2.to(torch.float32).to(DEVICE).eval()
    # Get latent of the sampled images
    latent_left_sampled_equally = model1.get_latent_space(sampled_images)
    latent_right_sampled_equally = model2.get_latent_space(sampled_images)
    latent_left_sampled_equally = latent_left_sampled_equally.to(torch.float32).cpu().detach().numpy()
    latent_right_sampled_equally = latent_right_sampled_equally.to(torch.float32).cpu().detach().numpy()
    # Create mapping and visualize
    cfg = Config(**parameters)
    mapping = create_mapping(cfg, latent_left_sampled_equally, latent_right_sampled_equally, do_print=False)
    mapping.fit()
    storage_path = f'../results/transformations/mapping_files/{name_model2}/'
    Path(storage_path).mkdir(parents=True, exist_ok=True)
    filename = f"{file1.strip('.pth')}>{file2.strip('.pth')}>{cfg.mapping}_{num_samples}_{lamda}_{sampling_strategy}"
    mapping.save_results(storage_path +  filename)
    transformations_database = pd.concat([transformations_database, pd.DataFrame({"model1": [file1], "model2": [file2], "mapping": [storage_path]})], ignore_index=True)
    return transformations_database

In [7]:
try:
    df_save_mappings = pd.read_csv("/results/transformations/mapping_files/transfomations_index.csv")
except:
    df_save_mappings = pd.DataFrame(columns=["model1", "model2", "mapping"])

## Here is the part that you have to modify however you want
## Define directories where you want to ieratively create the mapping, and then write down the parameters you want to use
import os 
# Set the directories where the models are stored

folder1 = "models/checkpoints/PCKTAE/FMNIST"
folder2 = "models/checkpoints/PCKTAE/FMNIST"
number_samples = [10,50,100,200, 300]    #[10,50,100,200,300]
mapping_list = ["Linear", "Affine"]
lamda_list = [0,0.1,0.01]    #[0,0.1, 0.01]
use_test_set = False
filter1 = '_' #write here if you want that the processed files contain this string (example "_50_" to only process the files with latent size 50)
filter2 = '_' #write here if you want that the processed files contain this string (example "_50_" to only process the files with latent size 50)
recalculate = False #If you want to recalculate the mappings, set this to True
use_same_sampling = True #If you want to use the same sampling points wherever possible for all the models, set this to True


## this autiomatically creates all teh possible setups with the paramenters and the files you speicified, and sets up the correct dataset
files1 = [f for f in os.listdir(folder1) if f.endswith(".pth") and filter1 in f]
files2 = [f for f in os.listdir(folder2) if f.endswith(".pth") and filter2 in f]
list_of_files = [(f1, f2) for f1, f2 in itertools.product(files1, files2) if f1 != f2]
combinations_parameters = list(itertools.product(number_samples, mapping_list, lamda_list))
combinations = list(itertools.product(list_of_files, combinations_parameters))
flattened_combinations = [(file1, file2, param1, param2, param3) for ((file1, file2), (param1, param2, param3)) in combinations]

# Sort the flattened list by all elements
sorted_combinations = sorted(flattened_combinations, key=lambda x: (x[2], x[1], x[0], x[3], x[4]))
print(sorted_combinations)
pbar = tqdm(sorted_combinations)

images, labels, n_classes = define_dataloader(files1[0], files2[0], use_test_set)
images = images.type(torch.float32)
labels = labels.type(torch.float32)

images_sampled_equally, labels_sampled_equally, images_sampled_drop_outliers, labels_sampled_drop_outliers, images_sampled_worst_classes, labels_sampled_worst_classes, images_sampled_best_classes, labels_sampled_convex_hull = [None]*8
past_num_samples, past_file1, past_file2 = None, None, None

# Loop through combinations
for file1, file2, num_samples, mapping, lamda in pbar:
    parameters = {"num_samples": num_samples, "mapping": mapping, "lamda": lamda} #This is done to go around some hydra stuff (<3 kai)
    name_dataset1, name_model1, size_of_the_latent1, seed1 = file1.strip(".pth").split("_")
    name_dataset2, name_model2, size_of_the_latent2, seed2 = file2.strip(".pth").split("_")

    model1 = load_model(model_name=name_model1, name_dataset=name_dataset1, latent_size=int(size_of_the_latent1), seed=int(seed1), model_path = folder1 + '/' + file1)
    model2 = load_model(model_name=name_model2, name_dataset=name_dataset2, latent_size=int(size_of_the_latent2), seed=int(seed2), model_path = folder1 + '/' + file2)
    if recalculate or str(f"{file1.strip('.pth')}>{file2.strip('.pth')}>{mapping}_{num_samples}_{lamda}_{'equally'}") not in os.listdir(f'results/transformations/mapping_files/{name_model2}'):
        if past_num_samples != num_samples or not use_same_sampling:
            pbar.set_description("Sampling equally per class")
            images_sampled_equally, labels_sampled_equally = sample_equally_per_class_images(num_samples, images, labels)        
        df_save_mappings = calculate_and_save_mapping(model1, model2, "equally", images_sampled_equally, parameters, file1, file2, df_save_mappings, num_samples, lamda, DEVICE)
    if recalculate or str(f"{file1.strip('.pth')}>{file2.strip('.pth')}>{mapping}_{num_samples}_{lamda}_{'outliers'}") not in os.listdir(f'results/transformations/mapping_files/{name_model2}'):
        if past_num_samples != num_samples or not use_same_sampling or past_file2 != file2:    
            pbar.set_description("Sampling removing outliers")
            images_sampled_drop_outliers, labels_sampled_drop_outliers = sample_removing_outliers(num_samples, images, labels, model2)
        df_save_mappings = calculate_and_save_mapping(model1, model2, "outliers", images_sampled_drop_outliers, parameters, file1, file2, df_save_mappings, num_samples, lamda, DEVICE)
    if recalculate or str(f"{file1.strip('.pth')}>{file2.strip('.pth')}>{mapping}_{num_samples}_{lamda}_{'worst_classes'}") not in os.listdir(f'results/transformations/mapping_files/{name_model2}'):
        if past_num_samples != num_samples or not use_same_sampling or past_file2 != file2:
            pbar.set_description("Sampling worst classes")
            images_sampled_worst_classes, labels_sampled_worst_classes = sample_with_half_worst_classes_images(num_samples, images, labels, model2)
        df_save_mappings = calculate_and_save_mapping(model1, model2, "worst_classes", images_sampled_worst_classes, parameters, file1, file2, df_save_mappings, num_samples, lamda, DEVICE)
    if (recalculate) or str(f"{file1.strip('.pth')}>{file2.strip('.pth')}>{mapping}_{num_samples}_{lamda}_{'best_classes'}") not in os.listdir(f'results/transformations/mapping_files/{name_model2}'):
        if past_num_samples != num_samples or not use_same_sampling or past_file1 != file1:
            pbar.set_description("Sampling convex hull")
            images_sampled_best_classes, labels_sampled_convex_hull = sample_convex_hulls_images(num_samples, images, labels, model1)
        df_save_mappings = calculate_and_save_mapping(model1, model2, "convex_hull", images_sampled_best_classes, parameters, file1, file2, df_save_mappings, num_samples, lamda, DEVICE)
    past_num_samples, past_file1, past_file2 = num_samples, file1, file2
df_save_mappings.to_csv("results/transformations/mapping_files/transfomations_index.csv", index=False)


[('FMNIST_PCKTAE_10_2.pth', 'FMNIST_PCKTAE_10_1.pth', 10, 'Affine', 0), ('FMNIST_PCKTAE_10_2.pth', 'FMNIST_PCKTAE_10_1.pth', 10, 'Affine', 0.01), ('FMNIST_PCKTAE_10_2.pth', 'FMNIST_PCKTAE_10_1.pth', 10, 'Affine', 0.1), ('FMNIST_PCKTAE_10_2.pth', 'FMNIST_PCKTAE_10_1.pth', 10, 'Linear', 0), ('FMNIST_PCKTAE_10_2.pth', 'FMNIST_PCKTAE_10_1.pth', 10, 'Linear', 0.01), ('FMNIST_PCKTAE_10_2.pth', 'FMNIST_PCKTAE_10_1.pth', 10, 'Linear', 0.1), ('FMNIST_PCKTAE_10_3.pth', 'FMNIST_PCKTAE_10_1.pth', 10, 'Affine', 0), ('FMNIST_PCKTAE_10_3.pth', 'FMNIST_PCKTAE_10_1.pth', 10, 'Affine', 0.01), ('FMNIST_PCKTAE_10_3.pth', 'FMNIST_PCKTAE_10_1.pth', 10, 'Affine', 0.1), ('FMNIST_PCKTAE_10_3.pth', 'FMNIST_PCKTAE_10_1.pth', 10, 'Linear', 0), ('FMNIST_PCKTAE_10_3.pth', 'FMNIST_PCKTAE_10_1.pth', 10, 'Linear', 0.01), ('FMNIST_PCKTAE_10_3.pth', 'FMNIST_PCKTAE_10_1.pth', 10, 'Linear', 0.1), ('FMNIST_PCKTAE_30_1.pth', 'FMNIST_PCKTAE_10_1.pth', 10, 'Affine', 0), ('FMNIST_PCKTAE_30_1.pth', 'FMNIST_PCKTAE_10_1.pth', 10,

Sampling equally per class:  18%|█▊        | 393/2160 [02:43<12:16,  2.40it/s]  

Failure:interrupted





SolverError: Solver 'SCS' failed. Try another solver, or solve with verbose=True for more information.