In [1]:
from pathlib import Path
import os
import itertools

os.chdir("/Users/federicoferoggio/Documents/vs_code/latent-communication")

import torch
import numpy as np
from tqdm import tqdm
from torchvision.datasets import MNIST
from torchvision.transforms import transforms
from sklearn.decomposition import PCA
import matplotlib.pyplot as plt
import seaborn as sns
import pandas as pd

from skimage.metrics import structural_similarity as ssim
from skimage.metrics import peak_signal_noise_ratio as psnr

from models.definitions.smallae import PocketAutoencoder
from utils.dataloaders.full_dataloaders import DataLoaderMNIST, DataLoaderFashionMNIST, DataLoaderCIFAR10, DataLoaderCIFAR100
from utils.visualization import (
    visualize_mapping_error,
    visualize_latent_space_pca,
    plot_latent_space,
    highlight_cluster,
)
from utils.sampler import *
from optimization.fit_mapping import create_mapping
from utils.metrics import calculate_MSE_ssim_psnr
from utils.model import load_model, get_transformations

DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")

class Config:
    def __init__(self, **entries):
        self.__dict__.update(entries)

In [2]:
# Function to clear GPU memory
def clear_memory():
    torch.cuda.empty_cache()

def define_dataloader(file, file2, use_test_set=False):
    if file.strip("_")[0] != file2.strip("_")[0]:
        logging.error("The datasets are different")
    # Define the dataloaders
    name_dataset, name_model, size_of_the_latent, seed = file.strip(".pth").split("_")
    augumentation = get_transformations(name_model)
    if name_dataset.lower() == "mnist":
        dataloader = DataLoaderMNIST(transformation=augumentation, batch_size=64, seed=int(seed))
    if name_dataset.lower() == "fmnist":
        dataloader = DataLoaderFashionMNIST(transformation=augumentation,batch_size=64, seed=int(seed))
    if name_dataset.lower() == "cifar10":
        dataloader = DataLoaderCIFAR10(transformation=augumentation,batch_size=64, seed=int(seed))
    if name_dataset.lower() == "cifar100":
        dataloader = DataLoaderCIFAR100(transformation=augumentation,batch_size=64, seed=int(seed))
    if use_test_set:
        full_dataset_images, full_dataset_labels = dataloader.get_full_test_dataset()
    else:
        full_dataset_images, full_dataset_labels = dataloader.get_full_train_dataset()
    return full_dataset_images, full_dataset_labels, len(np.unique(full_dataset_labels.numpy()))


def calculate_and_save_mapping(model1, model2, sampling_strategy, sampled_images, parameters, file1, file2, transformations_database, num_samples, lamda, DEVICE):

    name_dataset1, name_model1, size_of_the_latent1, seed1 = file1.strip(".pth").split("_")
    name_dataset2, name_model2, size_of_the_latent2, seed2 = file2.strip(".pth").split("_")

    # Set the model to evaluation and sends them to the DEVICE 
    model1.to(torch.float32).to(DEVICE).eval()
    model2.to(torch.float32).to(DEVICE).eval()
    # Get latent of the sampled images
    latent_left_sampled_equally = model1.get_latent_space(sampled_images)
    latent_right_sampled_equally = model2.get_latent_space(sampled_images)
    latent_left_sampled_equally = latent_left_sampled_equally.to(torch.float32).cpu().detach().numpy()
    latent_right_sampled_equally = latent_right_sampled_equally.to(torch.float32).cpu().detach().numpy()
    # Create mapping and visualize
    cfg = Config(**parameters)
    mapping = create_mapping(cfg, latent_left_sampled_equally, latent_right_sampled_equally, do_print=False)
    mapping.fit()
    storage_path = f'results/transformations/mapping_files/{name_model2}/'
    Path(storage_path).mkdir(parents=True, exist_ok=True)
    filename = f"{file1.strip('.pth')}>{file1.strip('.pth')}>{cfg.mapping}_{num_samples}_{lamda}_{sampling_strategy}"
    mapping.save_results(storage_path + filename)
    transformations_database = pd.concat([transformations_database, pd.DataFrame({"model1": [file1], "model2": [file2], "mapping": [storage_path]})], ignore_index=True)
    return transformations_database

SyntaxError: f-string: unmatched '(' (2114703214.py, line 45)

In [None]:
try:
    df_save_mappings = pd.read_csv("results/transformations/mapping_files/transfomations_index.csv")
except:
    df_save_mappings = pd.DataFrame(columns=["model1", "model2", "mapping"])

## Here is the part that you have to modify however you want
## Define directories where you want to ieratively create the mapping, and then write down the parameters you want to use
os.chdir("/Users/federicoferoggio/Documents/vs_code/latent-communication")
folder1 = "models/checkpoints/SMALLAE/MNIST/"
folder2 = "models/checkpoints/SMALLAE/MNIST/"
number_samples = [10,50,100]    #[10,50,100,200,300]
mapping_list = ["Linear", "Affine"]
lamda_list = [0]    #[0,0.1, 0.01]
use_test_set = False
filter = '_50_' #write here if you want that the processed files contain this string (example "_50_" to only process the files with latent size 50)


## this autiomatically creates all teh possible setups with the paramenters and the files you speicified, and sets up the correct dataset
files1 = [f for f in os.listdir(folder1) if f.endswith(".pth") and filter in f]
files2 = [f for f in os.listdir(folder2) if f.endswith(".pth") and filter in f]
list_of_files = [(f1, f2) for f1, f2 in itertools.product(files1, files2) if f1 != f2]
combinations_parameters = list(itertools.product(number_samples, mapping_list, lamda_list))
pbar = tqdm(list(itertools.product(list_of_files, combinations_parameters)))
images, labels, n_classes = define_dataloader(files1[0], files2[0], use_test_set)
images = images.type(torch.float32)
labels = labels.type(torch.float32)

# Loop through combinations
for (file1, file2), (num_samples, mapping, lamda) in pbar:
    parameters = {"num_samples": num_samples, "mapping": mapping, "lamda": lamda} #This is done to go around some hydra stuff (<3 kai)
    name_dataset1, name_model1, size_of_the_latent1, seed1 = file1.strip(".pth").split("_")
    name_dataset2, name_model2, size_of_the_latent2, seed2 = file2.strip(".pth").split("_")

    model1 = load_model(model_name=name_model1, name_dataset=name_dataset1, latent_size=size_of_the_latent1, seed=seed1, model_path = folder1 + file1)
    model2 = load_model(model_name=name_model2, name_dataset=name_dataset2, latent_size=size_of_the_latent2, seed=seed2, model_path = folder1 + file2)
    pbar.set_description("Sampling equally per class")
    images_sampled_equally, labels_sampled_equally = sample_equally_per_class_images(num_samples, images, labels)
    pbar.set_description("Sampling removing outliers")
    images_sampled_max_distance, labels_sampled_drop_outliers = sample_removing_outliers(num_samples, images, labels, model2)
    pbar.set_description("Sampling worst classes")
    images_sampled_worst_classes, labels_sampled_worst_classes = sample_with_half_worst_classes_images(num_samples, images, labels, model2)
    pbar.set_description("Sampling convex hull")
    images_sampled_best_classes, labels_sampled_convex_hull = sample_convex_hulls_images(num_samples, images, labels, model1)
    pbar.set_description("Processing %s and %s" % (file1, file2))
    df_save_mappings = calculate_and_save_mapping(model1, model2, "equally", images_sampled_equally, parameters, file1, file2, df_save_mappings, num_samples, lamda, DEVICE)
    df_save_mappings = calculate_and_save_mapping(model1, model2, "outliers", images_sampled_max_distance, parameters, file1, file2, df_save_mappings, num_samples, lamda, DEVICE)
    df_save_mappings = calculate_and_save_mapping(model1, model2, "worst_classes", images_sampled_worst_classes, parameters, file1, file2, df_save_mappings, num_samples, lamda, DEVICE)
    df_save_mappings = calculate_and_save_mapping(model1, model2, "convex_hull", images_sampled_best_classes, parameters, file1, file2, df_save_mappings, num_samples, lamda, DEVICE)
    pbar.set_description("Processed %s and %s" % (file1, file2))
df_save_mappings.to_csv("results/transformations/mapping_files/transfomations_index.csv", index=False)


  return torch.tensor(train_data), torch.tensor(train_labels)
Sampling worst classes:   0%|          | 3/2160 [06:23<76:39:17, 127.94s/it]                                    


KeyboardInterrupt: 