In [1]:
import os
import pickle
import models

import numpy as np
import numpy.ma as ma
import matplotlib.pyplot as plt
import seaborn as sns
from numpyro.diagnostics import hpdi

from hbmep.nn import functional as F
from hbmep.model.utils import Site as site

from hbmep.config import Config
from hbmep.model.utils import Site as site
from scipy import stats

from models import NonHierarchicalBayesianModel

USER = os.environ["USER"]

MAX_THRESHOLD_CONST = 2.24#3.17 #CONSTANT WILL BE mean of rats for stim max / threshold for EVERY possible combination... ma.mean(ma.median((stim_max/ma.mean(a, axis=0)),axis = (1,2,3)),axis =0)
NUM_THRESHOLD_POINTS = 500

# NORMALIZATION WILL BE ACROSS ELECTRODE


In [2]:
src = f"/home/andres/repos/rat-mapping-paper/reports/C_SMA_LAR/non_hierarchical_bayesian_model/inference.pkl"
with open(src, "rb") as f:
    (
        df,
        encoder_dict,
        model,
        posterior_samples,
    ) = pickle.load(f)


An NVIDIA GPU may be present on this machine, but a CUDA-enabled jaxlib is not installed. Falling back to cpu.


In [3]:
a = posterior_samples[site.a]
b = posterior_samples[site.b]
L = posterior_samples[site.L]
ell = posterior_samples[site.ell]
H = posterior_samples[site.H]

x = np.linspace(0, 500, 1000)

named_params = [site.a, site.b, site.L, site.ell, site.H]
params = [posterior_samples[param][ ...] for param in named_params]

for named_param, param in zip(named_params, params):
    
    print(named_param, param.shape)

a (4000, 8, 70, 3, 6)
b (4000, 8, 70, 3, 6)
L (4000, 8, 70, 3, 6)
ℓ (4000, 8, 70, 3, 6)
H (4000, 8, 70, 3, 6)


In [4]:
subjects = sorted(df['participant'].unique())
subjects = encoder_dict['participant'].inverse_transform(subjects)
subjects

array(['amap01', 'amap02', 'amap03', 'amap04', 'amap05', 'amap06',
       'amap07', 'amap08'], dtype=object)

In [5]:
positions = sorted(df['compound_position'].unique())
positions = encoder_dict['compound_position'].inverse_transform(positions)
positions


array(['-C5L', '-C5LL', '-C5LM', '-C5LM1', '-C5LM2', '-C5M', '-C6L',
       '-C6LL', '-C6LM', '-C6LM1', '-C6LM2', '-C6M', '-C6R', '-C6RM',
       '-C6RR', '-C7LM', '-C7M', '-C7R', '-C7RM', '-C7RR', 'C5L-C5LL',
       'C5LM1-C5L', 'C5LM1-C5LL', 'C5LM2-C5L', 'C5LM2-C5LL',
       'C5LM2-C5LM1', 'C5LM2-C5M', 'C5LM2-C6L', 'C5LM2-C6LL', 'C5M-C5L',
       'C5M-C5LL', 'C5M-C5LM', 'C5M-C5LM1', 'C5M-C5LM2', 'C5M-C6L',
       'C5M-C6LL', 'C6L-C6LL', 'C6LL-C6L', 'C6LM-C6L', 'C6LM-C6LL',
       'C6LM-C6M', 'C6LM-C6R', 'C6LM-C6RM', 'C6LM-C6RR', 'C6LM1-C6L',
       'C6LM1-C6LL', 'C6LM2-C6L', 'C6LM2-C6LL', 'C6LM2-C6LM1', 'C6M-C6L',
       'C6M-C6LL', 'C6M-C6LM', 'C6M-C6LM1', 'C6M-C6LM2', 'C6M-C6R',
       'C6M-C6RM', 'C6M-C6RR', 'C6R-C6RR', 'C6RM-C6R', 'C6RM-C6RR',
       'C7LM-C7M', 'C7LM-C7R', 'C7LM-C7RM', 'C7LM-C7RR', 'C7M-C7R',
       'C7M-C7RM', 'C7M-C7RR', 'C7R-C7RR', 'C7RM-C7R', 'C7RM-C7RR'],
      dtype=object)

In [6]:
sizes = sorted(df['compound_size'].unique())

In [7]:
sizes = sizes[:2]

In [8]:
muscles = model.response

In [9]:
pos_inv = encoder_dict['compound_position'].transform(positions)
p1 = pos_inv[:6]
p2 = pos_inv[6:12]
keys = [p1, p2]

In [10]:
poses = np.concatenate((p1, p2))

In [11]:
poses

array([ 0,  1,  2,  3,  4,  5,  6,  7,  8,  9, 10, 11])

In [12]:
stim_max = []

for s in subjects:
    for p in poses:
        for s in sizes:
            for m in muscles: 
                temp = df[(df['participant'] == s) 
                        & (df['compound_position'] == p) 
                        & (df['compound_size'] == s)
                        ]
                temp = temp.loc[:,[m, 'pulse_amplitude']]
                temp = temp.pulse_amplitude.max()
                stim_max.append(temp)
stim_max = ma.array(stim_max)
stim_max = stim_max.reshape(len(subjects), len(poses), len(sizes), len(muscles))

In [13]:
stim_max.shape

(8, 12, 2, 6)

In [14]:
a.shape

(4000, 8, 70, 3, 6)

In [15]:
atemp = a[:,:,:12,:2,:]

In [16]:
atemp = ma.mean(atemp, axis=0)

In [17]:
ma.mean(ma.median((stim_max/atemp),axis = (1,2,3)),axis = 0)

2.2355382176103276

In [18]:
def get_normalized_input_output():
    norm_y = []
    norm_x = []

    for subject_ind, subject in enumerate(subjects):
        for k in keys:
            # Fix the subject index and size index
            curr_params = [
                params[:, subject_ind, k, :2, :, None] for params in params
            ]

            # Take nanmean across posterior samples
            temp = np.nanmean(curr_params[0], axis=0)
            # Take nanmedian across compound positions and muscles
            median_threshold = np.nanmedian(temp)

            x_temp = np.linspace(0.,  median_threshold * MAX_THRESHOLD_CONST, NUM_THRESHOLD_POINTS)
            x_temp = x_temp[None, None, None, :]

            temp_thresh = F.rectified_logistic(
                x_temp, #x_temp
                *curr_params
            )
            # Remove offset
            temp_thresh = temp_thresh - curr_params[2]
            norm_y.append(temp_thresh)
            norm_x.append(x_temp)

    norm_x = np.array(norm_x)
    norm_x = norm_x.reshape(len(subjects), len(keys), *norm_x.shape[1:])
    norm_y = np.array(norm_y)
    norm_y = norm_y.reshape(len(subjects), len(keys), *norm_y.shape[1:])
    norm_y = ma.masked_invalid(norm_y)
    return norm_x, norm_y


In [19]:
x, y = get_normalized_input_output()

y.shape

  temp = np.nanmean(curr_params[0], axis=0)
  temp = np.nanmean(curr_params[0], axis=0)


(8, 2, 4000, 6, 2, 6, 500)

In [20]:
# y_max = ma.max(y, axis=(1, 3,4, -1), keepdims=True)


In [21]:
# y = ma.where(y_max, y / y_max, 0.)

In [22]:
# y.shape

In [23]:
big_ind = [0,1,2,2,5]
small_ind = [0,1,3,4,5]

In [24]:
y_big = y[...,big_ind,0,:,:]
y_small = y[...,small_ind,1,:,:]

In [25]:
y_big.shape

(8, 2, 4000, 5, 6, 500)

In [26]:
p_big = ma.sum(y_big, axis=-2, keepdims=True)
p_big = ma.where(p_big, y_big / p_big, 1 / y_big.shape[-2])
p_small = ma.sum(y_small, axis=-2, keepdims=True)
p_small = ma.where(p_small, y_small / p_small, 1 / y_small.shape[-2])

In [27]:
# p = ma.sum(y, axis=-2, keepdims=True)
# p = ma.where(p, y / p, 1 / y.shape[-2])

In [28]:
plogp_big = ma.where(p_big, p_big * ma.log(p_big), 0.)
plogp_small = ma.where(p_small, p_small * ma.log(p_small), 0.)

entropy_big = 1 + (plogp_big.sum(axis=-2) / ma.log(y_big.shape[-2]))
entropy_small = 1 + (plogp_small.sum(axis=-2) / ma.log(y_small.shape[-2]))


In [29]:
auc_big = np.trapz(y=entropy_big[...], x=np.linspace(0, MAX_THRESHOLD_CONST, NUM_THRESHOLD_POINTS), axis=-1)
auc_small = np.trapz(y=entropy_small[...], x=np.linspace(0, MAX_THRESHOLD_CONST, NUM_THRESHOLD_POINTS), axis=-1)

In [30]:
auc_big.shape

(8, 2, 4000, 5)

In [31]:
auc_big = auc_big.mean(axis=(2))
auc_small = auc_small.mean(axis=(2))
auc_big.shape

(8, 2, 5)

In [33]:
mat = auc_big - auc_small
mat = mat.mean(axis=(1,2))
mat.shape

(8,)

In [34]:
mat

masked_array(data=[-0.29247572613737083, -0.37060375221102876,
                   -0.07240368126741362, 0.08819247751086592,
                   -0.21377309659121801, -0.3076911824794817,
                   0.07821630081213153, -0.2609445116266059],
             mask=[False, False, False, False, False, False, False, False],
       fill_value=1e+20)

In [36]:
src = f"/home/andres/repos/rat-mapping-paper/notebooks/C_SMA_LAR/cst_selectivity_means.pkl"

with open(src, "wb") as f:
    pickle.dump([auc_small, auc_big], f)

In [35]:
stats.wilcoxon(mat, axis=0)

WilcoxonResult(statistic=5.0, pvalue=0.078125)

In [56]:
plogp = ma.where(p, p * ma.log(p), 0.)

entropy = 1 + (plogp.sum(axis=-2) / ma.log(y.shape[-2]))
entropy.shape

(8, 2, 4000, 6, 2, 500)

In [57]:
auc = np.trapz(y=entropy[...], x=np.linspace(0, MAX_THRESHOLD_CONST, NUM_THRESHOLD_POINTS), axis=-1)
auc.shape

(8, 2, 4000, 6, 2)

In [58]:
auc = auc.mean(axis=(2))
auc.shape

(8, 2, 6, 2)

In [59]:
mat = auc[:, :, :,0] - auc[:,:, :, 1]

In [61]:
mat = mat.mean(axis=1)

In [62]:
mat.shape

(8, 6)

In [None]:
def plot_selectivity_comparison(size_ind):
    x, y = get_normalized_input_output(size_ind)
    num_masked = y.mask.any(axis=(1, 3, -1)).sum()
    y_max = ma.max(y, axis=(1, 3, -1), keepdims=True)
    
    num_masked = y.mask.sum()
    y = ma.where(y_max, y / y_max, 0.)
    
    p = ma.sum(y, axis=-2, keepdims=True)
    p = ma.where(p, y / p, 1 / y.shape[-2])
    
    plogp = ma.where(p, p * ma.log(p), 0.)

    entropy = 1 + (plogp.sum(axis=-2) / ma.log(y.shape[-2]))
    
    auc = np.trapz(y=entropy[...], x=np.linspace(0, MAX_THRESHOLD_CONST, NUM_THRESHOLD_POINTS), axis=-1)
    auc = auc.mean(axis=(2))
    
    mat = auc[:, :, :, None] - auc[:, :, None, :]
    mat = mat.mean(axis=1)
    
    pvalues = []
    for i in range(6):
        for i2 in range(6):
            if i==i2:
                pvalues.append(1)
                continue
            temp_p = stats.wilcoxon(mat[:, i,i2], axis=0).pvalue
            pvalues.append(temp_p)

    pvalues = np.array(pvalues)
    pvalues = pvalues.reshape(6, 6)
    
    labels = ['L', 'LL', 'LM', 'LM1', 'LM2', 'M']
    mask = np.tril(np.ones_like(pvalues), k=0).astype(bool)
    sns.heatmap(pvalues, xticklabels=labels, yticklabels=labels, mask = mask, annot=False)
    sns.heatmap(pvalues, xticklabels=labels, yticklabels=labels, mask = mask,annot=np.round(mat.mean(axis=0).data, 3), annot_kws={'va':'top'}, fmt="", cbar=False)
    sns.heatmap(pvalues, xticklabels=labels, yticklabels=labels, mask = mask,annot=np.round(pvalues, 3), annot_kws={'va':'bottom'}, fmt="", cbar=False)

In [None]:
plot_selectivity_comparison(1)

In [None]:
x,y = get_normalized_input_output(1)

y_max = ma.max(y, axis=(1, 3, -1), keepdims=True)

y = ma.where(y_max, y / y_max, 0.)

p = ma.sum(y, axis=-2, keepdims=True)
p = ma.where(p, y / p, 1 / y.shape[-2])

plogp = ma.where(p, p * ma.log(p), 0.)

entropy = 1 + (plogp.sum(axis=-2) / ma.log(y.shape[-2]))

auc = np.trapz(y=entropy[...], x=np.linspace(0, MAX_THRESHOLD_CONST, NUM_THRESHOLD_POINTS), axis=-1)
auc = auc.mean(axis=(2))

mat = auc[:, :, :, None] - auc[:, :, None, :]
mat = mat.mean(axis=1) 

In [None]:
auc.shape 