In [2]:
import os
import torch
import torchvision.transforms as transforms
from torch.optim import Adam
from torch.optim.lr_scheduler import ReduceLROnPlateau
from torchvision import datasets, transforms
from torch.utils.data import DataLoader, Subset
from utils.dataloaders.full_dataloaders import DataLoaderMNIST, DataLoaderCIFAR10, DataLoaderCIFAR100, DataLoaderFashionMNIST
from models.definitions.PCKTAE import PocketAutoencoder
import pandas as pd
from tqdm import tqdm
import itertools
import sys
import logging

logging.getLogger('torchvision.datasets').setLevel(logging.ERROR)
os.chdir('../../')


# Initialize DataFrame with additional columns
loss_dataset = pd.DataFrame(columns=['Dataset', 'Model','Latent Size', 'Seed', 'Loss'])

# Define the lists
dataset = 'FMNIST'
size_input = 28
channels_input = 1
seeds = [1, 2, 3]
epochs = [20]
batch_sizes = [128]
learning_rates = [0.005]
latent_sizes = [10, 30, 50]

# Create combinations for MNIST and CIFAR datasets separately
combinations1 = [seeds, epochs, batch_sizes, learning_rates, latent_sizes[:3]]
combinations1 = list(itertools.product(*combinations1))

# Combine both sets of combinations
combinations = combinations1 

# Print the combinations and their count
print(combinations, len(combinations))

# Example of setting device and augmentations (adjust according to your actual use case)
DEVICE = torch.device("mps" if torch.backends.mps.is_available() else "cpu")

for combo in tqdm(combinations):
    seed, epoch, batch_size, learning_rate, latent_size = combo
    print(f"Dataset: {dataset}, Channels input: {channels_input}, Epochs: {epoch}, Batch size: {batch_size}, Learning rate: {learning_rate}, Latent size: {latent_size}")


ModuleNotFoundError: No module named 'utils'

In [6]:
#Filter the datloader with specific labels
# Define a transform to normalize the data
transform = transforms.Compose([transforms.ToTensor()])

# Download and load the dataset
trainset = datasets.FashionMNIST(root=os.getcwd()+'/data/', download=True, train=True, transform=transform)
testset = datasets.FashionMNIST(root=os.getcwd()+'/data/', download=True, train=False, transform=transform)

# Specify the labels you want to keep
specific_labels = [0, 1, 2, 3,4]
labels_string = '_'.join(map(str, specific_labels))

def filter_by_label(dataset, labels):
    indices = [i for i, label in enumerate(dataset.targets) if label in labels]
    return Subset(dataset, indices)

# Filter the datasets
trainset_filtered = filter_by_label(trainset, specific_labels)
testset_filtered = filter_by_label(testset, specific_labels)

# Create DataLoaders
trainloader_filtered = DataLoader(trainset_filtered, batch_size=64, shuffle=True)
testloader_filtered = DataLoader(testset_filtered, batch_size=64, shuffle=True)
path_model = f'models/checkpoints/SMALLAE_not_all_classes/FMNIST/classes{labels_string}/'

In [4]:
labels_string

'0_1_2_3'

In [7]:
iterations_ = tqdm(combinations)
for seed, num_epochs, batch_size, learning_rate, latent_dim in iterations_:
    
    config = {
        'model_name': 'PCKTAE',
        'dataset': dataset,
        'weight_var': 1,
        'weight_mean': 0,
        'seed': seed,
        'batch_size': batch_size,
        'num_epochs': num_epochs,
        'learning_rate': learning_rate,
        'path': path_model
    }
    
    torch.manual_seed(config['seed'])
    model = PocketAutoencoder(hidden_dim=latent_dim, n_input_channels=channels_input, input_size=size_input)
    model.to(DEVICE)
    optimizer = Adam(model.parameters(), lr=config['learning_rate'], weight_decay=1e-4)
    scheduler = ReduceLROnPlateau(optimizer, mode='min', factor=0.1, patience=5)

    for epoch in range(config['num_epochs']):
        overall_loss = 0
        model.train()  # Set the model to training mode
        
        for batch_idx, (x, _) in enumerate(trainloader_filtered):
            x = x.to(DEVICE)
            optimizer.zero_grad()
            loss = model.training_step(x)
            overall_loss += loss.item()
            loss.backward()
            optimizer.step()
        
        avg_loss = overall_loss / (len(trainloader_filtered) * batch_size)
        new_row = pd.DataFrame({'Dataset': [config['dataset']],
                                'Model': [config['model_name']],
                                'Seed': [config['seed']],
                                'Latent Size': [latent_dim],
                                'Loss': [avg_loss]})
        loss_dataset = pd.concat([loss_dataset, new_row], ignore_index=True)
        iterations_.set_description(f"Epoch: {epoch}, Loss: {avg_loss}")
    
    model.eval()  # Set the model to evaluation mode
    with torch.no_grad():  # Disable gradient calculation
        test_loss = 0
        for x_test, _ in testloader_filtered:
            x_test = x_test.to(DEVICE)
            test_loss += model.validation_step(x_test).item()
        
        avg_test_loss = test_loss / (len(testloader_filtered) * batch_size)
        scheduler.step(avg_test_loss)  # Update the learning rate based on the test loss
        new_row = pd.DataFrame({'Dataset': [config['dataset']],
                                'Model': [config['model_name']],
                                'Seed': [config['seed']],
                                'Latent Size': [latent_dim],
                                'Loss': [avg_test_loss]})
        loss_dataset = pd.concat([loss_dataset, new_row], ignore_index=True)
        iterations_.set_description(f"Test Loss: {avg_test_loss}")    
    
    # Save the model
    name = f"{config['dataset']}_{config['model_name']}_{latent_dim}_{config['seed']}.pth"
    print(name)
    path = config['path'] + name
    torch.save(model.state_dict(), path)

loss_dataset.to_csv(f'models/checkpoints/SMALLAE_classes/lossesFMNIST.csv', index=False)

  0%|          | 0/9 [00:00<?, ?it/s]

Epoch: 6, Loss: 0.0004904168944156516:   0%|          | 0/9 [04:11<?, ?it/s]


KeyboardInterrupt: 