In [1]:
from pathlib import Path
import os
import itertools
import sys
# Add to sys path parent directory
sys.path.append('../')
import torch
import numpy as np
from tqdm import tqdm
from torchvision.datasets import MNIST
from torchvision.transforms import transforms
from sklearn.decomposition import PCA
import matplotlib.pyplot as plt
import seaborn as sns
import pandas as pd

from models.definitions.PCKTAE import PocketAutoencoder
from utils.dataloaders.full_dataloaders import DataLoaderMNIST, DataLoaderFashionMNIST, DataLoaderCIFAR10, DataLoaderCIFAR100

from utils.sampler import *
from optimization.fit_mapping import create_mapping
from utils.model import load_model, get_transformations

DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")

class Config:
    def __init__(self, **entries):
        self.__dict__.update(entries)

In [2]:
# Function to clear GPU memory
def clear_memory():
    torch.cuda.empty_cache()

def define_dataloader(file, file2, use_test_set=False):
    if file.strip("_")[0] != file2.strip("_")[0]:
        logging.error("The datasets are different")
    # Define the dataloaders
    name_dataset, name_model, size_of_the_latent, seed = file.strip(".pth").split("_")
    augumentation = get_transformations(name_model)
    if name_dataset.lower() == "mnist":
        dataloader = DataLoaderMNIST(transformation=augumentation, batch_size=64, seed=int(seed))
    if name_dataset.lower() == "fmnist":
        dataloader = DataLoaderFashionMNIST(transformation=augumentation,batch_size=64, seed=int(seed))
    if name_dataset.lower() == "cifar10":
        dataloader = DataLoaderCIFAR10(transformation=augumentation,batch_size=64, seed=int(seed))
    if name_dataset.lower() == "cifar100":
        dataloader = DataLoaderCIFAR100(transformation=augumentation,batch_size=64, seed=int(seed))
    if use_test_set:
        full_dataset_images, full_dataset_labels = dataloader.get_full_test_dataset()
    else:
        full_dataset_images, full_dataset_labels = dataloader.get_full_train_dataset()
    return full_dataset_images, full_dataset_labels, len(np.unique(full_dataset_labels.numpy()))


def calculate_and_save_mapping(model1, model2, sampling_strategy, sampled_images, parameters, file1, file2, num_samples, lamda, DEVICE, seed):

    name_dataset1, name_model1, size_of_the_latent1, seed1 = file1.strip(".pth").split("_")
    name_dataset2, name_model2, size_of_the_latent2, seed2 = file2.strip(".pth").split("_")

    # Set the model to evaluation and sends them to the DEVICE 
    model1.to(torch.float32).to(DEVICE).eval()
    model2.to(torch.float32).to(DEVICE).eval()
    # Get latent of the sampled images
    latent_left_sampled_equally = model1.get_latent_space(sampled_images)
    latent_right_sampled_equally = model2.get_latent_space(sampled_images)
    latent_left_sampled_equally = latent_left_sampled_equally.to(torch.float32).cpu().detach().numpy()
    latent_right_sampled_equally = latent_right_sampled_equally.to(torch.float32).cpu().detach().numpy()
    # Create mapping and visualize
    cfg = Config(**parameters)
    mapping = create_mapping(cfg, latent_left_sampled_equally, latent_right_sampled_equally, do_print=False)
    mapping.fit()
    storage_path = f'test/uncertain/{name_model2}/'
    Path(storage_path).mkdir(parents=True, exist_ok=True)
    filename = f"{file1.strip('.pth')}>{file2.strip('.pth')}>{cfg.mapping}_{num_samples}_{lamda}_{sampling_strategy}_{seed}"
    mapping.save_results(storage_path +  filename)
  

In [8]:
import pandas as pd
os.chdir('/Users/federicoferoggio/Documents/vs_code/latent-communication/')
print(os.getcwd())
shit_ones = pd.read_csv('filtered_worst_models.csv')



images, labels, n_classes = define_dataloader('FMNIST_VAE_10_2.pth', 'FMNIST_VAE_10_1.pth', True)
images = images.type(torch.float32)
labels = labels.type(torch.float32)
folder1 = "models/checkpoints/PCKTAE/FMNIST"
images_sampled_equally_old, labels_sampled_equally_old, images_sampled_drop_outliers_old, labels_sampled_drop_outliers_old, images_sampled_worst_classes_old, labels_sampled_worst_classes_old, images_sampled_best_classes_old, labels_sampled_convex_hull_old = None, None, None, None, None, None, None, None
past_num_samples, past_file1, past_file2 = None, None, None

for index, row in shit_ones.iterrows():
    # Access specific columns from the current row
    dataset = row['dataset']
    file1 = row['model1'].split("/")[-1]
    file2 = row['model2'].split("/")[-1]
    sampling_strategy = row['sampling_strategy']
    mapping = row['mapping']
    lamda = row['lambda']
    latent_dim = row['latent_dim']
    num_samples = row['num_samples']
    
    print(dataset, file1, file2, sampling_strategy, mapping, lamda, latent_dim, num_samples)

    for seed in range(1,10):
        # Loop through combinations
        parameters = {"num_samples": num_samples, "mapping": mapping, "lamda": lamda}
        name_dataset1, name_model1, size_of_the_latent1, seed1 = file1.strip(".pth").split("_")
        name_dataset2, name_model2, size_of_the_latent2, seed2 = file2.strip(".pth").split("_")
        model1 = load_model(model_name=name_model1, name_dataset=name_dataset1, latent_size=int(size_of_the_latent1), seed=int(seed1), model_path = f'models/checkpoints/{name_model1}/FMNIST/' + file1)
        model2 = load_model(model_name=name_model2, name_dataset=name_dataset2, latent_size=int(size_of_the_latent2), seed=int(seed2), model_path = f'models/checkpoints/{name_model2}/FMNIST/' + file2)
        
        images_sampled_equally, labels_sampled_equally = sample_equally_per_class_images(num_samples, images, labels)        
        mapping_equal = calculate_and_save_mapping(model1, model2, "equally", images_sampled_equally, parameters, file1, file2, num_samples, lamda, DEVICE, seed)
        images_sampled_drop_outliers, labels_sampled_drop_outliers = sample_removing_outliers(num_samples, images, labels, model2)
        mapping_outliers = calculate_and_save_mapping(model1, model2, "outliers", images_sampled_drop_outliers, parameters, file1, file2, num_samples, lamda, DEVICE, seed)
        images_sampled_worst_classes, labels_sampled_worst_classes = sample_with_half_worst_classes_images(num_samples, images, labels, model2)
        mapping_worst = calculate_and_save_mapping(model1, model2, "worst_classes", images_sampled_worst_classes, parameters, file1, file2, num_samples, lamda, DEVICE, seed)
        images_sampled_best_classes, labels_sampled_convex_hull = sample_convex_hulls_images(num_samples, images, labels, model1, pca_components=6)
        convex_hull = calculate_and_save_mapping(model1, model2, "convex_hull", images_sampled_best_classes, parameters, file1, file2, num_samples, lamda, DEVICE, seed)
        
            

/Users/federicoferoggio/Documents/vs_code/latent-communication
FMNIST FMNIST_VAE_8_1.pth FMNIST_VAE_32_2.pth equally Affine 0.0 32 10
FMNIST FMNIST_VAE_8_1.pth FMNIST_VAE_64_1.pth worst_classes Linear 0.0 64 10
FMNIST FMNIST_VAE_8_1.pth FMNIST_VAE_64_3.pth convex_hull Affine 0.0 64 10
FMNIST FMNIST_VAE_8_1.pth FMNIST_VAE_8_3.pth convex_hull Affine 0.0 8 10
FMNIST FMNIST_VAE_8_2.pth FMNIST_VAE_16_1.pth worst_classes Linear 0.0 16 10
FMNIST FMNIST_VAE_8_2.pth FMNIST_VAE_32_3.pth convex_hull Affine 0.0 32 10
FMNIST FMNIST_VAE_8_2.pth FMNIST_VAE_32_3.pth outliers Affine 0.0 32 10
FMNIST FMNIST_VAE_8_3.pth FMNIST_VAE_32_2.pth outliers Affine 0.0 32 10
FMNIST FMNIST_VAE_8_3.pth FMNIST_VAE_8_1.pth outliers Affine 0.0 8 10
FMNIST FMNIST_VAE_8_3.pth FMNIST_VAE_8_2.pth worst_classes Affine 0.0 8 10


In [None]:
FMNIST,models/checkpoints/VAE/FMNIST/FMNIST_VAE_8_1.pth,models/checkpoints/VAE/FMNIST/FMNIST_VAE_32_2.pth,equally,Affine,0.0,32,10,28.874072,0.0834881586,0.0797489309,1651.56704,0.00204923724,0.00176378939,4.5
FMNIST,models/checkpoints/VAE/FMNIST/FMNIST_VAE_8_1.pth,models/checkpoints/VAE/FMNIST/FMNIST_VAE_64_1.pth,worst_classes,Linear,0.0,64,10,4.65794181,0.0834104127,0.0796288721,37.9907646,0.0020393355429999998,0.0017536056600000003,4.5
FMNIST,models/checkpoints/VAE/FMNIST/FMNIST_VAE_8_1.pth,models/checkpoints/VAE/FMNIST/FMNIST_VAE_64_3.pth,convex_hull,Affine,0.0,64,10,28.5921986,0.0834103712,0.0801316827,1649.70649,0.00204669661,0.0017506697840000002,4.5
FMNIST,models/checkpoints/VAE/FMNIST/FMNIST_VAE_8_1.pth,models/checkpoints/VAE/FMNIST/FMNIST_VAE_8_3.pth,convex_hull,Affine,0.0,8,10,13.2719976,0.0834832263,0.0834708581,291.40680299999997,0.0020473502599999997,0.0020503054399999996,4.5
FMNIST,models/checkpoints/VAE/FMNIST/FMNIST_VAE_8_2.pth,models/checkpoints/VAE/FMNIST/FMNIST_VAE_16_1.pth,worst_classes,Linear,0.0,16,10,4.3547192699999995,0.0832747205,0.07996789039999999,30.527883099999997,0.00203345351,0.00170410079,4.5
FMNIST,models/checkpoints/VAE/FMNIST/FMNIST_VAE_8_2.pth,models/checkpoints/VAE/FMNIST/FMNIST_VAE_32_3.pth,convex_hull,Affine,0.0,32,10,17.71162885,0.0832973844,0.0809605525,565.966632,0.00203544781,0.00174444226,4.5
FMNIST,models/checkpoints/VAE/FMNIST/FMNIST_VAE_8_2.pth,models/checkpoints/VAE/FMNIST/FMNIST_VAE_32_3.pth,outliers,Affine,0.0,32,10,37.9047367,0.0832461314,0.0810075825,3082.8855350000003,0.00203670734,0.0017430990930000002,4.5
FMNIST,models/checkpoints/VAE/FMNIST/FMNIST_VAE_8_3.pth,models/checkpoints/VAE/FMNIST/FMNIST_VAE_32_2.pth,outliers,Affine,0.0,32,10,3.6231780000000002,0.0834485366,0.07971695009999999,20.9996735,0.00204700517,0.001763845875,4.5
FMNIST,models/checkpoints/VAE/FMNIST/FMNIST_VAE_8_3.pth,models/checkpoints/VAE/FMNIST/FMNIST_VAE_8_1.pth,outliers,Affine,0.0,8,10,7.439845579999999,0.08340290889999999,0.0834156411,94.4867519,0.0020464351860000002,0.00203552303,4.5
FMNIST,models/checkpoints/VAE/FMNIST/FMNIST_VAE_8_3.pth,models/checkpoints/VAE/FMNIST/FMNIST_VAE_8_2.pth,worst_classes,Affine,0.0,8,10,8.47010411,0.0835132505,0.08327263630000001,135.4445805,0.0020704502900000003,0.00203577258,4.5
