In [2]:
from torch.utils.data import DataLoader
import torch
import torch.nn.functional as F
import torch.optim as optim

import sys
sys.path.append('../')

from src import Dataset, EdgeDistributionDataset
from src.network import MLP

### Loading dataset

In [3]:
dataset = Dataset('../data', force_process=False)
aggr_dict = {'e23Cux2': 'e23', 'i5Sst': 'i5Sst', 'i5Htr3a': 'i5Htr3a', 'e4Scnn1a': 'e4', 'e4Rorb': 'e4',
             'e4other': 'e4', 'e4Nr5a1': 'e4', 'i6Htr3a': 'i6Htr3a', 'i6Sst': 'i6Sst', 'e6Ntsr1': 'e6',
             'i23Pvalb': 'i23Pvalb', 'i23Htr3a': 'i23Htr3a', 'i1Htr3a': 'i1Htr3a', 'i4Sst': 'i4Sst', 'e5Rbp4': 'e5',
             'e5noRbp4': 'e5', 'i23Sst': 'i23Sst', 'i4Htr3a': 'i4Htr3a', 'i6Pvalb': 'i6Pvalb', 'i5Pvalb': 'i5Pvalb',
             'i4Pvalb': 'i4Pvalb'}
dataset.aggregate_cell_classes(aggr_dict)
# Split into train/val/test sets
dataset.split_cell_train_val_test(test_size=0.2, val_size=0.2, seed=1234)
dataset.split_trial_train_val_test(test_size=0.2, val_size=0.2, temp=True, seed=1234)

# bining
dataset.set_bining_parameters(bin_size=0.1) # in seconds, so this is 10ms

Found processed pickle. Loading from '../data/processed/dataset.pkl'.


In [3]:
# Create Pytorch datasets
train_dataset = ISIDistributionDataset(dataset, edge_dist_num_bins=10, mode='train', sampler='U20')
# fix population for validation set and test set (they will be different of course)
val_dataset = EdgeDistributionDataset(dataset, edge_dist_num_bins=10, mode='val', sampler='U20', cell_random_seed=1)
test_dataset = EdgeDistributionDataset(dataset, edge_dist_num_bins=10, mode='test', sampler='U20', cell_random_seed=1)

train_loader = DataLoader(train_dataset, shuffle=True, batch_size=5, collate_fn=train_dataset.collate_fn)
val_loader = DataLoader(val_dataset, batch_size=1, collate_fn=train_dataset.collate_fn)
test_loader = DataLoader(test_dataset, batch_size=1, collate_fn=train_dataset.collate_fn)

In [4]:
print(test_dataset)

<src.pytorch_dataset.EdgeDistributionDataset object at 0x7faff759b3d0>


In [6]:
def train(model, device, train_loader, optimizer, epoch, log_interval=None):
    model.train()
    for batch_idx, (data, target) in enumerate(train_loader):
        data, target = data.to(device), target.to(device)
        optimizer.zero_grad()
        output = model(data)
        loss = F.nll_loss(output, target)
        loss.backward()
        optimizer.step()
        if log_interval and batch_idx % log_interval == 0:
            print('Train Epoch: {} [({:.0f}%)]\tLoss: {:.6f}'.format(
                epoch, 100. * batch_idx / len(train_loader), loss.item()))

def test(model, device, loader, tag):
    model.eval()
    test_loss = 0
    correct = 0
    total = 0
    with torch.no_grad():
        for data, target in loader:
            data, target = data.to(device), target.to(device)
            output = model(data)
            test_loss += F.nll_loss(output, target, reduction='sum').item()  # sum up batch loss
            pred = output.argmax(dim=1, keepdim=True)  # get the index of the max log-probability
            correct += pred.eq(target.view_as(pred)).sum().item()
            total += len(target)

    print('{} set: Accuracy: {}/{} ({:.0f}%)'.format(tag,
        correct, total,
        100. * correct / total))

In [7]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

model = MLP(input_dims=train_dataset.num_bins, n_hiddens=[20, 20, 20], n_class=17, dropout_p=0.2).to(device)
print(model)

lr = 1e-2
optimizer = optim.Adam(model.parameters(), lr=lr)

epochs=100
for epoch in range(1, epochs + 1):
    train(model, device, train_loader, optimizer, epoch, log_interval=5)
    test(model, device, train_loader, 'Train')
    test(model, device, val_loader, 'Val')
    test(model, device, test_loader, 'Test')
    print('\n')

MLP(
  (model): Sequential(
    (fc1): Linear(in_features=10, out_features=20, bias=True)
    (relu1): ReLU()
    (drop1): Dropout(p=0.2, inplace=False)
    (fc2): Linear(in_features=20, out_features=20, bias=True)
    (relu2): ReLU()
    (drop2): Dropout(p=0.2, inplace=False)
    (fc3): Linear(in_features=20, out_features=20, bias=True)
    (relu3): ReLU()
    (drop3): Dropout(p=0.2, inplace=False)
    (out): Linear(in_features=20, out_features=17, bias=True)
  )
)
Train set: Accuracy: 2011/20400 (10%)
Val set: Accuracy: 609/6800 (9%)
Test set: Accuracy: 713/6800 (10%)


Train set: Accuracy: 1550/20400 (8%)
Val set: Accuracy: 515/6800 (8%)
Test set: Accuracy: 498/6800 (7%)


Train set: Accuracy: 2032/20400 (10%)
Val set: Accuracy: 691/6800 (10%)
Test set: Accuracy: 715/6800 (11%)


Train set: Accuracy: 2024/20400 (10%)
Val set: Accuracy: 632/6800 (9%)
Test set: Accuracy: 677/6800 (10%)


Train set: Accuracy: 2171/20400 (11%)
Val set: Accuracy: 744/6800 (11%)
Test set: Accuracy: 776/68

Train set: Accuracy: 2542/20400 (12%)
Val set: Accuracy: 683/6800 (10%)
Test set: Accuracy: 766/6800 (11%)


Train set: Accuracy: 2569/20400 (13%)
Val set: Accuracy: 738/6800 (11%)
Test set: Accuracy: 819/6800 (12%)


Train set: Accuracy: 2596/20400 (13%)
Val set: Accuracy: 813/6800 (12%)
Test set: Accuracy: 791/6800 (12%)


Train set: Accuracy: 2465/20400 (12%)
Val set: Accuracy: 745/6800 (11%)
Test set: Accuracy: 633/6800 (9%)


Train set: Accuracy: 2652/20400 (13%)
Val set: Accuracy: 818/6800 (12%)
Test set: Accuracy: 812/6800 (12%)


Train set: Accuracy: 2639/20400 (13%)
Val set: Accuracy: 791/6800 (12%)
Test set: Accuracy: 815/6800 (12%)


Train set: Accuracy: 2641/20400 (13%)
Val set: Accuracy: 793/6800 (12%)
Test set: Accuracy: 841/6800 (12%)


Train set: Accuracy: 2545/20400 (12%)
Val set: Accuracy: 786/6800 (12%)
Test set: Accuracy: 833/6800 (12%)


Train set: Accuracy: 2578/20400 (13%)
Val set: Accuracy: 788/6800 (12%)
Test set: Accuracy: 800/6800 (12%)


Train set: Accuracy:

Train set: Accuracy: 2527/20400 (12%)
Val set: Accuracy: 831/6800 (12%)
Test set: Accuracy: 816/6800 (12%)


Train set: Accuracy: 2566/20400 (13%)
Val set: Accuracy: 801/6800 (12%)
Test set: Accuracy: 806/6800 (12%)


Train set: Accuracy: 2648/20400 (13%)
Val set: Accuracy: 827/6800 (12%)
Test set: Accuracy: 813/6800 (12%)


Train set: Accuracy: 2530/20400 (12%)
Val set: Accuracy: 794/6800 (12%)
Test set: Accuracy: 850/6800 (12%)


Train set: Accuracy: 2578/20400 (13%)
Val set: Accuracy: 831/6800 (12%)
Test set: Accuracy: 813/6800 (12%)


Train set: Accuracy: 2576/20400 (13%)
Val set: Accuracy: 812/6800 (12%)
Test set: Accuracy: 808/6800 (12%)


Train set: Accuracy: 2636/20400 (13%)
Val set: Accuracy: 844/6800 (12%)
Test set: Accuracy: 831/6800 (12%)


Train set: Accuracy: 2677/20400 (13%)
Val set: Accuracy: 824/6800 (12%)
Test set: Accuracy: 816/6800 (12%)


Train set: Accuracy: 2656/20400 (13%)
Val set: Accuracy: 821/6800 (12%)
Test set: Accuracy: 814/6800 (12%)


Train set: Accuracy