In [3]:
import os
os.chdir('/Users/federicoferoggio/Documents/vs_code/latent-communication')

import torch
import torchvision.transforms as transforms
from torch.optim import Adam
from utils.dataloader_mnist_single import DataLoaderMNIST
from utils.dataloader_fnist_single import DataLoaderFNIST
from utils.dataloader_cifrar100_single import DataLoaderCIFAR100
from utils.dataloader_cifrar10_single import DataLoaderCIFAR10
from models.definitions.ae import LightningAutoencoder
from models.definitions.ae_more_channels import LightningAutoencoderV2



datasets_list = ['CIFAR10', 'CIFAR100']
seeds = [1, 2, 3, 3, 4 ,4]
paths = ['models/checkpoints/AE/CIFAR10/', 'models/checkpoints/AE/CIFAR100/']
dataloader_l = [DataLoaderCIFAR10, DataLoaderCIFAR100]
epochs = [5, 5, 5, 10, 1, 10]
DEVICE = torch.device("mps")
augmentations = [transforms.ToTensor(),transforms.Normalize((0.5,), (0.5,))]

for n, data in enumerate(datasets_list):
    dataloader_current = dataloader_l[n]
    batch_size = 128
    DataLoaders = dataloader_l[n]
    dataloader = DataLoaders(batch_size=batch_size, transformation= augmentations)

    test_loader = dataloader.get_test_loader()
    train_loader = dataloader.get_train_loader()
    for m, d in enumerate(seeds):
        config = {
            'model_name': 'AE',
            'dataset': data,
            # Variance and Mean for the weight initialization
            'weight_var': 1,
            'weight_mean': 0,
            'seed': d,
            # Model setup 
            'input_dim': 784,
            'dims': [256, 128, 64, 32],
            'distribution_dim': 16,
            # Training setup
            'batch_size': 128,
            'num_epochs': epochs[m],
            'learning_rate': 0.001,
            'path': paths[n]
        }
        # Set the seed
        torch.manual_seed(config['seed'])
        model = LightningAutoencoderV2()
        model.to(DEVICE)
        optimizer = Adam(model.parameters(), lr=config['learning_rate'], weight_decay=1e-4)
        for epoch in range(config['num_epochs']):
            overall_loss = 0
            model.train()  # set the model to training mode
            for batch_idx, (x, _) in enumerate(train_loader):
                x = x.to(DEVICE)

                optimizer.zero_grad()
                loss = model.training_step(x)
                
                overall_loss += loss.item()
                
                loss.backward()
                optimizer.step()
                
            print("\tEpoch", epoch + 1, "complete!", "\tAverage Loss: ", overall_loss / (len(train_loader)*batch_size))
                    
        model.eval()  # set the model to evaluation mode
        with torch.no_grad():  # disable gradient calculation
            test_loss = 0
            for x_test, _ in test_loader:  # assuming you have a separate test loader
                x_test = x_test.to(DEVICE)
                test_loss += model.validation_step(x_test).item()

            print("\tTest Loss: ", test_loss / ((len(test_loader)*batch_size)))
                    
        print("Finish!!")

        # Save the model
        name = str(config['dataset'])+ '_' + str(config['model_name']) + '_' + str(config['seed']) + '_' + str(config['num_epochs']) + '.pth'
        print(name)
        # Model Path
        path = config['path'] + name

        torch.save(model.state_dict(), path)

Files already downloaded and verified
Files already downloaded and verified


/Users/federicoferoggio/Documents/vs_code/latent-communication/.zeroshot/lib/python3.9/site-packages/lightning/pytorch/core/module.py:436: You are trying to `self.log()` but the `self.trainer` reference is not registered on the model yet. This is most likely because the model hasn't been passed to the `Trainer`


	Epoch 1 complete! 	Average Loss:  0.001480917947407326
	Epoch 2 complete! 	Average Loss:  0.001276107942637847
	Epoch 3 complete! 	Average Loss:  0.0012483366930623875
	Epoch 4 complete! 	Average Loss:  0.0012326492937496099
	Epoch 5 complete! 	Average Loss:  0.001219926783374375
	Test Loss:  0.0011917626517647995
Finish!!
CIFAR10_AE_1_5.pth
	Epoch 1 complete! 	Average Loss:  0.0015148517191576798
	Epoch 2 complete! 	Average Loss:  0.0012783019489649198
	Epoch 3 complete! 	Average Loss:  0.00124848061961734
	Epoch 4 complete! 	Average Loss:  0.0012317393907009984
	Epoch 5 complete! 	Average Loss:  0.0012197884393955968
	Test Loss:  0.0011929268506650306
Finish!!
CIFAR10_AE_2_5.pth
	Epoch 1 complete! 	Average Loss:  0.0014651704015677123
	Epoch 2 complete! 	Average Loss:  0.0012725981942656667
	Epoch 3 complete! 	Average Loss:  0.001245795468063763
	Epoch 4 complete! 	Average Loss:  0.0012316243553920018
	Epoch 5 complete! 	Average Loss:  0.0012184836990450083
	Test Loss:  0.0011909879

/Users/federicoferoggio/Documents/vs_code/latent-communication/.zeroshot/lib/python3.9/site-packages/lightning/pytorch/core/module.py:436: You are trying to `self.log()` but the `self.trainer` reference is not registered on the model yet. This is most likely because the model hasn't been passed to the `Trainer`
	Epoch 1 complete! 	Average Loss:  0.007029166570239102
	Epoch 2 complete! 	Average Loss:  0.006620410752337751
	Epoch 3 complete! 	Average Loss:  0.006593725072152452
	Epoch 4 complete! 	Average Loss:  0.0065810882639306695
	Epoch 5 complete! 	Average Loss:  0.006572851465383509
	Test Loss:  0.006555338278177041
Finish!!
FMNIST_AE_1_5.pth
	Epoch 1 complete! 	Average Loss:  0.006961854703938846
	Epoch 2 complete! 	Average Loss:  0.00661639380616261
	Epoch 3 complete! 	Average Loss:  0.006594270342655147
	Epoch 4 complete! 	Average Loss:  0.006580616891590644
	Epoch 5 complete! 	Average Loss:  0.006572537388263354
	Test Loss:  0.006551660998146745
Finish!!
FMNIST_AE_2_5.pth
	Epoch 1 complete! 	Average Loss:  0.006978094329569004
	Epoch 2 complete! 	Average Loss:  0.006619723167206878
	Epoch 3 complete! 	Average Loss:  0.006596427908274474
	Epoch 4 complete! 	Average Loss:  0.0065816252438752635
	Epoch 5 complete! 	Average Loss:  0.006573408149055708
	Test Loss:  0.0065538919695853435
Finish!!
FMNIST_AE_3_5.pth
	Epoch 1 complete! 	Average Loss:  0.006978094329569004
	Epoch 2 complete! 	Average Loss:  0.006619723167206878
	Epoch 3 complete! 	Average Loss:  0.006596427908274474
	Epoch 4 complete! 	Average Loss:  0.0065816252438752635
	Epoch 5 complete! 	Average Loss:  0.006573408149055708
	Epoch 6 complete! 	Average Loss:  0.006567447325155171
	Epoch 7 complete! 	Average Loss:  0.006563040653247632
	Epoch 8 complete! 	Average Loss:  0.006559921372959863
	Epoch 9 complete! 	Average Loss:  0.006557290498882151
	Epoch 10 complete! 	Average Loss:  0.00655529145667675
	Test Loss:  0.006538598108565128
Finish!!
FMNIST_AE_3_10.pth
	Epoch 1 complete! 	Average Loss:  0.007069153271353385
	Test Loss:  0.006633507913049263
Finish!!
FMNIST_AE_4_1.pth
	Epoch 1 complete! 	Average Loss:  0.007069153271353385
	Epoch 2 complete! 	Average Loss:  0.006618808575673526
	Epoch 3 complete! 	Average Loss:  0.0065925988916760445
	Epoch 4 complete! 	Average Loss:  0.006580502830390165
	Epoch 5 complete! 	Average Loss:  0.0065728001135673475
	Epoch 6 complete! 	Average Loss:  0.00656704473366806
	Epoch 7 complete! 	Average Loss:  0.00656309546624769
	Epoch 8 complete! 	Average Loss:  0.006560077121270809
	Epoch 9 complete! 	Average Loss:  0.006557785113578412
	Epoch 10 complete! 	Average Loss:  0.006555454884923852
	Test Loss:  0.006539197125814006
Finish!!
FMNIST_AE_4_10.pth
	Epoch 1 complete! 	Average Loss:  0.0048762068156399194
	Epoch 2 complete! 	Average Loss:  0.0046089222785760595
	Epoch 3 complete! 	Average Loss:  0.004585484623599217
	Epoch 4 complete! 	Average Loss:  0.004572693035125669
	Epoch 5 complete! 	Average Loss:  0.004563915694176134
	Test Loss:  0.004547707960481131
Finish!!
MNIST_AE_1_5.pth
	Epoch 1 complete! 	Average Loss:  0.004806873547449422
	Epoch 2 complete! 	Average Loss:  0.004601690826862097
	Epoch 3 complete! 	Average Loss:  0.004580576711499106
	Epoch 4 complete! 	Average Loss:  0.004567994766914323
	Epoch 5 complete! 	Average Loss:  0.004558863717991152
	Test Loss:  0.00454367640652234
Finish!!
MNIST_AE_2_5.pth
	Epoch 1 complete! 	Average Loss:  0.0048170165369299045
	Epoch 2 complete! 	Average Loss:  0.004602856104021896
	Epoch 3 complete! 	Average Loss:  0.004580189375631781
	Epoch 4 complete! 	Average Loss:  0.004568471197984112
	Epoch 5 complete! 	Average Loss:  0.004559944934849101
	Test Loss:  0.004543465226181323
Finish!!
MNIST_AE_3_5.pth
	Epoch 1 complete! 	Average Loss:  0.0048170165369299045
	Epoch 2 complete! 	Average Loss:  0.004602856104021896
	Epoch 3 complete! 	Average Loss:  0.004580189375631781
	Epoch 4 complete! 	Average Loss:  0.004568471197984112
	Epoch 5 complete! 	Average Loss:  0.004559944934849101
	Epoch 6 complete! 	Average Loss:  0.004553412528100934
	Epoch 7 complete! 	Average Loss:  0.004548212255178484
	Epoch 8 complete! 	Average Loss:  0.004543802730723231
	Epoch 9 complete! 	Average Loss:  0.004539757607572241
	Epoch 10 complete! 	Average Loss:  0.0045367768778999855
	Test Loss:  0.004522914408788651
Finish!!
MNIST_AE_3_10.pth
	Epoch 1 complete! 	Average Loss:  0.0048799010744290565
	Test Loss:  0.00461422679120604
Finish!!
MNIST_AE_4_1.pth
	Epoch 1 complete! 	Average Loss:  0.0048799010744290565
	Epoch 2 complete! 	Average Loss:  0.00460798678292172
	Epoch 3 complete! 	Average Loss:  0.004584960452815109
	Epoch 4 complete! 	Average Loss:  0.0045720770963028805
	Epoch 5 complete! 	Average Loss:  0.004562147325083518
	Epoch 6 complete! 	Average Loss:  0.0045554388268813015
	Epoch 7 complete! 	Average Loss:  0.0045496448368104155
	Epoch 8 complete! 	Average Loss:  0.004545229927960363
	Epoch 9 complete! 	Average Loss:  0.004541579848413528
	Epoch 10 complete! 	Average Loss:  0.004538391515223393
	Test Loss:  0.0045248305106747756
Finish!!
MNIST_AE_4_10.pth