# Full Analysis Example

### Imports

In [1]:
import json
import os
from pathlib import Path

import numpy as np
from omegaconf import OmegaConf
import pytorch_lightning as pl
from pytorch_lightning import Trainer, seed_everything
from pytorch_lightning.callbacks import TQDMProgressBar, ModelCheckpoint, EarlyStopping
from pl_bolts.callbacks import PrintTableMetricsCallback
import torch

from src.data.preprocess_templates import plot_templates, get_max_chan_temps, take_channel_range, localize_wfs
from src.data.make_datasets import (
    featurization_dataset, positional_invariance_dataset, clustering_dataset, 
    time_center_templates, normalize_inputs
)
from src.models.ae import *
from src.models.vae import *
from src.models.spike_vaes import *
from src.models.utils.train import train

In [2]:
REPO_PATH = "/Users/johnzhou/research/spike-sorting"
RAW_DATA_DIR = Path(f"{REPO_PATH}/data/raw")
PROCESS_DATA_DIR = Path(f"{REPO_PATH}/data/processed")
EXPT_DIR = Path(f"{REPO_PATH}/experiments")

### Load Data

In [3]:
# Cleaned and denoised templates
templates_fname = "templates_yass.npy"
templates_fpath = Path(f"{RAW_DATA_DIR}/{templates_fname}")
templates = np.load(templates_fpath)

# Probe geometry
geom_fname = "channel_map_np2.npy"
geom_fpath = Path(f"{RAW_DATA_DIR}/{geom_fname}")
geom_array = np.load(geom_fpath)
channels_pos = geom_array[:20]

## Data Preprocessing and Synthetic Dataset Generation

In [4]:
a, loc, scale = 3, 100, 500
n_channels = 20

### Identify and Remove Bad Templates

In [5]:
num_templates, duration, num_channels = templates.shape
print("{} contains {} templates for {} timesteps across {} channels.".format(
    templates_fname, num_templates, duration, num_channels))
max_chan_temp = get_max_chan_temps(templates)
# plot_templates(templates, max_chan_temp, n_channels=n_channels)

templates_yass.npy contains 170 templates for 121 timesteps across 384 channels.


In [6]:
bad_template_idxs = [3, 6, 27, 29, 32, 35, 36, 56, 57, 58, 59, 62, 63, 64, 74, 78, 79, 80, 85, 91, 92, \
    101, 107, 109, 110, 111, 118, 119, 121, 145, 150, 151, 152, 157, 159, 164, 165, 169]
good_templates = np.delete(templates, bad_template_idxs, axis=0)
templates_chans, templates_ptp_chans = take_channel_range(good_templates, n_channels_loc=n_channels)
positions_templates = localize_wfs(templates_ptp_chans, geom_array)

100%|████████████████████████████████████████| 132/132 [00:01<00:00, 114.39it/s]


In [7]:
print(good_templates.shape)

(132, 121, 384)


### Produce Datasets

In [8]:
from src.data.make_datasets import (
    featurization_dataset,
    positional_invariance_dataset,
    clustering_dataset
)

In [9]:
# Featurization dataset to train VAEs/PCA
n_samples = 100000
n_train_samples = round(n_samples * 0.8)
n_val_samples = round(n_samples * 0.2)
# Train
featurize_train_experiment_name = "featurization_train"
featurization_dataset(
    templates_chans, positions_templates, channels_pos, a, loc, scale, n_samples=n_train_samples, 
    experiment_data_dir=PROCESS_DATA_DIR, experiment_name=featurize_train_experiment_name
)
# Validation
featurize_val_experiment_name = "featurization_val"
featurization_dataset(
    templates_chans, positions_templates, channels_pos, a, loc, scale, n_samples=n_val_samples, 
    experiment_data_dir=PROCESS_DATA_DIR, experiment_name=featurize_val_experiment_name
)

# Positional invariance analysis dataset for visualization
position_features = ["x", "z", "y", "alpha"]
for vary_feature in position_features:
    vary_experiment_name = "invariance_analysis_{}".format(vary_feature)
    vary_samples = 100
    positional_invariance_dataset(
        templates_chans, positions_templates, channels_pos, a, loc, scale, vary_feature=vary_feature, 
        n_samples=vary_samples, experiment_data_dir=PROCESS_DATA_DIR, experiment_name=vary_experiment_name
    )

# Clustering dataset for feature evaluation
num_clusters = 20
num_samples_per_cluster = 100
cluster_experiment_name = "clusters_k={}".format(num_clusters)
clustering_dataset(templates, positions_templates, channels_pos, a, loc, scale, n_clusters=num_clusters, 
                   num_samples_per_cluster=num_samples_per_cluster, experiment_data_dir=PROCESS_DATA_DIR, 
                   experiment_name=cluster_experiment_name)

100%|███████████████████████████████████| 80000/80000 [00:57<00:00, 1397.72it/s]


Saving templates to folder /Users/johnzhou/research/spike-sorting/data/processed/featurization_train, array of size: (80000, 20, 121)
Saving predicted PTPs to folder /Users/johnzhou/research/spike-sorting/data/processed/featurization_train, array of size: (80000, 20)
Saving positions to folder /Users/johnzhou/research/spike-sorting/data/processed/featurization_train, array of size: (4, 80000)
Saving unit indices to folder /Users/johnzhou/research/spike-sorting/data/processed/featurization_train, array of size: (80000,)


100%|███████████████████████████████████| 20000/20000 [00:11<00:00, 1746.37it/s]


Saving templates to folder /Users/johnzhou/research/spike-sorting/data/processed/featurization_val, array of size: (20000, 20, 121)
Saving predicted PTPs to folder /Users/johnzhou/research/spike-sorting/data/processed/featurization_val, array of size: (20000, 20)
Saving positions to folder /Users/johnzhou/research/spike-sorting/data/processed/featurization_val, array of size: (4, 20000)
Saving unit indices to folder /Users/johnzhou/research/spike-sorting/data/processed/featurization_val, array of size: (20000,)


100%|███████████████████████████████████████| 100/100 [00:00<00:00, 1455.60it/s]

Saving templates to folder /Users/johnzhou/research/spike-sorting/data/processed/invariance_analysis_x, array of size: (100, 20, 121)





Saving predicted PTPs to folder /Users/johnzhou/research/spike-sorting/data/processed/invariance_analysis_x, array of size: (100, 20)
Saving positions to folder /Users/johnzhou/research/spike-sorting/data/processed/invariance_analysis_x, array of size: (4, 100)
Saving unit indices to folder /Users/johnzhou/research/spike-sorting/data/processed/invariance_analysis_x, array of size: (100,)


100%|████████████████████████████████████████| 100/100 [00:00<00:00, 969.77it/s]

Saving templates to folder /Users/johnzhou/research/spike-sorting/data/processed/invariance_analysis_z, array of size: (100, 20, 121)





Saving predicted PTPs to folder /Users/johnzhou/research/spike-sorting/data/processed/invariance_analysis_z, array of size: (100, 20)
Saving positions to folder /Users/johnzhou/research/spike-sorting/data/processed/invariance_analysis_z, array of size: (4, 100)
Saving unit indices to folder /Users/johnzhou/research/spike-sorting/data/processed/invariance_analysis_z, array of size: (100,)


100%|████████████████████████████████████████| 100/100 [00:00<00:00, 904.91it/s]

Saving templates to folder /Users/johnzhou/research/spike-sorting/data/processed/invariance_analysis_y, array of size: (100, 20, 121)





Saving predicted PTPs to folder /Users/johnzhou/research/spike-sorting/data/processed/invariance_analysis_y, array of size: (100, 20)
Saving positions to folder /Users/johnzhou/research/spike-sorting/data/processed/invariance_analysis_y, array of size: (4, 100)
Saving unit indices to folder /Users/johnzhou/research/spike-sorting/data/processed/invariance_analysis_y, array of size: (100,)


100%|████████████████████████████████████████| 100/100 [00:00<00:00, 819.18it/s]


Saving templates to folder /Users/johnzhou/research/spike-sorting/data/processed/invariance_analysis_alpha, array of size: (100, 20, 121)
Saving predicted PTPs to folder /Users/johnzhou/research/spike-sorting/data/processed/invariance_analysis_alpha, array of size: (100, 20)
Saving positions to folder /Users/johnzhou/research/spike-sorting/data/processed/invariance_analysis_alpha, array of size: (4, 100)
Saving unit indices to folder /Users/johnzhou/research/spike-sorting/data/processed/invariance_analysis_alpha, array of size: (100,)


100%|███████████████████████████████████████████| 20/20 [00:01<00:00, 11.75it/s]


Saving templates to folder /Users/johnzhou/research/spike-sorting/data/processed/clusters_k=20, array of size: (2000, 384, 121)
Saving predicted PTPs to folder /Users/johnzhou/research/spike-sorting/data/processed/clusters_k=20, array of size: (2000, 384)
Saving positions to folder /Users/johnzhou/research/spike-sorting/data/processed/clusters_k=20, array of size: (4, 2000)
Saving unit indices to folder /Users/johnzhou/research/spike-sorting/data/processed/clusters_k=20, array of size: (2000,)


## Model Training

In [10]:
"""
If GPUs are available, set "gpus" parameter. Modify data paths to match own file structure.
"""

train_template_path = f"{PROCESS_DATA_DIR}/{featurize_train_experiment_name}/templates.npy"
train_labels_path = f"{PROCESS_DATA_DIR}/{featurize_train_experiment_name}/positions.npy"
val_template_path = f"{PROCESS_DATA_DIR}/{featurize_val_experiment_name}/templates.npy"
val_labels_path = f"{PROCESS_DATA_DIR}/{featurize_val_experiment_name}/positions.npy"


base_config = OmegaConf.create({
    "random_seed": 4995,
    "model": {
        "in_channels": 20,
        "conv_encoder_layers": [[32, 5, 2], [16, 5, 2]],
        "conv_decoder_layers": [[16, 5, 2, 0], [20, 5, 2, 0]],
        "encoder_output_dim": [16, 28],
        "use_batch_norm": True
    },
    "learning_rate": 1e-4,
    "data": {
        "train_data_path": train_template_path,
        "val_data_path": val_template_path,
        "train_batch_size": 100,
        "val_batch_size": 100
    },
    "trainer": {
        "gpus": 0,
        "max_epochs": 100
    }

})

psvae_base_config = OmegaConf.merge(base_config, {
    "data": {
        "train_label_path": "data/train_labels.npy",
        "val_label_path": "data/val_labels.npy",
    },
    "anneal_epochs": 50
})

### VAE

In [11]:
vae_configs = [OmegaConf.merge(base_config, c) for c in [
    {
        "name": "vae_10latent",
        "model": {
            "latent_dim": 10
        }
    },
    {
        "name": "vae_8latent",
        "model": {
            "latent_dim": 8
        }
    },
    {
        "name": "vae_6latent",
        "model": {
            "latent_dim": 6
        }
    },
]]

expt_dir = f"{EXPT_DIR}/vae"
for config in vae_configs:
    system, trainer = train(
        SpikeSortingVAE,
        OmegaConf.to_container(config),
        experiment_dir=expt_dir,
        checkpoint_name="model")
    val_losses = trainer.validate()
    with open(f"{expt_dir}/{config['name']}/val_losses.json", "w") as f:
        json.dump(val_losses[0], f)

Global seed set to 4995
GPU available: False, used: False
TPU available: False, using: 0 TPU cores
IPU available: False, using: 0 IPUs

  | Name  | Type | Params
-------------------------------
0 | model | VAE  | 22.8 K
-------------------------------
22.8 K    Trainable params
0         Non-trainable params
22.8 K    Total params
0.091     Total estimated model params size (MB)


Validation sanity check:   0%|                            | 0/2 [00:00<?, ?it/s]

  f"The dataloader, {name}, does not have many workers which may be a bottleneck."


                                                                                

Global seed set to 4995
  f"The dataloader, {name}, does not have many workers which may be a bottleneck."


Epoch 0:  80%|████████████▊   | 800/1000 [00:28<00:07, 28.24it/s, loss=1.63e+04]
Validating: 0it [00:00, ?it/s][A
Validating:   0%|                                       | 0/200 [00:00<?, ?it/s][A
Epoch 0:  84%|█████████████▍  | 840/1000 [00:28<00:05, 29.36it/s, loss=1.63e+04][A
Validating:  20%|██████                        | 40/200 [00:00<00:02, 66.18it/s][A
Epoch 0:  88%|██████████████  | 880/1000 [00:29<00:03, 30.10it/s, loss=1.63e+04][A
Validating:  40%|████████████                  | 80/200 [00:01<00:01, 68.82it/s][A
Epoch 0:  92%|██████████████▋ | 920/1000 [00:29<00:02, 30.90it/s, loss=1.63e+04][A
Validating:  60%|█████████████████▍           | 120/200 [00:01<00:01, 71.50it/s][A
Epoch 0:  96%|███████████████▎| 960/1000 [00:30<00:01, 31.67it/s, loss=1.63e+04][A
Validating:  80%|███████████████████████▏     | 160/200 [00:02<00:00, 72.13it/s][A
Epoch 0: 100%|███████████████| 1000/1000 [00:30<00:00, 32.40it/s, loss=1.63e+04][A
Epoch 0: 100%|█| 1000/1000 [00:31<00:00, 32.1

Epoch 7: 100%|█| 1000/1000 [00:30<00:00, 33.08it/s, loss=1.24e+04, val_loss=1.13[A
Epoch 7: 100%|█| 1000/1000 [00:30<00:00, 32.76it/s, loss=1.24e+04, val_loss=1.08[A
Epoch 8:  80%|▊| 800/1000 [00:29<00:07, 27.44it/s, loss=1.2e+04, val_loss=1.08e+[A
Validating: 0it [00:00, ?it/s][A
Validating:   0%|                                       | 0/200 [00:00<?, ?it/s][A
Epoch 8:  84%|▊| 840/1000 [00:29<00:05, 28.54it/s, loss=1.2e+04, val_loss=1.08e+[A
Validating:  20%|██████                        | 40/200 [00:00<00:02, 69.92it/s][A
Epoch 8:  88%|▉| 880/1000 [00:30<00:04, 29.30it/s, loss=1.2e+04, val_loss=1.08e+[A
Validating:  40%|████████████                  | 80/200 [00:01<00:01, 68.98it/s][A
Epoch 8:  92%|▉| 920/1000 [00:30<00:02, 30.08it/s, loss=1.2e+04, val_loss=1.08e+[A
Validating:  60%|█████████████████▍           | 120/200 [00:01<00:01, 71.25it/s][A
Epoch 8:  96%|▉| 960/1000 [00:31<00:01, 30.84it/s, loss=1.2e+04, val_loss=1.08e+[A
Validating:  80%|███████████████████████▏ 

Epoch 15:  96%|▉| 960/1000 [00:32<00:01, 29.62it/s, loss=9.63e+03, val_loss=8.48[A
Validating:  80%|███████████████████████▏     | 160/200 [00:02<00:00, 74.04it/s][A
Epoch 15: 100%|█| 1000/1000 [00:32<00:00, 30.34it/s, loss=9.63e+03, val_loss=8.4[A
Epoch 15: 100%|█| 1000/1000 [00:33<00:00, 30.09it/s, loss=9.63e+03, val_loss=8.2[A
Epoch 16:  80%|▊| 800/1000 [00:31<00:07, 25.71it/s, loss=9.36e+03, val_loss=8.23[A
Validating: 0it [00:00, ?it/s][A
Validating:   0%|                                       | 0/200 [00:00<?, ?it/s][A
Epoch 16:  84%|▊| 840/1000 [00:31<00:05, 26.75it/s, loss=9.36e+03, val_loss=8.23[A
Validating:  20%|██████                        | 40/200 [00:00<00:02, 72.04it/s][A
Epoch 16:  88%|▉| 880/1000 [00:31<00:04, 27.55it/s, loss=9.36e+03, val_loss=8.23[A
Validating:  40%|████████████                  | 80/200 [00:01<00:01, 73.47it/s][A
Epoch 16:  92%|▉| 920/1000 [00:32<00:02, 28.33it/s, loss=9.36e+03, val_loss=8.23[A
Validating:  60%|█████████████████▍       

Epoch 23:  92%|▉| 920/1000 [00:25<00:02, 35.60it/s, loss=7.85e+03, val_loss=6.72[A
Validating:  60%|█████████████████▍           | 120/200 [00:01<00:01, 78.12it/s][A
Epoch 23:  96%|▉| 960/1000 [00:26<00:01, 36.48it/s, loss=7.85e+03, val_loss=6.72[A
Validating:  80%|███████████████████████▏     | 160/200 [00:02<00:00, 82.01it/s][A
Epoch 23: 100%|█| 1000/1000 [00:26<00:00, 37.35it/s, loss=7.85e+03, val_loss=6.7[A
Epoch 23: 100%|█| 1000/1000 [00:27<00:00, 37.03it/s, loss=7.85e+03, val_loss=6.5[A
Epoch 24:  80%|▊| 800/1000 [00:24<00:06, 32.54it/s, loss=7.68e+03, val_loss=6.55[A
Validating: 0it [00:00, ?it/s][A
Validating:   0%|                                       | 0/200 [00:00<?, ?it/s][A
Epoch 24:  84%|▊| 840/1000 [00:24<00:04, 33.75it/s, loss=7.68e+03, val_loss=6.55[A
Validating:  20%|██████                        | 40/200 [00:00<00:02, 78.44it/s][A
Epoch 24:  88%|▉| 880/1000 [00:25<00:03, 34.73it/s, loss=7.68e+03, val_loss=6.55[A
Validating:  40%|████████████             

Epoch 31:  88%|▉| 880/1000 [00:43<00:05, 20.07it/s, loss=6.76e+03, val_loss=5.59[A
Validating:  40%|████████████                  | 80/200 [00:01<00:01, 61.06it/s][A
Epoch 31:  92%|▉| 920/1000 [00:44<00:03, 20.70it/s, loss=6.76e+03, val_loss=5.59[A
Validating:  60%|█████████████████▍           | 120/200 [00:01<00:01, 64.88it/s][A
Epoch 31:  96%|▉| 960/1000 [00:45<00:01, 21.32it/s, loss=6.76e+03, val_loss=5.59[A
Validating:  80%|███████████████████████▏     | 160/200 [00:02<00:00, 67.16it/s][A
Epoch 31: 100%|█| 1000/1000 [00:45<00:00, 21.93it/s, loss=6.76e+03, val_loss=5.5[A
Epoch 31: 100%|█| 1000/1000 [00:45<00:00, 21.80it/s, loss=6.76e+03, val_loss=5.4[A
Epoch 32:  80%|▊| 800/1000 [00:26<00:06, 30.10it/s, loss=6.66e+03, val_loss=5.48[A
Validating: 0it [00:00, ?it/s][A
Validating:   0%|                                       | 0/200 [00:00<?, ?it/s][A
Epoch 32:  84%|▊| 840/1000 [00:26<00:05, 31.28it/s, loss=6.66e+03, val_loss=5.48[A
Validating:  20%|██████                   

Epoch 39:  84%|▊| 840/1000 [00:26<00:05, 31.83it/s, loss=6.17e+03, val_loss=4.94[A
Validating:  20%|██████                        | 40/200 [00:00<00:02, 76.11it/s][A
Epoch 39:  88%|▉| 880/1000 [00:26<00:03, 32.71it/s, loss=6.17e+03, val_loss=4.94[A
Validating:  40%|████████████                  | 80/200 [00:01<00:01, 76.77it/s][A
Epoch 39:  92%|▉| 920/1000 [00:27<00:02, 33.55it/s, loss=6.17e+03, val_loss=4.94[A
Validating:  60%|█████████████████▍           | 120/200 [00:01<00:01, 77.12it/s][A
Epoch 39:  96%|▉| 960/1000 [00:27<00:01, 34.36it/s, loss=6.17e+03, val_loss=4.94[A
Validating:  80%|███████████████████████▏     | 160/200 [00:02<00:00, 76.66it/s][A
Epoch 39: 100%|█| 1000/1000 [00:28<00:00, 35.13it/s, loss=6.17e+03, val_loss=4.9[A
Epoch 39: 100%|█| 1000/1000 [00:28<00:00, 34.81it/s, loss=6.17e+03, val_loss=4.8[A
Epoch 40:  80%|▊| 800/1000 [00:26<00:06, 30.57it/s, loss=6.12e+03, val_loss=4.89[A
Validating: 0it [00:00, ?it/s][A
Validating:   0%|                         

Epoch 47:  80%|▊| 800/1000 [00:27<00:06, 29.16it/s, loss=5.88e+03, val_loss=4.6e[A
Validating: 0it [00:00, ?it/s][A
Validating:   0%|                                       | 0/200 [00:00<?, ?it/s][A
Epoch 47:  84%|▊| 840/1000 [00:27<00:05, 30.36it/s, loss=5.88e+03, val_loss=4.6e[A
Validating:  20%|██████                        | 40/200 [00:00<00:01, 89.91it/s][A
Epoch 47:  88%|▉| 880/1000 [00:28<00:03, 31.31it/s, loss=5.88e+03, val_loss=4.6e[A
Validating:  40%|████████████                  | 80/200 [00:00<00:01, 91.67it/s][A
Epoch 47:  92%|▉| 920/1000 [00:28<00:02, 32.23it/s, loss=5.88e+03, val_loss=4.6e[A
Validating:  60%|█████████████████▍           | 120/200 [00:01<00:00, 92.03it/s][A
Epoch 47:  96%|▉| 960/1000 [00:28<00:01, 33.14it/s, loss=5.88e+03, val_loss=4.6e[A
Validating:  80%|███████████████████████▏     | 160/200 [00:01<00:00, 93.94it/s][A
Epoch 47: 100%|█| 1000/1000 [00:29<00:00, 34.03it/s, loss=5.88e+03, val_loss=4.6[A
Epoch 47: 100%|█| 1000/1000 [00:29<00:00, 

Epoch 54: 100%|█| 1000/1000 [00:28<00:00, 35.41it/s, loss=5.75e+03, val_loss=4.4[A
Epoch 54: 100%|█| 1000/1000 [00:28<00:00, 35.09it/s, loss=5.75e+03, val_loss=4.4[A
Epoch 55:  80%|▊| 800/1000 [00:26<00:06, 30.13it/s, loss=5.74e+03, val_loss=4.44[A
Validating: 0it [00:00, ?it/s][A
Validating:   0%|                                       | 0/200 [00:00<?, ?it/s][A
Epoch 55:  84%|▊| 840/1000 [00:26<00:05, 31.32it/s, loss=5.74e+03, val_loss=4.44[A
Validating:  20%|██████                        | 40/200 [00:00<00:02, 78.10it/s][A
Epoch 55:  88%|▉| 880/1000 [00:27<00:03, 32.21it/s, loss=5.74e+03, val_loss=4.44[A
Validating:  40%|████████████                  | 80/200 [00:01<00:01, 78.14it/s][A
Epoch 55:  92%|▉| 920/1000 [00:27<00:02, 33.05it/s, loss=5.74e+03, val_loss=4.44[A
Validating:  60%|█████████████████▍           | 120/200 [00:01<00:01, 79.01it/s][A
Epoch 55:  96%|▉| 960/1000 [00:28<00:01, 33.89it/s, loss=5.74e+03, val_loss=4.44[A
Validating:  80%|███████████████████████▏ 

Epoch 62:  96%|▉| 960/1000 [00:35<00:01, 27.01it/s, loss=5.68e+03, val_loss=4.37[A
Validating:  80%|███████████████████████▏     | 160/200 [00:01<00:00, 89.02it/s][A
Epoch 62: 100%|█| 1000/1000 [00:35<00:00, 27.79it/s, loss=5.68e+03, val_loss=4.3[A
Epoch 62: 100%|█| 1000/1000 [00:36<00:00, 27.61it/s, loss=5.68e+03, val_loss=4.3[A
Epoch 63:  80%|▊| 800/1000 [00:23<00:05, 33.48it/s, loss=5.68e+03, val_loss=4.36[A
Validating: 0it [00:00, ?it/s][A
Validating:   0%|                                       | 0/200 [00:00<?, ?it/s][A
Epoch 63:  84%|▊| 840/1000 [00:24<00:04, 34.81it/s, loss=5.68e+03, val_loss=4.36[A
Validating:  20%|██████                        | 40/200 [00:00<00:01, 86.91it/s][A
Epoch 63:  88%|▉| 880/1000 [00:24<00:03, 35.79it/s, loss=5.68e+03, val_loss=4.36[A
Validating:  40%|████████████                  | 80/200 [00:00<00:01, 86.77it/s][A
Epoch 63:  92%|▉| 920/1000 [00:25<00:02, 36.73it/s, loss=5.68e+03, val_loss=4.36[A
Validating:  60%|█████████████████▍       

  rank_zero_warn("Detected KeyboardInterrupt, attempting graceful shutdown...")
  f"`.{fn}(ckpt_path=None)` was called without a model."
Restoring states from the checkpoint path at /Users/johnzhou/research/spike-sorting/experiments/vae_10latent/model-v1.ckpt
Loaded model weights from checkpoint at /Users/johnzhou/research/spike-sorting/experiments/vae_10latent/model-v1.ckpt



Validating: 0it [00:00, ?it/s][A
Validating:   0%|                                       | 0/200 [00:00<?, ?it/s][A
Validating:  10%|███                           | 20/200 [00:00<00:02, 63.23it/s][A
Validating:  20%|██████                        | 40/200 [00:00<00:02, 64.74it/s][A
Validating:  30%|█████████                     | 60/200 [00:00<00:01, 70.02it/s][A
Validating:  40%|████████████                  | 80/200 [00:01<00:01, 69.48it/s][A
Validating:  50%|██████████████▌              | 100/200 [00:01<00:01, 71.48it/s][A
Validating:  60%|█████████████████▍           | 120/200 [00:01<00:01, 71.39it/s][A
Validating:  70%|████████████████████▎        | 140/200 [00:01<00:00, 74.15it/s][A
Validating:  80%|███████████████████████▏     | 160/200 [00:02<00:00, 78.15it/s][A
Validating:  90%|██████████████████████████   | 180/200 [00:02<00:00, 81.36it/s][A
Validating: 100%|█████████████████████████████| 200/200 [00:02<00:00, 83.55it/s][A------------------------------------------

NameError: name 'exp_dir' is not defined

### PS-VAE Selecting $\alpha$

In [None]:
psvae_alpha_selection = [OmegaConf.merge(psvae_base_config, c) for c in [
    {
        "name": "psvae_10latent_alpha=1_beta=1",
        "model": {
            "latent_dim": 10,
            "label_dim": 4
        },
        "alpha": 1,
        "beta": 1
    },
    {
        "name": "psvae_10latent_alpha=10_beta=1",
        "model": {
            "latent_dim": 10,
            "label_dim": 4
        },
        "alpha": 10,
        "beta": 1
    },
    {
        "name": "psvae_10latent_alpha=25_beta=1",
        "model": {
            "latent_dim": 10,
            "label_dim": 4
        },
        "alpha": 25,
        "beta": 1
    },
    {
        "name": "psvae_10latent_alpha=50_beta=1",
        "model": {
            "latent_dim": 10,
            "label_dim": 4
        },
        "alpha": 50,
        "beta": 1
    },
    {
        "name": "psvae_10latent_alpha=100_beta=1",
        "model": {
            "latent_dim": 10,
            "label_dim": 4
        },
        "alpha": 100,
        "beta": 1
    },
]]

expt_dir = f"{EXPT_DIR}/psvae_alpha_selection"
for config in psvae_alpha_selection:
    system, trainer = train(
        SpikeSortingPSVAE,
        OmegaConf.to_container(config),
        experiment_dir=expt_dir,
        checkpoint_name="model")
    val_losses = trainer.validate()
    with open(f"{expt_dir}/{config['name']}/val_losses.json", "w") as f:
        json.dump(val_losses[0], f)

### PS-VAE Selecting $\beta$

In [None]:
psvae_beta_selection = [OmegaConf.merge(psvae_base_config, c) for c in [
    {
        "name": "psvae_10latent_alpha=1_beta=1",
        "model": {
            "latent_dim": 10,
            "label_dim": 4
        },
        "alpha": 1,
        "beta": 1
    },
    {
        "name": "psvae_10latent_alpha=1_beta=5",
        "model": {
            "latent_dim": 10,
            "label_dim": 4
        },
        "alpha": 1,
        "beta": 5
    },
    {
        "name": "psvae_10latent_alpha=1_beta=10",
        "model": {
            "latent_dim": 10,
            "label_dim": 4
        },
        "alpha": 1,
        "beta": 10
    },
    {
        "name": "psvae_10latent_alpha=1_beta=20",
        "model": {
            "latent_dim": 10,
            "label_dim": 4
        },
        "alpha": 1,
        "beta": 20
    },
]]

expt_dir = f"{EXPT_DIR}/psvae_beta_selection"
for config in psvae_beta_selection:
    system, trainer = train(
        SpikeSortingPSVAE,
        OmegaConf.to_container(config),
        experiment_dir=expt_dir,
        checkpoint_name="model")
    val_losses = trainer.validate()
    with open(f"{expt_dir}/{config['name']}/val_losses.json", "w") as f:
        json.dump(val_losses[0], f)