## Winner Take All Convolutional Autoencoders
Let's walk through training and investigating a winner take all convolutional autoencoder on the MNIST dataset. We'll start by importing the necessary libraries and loading the dataset. We'll visualize training in TensorBoard.

In [None]:
import torch
import numpy as np
from datetime import datetime
import torchvision
from torchvision import datasets
import torchvision.transforms as transforms
from tqdm import tqdm
from torch.utils.tensorboard import SummaryWriter
from tools.dataset import split_dataset
import random
from tools.eval import load_model_from_checkpoint, visualize_filters

Define constants.

In [None]:
RANDOM_SEED = 42
VALIDATION_SPLIT = 0.05
BATCH_SIZE = 128

Make sure the current device is logged.

In [None]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Torch running on {device}")

### Import and prepare MNIST dataset
We will work with the MNIST dataset for experimenation and setup. Let's download it using the handy `torchvision.MNIST` datasets. We will first prepare our train and validation splits.

In [None]:
# convert data to torch.FloatTensor
transform = transforms.ToTensor()

# load the training and test datasets
dataset = datasets.MNIST(
    root="~/.pytorch/MNIST_data/", train=True, download=True, transform=transform
)
test_dataset = datasets.MNIST(
    root="~/.pytorch/MNIST_data/", train=False, download=True, transform=transform
)

### Prepare dataloaders
The dataloaders will be helpful to let us access the dataset in batches during training.

In [None]:
# Define dataloaders
train_loader, validation_loader = split_dataset(
    dataset,
    batch_size=BATCH_SIZE,
    validation_split=VALIDATION_SPLIT,
    random_seed=RANDOM_SEED,
)
test_loader = torch.utils.data.DataLoader(
    test_dataset, batch_size=BATCH_SIZE, shuffle=True
)

Import implemented models and training loop.

In [None]:
from torch import nn
from tools.train import train_for_n_epochs

## Train autoencoders
Let's train the autoencoder and visualize the training in TensorBoard. We will train the autoencoder for 10 epochs. The goal of this excercise is to make sure the autoencoder is learning something useful. We will not be using the autoencoder for any downstream tasks.

Import the model from our models repository

In [None]:
from models.wta import (
    WTAConvAutoencoder128,
    WTAConvAutoencoder64,
    ConvAutoencoder128,
    ConvAutoencoder64,
)

Let's define the optimizer and loss criterion we will use.

In [None]:
criterion = nn.MSELoss()

Train baseline Convolutional Autoencoder

In [None]:
# Define hyperparameters
N_EPOCHS = 100
VISUALIZE_EVERY = 1
LEARNING_RATE = 1e-4
CHECKPOINT_PATH = "/home/fede/Documents/datasets_that_are_not/checkpoints"

# Initialize model
model = ConvAutoencoder64()
model.to(device)
print(f"Model architecture:\n\n{model}")
optimizer = torch.optim.Adam(model.parameters(), lr=LEARNING_RATE, weight_decay=1e-5)

In [None]:
# Setup summary writer and set training going
writer = SummaryWriter(
    f"logs/{model.name}_{datetime.now().strftime('%Y%m%d-%H%M%S')}"
)
train_for_n_epochs(
    N_EPOCHS,
    VISUALIZE_EVERY,
    model,
    train_loader,
    validation_loader,
    optimizer,
    criterion,
    device,
    writer=writer,
    checkpoint_dir=CHECKPOINT_PATH
)

In [None]:
PATH_TO_CHECKPOINTS = \
    "/home/fede/Documents/datasets_that_are_not/checkpoints/ConvAutoencoder64_20231001-153026/epoch_50.pth"
model = load_model_from_checkpoint(
    ConvAutoencoder64(), 
    PATH_TO_CHECKPOINTS, 
    device, 
    eval=True
)
visualize_filters(model, model.decoder[0].weight)

## Train Winner Takes All Convolutional Autoencoder with Spatial Sparsity and Lifetime Sparsity
Let's train our Autoencoder architecture with a basic training loop to verify outputs and such. We will train 
the model that uses architecture `128conv5-128conv5-128deconv1` first.

In [None]:
N_EPOCHS = 30
VISUALIZE_EVERY = 1
LEARNING_RATE = 1e-4
CHECKPOINT_PATH = "/home/fede/Documents/datasets_that_are_not/checkpoints"

In [None]:
# Initialize model
K_PERCENTAGE = 0.2
model = WTAConvAutoencoder128(
    k_percentage=K_PERCENTAGE,
)
model.name = f"{model.name}_k_{K_PERCENTAGE}"
model.to(device)
print(f"Model architecture:\n\n{model}")
optimizer = torch.optim.Adam(model.parameters(), lr=LEARNING_RATE)

Train model for `N_EPOCHS`.

In [None]:
# Setup summary writer and set training going
writer = SummaryWriter(
    f"logs/{model.name}_{K_PERCENTAGE}_{datetime.now().strftime('%Y%m%d-%H%M%S')}"
)
train_for_n_epochs(
    N_EPOCHS,
    VISUALIZE_EVERY,
    model,
    train_loader,
    validation_loader,
    optimizer,
    criterion,
    device,
    writer=writer,
    checkpoint_dir=CHECKPOINT_PATH
)

Train model that uses architecture `64conv5-64conv5-64conv5-64deconv11`.

In [None]:
# Initialize model
K_PERCENTAGE = 0.2
model = WTAConvAutoencoder64(
    k_percentage=1
)
model.name = f"{model.name}_k_{1}"
model.to(device)
print(f"Model architecture:\n\n{model}")
optimizer = torch.optim.Adam(model.parameters(), lr=LEARNING_RATE)
print(f"Model name is {model.name}")

In [None]:
# Setup summary writer and set training going
N_EPOCHS = 100
writer = SummaryWriter(
    f"logs/{model.name}_{K_PERCENTAGE}_{datetime.now().strftime('%Y%m%d-%H%M%S')}"
)
train_for_n_epochs(
    N_EPOCHS,
    VISUALIZE_EVERY,
    model,
    train_loader,
    validation_loader,
    optimizer,
    criterion,
    device,
    writer=writer,
    checkpoint_dir=CHECKPOINT_PATH
)

## Visualizing learned filters from WTA Convolutional Autoencoder
Let's visualize the learned filters from the convolutional layers of the autoencoder. We will visualize the filters from the first convolutional layer of the autoencoder.

Visualizing the learned deconvolution filters when keeping the top 5% of the activations trained for 100 epochs.

In [None]:
PATH_TO_CHECKPOINTS = \
    "/home/fede/Documents/datasets_that_are_not/checkpoints/WTAConvAutoencoder64_k_0.05_20231001-141318/epoch_69.pth"
model = load_model_from_checkpoint(
    WTAConvAutoencoder64(), 
    PATH_TO_CHECKPOINTS, 
    device, 
    eval=True
)
visualize_filters(model, model.decoder[0].weight)

Visualizing the learned deconvolution filters when keeping the top 20% of the activations.

In [None]:
PATH_TO_CHECKPOINTS = \
    "/home/fede/Documents/datasets_that_are_not/checkpoints/WTASpatialLifetimeSparseConvAutoencoder_20231001-124328/epoch_29.pth"
model = load_model_from_checkpoint(
    WTAConvAutoencoder128(), 
    PATH_TO_CHECKPOINTS, 
    device, 
    eval=True
)
visualize_filters(model, model.decoder[0].weight)