In [18]:
import json
import time
import os
import sys
import pickle5 as pickle
import time
import traceback
from collections import Counter
from functools import partial

import numpy as np
import pandas as pd
import torch
from sklearn.metrics import roc_auc_score
from spatiotemporal_cnn.dataset_bci import DatasetBuilder, SingleSubjectData
from spatiotemporal_cnn.models_bci import SpatiotemporalCNN
from spatiotemporal_cnn.utils import deterministic, evaluate, run_inference
from torch.utils.data import DataLoader
from utils import constants

In [2]:
results_dir = '/shared/rsaas/nschiou2/mindportal/bci/mot/motor_LR/spatiotemporal_cnn/voxel_space/max_abs_scale/baseline_ph_viable_vox/hyperopt/s_2804/hyperopt_search/'
model_dir = os.path.join(
    results_dir,
    'train_with_configs_fff29d68_10_D=9,F1=16,F2=7,T=75,batch_size=32,dropout=0.5,fs=52,l2=0.0001,lr=0.001_2022-06-08_11-49-44',
)
data_dir = os.path.join(constants.PH_SUBJECTS_DIR, 'bci', 'voxel_space', 'avg_rl_cropped_00_12')

In [3]:
subject = '2804'
train_submontages = ['abc']

In [7]:
bci_data = SingleSubjectData(
    data_dir=data_dir,
    subject_id=subject,
    train_submontages=train_submontages,
    classification_task='motor_LR',
    expt_type='mot',
    filter_zeros=True,
    input_space='voxel_space',
    data_type='ph'
)

num_features = bci_data.get_num_viable_features()

db = DatasetBuilder(
    data=bci_data, seed=42, seed_cv=15,
    max_abs_scale=True,
    impute='zero')

In [9]:
with open(os.path.join(model_dir, 'params.pkl'), 'rb') as pickle_file:
    model_p = pickle.load(pickle_file)
model_p['C'] = num_features

In [10]:
model_p

{'fs': 52,
 'T': 75,
 'F1': 16,
 'D': 9,
 'F2': 7,
 'lr': 0.001,
 'batch_size': 32,
 'dropout': 0.5,
 'l2': 0.0001,
 'C': 98}

In [19]:
model = SpatiotemporalCNN(
    C=model_p['C'],
    F1=model_p['F1'],
    D=model_p['D'],
    F2=model_p['F2'],
    p=model_p['dropout'],
    fs=model_p['fs'],
    T=model_p['T']
)
model_parameters = filter(
    lambda p: p.requires_grad, model.parameters())
params = sum([np.prod(p.size()) for p in model_parameters])
print(f'{params} trainable parameters', flush=True)

device = 'cpu'
if torch.cuda.is_available():
    device = 'cuda:0'
    if torch.cuda.device_count() > 1:
        model = torch.nn.DataParallel(model)
model = model.to(device)

19678 trainable parameters


In [20]:
def time_inference(model, device, data_loader):

    model.eval()
    pred = list()
    times = list()
    with torch.no_grad():
        for _, data, labels in data_loader:

            t0 = time.time()
            data = data.to(device) \
                if isinstance(data, torch.Tensor) \
                else [i.to(device) for i in data]
            labels = labels.to(device)

            outputs = model(data).squeeze()
            probabilities = torch.sigmoid(outputs)
            predicted = probabilities > 0.5
            t1 = time.time()

            pred.extend(predicted.data.tolist())
            times.extend(t1-t0)
    
    return times

In [None]:
for i, (inner_train_valids, test_dataset) in enumerate(db.build_datasets(cv=5, nested_cv=1)):

    for j, (train_dataset, valid_dataset) in enumerate(inner_train_valids):

        test_dataset.impute_chan(train_dataset)
        test_loader = DataLoader(test_dataset, batch_size=1, shuffle=False)

        time_inference(model, device, test_loader)