In [None]:
import os
os.chdir('/Users/federicoferoggio/Documents/vs_code/latent-communication')

import torch
import numpy as np
import torch.nn as nn
import matplotlib.pyplot as plt
import torchvision.datasets as datasets
from torch.utils.data import DataLoader
import torchvision.transforms as transforms
from mpl_toolkits.axes_grid1 import ImageGrid
from torchvision.utils import save_image, make_grid
from torch.optim import Adam
from sklearn.manifold import  TSNE
from sklearn.decomposition import PCA
from utils.dataloader_mnist_single import DataLoaderMNIST
from utils.dataloader_fnist_single import DataLoaderFNIST

datasets_list = ['FMNIST', 'MNIST']
seeds = [1, 2, 3, 3, 4 ,4]
paths = ['./models/checkpoints/AE/FMNIST/', './models/checkpoints/AE/MNIST/']
dataloader_l = [DataLoaderMNIST, DataLoaderFNIST]
epochs = [10, 10, 10, 20, 3, 20]

for n, model in enumerate(datasets_list):
    for m, d in enumerate(seeds):
        config = {
            'model_name': 'AE',
            'dataset': model,
            'dataloaders': dataloader_l[n],
            # Variance and Mean for the weight initialization
            'weight_var': 1,
            'weight_mean': 0,
            'seed': d,
            # Model setup 
            'input_dim': 784,
            'dims': [256, 128, 64, 32],
            'distribution_dim': 16,
            # Training setup
            'batch_size': 128,
            'num_epochs': epochs[m],
            'learning_rate': 0.0002,
            'path': paths[n]
        }

        # DEVICE for Mac for Windows use CUDA
        DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")

        # Set the seed
        torch.manual_seed(config['seed'])

        import os
        # Print the current working directory
        print(os.getcwd())
        # Set the working directory

        # Change directory
        
        
        # Import MNIST Data Loader 
        from utils.dataloader_mnist_single import DataLoaderMNIST

        # Data Transformations
        augmentations = [transforms.ToTensor(),
                    transforms.RandomRotation(10),
                    transforms.RandomHorizontalFlip(),
                    transforms.Normalize((0.5,), (0.5,))
                ]
        # Create DataLoader
        batch_size = config['batch_size']
        DataLoaders = config['dataloaders']
        dataloader = DataLoaders(batch_size=batch_size, transformation= augmentations)

        test_loader = dataloader.get_test_loader()
        train_loader = dataloader.get_train_loader()

        print(len(train_loader), len(test_loader))
        plt.show()

        # Model comes from python file for later usage 
        # Model is defined in model_def.py
        from models.definitions.ae import LightningAutoencoder
        model = LightningAutoencoder()

        DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")

        optimizer = Adam(model.parameters(), lr=config['learning_rate'], weight_decay=1e-4)

        for epoch in range(config['num_epochs']):
            overall_loss = 0
            model.train()  # set the model to training mode
            for batch_idx, (x, _) in enumerate(train_loader):
                x = x.to(DEVICE)

                optimizer.zero_grad()
                loss = model.training_step(x)
                
                overall_loss += loss.item()
                
                loss.backward()
                optimizer.step()
                
            print("\tEpoch", epoch + 1, "complete!", "\tAverage Loss: ", overall_loss / (len(train_loader)*batch_size))
            
        model.eval()  # set the model to evaluation mode
        with torch.no_grad():  # disable gradient calculation
            test_loss = 0
            for x_test, _ in test_loader:  # assuming you have a separate test loader
                x_test = x_test.to(DEVICE)
                test_loss += model.validation_step(x_test).item()
            print("\tTest Loss: ", test_loss / len(test_loader)*batch_size)
                
        print("Finish!!")

        # Save the model
        name = config['dataset']+ '_' + config['model_name'] + '_' + str(config['seed']) + '_' + str(config['num_epochs']) + '.pth'
        print(name)
        # Model Path
        path = config['path'] + name

        torch.save(model.state_dict(), path)