# Full Analysis Example

### Imports

In [1]:
import json
import os
from pathlib import Path

import numpy as np
from omegaconf import OmegaConf
import pytorch_lightning as pl
from pytorch_lightning import Trainer, seed_everything
from pytorch_lightning.callbacks import TQDMProgressBar, ModelCheckpoint, EarlyStopping
from pl_bolts.callbacks import PrintTableMetricsCallback
import torch

from src.data.preprocess_templates import plot_templates, get_max_chan_temps, take_channel_range, localize_wfs
from src.data.make_datasets import (
    featurization_dataset, positional_invariance_dataset, clustering_dataset, 
    time_center_templates, normalize_inputs
)
from src.models.ae import *
from src.models.vae import *
from src.models.spike_vaes import *
from src.models.train import train

### Load Data

In [2]:
RAW_DATA_DIR = "/Users/johnzhou/research/spike-sorting/data/raw"
PROCESS_DATA_DIR = "/Users/johnzhou/research/spike-sorting/data/processed"

# Cleaned and denoised templates
templates_fname = "templates_yass.npy"
templates_fpath = os.path.join(RAW_DATA_DIR, templates_fname)
templates = np.load(templates_fpath)

# Probe geometry
geom_fname = "np2_channel_map.npy"
geom_fpath = os.path.join(RAW_DATA_DIR, geom_fname)
geom_array = np.load(geom_fpath)
channels_pos = geom_array[:20]

## Data Preprocessing

In [3]:
a, loc, scale = 3, 100, 500
n_channels = 20
n_samples = 10000

### Identify and Remove Bad Templates

In [4]:
num_templates, duration, num_channels = templates.shape
print("{} contains {} templates for {} timesteps across {} channels.".format(
    templates_fname, num_templates, duration, num_channels))
max_chan_temp = get_max_chan_temps(templates)
# plot_templates(templates, max_chan_temp, n_channels=n_channels)

templates_yass.npy contains 170 templates for 121 timesteps across 384 channels.


In [5]:
bad_template_idxs = [3, 6, 27, 29, 32, 35, 36, 56, 57, 58, 59, 62, 63, 64, 74, 78, 79, 80, 85, 91, 92, \
    101, 107, 109, 110, 111, 118, 119, 121, 145, 150, 151, 152, 157, 159, 164, 165, 169]
good_templates = np.delete(templates, bad_template_idxs, axis=0)
templates_chans, templates_ptp_chans = take_channel_range(good_templates, n_channels_loc=n_channels)
positions_templates = localize_wfs(templates_ptp_chans, geom_array)

100%|████████████████████████████████████████| 132/132 [00:01<00:00, 102.32it/s]


### Produce Datasets

In [6]:
from src.data.make_datasets import (
    featurization_dataset,
    positional_invariance_dataset,
    clustering_dataset
)

In [7]:
# Featurization dataset to train VAEs/PCA
featurize_experiment_name = "featurization"
featurization_dataset(
    templates_chans, positions_templates, channels_pos, a, loc, scale, n_samples=n_samples, 
    experiment_data_dir=PROCESS_DATA_DIR, experiment_name=featurize_experiment_name
)

# Positional invariance analysis dataset for visualization
position_features = ["x", "z", "y", "alpha"]
for vary_feature in position_features:
    vary_experiment_name = "{}_invariance_analysis".format(vary_feature)
    vary_samples = 100
    positional_invariance_dataset(
        templates_chans, positions_templates, channels_pos, a, loc, scale, vary_feature=vary_feature, 
        n_samples=vary_samples, experiment_data_dir=PROCESS_DATA_DIR, experiment_name=vary_experiment_name
    )

# Clustering dataset for feature evaluation
num_clusters = 20
num_samples_per_cluster = 100
cluster_experiment_name = "{}_clusters".format(num_clusters)
clustering_dataset(templates, positions_templates, channels_pos, a, loc, scale, n_clusters=num_clusters, 
                   num_samples_per_cluster=num_samples_per_cluster, experiment_data_dir=PROCESS_DATA_DIR, 
                   experiment_name=cluster_experiment_name)

100%|███████████████████████████████████| 10000/10000 [00:04<00:00, 2019.46it/s]


Saving templates to folder /Users/johnzhou/research/spike-sorting/data/processed/featurization, array of size: (10000, 121, 20)
Saving predicted PTPs to folder /Users/johnzhou/research/spike-sorting/data/processed/featurization, array of size: (10000, 20)
Saving positions to folder /Users/johnzhou/research/spike-sorting/data/processed/featurization, array of size: (4, 10000)
Saving unit indices to folder /Users/johnzhou/research/spike-sorting/data/processed/featurization, array of size: (10000,)


100%|███████████████████████████████████████| 100/100 [00:00<00:00, 1354.02it/s]


Saving templates to folder /Users/johnzhou/research/spike-sorting/data/processed/x_invariance_analysis, array of size: (100, 121, 20)
Saving predicted PTPs to folder /Users/johnzhou/research/spike-sorting/data/processed/x_invariance_analysis, array of size: (100, 20)
Saving positions to folder /Users/johnzhou/research/spike-sorting/data/processed/x_invariance_analysis, array of size: (4, 100)
Saving unit indices to folder /Users/johnzhou/research/spike-sorting/data/processed/x_invariance_analysis, array of size: (100,)


100%|███████████████████████████████████████| 100/100 [00:00<00:00, 1150.25it/s]


Saving templates to folder /Users/johnzhou/research/spike-sorting/data/processed/z_invariance_analysis, array of size: (100, 121, 20)
Saving predicted PTPs to folder /Users/johnzhou/research/spike-sorting/data/processed/z_invariance_analysis, array of size: (100, 20)
Saving positions to folder /Users/johnzhou/research/spike-sorting/data/processed/z_invariance_analysis, array of size: (4, 100)
Saving unit indices to folder /Users/johnzhou/research/spike-sorting/data/processed/z_invariance_analysis, array of size: (100,)


100%|███████████████████████████████████████| 100/100 [00:00<00:00, 1179.18it/s]


Saving templates to folder /Users/johnzhou/research/spike-sorting/data/processed/y_invariance_analysis, array of size: (100, 121, 20)
Saving predicted PTPs to folder /Users/johnzhou/research/spike-sorting/data/processed/y_invariance_analysis, array of size: (100, 20)
Saving positions to folder /Users/johnzhou/research/spike-sorting/data/processed/y_invariance_analysis, array of size: (4, 100)
Saving unit indices to folder /Users/johnzhou/research/spike-sorting/data/processed/y_invariance_analysis, array of size: (100,)


100%|███████████████████████████████████████| 100/100 [00:00<00:00, 1064.93it/s]


Saving templates to folder /Users/johnzhou/research/spike-sorting/data/processed/alpha_invariance_analysis, array of size: (100, 121, 20)
Saving predicted PTPs to folder /Users/johnzhou/research/spike-sorting/data/processed/alpha_invariance_analysis, array of size: (100, 20)
Saving positions to folder /Users/johnzhou/research/spike-sorting/data/processed/alpha_invariance_analysis, array of size: (4, 100)
Saving unit indices to folder /Users/johnzhou/research/spike-sorting/data/processed/alpha_invariance_analysis, array of size: (100,)


100%|███████████████████████████████████████████| 20/20 [00:01<00:00, 16.22it/s]


Saving templates to folder /Users/johnzhou/research/spike-sorting/data/processed/20_clusters, array of size: (2000, 121, 384)
Saving predicted PTPs to folder /Users/johnzhou/research/spike-sorting/data/processed/20_clusters, array of size: (2000, 384)
Saving positions to folder /Users/johnzhou/research/spike-sorting/data/processed/20_clusters, array of size: (4, 2000)
Saving unit indices to folder /Users/johnzhou/research/spike-sorting/data/processed/20_clusters, array of size: (2000,)


## Model Training