In [10]:
import numpy as np
import matplotlib.pyplot as plt
from matplotlib.animation import FuncAnimation
from sklearn.decomposition import PCA
from pathlib import Path
import os
import sys
import hydra
import hydra.core.global_hydra
import pandas as pd
import torch
import torch.nn as nn
from omegaconf import DictConfig
from tqdm import tqdm
from IPython.display import HTML

# Define the project root directory name
PROJECT_ROOT_DIR = "latent-communication"

current_dir = os.getcwd()

# Find the project root by walking up the directory tree
while current_dir:
    if os.path.basename(current_dir) == PROJECT_ROOT_DIR:
        break  # Found the project root!
    current_dir = os.path.dirname(current_dir)
else:
    raise FileNotFoundError(f"Project root '{PROJECT_ROOT_DIR}' not found in the directory tree.")

os.chdir(current_dir)
# Add the project root and any necessary subdirectories to sys.path
sys.path.insert(0, current_dir) 
sys.path.insert(0, os.path.join(current_dir, "utils"))  # Add the utils directory if needed

from utils.dataloaders.get_dataloaders import define_dataloader
from utils.get_mapping import load_mapping
from utils.model import load_model
import matplotlib as mpl

DEVICE = (
    torch.device("cuda")
    if torch.cuda.is_available()
    else torch.device("mps") if torch.backends.mps.is_available() else "cpu"
)
hydra.core.global_hydra.GlobalHydra.instance().clear()

In [12]:
dict_losses = pd.DataFrame(columns=[])

print(os.getcwd())

for file in os.listdir("models/checkpoints/VAE/FMNIST"):
    file = file[:-4]
    name_dataset1, name_model1, size_of_the_latent1, seed1 = file.split("_")
    images, labels, n_classes = define_dataloader(name_dataset1, name_model1, seed=seed1, use_test_set=False)
    name_dataset2, name_model2, size_of_the_latent2, seed2 = file.split("_")

    filepath = f"models/checkpoints/VAE/FMNIST/{name_dataset2}_{name_model2}_{size_of_the_latent2}_{seed2}.pth"
    model2 = load_model(
        model_name=name_model2,
        name_dataset=name_dataset2,
        latent_size=size_of_the_latent2,
        seed=seed2,
        model_path=filepath,
    ).to(DEVICE)
    model2.eval()
    reconstructed = model2(images.to(DEVICE).float())
    loss = nn.MSELoss()
    loss_value = loss(reconstructed, images.to(DEVICE).float())
    print(file, loss_value.item())
    

/Users/federicoferoggio/Documents/vs_code/latent-communication
FMNIST_VAE_8_3 0.08275162428617477
FMNIST_VAE_8_2 0.08261602371931076
FMNIST_VAE_64_1 0.07893422991037369
FMNIST_VAE_64_3 0.07954147458076477
FMNIST_VAE_8_1 0.08291251957416534
FMNIST_VAE_64_2 0.07903110980987549
FMNIST_VAE_32_3 0.0804511308670044
FMNIST_VAE_16_2 0.08289296180009842
FMNIST_VAE_16_3 0.08091039210557938
FMNIST_VAE_32_2 0.07916295528411865
FMNIST_VAE_16_1 0.07920797914266586
FMNIST_VAE_32_1 0.08049789071083069
