### Imports and initialization

In [1]:
from torch.utils.data import DataLoader
import torch
import torch.nn.functional as F
import torch.optim as optim
from torch.utils.data import ConcatDataset

from sklearn.preprocessing import StandardScaler, MinMaxScaler
from sklearn.metrics import confusion_matrix,f1_score
import numpy as np
import pandas as pd
import pickle

import matplotlib.pyplot as plt


import sys
import os
sys.path.append('../')

from src import Dataset, ISIDistributionDataset, EdgeDistributionDataset, FRDistributionDataset, ISIFRDistributionDataset
from src.network import MLP

aggr_dict = {'e23Cux2': 'e', 'i5Sst': 'i', 'i5Htr3a': 'i', 'e4Scnn1a': 'e', 'e4Rorb': 'e',
         'e4other': 'e', 'e4Nr5a1': 'e', 'i6Htr3a': 'i', 'i6Sst': 'i', 'e6Ntsr1': 'e',
         'i23Pvalb': 'i', 'i23Htr3a': 'i', 'i1Htr3a': 'i', 'i4Sst': 'i', 'e5Rbp4': 'e',
         'e5noRbp4': 'e', 'i23Sst': 'i', 'i4Htr3a': 'i', 'i6Pvalb': 'i', 'i5Pvalb': 'i',
         'i4Pvalb': 'i'}        
dataset = Dataset('../data', force_process=False)

### Split dataset and add correct labels for task

In [3]:
def load_data():
    if dataset.data_source == 'v1':
        if dataset.labels_col == 'pop_name':
            dataset.drop_dead_cells(cutoff=30)
            keepers = ['e5Rbp4', 'e23Cux2', 'i6Pvalb', 'e4Scnn1a', 'i23Pvalb', 'i23Htr3a',
             'e4Rorb', 'e4other', 'i5Pvalb', 'i4Pvalb', 'i23Sst', 'i4Sst', 'e4Nr5a1',
             'i1Htr3a', 'e5noRbp4', 'i6Sst', 'e6Ntsr1']
            dataset.drop_other_classes(classes_to_keep=keepers)
            if k == '17celltypes':
                pass #all filtering done above
            elif k == '13celltypes':
                aggr_dict = {'e23Cux2': 'e23', 'i5Sst': 'i5Sst', 'i5Htr3a': 'i5Htr3a', 'e4Scnn1a': 'e4', 'e4Rorb': 'e4',
                         'e4other': 'e4', 'e4Nr5a1': 'e4', 'i6Htr3a': 'i6Htr3a', 'i6Sst': 'i6Sst', 'e6Ntsr1': 'e6',
                         'i23Pvalb': 'i23Pvalb', 'i23Htr3a': 'i23Htr3a', 'i1Htr3a': 'i1Htr3a', 'i4Sst': 'i4Sst', 'e5Rbp4': 'e5',
                         'e5noRbp4': 'e5', 'i23Sst': 'i23Sst', 'i4Htr3a': 'i4Htr3a', 'i6Pvalb': 'i6Pvalb', 'i5Pvalb': 'i5Pvalb',
                         'i4Pvalb': 'i4Pvalb'}        
                dataset.aggregate_cell_classes(aggr_dict)
            elif k == '11celltypes':
                aggr_dict = {'e23Cux2': 'Cux2', 'i5Sst': 'Sst', 'i5Htr3a': 'Htr3a', 'e4Scnn1a': 'Scnn1a', 'e4Rorb': 'Rorb',
                         'e4other': 'other', 'e4Nr5a1': 'Nr5a1', 'i6Htr3a': 'Htr3a', 'i6Sst': 'Sst', 'e6Ntsr1': 'Ntsr1',
                         'i23Pvalb': 'Pvalb', 'i23Htr3a': 'Htr3a', 'i1Htr3a': 'Htr3a', 'i4Sst': 'Sst', 'e5Rbp4': 'Rbp4',
                         'e5noRbp4': 'noRbp4', 'i23Sst': 'Sst', 'i4Htr3a': 'Htr3a', 'i6Pvalb': 'Pvalb', 'i5Pvalb': 'Pvalb',
                         'i4Pvalb': 'Pvalb'}
                dataset.aggregate_cell_classes(aggr_dict)
            elif k == '4celltypes':
                aggr_dict = {'e23Cux2': 'e', 'i5Sst': 'Sst', 'i5Htr3a': 'Htr3a', 'e4Scnn1a': 'e', 'e4Rorb': 'e', 'e4other': 'e', 
                         'e4Nr5a1': 'e', 'i6Htr3a': 'Htr3a', 'i6Sst': 'Sst', 'e6Ntsr1': 'e', 'i23Pvalb': 'Pvalb', 'i23Htr3a': 'Htr3a',
                         'i1Htr3a': 'Htr3a', 'i4Sst': 'Sst', 'e5Rbp4': 'e', 'e5noRbp4': 'e', 'i23Sst': 'Sst', 'i4Htr3a': 'Htr3a',
                         'i6Pvalb': 'Pvalb', 'i5Pvalb': 'Pvalb', 'i4Pvalb': 'Pvalb'}
                dataset.aggregate_cell_classes(aggr_dict)
            elif k == '5layers':
                aggr_dict = {'e23Cux2': '23', 'i5Sst': '5', 'i5Htr3a': '5', 'e4Scnn1a': '4', 'e4Rorb': '4', 'e4other': '4',
                         'e4Nr5a1': '4', 'i6Htr3a': '6', 'i6Sst': '6', 'e6Ntsr1': '6', 'i23Pvalb': '23', 'i23Htr3a': '23',
                         'i1Htr3a': '1', 'i4Sst': '4', 'e5Rbp4': '5', 'e5noRbp4': '5', 'i23Sst': '23', 'i4Htr3a': '4',
                         'i6Pvalb': '6', 'i5Pvalb': '5', 'i4Pvalb': '4'}
                dataset.aggregate_cell_classes(aggr_dict)
            elif k == '2celltypes':
                aggr_dict = {'e23Cux2': 'e', 'i5Sst': 'i', 'i5Htr3a': 'i', 'e4Scnn1a': 'e', 'e4Rorb': 'e',
                         'e4other': 'e', 'e4Nr5a1': 'e', 'i6Htr3a': 'i', 'i6Sst': 'i', 'e6Ntsr1': 'e',
                         'i23Pvalb': 'i', 'i23Htr3a': 'i', 'i1Htr3a': 'i', 'i4Sst': 'i', 'e5Rbp4': 'e',
                         'e5noRbp4': 'e', 'i23Sst': 'i', 'i4Htr3a': 'i', 'i6Pvalb': 'i', 'i5Pvalb': 'i',
                         'i4Pvalb': 'i'}    
                dataset.aggregate_cell_classes(aggr_dict)


    # Split into train/val/test sets
    dataset.split_cell_train_val_test(test_size=test_size, val_size=val_size, seed=cell_split_seed)
    #dataset.split_trial_train_val_test(test_size=0.2, val_size=0.2, temp=True, seed=1234)

    # bining
    dataset.set_bining_parameters(bin_size=bin_size) # in seconds, so this is 200ms
    return dataset

### Normalize data and prepare dataloaders

In [4]:
def get_scaler():
    if distribution == 'ISIFR':
        fr_train_scaler, fr_bins = get_train_scaler(dataset,sampler=sampler,transform='firing_rate',bins=fr_dist_bins)
        isi_train_scaler, isi_bins = get_train_scaler(dataset,sampler=sampler,transform='interspike_interval',bins=isi_dist_bins)
        train_scaler = [isi_train_scaler,fr_train_scaler]
        train_dataset = ISIFRDistributionDataset(dataset, isi_bins=isi_bins,fr_bins=fr_bins, isi_scaler=isi_train_scaler, fr_scaler=fr_train_scaler, mode='train', sampler=sampler)
        # fix population for validation set and test set (they will be different of course)
        val_dataset = ISIFRDistributionDataset(dataset, isi_bins=isi_bins,fr_bins=fr_bins, isi_scaler=isi_train_scaler, fr_scaler=fr_train_scaler, mode='val', sampler=sampler,cell_random_seed=cell_sample_seed)
        test_dataset = ISIFRDistributionDataset(dataset, isi_bins=isi_bins,fr_bins=fr_bins, isi_scaler=isi_train_scaler, fr_scaler=fr_train_scaler, mode='test', sampler=sampler,cell_random_seed=cell_sample_seed)    

    
    elif 'ISI' in distribution:
        if distribution == 'log_ISI':
            train_scaler, bins = get_train_scaler(dataset,sampler=sampler,transform='log_interspike_interval',bins=isi_dist_bins)        
        else:
            train_scaler, bins = get_train_scaler(dataset,sampler=sampler,transform='interspike_interval',bins=isi_dist_bins)
        # Create Pytorch datasets
        train_dataset = ISIDistributionDataset(dataset, min_isi=0, max_isi=0.4, bins=isi_dist_bins, scaler=train_scaler, mode='train', sampler=sampler)
        # fix population for validation set and test set (they will be different of course)
        val_dataset = ISIDistributionDataset(dataset, min_isi=0, max_isi=0.4, bins=isi_dist_bins, scaler=train_scaler, mode='val', sampler=sampler,cell_random_seed=cell_sample_seed)
        test_dataset = ISIDistributionDataset(dataset, min_isi=0, max_isi=0.4, bins=isi_dist_bins, scaler=train_scaler, mode='test', sampler=sampler,cell_random_seed=cell_sample_seed)

    elif distribution == 'FR':
        train_scaler, bins = get_train_scaler(dataset,sampler=sampler,transform='firing_rate',bins=fr_dist_bins)
        train_dataset = FRDistributionDataset(dataset, bins=fr_dist_bins, scaler=train_scaler, mode='train', sampler=sampler)
        # fix population for validation set and test set (they will be different of course)
        val_dataset = FRDistributionDataset(dataset, bins=fr_dist_bins, scaler=train_scaler, mode='val', sampler=sampler,cell_random_seed=cell_sample_seed)
        test_dataset = FRDistributionDataset(dataset, bins=fr_dist_bins, scaler=train_scaler, mode='test', sampler=sampler,cell_random_seed=cell_sample_seed)    



    train_loader = DataLoader(train_dataset, shuffle=True, batch_size=batch_size, collate_fn=train_dataset.collate_fn)
    val_loader = DataLoader(val_dataset, batch_size=batch_size, collate_fn=val_dataset.collate_fn)
    test_loader = DataLoader(test_dataset, batch_size=batch_size, collate_fn=test_dataset.collate_fn)
    return train_scaler, train_loader, val_loader, test_loader, train_dataset, val_dataset, test_dataset

def get_train_scaler(dataset,sampler='R20',transform='interspike_interval',cell_random_seed=1,bins=list(np.arange(0,0.402,0.002)),scaler_type='Z'):
    trials = dataset.get_trials('train')
    train_mask = dataset._cell_split['train']
    X_bank = []
    for trial in trials:
        X, y, m = dataset.sample(mode='train',trial_id=trial,sampler=sampler,transform=transform,preselected_mask=train_mask)
        for x in X:
            X_bank.append(x)

    print(min([np.min(x) for x in X_bank if len(x)>0]),max([np.max(x) for x in X_bank if len(x)>0]))
            
    if type(bins) == int:
        ser,adaptive_bins = pd.qcut(np.ndarray.flatten(np.hstack(X_bank)),bins,retbins=True)
        xi_hists = []
        for xi in X_bank:
            if len(xi) > 0:
                xi_hist = np.histogram(xi,bins=adaptive_bins)[0]
            else:
                xi_hist = np.zeros(bins)
            xi_hists.append(xi_hist)
        xi_hists_array = np.vstack(xi_hists)
        if scaler_type == 'Z':
            train_scaler = StandardScaler()
            train_scaler = train_scaler.fit(xi_hists_array)
        elif scaler_type == "MinMax":
            train_scaler = MinMaxScaler()
            train_scaler = train_scaler.fit(xi_hists_array)
        bins = adaptive_bins
        
    elif (type(bins) == list) | (type(bins) == tuple):
        xi_hists = []
        for xi in X_bank:
            xi_hist = np.histogram(xi,bins=bins)[0]
            xi_hists.append(xi_hist)
        xi_hists_array = np.vstack(xi_hists)        
        if scaler_type == 'Z':
            train_scaler = StandardScaler()
            train_scaler = train_scaler.fit(xi_hists_array)
        elif scaler_type == "MinMax":
            train_scaler = MinMaxScaler()
            train_scaler = train_scaler.fit(xi_hists_array)
    
    return train_scaler, bins


### Run the model

In [6]:
def train(model, device, train_loader, optimizer, epoch, log_interval=None):
    model.train()
    for batch_idx, (data, target) in enumerate(train_loader):
        data, target = data.to(device), target.to(device)
        optimizer.zero_grad()
        output = model(data)
        loss = F.nll_loss(output, target)
        loss.backward()
        optimizer.step()
        '''
        if log_interval and batch_idx % log_interval == 0:
            
            print('Train Epoch: {} [({:.0f}%)]\tLoss: {:.6f}'.format(
                epoch, 100. * batch_idx / len(train_loader), loss.item()))
                '''

def test(model, device, loader, tag, labels=dataset.cell_type_labels):
    model.eval()
    test_loss = 0
    correct = 0
    total = 0
    preds = []
    corrects = []
    with torch.no_grad():
        for data, target in loader:
            data, target = data.to(device), target.to(device)
            output = model(data)
            test_loss += F.nll_loss(output, target, reduction='sum').item()  # sum up batch loss
            pred = output.argmax(dim=1, keepdim=True)  # get the index of the max log-probability
            pred_ = np.ndarray.flatten(pred.cpu().numpy())
            targ_ = target.cpu().numpy()
            preds.append(pred_)
            corrects.append(targ_)
            correct += pred.eq(target.view_as(pred)).sum().item()
            total += len(target)
    '''
    print('{} set: Accuracy: {}/{} ({:.0f}%)'.format(tag,
        correct, total,
        100. * correct / total))
        '''
    corrects = np.hstack(corrects)
    preds = np.hstack(preds)
    acc = f1_score(corrects,preds,average='macro')
    print('F1:',acc)
    cm = confusion_matrix(corrects,preds,normalize='true')
    return cm,acc

### Plot some figures

In [None]:
def plot_cm(cm,cell_type_labels=dataset.cell_type_labels,outfilename='ex_cm.png'):
    plt.figure(figsize = (10,7))
    plt.set_cmap('Reds')
    plt.xticks(range(len(cell_type_labels)),dataset.cell_type_labels,rotation='vertical')
    plt.yticks(range(len(cell_type_labels)),dataset.cell_type_labels)
    plt.imshow(cm,vmin=0,vmax=1)
    plt.colorbar()
    for (j,i),label in np.ndenumerate(cm):
        if i == j:
            plt.text(i,j,np.round(label,2),ha='center',va='center',fontsize='small',weight='bold')        
        else:
            plt.text(i,j,np.round(label,2),ha='center',va='center',fontsize='small')
    if outfilename == None:
        plt.show()
    else:
        plt.savefig(outfilename,dpi=300)    
    plt.clf()

def plot_accs(train_accs,val_accs,test_accs,outfilename='ex_acc.png'):
    plt.plot(train_accs,label='Train')
    plt.plot(val_accs,label='Validate')
    plt.plot(test_accs,label='Test')
    plt.xlabel('Iterations')
    plt.ylabel('F-Measure')
    plt.legend()
    if outfilename == None:
        plt.show()
    else:
        plt.savefig(outfilename,dpi=300)
    plt.clf()

### Run Hyperparameter Grid

In [7]:
hp_df = pd.read_csv('hp_grid_p6.csv') #this is a hyperparameter grid
prev_dist_params = np.asarray([0]) #just empty initialization

#FOR EVERY HYPERPARAMETER COMBINATION
for index,row in hp_df.iterrows(): #just skipping the first row because I ran previously
    
    #LOAD THE HYPERPARAMETERS
    hpc = row['hp_idx']
    k,distribution,cell_split_seed,bin_size = row['k'],row['distribution'],row['cell_split_seed'],row['bin_size']
    isi_dist_bins_start,isi_dist_bins_stop,isi_dist_bins_step,fr_dist_bins_start,fr_dist_bins_stop,fr_dist_bins_step = [row['isi_dist_bins_start'],row['isi_dist_bins_stop'],row['isi_dist_bins_step'],row['fr_dist_bins_start'],row['fr_dist_bins_stop'],row['fr_dist_bins_step']]
    lr,n_hiddens = row['lr'],row['n_hiddens']
    n_hiddens = [int(nh) for nh in n_hiddens.rsplit(',')]
    dist_params = [k,distribution,cell_split_seed,bin_size,isi_dist_bins_start,isi_dist_bins_stop,isi_dist_bins_step,fr_dist_bins_start,fr_dist_bins_stop,fr_dist_bins_step,]
    isi_dist_bins = list(np.arange(isi_dist_bins_start,isi_dist_bins_stop,isi_dist_bins_step))
    fr_dist_bins = list(range(int(fr_dist_bins_start),int(fr_dist_bins_stop),int(fr_dist_bins_step)))
    
    #SET SOME HARDCODED HYPERPARAMETERS
    batch_size = 1
    n_class = int(''.join(filter(str.isdigit, k)))
    cell_sample_seed = 1
    test_size, val_size = 0.4,0.4 #should leave a training set of only 20%
    
    dropout_p = 0 #currently overfitting
    epochs=50 #seems to be enough to converge    
    
    #IF THIS HYPERPARAMETER COMBO HAS DIFFERENT DATA (FEATURES OR CLASS LABELS) RELOAD AND RENORMALIZE THE DATA
    if dist_params!=prev_dist_params: 
        dataset = Dataset('../data', force_process=False)
        dataset.data_source = 'v1'
        dataset.labels_col = 'pop_name'
        sampler = 'R20'
        dataset.num_trials_in_window = 33
        print('Load')
        dataset = load_data()
        print('Scale')
        train_scaler, train_loader, val_loader, test_loader, train_dataset, val_dataset, test_dataset = get_scaler()
    
    #RUN THE MODEL
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    model = MLP(input_dims=train_dataset.num_bins, n_hiddens=n_hiddens, n_class=n_class, dropout_p=dropout_p).to(device)
    optimizer = optim.Adam(model.parameters(), lr=lr)
    train_accs, val_accs, test_accs = [],[],[] #used to save the f1s across iterations

    for epoch in range(1, epochs + 1):
        train(model, device, train_loader, optimizer, epoch, log_interval=5)
        train_cm, train_acc = test(model, device, train_loader, 'Train')
        val_cm, val_acc = test(model, device, val_loader, 'Val')
        test_cm, test_acc = test(model, device, test_loader, 'Test')
        train_accs.append(train_acc)
        val_accs.append(val_acc)
        test_accs.append(test_acc)
    
    #SAVE THE F1 SCORES ACROSS ITERATIONS AND FINAL CONFUSION MATRIX
    accs_df = pd.DataFrame(list(zip(train_accs, val_accs, test_accs)),columns=['train_acc','test_acc','val_acc'])
    accs_df.to_csv('{}_hp{}_f1s.csv'.format(distribution,str(hpc)),index=False)
    np.save('{}_hp{}_cms.npy'.format(distribution,str(hpc)),np.dstack([train_cm,val_cm,test_cm]))
    prev_dist_params = dist_params

['17celltypes', 'ISI', 1234, 0.2, 0, 0.401, 5e-05, 0, 51, 1]
Found processed pickle. Loading from '../data/processed/dataset.pkl'.


  if dist_params!=prev_dist_params:


Load
Scale
0.0026557706966627848 98.05970326641108
F1: 0.18916435244273475
F1: 0.1383919154007944
F1: 0.14799819093827582


F1: 0.4142392985710413
F1: 0.19260943685161394
F1: 0.19113063260093238


F1: 0.5305468984718305
F1: 0.20806080167133334
F1: 0.21222997146052136


F1: 0.5597980702576882
F1: 0.2020895225590915
F1: 0.22420991761557238


F1: 0.58625074661062
F1: 0.19691447533866102
F1: 0.2177106827328939


F1: 0.6141036691802909
F1: 0.21942143302719852
F1: 0.21709806705252083


F1: 0.6249156094497158
F1: 0.21635442862093854
F1: 0.2349450050464663


F1: 0.673675100446019
F1: 0.21388979325248877
F1: 0.2202373587461721


F1: 0.6720930288286349
F1: 0.20064280293782066
F1: 0.20987270700062888


F1: 0.693258013859217
F1: 0.19561967010112108
F1: 0.22445018977505282


F1: 0.6895093922781131
F1: 0.19775593539905198
F1: 0.21584603722934326


F1: 0.7079520098375183
F1: 0.20074150455838422
F1: 0.22573672316647025


F1: 0.7102082236681654
F1: 0.2102418706682043
F1: 0.22390881947345276


F1: 0.693

KeyboardInterrupt: 