In [1]:
import os
import pickle
import models

import numpy as np
import numpy.ma as ma
import matplotlib.pyplot as plt
import seaborn as sns
from numpyro.diagnostics import hpdi

from hbmep.nn import functional as F
from hbmep.model.utils import Site as site

from hbmep.config import Config
from hbmep.model.utils import Site as site
from scipy import stats

from models import NonHierarchicalBayesianModel

USER = os.environ["USER"]

In [2]:
src = f"/home/andres/repos/rat-mapping-paper/reports/J_SHAP/non_hierarchical_bayesian_model/inference.pkl"
with open(src, "rb") as f:
    (
        df,
        encoder_dict,
        model,
        posterior_samples,
    ) = pickle.load(f)

An NVIDIA GPU may be present on this machine, but a CUDA-enabled jaxlib is not installed. Falling back to cpu.


In [3]:
a = posterior_samples[site.a]
b = posterior_samples[site.b]
L = posterior_samples[site.L]
ell = posterior_samples[site.ell]
H = posterior_samples[site.H]

x = np.linspace(0, 500, 1000)

named_params = [site.a, site.b, site.L, site.ell, site.H]
params = [posterior_samples[param][ ...] for param in named_params]

for named_param, param in zip(named_params, params):
    
    print(named_param, param.shape)

a (4000, 8, 16, 4, 6)
b (4000, 8, 16, 4, 6)
L (4000, 8, 16, 4, 6)
ℓ (4000, 8, 16, 4, 6)
H (4000, 8, 16, 4, 6)


In [4]:
subjects = df['participant'].unique()
subjects = sorted(encoder_dict['participant'].inverse_transform(subjects))


In [5]:
norm_y = []
norm_x = []

for subject_ind, subject in enumerate(subjects):
    curr_params = [
        param[:, subject_ind, ...][..., None] for param in params
    ]
    constant = np.nanmedian(np.nanmean(curr_params[0], axis=0))

    x_temp = np.linspace(0., 5 * constant, 500)
    x_temp = x_temp[None, None, None, None, :]

    temp_thresh = F.rectified_logistic(
        x_temp,
        *curr_params
    )
    temp_thresh = temp_thresh - curr_params[2]
    norm_y.append(temp_thresh)
    norm_x.append(x_temp)

norm_y = np.array(norm_y)
norm_x = np.array(norm_x)

# print(norm_y.shape)
# print(norm_x.shape)

  constant = np.nanmedian(np.nanmean(curr_params[0], axis=0))
  constant = np.nanmedian(np.nanmean(curr_params[0], axis=0))


In [39]:
norm_y.shape

(8, 4000, 16, 4, 6, 500)

In [40]:
compound_position = encoder_dict[model.features[1]].inverse_transform(sorted(df['compound_position'].unique()))
mono = [c for c in compound_position if c.split("-")[0] == "" or c.split("-")[1] == ""]
mono = np.array(mono)

In [8]:
model.features[1]

'compound_position'

In [42]:
mono

array(['-C6L', '-C7L', '-C8L', 'C6L-', 'C7L-', 'C8L-'], dtype='<U4')

In [43]:
mono = encoder_dict[model.features[1]].transform(mono)


In [None]:
compound_charge = encoder_dict[model.features[2]].inverse_transform(sorted(df['compound_charge_params'].unique()))
print(compound_charge)

['20-0-80-25' '50-0-50-0' '50-0-50-100' '80-0-20-400']


In [21]:
ind = [0,1,2]

In [32]:
mono[:3]

array([0, 1, 2])

In [24]:
arr = arr.swapaxes(0, 1)


In [26]:
arr = arr[:, :, ind, ...]

In [57]:
arr = norm_y
arr = ma.masked_invalid(arr)
arr.shape

(8, 4000, 16, 4, 6, 500)

In [47]:
test = norm_y[...,mono,1,:,:]
test.shape

(8, 4000, 6, 6, 500)

In [58]:
def get_entropy(elec_ind):
    if elec_ind == 0:
        elec_ind = mono[:3]
    else:
        elec_ind = mono[3:]
    y = arr[...,elec_ind,1,:,:]
    x = norm_x

    y_max = ma.max(y, axis=(2,3, -1), keepdims=True) 

    y = ma.where(y, y / y_max, 0.)

    p = ma.sum(y, axis=-2, keepdims=True)
    p = ma.where(y, y / p, 0.)

    plogp = ma.where(p, p * ma.log(p), 0.)

    entropy = ma.where(
        ma.any(p, axis=-2, keepdims=True),
        (
            1
            + (ma.sum(plogp, axis=-2, keepdims=True) / np.log(plogp.shape[-2]))
        ),
        0.
    )
    entropy = entropy[..., 0, :]
    entropy.shape

    return entropy, x

In [59]:
cathodic_entropy, cathodic_x = get_entropy(0)
anodic_entropy, anodic_x = get_entropy(1)

In [35]:
pseudo25_entropy, pseudo25_x = get_entropy(0)
mono_entropy, mono_x = get_entropy(1)
bi_entropy, bi_x = get_entropy(2)
pseudo80_entropy, pseudo80_x = get_entropy(3)


In [52]:
def AUC(entropy, x):
    area_list = []
    for subject_ind, subject in enumerate(subjects):
        x_temp = x[subject_ind, 0,0, ...]
        sample_integrals = entropy[:, subject_ind, ...]
        area = np.trapz(sample_integrals, x_temp)
        area_list.append(area)

    area_list = np.array(area_list)
    area_list = area_list.reshape(len(subjects), *area_list.shape[1:])
    area_list = ma.masked_values(area_list, 0)
    area_list_mean = ma.mean(area_list, axis=(-1, -2))
    
    return area_list_mean
    

In [60]:
cathodic = AUC(cathodic_entropy, cathodic_x)
anodic = AUC(anodic_entropy, anodic_x)

In [48]:
pseudo25 = AUC(pseudo25_entropy, pseudo25_x)
mono = AUC(mono_entropy, mono_x)
bi = AUC(bi_entropy, bi_x)
pseudo80 = AUC(pseudo80_entropy, pseudo80_x)

In [54]:
auc_list = [cathodic , anodic]

In [61]:
anodic

masked_array(data=[401.8133227804154, 498.29304357476747,
                   526.457555664568, 216.19598743128378,
                   105.27595208439043, 199.48292263557556,
                   114.48068324463604, 58.940062908339925],
             mask=[False, False, False, False, False, False, False, False],
       fill_value=1e+20)

In [62]:
src = f"/home/andres/repos/rat-mapping-paper/notebooks/J_SHAP/cath_anode_selectivity_means.pkl"

with open(src, "wb") as f:
    pickle.dump([cathodic , anodic], f)

In [68]:
for a1_ind, auc1 in enumerate(auc_list):
    for a2_ind, auc2 in enumerate(auc_list):
        if a1_ind == a2_ind:
            continue
        t = auc1 - auc2
        ttest = stats.ttest_1samp(t, popmean=0, axis=0, alternative='less')
        pvalue = ttest.pvalue
        print(compound_charge[a1_ind], compound_charge[a2_ind], pvalue)

20-0-80-25 50-0-50-0 0.13942974345812995
20-0-80-25 50-0-50-100 0.04215669430193781
20-0-80-25 80-0-20-400 0.02894819264413264
50-0-50-0 20-0-80-25 0.86057025654187
50-0-50-0 50-0-50-100 0.20342394860276447
50-0-50-0 80-0-20-400 0.03345596220632771
50-0-50-100 20-0-80-25 0.9578433056980622
50-0-50-100 50-0-50-0 0.7965760513972355
50-0-50-100 80-0-20-400 0.041965577763092204
80-0-20-400 20-0-80-25 0.9710518073558674
80-0-20-400 50-0-50-0 0.9665440377936723
80-0-20-400 50-0-50-100 0.9580344222369078
