# Train Simple Autoencoder

Pip installs for Colab

In [None]:
!pip install torch
!pip install lightning
!pip install matplotlib
!pip install tqdm

In [None]:
import torch
from torch.utils.data import DataLoader
import lightning as L

from datasets import CustomMNIST
from autoencoder import AutoencoderSimple
from utils import plot_mnist_samples, plot_reconstruction_comparison

## Load the data and costruct dataloader

In [None]:
train_dataset = CustomMNIST(root="./datasets", train=True)
test_dataset = CustomMNIST(root="./datasets", train=False)

# Define the sizes for train, validation, and test sets
TRAIN_SIZE = int(0.8 * len(train_dataset))
VAL_SIZE = len(train_dataset) - TRAIN_SIZE
TEST_SIZE = len(test_dataset)
BATCH_SIZE = 100

train_data, val_data = torch.utils.data.random_split(
    train_dataset, [TRAIN_SIZE, VAL_SIZE]
)

train_loader = DataLoader(train_data, batch_size=BATCH_SIZE, shuffle=True)
val_loader = DataLoader(val_data, batch_size=BATCH_SIZE, shuffle=False)
test_loader = DataLoader(test_dataset, batch_size=BATCH_SIZE, shuffle=False)

Plot some sample images and labels from the dataset 


In [None]:
sample_images = [train_dataset[i][0].squeeze().numpy() for i in range(8)]
sample_labels = [train_dataset[i][1] for i in range(8)]
sample_captions = [f"Label: {label}" for label in sample_labels]
plot_mnist_samples(sample_images, sample_labels, sample_captions)

## Train the Autoencoder

In [None]:
autoencoder = AutoencoderSimple()

# Initialize a trainer
trainer = L.Trainer(max_epochs=100, progress_bar_refresh_rate=20, gpus=1 if torch.cuda.is_available() else 0)

# Train the model
trainer.fit(autoencoder, train_loader, val_loader)

## Test Autoencoder

In [None]:
def test_autoencoder(autoencoder, test_loader, device="cuda" if torch.cuda.is_available() else "cpu"):
    """
    Tests the autoencoder on a batch from the test_loader and plots original and reconstructed images.
    :param autoencoder: The autoencoder model.
    :param test_loader: DataLoader for the test dataset.
    :param device: The device to run the model on.
    """
    autoencoder.eval()
    
    # Get a batch of test images
    images, _ = next(iter(test_loader))
    images = images.to(device)
    
    # Reconstruct images using the autoencoder
    with torch.no_grad():
        reconstructed_images = autoencoder(images)
    
    # Prepare images for display
    original_images = images.cpu()
    reconstructed_images = reconstructed_images.cpu()
    
    # Plot original and reconstructed images
    plot_reconstruction_comparison(original_images, reconstructed_images)