In [1]:
%load_ext autoreload
%autoreload 2

%autosave 10

#%load_ext lab_black

Autosaving every 10 seconds


In [2]:
import sys
import os

sys.path.insert(0, os.path.abspath(os.path.join("..")))

## Imports

In [3]:
from lfp_analysis.data import *
from lfp_analysis.process import *
from lfp_analysis.resnet2d import *
from lfp_analysis.resnet1d import *
from lfp_analysis.svm import *
from lfp_analysis.report import *

from fastai.vision.all import *
import torch.nn.functional as F
from torchvision.transforms import ToPILImage, ToTensor

In [4]:
import numpy as np
import pandas as pd
import h5py

from pathlib import Path

import matplotlib.pyplot as plt
import matplotlib.patches as patches
from matplotlib.colors import to_hex
%matplotlib widget

import seaborn as sns


In [5]:
def plot_save_modelPs(ts,LFPs,label):
    
    fig,ax = plt.subplots(LFPs.shape[0],1)
    
    for i in range(LFPs.shape[0]):
        ax[i].plot(ts,LFPs[i])
        ax[i].plot(ts,label*0.5*np.max(LFPs[i]))
    
    return fig,ax


def get_metric(results,metric):
    return {'svm':results['svm_scores'][metric],
            'b-lda':results['b_lda_scores'][metric],
            'th-lda':results['th_lda_scores'][metric],
            '1d-cnn':results['1d_cnn_scores'][metric],
            '2d-cnn':results['2d_cnn_scores'][metric]}


def plot_roc_instance(roc_curve_cont,ax=None,color=None,label=None):
    
    mean_fpr = np.linspace(0, 1, 100)
    if ax is None:
        fig,ax = plt.subplots()
    
    ax.plot([0,1],[0,1],color='k',alpha=0.5,linestyle='--',linewidth=0.4)
    ax.set_xlabel("False Positive Rate")
    ax.set_ylabel("True Positve Rate")
    
    ax.plot(mean_fpr,roc_curve_cont['mean'],color=color,label=label)
    ax.fill_between(mean_fpr,
                    roc_curve_cont['mean']-roc_curve_cont['std'],
                    roc_curve_cont['mean']+roc_curve_cont['std'],
                    alpha=0.2,color=color
                   )
    
    
def plot_conf_mats(conf_mats):
    
    fig,ax = plt.subplots(5,1,figsize=(4,18))
    
    ml_keys = ['svm','b-lda','th-lda']
    nn_keys = ['1d-cnn','2d-cnn']
    
    max_val = max([np.max(conf_mats[key]["mean"]) for key in ml_keys]+[np.max(conf_mats[key]) for key in nn_keys])
    min_val = min([np.min(conf_mats[key]["mean"]) for key in ml_keys]+[np.min(conf_mats[key]) for key in nn_keys])
    
    for ii,cls_name in enumerate(ml_keys):
        means, stds = conf_mats[cls_name]["mean"], conf_mats[cls_name]["std"]
        annot = pd.DataFrame([[f"{mean:.1f} +/- {std:.2f}" for mean,std in zip(mean_outer,std_outer)] for mean_outer,std_outer in zip(means,stds)])
    
        df = pd.DataFrame(conf_mats[cls_name]['mean'],index=[0,1],columns=[0,1])

        sns.heatmap(df,annot=annot,vmin=min_val,vmax=max_val,ax=ax[ii],fmt="s",cmap="rocket")
        ax[ii].set_title(cls_name)
        ax[ii].set_ylabel("True Label")
    
    for ii,cls_name in enumerate(nn_keys):
        df = pd.DataFrame(conf_mats[cls_name],index=[0,1],columns=[0,1])

        sns.heatmap(df,annot=True,vmin=min_val,vmax=max_val,ax=ax[3+ii],fmt=".2f",cmap="rocket")
        ax[3+ii].set_title(cls_name)
        ax[3+ii].set_ylabel("True Label")

    ax[-1].set_xlabel("Predicted Label")
    return fig,ax


def plot_svm_coefs(svm_coefs, fig=None, label=None, **kwargs):
    
    color_dict = {0:to_hex('C0'),1:to_hex('C1'),2:to_hex('C2')}
    
    cls_id = 0 if label == 'SVM' else (1 if label =='b-LDA' else 2)

    if fig is None:
        fig = make_subplots(rows=n_chan,cols=1)
    
    coefs_df = pd.DataFrame(np.stack([svm_coefs["mean"],svm_coefs["std"]]).T,
                            columns=["mean","std"],
                            index=svm_coefs["names"])      
    
    df_cont = [coefs_df.iloc[[f'pow{i}' in coef for coef in coefs_df.index],:] for i in range(n_chan)]
    
    for i in range(n_chan):
        this_df = df_cont[i]
    
        fig.add_trace(go.Scatter(
            name=label,
            showlegend=True if i==0 else False,
            x=this_df.index.values,
            y=this_df["mean"],
            mode='markers',
            marker={"color":color_dict[cls_id]},
            error_y=dict(
                type='data',
                array=this_df["std"],
                visible=True)
        ),row=i+1,col=1)
    
    fig.update_layout(showlegend=True)
    
    return fig

# Import data:

In [45]:
EXPERIMENT = "add_reg"

In [46]:
data1 = Patient(2).Pegboard_off.load_1d()
data1

Dataset: ET2 - Pegboard_off
    n_chan = 5, duration = 12.58m


In [47]:
data2 = Patient(2).Pouring_off.load_1d()
data2

Dataset: ET2 - Pouring_off
    n_chan = 5, duration = 5.11m


## Make label DF:

In [48]:
WIN_LEN_SEC = 0.750
windower = Windower(WIN_LEN_SEC).window(data1.label)

data_df = windower.data_df
data_decim_df = windower.data_decim_df

windower



Windower object
    Overall class balance: 
        ['0 -> 558', '1 -> 447']
        ['0 -> 0.56%', '1 -> 0.44%']
    On Valid:
        ['0 -> 120', '1 -> 81']
        ['0 -> 0.60%', '1 -> 0.40%']

In [49]:
WIN_LEN_SEC = 0.750
windower2 = Windower(WIN_LEN_SEC).window(data2.label)

windower2



Windower object
    Overall class balance: 
        ['0 -> 231', '1 -> 176']
        ['0 -> 0.57%', '1 -> 0.43%']
    On Valid:
        ['0 -> 44', '1 -> 37']
        ['0 -> 0.54%', '1 -> 0.46%']

## Normalize LFP data:

## Baseline classifier:

In [60]:
bl_cls = BLClassifier(data1.LFP, data_df, extract_method='periodogram')

svm_scores = bl_cls.classify_many(method="SVM")
#y_hat_svm = bl_cls.y_pred

#b_scores = bl_cls.classify_many(method="beta")
#y_hat_b = bl_cls.y_pred

#th_scores = bl_cls.classify_many(method="theta")
#y_hat_th = bl_cls.y_pred


In [65]:
svm_scores

Score Summary Object
        Train: Scorer Object --- ds_type = train --- n_runs = 10 --- n_samp = 804

        Acc: {'mean': 0.8911691542288558, 'std': 0.0011467095096135212}
        AUC: {'mean': 0.9812485964518303, 'std': 0.0003294344867757392}
        precision: {'mean': 0.988569272186863, 'std': 0.0008453400054939021}
        recall: {'mean': 0.8095890410958905, 'std': 0.00151444054354126}
        loss: {'mean': 0.34, 'std': 0.01}

        
        Valid: Scorer Object --- ds_type = valid --- n_runs = 10 --- n_samp = 201

        Acc: {'mean': 0.9149253731343284, 'std': 0.0014925373134328623}
        AUC: {'mean': 0.986820987654321, 'std': 0.0004679656549606575}
        precision: {'mean': 0.990467032967033, 'std': 2.7472527472527374e-05}
        recall: {'mean': 0.8658333333333333, 'std': 0.0025000000000000243}
        loss: {'mean': 0.33, 'std': 0.01}

        
        

In [62]:
bl_cls2 = BLClassifier(data2.LFP, windower2.data_df, extract_method='periodogram')

ret = bl_cls2.classify();

In [64]:
bl_cls.cross_task(bl_cls2)

Scorer Object --- ds_type = train --- n_runs = 1 --- n_samp = 326

        Acc: 0.93
        AUC: 0.96
        precision: 0.95
        recall: 0.88
        loss: 0.366

        

In [10]:
svm_scores.lin_svm_acc

(0.7056910569105691, 0.018106550773430954)

In [11]:
svm_scores, b_scores, th_scores

(Score Summary Object
         Train: Scorer Object --- ds_type = train --- n_runs = 10 --- n_samp = 492
 
         Acc: {'mean': 0.8485772357723578, 'std': 0.008871200342599654}
         AUC: {'mean': 0.9230105683090706, 'std': 0.0068171322723862265}
         precision: {'mean': 0.9068380693440377, 'std': 0.008287763149105782}
         recall: {'mean': 0.8414473684210526, 'std': 0.012290487955440393}
         loss: {'mean': 0.45, 'std': 0.06}
 
         
         Valid: Scorer Object --- ds_type = valid --- n_runs = 10 --- n_samp = 123
 
         Acc: {'mean': 0.6951219512195121, 'std': 0.02525564970245377}
         AUC: {'mean': 0.801923076923077, 'std': 0.0059428649522946545}
         precision: {'mean': 0.8481273357679132, 'std': 0.026067974628509474}
         recall: {'mean': 0.5746478873239436, 'std': 0.03551977524765771}
         loss: {'mean': 0.59, 'std': 0.05}
 
         
         ,
 Score Summary Object
         Train: Scorer Object --- ds_type = train --- n_runs = 10 --- n_

## Prepare and train 1D-CNN:

In [18]:
trainer = Trainer1d(log_wandb=False, layers=[2],wd=50, experiment=EXPERIMENT).prepare_dls(data1,windower).prepare_learner()

In [13]:
trainer.learn.lr_find()

Canvas(toolbar=Toolbar(toolitems=[('Home', 'Reset original view', 'home', 'home'), ('Back', 'Back to previous …

SuggestedLRs(lr_min=0.00831763744354248, lr_steep=0.0012022644514217973)

In [19]:
trainer.train(250, 0.3*1e-2)

epoch,train_loss,train_accuracy,valid_loss,valid_accuracy,time
0,0.661461,0.617188,0.71657,0.577236,00:01
1,0.656633,0.640625,0.705808,0.577236,00:01
2,0.655683,0.644531,0.697069,0.577236,00:01
3,0.65825,0.601562,0.690953,0.577236,00:01
4,0.660571,0.59375,0.686379,0.577236,00:01
5,0.663079,0.578125,0.682497,0.577236,00:01
6,0.665396,0.570312,0.679327,0.577236,00:01
7,0.664424,0.621094,0.677067,0.577236,00:01
8,0.659826,0.671875,0.675359,0.577236,00:01
9,0.659827,0.609375,0.674155,0.577236,00:01


epoch,train_loss,train_accuracy,valid_loss,valid_accuracy,time
0,0.145693,1.0,0.616413,0.674797,00:01
1,0.148609,1.0,0.612073,0.682927,00:01
2,0.146031,1.0,0.608274,0.674797,00:01
3,0.148052,1.0,0.606329,0.666667,00:01
4,0.147551,1.0,0.604009,0.658536,00:01
5,0.147091,1.0,0.605341,0.650406,00:01
6,0.146162,1.0,0.613194,0.666667,00:01
7,0.145793,1.0,0.649709,0.666667,00:01


No improvement since epoch 4: early stopping


epoch,train_loss,train_accuracy,valid_loss,valid_accuracy,time
0,0.13904,1.0,0.64004,0.658536,00:01
1,0.147836,1.0,0.634752,0.674797,00:01
2,0.146886,1.0,0.629685,0.666667,00:01
3,0.146396,1.0,0.626183,0.666667,00:01
4,0.146753,1.0,0.624135,0.666667,00:01
5,0.145432,1.0,0.620805,0.666667,00:01
6,0.144239,1.0,0.616304,0.674797,00:01
7,0.143756,1.0,0.615708,0.682927,00:01
8,0.142753,1.0,0.621034,0.682927,00:01
9,0.141914,1.0,0.627698,0.682927,00:01


No improvement since epoch 6: early stopping


epoch,train_loss,train_accuracy,valid_loss,valid_accuracy,time
0,0.14018,1.0,0.620952,0.674797,00:01
1,0.140242,1.0,0.61687,0.674797,00:01
2,0.144099,0.996094,0.615047,0.666667,00:01
3,0.14139,1.0,0.612015,0.666667,00:01
4,0.139707,1.0,0.609608,0.658536,00:01
5,0.13958,1.0,0.609432,0.658536,00:01
6,0.140088,1.0,0.611042,0.666667,00:01
7,0.1397,1.0,0.612184,0.658536,00:01


No improvement since epoch 4: early stopping


epoch,train_loss,train_accuracy,valid_loss,valid_accuracy,time
0,0.134312,1.0,0.609892,0.658536,00:01
1,0.136774,1.0,0.608265,0.650406,00:01
2,0.135686,1.0,0.606718,0.650406,00:01
3,0.134418,1.0,0.605819,0.642276,00:01
4,0.134765,1.0,0.605466,0.642276,00:01
5,0.133686,1.0,0.605027,0.642276,00:01
6,0.133276,1.0,0.604898,0.650406,00:01
7,0.133166,1.0,0.605026,0.650406,00:01


No improvement since epoch 4: early stopping


In [None]:
trainer.save_model()

In [20]:
cnn_1d_scores = trainer.score()
cnn_1d_scores

Score Summary Object
        Train: Scorer Object --- ds_type = train --- n_runs = 1 --- n_samp = 494

        Acc: 1.00
        AUC: 1.00
        precision: 1.00
        recall: 1.00
        loss: 0.143

        
        Valid: Scorer Object --- ds_type = train --- n_runs = 1 --- n_samp = 123

        Acc: 0.65
        AUC: 0.72
        precision: 0.66
        recall: 0.80
        loss: 0.605

        
        

In [None]:
data2 = Patient(1).Posture_off.load_2d()

In [None]:
trainer2 = Trainer2d(log_wandb=False, wd=25, experiment=EXPERIMENT).prepare_dls(data2,windower).prepare_learner()

In [None]:
trainer2.learn.lr_find()

In [None]:
trainer2.train(60,0.5*1e-2)

In [None]:
trainer2.save_model()

In [None]:
cnn_2d_scores = trainer2.score()
cnn_2d_scores

In [None]:
rec = Reporter(svm_scores, cnn_1d_scores, cnn_2d_scores, b_scores,th_scores, data1, windower, experiment=EXPERIMENT)

In [None]:
import pickle
rec_load = pickle.load(open("./../data/results/ET1/run_all/Pegboard/results.p", "rb"))

In [None]:
rec_load

In [None]:
rec.save()

In [None]:
rec.save_plots();

In [None]:
plt.close('all')

In [None]:
rec.plot_svm_coefs()

In [None]:
rec.plot_losses()

In [None]:
rec.plot_roc_curves()

In [None]:
rec.plot_accs()

In [None]:
rec.plot_conf_mats()

In [None]:
plt.close('all')

In [None]:
model_save = False

if model_save:
    MODEL_DIR = DATA_PATH / 'results' / PAT_ID / "trained"
    learn.model_dir = MODEL_DIR

    learn.save(DATASET_NAME)

### Get validation preds:

In [None]:
y = df_data[df_data["is_valid"]==1]["label"].astype(float).values

In [None]:
preds_val = learn.get_preds()
y_pred = torch.argmax(preds_val[0],-1).numpy()
y_score = preds_val[0][:,1].numpy()

In [None]:
y.shape

In [None]:
# CNN performance on Train:
# preds_val = learn.get_preds(ds_idx=0)


In [None]:
wandb.run.finish()

In [None]:
preds_val[0].shape

In [None]:
cnn1d_scores = get_scores(y,y_pred,y_score)
cnn1d_scores;

## 2D-CNN:

In [None]:
from lfp_analysis.resnet2d import *

In [None]:
TF.shape

In [None]:
dblock.summary(df_data)

In [None]:
np.repeat(1,5)

In [None]:
df_data = make_label_df(label,WIN_LEN_SEC)

def get_x(row):
    return torch.tensor(TF[:,:, row["id_start"] : row["id_end"]]).float()

def get_y(row):
    return row["label"]

def splitter(df):
    train = df.index[df["is_valid"] == 0].tolist()
    valid = df.index[df["is_valid"] == 1].tolist()
    return train, valid


def LFP_block2d():
    return TransformBlock(
        item_tfms=[Resizer((160, 160)), IntToFloatTensor],
        batch_tfms=LFPNormalizer2d(([0.5, 0.5, 0.5, 0.5, 0.5], [0.098, 0.098, 0.098, 0.098, 0.098])),
    )

dblock = DataBlock(
    blocks=(LFP_block2d, CategoryBlock), get_x=get_x, get_y=get_y, splitter=splitter,
)

dls = dblock.dataloaders(df_data, bs=32)
xb, yb = dls.one_batch()
yb.shape, xb.shape

In [None]:
wandb.init(project='lfp-decoding')
wandb.run.name = str(PAT_ID)+'/'+str(DATASET_NAME)+'_2D'
resnet2d = ResNet2d(TF.shape[0], 2, [2, 2, 1])

learn2d = Learner(
    dls,
    resnet2d,
    wd=0.3,
    metrics=[accuracy],
    loss_func=F.cross_entropy,
    cbs=[WandbCallback(),EarlyStoppingCallback(min_delta=0.01,patience=4)],)

learn2d.recorder.train_metrics = True

In [None]:
init_loss = learn2d.loss_func(learn2d.model(xb), yb)
init_loss

In [None]:
learn2d.lr_find(start_lr=1e-5,end_lr=0.05,num_it=100)

In [None]:
learn2d.fit_one_cycle(8, 10e-4)

In [None]:
learn2d.fit_one_cycle(8, 10e-6)

In [None]:
learn2d.cbs

In [None]:
[learn2d.remove_cb(cb) for cb in learn2d.cbs[3:]]

In [None]:
preds_2d_val = learn2d.get_preds()
y_2d_pred = torch.argmax(preds_2d_val[0],-1).numpy()
y_2d_score = preds_2d_val[0][:,1].numpy()

In [None]:
cnn2d_scores = get_scores(y,y_2d_pred,y_2d_score)
cnn2d_scores;

# Process and Visualize results:

In [None]:
# Persist results to disk:
import pickle

PKL_TARGET = DATA_PATH / 'results' / PAT_ID / DATASET_NAME.with_suffix('.p')
FIG_TARGET = DATA_PATH / 'results' / PAT_ID / DATASET_NAME.with_suffix('.png')

data_container = {"win_len_sec":WIN_LEN_SEC,
                  "svm_scores":svm_scores,
                  "b_lda_scores":b_lda_scores, 
                  "th_lda_scores":th_lda_scores,
                  "1d_cnn_scores":cnn1d_scores,
                  "2d_cnn_scores":cnn2d_scores
                }

pickle.dump(data_container,open(PKL_TARGET,"wb"))

wandb.log(data_container)

## ROC Curves:

In [None]:
fig,ax = plt.subplots()
plot_roc_instance(svm_scores["roc_curve"],ax=ax,label=f"SVM (AUC: {svm_scores['AUC']['valid']['mean']:0.2f} +/- {svm_scores['AUC']['valid']['std']:0.2f})")
plot_roc_instance(b_lda_scores["roc_curve"],ax=ax,label=f"b-LDA (AUC: {b_lda_scores['AUC']['valid']['mean']:0.2f} +/- {b_lda_scores['AUC']['valid']['std']:0.2f})")
plot_roc_instance(th_lda_scores["roc_curve"],ax=ax,label=f"th-LDA (AUC: {th_lda_scores['AUC']['valid']['mean']:0.2f} +/- {th_lda_scores['AUC']['valid']['std']:0.2f})")

ax.plot(cnn1d_scores["roc_curve"][0],cnn1d_scores["roc_curve"][1],label=f"1d-CNN (AUC: {cnn1d_scores['AUC']['valid']:.2f})")
ax.plot(cnn2d_scores["roc_curve"][0],cnn2d_scores["roc_curve"][1],label=f"2d-CNN (AUC: {cnn2d_scores['AUC']['valid']:.2f})")

ax.legend()

## Confusion Mats:

In [None]:
fig_c,ax_c = plot_conf_mats(get_metric(data_container,'conf_mat_norm'))

In [None]:
wandb.log({'conf_mat':wandb.Image(fig_c)})
wandb.log({'roc_curves':wandb.Image(fig)})

In [None]:
fig,ax = plt.subplots(figsize=(8,5))
ax.plot(y,label='Ground Truth')
ax.plot(y_pred*0.8,label='1D-CNN')
ax.plot(y_2d_pred*0.6,label='2D-CNN')
ax.plot(y_hat_svm*0.4,label='SVM')
ax.plot(y_hat_blda*0.2,label='b-LDA')
ax.plot(y_hat_thlda*0.1,label='th-LDA')

ax.legend(bbox_to_anchor=(1, 1))
#plt.legend(handles=[p1, p2], title='title', bbox_to_anchor=(1.05, 1), loc='upper left', prop=fontP)
plt.tight_layout()

In [None]:
wandb.log({'timeseries_preds':wandb.Image(fig)})