## Packages

In [None]:
!pip install omegaconf pytorch_lightning lightning-bolts



In [None]:
from omegaconf import OmegaConf
from pathlib import Path
import numpy as np
import torch
import json

import pytorch_lightning as pl
from pytorch_lightning import Trainer, seed_everything
from pytorch_lightning.callbacks import (
    TQDMProgressBar, ModelCheckpoint, EarlyStopping
)
from pl_bolts.callbacks import PrintTableMetricsCallback

In [None]:
from src.models.ae import *
from src.models.vae import *
from src.spike_sorting import *

## Train Function

In [None]:
def train(system_class,
          config,
          experiment_dir="experiments",
          checkpoint_name="vae_{val_loss:.3f}.ckpt"):
    seed_everything(config["random_seed"])
    system = system_class(config)

    experiment_name = config["name"]
    experiment_dir = Path(f"{experiment_dir}/{experiment_name}")
    experiment_dir.mkdir(parents=True, exist_ok=True)

    checkpoint_callback = ModelCheckpoint(
        dirpath=experiment_dir,
        filename=checkpoint_name,
        auto_insert_metric_name=True
    )

    trainer = Trainer(
        **config["trainer"],
        callbacks=[
            checkpoint_callback,
            EarlyStopping(monitor="val_loss", patience=3),
            TQDMProgressBar(refresh_rate=20)
        ],
        logger=None
    )
    trainer.fit(system)
    return system, trainer

## Experiments

### Base Config

In [44]:
base_config = OmegaConf.create({
    "random_seed": 4995,
    "model": {
        "in_channels": 20,
        "conv_encoder_layers": [[32, 5, 2], [16, 5, 2]],
        "conv_decoder_layers": [[16, 5, 2, 0], [20, 5, 2, 0]],
        "encoder_output_dim": [16, 28],
        "use_batch_norm": True
    },
    "learning_rate": 1e-4,
    "data": {
        "train_data_path": "data/train_templates.npy",
        "val_data_path": "data/val_templates.npy",
        "train_batch_size": 100,
        "val_batch_size": 100
    },
    "trainer": {
        "gpus": 1,
        "max_epochs": 100
    }

})

psvae_base_config = OmegaConf.merge(base_config, {
    "data": {
        "train_label_path": "data/train_labels.npy",
        "val_label_path": "data/val_labels.npy",
    },
    "anneal_epochs": 50
})

### VAE

In [None]:
vae_configs = [OmegaConf.merge(base_config, c) for c in [
    {
        "name": "vae_10latent",
        "model": {
            "latent_dim": 10
        }
    },
    {
        "name": "vae_8latent",
        "model": {
            "latent_dim": 8
        }
    },
    {
        "name": "vae_6latent",
        "model": {
            "latent_dim": 6
        }
    },
]]

exp_dir = "experiments/vaes"
for config in vae_configs:
    system, trainer = train(
        SpikeSortingVAE,
        OmegaConf.to_container(config),
        experiment_dir=exp_dir,
        checkpoint_name="model")
    val_losses = trainer.validate()
    with open(f"{exp_dir}/{config['name']}/val_losses.json", "w") as f:
        json.dump(val_losses[0], f)

### PS-VAE Selecting $\alpha$

In [None]:
psvae_alpha_selection = [OmegaConf.merge(psvae_base_config, c) for c in [
    {
        "name": "psvae_10latent_alpha=1_beta=1",
        "model": {
            "latent_dim": 10,
            "label_dim": 4
        },
        "alpha": 1,
        "beta": 1
    },
    {
        "name": "psvae_10latent_alpha=10_beta=1",
        "model": {
            "latent_dim": 10,
            "label_dim": 4
        },
        "alpha": 10,
        "beta": 1
    },
    {
        "name": "psvae_10latent_alpha=25_beta=1",
        "model": {
            "latent_dim": 10,
            "label_dim": 4
        },
        "alpha": 25,
        "beta": 1
    },
    {
        "name": "psvae_10latent_alpha=50_beta=1",
        "model": {
            "latent_dim": 10,
            "label_dim": 4
        },
        "alpha": 50,
        "beta": 1
    },
    {
        "name": "psvae_10latent_alpha=100_beta=1",
        "model": {
            "latent_dim": 10,
            "label_dim": 4
        },
        "alpha": 100,
        "beta": 1
    },
]]

exp_dir = "experiments/psvae_alpha_selection"
for config in psvae_alpha_selection:
    system, trainer = train(
        SpikeSortingPSVAE,
        OmegaConf.to_container(config),
        experiment_dir=exp_dir,
        checkpoint_name="model")
    val_losses = trainer.validate()
    with open(f"{exp_dir}/{config['name']}/val_losses.json", "w") as f:
        json.dump(val_losses[0], f)

INFO:pytorch_lightning.utilities.seed:Global seed set to 4995
INFO:pytorch_lightning.utilities.distributed:GPU available: True, used: True
INFO:pytorch_lightning.utilities.distributed:TPU available: False, using: 0 TPU cores
INFO:pytorch_lightning.utilities.distributed:IPU available: False, using: 0 IPUs
INFO:pytorch_lightning.accelerators.gpu:LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]
INFO:pytorch_lightning.callbacks.model_summary:
  | Name  | Type  | Params
--------------------------------
0 | model | PSVAE | 22.9 K
--------------------------------
22.8 K    Trainable params
100       Non-trainable params
22.9 K    Total params
0.092     Total estimated model params size (MB)


Validation sanity check: 0it [00:00, ?it/s]

  f"The dataloader, {name}, does not have many workers which may be a bottleneck."
INFO:pytorch_lightning.utilities.seed:Global seed set to 4995
  f"The dataloader, {name}, does not have many workers which may be a bottleneck."


Training: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

  f"`.{fn}(ckpt_path=None)` was called without a model."
INFO:pytorch_lightning.utilities.distributed:Restoring states from the checkpoint path at /content/experiments/psvae_alpha_selection/psvae_10latent_alpha=1/model-v1.ckpt
INFO:pytorch_lightning.accelerators.gpu:LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]
INFO:pytorch_lightning.utilities.distributed:Loaded model weights from checkpoint at /content/experiments/psvae_alpha_selection/psvae_10latent_alpha=1/model-v1.ckpt


Validating: 0it [00:00, ?it/s]

--------------------------------------------------------------------------------
DATALOADER:0 VALIDATE RESULTS
{'val_data_recon_loss': 2254.187255859375,
 'val_label_recon_loss': 23530.369140625,
 'val_loss': 68765.96875,
 'val_supervised_latents_kld': 42944.47265625,
 'val_unsupervised_latents_kld': 36.911590576171875}
--------------------------------------------------------------------------------


### PS-VAE Selecting $\beta$

In [None]:
psvae_beta_selection = [OmegaConf.merge(psvae_base_config, c) for c in [
    {
        "name": "psvae_10latent_alpha=1_beta=1",
        "model": {
            "latent_dim": 10,
            "label_dim": 4
        },
        "alpha": 1,
        "beta": 1
    },
    {
        "name": "psvae_10latent_alpha=1_beta=5",
        "model": {
            "latent_dim": 10,
            "label_dim": 4
        },
        "alpha": 1,
        "beta": 5
    },
    {
        "name": "psvae_10latent_alpha=1_beta=10",
        "model": {
            "latent_dim": 10,
            "label_dim": 4
        },
        "alpha": 1,
        "beta": 10
    },
    {
        "name": "psvae_10latent_alpha=1_beta=20",
        "model": {
            "latent_dim": 10,
            "label_dim": 4
        },
        "alpha": 1,
        "beta": 20
    },
]]

exp_dir = "experiments/psvae_beta_selection"
for config in psvae_beta_selection:
    system, trainer = train(
        SpikeSortingPSVAE,
        OmegaConf.to_container(config),
        experiment_dir=exp_dir,
        checkpoint_name="model")
    val_losses = trainer.validate()
    with open(f"{exp_dir}/{config['name']}/val_losses.json", "w") as f:
        json.dump(val_losses[0], f)

### PS-VAE Varying \# of Latent Dimensions

In [37]:
psvae_latent_dim_selection = [OmegaConf.merge(psvae_base_config, c) for c in [
    # Note: with latent_dim=10 was trained previously
    {
        "name": "psvae_8latent_alpha=1_beta=20",
        "model": {
            "latent_dim": 8,
            "label_dim": 4
        },
        "alpha": 1,
        "beta": 20
    },
    {
        "name": "psvae_6latent_alpha=1_beta=20",
        "model": {
            "latent_dim": 6,
            "label_dim": 4
        },
        "alpha": 1,
        "beta": 20
    },
]]

exp_dir = "experiments/psvae_latent_dim_selection"
for config in psvae_latent_dim_selection:
    system, trainer = train(
        SpikeSortingPSVAE,
        OmegaConf.to_container(config),
        experiment_dir=exp_dir,
        checkpoint_name="model")
    val_losses = trainer.validate()
    with open(f"{exp_dir}/{config['name']}/val_losses.json", "w") as f:
        json.dump(val_losses[0], f)