## Notebook for training AutoEncoders
We will use this notebook for training our autoencoders. The aim of this notebook is to set up a training procedure as explaineed on the Datasets That Are Not paper.

In [None]:
import torch
import numpy as np
from datetime import datetime
import torchvision
from torchvision import datasets
import torchvision.transforms as transforms
from tqdm import tqdm
from torch.utils.tensorboard import SummaryWriter
import random

Make sure the current device is logged

In [None]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Torch running on {device}")

### Import and prepare MNIST dataset
We will work with the MNIST dataset for experimenation and setup. Let's download it using the handy `torchvision.MNIST` datasets. We will first prepare our train and validation splits.

In [None]:
# convert data to torch.FloatTensor
transform = transforms.ToTensor()

# load the training and test datasets
mnist_dataset = datasets.MNIST(
    root="~/.pytorch/MNIST_data/", train=True, download=True, transform=transform
)

# Define the split ratio for validation data
validation_split = 0.2
num_train = len(mnist_dataset)
indices = list(range(num_train))
split = int(validation_split * num_train)

# Shuffle the indices
random.seed(42)  # For reproducibility
random.shuffle(indices)

# Split the indices into training and validation sets
train_indices, val_indices = indices[split:], indices[:split]

### Prepare dataloaders
The dataloaders will be helpful to let us access the dataset in batches during training.

In [None]:
# Create data loaders for training and validation
train_sampler = torch.utils.data.SubsetRandomSampler(train_indices)
val_sampler = torch.utils.data.SubsetRandomSampler(val_indices)

batch_size = 32
train_dataloader = torch.utils.data.DataLoader(
    mnist_dataset, batch_size=batch_size, sampler=train_sampler
)
validation_dataloader = torch.utils.data.DataLoader(
    mnist_dataset, batch_size=batch_size, sampler=val_sampler
)

Visualize the data

In [None]:
import matplotlib.pyplot as plt
%matplotlib inline
    
# obtain one batch of training images
images, labels = next(iter(train_dataloader))
images = images.numpy()

# get one image from the batch
img = np.squeeze(images[0])

fig = plt.figure(figsize = (2,2)) 
ax = fig.add_subplot(111)
ax.imshow(img, cmap='gray')

Import implemented models and training loop.

In [None]:
from models.convautoencoders import (
    ConvAutoencoder,
    WTASpatialConvAutoencoder,
    WTALifetimeSparseConvAutoencoder,
    WTASpatialLifetimeSparseConvAutoencoder
)
from torch import nn
from train import train_for_n_epochs

## Train Vanilla Convolutional Autoencoder
Let's train our Autoencoder architecture with a basic training loop to verify outputs and such.

Import the model from our models repository

In [None]:
model = ConvAutoencoder()
model.to(device)
print(f"Model architecture:\n\n{model}")

Set up basic training hyper parameters and TensorBoard training visualizer

In [None]:
LEARNING_RATE = 0.001
N_EPOCHS = 10
VISUALIZE_EVERY = 1

Configure basic optimizer with basic loss function.

In [None]:
criterion = nn.MSELoss()
optimizer = torch.optim.Adam(model.parameters(), lr=LEARNING_RATE)

Train model for `N_EPOCHS`:

In [None]:
train_for_n_epochs(
    N_EPOCHS,
    VISUALIZE_EVERY,
    model,
    train_dataloader,
    validation_dataloader,
    optimizer,
    criterion,
    device,
)

## Train Winner Takes All Convolutional Autoencoder with Spatial Sparsity
Let's train our Autoencoder architecture with a basic training loop to verify outputs and such.

Import the model from our models repository

In [None]:
wta_model = WTASpatialConvAutoencoder()
wta_model.to(device)
print(f"Model architecture:\n\n{wta_model}")

Set up basic training hyper parameters and TensorBoard training visualizer

In [None]:
LEARNING_RATE = 0.001
N_EPOCHS = 10
VISUALIZE_EVERY = 1

Configure basic optimizer with basic loss function.

In [None]:
criterion = nn.MSELoss()
optimizer = torch.optim.Adam(wta_model.parameters(), lr=LEARNING_RATE)

Train model for `N_EPOCHS`:
- TODO: This still doesn't work, the sizes inside the model don't match. We have to investigate, but my theory is the maxpooling might be messing things up inside of the architecture.

In [None]:
train_for_n_epochs(
    N_EPOCHS,
    VISUALIZE_EVERY,
    wta_model,
    train_dataloader,
    validation_dataloader,
    optimizer,
    criterion,
    device,
)

## Train Winner Takes All Convolutional Autoencoder with Lifetime Sparsity
Let's train our Autoencoder architecture with a basic training loop to verify outputs and such.

Import the model from our models repository

In [None]:
wta_lifetime_model = WTALifetimeSparseConvAutoencoder(k_percentage=0.9)
wta_lifetime_model.to(device)
print(f"Model architecture:\n\n{wta_lifetime_model}")

Set up basic training hyper parameters and TensorBoard training visualizer

In [None]:
LEARNING_RATE = 0.001
N_EPOCHS = 30
VISUALIZE_EVERY = 1

Configure basic optimizer with basic loss function.

In [None]:
criterion = nn.MSELoss()
optimizer = torch.optim.Adam(wta_lifetime_model.parameters(), lr=LEARNING_RATE)

Train model for `N_EPOCHS`:
- TODO: This still doesn't work, the sizes inside the model don't match. We have to investigate, but my theory is the maxpooling might be messing things up inside of the architecture.

In [None]:
train_for_n_epochs(
    N_EPOCHS,
    VISUALIZE_EVERY,
    wta_lifetime_model,
    train_dataloader,
    validation_dataloader,
    optimizer,
    criterion,
    device,
)

## Train Winner Takes All Convolutional Autoencoder with Spatial Sparsity and Lifetime Sparsity
Let's train our Autoencoder architecture with a basic training loop to verify outputs and such.

We'll do the following:
- Import the model from our models repository
- Set up basic training hyper parameters and TensorBoard training visualizer
- Configure basic optimizer with basic loss function.

In [None]:
# Setup simple hyperparameters
LEARNING_RATE = 0.001
N_EPOCHS = 10
VISUALIZE_EVERY = 1
K_PERCENTAGE=0.05

In [None]:
# Initialize model
wta_spatial_lifetime_model = WTASpatialLifetimeSparseConvAutoencoder(k_percentage=K_PERCENTAGE)
wta_spatial_lifetime_model.to(device)
print(f"Model architecture:\n\n{wta_spatial_lifetime_model}")
# Define loss function and optimizer
criterion = nn.MSELoss()
optimizer = torch.optim.Adam(wta_spatial_lifetime_model.parameters(), lr=LEARNING_RATE)

Train model for `N_EPOCHS`.

In [None]:
# Setup summary writer and set training going
writer = SummaryWriter(
        f"logs/{wta_spatial_lifetime_model.name}_{K_PERCENTAGE}_{datetime.now().strftime('%Y%m%d-%H%M%S')}"
    )
train_for_n_epochs(
    N_EPOCHS,
    VISUALIZE_EVERY,
    wta_spatial_lifetime_model,
    train_dataloader,
    validation_dataloader,
    optimizer,
    criterion,
    device,
    writer
)

Experiment with multiple percentages from `0.1` to `1`:

In [None]:
K_PERCENTAGES=np.linspace(0.1,1, 10)

In [None]:
for K_PERCENTAGE in K_PERCENTAGES:
    print(f"Training WTA Autoencoder with lifetime sparsity k% {K_PERCENTAGE}.")
    # Initialize model
    wta_spatial_lifetime_model = WTASpatialLifetimeSparseConvAutoencoder(k_percentage=K_PERCENTAGE)
    wta_spatial_lifetime_model.to(device)
    print(f"Model architecture:\n\n{wta_spatial_lifetime_model}")
    # Define loss function and optimizer
    criterion = nn.MSELoss()
    optimizer = torch.optim.Adam(wta_spatial_lifetime_model.parameters(), lr=LEARNING_RATE)
    # Setup summary writer and set training going
    writer = SummaryWriter(
            f"logs/{wta_spatial_lifetime_model.name}_{K_PERCENTAGE}_{datetime.now().strftime('%Y%m%d-%H%M%S')}"
        )
    train_for_n_epochs(
        N_EPOCHS,
        VISUALIZE_EVERY,
        wta_spatial_lifetime_model,
        train_dataloader,
        validation_dataloader,
        optimizer,
        criterion,
        device,
        writer
    )