In [6]:
import os
import torch
import torchvision.transforms as transforms
from torch.optim import Adam
from torch.optim.lr_scheduler import ReduceLROnPlateau
from utils.dataloaders.dataloader_mnist_single import DataLoaderMNIST
from models.definitions.PocketAutoencoder import PocketAutoencoder
import pandas as pd
from tqdm import tqdm
import itertools

os.chdir('/Users/federicoferoggio/Documents/vs_code/latent-communication')

# Initialize DataFrame with additional columns
loss_dataset = pd.DataFrame(columns=['Dataset', 'Model', 'Seed', 'Epochs', 'Learning Rate', 'Batch Size', 'Loss'])

datasets_list = ['MNIST']
seeds = [1, 2, 3, 4]
paths = ['models/checkpoints/SMALLAE/MNIST/']
dataloader_l = [DataLoaderMNIST]
epochs = [5]
batch_sizes = [64, 128]
learning_rates = [0.01, 0.001]

combinations1 = [datasets_list, seeds, paths, dataloader_l, epochs, batch_sizes, learning_rates]
combinations1 = list(itertools.product(*combinations1))

datasets_list = ['MNIST']
seeds = [3, 4]
paths = ['models/checkpoints/SMALLAE/MNIST/']
dataloader_l = [DataLoaderMNIST]
epochs = [1, 10]
batch_sizes = [32, 64, 128]
learning_rates = [0.01, 0.001, 0.0001]

combinations2 = [datasets_list, seeds, paths, dataloader_l, epochs, batch_sizes, learning_rates]
combinations2 = list(itertools.product(*combinations2))

combinations = combinations1 + combinations2

DEVICE = torch.device("mps")
augmentations = [transforms.ToTensor(), transforms.Normalize((0.5,), (0.5,))]
print(len(combinations))

for dataset, seed, paths, dataloader, epochs, batch_size, learning_rate in tqdm(combinations):
    dataloader = dataloader(batch_size=batch_size, transformation=augmentations, seed=seed)
    test_loader = dataloader.get_test_loader()
    train_loader = dataloader.get_train_loader()
    config = {
        'model_name': 'SMLLAE',
        'dataset': dataset,
        'weight_var': 1,
        'weight_mean': 0,
        'seed': seed,
        'batch_size': batch_size,
        'num_epochs': epochs,
        'learning_rate': learning_rate,
        'path': paths
    }
    
    torch.manual_seed(config['seed'])
    model = PocketAutoencoder()
    model.to(DEVICE)
    optimizer = Adam(model.parameters(), lr=config['learning_rate'], weight_decay=1e-4)
    scheduler = ReduceLROnPlateau(optimizer, mode='min', factor=0.1, patience=5, verbose=True)
    
    for epoch in range(config['num_epochs']):
        overall_loss = 0
        model.train()  # Set the model to training mode
        
        for batch_idx, (x, _) in enumerate(train_loader):
            x = x.to(DEVICE)
            optimizer.zero_grad()
            loss = model.training_step(x)
            overall_loss += loss.item()
            loss.backward()
            optimizer.step()
        
        avg_loss = overall_loss / (len(train_loader) * batch_size)
        new_row = pd.DataFrame({'Dataset': [config['dataset']],
                                'Model': [config['model_name']],
                                'Seed': [config['seed']],
                                'Epochs': [epoch],
                                'Learning Rate': [config['learning_rate']],
                                'Batch Size': [config['batch_size']],
                                'Loss': [avg_loss]})
        loss_dataset = pd.concat([loss_dataset, new_row], ignore_index=True)
        print("\tEpoch", epoch + 1, "complete!", "\tAverage Loss: ", avg_loss)
    
    model.eval()  # Set the model to evaluation mode
    with torch.no_grad():  # Disable gradient calculation
        test_loss = 0
        for x_test, _ in test_loader:
            x_test = x_test.to(DEVICE)
            test_loss += model.validation_step(x_test).item()
        
        avg_test_loss = test_loss / (len(test_loader) * batch_size)
        scheduler.step(avg_test_loss)  # Update the learning rate based on the test loss
        new_row = pd.DataFrame({'Dataset': [config['dataset']],
                                'Model': [config['model_name']],
                                'Seed': [config['seed']],
                                'Epochs': ['Test'],
                                'Learning Rate': [config['learning_rate']],
                                'Batch Size': [config['batch_size']],
                                'Loss': [avg_test_loss]})
        loss_dataset = pd.concat([loss_dataset, new_row], ignore_index=True)
    
    
    # Save the model
    name = f"{config['dataset']}_{config['model_name']}_{config['learning_rate']}_{config['batch_size']}_{config['num_epochs']}_{config['seed']}.pth"
    print(name)
    path = config['path'] + name
    torch.save(model.state_dict(), path)

loss_dataset.to_csv('models/checkpoints/SMALLAE/losses.csv', index=False)


52


  loss_dataset = pd.concat([loss_dataset, new_row], ignore_index=True)


	Epoch 1 complete! 	Average Loss:  0.014295315334219922
	Epoch 2 complete! 	Average Loss:  0.013471051630601763
	Epoch 3 complete! 	Average Loss:  0.013367264171136912
	Epoch 4 complete! 	Average Loss:  0.013359573416546908
	Epoch 5 complete! 	Average Loss:  0.013356479577847254


  2%|▏         | 1/52 [01:30<1:16:56, 90.52s/it]

MNIST_SMLLAE_0.01_64_5_1.pth
	Epoch 1 complete! 	Average Loss:  0.014714868470176515
	Epoch 2 complete! 	Average Loss:  0.013862573777013687
	Epoch 3 complete! 	Average Loss:  0.013591269830436403
	Epoch 4 complete! 	Average Loss:  0.013448886207934382
	Epoch 5 complete! 	Average Loss:  0.01341128998350646


  4%|▍         | 2/52 [03:00<1:15:05, 90.11s/it]

MNIST_SMLLAE_0.001_64_5_1.pth
	Epoch 1 complete! 	Average Loss:  0.007206655884864552
	Epoch 2 complete! 	Average Loss:  0.006740905529162141
	Epoch 3 complete! 	Average Loss:  0.006678439953775484
	Epoch 4 complete! 	Average Loss:  0.006666744267865857
	Epoch 5 complete! 	Average Loss:  0.006663902720281565


  6%|▌         | 3/52 [04:02<1:03:14, 77.44s/it]

MNIST_SMLLAE_0.01_128_5_1.pth
	Epoch 1 complete! 	Average Loss:  0.0075174807868739054
	Epoch 2 complete! 	Average Loss:  0.007156756362998918
	Epoch 3 complete! 	Average Loss:  0.00695778059163518
	Epoch 4 complete! 	Average Loss:  0.006794732149396496
	Epoch 5 complete! 	Average Loss:  0.006727197299848424


  8%|▊         | 4/52 [05:01<56:10, 70.22s/it]  

MNIST_SMLLAE_0.001_128_5_1.pth
	Epoch 1 complete! 	Average Loss:  0.013923359738944817
	Epoch 2 complete! 	Average Loss:  0.013381644184869935
	Epoch 3 complete! 	Average Loss:  0.013359250480940601
	Epoch 4 complete! 	Average Loss:  0.013351790211808833
	Epoch 5 complete! 	Average Loss:  0.013350434403723554


 10%|▉         | 5/52 [06:32<1:00:40, 77.46s/it]

MNIST_SMLLAE_0.01_64_5_2.pth
	Epoch 1 complete! 	Average Loss:  0.014521742497743574
	Epoch 2 complete! 	Average Loss:  0.013553963543207788
	Epoch 3 complete! 	Average Loss:  0.013429796106732094
	Epoch 4 complete! 	Average Loss:  0.013376718257933157
	Epoch 5 complete! 	Average Loss:  0.01334023198434539


 12%|█▏        | 6/52 [08:02<1:02:49, 81.95s/it]

MNIST_SMLLAE_0.001_64_5_2.pth
	Epoch 1 complete! 	Average Loss:  0.007128109783132765
	Epoch 2 complete! 	Average Loss:  0.006743482846234526
	Epoch 3 complete! 	Average Loss:  0.0066839733472796896
	Epoch 4 complete! 	Average Loss:  0.006668673577045263
	Epoch 5 complete! 	Average Loss:  0.0066651310731988475


 13%|█▎        | 7/52 [09:01<55:52, 74.50s/it]  

MNIST_SMLLAE_0.01_128_5_2.pth
	Epoch 1 complete! 	Average Loss:  0.007528818107204143
	Epoch 2 complete! 	Average Loss:  0.006906906490339272
	Epoch 3 complete! 	Average Loss:  0.006798897343657927
	Epoch 4 complete! 	Average Loss:  0.006770161533557467
	Epoch 5 complete! 	Average Loss:  0.00673644801280074


 15%|█▌        | 8/52 [10:00<51:00, 69.57s/it]

MNIST_SMLLAE_0.001_128_5_2.pth
	Epoch 1 complete! 	Average Loss:  0.01404481735636494
	Epoch 2 complete! 	Average Loss:  0.013386585297924814
	Epoch 3 complete! 	Average Loss:  0.01336554453662559
	Epoch 4 complete! 	Average Loss:  0.013358125252637274
	Epoch 5 complete! 	Average Loss:  0.013355930065756033


 17%|█▋        | 9/52 [11:31<54:35, 76.17s/it]

MNIST_SMLLAE_0.01_64_5_3.pth
	Epoch 1 complete! 	Average Loss:  0.014717728616411624
	Epoch 2 complete! 	Average Loss:  0.013693687724453936
	Epoch 3 complete! 	Average Loss:  0.013446104550944653
	Epoch 4 complete! 	Average Loss:  0.01338917277018621
	Epoch 5 complete! 	Average Loss:  0.01336253014232304


 19%|█▉        | 10/52 [13:01<56:20, 80.49s/it]

MNIST_SMLLAE_0.001_64_5_3.pth
	Epoch 1 complete! 	Average Loss:  0.007266542625063454
	Epoch 2 complete! 	Average Loss:  0.00678092165431107
	Epoch 3 complete! 	Average Loss:  0.0066915021718803374
	Epoch 4 complete! 	Average Loss:  0.006670213722463038
	Epoch 5 complete! 	Average Loss:  0.006663954504993933


 21%|██        | 11/52 [14:01<50:33, 74.00s/it]

MNIST_SMLLAE_0.01_128_5_3.pth
	Epoch 1 complete! 	Average Loss:  0.007526669676091943
	Epoch 2 complete! 	Average Loss:  0.007129180543958696
	Epoch 3 complete! 	Average Loss:  0.006862399920916506
	Epoch 4 complete! 	Average Loss:  0.006765948293774303
	Epoch 5 complete! 	Average Loss:  0.006721909473668029


 23%|██▎       | 12/52 [15:00<46:16, 69.42s/it]

MNIST_SMLLAE_0.001_128_5_3.pth
	Epoch 1 complete! 	Average Loss:  0.01391412727217049
	Epoch 2 complete! 	Average Loss:  0.013436063997018566
	Epoch 3 complete! 	Average Loss:  0.013376562770749969
	Epoch 4 complete! 	Average Loss:  0.013366346217707785
	Epoch 5 complete! 	Average Loss:  0.013361999714401548


 25%|██▌       | 13/52 [16:30<49:10, 75.66s/it]

MNIST_SMLLAE_0.01_64_5_4.pth
	Epoch 1 complete! 	Average Loss:  0.014377051853199504
	Epoch 2 complete! 	Average Loss:  0.013512781122202938
	Epoch 3 complete! 	Average Loss:  0.013400144494776087
	Epoch 4 complete! 	Average Loss:  0.013349439117755654
	Epoch 5 complete! 	Average Loss:  0.013319728205969402


 27%|██▋       | 14/52 [18:00<50:45, 80.14s/it]

MNIST_SMLLAE_0.001_64_5_4.pth
	Epoch 1 complete! 	Average Loss:  0.007190725767115222
	Epoch 2 complete! 	Average Loss:  0.00675023056125876
	Epoch 3 complete! 	Average Loss:  0.0066801606357764845
	Epoch 4 complete! 	Average Loss:  0.00666625525500538
	Epoch 5 complete! 	Average Loss:  0.006662661204738086


 29%|██▉       | 15/52 [18:59<45:26, 73.69s/it]

MNIST_SMLLAE_0.01_128_5_4.pth
	Epoch 1 complete! 	Average Loss:  0.00748489843979319
	Epoch 2 complete! 	Average Loss:  0.006885871741928653
	Epoch 3 complete! 	Average Loss:  0.0067617408891540096
	Epoch 4 complete! 	Average Loss:  0.006719181147306713
	Epoch 5 complete! 	Average Loss:  0.006695649964905687


 31%|███       | 16/52 [19:58<41:33, 69.27s/it]

MNIST_SMLLAE_0.001_128_5_4.pth
	Epoch 1 complete! 	Average Loss:  0.027454067893822987


 33%|███▎      | 17/52 [20:29<33:48, 57.94s/it]

MNIST_SMLLAE_0.01_32_1_3.pth
	Epoch 1 complete! 	Average Loss:  0.028649589363733928


 35%|███▍      | 18/52 [21:01<28:18, 49.95s/it]

MNIST_SMLLAE_0.001_32_1_3.pth
	Epoch 1 complete! 	Average Loss:  0.03173647263844808


 37%|███▋      | 19/52 [21:33<24:27, 44.48s/it]

MNIST_SMLLAE_0.0001_32_1_3.pth
	Epoch 1 complete! 	Average Loss:  0.01404481735636494


 38%|███▊      | 20/52 [21:52<19:42, 36.95s/it]

MNIST_SMLLAE_0.01_64_1_3.pth
	Epoch 1 complete! 	Average Loss:  0.014717728616411624


 40%|████      | 21/52 [22:11<16:20, 31.62s/it]

MNIST_SMLLAE_0.001_64_1_3.pth
	Epoch 1 complete! 	Average Loss:  0.01703710093605779


 42%|████▏     | 22/52 [22:31<14:01, 28.07s/it]

MNIST_SMLLAE_0.0001_64_1_3.pth
	Epoch 1 complete! 	Average Loss:  0.007266542625063454


 44%|████▍     | 23/52 [22:44<11:22, 23.52s/it]

MNIST_SMLLAE_0.01_128_1_3.pth
	Epoch 1 complete! 	Average Loss:  0.007526669676091943


 46%|████▌     | 24/52 [22:56<09:26, 20.22s/it]

MNIST_SMLLAE_0.001_128_1_3.pth
	Epoch 1 complete! 	Average Loss:  0.00950174984126997


 48%|████▊     | 25/52 [23:09<08:03, 17.89s/it]

MNIST_SMLLAE_0.0001_128_1_3.pth
	Epoch 1 complete! 	Average Loss:  0.027454067893822987
	Epoch 2 complete! 	Average Loss:  0.026831529101729392
	Epoch 3 complete! 	Average Loss:  0.026811624216039977
	Epoch 4 complete! 	Average Loss:  0.026806506638725597
	Epoch 5 complete! 	Average Loss:  0.026803672764698663
	Epoch 6 complete! 	Average Loss:  0.026803912051518757
	Epoch 7 complete! 	Average Loss:  0.026797309777140618
	Epoch 8 complete! 	Average Loss:  0.026796994694073994
	Epoch 9 complete! 	Average Loss:  0.026795483962694804
	Epoch 10 complete! 	Average Loss:  0.026794185261925063


 50%|█████     | 26/52 [28:18<45:41, 105.44s/it]

MNIST_SMLLAE_0.01_32_10_3.pth
	Epoch 1 complete! 	Average Loss:  0.028649589363733928
	Epoch 2 complete! 	Average Loss:  0.026906486244996388
	Epoch 3 complete! 	Average Loss:  0.026753301899631817
	Epoch 4 complete! 	Average Loss:  0.026641209758321443
	Epoch 5 complete! 	Average Loss:  0.026585692855715753
	Epoch 6 complete! 	Average Loss:  0.026567875307798385
	Epoch 7 complete! 	Average Loss:  0.026551405408978462
	Epoch 8 complete! 	Average Loss:  0.02654039367934068
	Epoch 9 complete! 	Average Loss:  0.02653295478026072
	Epoch 10 complete! 	Average Loss:  0.026528906763593357


 52%|█████▏    | 27/52 [33:32<1:09:57, 167.88s/it]

MNIST_SMLLAE_0.001_32_10_3.pth
	Epoch 1 complete! 	Average Loss:  0.03173647263844808
	Epoch 2 complete! 	Average Loss:  0.02886401378015677
	Epoch 3 complete! 	Average Loss:  0.028587549089392025
	Epoch 4 complete! 	Average Loss:  0.027565429803729058
	Epoch 5 complete! 	Average Loss:  0.027165669505794843
	Epoch 6 complete! 	Average Loss:  0.02697430467903614
	Epoch 7 complete! 	Average Loss:  0.026871254283189774
	Epoch 8 complete! 	Average Loss:  0.026795387203494708
	Epoch 9 complete! 	Average Loss:  0.026741917686661086
	Epoch 10 complete! 	Average Loss:  0.026709992335240046


 54%|█████▍    | 28/52 [38:42<1:24:14, 210.62s/it]

MNIST_SMLLAE_0.0001_32_10_3.pth
	Epoch 1 complete! 	Average Loss:  0.01404481735636494
	Epoch 2 complete! 	Average Loss:  0.013386585297924814
	Epoch 3 complete! 	Average Loss:  0.01336554453662559
	Epoch 4 complete! 	Average Loss:  0.013358125252637274
	Epoch 5 complete! 	Average Loss:  0.013355930065756033
	Epoch 6 complete! 	Average Loss:  0.01335246759607816
	Epoch 7 complete! 	Average Loss:  0.013351549766759183
	Epoch 8 complete! 	Average Loss:  0.013349086546631002
	Epoch 9 complete! 	Average Loss:  0.013348809674755532
	Epoch 10 complete! 	Average Loss:  0.013348076793390996


 56%|█████▌    | 29/52 [41:43<1:17:17, 201.62s/it]

MNIST_SMLLAE_0.01_64_10_3.pth
	Epoch 1 complete! 	Average Loss:  0.014717728616411624
	Epoch 2 complete! 	Average Loss:  0.013693687724453936
	Epoch 3 complete! 	Average Loss:  0.013446104550944653
	Epoch 4 complete! 	Average Loss:  0.01338917277018621
	Epoch 5 complete! 	Average Loss:  0.01336253014232304
	Epoch 6 complete! 	Average Loss:  0.013338830075792667
	Epoch 7 complete! 	Average Loss:  0.013317811025826852
	Epoch 8 complete! 	Average Loss:  0.013303299055202429
	Epoch 9 complete! 	Average Loss:  0.013295027448225822
	Epoch 10 complete! 	Average Loss:  0.013289963825544251


 58%|█████▊    | 30/52 [44:44<1:11:38, 195.41s/it]

MNIST_SMLLAE_0.001_64_10_3.pth
	Epoch 1 complete! 	Average Loss:  0.01703710093605779
	Epoch 2 complete! 	Average Loss:  0.01449135983529598
	Epoch 3 complete! 	Average Loss:  0.014445639359377531
	Epoch 4 complete! 	Average Loss:  0.014414341943556947
	Epoch 5 complete! 	Average Loss:  0.014290309015677365
	Epoch 6 complete! 	Average Loss:  0.01393809658326288
	Epoch 7 complete! 	Average Loss:  0.01374096487031213
	Epoch 8 complete! 	Average Loss:  0.013627873918315622
	Epoch 9 complete! 	Average Loss:  0.013548263607781007
	Epoch 10 complete! 	Average Loss:  0.01349251566113217


 60%|█████▉    | 31/52 [47:44<1:06:48, 190.88s/it]

MNIST_SMLLAE_0.0001_64_10_3.pth
	Epoch 1 complete! 	Average Loss:  0.007266542625063454
	Epoch 2 complete! 	Average Loss:  0.00678092165431107
	Epoch 3 complete! 	Average Loss:  0.0066915021718803374
	Epoch 4 complete! 	Average Loss:  0.006670213722463038
	Epoch 5 complete! 	Average Loss:  0.006663954504993933
	Epoch 6 complete! 	Average Loss:  0.0066620229250952
	Epoch 7 complete! 	Average Loss:  0.006660168655891853
	Epoch 8 complete! 	Average Loss:  0.006659526169411282
	Epoch 9 complete! 	Average Loss:  0.006658213584423701
	Epoch 10 complete! 	Average Loss:  0.006658056228638076


 62%|██████▏   | 32/52 [49:42<56:17, 168.86s/it]  

MNIST_SMLLAE_0.01_128_10_3.pth
	Epoch 1 complete! 	Average Loss:  0.007526669676091943
	Epoch 2 complete! 	Average Loss:  0.007129180543958696
	Epoch 3 complete! 	Average Loss:  0.006862399920916506
	Epoch 4 complete! 	Average Loss:  0.006765948293774303
	Epoch 5 complete! 	Average Loss:  0.006721909473668029
	Epoch 6 complete! 	Average Loss:  0.006699072000489179
	Epoch 7 complete! 	Average Loss:  0.006684755486473918
	Epoch 8 complete! 	Average Loss:  0.006674145271918222
	Epoch 9 complete! 	Average Loss:  0.006666190188521071
	Epoch 10 complete! 	Average Loss:  0.006659881065863727


 63%|██████▎   | 33/52 [51:39<48:34, 153.41s/it]

MNIST_SMLLAE_0.001_128_10_3.pth
	Epoch 1 complete! 	Average Loss:  0.00950174984126997
	Epoch 2 complete! 	Average Loss:  0.007386878294659767
	Epoch 3 complete! 	Average Loss:  0.007251438246900911
	Epoch 4 complete! 	Average Loss:  0.00723170671564366
	Epoch 5 complete! 	Average Loss:  0.007222511460071307
	Epoch 6 complete! 	Average Loss:  0.007213867470018391
	Epoch 7 complete! 	Average Loss:  0.007201985078754583
	Epoch 8 complete! 	Average Loss:  0.0071807878910065456
	Epoch 9 complete! 	Average Loss:  0.0071292876669052825
	Epoch 10 complete! 	Average Loss:  0.007031703357900511


 65%|██████▌   | 34/52 [53:36<42:46, 142.58s/it]

MNIST_SMLLAE_0.0001_128_10_3.pth
	Epoch 1 complete! 	Average Loss:  0.027467248126864432


 67%|██████▋   | 35/52 [54:10<31:05, 109.75s/it]

MNIST_SMLLAE_0.01_32_1_4.pth
	Epoch 1 complete! 	Average Loss:  0.027989606668551763


 69%|██████▉   | 36/52 [54:42<23:04, 86.51s/it] 

MNIST_SMLLAE_0.001_32_1_4.pth
	Epoch 1 complete! 	Average Loss:  0.03149251182675362


 71%|███████   | 37/52 [55:14<17:31, 70.08s/it]

MNIST_SMLLAE_0.0001_32_1_4.pth
	Epoch 1 complete! 	Average Loss:  0.01391412727217049


 73%|███████▎  | 38/52 [55:33<12:46, 54.75s/it]

MNIST_SMLLAE_0.01_64_1_4.pth
	Epoch 1 complete! 	Average Loss:  0.014377051853199504


 75%|███████▌  | 39/52 [55:52<09:32, 44.05s/it]

MNIST_SMLLAE_0.001_64_1_4.pth
	Epoch 1 complete! 	Average Loss:  0.01683956883482333


 77%|███████▋  | 40/52 [56:12<07:24, 37.04s/it]

MNIST_SMLLAE_0.0001_64_1_4.pth
	Epoch 1 complete! 	Average Loss:  0.007190725767115222


 79%|███████▉  | 41/52 [56:25<05:26, 29.72s/it]

MNIST_SMLLAE_0.01_128_1_4.pth
	Epoch 1 complete! 	Average Loss:  0.00748489843979319


 81%|████████  | 42/52 [56:38<04:06, 24.61s/it]

MNIST_SMLLAE_0.001_128_1_4.pth
	Epoch 1 complete! 	Average Loss:  0.009374442009894705


 83%|████████▎ | 43/52 [56:50<03:09, 21.02s/it]

MNIST_SMLLAE_0.0001_128_1_4.pth
	Epoch 1 complete! 	Average Loss:  0.027467248126864432
	Epoch 2 complete! 	Average Loss:  0.02684157291750113
	Epoch 3 complete! 	Average Loss:  0.026813610881567
	Epoch 4 complete! 	Average Loss:  0.026807075793544452
	Epoch 5 complete! 	Average Loss:  0.026806024954716366
	Epoch 6 complete! 	Average Loss:  0.02680357968211174
	Epoch 7 complete! 	Average Loss:  0.02680332838992278
	Epoch 8 complete! 	Average Loss:  0.02680427881081899
	Epoch 9 complete! 	Average Loss:  0.026805119377374648
	Epoch 10 complete! 	Average Loss:  0.02680271965165933


 85%|████████▍ | 44/52 [1:02:01<14:22, 107.85s/it]

MNIST_SMLLAE_0.01_32_10_4.pth
	Epoch 1 complete! 	Average Loss:  0.027989606668551763
	Epoch 2 complete! 	Average Loss:  0.026807100077470145
	Epoch 3 complete! 	Average Loss:  0.026670848566293717
	Epoch 4 complete! 	Average Loss:  0.026590073439478876
	Epoch 5 complete! 	Average Loss:  0.026566647137204805
	Epoch 6 complete! 	Average Loss:  0.026555790369709332
	Epoch 7 complete! 	Average Loss:  0.02654871385594209
	Epoch 8 complete! 	Average Loss:  0.026543216965595882
	Epoch 9 complete! 	Average Loss:  0.02653994532128175
	Epoch 10 complete! 	Average Loss:  0.026537012592951457


 87%|████████▋ | 45/52 [1:07:12<19:42, 168.88s/it]

MNIST_SMLLAE_0.001_32_10_4.pth
	Epoch 1 complete! 	Average Loss:  0.03149251182675362
	Epoch 2 complete! 	Average Loss:  0.0282554923504591
	Epoch 3 complete! 	Average Loss:  0.02745947249432405
	Epoch 4 complete! 	Average Loss:  0.027162082609534264
	Epoch 5 complete! 	Average Loss:  0.027011254115899403
	Epoch 6 complete! 	Average Loss:  0.02691289305090904
	Epoch 7 complete! 	Average Loss:  0.026840431077281634
	Epoch 8 complete! 	Average Loss:  0.026787269205848376
	Epoch 9 complete! 	Average Loss:  0.026747838108738264
	Epoch 10 complete! 	Average Loss:  0.02671565698583921


 88%|████████▊ | 46/52 [1:12:21<21:05, 210.91s/it]

MNIST_SMLLAE_0.0001_32_10_4.pth
	Epoch 1 complete! 	Average Loss:  0.01391412727217049
	Epoch 2 complete! 	Average Loss:  0.013436063997018566
	Epoch 3 complete! 	Average Loss:  0.013376562770749969
	Epoch 4 complete! 	Average Loss:  0.013366346217707785
	Epoch 5 complete! 	Average Loss:  0.013361999714401548
	Epoch 6 complete! 	Average Loss:  0.013359892700336128
	Epoch 7 complete! 	Average Loss:  0.013357244081485437
	Epoch 8 complete! 	Average Loss:  0.013357081209251812
	Epoch 9 complete! 	Average Loss:  0.013356173400090002
	Epoch 10 complete! 	Average Loss:  0.013355251994412908


 90%|█████████ | 47/52 [1:15:22<16:49, 201.98s/it]

MNIST_SMLLAE_0.01_64_10_4.pth
	Epoch 1 complete! 	Average Loss:  0.014377051853199504
	Epoch 2 complete! 	Average Loss:  0.013512781122202938
	Epoch 3 complete! 	Average Loss:  0.013400144494776087
	Epoch 4 complete! 	Average Loss:  0.013349439117755654
	Epoch 5 complete! 	Average Loss:  0.013319728205969402
	Epoch 6 complete! 	Average Loss:  0.013296603543886435
	Epoch 7 complete! 	Average Loss:  0.01327786815644645
	Epoch 8 complete! 	Average Loss:  0.013271228816392007
	Epoch 9 complete! 	Average Loss:  0.013267569254829623
	Epoch 10 complete! 	Average Loss:  0.013264783328450692


 92%|█████████▏| 48/52 [1:18:23<13:02, 195.56s/it]

MNIST_SMLLAE_0.001_64_10_4.pth
	Epoch 1 complete! 	Average Loss:  0.01683956883482333
	Epoch 2 complete! 	Average Loss:  0.014429507977656846
	Epoch 3 complete! 	Average Loss:  0.01421829554707066
	Epoch 4 complete! 	Average Loss:  0.013919128451559907
	Epoch 5 complete! 	Average Loss:  0.013722611110665396
	Epoch 6 complete! 	Average Loss:  0.013619092191809784
	Epoch 7 complete! 	Average Loss:  0.013552690168092055
	Epoch 8 complete! 	Average Loss:  0.013505488462341049
	Epoch 9 complete! 	Average Loss:  0.013471202542588337
	Epoch 10 complete! 	Average Loss:  0.013444786058313875


 94%|█████████▍| 49/52 [1:21:23<09:32, 190.95s/it]

MNIST_SMLLAE_0.0001_64_10_4.pth
	Epoch 1 complete! 	Average Loss:  0.007190725767115222
	Epoch 2 complete! 	Average Loss:  0.00675023056125876
	Epoch 3 complete! 	Average Loss:  0.0066801606357764845
	Epoch 4 complete! 	Average Loss:  0.00666625525500538
	Epoch 5 complete! 	Average Loss:  0.006662661204738086
	Epoch 6 complete! 	Average Loss:  0.006661542431155501
	Epoch 7 complete! 	Average Loss:  0.006660503873438723
	Epoch 8 complete! 	Average Loss:  0.006660562769166315
	Epoch 9 complete! 	Average Loss:  0.006659372222186851
	Epoch 10 complete! 	Average Loss:  0.006659471866752166


 96%|█████████▌| 50/52 [1:23:20<05:37, 168.80s/it]

MNIST_SMLLAE_0.01_128_10_4.pth
	Epoch 1 complete! 	Average Loss:  0.00748489843979319
	Epoch 2 complete! 	Average Loss:  0.006885871741928653
	Epoch 3 complete! 	Average Loss:  0.0067617408891540096
	Epoch 4 complete! 	Average Loss:  0.006719181147306713
	Epoch 5 complete! 	Average Loss:  0.006695649964905687
	Epoch 6 complete! 	Average Loss:  0.006681848885312772
	Epoch 7 complete! 	Average Loss:  0.006672086936832745
	Epoch 8 complete! 	Average Loss:  0.006663112913065755
	Epoch 9 complete! 	Average Loss:  0.006654698399703767
	Epoch 10 complete! 	Average Loss:  0.006649073192488347


 98%|█████████▊| 51/52 [1:25:17<02:33, 153.25s/it]

MNIST_SMLLAE_0.001_128_10_4.pth
	Epoch 1 complete! 	Average Loss:  0.009374442009894705
	Epoch 2 complete! 	Average Loss:  0.007340980528879649
	Epoch 3 complete! 	Average Loss:  0.0072259333586371915
	Epoch 4 complete! 	Average Loss:  0.007187414782316382
	Epoch 5 complete! 	Average Loss:  0.007134252504856665
	Epoch 6 complete! 	Average Loss:  0.0070492568751499216
	Epoch 7 complete! 	Average Loss:  0.006963831557234975
	Epoch 8 complete! 	Average Loss:  0.00689666945117909
	Epoch 9 complete! 	Average Loss:  0.006850970373836471
	Epoch 10 complete! 	Average Loss:  0.006820131329569354


100%|██████████| 52/52 [1:27:14<00:00, 100.67s/it]

MNIST_SMLLAE_0.0001_128_10_4.pth





In [8]:
import os
import torch
import torchvision.transforms as transforms
from torch.optim import Adam
from utils.dataloaders.dataloader_mnist_single import DataLoaderMNIST
from models.definitions.PocketAutoencoder import PocketAutoencoder
from models.definitions.PocketAutoencoder import PocketAutoencoder

import os

os.chdir("/Users/federicoferoggio/Documents/vs_code/latent-communication")

# Assuming you have already defined and initialized test_loader and DEVICE

losses_per_class = pd.DataFrame(columns=["Model", "Class", "Loss"])

DataLoaders = DataLoaderMNIST
batch_size = 64
augmentations = [transforms.ToTensor(), transforms.Normalize((0.5,), (0.5,))]

dataloader = DataLoaders(batch_size=batch_size, transformation=augmentations, seed=seed)

test_loader = dataloader.get_test_loader()
train_loader = dataloader.get_train_loader()

for file in os.listdir("models/checkpoints/SMALLAE/MNIST/"):
    if file.endswith(".pth"):  # Check if the file is a PyTorch model file
        # Load the model
        model = PocketAutoencoder()
        model.load_state_dict(torch.load("models/checkpoints/SMALLAE/MNIST/" + file))
        print("Model loaded:", file)

        # Iterate through each class (0 to 9)
        for n in range(10):
            desired_class = n  # Specify the class you want to filter
            filtered_samples = []

            # Filter samples from the test loader based on the desired class
            for data, label in test_loader:
                indices = torch.nonzero(label == desired_class, as_tuple=False)
                if indices.numel() > 0:
                    for idx in indices:
                        filtered_samples.append((data[idx], label[idx]))

            test_loss_filtered = 0  # Initialize test loss for the current class

            with torch.no_grad():  # Disable gradient calculation
                for x_test, _ in filtered_samples:
                    test_loss_filtered += model.validation_step(x_test).item()

            # Calculate average reconstruction loss for the current class
            if len(filtered_samples) > 0:
                avg_loss = test_loss_filtered / len(filtered_samples)
            else:
                avg_loss = (
                    0  # Handle the case when there are no samples for the current class
                )

            print("\tTest Loss for class", n, ":", avg_loss)
            # Concatenate the results to the DataFrame
            losses_per_class = pd.concat(
                [
                    losses_per_class,
                    pd.DataFrame({"Model": [file], "Class": [n], "Loss": [avg_loss]}),
                ],
                ignore_index=True,
            )

# Save the results to a CSV file
losses_per_class.to_csv("models/checkpoints/SMALLAE/losses_per_class.csv", index=False)

Model loaded: MNIST_SMLLAE_0.001_128_5_4.pth
	Test Loss for class 0 : 0.8365662784600745


  losses_per_class = pd.concat(


	Test Loss for class 1 : 1.0393805159871274
	Test Loss for class 2 : 0.8830900924034821
	Test Loss for class 3 : 0.9042004975351957
	Test Loss for class 4 : 0.9444078854655054
	Test Loss for class 5 : 0.9092364298522205
	Test Loss for class 6 : 0.9007483102334566
	Test Loss for class 7 : 0.9739232790957165
	Test Loss for class 8 : 0.9081481354314933
	Test Loss for class 9 : 0.9573835132143306
Model loaded: MNIST_SMLLAE_0.001_64_1_4.pth
	Test Loss for class 0 : 0.8528710160936628
	Test Loss for class 1 : 1.1349865086278201
	Test Loss for class 2 : 0.8843246402088986
	Test Loss for class 3 : 0.9126271761880063
	Test Loss for class 4 : 1.0431147588610405
	Test Loss for class 5 : 0.9216891945477559
	Test Loss for class 6 : 0.8874548055807086
	Test Loss for class 7 : 1.0649184425749203
	Test Loss for class 8 : 0.931528761707537
	Test Loss for class 9 : 1.0597748974978747
Model loaded: MNIST_SMLLAE_0.001_128_10_4.pth
	Test Loss for class 0 : 0.8180888764712275
	Test Loss for class 1 : 0.9218

In [12]:
import os
import torch
import pandas as pd
import torchvision.transforms as transforms
from skimage.metrics import structural_similarity as ssim
from skimage.metrics import peak_signal_noise_ratio as psnr
from models.definitions.PocketAutoencoder import PocketAutoencoder

# Assuming you have already defined and initialized test_loader and DEVICE

os.chdir("/Users/federicoferoggio/Documents/vs_code/latent-communication")
losses_per_class = pd.DataFrame(columns=["Model", "Class", "MSE", "SSIM", "PSNR"])


# Function to calculate SSIM and PSNR
def calculate_ssim_psnr(original, reconstructed, data_range=1.0):
    original_np = original.cpu().numpy().squeeze()
    reconstructed_np = reconstructed.cpu().numpy().squeeze()
    ssim_value = ssim(original_np, reconstructed_np, data_range=data_range)
    psnr_value = psnr(original_np, reconstructed_np, data_range=data_range)
    return ssim_value, psnr_value


# Iterate through each model file in the specified directory
for file in os.listdir("models/checkpoints/SMALLAE/MNIST/"):
    if file.endswith(".pth"):  # Check if the file is a PyTorch model file
        # Load the model
        model = PocketAutoencoder()
        model.load_state_dict(
            torch.load("models/checkpoints/SMALLAE/MNIST/" + file, map_location=DEVICE),
            strict=False,
        )
        model.to(DEVICE)
        print("Model loaded:", file)

        # Iterate through each class (0 to 9)
        for n in range(10):
            desired_class = n  # Specify the class you want to filter
            filtered_samples = []

            # Filter samples from the test loader based on the desired class
            for data, label in test_loader:
                indices = torch.nonzero(label == desired_class, as_tuple=False)
                if indices.numel() > 0:
                    for idx in indices:
                        filtered_samples.append((data[idx], label[idx]))

            mse_loss_filtered = 0  # Initialize MSE loss for the current class
            ssim_loss_filtered = 0  # Initialize SSIM loss for the current class
            psnr_loss_filtered = 0  # Initialize PSNR loss for the current class

            with torch.no_grad():  # Disable gradient calculation
                for x_test, _ in filtered_samples:
                    x_test = x_test.to(DEVICE)
                    x_reconstructed = model(x_test)
                    mse_loss_filtered += torch.nn.functional.mse_loss(
                        x_reconstructed, x_test
                    ).item()
                    ssim_value, psnr_value = calculate_ssim_psnr(
                        x_test, x_reconstructed
                    )
                    ssim_loss_filtered += ssim_value
                    psnr_loss_filtered += psnr_value

            # Calculate average losses for the current class
            num_samples = len(filtered_samples)
            if num_samples > 0:
                avg_mse_loss = mse_loss_filtered / num_samples
                avg_ssim_loss = ssim_loss_filtered / num_samples
                avg_psnr_loss = psnr_loss_filtered / num_samples
            else:
                avg_mse_loss = (
                    0  # Handle the case when there are no samples for the current class
                )
                avg_ssim_loss = 0
                avg_psnr_loss = 0

            print(
                f"\tMetrics for class {n} - MSE: {avg_mse_loss}, SSIM: {avg_ssim_loss}, PSNR: {avg_psnr_loss}"
            )
            # Concatenate the results to the DataFrame
            losses_per_class = pd.concat(
                [
                    losses_per_class,
                    pd.DataFrame(
                        {
                            "Model": [file],
                            "Class": [n],
                            "MSE": [avg_mse_loss],
                            "SSIM": [avg_ssim_loss],
                            "PSNR": [avg_psnr_loss],
                        }
                    ),
                ],
                ignore_index=True,
            )

# Save the results to a CSV file
losses_per_class.to_csv("models/checkpoints/SMALLAE/MNIST/more_metrics.csv", index=False)

Model loaded: MNIST_SMLLAE_0.001_128_5_4.pth
	Metrics for class 0 - MSE: 0.8369951843607183, SSIM: -0.16097904340033933, PSNR: 0.7779342676748042


  losses_per_class = pd.concat(


	Metrics for class 1 - MSE: 1.0395673592710284, SSIM: -0.23371325305179277, PSNR: -0.16037558986774278
	Metrics for class 2 - MSE: 0.8835762621935948, SSIM: -0.1835876184230727, PSNR: 0.5506155197559055
	Metrics for class 3 - MSE: 0.90332203318577, SSIM: -0.22636077990095746, PSNR: 0.45570925101892745
	Metrics for class 4 - MSE: 0.9446686668939843, SSIM: -0.2258157755824825, PSNR: 0.2597230016664452
	Metrics for class 5 - MSE: 0.9084933037982393, SSIM: -0.25572115042537275, PSNR: 0.4295579999674039
	Metrics for class 6 - MSE: 0.9010694963324791, SSIM: -0.18473424473568056, PSNR: 0.46948514602135094
	Metrics for class 7 - MSE: 0.9728489723998749, SSIM: -0.26055974667325466, PSNR: 0.1288607795634047
	Metrics for class 8 - MSE: 0.9085603426124526, SSIM: -0.19359579711920039, PSNR: 0.42806539692316664
	Metrics for class 9 - MSE: 0.9579845519202906, SSIM: -0.24348689703740048, PSNR: 0.19584358435079216
Model loaded: MNIST_SMLLAE_0.001_64_1_4.pth
	Metrics for class 0 - MSE: 0.853505191815142