### Imports and initialization

In [1]:
from torch.utils.data import DataLoader
import torch
import torch.nn.functional as F
import torch.optim as optim
from torch.utils.data import TensorDataset
from torch.utils.data.sampler import WeightedRandomSampler

from sklearn.preprocessing import StandardScaler, MinMaxScaler
from sklearn.metrics import confusion_matrix,f1_score
import numpy as np
import pandas as pd
import pickle

import matplotlib.pyplot as plt


import sys
import os
sys.path.append('../')

from src import Dataset, FiringRates, ISIDistribution
from src.network import MLP
from src.augmentations import slice_data

aggr_dict = {'e23Cux2': 'e', 'i5Sst': 'i', 'i5Htr3a': 'i', 'e4Scnn1a': 'e', 'e4Rorb': 'e',
         'e4other': 'e', 'e4Nr5a1': 'e', 'i6Htr3a': 'i', 'i6Sst': 'i', 'e6Ntsr1': 'e',
         'i23Pvalb': 'i', 'i23Htr3a': 'i', 'i1Htr3a': 'i', 'i4Sst': 'i', 'e5Rbp4': 'e',
         'e5noRbp4': 'e', 'i23Sst': 'i', 'i4Htr3a': 'i', 'i6Pvalb': 'i', 'i5Pvalb': 'i',
         'i4Pvalb': 'i'}   
data_root = '../data/'
dataset = Dataset(data_root, force_process=False) #specify data location

Found processed pickle. Loading from '../data/processed/v1_dataset.pkl'.


In [2]:
X = np.array([[1,2,3,4,5,6], [1,2,3,4,5,6]])

In [3]:
os.environ["CUDA_VISIBLE_DEVICES"] ="0" #specify gpu to use

### Split dataset and add correct labels for task

In [4]:
def load_data():
    if dataset.data_source == 'v1':
        if dataset.labels_col == 'pop_name':
            dataset.drop_dead_cells(cutoff=30) #each sample must have atleast 30 spikes
            keepers = ['e5Rbp4', 'e23Cux2', 'i6Pvalb', 'e4Scnn1a', 'i23Pvalb', 'i23Htr3a',
             'e4Rorb', 'e4other', 'i5Pvalb', 'i4Pvalb', 'i23Sst', 'i4Sst', 'e4Nr5a1',
             'i1Htr3a', 'e5noRbp4', 'i6Sst', 'e6Ntsr1'] #cell classes identified by Louis as not being too quiet
            dataset.drop_other_classes(classes_to_keep=keepers)
            if k == '17celltypes':
                pass #all necessary filtering done above
            elif k == '13celltypes':
                aggr_dict = {'e23Cux2': 'e23', 'i5Sst': 'i5Sst', 'i5Htr3a': 'i5Htr3a', 'e4Scnn1a': 'e4', 'e4Rorb': 'e4',
                         'e4other': 'e4', 'e4Nr5a1': 'e4', 'i6Htr3a': 'i6Htr3a', 'i6Sst': 'i6Sst', 'e6Ntsr1': 'e6',
                         'i23Pvalb': 'i23Pvalb', 'i23Htr3a': 'i23Htr3a', 'i1Htr3a': 'i1Htr3a', 'i4Sst': 'i4Sst', 'e5Rbp4': 'e5',
                         'e5noRbp4': 'e5', 'i23Sst': 'i23Sst', 'i4Htr3a': 'i4Htr3a', 'i6Pvalb': 'i6Pvalb', 'i5Pvalb': 'i5Pvalb',
                         'i4Pvalb': 'i4Pvalb'}        
                dataset.aggregate_cell_classes(aggr_dict)
            elif k == '11celltypes':
                aggr_dict = {'e23Cux2': 'Cux2', 'i5Sst': 'Sst', 'i5Htr3a': 'Htr3a', 'e4Scnn1a': 'Scnn1a', 'e4Rorb': 'Rorb',
                         'e4other': 'other', 'e4Nr5a1': 'Nr5a1', 'i6Htr3a': 'Htr3a', 'i6Sst': 'Sst', 'e6Ntsr1': 'Ntsr1',
                         'i23Pvalb': 'Pvalb', 'i23Htr3a': 'Htr3a', 'i1Htr3a': 'Htr3a', 'i4Sst': 'Sst', 'e5Rbp4': 'Rbp4',
                         'e5noRbp4': 'noRbp4', 'i23Sst': 'Sst', 'i4Htr3a': 'Htr3a', 'i6Pvalb': 'Pvalb', 'i5Pvalb': 'Pvalb',
                         'i4Pvalb': 'Pvalb'}
                dataset.aggregate_cell_classes(aggr_dict)
            elif k == '4celltypes':
                aggr_dict = {'e23Cux2': 'e', 'i5Sst': 'Sst', 'i5Htr3a': 'Htr3a', 'e4Scnn1a': 'e', 'e4Rorb': 'e', 'e4other': 'e', 
                         'e4Nr5a1': 'e', 'i6Htr3a': 'Htr3a', 'i6Sst': 'Sst', 'e6Ntsr1': 'e', 'i23Pvalb': 'Pvalb', 'i23Htr3a': 'Htr3a',
                         'i1Htr3a': 'Htr3a', 'i4Sst': 'Sst', 'e5Rbp4': 'e', 'e5noRbp4': 'e', 'i23Sst': 'Sst', 'i4Htr3a': 'Htr3a',
                         'i6Pvalb': 'Pvalb', 'i5Pvalb': 'Pvalb', 'i4Pvalb': 'Pvalb'}
                dataset.aggregate_cell_classes(aggr_dict)
            elif k == '5layers':
                aggr_dict = {'e23Cux2': '23', 'i5Sst': '5', 'i5Htr3a': '5', 'e4Scnn1a': '4', 'e4Rorb': '4', 'e4other': '4',
                         'e4Nr5a1': '4', 'i6Htr3a': '6', 'i6Sst': '6', 'e6Ntsr1': '6', 'i23Pvalb': '23', 'i23Htr3a': '23',
                         'i1Htr3a': '1', 'i4Sst': '4', 'e5Rbp4': '5', 'e5noRbp4': '5', 'i23Sst': '23', 'i4Htr3a': '4',
                         'i6Pvalb': '6', 'i5Pvalb': '5', 'i4Pvalb': '4'}
                dataset.aggregate_cell_classes(aggr_dict)
            elif k == '2celltypes':
                aggr_dict = {'e23Cux2': 'e', 'i5Sst': 'i', 'i5Htr3a': 'i', 'e4Scnn1a': 'e', 'e4Rorb': 'e',
                         'e4other': 'e', 'e4Nr5a1': 'e', 'i6Htr3a': 'i', 'i6Sst': 'i', 'e6Ntsr1': 'e',
                         'i23Pvalb': 'i', 'i23Htr3a': 'i', 'i1Htr3a': 'i', 'i4Sst': 'i', 'e5Rbp4': 'e',
                         'e5noRbp4': 'e', 'i23Sst': 'i', 'i4Htr3a': 'i', 'i6Pvalb': 'i', 'i5Pvalb': 'i',
                         'i4Pvalb': 'i'}    
                dataset.aggregate_cell_classes(aggr_dict)

    # Split into train/val/test sets
    dataset.split_cell_train_val_test(test_size=test_size, val_size=val_size, seed=cell_split_seed)
    
    return dataset

### Run the model

In [5]:
def train(model, device, train_loader, optimizer, epoch, resample=False, window_size=None, log_interval=None):
    model.train()
    for batch_idx, (data, target) in enumerate(train_loader):
        data, target = data.to(device), target.to(device)
        optimizer.zero_grad()
        if resample:
            # Sample consecutive windows
#             start = np.random.randint(0, int(data.shape[1] // 2))
#             stop = start + int(data.shape[1] // 2)
#             data = data[:,start:stop,:].sum(axis=1)
            # Sample non consecutive windows
            draws = np.random.choice(np.arange(0, data.shape[1]), size=data.shape[1])
            data = data[:, draws, :].sum(axis=1)
        output = model(data)
        loss = F.nll_loss(output, target)
        loss.backward()
        optimizer.step()

def test(model, device, loader, tag, resample=False, labels=dataset.cell_type_labels):
    model.eval()
    test_loss = 0
    correct = 0
    total = 0
    preds = []
    corrects = []
    with torch.no_grad():
        for data, target in loader:
            data, target = data.to(device), target.to(device)
            if tag == 'Train' and resample:
                start = np.random.randint(0, int(data.shape[1] // 2))
                stop = start + int(data.shape[1] // 2)
                data = data[:,start:stop,:].sum(axis=1)
            output = model(data)
            test_loss += F.nll_loss(output, target, reduction='sum').item()  # sum up batch loss
            pred = output.argmax(dim=1, keepdim=True)  # get the index of the max log-probability
            pred_ = np.ndarray.flatten(pred.cpu().numpy())
            targ_ = target.cpu().numpy()
            preds.append(pred_)
            corrects.append(targ_)
            correct += pred.eq(target.view_as(pred)).sum().item()
            total += len(target)
    corrects = np.hstack(corrects)
    preds = np.hstack(preds)
    acc = f1_score(corrects,preds,average='macro')
    cm = confusion_matrix(corrects,preds,normalize='true')
    return cm,acc

### Run Hyperparameter Grid

In [6]:
hp_df = pd.read_csv('hp_grid_resampling_v2.csv',sep='\t') #this is a hyperparameter grid
prev_dist_params = np.asarray([0]) #just empty initialization

#FOR EVERY HYPERPARAMETER COMBINATION
for index,row in hp_df.iterrows(): #just skipping the first row because I ran previously
    
    #LOAD THE HYPERPARAMETERS
    print('Loading Parameters')
    hpc = row['hp_idx']
    #bin_size (below) applies to firing rate distributions only 
    k,distribution,cell_split_seed,bin_size = row['k'],row['distribution'],int(row['cell_split_seed']),row['bin_size']
    isi_dist_bins_start,isi_dist_bins_stop,isi_dist_bins_step,fr_dist_bins_start,fr_dist_bins_stop,fr_dist_bins_step = [row['isi_dist_bins_start'],row['isi_dist_bins_stop'],row['isi_dist_bins_step'],row['fr_dist_bins_start'],row['fr_dist_bins_stop'],row['fr_dist_bins_step']]
    lr,n_hiddens = row['lr'],row['n_hiddens']
    n_hiddens = [int(nh) for nh in n_hiddens.rsplit(',')]
    dist_params = [k,distribution,cell_split_seed,bin_size,isi_dist_bins_start,isi_dist_bins_stop,isi_dist_bins_step,fr_dist_bins_start,fr_dist_bins_stop,fr_dist_bins_step,]
    
    if len(list(set([isi_dist_bins_start,isi_dist_bins_stop,isi_dist_bins_step]))) > 1:
        isi_dist_bins = list(np.arange(isi_dist_bins_start,isi_dist_bins_stop,isi_dist_bins_step))
    else:
        isi_dist_bins = isi_dist_bins_start
    if len(list(set([fr_dist_bins_start,fr_dist_bins_stop,fr_dist_bins_step]))) > 1:    
        fr_dist_bins = list(range(int(fr_dist_bins_start),int(fr_dist_bins_stop),int(fr_dist_bins_step)))
    else:
        fr_dist_bins = fr_dist_bins_start    
    
    if 'dropout_p' in hp_df.columns and 'weight_decay' in hp_df.columns:
        dropout_p = row['dropout_p']
        weight_decay = row['weight_decay']
    else:
        dropout_p = 0 #currently overfitting
        weight_decay = 0
        
    if 'preaugmentation_perc1' in hp_df.columns and 'preaugmentation_perc2' in hp_df.columns:
        preaugmentation_percs = [row['preaugmentation_perc1'],row['preaugmentation_perc2']]
    else:
        preaugmentation_percs = [0,0]    
        
    if 'augmentation_perc1' in hp_df.columns and 'augmentation_perc2' in hp_df.columns:
        augmentation_percs = [row['augmentation_perc1'],row['augmentation_perc2']]
    else:
        augmentation_percs = [0,0]
        
    if 'window_size' in hp_df.columns:
        window_size = row['window_size']
        dist_params.append(window_size)
    else : 
        window_size = 0

    
    #SET SOME HARDCODED HYPERPARAMETERS
    batch_size = 1 #confusingly named- this is multiplied by base batch size to specify number of samples per class per batch (i.e. if base_batch_size = 20, num_classes = 5 and batch_size=1 , then 100 total samples per batch)
    base_batch_size = 20 
    n_class = int(''.join(filter(str.isdigit, k))) #determine number of classes from k in hp grid
    cell_sample_seed = 1
    test_size, val_size = 0.2,0.2   
#     epochs = 50 #seems to be enough to converge
    epochs = 100
    threshold = 30 #minimum number of spikes per trial
        
    #IF THIS HYPERPARAMETER COMBO HAS DIFFERENT DATA (FEATURES OR CLASS LABELS) RELOAD AND RENORMALIZE THE DATA
    if np.array_equal(np.asarray(dist_params),np.asarray(prev_dist_params))==False:
        dataset = Dataset(data_root, force_process=False)
        dataset.data_source = 'v1'
        dataset.labels_col = 'pop_name'
        dataset.num_trials_in_window = 33 
        print('Loading Data')
        dataset = load_data()
        print('Defining Transforms')
        fr_transform = FiringRates(window_size=3, bin_size=bin_size) #window_size is trial size in seconds
        if distribution == 'ISI':
            train_transform = ISIDistribution(bins=isi_dist_bins, min_isi=isi_dist_bins_start, 
                                              max_isi=isi_dist_bins_stop, 
                                              augmentation_percs=augmentation_percs, 
                                              preaugmentation_percs=preaugmentation_percs,
                                              window_size=window_size)
            val_test_transform = ISIDistribution(bins=isi_dist_bins, min_isi=isi_dist_bins_start, 
                                                 max_isi=isi_dist_bins_stop, augmentation_percs=[0,0], 
                                                 preaugmentation_percs=preaugmentation_percs,
                                                window_size=window_size) #don't augment test/validation set  
        elif distribution == 'FR':
            data_transform = fr_transform
            train_transform = data_transform
            val_test_transform = data_transform
        elif distribution == 'ISIFR':
            isi_transform = ISIDistribution(bins=isi_dist_bins, min_isi=isi_dist_bins_start, 
                                            max_isi=isi_dist_bins_stop)
            fr_transform = FiringRates(window_size=3, bin_size=bin_size)#window_size is trial size in seconds
            data_transform = ConcatFeats(fr_transform, isi_transform)
            train_transform = data_transform
            val_test_transform = data_transform
        
        X_train_fr, y_train_fr = dataset.get_set('train', transform=fr_transform)
        X_test_fr, y_test_fr = dataset.get_set('test', transform=fr_transform)    
        X_val_fr, y_val_fr = dataset.get_set('val', transform=fr_transform)
        
        #SPLIT TRAIN/TEST/VAL SETS
        #ENFORCE THRESHOLD ON MIN NUMBER OF SPIKES PER TRIAL
        print('Filtering')
        train_mask = X_train_fr.sum(axis=1) > threshold
        test_mask = X_test_fr.sum(axis=1) > threshold
        val_mask = X_val_fr.sum(axis=1) > threshold
        print('Transforming')
        print('...Train')
        X_train, y_train = dataset.get_set('train')
        X_train, y_train = X_train[train_mask], y_train[train_mask]
        if window_size > 0:
            X_train, y_train = slice_data(X_train, y_train, window_size)
        X_train = train_transform(np.array(X_train))

        print('...Test')
        X_test, y_test = dataset.get_set('test')
        X_test, y_test = X_test[test_mask], y_test[test_mask]
        if window_size > 0:
            X_test, y_test = slice_data(X_test, y_test, window_size)
        X_test = val_test_transform(np.array(X_test))
        
        print('...Val')
        X_val, y_val = dataset.get_set('val')
        X_val, y_val = X_val[val_mask], y_val[val_mask]
        if window_size > 0:
            X_val, y_val = slice_data(X_val, y_val, window_size)
        X_val = val_test_transform(np.array(X_val))
      
        #NORMALIZE BASED ON TRAINING SET
        print('Normalizing')
        if len(X_train.shape) == 2:
            train_scaler = StandardScaler()
            train_scaler.fit(X_train.reshape(X_train.shape[0], -1))
            print('...Train')
            X_train = train_scaler.transform(X_train.reshape(X_train.shape[0], -1)).reshape(X_train.shape)
            print('...Test')
            X_test = train_scaler.transform(X_test.reshape(X_test.shape[0], -1)).reshape(X_test.shape)
            print('...Val')
            X_val = train_scaler.transform(X_val.reshape(X_val.shape[0], -1)).reshape(X_val.shape)
        else : 
            scalers = {}
            for i in range(X_train.shape[1]):
                scalers[i] = StandardScaler()
                X_train[:, i, :] = scalers[i].fit_transform(X_train[:, i, :]) 

            for i in range(X_test.shape[1]):
                X_test[:, i, :] = scalers[i].transform(X_test[:, i, :]) 
                
            for i in range(X_val.shape[1]):
                X_val[:, i, :] = scalers[i].transform(X_val[:, i, :]) 

        # For scaling purposes ?
        if window_size > 0:
#             half_window = int(X_train.shape[1] // 2)
#             X_test = X_test[:,0:half_window,:].sum(axis=1)
#             X_val = X_val[0:,0:half_window,:].sum(axis=1)
            X_test = X_test.sum(axis=1)
            X_val = X_val.sum(axis=1)
        
        #RE-DTYPE SETS
        print('Retyping')
        X_train, y_train = torch.FloatTensor(X_train), torch.LongTensor(y_train)
        train_dataset = TensorDataset(X_train,y_train)
        X_test, y_test = torch.FloatTensor(X_test), torch.LongTensor(y_test)
        test_dataset = TensorDataset(X_test,y_test)
        X_val, y_val = torch.FloatTensor(X_val), torch.LongTensor(y_val)
        val_dataset = TensorDataset(X_val,y_val)
        
        #SAMPLE WITH WEIGHTS TO COUNTER CLASS IMBALANCE
        print('Preparing Loaders')
        class_sample_count = np.array([len(np.where(y_train == t)[0]) for t in np.unique(y_train)])
        weight = 1. / class_sample_count
        samples_weight = np.array([weight[t] for t in y_train])
        samples_weight = torch.from_numpy(samples_weight)
        samples_weight = samples_weight.double()
        sampler = WeightedRandomSampler(samples_weight, len(samples_weight))
        train_loader = DataLoader(train_dataset, batch_size=base_batch_size*dataset.num_cell_types, sampler=sampler)
        test_loader = DataLoader(test_dataset, batch_size=base_batch_size*dataset.num_cell_types, shuffle=True)
        val_loader = DataLoader(val_dataset, batch_size=base_batch_size*dataset.num_cell_types, shuffle=True)
        
#         sys.exit()
        
    #RUN THE MODEL
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    model = MLP(input_dims=X_train.shape[-1], n_hiddens=n_hiddens, n_class=n_class, dropout_p=dropout_p).to(device)
    optimizer = optim.AdamW(model.parameters(), lr=lr, weight_decay=weight_decay)
    train_accs, val_accs, test_accs = [],[],[] #used to save the f1s across iterations
    resample = True if window_size > 0 else False
    
    print('Running HP Idx',hpc)
    for epoch in range(1, epochs + 1):
        train(model, device, train_loader, optimizer, epoch, log_interval=5, resample=resample)
        train_cm, train_acc = test(model, device, train_loader, 'Train', resample=resample)
        val_cm, val_acc = test(model, device, val_loader, 'Val')
        test_cm, test_acc = test(model, device, test_loader, 'Test')
        train_accs.append(train_acc)
        val_accs.append(val_acc)
        test_accs.append(test_acc)
        if epoch%10==0:
            print(train_acc,val_acc,test_acc)
    
    #SAVE THE F1 SCORES ACROSS ITERATIONS AND FINAL CONFUSION MATRIX
    accs_df = pd.DataFrame(list(zip(train_accs, val_accs, test_accs)),columns=['train_acc','test_acc','val_acc'])
    try:
        accs_df.to_csv('hp_grid_f1s_{}.csv'.format(hpc),index=False)
    except:
        print('Error saving F1s:',hpc)
    try:
        np.save('ex_hp_grid_f1s_v2.npy')
    except:
        print('Error saving CM:',hpc)
    
    prev_dist_params = dist_params


Loading Parameters
Found processed pickle. Loading from '../data/processed/v1_dataset.pkl'.
Loading Data
Defining Transforms
Filtering
Transforming
...Train
...Test
...Val
Normalizing
Retyping
Preparing Loaders


  X_train, y_train = torch.FloatTensor(X_train), torch.LongTensor(y_train)
  X_test, y_test = torch.FloatTensor(X_test), torch.LongTensor(y_test)
  X_val, y_val = torch.FloatTensor(X_val), torch.LongTensor(y_val)


Running HP Idx 1
0.4557955380390162 0.2668398624543523 0.27259454420306684
0.47784991991628717 0.2780833085187512 0.28381739369872494
0.48157767570720345 0.2907498159097917 0.2952154979817085
0.4654374370041891 0.28446785532095714 0.2933372715012257
0.48208891884366323 0.29171911671421025 0.29836949745876296
0.507406210420959 0.29265803603152074 0.30674476479871393
0.4766898508843505 0.3015126171920059 0.30490050933726165
0.4770410724573675 0.2955601360836763 0.3009556785224405
0.4891569172324222 0.29719834033349196 0.30663273185973466
0.48563635082852097 0.29845426449023 0.30459279428539376
Error saving CM: 1
Loading Parameters
Found processed pickle. Loading from '../data/processed/v1_dataset.pkl'.
Loading Data
Defining Transforms
Filtering
Transforming
...Train
...Test
...Val
Normalizing
Retyping


  X_train, y_train = torch.FloatTensor(X_train), torch.LongTensor(y_train)
  X_test, y_test = torch.FloatTensor(X_test), torch.LongTensor(y_test)
  X_val, y_val = torch.FloatTensor(X_val), torch.LongTensor(y_val)


Preparing Loaders
Running HP Idx 2
0.5624641617786494 0.2820787636306288 0.28295056875068364
0.5779451163265201 0.29394994266810487 0.3034234961850877
0.595180034619076 0.28916482531676485 0.2935006841556329
0.6130403713556898 0.31035343947932803 0.3135261110364766
0.615385667870344 0.29852017826722355 0.3035551638383056
0.623289920512903 0.316710063570498 0.3192517976540746
0.6256559269205639 0.3102476091323883 0.3048186700285861
0.6201324725476839 0.3214158727096928 0.31530134442335456


KeyboardInterrupt: 