In [3]:
import os
import torch
import torchvision.transforms as transforms
from torch.optim import Adam
from torch.optim.lr_scheduler import ReduceLROnPlateau
from utils.dataloaders.full_dataloaders import DataLoaderMNIST, DataLoaderCIFAR10, DataLoaderCIFAR100
from models.definitions.smallae import PocketAutoencoder
import pandas as pd
from tqdm import tqdm
import itertools
import sys
import logging

logging.getLogger('torchvision.datasets').setLevel(logging.ERROR)
os.chdir('/Users/federicoferoggio/Documents/vs_code/latent-communication')


# Initialize DataFrame with additional columns
loss_dataset = pd.DataFrame(columns=['Dataset', 'Model','Latent Size', 'Seed', 'Loss'])

# Define the lists
datasets_list = ['MNIST', 'CIFAR10', 'CIFAR100']
seeds = [1, 2, 3]
paths = ['models/checkpoints/SMALLAE/MNIST/', 'models/checkpoints/SMALLAE/CIFAR10/', 'models/checkpoints/SMALLAE/CIFAR100/']
dataloader_l = [DataLoaderMNIST, DataLoaderCIFAR10, DataLoaderCIFAR100]
channels_input = [1, 3, 3]
epochs = [20]
batch_sizes = [128]
learning_rates = [0.005]
latent_sizes = [10, 30, 50, 100, 500, 1000]

# Create a list of tuples with dataset, corresponding path, dataloader, and input channels
dataset_info = list(zip(datasets_list, paths, dataloader_l, channels_input))

# Split dataset_info based on the dataset type
mnist_info = [dataset_info[0]]
cifar_info = dataset_info[1:]

# Create combinations for MNIST and CIFAR datasets separately
combinations1 = [mnist_info, seeds, epochs, batch_sizes, learning_rates, latent_sizes[:3]]
combinations1 = list(itertools.product(*combinations1))

combinations2 = [cifar_info, seeds, epochs, batch_sizes, learning_rates, latent_sizes[3:]]
combinations2 = list(itertools.product(*combinations2))

# Combine both sets of combinations
combinations = combinations2 + combinations1

# Print the combinations and their count
print(combinations, len(combinations))

# Example of setting device and augmentations (adjust according to your actual use case)
DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")

augmentations_mnist = transforms.Compose([transforms.ToTensor(), transforms.Normalize((0.5,), (0.5,))])
augmentations_cifar = transforms.Compose([transforms.ToTensor(), transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))])

for combo in tqdm(combinations):
    (dataset, path, dataloader, channels_input), seed, epoch, batch_size, learning_rate, latent_size = combo
    print(f"Dataset: {dataset},{dataloader.__name__}, Channels input: {channels_input}, Epochs: {epoch}, Batch size: {batch_size}, Learning rate: {learning_rate}, Latent size: {latent_size}")


[(('CIFAR10', 'models/checkpoints/SMALLAE/CIFAR10/', <class 'utils.dataloaders.full_dataloaders.DataLoaderCIFAR10'>, 3), 1, 20, 128, 0.005, 100), (('CIFAR10', 'models/checkpoints/SMALLAE/CIFAR10/', <class 'utils.dataloaders.full_dataloaders.DataLoaderCIFAR10'>, 3), 1, 20, 128, 0.005, 500), (('CIFAR10', 'models/checkpoints/SMALLAE/CIFAR10/', <class 'utils.dataloaders.full_dataloaders.DataLoaderCIFAR10'>, 3), 1, 20, 128, 0.005, 1000), (('CIFAR10', 'models/checkpoints/SMALLAE/CIFAR10/', <class 'utils.dataloaders.full_dataloaders.DataLoaderCIFAR10'>, 3), 2, 20, 128, 0.005, 100), (('CIFAR10', 'models/checkpoints/SMALLAE/CIFAR10/', <class 'utils.dataloaders.full_dataloaders.DataLoaderCIFAR10'>, 3), 2, 20, 128, 0.005, 500), (('CIFAR10', 'models/checkpoints/SMALLAE/CIFAR10/', <class 'utils.dataloaders.full_dataloaders.DataLoaderCIFAR10'>, 3), 2, 20, 128, 0.005, 1000), (('CIFAR10', 'models/checkpoints/SMALLAE/CIFAR10/', <class 'utils.dataloaders.full_dataloaders.DataLoaderCIFAR10'>, 3), 3, 20, 

100%|██████████| 27/27 [00:00<00:00, 1276.66it/s]

Dataset: CIFAR10,DataLoaderCIFAR10, Channels input: 3, Epochs: 20, Batch size: 128, Learning rate: 0.005, Latent size: 100
Dataset: CIFAR10,DataLoaderCIFAR10, Channels input: 3, Epochs: 20, Batch size: 128, Learning rate: 0.005, Latent size: 500
Dataset: CIFAR10,DataLoaderCIFAR10, Channels input: 3, Epochs: 20, Batch size: 128, Learning rate: 0.005, Latent size: 1000
Dataset: CIFAR10,DataLoaderCIFAR10, Channels input: 3, Epochs: 20, Batch size: 128, Learning rate: 0.005, Latent size: 100
Dataset: CIFAR10,DataLoaderCIFAR10, Channels input: 3, Epochs: 20, Batch size: 128, Learning rate: 0.005, Latent size: 500
Dataset: CIFAR10,DataLoaderCIFAR10, Channels input: 3, Epochs: 20, Batch size: 128, Learning rate: 0.005, Latent size: 1000
Dataset: CIFAR10,DataLoaderCIFAR10, Channels input: 3, Epochs: 20, Batch size: 128, Learning rate: 0.005, Latent size: 100
Dataset: CIFAR10,DataLoaderCIFAR10, Channels input: 3, Epochs: 20, Batch size: 128, Learning rate: 0.005, Latent size: 500
Dataset: CIFAR




In [None]:
iterations_ = tqdm(combinations)
for (dataset, path, dataloader, channels_input), seed, num_epochs, batch_size, learning_rate, latent_dim in iterations_:
    augmentations = [transforms.ToTensor()]
    if channels_input == 1:
        augmentations.append(transforms.Normalize((0.5,), (0.5,)))
    else:
        augmentations.append(transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)))
    dataloader = dataloader(batch_size=batch_size, transformation=augmentations, seed=seed, shuffle_train_flag = True)
    test_loader = dataloader.get_test_loader()
    train_loader = dataloader.get_train_loader()
    config = {
        'model_name': 'PCKTAE',
        'dataset': dataset,
        'weight_var': 1,
        'weight_mean': 0,
        'seed': seed,
        'batch_size': batch_size,
        'num_epochs': num_epochs,
        'learning_rate': learning_rate,
        'path': path
    }
    
    torch.manual_seed(config['seed'])
    model = PocketAutoencoder(hidden_dim=latent_dim, n_input_channels=channels_input, input_size = dataloader.input_size)
    model.to(DEVICE)
    optimizer = Adam(model.parameters(), lr=config['learning_rate'], weight_decay=1e-4)
    scheduler = ReduceLROnPlateau(optimizer, mode='min', factor=0.1, patience=5)

    for epoch in range(config['num_epochs']):
        overall_loss = 0
        model.train()  # Set the model to training mode
        
        for batch_idx, (x, _) in enumerate(train_loader):
            x = x.to(DEVICE)
            optimizer.zero_grad()
            loss = model.training_step(x)
            overall_loss += loss.item()
            loss.backward()
            optimizer.step()
        
        avg_loss = overall_loss / (len(train_loader) * batch_size)
        new_row = pd.DataFrame({'Dataset': [config['dataset']],
                                'Model': [config['model_name']],
                                'Seed': [config['seed']],
                                'Latent Size': [latent_dim],
                                'Loss': [avg_loss]})
        loss_dataset = pd.concat([loss_dataset, new_row], ignore_index=True)
        iterations_.set_description(f"Epoch: {epoch}, Loss: {avg_loss}")
    
    model.eval()  # Set the model to evaluation mode
    with torch.no_grad():  # Disable gradient calculation
        test_loss = 0
        for x_test, _ in test_loader:
            x_test = x_test.to(DEVICE)
            test_loss += model.validation_step(x_test).item()
        
        avg_test_loss = test_loss / (len(test_loader) * batch_size)
        scheduler.step(avg_test_loss)  # Update the learning rate based on the test loss
        new_row = pd.DataFrame({'Dataset': [config['dataset']],
                                'Model': [config['model_name']],
                                'Seed': [config['seed']],
                                'Latent Size': [latent_dim],
                                'Loss': [avg_test_loss]})
        loss_dataset = pd.concat([loss_dataset, new_row], ignore_index=True)
        iterations_.set_description(f"Test Loss: {avg_test_loss}")    
    
    # Save the model
    name = f"{config['dataset']}_{config['model_name']}_{latent_dim}_{config['seed']}.pth"
    print(name)
    path = config['path'] + name
    torch.save(model.state_dict(), path)

loss_dataset.to_csv('models/checkpoints/SMALLAE/losses.csv', index=False)


In [8]:
import os
import torch
import pandas as pd
from skimage.metrics import structural_similarity as ssim
from skimage.metrics import peak_signal_noise_ratio as psnr
from models.definitions.smallae import PocketAutoencoder
from utils.dataloaders.full_dataloaders import DataLoaderMNIST, DataLoaderCIFAR10, DataLoaderCIFAR100  # Import your data loader classes

# Assuming you have already defined and initialized DEVICE
DEVICE = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

os.chdir("/Users/federicoferoggio/Documents/vs_code/latent-communication")
print("Current working directory:", os.getcwd())

losses_per_class = pd.DataFrame(columns=["Model", "Class", "MSE", "SSIM", "PSNR"])

# Function to calculate SSIM and PSNR
def calculate_ssim_psnr(original, reconstructed, data_range=1.0):
    original_np = original.cpu().numpy().squeeze()
    reconstructed_np = reconstructed.cpu().numpy().squeeze()
    if original.shape[1] == 3:
        mean = 0
        for i in range(3):
            mean += ssim(original_np, reconstructed_np, data_range=data_range, channel_axis=1, win_size=3)
        ssim_value = mean / 3
    else:
        ssim_value = ssim(original_np, reconstructed_np, data_range=data_range)
    psnr_value = psnr(original_np, reconstructed_np, data_range=data_range)
    return ssim_value, psnr_value

# List of data loader classes
dataloader_l = [DataLoaderMNIST, DataLoaderCIFAR10, DataLoaderCIFAR100]

# Map model file name patterns to the appropriate data loader index
dataset_name_to_loader_idx = {
    "mnist": 0,
    "cifar10": 1,
    "cifar100": 2
}

# Iterate through each model file in the specified directory and its subdirectories
for root, dirs, files in os.walk("models/checkpoints/SMALLAE/"):
    for file in files:
        if file.endswith(".pth"):  # Check if the file is a PyTorch model file
            augmentations = [transforms.ToTensor()]
            dataset_name = None
            for name in dataset_name_to_loader_idx.keys():
                if name in file.lower():
                    dataset_name = name
                    if dataset_name == "mnist":
                        augmentations.append(transforms.Normalize((0.5,), (0.5,)))
                    else:
                        augmentations.append(transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)))
                    break
            print(file)
            model = PocketAutoencoder(path=file)
            model.load_state_dict(torch.load(os.path.join(root, file), map_location=DEVICE), strict=True)
            model.to(DEVICE)

            # Get the test loader for the corresponding dataset
            loader_idx = dataset_name_to_loader_idx[dataset_name]
            dataloader_class = dataloader_l[loader_idx]
            dataloader = dataloader_class(batch_size=128, transformation=augmentations, seed=0, shuffle_train_flag=False)  # Adjust arguments as needed
            test_loader = dataloader.get_test_loader()
            print(f"Data loader initialized for {dataset_name}")

            # Iterate through each class (0 to 9 for MNIST and CIFAR-10)
            for n in range(10):
                desired_class = n  # Specify the class you want to filter
                filtered_samples = []

                # Filter samples from the test loader based on the desired class
                for data, label in test_loader:
                    indices = torch.nonzero(label == desired_class, as_tuple=False)
                    if indices.numel() > 0:
                        for idx in indices:
                            filtered_samples.append((data[idx], label[idx]))

                mse_loss_filtered = 0  # Initialize MSE loss for the current class
                ssim_loss_filtered = 0  # Initialize SSIM loss for the current class
                psnr_loss_filtered = 0  # Initialize PSNR loss for the current class

                with torch.no_grad():  # Disable gradient calculation
                    for x_test, _ in filtered_samples:
                        x_test = x_test.to(DEVICE)
                        x_reconstructed = model(x_test)
                        mse_loss_filtered += torch.nn.functional.mse_loss(x_reconstructed, x_test).item()
                        ssim_value, psnr_value = calculate_ssim_psnr(x_test , x_reconstructed)
                        ssim_loss_filtered += ssim_value
                        psnr_loss_filtered += psnr_value

                # Calculate average losses for the current class
                num_samples = len(filtered_samples)
                if num_samples > 0:
                    avg_mse_loss = mse_loss_filtered / num_samples
                    avg_ssim_loss = ssim_loss_filtered / num_samples
                    avg_psnr_loss = psnr_loss_filtered / num_samples
                else:
                    avg_mse_loss = 0  # Handle the case when there are no samples for the current class
                    avg_ssim_loss = 0
                    avg_psnr_loss = 0

                print(f"\tMetrics for class {n} - MSE: {avg_mse_loss}, SSIM: {avg_ssim_loss}, PSNR: {avg_psnr_loss}")
                # Concatenate the results to the DataFrame
                losses_per_class = pd.concat(
                    [
                        losses_per_class,
                        pd.DataFrame(
                            {
                                "Model": [file],
                                "Class": [n],
                                "MSE": [avg_mse_loss],
                                "SSIM": [avg_ssim_loss],
                                "PSNR": [avg_psnr_loss],
                            }
                        ),
                    ],
                    ignore_index=True,
                )

# Save the results to a CSV file
output_path = "models/checkpoints/SMALLAE/more_metrics.csv"
losses_per_class.to_csv(output_path, index=False)
print(f"Results saved to {output_path}")

Current working directory: /Users/federicoferoggio/Documents/vs_code/latent-communication
MNIST_PCKTAE_50_3.pth
50 1 28
441
Data loader initialized for mnist
	Metrics for class 0 - MSE: 0.8506158621335517, SSIM: -0.0998421540716091, PSNR: 0.7038977653376761


  losses_per_class = pd.concat(


	Metrics for class 1 - MSE: 0.9449309223548956, SSIM: -0.1526665754639414, PSNR: 0.2485790207299549
	Metrics for class 2 - MSE: 0.8668672287071398, SSIM: -0.08429962857573282, PSNR: 0.6223434523054625
	Metrics for class 3 - MSE: 0.869349455479348, SSIM: -0.10438738138859815, PSNR: 0.6102812839831442
	Metrics for class 4 - MSE: 0.8860240415375976, SSIM: -0.10455956665741786, PSNR: 0.5276416338848163
	Metrics for class 5 - MSE: 0.878580017341092, SSIM: -0.12794686548923415, PSNR: 0.5650688152711784
	Metrics for class 6 - MSE: 0.8715365356094902, SSIM: -0.09193764723525412, PSNR: 0.5992424845754689
	Metrics for class 7 - MSE: 0.8932938426153205, SSIM: -0.13185500163712022, PSNR: 0.49218817827723366
	Metrics for class 8 - MSE: 0.8652298976141325, SSIM: -0.06650443878665893, PSNR: 0.6301851641521072
	Metrics for class 9 - MSE: 0.8819483372099454, SSIM: -0.10765163252033638, PSNR: 0.5472905170636664
MNIST_PCKTAE_50_2.pth
50 1 28
441
Data loader initialized for mnist
	Metrics for class 0 - MS