### Imports and initialization

In [2]:
from torch.utils.data import DataLoader
import torch
import torch.nn.functional as F
import torch.optim as optim
from torch.utils.data import ConcatDataset, TensorDataset
from torch.utils.data.sampler import WeightedRandomSampler

from sklearn.preprocessing import StandardScaler, MinMaxScaler
from sklearn.metrics import confusion_matrix,f1_score
import numpy as np
import pandas as pd
import pickle

import matplotlib.pyplot as plt


import sys
import os
sys.path.append('../')

from src import Dataset, FiringRates, ISIDistribution, ConcatFeats
from src.network import MLP

aggr_dict = {'e23Cux2': 'e', 'i5Sst': 'i', 'i5Htr3a': 'i', 'e4Scnn1a': 'e', 'e4Rorb': 'e',
         'e4other': 'e', 'e4Nr5a1': 'e', 'i6Htr3a': 'i', 'i6Sst': 'i', 'e6Ntsr1': 'e',
         'i23Pvalb': 'i', 'i23Htr3a': 'i', 'i1Htr3a': 'i', 'i4Sst': 'i', 'e5Rbp4': 'e',
         'e5noRbp4': 'e', 'i23Sst': 'i', 'i4Htr3a': 'i', 'i6Pvalb': 'i', 'i5Pvalb': 'i',
         'i4Pvalb': 'i'}        
dataset = Dataset('../data', force_process=False)

Found processed pickle. Loading from '../data/processed/v1_dataset.pkl'.


In [3]:
os.environ["CUDA_VISIBLE_DEVICES"] ="0"

### Split dataset and add correct labels for task

In [4]:
def load_data():
    if dataset.data_source == 'v1':
        if dataset.labels_col == 'pop_name':
            dataset.drop_dead_cells(cutoff=30)
            keepers = ['e5Rbp4', 'e23Cux2', 'i6Pvalb', 'e4Scnn1a', 'i23Pvalb', 'i23Htr3a',
             'e4Rorb', 'e4other', 'i5Pvalb', 'i4Pvalb', 'i23Sst', 'i4Sst', 'e4Nr5a1',
             'i1Htr3a', 'e5noRbp4', 'i6Sst', 'e6Ntsr1']
            dataset.drop_other_classes(classes_to_keep=keepers)
            print(np.unique(dataset.cell_type_ids))
            if k == '17celltypes':
                pass #all filtering done above
            elif k == '13celltypes':
                aggr_dict = {'e23Cux2': 'e23', 'i5Sst': 'i5Sst', 'i5Htr3a': 'i5Htr3a', 'e4Scnn1a': 'e4', 'e4Rorb': 'e4',
                         'e4other': 'e4', 'e4Nr5a1': 'e4', 'i6Htr3a': 'i6Htr3a', 'i6Sst': 'i6Sst', 'e6Ntsr1': 'e6',
                         'i23Pvalb': 'i23Pvalb', 'i23Htr3a': 'i23Htr3a', 'i1Htr3a': 'i1Htr3a', 'i4Sst': 'i4Sst', 'e5Rbp4': 'e5',
                         'e5noRbp4': 'e5', 'i23Sst': 'i23Sst', 'i4Htr3a': 'i4Htr3a', 'i6Pvalb': 'i6Pvalb', 'i5Pvalb': 'i5Pvalb',
                         'i4Pvalb': 'i4Pvalb'}        
                dataset.aggregate_cell_classes(aggr_dict)
            elif k == '11celltypes':
                aggr_dict = {'e23Cux2': 'Cux2', 'i5Sst': 'Sst', 'i5Htr3a': 'Htr3a', 'e4Scnn1a': 'Scnn1a', 'e4Rorb': 'Rorb',
                         'e4other': 'other', 'e4Nr5a1': 'Nr5a1', 'i6Htr3a': 'Htr3a', 'i6Sst': 'Sst', 'e6Ntsr1': 'Ntsr1',
                         'i23Pvalb': 'Pvalb', 'i23Htr3a': 'Htr3a', 'i1Htr3a': 'Htr3a', 'i4Sst': 'Sst', 'e5Rbp4': 'Rbp4',
                         'e5noRbp4': 'noRbp4', 'i23Sst': 'Sst', 'i4Htr3a': 'Htr3a', 'i6Pvalb': 'Pvalb', 'i5Pvalb': 'Pvalb',
                         'i4Pvalb': 'Pvalb'}
                dataset.aggregate_cell_classes(aggr_dict)
            elif k == '4celltypes':
                aggr_dict = {'e23Cux2': 'e', 'i5Sst': 'Sst', 'i5Htr3a': 'Htr3a', 'e4Scnn1a': 'e', 'e4Rorb': 'e', 'e4other': 'e', 
                         'e4Nr5a1': 'e', 'i6Htr3a': 'Htr3a', 'i6Sst': 'Sst', 'e6Ntsr1': 'e', 'i23Pvalb': 'Pvalb', 'i23Htr3a': 'Htr3a',
                         'i1Htr3a': 'Htr3a', 'i4Sst': 'Sst', 'e5Rbp4': 'e', 'e5noRbp4': 'e', 'i23Sst': 'Sst', 'i4Htr3a': 'Htr3a',
                         'i6Pvalb': 'Pvalb', 'i5Pvalb': 'Pvalb', 'i4Pvalb': 'Pvalb'}
                dataset.aggregate_cell_classes(aggr_dict)
            elif k == '5layers':
                aggr_dict = {'e23Cux2': '23', 'i5Sst': '5', 'i5Htr3a': '5', 'e4Scnn1a': '4', 'e4Rorb': '4', 'e4other': '4',
                         'e4Nr5a1': '4', 'i6Htr3a': '6', 'i6Sst': '6', 'e6Ntsr1': '6', 'i23Pvalb': '23', 'i23Htr3a': '23',
                         'i1Htr3a': '1', 'i4Sst': '4', 'e5Rbp4': '5', 'e5noRbp4': '5', 'i23Sst': '23', 'i4Htr3a': '4',
                         'i6Pvalb': '6', 'i5Pvalb': '5', 'i4Pvalb': '4'}
                dataset.aggregate_cell_classes(aggr_dict)
            elif k == '2celltypes':
                aggr_dict = {'e23Cux2': 'e', 'i5Sst': 'i', 'i5Htr3a': 'i', 'e4Scnn1a': 'e', 'e4Rorb': 'e',
                         'e4other': 'e', 'e4Nr5a1': 'e', 'i6Htr3a': 'i', 'i6Sst': 'i', 'e6Ntsr1': 'e',
                         'i23Pvalb': 'i', 'i23Htr3a': 'i', 'i1Htr3a': 'i', 'i4Sst': 'i', 'e5Rbp4': 'e',
                         'e5noRbp4': 'e', 'i23Sst': 'i', 'i4Htr3a': 'i', 'i6Pvalb': 'i', 'i5Pvalb': 'i',
                         'i4Pvalb': 'i'}    
                dataset.aggregate_cell_classes(aggr_dict)

    # Split into train/val/test sets
    dataset.split_cell_train_val_test(test_size=test_size, val_size=val_size, seed=cell_split_seed)
    #dataset.split_trial_train_val_test(test_size=0.2, val_size=0.2, temp=True, seed=1234)

    return dataset

### Run the model

In [5]:
def train(model, device, train_loader, optimizer, epoch, log_interval=None):
    model.train()
    for batch_idx, (data, target) in enumerate(train_loader):
        data, target = data.to(device), target.to(device)
        optimizer.zero_grad()
        output = model(data)
        loss = F.nll_loss(output, target)
        loss.backward()
        optimizer.step()
        '''
        if log_interval and batch_idx % log_interval == 0:
            
            print('Train Epoch: {} [({:.0f}%)]\tLoss: {:.6f}'.format(
                epoch, 100. * batch_idx / len(train_loader), loss.item()))
                '''

def test(model, device, loader, tag, labels=dataset.cell_type_labels):
    model.eval()
    test_loss = 0
    correct = 0
    total = 0
    preds = []
    corrects = []
    with torch.no_grad():
        for data, target in loader:
            data, target = data.to(device), target.to(device)
            output = model(data)
            test_loss += F.nll_loss(output, target, reduction='sum').item()  # sum up batch loss
            pred = output.argmax(dim=1, keepdim=True)  # get the index of the max log-probability
            pred_ = np.ndarray.flatten(pred.cpu().numpy())
            targ_ = target.cpu().numpy()
            preds.append(pred_)
            corrects.append(targ_)
            correct += pred.eq(target.view_as(pred)).sum().item()
            total += len(target)
    '''
    print('{} set: Accuracy: {}/{} ({:.0f}%)'.format(tag,
        correct, total,
        100. * correct / total))
        '''
    corrects = np.hstack(corrects)
    preds = np.hstack(preds)
    acc = f1_score(corrects,preds,average='macro')
    #print('F1:',acc)
    cm = confusion_matrix(corrects,preds,normalize='true')
    return cm,acc

In [6]:
def aug_viz(X_train,y_train,aug,num_samples=20):
    y_sort_ord = np.argsort(y_train)
    sorted_Xt = X_train[y_sort_ord]
    sorted_yt = y_train[y_sort_ord]
    uniq_inds = np.unique(sorted_yt,return_index=True)[1]
    for idx, uniq_idx in enumerate(uniq_inds):
        sampled_Xt = sorted_Xt[uniq_idx:uniq_idx+num_samples]
        np.save('{}_ct{}_aug{}_{}bins_{}samples.npy'.format(k,str(idx),str(aug),str(sampled_Xt.shape[1]),str(num_samples)),sampled_Xt)

### Run Hyperparameter Grid

In [7]:
hp_df = pd.read_csv('hp_regrid111_v2_augviz.csv') #this is a hyperparameter grid
prev_dist_params = np.asarray([0]) #just empty initialization
aug = '1'
#FOR EVERY HYPERPARAMETER COMBINATION
for index,row in hp_df.iterrows(): #just skipping the first row because I ran previously
    
    #LOAD THE HYPERPARAMETERS
    hpc = row['hp_idx']
    k,distribution,cell_split_seed,bin_size = row['k'],row['distribution'],int(row['cell_split_seed']),row['bin_size']
    isi_dist_bins_start,isi_dist_bins_stop,isi_dist_bins_step,fr_dist_bins_start,fr_dist_bins_stop,fr_dist_bins_step = [row['isi_dist_bins_start'],row['isi_dist_bins_stop'],row['isi_dist_bins_step'],row['fr_dist_bins_start'],row['fr_dist_bins_stop'],row['fr_dist_bins_step']]
    lr,n_hiddens = row['lr'],row['n_hiddens']
    n_hiddens = [int(nh) for nh in n_hiddens.rsplit(',')]
    dist_params = [k,distribution,cell_split_seed,bin_size,isi_dist_bins_start,isi_dist_bins_stop,isi_dist_bins_step,fr_dist_bins_start,fr_dist_bins_stop,fr_dist_bins_step,]
    if len(list(set([isi_dist_bins_start,isi_dist_bins_stop,isi_dist_bins_step]))) > 1:
        isi_dist_bins = list(np.arange(isi_dist_bins_start,isi_dist_bins_stop,isi_dist_bins_step))
    else:
        isi_dist_bins = isi_dist_bins_start
    if len(list(set([fr_dist_bins_start,fr_dist_bins_stop,fr_dist_bins_step]))) > 1:    
        fr_dist_bins = list(range(int(fr_dist_bins_start),int(fr_dist_bins_stop),int(fr_dist_bins_step)))
    else:
        fr_dist_bins = fr_dist_bins_start    
    #SET SOME HARDCODED HYPERPARAMETERS
    batch_size = 1
    n_class = int(''.join(filter(str.isdigit, k)))
    cell_sample_seed = 1
    test_size, val_size = 0.2,0.2 
    
    if 'dropout_p' in hp_df.columns and 'weight_decay' in hp_df.columns:
        dropout_p = row['dropout_p']
        weight_decay = row['weight_decay']
    else:
        dropout_p = 0 #currently overfitting
        weight_decay = 0
    


    if 'augmentation_perc1' in hp_df.columns and 'augmentation_perc2' in hp_df.columns:
        augmentation_percs = [row['augmentation_perc1'],row['augmentation_perc2']]
    else:
        augmentation_percs = [0,0]

        
    epochs = 50 #seems to be enough to converge
    threshold = 30 #minimum number of spikes per trial
    
    #IF THIS HYPERPARAMETER COMBO HAS DIFFERENT DATA (FEATURES OR CLASS LABELS) RELOAD AND RENORMALIZE THE DATA
    if dist_params!=prev_dist_params:

        print(1)
        dataset = Dataset('../data', force_process=False)
        dataset.data_source = 'v1'
        dataset.labels_col = 'pop_name'
        dataset.num_trials_in_window = 33
        print('Load')
        dataset = load_data()
        print(dataset.cell_type_ids)

        fr_transform = FiringRates(window_size=3, bin_size=bin_size)
        if distribution == 'ISI':
            aug1_train_transform = ISIDistribution(bins=isi_dist_bins, min_isi=0, max_isi=0.4, augmentation_percs=[1,0])
            aug2_train_transform = ISIDistribution(bins=isi_dist_bins, min_isi=0, max_isi=0.4, augmentation_percs=[0,1])
            aug0_train_transform = ISIDistribution(bins=isi_dist_bins, min_isi=0, max_isi=0.4, augmentation_percs=[0,0])
        elif distribution == 'FR':
            data_transform = fr_transform
            train_transform = data_transform
            val_test_transform = data_transform
        elif distribution == 'ISIFR':
            isi_transform = ISIDistribution(bins=isi_dist_bins, min_isi=0, max_isi=0.4)
            fr_transform = FiringRates(window_size=3, bin_size=bin_size)
            data_transform = ConcatFeats(fr_transform, isi_transform)
            train_transform = data_transform
            val_test_transform = data_transform
        print(2)
        print(2.1)
        X_train_fr, y_train_fr = dataset.get_set('train', transform=fr_transform)
        print(2.4)
        train_mask = X_train_fr.sum(axis=1) > threshold
        print(3)
        print(3.1)
        X_train, y_train = dataset.get_set('train',transform=aug0_train_transform)
        X_train, y_train = X_train[train_mask], y_train[train_mask]

        aug_viz(X_train,y_train,aug='0',num_samples=20)
        print(3.2)
        X_train, y_train = dataset.get_set('train',transform=aug1_train_transform)
        X_train, y_train = X_train[train_mask], y_train[train_mask]
        aug_viz(X_train,y_train,aug='1',num_samples=20)
        print(3.3)
        X_train, y_train = dataset.get_set('train',transform=aug2_train_transform)
        X_train, y_train = X_train[train_mask], y_train[train_mask]
        aug_viz(X_train,y_train,aug='2',num_samples=20)

  if dist_params!=prev_dist_params:


1
Found processed pickle. Loading from '../data/processed/v1_dataset.pkl'.
Load
[ 0.  1.  2.  3.  4.  5.  6.  7.  8.  9. 10. 11. 12. 13. 14. 15. 16.]
[ 0.  0.  0. ... 16. 16. 16.]
2
2.1
2.4
3
3.1
3.2
3.3


  if dist_params!=prev_dist_params:


1
Found processed pickle. Loading from '../data/processed/v1_dataset.pkl'.
Load
[ 0.  1.  2.  3.  4.  5.  6.  7.  8.  9. 10. 11. 12. 13. 14. 15. 16.]
[ 0.  0.  0. ... 16. 16. 16.]
2
2.1
2.4
3
3.1
3.2
3.3


  if dist_params!=prev_dist_params:


1
Found processed pickle. Loading from '../data/processed/v1_dataset.pkl'.
Load
[ 0.  1.  2.  3.  4.  5.  6.  7.  8.  9. 10. 11. 12. 13. 14. 15. 16.]
[ 0.  0.  0. ... 16. 16. 16.]
2
2.1
2.4
3
3.1
3.2
3.3


  if dist_params!=prev_dist_params:


1
Found processed pickle. Loading from '../data/processed/v1_dataset.pkl'.
Load
[ 0.  1.  2.  3.  4.  5.  6.  7.  8.  9. 10. 11. 12. 13. 14. 15. 16.]
[ 0.  0.  0. ... 16. 16. 16.]
2
2.1
2.4
3
3.1
3.2
3.3
