In [6]:
from torch.utils.data import DataLoader
import torch
import torch.nn.functional as F
import torch.optim as optim

from src.data import Dataset
from src.covariance import compute_cov, compute_edge_dist
from src.network import MLP
from src.dataloader import EdgeDistributionDataset

### Loading dataset

In [7]:
dataset = Dataset('/Users/mehdiazabou/Documents/Nerds/cell_type/data', force_process=False)
aggr_dict = {'e23Cux2': 'e23', 'i5Sst': 'i5Sst', 'i5Htr3a': 'i5Htr3a', 'e4Scnn1a': 'e4', 'e4Rorb': 'e4',
             'e4other': 'e4', 'e4Nr5a1': 'e4', 'i6Htr3a': 'i6Htr3a', 'i6Sst': 'i6Sst', 'e6Ntsr1': 'e6',
             'i23Pvalb': 'i23Pvalb', 'i23Htr3a': 'i23Htr3a', 'i1Htr3a': 'i1Htr3a', 'i4Sst': 'i4Sst', 'e5Rbp4': 'e5',
             'e5noRbp4': 'e5', 'i23Sst': 'i23Sst', 'i4Htr3a': 'i4Htr3a', 'i6Pvalb': 'i6Pvalb', 'i5Pvalb': 'i5Pvalb',
             'i4Pvalb': 'i4Pvalb'}
dataset.aggregate_cell_classes(aggr_dict)
# Split into train/val/test sets
dataset.split_cell_train_val_test(test_size=0.2, val_size=0.2, seed=1234)
dataset.split_trial_train_val_test(test_size=0.2, val_size=0.2, temp=True, seed=1234)

# bining
dataset.set_bining_parameters(bin_size=0.1) # in seconds, so this is 10ms

Found processed pickle. Loading from '/Users/mehdiazabou/Documents/Nerds/cell_type/data/processed/dataset.pkl'.


In [30]:
# Create Pytorch datasets
train_dataset = EdgeDistributionDataset(dataset, num_bins=10, dataset_size=100, mode='train', sampler='U20')
val_dataset = EdgeDistributionDataset(dataset, num_bins=10, dataset_size=1, mode='val', sampler='U20', cell_random_seed=2, trial_id=70)
test_dataset = EdgeDistributionDataset(dataset, num_bins=10, dataset_size=1, mode='test', sampler='U20', cell_random_seed=1, trial_id=90)

train_loader = DataLoader(train_dataset, shuffle=True, batch_size=10)
val_loader = DataLoader(val_dataset, batch_size=1)
test_loader = DataLoader(test_dataset, batch_size=1)

In [33]:
def train(model, device, train_loader, optimizer, epoch, log_interval=None):
    model.train()
    for batch_idx, (data, target) in enumerate(train_loader):
        data, target = data.to(device), target.to(device)
        # concat batches
        data = data.view((-1, data.size(2))).float()
        target = target.view((-1,)).long()
        optimizer.zero_grad()
        output = model(data)
        loss = F.nll_loss(output, target)
        loss.backward()
        optimizer.step()
        if log_interval and batch_idx % log_interval == 0:
            print('Train Epoch: {} [({:.0f}%)]\tLoss: {:.6f}'.format(
                epoch, 100. * batch_idx / len(train_loader), loss.item()))

def test(model, device, loader, tag):
    model.eval()
    test_loss = 0
    correct = 0
    total = 0
    with torch.no_grad():
        for data, target in loader:
            data, target = data.to(device), target.to(device)
            # concat batches
            data = data.view((-1, data.size(2))).float()
            target = target.view((-1,)).long()
            output = model(data)
            test_loss += F.nll_loss(output, target, reduction='sum').item()  # sum up batch loss
            pred = output.argmax(dim=1, keepdim=True)  # get the index of the max log-probability
            correct += pred.eq(target.view_as(pred)).sum().item()
            total += len(target)

    print('{} set: Accuracy: {}/{} ({:.0f}%)'.format(tag,
        correct, total,
        100. * correct / total))

In [None]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

model = MLP(input_dims=train_dataset.num_bins, n_hiddens=[20, 20, 20], n_class=17, dropout_p=0.2).to(device)
print(model)

lr = 1e-2
optimizer = optim.Adam(model.parameters(), lr=lr)

epochs=100
for epoch in range(1, epochs + 1):
    train(model, device, train_loader, optimizer, epoch, log_interval=5)
    test(model, device, train_loader, '\nTrain')
    test(model, device, val_loader, 'Val')
    test(model, device, test_loader, 'Test')

MLP(
  (model): Sequential(
    (fc1): Linear(in_features=10, out_features=20, bias=True)
    (relu1): ReLU()
    (drop1): Dropout(p=0.2, inplace=False)
    (fc2): Linear(in_features=20, out_features=20, bias=True)
    (relu2): ReLU()
    (drop2): Dropout(p=0.2, inplace=False)
    (fc3): Linear(in_features=20, out_features=20, bias=True)
    (relu3): ReLU()
    (drop3): Dropout(p=0.2, inplace=False)
    (out): Linear(in_features=20, out_features=17, bias=True)
  )
)

Train set: Accuracy: 2001/34000 (6%)
Val set: Accuracy: 20/340 (6%)
Test set: Accuracy: 20/340 (6%)

Train set: Accuracy: 1965/34000 (6%)
Val set: Accuracy: 18/340 (5%)
Test set: Accuracy: 18/340 (5%)

Train set: Accuracy: 3678/34000 (11%)
Val set: Accuracy: 39/340 (11%)
Test set: Accuracy: 39/340 (11%)

Train set: Accuracy: 3619/34000 (11%)
Val set: Accuracy: 36/340 (11%)
Test set: Accuracy: 37/340 (11%)

Train set: Accuracy: 3609/34000 (11%)
Val set: Accuracy: 36/340 (11%)
Test set: Accuracy: 38/340 (11%)

Train set: Acc