In [1]:
%load_ext autoreload
%autoreload 2

%autosave 10

#%load_ext lab_black

Autosaving every 10 seconds


In [2]:
import sys
import os

sys.path.insert(0, os.path.abspath(os.path.join("..")))

In [3]:
import logging

logging.basicConfig(filename="mylog.log", format="%(message)s", level=logging.INFO)
logging.info("-- Starting run")

In [4]:
from lfp_analysis.data import *
from lfp_analysis.process import *
from lfp_analysis.resnet2d import *
from lfp_analysis.resnet1d import *
from lfp_analysis.svm import *
from lfp_analysis.report import *

from fastai.vision.all import *
import torch.nn.functional as F
from torchvision.transforms import ToPILImage, ToTensor

In [24]:
import numpy as np
import pandas as pd
import h5py
import json

from pathlib import Path

import matplotlib.pyplot as plt
import matplotlib.patches as patches
from matplotlib.colors import to_hex
from matplotlib.lines import Line2D


%matplotlib widget

import matplotlib

font = {"size": 9}
matplotlib.rc("font", **font)

import seaborn as sns

WIN_LEN_SEC = 0.750

In [6]:
pat_nums = list(range(1, 9))
tasks = ["Pegboard", "Pouring", "Posture"]
stim_conds = ["ON", "OFF"]

In [7]:
def pr(*args):
    print(args)

    logging.info("")
    logging.info(args)
    logging.info("--------")

# Import data: 

In [8]:
with open('within_cond_svm.json', 'r') as f:
    within_svm = json.load(f)
    
with open('within_cond_lda.json', 'r') as f:
    within_lda = json.load(f)
    
with open('within_cond_cnn.json', 'r') as f:
    within_cnn = json.load(f)

In [9]:
with open('across_cond_svm.json', 'r') as f:
    across_svm = json.load(f)
    
with open('across_cond_lda.json', 'r') as f:
    across_lda = json.load(f)
    
with open('across_cond_cnn.json', 'r') as f:
    across_cnn = json.load(f)

In [10]:
with open('across_cond_svm_same_norm.json', 'r') as f:
    across_svm_same_norm = json.load(f)
    
with open('across_cond_lda_same_norm.json', 'r') as f:
    across_lda_same_norm = json.load(f)

In [None]:
fig, ax = plt.subplots()

svm_res = across_svm
lda_res = across_lda
cnn_res = across_cnn

for pat_num in pat_nums:
    for task in tasks:
        for stim_cond in stim_conds:
            
            lda_val = lda_res[str(pat_num)][task][stim_cond]
            if lda_val is None:
                continue
            lda_val = lda_val['AUC']['mean']
            svm_val = svm_res[str(pat_num)][task][stim_cond]['AUC']['mean']
            if lda_val is not None and svm_val is not None:
                ax.plot([0,1],[svm_val, lda_val],c='k')

## Two cols:

In [115]:
plt.close('all')

In [122]:
fig, ax = plt.subplots(figsize=(4,4))

data1 = within_lda
data2 = across_lda

metric = 'f1'

set1, set2 = [],[]
for pat_num in pat_nums:
    for task in tasks:
        for stim_cond in stim_conds:
            
            data1_vals = data1[str(pat_num)][task][stim_cond]
            if data1_vals is None:
                continue
            data2_vals = data2[str(pat_num)][task][stim_cond]
            
            data1_val = data1_vals[metric]['mean']
            data2_val = data2_vals[metric]['mean']
            
            set1.append(data1_val); set2.append(data2_val)
            if data1_val is not None:
                c = 'C1' if stim_cond=='ON' else 'C0'
                ax.plot([0,1],[data1_val, data2_val],c=c)
                ax.scatter([0,1],[data1_val, data2_val],c=c,marker='s',s=15)
            
lines = [Line2D([0], [0], label = 'StimOFF',color='C0'), Line2D([0], [0], label = 'StimON',color='C1')]
plt.legend(lines, ['StimOFF', 'StimON'])

ax.set_xticks([0,1])
ax.set_xticklabels(['LDA \n within cond.','LDA \n across cond.'])
ax.set_ylabel(metric)

plt.tight_layout()
#ax.set_ylim([ax.get_ylim()[0], 1.15])

Canvas(toolbar=Toolbar(toolitems=[('Home', 'Reset original view', 'home', 'home'), ('Back', 'Back to previous …

In [71]:
fig, ax = plt.subplots(figsize=(4,4))

data1 = within_svm
data2 = within_cnn
data3 = within_lda

metric = 'f1'

set1, set2, set3 = [],[],[]
for pat_num in pat_nums:
    for task in tasks:
        for stim_cond in stim_conds:
            
            data1_vals = data1[str(pat_num)][task][stim_cond]
            if data1_vals is None:
                continue
            data2_vals = data2[str(pat_num)][task][stim_cond]
            data3_vals = data3[str(pat_num)][task][stim_cond]
            
            data1_val = data1_vals[metric]['mean']
            data2_val = data2_vals[metric]['mean']
            data3_val = data3_vals[metric]['mean']
            
            set1.append(data1_val); set2.append(data2_val); set3.append(data3_val)
            if data1_val is not None:
                c = 'C1' if stim_cond=='ON' else 'C0'
                ax.plot([0,1,2],[data1_val, data2_val, data3_val],c=c)
                ax.scatter([0,1,2],[data1_val, data2_val, data3_val],c=c,marker='s',s=15)
            
lines = [Line2D([0], [0], label = 'StimOFF',color='C0'), Line2D([0], [0], label = 'StimON',color='C1')]
plt.legend(lines, ['StimOFF', 'StimON'])

ax.set_xticks([0,1,2])
ax.set_xticklabels(['SVM','CNN','LDA'])
ax.set_ylabel(metric)

ax.set_ylim([ax.get_ylim()[0], 1.15])

Canvas(toolbar=Toolbar(toolitems=[('Home', 'Reset original view', 'home', 'home'), ('Back', 'Back to previous …

(0.5113930135520915, 1.15)

In [95]:
fig, ax = plt.subplots(figsize=(4,4))

data1 = within_svm
data2 = within_cnn
data3 = within_lda

metric = 'f1'

set1, set2, set3 = [],[],[]
for pat_num in pat_nums:
    for task in tasks:
        for stim_cond in stim_conds:
            
            data1_vals = data1[str(pat_num)][task][stim_cond]
            if data1_vals is None:
                continue
            data2_vals = data2[str(pat_num)][task][stim_cond]
            data3_vals = data3[str(pat_num)][task][stim_cond]
            
            for i in range(5):
                data1_val = data1_vals[metric]['folds'][i]
                data2_val = data2_vals[metric]['folds'][i]
                data3_val = data3_vals[metric]['folds'][i]

                set1.append(data1_val); set2.append(data2_val); set3.append(data3_val)
                if data1_val is not None:
                    c = 'C1' if stim_cond=='ON' else 'C0'
                    ax.plot([0,1,2],[data1_val, data2_val, data3_val],c=c, alpha=0.3,linewidth=0.3)
                    ax.scatter([0,1,2],[data1_val, data2_val, data3_val],c=c,marker='s',s=15, alpha=0.3)
            
lines = [Line2D([0], [0], label = 'StimOFF',color='C0'), Line2D([0], [0], label = 'StimON',color='C1')]
plt.legend(lines, ['StimOFF', 'StimON'])

ax.set_xticks([0,1,2])
ax.set_xticklabels(['SVM','CNN','LDA'])
ax.set_ylabel(metric)

ax.set_ylim([ax.get_ylim()[0], 1.15])

Canvas(toolbar=Toolbar(toolitems=[('Home', 'Reset original view', 'home', 'home'), ('Back', 'Back to previous …

(0.10102328181095083, 1.15)

In [86]:
plt.close('all')

In [48]:
fig, ax = plt.subplots()

within = across_svm
across = across_cnn

set1, set2 = [],[]
for pat_num in pat_nums:
    for task in tasks:
        for stim_cond in stim_conds:
            
            within_vals = within[str(pat_num)][task][stim_cond]
            if within_vals is None:
                continue
            across_vals = across[str(pat_num)][task][stim_cond]
            
            for i in range(5):
                within_val = within_vals['AUC']['folds'][i]
                across_val = across_vals['AUC']['folds'][i]
                set1.append(within_val); set2.append(across_val)
                if within_val is not None and across_val is not None:
                    ax.plot([0,1],[within_val, across_val],c='k', linewidth=0.5)

Canvas(toolbar=Toolbar(toolitems=[('Home', 'Reset original view', 'home', 'home'), ('Back', 'Back to previous …

In [62]:
from scipy.stats import ttest_rel

In [97]:
ttest_rel(set1,set3,alternative='less')

Ttest_relResult(statistic=1.6746425368660773, pvalue=0.952212410604509)

In [80]:
np.sum( (np.array(set1) - np.array(set2)) > 0.05) 

2

In [82]:
np.array(set2).size

40

In [98]:
fig, ax = plt.subplots()
set1, set2 = np.array(set1),np.array(set2)


sns.violinplot(np.zeros_like(set1),set2-set1, inner='point')

Canvas(toolbar=Toolbar(toolitems=[('Home', 'Reset original view', 'home', 'home'), ('Back', 'Back to previous …



<AxesSubplot:>

In [None]:
plt.close('all')