In [None]:
from pathlib import Path
import os

os.chdir("/Users/federicoferoggio/Documents/vs_code/latent-communication")

import itertools

import torch
import numpy as np
from tqdm import tqdm
from torchvision.datasets import MNIST
from torchvision.transforms import transforms
from sklearn.decomposition import PCA
import matplotlib.pyplot as plt
import seaborn as sns
import pandas as pd

from skimage.metrics import structural_similarity as ssim
from skimage.metrics import peak_signal_noise_ratio as psnr

from models.definitions.PCKTAE import PocketAutoencoder
from utils.dataloaders.full_dataloaders import DataLoaderMNIST
from utils.visualization import (
    visualize_mapping_error,
    visualize_latent_space_pca,
    plot_latent_space,
    highlight_cluster,
)
from utils.sampler import *

DEVICE = torch.device("cuda") if torch.cuda.is_available() else "cpu"
augmentations = [transforms.ToTensor(), transforms.Normalize((0.5,), (0.5,))]


class Config:
    def __init__(self, **entries):
        self.__dict__.update(entries)

In [None]:
def create_mapping(cfg, latents1, latents2):
    if cfg.mapping == "Linear":
        from optimization.optimizer import LinearFitting

        mapping = LinearFitting(latents1, latents2, lamda=cfg.lamda)
    elif cfg.mapping == "Affine":
        from optimization.optimizer import AffineFitting

        mapping = AffineFitting(latents1, latents2, lamda=cfg.lamda)
    elif cfg.mapping == "NeuralNetwork":
        from optimization.optimizer import NeuralNetworkFitting

        mapping = NeuralNetworkFitting(
            latents1,
            latents2,
            hidden_dim=cfg.hidden_size,
            lamda=cfg.lamda,
            learning_rate=cfg.learning_rate,
            epochs=cfg.epochs,
        )
    else:
        raise ValueError("Invalid experiment name")
    return mapping

def calculate__MSE_ssim_psnr(original, reconstructed, data_range=1.0):
    original_np = original.cpu().numpy().squeeze()
    mse_value = np.mean((original_np - reconstructed) ** 2)
    if original_np.shape != reconstructed.shape:
        raise ValueError(f"Shape mismatch: original shape {original_np.shape}, reconstructed shape {reconstructed.shape}")

    ssim_value = ssim(original_np, reconstructed, data_range=data_range)
    psnr_value = psnr(original_np, reconstructed, data_range=data_range)
    return mse_value, ssim_value, psnr_value

def visualize_modified_latent_space_pca(
    latents_trans,
    latents_2,
    labels,
    fig_path=None,
    anchors=None,
    pca=None,
    size=10,
    bg_alpha=1,
    alpha=1,
    title="2D PCA of Latent Space",
):
    """
    Visualizes the 2D latent space obtained from PCA.

    Args:
        latents_trans: A tensor of shape (N, dim) representing the first set of latent points.
        latents_2: A tensor of shape (N, dim) representing the second set of latent points.
        labels: A tensor of shape (N,) representing the labels for each latent point.
        fig_path: Optional; Path to save the figure.
        anchors: Optional; A tensor of shape (M, dim) representing anchor points in the latent space.
        pca: Optional; A PCA object to use for transforming the latent space.
        size: Optional; Size of the points in the plot.
        bg_alpha: Optional; Alpha value for the background points.
        alpha: Optional; Alpha value for the highlighted points.
    """
    # Convert lists to tensors if needed
    if isinstance(latents_trans, list):
        latents_trans = torch.tensor(latents_trans)
    if isinstance(latents_2, list):
        latents_2 = torch.tensor(latents_2)
    labels = np.asarray(labels)

    # Concatenate latent spaces
    latents = torch.cat([latents_trans, latents_2], dim=0)
    print(latents.shape)

    if pca is None:
        pca = PCA(n_components=2)
        latents_2d = pca.fit_transform(latents)
    else:
        latents_2d = pca.transform(latents)

    # Normalize latents
    minimum = latents_2d.min(axis=0)
    maximum = latents_2d.max(axis=0)
    latents_2d -= minimum
    latents_2d /= maximum

    # Separate the two datasets
    latents_trans_2d = latents_2d[: len(latents_trans)]
    latents_2_2d = latents_2d[len(latents_trans) :]

    # Create DataFrames for easy plotting
    latent_df_trans = pd.DataFrame(latents_trans_2d, columns=["x", "y"])
    latent_df_trans["target"] = labels

    latent_df_2 = pd.DataFrame(latents_2_2d, columns=["x", "y"])
    latent_df_2["target"] = labels

    # Plot the 2D latent space
    fig, ax = plt.subplots(figsize=(6, 6))
    cmap = plt.get_cmap("tab10")
    norm = plt.Normalize(
        latent_df_trans["target"].min(), latent_df_trans["target"].max()
    )

    ax = plot_latent_space(
        ax,
        latent_df_trans,
        targets=np.unique(labels),
        size=size,
        cmap=cmap,
        norm=norm,
        bg_alpha=bg_alpha,
        alpha=alpha,
        marker="2",
    )
    ax = plot_latent_space(
        ax,
        latent_df_2,
        targets=np.unique(labels),
        size=size,
        cmap=cmap,
        norm=norm,
        bg_alpha=bg_alpha,
        alpha=alpha,
        marker="1",
    )

    if anchors is not None:
        # plot anchors with star marker
        anchors_2d = pca.transform(anchors.cpu().detach().numpy())
        anchors_2d -= minimum
        anchors_2d /= maximum
        ax.scatter(anchors_2d[:, 0], anchors_2d[:, 1], marker="*", s=50, c="black")

    plt.title(title)
    if fig_path is not None:
        plt.savefig(fig_path)
    plt.show()

# Function to clear GPU memory
def clear_memory():
    torch.cuda.empty_cache()

def plot_and_calculate_all(model1, model2, images, labels, sampling_strategy, sampled_images, parameters, file1, file2, df, num_samples, lamda, losses_per_class):
    
    model1.eval()
    model2.eval()
    # Get latent
    latent_left = model1.get_latent_space(images).detach().cpu().numpy()
    latent_right = model2.get_latent_space(images).detach().cpu().numpy()
    latent_left_sampled_equally = torch.tensor(model1.get_latent_space(sampled_images).detach().cpu().numpy())
    latent_right_sampled_equally = torch.tensor(model2.get_latent_space(sampled_images).detach().cpu().numpy())
    # Create mapping and visualize
    cfg = Config(**parameters)
    mapping = create_mapping(cfg, latent_left_sampled_equally, latent_right_sampled_equally)
    mapping.fit()
    storage_path = f'results/transformations/SMALLAE/{cfg.mapping}_{file1[:-len(".pth")]}_->_{file2[:-len(".pth")]}_{num_samples}_{lamda}/{sampling_strategy}/'
    Path(storage_path).mkdir(parents=True, exist_ok=True)
    mapping.save_results(storage_path + "mapping")
    df = pd.concat([df, pd.DataFrame({"model1": [file1], "model2": [file2], "mapping": [storage_path + "mapping"]})], ignore_index=True)
    transformed_latent_space = mapping.transform(latent_left)
    _, latents1_2d = visualize_latent_space_pca(latents=latent_left, labels=labels, fig_path= storage_path + "latent_left_sampled_equally.png", anchors=latent_left_sampled_equally, title="Model 1 Equal Points", alpha=0.7, show_fig=False)
    pca, latents2_2d = visualize_latent_space_pca(latents=latent_right, labels=labels, fig_path=storage_path + "latent_right_sampled_equally.png", anchors=latent_right_sampled_equally, title="Model 2 Equal Points", alpha=0.7, show_fig=False)
    _, latents1_trafo_2d = visualize_latent_space_pca(transformed_latent_space, labels, storage_path + "latents1_transformed_sampled_equally.png", anchors=mapping.transform(latent_left_sampled_equally), pca=pca, title=f"Transformation with an {cfg.mapping} Mapping, lambda={cfg.lamda}, num_samples={cfg.num_samples} Equal Points", alpha=0.7, show_fig=False)
    errors = np.linalg.norm(transformed_latent_space - latent_right, axis=1)
    visualize_mapping_error(latents1_2d, errors, storage_path + f"mapping_error_{cfg.mapping}_{sampling_strategy}_.png", show_fig=False)
    total_mse_error, total_ssim_error, total_psnr_error = 0, 0, 0

    for i in range(10):
        images_sel = images[labels == i]

        latents1 = model1.encode(images_sel)
        recomposed = model2.decode(mapping.transform(latents1).float())


        errors_per_pixel_all = np.abs(images_sel - recomposed.detach().numpy())
        errors_per_pixel = errors_per_pixel_all.mean(axis=0)
        errors_per_pixel = errors_per_pixel.mean(axis=0)
#        errors_per_pixel = (errors_per_pixel - errors_per_pixel.min()) / (errors_per_pixel.max() - errors_per_pixel.min())

        ## Plot matrix as heatmap
        plt.figure()
        sns.heatmap(errors_per_pixel.reshape(28,28), cmap='Reds')
        Path(storage_path + '/mean_errors_for_class/').mkdir(parents=True, exist_ok=True)
        plt.title(f"Mean error per pixel for digit {i}")
        plt.savefig(storage_path + '/mean_errors_for_class/' + f"error_per_pixel_digit_{i}.png")
        plt.close()
        mse, ssim, psnr = 0, 0, 0
        for x in images_sel:
            x = x.unsqueeze(0)
            x_reconstructed = model2.decode(mapping.transform(model1.encode(x).float()).float()).detach().numpy()[0][0]
            mse_temp, ssim_temp, psnr_temp = calculate__MSE_ssim_psnr(x, x_reconstructed)
            mse += mse_temp
            ssim += ssim_temp
            psnr += psnr_temp
        sample_len = len(images_sel)
        if sample_len > 0:
            mse_value = mse / sample_len
            ssim_value = ssim / sample_len
            psnr_value = psnr / sample_len
        else:
            mse_value = (0)
            ssim_value = 0
            psnr_value = 0
        total_mse_error += mse_value
        total_ssim_error += ssim_value
        total_psnr_error += psnr_value
        losses_per_class = pd.concat([losses_per_class,pd.DataFrame({"Transformation": [storage_path], "Class": [i], "MSE": [mse_value], "SSIM": [ssim_value], "PSNR": [psnr_value]})], ignore_index=True)
    total_mse_error, total_ssim_error, total_psnr_error = total_mse_error / len(images), total_ssim_error / len(images), total_psnr_error / len(images)
    losses_per_class = pd.concat([losses_per_class,pd.DataFrame({"Transformation": [storage_path], "Class": ["Total"], "MSE": [total_mse_error], "SSIM": [total_ssim_error], "PSNR": [total_psnr_error]})], ignore_index=True)

    clear_memory()
    return df, losses_per_class

def get_latent_variances(model1, model2, images, labels, sampling_strategy, sampled_images, parameters, file1, file2, df, num_samples, lamda, n_classes):
    model1.eval()
    model2.eval()
    # Get latent
    latent_left = model1.get_latent_space(images).detach().cpu().numpy()
    latent_right = model2.get_latent_space(images).detach().cpu().numpy()
    latent_left_sampled_equally = torch.tensor(model1.get_latent_space(sampled_images).detach().cpu().numpy())
    latent_right_sampled_equally = torch.tensor(model2.get_latent_space(sampled_images).detach().cpu().numpy())
    # Create mapping and visualize
    cfg = Config(**parameters)
    mapping = create_mapping(cfg, latent_left_sampled_equally, latent_right_sampled_equally)
    mapping.fit()
    storage_path = f'results/transformations/SMALLAE/{cfg.mapping}_{file1[:-len(".pth")]}_->_{file2[:-len(".pth")]}_{num_samples}_{lamda}/{sampling_strategy}/'
    Path(storage_path).mkdir(parents=True, exist_ok=True)
    mapping.save_results(storage_path + "mapping")
    df = pd.concat([df, pd.DataFrame({"model1": [file1], "model2": [file2], "mapping": [storage_path + "mapping"]})], ignore_index=True)
    transformed_latent_space = mapping.transform(latent_left)

    # Calculate errors per class
    errors_original = []
    errors_reconstructed = []
    for i in range(n_classes):  # Assuming 10 classes (0-9)
        class_indices = (labels == i).nonzero(as_tuple=True)[0]
        class_images = images[class_indices]
        class_transformed_latents = transformed_latent_space[class_indices]
        class_latent_right = latent_right[class_indices]

        decoded_transformed = model2.decode(torch.tensor(class_transformed_latents).to(images.device)).detach().cpu().numpy()
        decoded_latent_right = model2.decode(torch.tensor(class_latent_right).to(images.device)).detach().cpu().numpy()

        error_compared_to_original_images = np.linalg.norm(decoded_transformed - class_images.cpu().numpy(), axis=(1, 2, 3))
        error_compared_to_reconstructed_images = np.linalg.norm(decoded_transformed - decoded_latent_right, axis=(1, 2, 3))
        
        errors_original.append(error_compared_to_original_images)
        errors_reconstructed.append(error_compared_to_reconstructed_images)
    
    # Plotting
    plt.figure(figsize=(12, 6))
    for i in range(n_classes):
        sns.kdeplot(errors_original[i], label=f'Class {i}')
    plt.title('Error Distribution Compared to Original Images by Class')
    plt.xlabel('Error')
    plt.ylabel('Density')
    plt.legend()
    plt.show()
    plt.savefig(storage_path + "error_distribution.png")


    plt.figure(figsize=(12, 6))
    for i in range(n_classes):
        sns.kdeplot(errors_reconstructed[i], label=f'Class {i}')
    plt.title('Error Distribution Compared to Reconstructed Images by Class')
    plt.xlabel('Error')
    plt.ylabel('Density')    
    plt.legend()
    plt.show()
    #save the figure
    plt.savefig(storage_path + "error_distribution.png")


    return df

In [None]:

# define pandas dataframe to store paths of models and mapping
df = pd.DataFrame(columns=["model1", "model2", "mapping"])
losses_per_class = pd.DataFrame(columns=["Transformation", "Class", "MSE", "SSIM", "PSNR"])


# Load data
data_loader = DataLoaderMNIST(64, transformation=augmentations)
images, labels = data_loader.get_full_train_dataset()
n_of_classes = len(np.unique(labels))
# Generate combinations of parameters
pbar = tqdm([10,50,100,200,300])
mapping_list = ["Linear", "Affine"]
lamda_list = [0,0.1, 0.01]
combinations = list(itertools.product(mapping_list, lamda_list))

# Change working directory
os.chdir("/Users/federicoferoggio/Documents/vs_code/latent-communication")
operating_path = "models/checkpoints/SMALLAE/MNIST/"


# Loop through combinations
for num_samples in pbar:
    for mapping, lamda in combinations:
        parameters = {"num_samples": num_samples, "mapping": mapping, "lamda": lamda}
        for file1 in os.listdir(operating_path):
            if file1.endswith(".pth"):
                for file2 in os.listdir(operating_path):
                    if file1 != file2:
                        model1 = PocketAutoencoder(path=file1)
                        model2 = PocketAutoencoder(path=file2)
                        model1.load_state_dict(torch.load(operating_path+ file1))
                        model2.load_state_dict(torch.load(operating_path+ file2))
                        pbar.set_description("Sampling equally per class")
                        images_sampled_equally, labels_sampled_equally = sample_equally_per_class_images(num_samples, images, labels)
                        pbar.set_description("Sampling removing outliers")
                        images_sampled_max_distance, labels_sampled_drop_outliers = sample_removing_outliers(num_samples, images, labels, model2)
                        pbar.set_description("Sampling worst classes")
                        images_sampled_worst_classes, labels_sampled_worst_classes = sample_with_half_worst_classes_images(num_samples, images, labels, model2)
                        pbar.set_description("Sampling convex hull")
                        images_sampled_best_classes, labels_sampled_convex_hull = sample_convex_hulls_images(num_samples, images, labels, model1)
                        pbar.set_description("Processing %s and %s" % (file1, file2))
                        
                        df = get_latent_variances(model1, model2, images, labels, "equally", images_sampled_equally, parameters, file1, file2, df, num_samples, lamda, n_of_classes)
                        df = get_latent_variances(model1, model2, images, labels, "outliers", images_sampled_max_distance, parameters, file1, file2, df, num_samples, lamda, n_of_classes)
                        df = get_latent_variances(model1, model2, images, labels, "worst_classes", images_sampled_worst_classes, parameters, file1, file2, df, num_samples, lamda, n_of_classes)
                        df = get_latent_variances(model1, model2, images, labels, "convex_hull", images_sampled_best_classes, parameters, file1, file2, df, num_samples, lamda, n_of_classes)
                        pbar.set_description("Processed %s and %s" % (file1, file2))
df.to_csv(operating_path + "transformations.csv", index=False)
losses_per_class.to_csv(operating_path + "losses_per_class.csv", index=False)
