# Full Analysis Example

### Imports

In [2]:
import json
import os
from pathlib import Path

import numpy as np
from omegaconf import OmegaConf
import pytorch_lightning as pl
from pytorch_lightning import Trainer, seed_everything
from pytorch_lightning.callbacks import TQDMProgressBar, ModelCheckpoint, EarlyStopping
from pl_bolts.callbacks import PrintTableMetricsCallback
import torch

from src.data.preprocess_templates import plot_templates, get_max_chan_temps, take_channel_range, localize_wfs
from src.data.make_datasets import (
    featurization_dataset, positional_invariance_dataset, clustering_dataset, 
    time_center_templates, normalize_inputs
)
from src.models.ae import *
from src.models.vae import *
from src.models.spike_vaes import *
from src.models.train import train

### Load Data

In [4]:
RAW_DATA_DIR = "/Users/johnzhou/research/spike-sorting/data/raw"
PROCESS_DATA_DIR = "/Users/johnzhou/research/spike-sorting/data/processed"

# Cleaned and denoised templates
templates_fname = "templates_yass.npy"
templates_fpath = os.path.join(RAW_DATA_DIR, templates_fname)
templates = np.load(templates_fpath)

# Probe geometry
geom_fname = "np2_channel_map.npy"
geom_fpath = os.path.join(RAW_DATA_DIR, geom_fname)
geom_array = np.load(geom_fpath)
channels_pos = geom_array[:20]

## Data Preprocessing

In [5]:
a, loc, scale = 3, 100, 500
n_channels = 20
n_samples = 10000

### Identify and Remove Bad Templates

In [6]:
num_templates, duration, num_channels = templates.shape
print("{} contains {} templates for {} timesteps across {} channels.".format(
    templates_fname, num_templates, duration, num_channels))
max_chan_temp = get_max_chan_temps(templates)
# plot_templates(templates, max_chan_temp, n_channels=n_channels)

templates_yass.npy contains 170 templates for 121 timesteps across 384 channels.


In [7]:
bad_template_idxs = [3, 6, 27, 29, 32, 35, 36, 56, 57, 58, 59, 62, 63, 64, 74, 78, 79, 80, 85, 91, 92, \
    101, 107, 109, 110, 111, 118, 119, 121, 145, 150, 151, 152, 157, 159, 164, 165, 169]
good_templates = np.delete(templates, bad_template_idxs, axis=0)
templates_chans, templates_ptp_chans = take_channel_range(good_templates, n_channels_loc=n_channels)
positions_templates = localize_wfs(templates_ptp_chans, geom_array)

100%|████████████████████████████████████████| 132/132 [00:01<00:00, 110.54it/s]


### Produce Featurization Dataset

In [7]:
featurize_templates, featurize_predicted_ptps, featurize_positions, featurize_idx_units = featurization_dataset(
    templates_chans, positions_templates, channels_pos, a, loc, scale, n_samples=n_samples
)

100%|███████████████████████████████████| 10000/10000 [00:03<00:00, 2575.01it/s]


### Produce Positional Invariance Analysis Dataset

In [8]:
vary_feature = "x"
vary_samples = 100
invar_templates, invar_predicted_ptps, invar_positions, invar_idx_units = positional_invariance_dataset(
    templates_chans, positions_templates, channels_pos, a, loc, scale, vary_feature=vary_feature, 
    n_samples=vary_samples
)

100%|███████████████████████████████████████| 100/100 [00:00<00:00, 1828.16it/s]


### Produce Clustering Dataset

In [9]:
num_clusters = 20
num_samples_per_cluster = 100
cluster_templates, cluster_predicted_ptps, cluster_positions, cluster_idx_units = clustering_dataset(
    templates_chans, positions_templates, channels_pos, a, loc, scale, n_clusters=num_clusters, 
    num_samples_per_cluster=num_samples_per_cluster
)

100%|███████████████████████████████████████████| 20/20 [00:00<00:00, 27.27it/s]


In [1]:
def saver(func):
    def wrapper_saver(*args, **kwargs):
        templates, predicted_ptps, positions, idx_units = func(*args)
        print("Saving templates, array of size: {}".format(templates.shape))
#         np.save("{}/templates.npy".format(PROCESS_DATA_DIR), templates)
        print("Saving predicted PTPs, array of size: {}".format(predicted_ptps.shape))
#         np.save("{}/predicted_ptps.npy".format(PROCESS_DATA_DIR), predicted_ptps)
        print("Saving positions, array of size: {}".format(positions.shape))
#         np.save("{}/positions.npy".format(PROCESS_DATA_DIR), positions)
        print("Saving unit indices, array of size: {}".format(idx_units.shape))
#         np.save("{}/unit_idxs.npy".format(PROCESS_DATA_DIR), idx_units)
    return wrapper_saver

In [17]:
import scipy.stats as stats
from tqdm import tqdm

@saver
def save_featurization_dataset(templates, positions_templates, channels_pos, a, loc, scale, n_samples=10000):
    return featurization_dataset(templates_chans, positions_templates, channels_pos, a, loc, scale, 
                                 n_samples=n_samples)

In [18]:
save_featurization_dataset(
    templates_chans, positions_templates, channels_pos, a, loc, scale, n_samples=n_samples
)

100%|███████████████████████████████████| 10000/10000 [00:03<00:00, 2508.68it/s]


Saving templates, array of size: (10000, 121, 20)
Saving predicted PTPs, array of size: (10000, 20)
Saving positions, array of size: (4, 10000)
Saving unit indices, array of size: (10000,)


## Model Training