## Packages

In [1]:
!pip install omegaconf pytorch_lightning



In [2]:
import argparse
import sys
from pathlib import Path
from omegaconf import OmegaConf

from operator import mul
from functools import reduce
import numpy as np
import torch
from torch import nn
from torch.utils.data import TensorDataset, random_split
from torch.utils.data.dataloader import DataLoader

import pytorch_lightning as pl
from pytorch_lightning import Trainer, seed_everything
from pytorch_lightning.callbacks import TQDMProgressBar, ModelCheckpoint
from pytorch_lightning.loggers import TensorBoardLogger

ModuleNotFoundError: No module named 'omegaconf'

## Encoder/Decoder

In [None]:
class ConvEncoder(nn.Module):
    def __init__(self, in_channels, conv_encoder_layers, use_batch_norm=False):
        super().__init__()

        layers = []
        for (out_channels, kernel, stride) in conv_encoder_layers:
            layers.append(nn.Conv1d(in_channels, out_channels, kernel, stride=stride))
            if use_batch_norm:
                layers.append(nn.BatchNorm1d(out_channels))
            layers.append(nn.LeakyReLU(0.05))

            in_channels = out_channels
        
        self.layers = nn.Sequential(*layers)

    def forward(self, x):
        return self.layers(x)


class ConvDecoder(nn.Module):
    def __init__(self,
                 in_features,
                 encoder_output_dim,
                 conv_decoder_layers,
                 use_batch_norm=False):
        """1-D Convolutional Decoder

        Args:
            in_features (int): Number of input features
            encoder_output_dim (list): Shape of the unflattened output from the encoder
                (num_channels x ...)
            conv_decoder_layers (list):List of tuples specifying ConvTranspose1d layers
            use_batch_norm (bool, optional): Whether to add BatchNorm1d layers in 
                between ConvTranspose1d layers. Defaults to False.
        """
        super().__init__()

        layers = [
            nn.Linear(in_features, reduce(mul, encoder_output_dim)),
            nn.Unflatten(1, encoder_output_dim)
        ]

        in_channels = encoder_output_dim[0]
        for i, (out_channels, kernel, stride, output_padding) in enumerate(conv_decoder_layers):
            layers.append(
                nn.ConvTranspose1d(
                    in_channels,
                    out_channels,
                    kernel,
                    stride=stride,
                    output_padding=output_padding))
            if use_batch_norm:
                layers.append(nn.BatchNorm1d(out_channels))
            
            # Don't add activation for last layer
            if i != len(conv_decoder_layers) - 1:
                layers.append(nn.LeakyReLU(0.05))

            in_channels = out_channels
        
        self.layers = nn.Sequential(*layers)

    def forward(self, x):
        return self.layers(x)


class AE(nn.Module):
    def __init__(self, encoder, decoder, encoder_output_dim, encoding_dim):
        super(AE, self).__init__()

        self.encoder = encoder
        self.decoder = decoder

        flattened_dim = reduce(mul, encoder_output_dim)
        self.encoding = nn.Linear(flattened_dim, encoding_dim)

    def encode(self, x):
        output = torch.flatten(self.encoder(x), start_dim=1)
        return self.encoding(output)
    
    def decode(self, encoding):
        return self.decoder(encoding)
    
    def forward(self, x):
        return self.decode(self.encode(x))

## VAE

In [None]:
class VAE(nn.Module):
    def __init__(self, encoder, decoder, encoder_output_dim, latent_dim):
        super(VAE, self).__init__()

        self.encoder = encoder
        self.decoder = decoder

        flattened_dim = reduce(mul, encoder_output_dim)
        self.latent_mean = nn.Linear(flattened_dim, latent_dim)
        self.latent_log_var = nn.Linear(flattened_dim, latent_dim)
    
    def encode(self, x):
        output = torch.flatten(self.encoder(x), start_dim=1)
        return self.latent_mean(output), self.latent_log_var(output)

    def sample(self, mean, log_var):
        std = torch.exp(log_var / 2)
        eps = torch.randn_like(std)
        return eps * std + mean
    
    def decode(self, z):
        return self.decoder(z)

    def forward(self, x):
        mean, log_var = self.encode(x)
        z = self.sample(mean, log_var)
        return mean, log_var, self.decode(z)

## SpikeSortingVAE LightningModule

In [None]:
class SpikeSortingVAE(pl.LightningModule):
    def __init__(self, config):
        super().__init__()
        self.config = config
        self.save_hyperparameters(config)
        
        model_config = config["model"]
        encoder = ConvEncoder(model_config["in_channels"],
                              model_config["conv_encoder_layers"],
                              use_batch_norm=model_config["use_batch_norm"])
        decoder = ConvDecoder(model_config["latent_dim"], 
                              model_config["encoder_output_dim"],
                              model_config["conv_decoder_layers"],
                              use_batch_norm=model_config["use_batch_norm"])
        self.model = VAE(
            encoder,
            decoder,
            model_config["encoder_output_dim"],
            model_config["latent_dim"]
        )
        self.log_scale = nn.Parameter(torch.tensor(0.0))

    def prepare_data(self):
        data_config = self.config["data"]
        templates = np.load(data_config["path"])
        x = torch.from_numpy(templates).float()
        dataset = TensorDataset(x, x)
        
        total_samples = len(x)
        n_train = int(data_config["train_val_split"] * total_samples)
        n_val = total_samples - n_train
        self.train_dataset, self.val_dataset = random_split(dataset, [n_train, n_val])

    def train_dataloader(self):
        return DataLoader(
            self.train_dataset, batch_size=self.config["data"]["train_batch_size"])
    
    def val_dataloader(self):
        return DataLoader(
            self.val_dataset, batch_size=self.config["data"]["val_batch_size"])

    def forward(self, x):
        return self.model(x)

    def gaussian_nll(self, x, x_hat):
        scale = torch.exp(self.log_scale)
        predicted_dist = torch.distributions.Normal(x_hat, scale)
        batch_size = x.shape[0]
        return -predicted_dist.log_prob(x).view(batch_size, -1).sum(dim=1)

    def kl_divergence(self, mu, log_var):
        return torch.mean(
            0.5 * torch.sum(log_var.exp() - log_var + mu**2 - 1, dim=1), dim=0)

    def configure_optimizers(self):
        return torch.optim.Adam(
            self.parameters(), lr=(self.config["learning_rate"] or 1e-4))

    def training_step(self, batch, batch_idx):
        x, _ = batch
        mean, log_var, x_hat = self.model(x)

        reconstruction_loss = self.gaussian_nll(x, x_hat)
        kl_divergence = self.kl_divergence(mean, log_var)
        elbo = (reconstruction_loss + kl_divergence).mean()

        self.log_dict({
            'train_loss': elbo,
            'train_recon_loss': reconstruction_loss.mean(),
            'train_kld': kl_divergence.mean()
        }, on_step=False, on_epoch=True, prog_bar=True)

        return elbo

    def validation_step(self, batch, batch_idx):
        x, _ = batch
        mean, log_var, x_hat = self.model(x)

        reconstruction_loss = self.gaussian_nll(x, x_hat)
        kl_divergence = self.kl_divergence(mean, log_var)
        elbo = (reconstruction_loss + kl_divergence).mean()

        self.log_dict({
            'val_loss': elbo,
            'val_recon_loss': reconstruction_loss.mean(),
            'val_kld': kl_divergence.mean()
        }, on_epoch=True, prog_bar=True)

        return elbo

## Configurations

In [None]:
base_config = OmegaConf.create({
    "random_seed": 4995,
    "model": {
        "in_channels": 20,
        "use_batch_norm": True
    },
    "learning_rate": 1e-4,
    "data": {
        "path": "templates.npy",
        "train_val_split": 0.8,
        "train_batch_size": 100,
        "val_batch_size": 100
    },
    "trainer": {
        "gpus": 1,
        "max_epochs": 60
    }

})

two_conv_config = OmegaConf.create({
    "model": {
        "conv_encoder_layers": [[32, 5, 2], [16, 5, 2]],
        "conv_decoder_layers": [[16, 5, 2, 0], [20, 5, 2, 0]],
        "encoder_output_dim": [16, 28]
    }
})

latent_dim_ablation_configs = [OmegaConf.merge(
    base_config, two_conv_config, OmegaConf.create(c)) for c in [
    {
        "name": "vae_2conv_12latent",
        "model": { "latent_dim": 12 }
    },
    {
        "name": "vae_2conv_10latent",
        "model": { "latent_dim": 10 }
    },
    {
        "name": "vae_2conv_8latent",
        "model": { "latent_dim": 8 }
    },
    {
        "name": "vae_2conv_6latent",
        "model": { "latent_dim": 6 }
    }
]]

## Training

In [None]:
def train(config):
    seed_everything(config["random_seed"])
    system = SpikeSortingVAE(config)

    experiment_name = config["name"]
    experiment_dir = Path(f"./experiments/{experiment_name}")
    experiment_dir.mkdir(parents=True, exist_ok=True)

    tb_logger = TensorBoardLogger("experiments", name=experiment_name)

    checkpoint_callback = ModelCheckpoint(
        dirpath=experiment_dir,
        filename="{epoch}",
        auto_insert_metric_name=True
    )

    progress_bar_callback = TQDMProgressBar(refresh_rate=20)
    trainer = Trainer(
        **config["trainer"],
        callbacks=[checkpoint_callback, progress_bar_callback],
        logger=tb_logger
    )
    trainer.fit(system)

    train_templates, _ = system.train_dataset[:]
    train_latent_reps, _, train_reconstructions = system(train_templates)
    np.save(experiment_dir / "train_latent_reps.npy", train_latent_reps.detach().numpy())
    np.save(experiment_dir / "train_reconstructions.npy", train_reconstructions.detach().numpy())

    val_templates, _ = system.val_dataset[:]
    val_latent_reps, _, val_reconstructions = system(val_templates)
    np.save(experiment_dir / "val_latent_reps.npy", val_latent_reps.detach().numpy())
    np.save(experiment_dir / "val_reconstructions.npy", val_reconstructions.detach().numpy())

In [None]:
for config in latent_dim_ablation_configs:
    config = OmegaConf.to_container(config)
    print(config["name"])
    print('-' * 50)
    train(config)

## Latent Space Visualization

In [None]:
ldims_10 = SpikeSortingVAE.load_from_checkpoint("epoch=49.ckpt")
ldims_12 = SpikeSortingVAE.load_from_checkpoint("epoch=59.ckpt")

In [None]:
ldims_10.forward()