In [1]:
from torch.utils.data import DataLoader
import torch
import torch.nn.functional as F
import torch.optim as optim
from torch.utils.data import ConcatDataset

from sklearn.preprocessing import StandardScaler
import numpy as np
import pandas as pd

import sys
sys.path.append('../')

from src import Dataset, ISIDistributionDataset, EdgeDistributionDataset, FRDistributionDataset
from src.network import MLP

### Loading dataset

In [2]:
dataset = Dataset('../data', force_process=False)
dataset.data_source = 'v1'
dataset.labels_col = 'pop_name'
k = '17celltypes'
distribution = 'ISI'
sampler = 'R20'
dataset.num_trials_in_window = 100
batch_size = 1
n_class = int(''.join(filter(str.isdigit, k)))
cell_sample_seed = 1
cell_split_seed = 1234
test_size, val_size = 0.2,0.2
bin_size = 0.2
isi_dist_bins = list(np.arange(0,0.402,0.002))
fr_dist_bins = list(range(0,51,1))

if dataset.data_source == 'v1':
    if dataset.labels_col == 'pop_name':
        dataset.drop_dead_cells(cutoff=30)
        keepers = ['e5Rbp4', 'e23Cux2', 'i6Pvalb', 'e4Scnn1a', 'i23Pvalb', 'i23Htr3a',
         'e4Rorb', 'e4other', 'i5Pvalb', 'i4Pvalb', 'i23Sst', 'i4Sst', 'e4Nr5a1',
         'i1Htr3a', 'e5noRbp4', 'i6Sst', 'e6Ntsr1']
        dataset.drop_other_classes(classes_to_keep=keepers)
        if k == '17celltypes':
            pass #all filtering done above
        elif k == '13celltypes':
            aggr_dict = {'e23Cux2': 'e23', 'i5Sst': 'i5Sst', 'i5Htr3a': 'i5Htr3a', 'e4Scnn1a': 'e4', 'e4Rorb': 'e4',
                     'e4other': 'e4', 'e4Nr5a1': 'e4', 'i6Htr3a': 'i6Htr3a', 'i6Sst': 'i6Sst', 'e6Ntsr1': 'e6',
                     'i23Pvalb': 'i23Pvalb', 'i23Htr3a': 'i23Htr3a', 'i1Htr3a': 'i1Htr3a', 'i4Sst': 'i4Sst', 'e5Rbp4': 'e5',
                     'e5noRbp4': 'e5', 'i23Sst': 'i23Sst', 'i4Htr3a': 'i4Htr3a', 'i6Pvalb': 'i6Pvalb', 'i5Pvalb': 'i5Pvalb',
                     'i4Pvalb': 'i4Pvalb'}        
            dataset.aggregate_cell_classes(aggr_dict)
        elif k == '11celltypes':
            aggr_dict = {'e23Cux2': 'Cux2', 'i5Sst': 'Sst', 'i5Htr3a': 'Htr3a', 'e4Scnn1a': 'Scnn1a', 'e4Rorb': 'Rorb',
                     'e4other': 'other', 'e4Nr5a1': 'Nr5a1', 'i6Htr3a': 'Htr3a', 'i6Sst': 'Sst', 'e6Ntsr1': 'Ntsr1',
                     'i23Pvalb': 'Pvalb', 'i23Htr3a': 'Htr3a', 'i1Htr3a': 'Htr3a', 'i4Sst': 'Sst', 'e5Rbp4': 'Rbp4',
                     'e5noRbp4': 'noRbp4', 'i23Sst': 'Sst', 'i4Htr3a': 'Htr3a', 'i6Pvalb': 'Pvalb', 'i5Pvalb': 'Pvalb',
                     'i4Pvalb': 'Pvalb'}
            dataset.aggregate_cell_classes(aggr_dict)
        elif k == '4celltypes':
            aggr_dict = {'e23Cux2': 'e', 'i5Sst': 'Sst', 'i5Htr3a': 'Htr3a', 'e4Scnn1a': 'e', 'e4Rorb': 'e', 'e4other': 'e', 
                     'e4Nr5a1': 'e', 'i6Htr3a': 'Htr3a', 'i6Sst': 'Sst', 'e6Ntsr1': 'e', 'i23Pvalb': 'Pvalb', 'i23Htr3a': 'Htr3a',
                     'i1Htr3a': 'Htr3a', 'i4Sst': 'Sst', 'e5Rbp4': 'e', 'e5noRbp4': 'e', 'i23Sst': 'Sst', 'i4Htr3a': 'Htr3a',
                     'i6Pvalb': 'Pvalb', 'i5Pvalb': 'Pvalb', 'i4Pvalb': 'Pvalb'}
            dataset.aggregate_cell_classes(aggr_dict)
        elif k == '5layers':
            aggr_dict = {'e23Cux2': '23', 'i5Sst': '5', 'i5Htr3a': '5', 'e4Scnn1a': '4', 'e4Rorb': '4', 'e4other': '4',
                     'e4Nr5a1': '4', 'i6Htr3a': '6', 'i6Sst': '6', 'e6Ntsr1': '6', 'i23Pvalb': '23', 'i23Htr3a': '23',
                     'i1Htr3a': '1', 'i4Sst': '4', 'e5Rbp4': '5', 'e5noRbp4': '5', 'i23Sst': '23', 'i4Htr3a': '4',
                     'i6Pvalb': '6', 'i5Pvalb': '5', 'i4Pvalb': '4'}
            dataset.aggregate_cell_classes(aggr_dict)
            raise NotImplementedError    
      
        
# Split into train/val/test sets
dataset.split_cell_train_val_test(test_size=test_size, val_size=val_size, seed=cell_split_seed)
#dataset.split_trial_train_val_test(test_size=0.2, val_size=0.2, temp=True, seed=1234)

# bining
dataset.set_bining_parameters(bin_size=bin_size) # in seconds, so this is 200ms

Found processed pickle. Loading from '../data/processed/dataset.pkl'.


In [3]:
def get_train_scaler(dataset,sampler='R20',transform='interspike_interval',cell_random_seed=1,bins=list(np.arange(0,0.402,0.002))):
    trials = dataset.get_trials('train')
    X_bank = []
    for trial in trials:
        X, y, m = dataset.sample(mode='train',trial_id=trial,sampler=sampler,transform=transform,cell_random_seed=cell_random_seed)
        for x in X:
            X_bank.append(x)

    if type(bins) == int:
        ser,adaptive_bins = pd.qcut(np.ndarray.flatten(np.hstack(X_bank)),bins,retbins=True)
        xi_hists = []
        for xi in X_bank:
            if len(xi) > 0:
                xi_hist = np.histogram(xi,bins=adaptive_bins)[0]
            else:
                xi_hist = np.zeros(bins)
            xi_hists.append(xi_hist)
        xi_hists_array = np.vstack(xi_hists)        
        train_scaler = StandardScaler()
        train_scaler = train_scaler.fit(xi_hists_array)
        bins = adaptive_bins
        
    elif (type(bins) == list) | (type(bins) == tuple):
        xi_hists = []
        for xi in X_bank:
            xi_hist = np.histogram(xi,bins=bins)[0]
            xi_hists.append(xi_hist)
        xi_hists_array = np.vstack(xi_hists)        
        train_scaler = StandardScaler()
        train_scaler = train_scaler.fit(xi_hists_array)
    
    return train_scaler, bins


In [6]:
if distribution == 'ISI':
    train_scaler, bins = get_train_scaler(dataset,sampler=sampler,transform='interspike_interval',cell_random_seed=1,bins=list(np.arange(0,0.402,0.002)))
    # Create Pytorch datasets
    train_dataset = ISIDistributionDataset(dataset, min_isi=0, max_isi=0.4, bins=isi_dist_bins, scaler=train_scaler, mode='train', sampler=sampler,cell_random_seed=cell_sample_seed)
    # fix population for validation set and test set (they will be different of course)
    val_dataset = ISIDistributionDataset(dataset, min_isi=0, max_isi=0.4, bins=isi_dist_bins, scaler=train_scaler, mode='val', sampler=sampler,cell_random_seed=cell_sample_seed)
    test_dataset = ISIDistributionDataset(dataset, min_isi=0, max_isi=0.4, bins=isi_dist_bins, scaler=train_scaler, mode='test', sampler=sampler,cell_random_seed=cell_sample_seed)

elif distribution == 'FR':
    train_scaler, bins = get_train_scaler(dataset,sampler=sampler,transform='firing_rate',cell_random_seed=1,bins=list(range(0,51,1)))
    train_dataset = FRDistributionDataset(dataset, bins=fr_dist_bins, scaler=train_scaler, mode='train', sampler=sampler,cell_random_seed=cell_sample_seed)
    # fix population for validation set and test set (they will be different of course)
    val_dataset = FRDistributionDataset(dataset, bins=fr_dist_bins, scaler=train_scaler, mode='val', sampler=sampler,cell_random_seed=cell_sample_seed)
    test_dataset = FRDistributionDataset(dataset, bins=fr_dist_bins, scaler=train_scaler, mode='test', sampler=sampler,cell_random_seed=cell_sample_seed)    
    
train_loader = DataLoader(train_dataset, shuffle=True, batch_size=batch_size, collate_fn=train_dataset.collate_fn)
val_loader = DataLoader(val_dataset, batch_size=batch_size, collate_fn=val_dataset.collate_fn)
test_loader = DataLoader(test_dataset, batch_size=batch_size, collate_fn=test_dataset.collate_fn)    

In [7]:
def train(model, device, train_loader, optimizer, epoch, log_interval=None):
    model.train()
    for batch_idx, (data, target) in enumerate(train_loader):
        data, target = data.to(device), target.to(device)
        optimizer.zero_grad()
        output = model(data)
        loss = F.nll_loss(output, target)
        loss.backward()
        optimizer.step()
        if log_interval and batch_idx % log_interval == 0:
            print('Train Epoch: {} [({:.0f}%)]\tLoss: {:.6f}'.format(
                epoch, 100. * batch_idx / len(train_loader), loss.item()))

def test(model, device, loader, tag):
    model.eval()
    test_loss = 0
    correct = 0
    total = 0
    with torch.no_grad():
        for data, target in loader:
            data, target = data.to(device), target.to(device)
            output = model(data)
            test_loss += F.nll_loss(output, target, reduction='sum').item()  # sum up batch loss
            pred = output.argmax(dim=1, keepdim=True)  # get the index of the max log-probability
            correct += pred.eq(target.view_as(pred)).sum().item()
            total += len(target)

    print('{} set: Accuracy: {}/{} ({:.0f}%)'.format(tag,
        correct, total,
        100. * correct / total))

In [None]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

model = MLP(input_dims=train_dataset.num_bins, n_hiddens=[20,20,20], n_class=n_class, dropout_p=0.2).to(device)

lr = 1e-2
optimizer = optim.Adam(model.parameters(), lr=lr)

epochs=100
for epoch in range(1, epochs + 1):
    train(model, device, train_loader, optimizer, epoch, log_interval=5)
    test(model, device, train_loader, 'Train')
    test(model, device, val_loader, 'Val')
    test(model, device, test_loader, 'Test')
    print('\n')

Train set: Accuracy: 9986/34000 (29%)
Val set: Accuracy: 8647/34000 (25%)
Test set: Accuracy: 7028/34000 (21%)


Train set: Accuracy: 12198/34000 (36%)
Val set: Accuracy: 9322/34000 (27%)
Test set: Accuracy: 7402/34000 (22%)


Train set: Accuracy: 12648/34000 (37%)
Val set: Accuracy: 8847/34000 (26%)
Test set: Accuracy: 7059/34000 (21%)


Train set: Accuracy: 14530/34000 (43%)
Val set: Accuracy: 9628/34000 (28%)
Test set: Accuracy: 7683/34000 (23%)


Train set: Accuracy: 14231/34000 (42%)
Val set: Accuracy: 9212/34000 (27%)
Test set: Accuracy: 8453/34000 (25%)




In [None]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

model = MLP(input_dims=train_dataset.num_bins, n_hiddens=[20, 20, 20], n_class=n_class, dropout_p=0.2).to(device)
print(model)

lr = 1e-2
optimizer = optim.Adam(model.parameters(), lr=lr)

epochs=100
for epoch in range(1, epochs + 1):
    train(model, device, train_loader2, optimizer, epoch, log_interval=5)
    test(model, device, train_loader2, 'Train')
    test(model, device, val_loader2, 'Val')
    test(model, device, test_loader2, 'Test')
    print('\n')