In [25]:
%load_ext autoreload
%autoreload 2

The autoreload extension is already loaded. To reload it, use:
  %reload_ext autoreload


In [26]:
import pickle
import os
import sys
sys.path.append('..')

In [27]:
import numpy as np
import pandas as pd
import torch
from torch.utils.data import Subset
from tqdm import tqdm, trange

In [28]:
from sklearn.model_selection import train_test_split
from sklearn.metrics import roc_auc_score
from sklearn.model_selection import KFold
from sklearn.linear_model import LogisticRegressionCV

In [29]:
from models.cnn import MyAlexNet
from models.resnet import ResNetBaseline
from jose.my_torch.helpers import train_one_epoch, eval, test, test_binary, test_regression, _test
from jose.eval.eval import dataset_train_test
from dataset import AccelLaughterDataset
from constants import cloud_data_path

In [30]:
examples = pd.read_csv('../dataset/computational_examples.csv')
examples = examples[examples['condition'] == 'av']
accel_ds_path = os.path.join(cloud_data_path, 'accel', 'accel_ds.pkl')
ds = AccelLaughterDataset(examples, accel_ds_path, label='pressed_key', example_len=60)

loaded 672 examples
467 have accel


In [31]:
(len(ds.examples_df), len(ds.accel))

(467, 359)

In [32]:
ds.accel['006f74addfc99845bf6c9f80d13d52ccc189341031525530762bb83dd8b713af'].shape

(49, 3)

In [33]:
def get_metrics(outputs, labels, model):
    if model in ['bce']:
        proba = torch.sigmoid(outputs)
        pred = (proba > 0.5)

        correct = pred.eq(outputs.bool()).sum().item()
        return {
            'auc': roc_auc_score(labels, proba),
            'correct': correct
        }
    elif model in ['l1', 'mse', 'mean_baseline']:
        return {
            'mse': torch.nn.functional.mse_loss(outputs, labels, reduction='mean'),
            'l1': torch.nn.functional.l1_loss(outputs, labels, reduction='mean')
        }

In [34]:
def do_fold(train_idx, test_idx, model='bce', logfile=None):
    # create datasets    
    train_ds = Subset(ds, train_idx)
    test_ds = Subset(ds, test_idx)
    
    # data loaders
    data_loader = torch.utils.data.DataLoader(
        train_ds, batch_size=100, shuffle=True, num_workers=0,
        collate_fn=None)
    data_loader_val = torch.utils.data.DataLoader(
        test_ds, batch_size=100, shuffle=False, num_workers=0,
        collate_fn=None)

    if model in ['bce', 'l1', 'mse']:
        return do_fold_cnn(data_loader, data_loader_val, model, logfile)
    elif model in ['mean_baseline']:
        return do_fold_mean_baseline(data_loader, data_loader_val)

def do_fold_mean_baseline(train_dl, test_dl):
    labels = []
    for batch_idx, (X, Y) in enumerate(train_dl):
        Y = Y.float()
        labels.append(Y.reshape(-1))

    labels = torch.cat(labels)
    return torch.full((len(test_dl.dataset),), labels.mean().item())

def do_fold_cnn(train_dl, test_dl, loss='bce', logfile=None):
    
    # model = MyAlexNet()
    model = ResNetBaseline(in_channels = 3)
    if loss == 'bce':
        loss_fn = torch.nn.BCEWithLogitsLoss(reduction='sum')
    elif loss == 'mse':
        loss_fn = torch.nn.MSELoss(reduction='sum')
    elif loss == 'l1':
        loss_fn = torch.nn.L1Loss(reduction='sum')
    else:
        raise Exception('unknown loss')
    optimizer = torch.optim.Adam(model.parameters(), lr=.001)
    
    device = torch.device('cuda')
    model = model.to(device)
    
    for epoch in range(15):
        try:
            train_one_epoch(model, loss_fn, device, train_dl, optimizer, epoch)
            eval_labels, eval_output, stats = _test(model, loss_fn, device, test_dl)
            eval_metrics = get_metrics(eval_output, eval_labels, loss)
            if logfile is not None:
                logfile.write(str(eval_metrics)+'\n')
        except KeyboardInterrupt:
            pass
        
    # testing
    all_labels, all_output, stats = _test(model, loss_fn, device, test_dl)

    metrics = get_metrics(all_output, all_labels, loss)
    return all_output

In [35]:

def do_run(model):
    seed = 22
    cv_splits = KFold(n_splits=10, random_state=seed, shuffle=True).split(range(len(ds)))

    fh = open('run_logs.log', 'w')

    outputs = torch.empty((len(ds),))
    for f, (train_idx, test_idx) in enumerate(cv_splits):
        fold_outputs = do_fold(train_idx, test_idx, model, logfile=fh)
        outputs[test_idx] = fold_outputs.cpu()

    labels = torch.Tensor(ds.get_all_labels())
    run_metrics = get_metrics(outputs, labels, model)
    fh.close()
    print(run_metrics)

    return outputs, run_metrics

In [36]:
perf = []
for i in range(10):
    _, run_metrics = do_run('bce')
    perf.append(run_metrics['auc'])
    
(np.mean(perf), np.std(perf))

{'auc': 0.7926544966018652, 'correct': 330}
{'auc': 0.7938794057215111, 'correct': 316}
{'auc': 0.7698948948948949, 'correct': 328}
{'auc': 0.8161450924608818, 'correct': 345}
{'auc': 0.808854907539118, 'correct': 336}
{'auc': 0.8093488225067172, 'correct': 327}
{'auc': 0.7558479532163743, 'correct': 347}
{'auc': 0.8108503240082188, 'correct': 337}
{'auc': 0.8005768926821558, 'correct': 318}
{'auc': 0.8224276908487435, 'correct': 340}


(0.7980480480480481, 0.019891141375635555)