In [1]:
import os
import pickle
import models

import numpy as np
import numpy.ma as ma
import matplotlib.pyplot as plt
import seaborn as sns
from numpyro.diagnostics import hpdi

from hbmep.nn import functional as F
from hbmep.model.utils import Site as site

from hbmep.config import Config
from hbmep.model.utils import Site as site
from scipy import stats

from models import NonHierarchicalBayesianModel

USER = os.environ["USER"]

In [2]:
src = f"/home/{USER}/repos/rat-mapping-paper/reports/non-hierarchical/J_RCML_000/non_hierarchical_bayesian_model/inference.pkl"
with open(src, "rb") as f:
    (
        df,
        encoder_dict,
        model,
        posterior_samples,
    ) = pickle.load(f)


An NVIDIA GPU may be present on this machine, but a CUDA-enabled jaxlib is not installed. Falling back to cpu.


In [3]:
subjects = encoder_dict[model.features[0]].inverse_transform(df.participant.unique())
positions = encoder_dict[model.features[1]].inverse_transform(df.compound_position.unique())

print(len(subjects), len(positions))


8 21


In [4]:
positions

array(['C5M-C5L', 'C5M-C6L', 'C7L-C8L', '-C5M', 'C7M-C8L', 'C8M-C8L',
       '-C8M', '-C6L', '-C7M', '-C8L', 'C7M-C8M', 'C5L-C6L', '-C5L',
       'C6L-C7L', '-C7L', 'C7M-C7L', 'C6M-C7M', 'C5M-C6M', '-C6M',
       'C6M-C7L', 'C6M-C6L'], dtype=object)

In [5]:
named_params = [site.a, site.b, site.L, site.ell, site.H]
params = [posterior_samples[param][:400, ...] for param in named_params]


In [6]:
x = np.arange(0, 500, 1)
x = x[:, *([None] * 4)]
x.shape


(500, 1, 1, 1, 1)

In [7]:
params = [param[None, ...] for param in params]


In [8]:
params[0].shape


(1, 400, 8, 21, 6)

In [9]:
y = F.rectified_logistic(
    x, a=params[0], b=params[1], L=params[2], ell=params[3], H=params[4]
)
y = np.array(y)

y = ma.masked_array(y, mask=np.isnan(y))
y.shape


(500, 400, 8, 21, 6)

In [10]:
# norm_y = []
# norm_x = []

# for subject_ind, subject in enumerate(subjects):
#     curr_params = [
#         param[:, subject_ind, :, :][:, :, :, None] for param in params
#     ]
#     constant = curr_params[0].mean() #do this for bvs

#     x_temp = np.linspace(0., 5 * constant, 500)
#     x_temp = x_temp[None, None, None, :]

#     temp_thresh = F.rectified_logistic(
#         x_temp,
#         *curr_params
#     )
#     norm_y.append(temp_thresh)
#     norm_x.append(x_temp)

# norm_y = np.array(norm_y)
# norm_x = np.array(norm_x)

# print(norm_y.shape)
# print(norm_x.shape)

# norm_y = np.swapaxes(norm_y, 1, 2)
# norm_x = np.swapaxes(norm_x, 1, 2)

# print(norm_y.shape)
# print(norm_x.shape)


In [11]:
cath_lat = [(i, cpos) for i, cpos in enumerate(positions) if cpos.split("-")[1][-1] == "L" and cpos.split("-")[0] == ""]
cath_lat = sorted(cath_lat, key=lambda x: x[-1])
print(cath_lat)

cath_lat = [i for (i, _) in cath_lat]
print(cath_lat)


[(12, '-C5L'), (7, '-C6L'), (14, '-C7L'), (9, '-C8L')]
[12, 7, 14, 9]


In [12]:
y = y[..., cath_lat, :]
y.shape


(500, 400, 8, 4, 6)

In [13]:
y_max = ma.max(y, axis=(0,), keepdims=True)
y_max.shape


(1, 400, 8, 4, 6)

In [14]:
y = ma.where(y, y / y_max, 0.)
y.shape


(500, 400, 8, 4, 6)

In [15]:
p = ma.sum(y, axis=-1, keepdims=True)
p = ma.where(y, y / p, 0.)
p.shape


(500, 400, 8, 4, 6)

In [16]:
plogp = ma.where(p, p * ma.log(p), 0.)
plogp.shape


(500, 400, 8, 4, 6)

In [17]:
entropy = ma.where(
    ma.any(p, axis=-1, keepdims=True),
    (
        1
        + (ma.sum(plogp, axis=-1, keepdims=True) / np.log(plogp.shape[-1]))
    ),
    0.
)

entropy = entropy[..., 0]
entropy.shape


(500, 400, 8, 4)

In [8]:
positions

array(['C5M-C5L', 'C5M-C6L', 'C7L-C8L', '-C5M', 'C7M-C8L', 'C8M-C8L',
       '-C8M', '-C6L', '-C7M', '-C8L', 'C7M-C8M', 'C5L-C6L', '-C5L',
       'C6L-C7L', '-C7L', 'C7M-C7L', 'C6M-C7M', 'C5M-C6M', '-C6M',
       'C6M-C7L', 'C6M-C6L'], dtype=object)

In [7]:
cath_mid = [(i, cpos) for i, cpos in enumerate(positions) if cpos.split("-")[1][-1] == "M" and cpos.split("-")[0] == ""]
cath_lat = [(i, cpos) for i, cpos in enumerate(positions) if cpos.split("-")[1][-1] == "L" and cpos.split("-")[0] == ""]

mid_mid = [(i, cpos) for i, cpos in enumerate(positions) if cpos.split("-")[1][-1] == "M" and cpos.split("-")[0] != "" and cpos.split("-")[0][-1] == "M"]
lat_lat = [(i, cpos) for i, cpos in enumerate(positions) if cpos.split("-")[1][-1] == "L" and cpos.split("-")[0] != "" and cpos.split("-")[0][-1] == "L"]
mid_lat = [(i, cpos) for i, cpos in enumerate(positions) if cpos.split("-")[1][-1] == "L" and cpos.split("-")[0] != "" and cpos.split("-")[0][-1] == "M"]

cath_mid = [i for (i, cpos) in cath_mid]
cath_lat = [i for (i, cpos) in cath_lat]
mid_mid = [i for (i, cpos) in mid_mid]
lat_lat = [i for (i, cpos) in lat_lat]
mid_lat = [i for (i, cpos) in mid_lat]

In [9]:
cath_lat

[7, 9, 12, 14]

In [10]:
np.array(positions)[cath_lat]

array(['-C6L', '-C8L', '-C5L', '-C7L'], dtype=object)

In [12]:
cath_mid_entropy, cath_mid_x = get_entropy(cath_mid)
cath_lat_entropy, cath_lat_x = get_entropy(cath_lat)
mid_mid_entropy, mid_mid_x = get_entropy(mid_mid)
lat_lat_entropy, lat_lat_x = get_entropy(lat_lat)
mid_lat_entropy, mid_lat_x = get_entropy(mid_lat)

In [26]:
cath_mid_entropy.shape

(8, 4, 400, 500)

In [13]:
def get_area_list(mix, entropy, x):
    area_list = []
    for subject_ind, subject in enumerate(subjects):

        for cpos_ind in range(len(entropy[0,:,0,0])):
            cpos = positions[mix[cpos_ind]]
            x_temp = x[subject_ind, 0, ...].tolist()
            
            sample_integrals = entropy[subject_ind, cpos_ind, :, :]
            sample_integrals_mean = sample_integrals.mean(axis=0)
            
            area = np.trapz(sample_integrals_mean, x_temp)
            area_list.append((area))

    # filt_positions = list(set((a[1] for a in area_list)))
    area_list = np.array(area_list)
    area_list = area_list.reshape(len(subjects), len(entropy[0,:,0,0]), *area_list.shape[1:])
    return area_list

In [16]:
cath_mid_list = get_area_list(cath_mid, cath_mid_entropy, cath_mid_x)
cath_lat_list = get_area_list(cath_lat, cath_lat_entropy, cath_lat_x)
# mid_mid_list = get_area_list(mid_mid, mid_mid_entropy, mid_mid_x)
# lat_lat_list = get_area_list(lat_lat, lat_lat_entropy, lat_lat_x)
# mid_lat_list = get_area_list(mid_lat, mid_lat_entropy, mid_lat_x)

  area_list = np.array(area_list)


In [17]:
cath_mid_list

array([[        nan,         nan,         nan,         nan],
       [        nan,         nan,         nan,         nan],
       [        nan,         nan,         nan,         nan],
       [17.98200421, 19.27162404, 23.52306667, 22.84036543],
       [ 9.41572147,  8.20576207,  9.9555965 , 10.32166155],
       [52.85556427, 72.01456922, 48.25866931, 24.22335774],
       [22.20573364, 15.94846173, 19.87388737, 32.35204555],
       [ 9.51518522, 24.798713  , 12.84131624,  6.37203329]])

In [116]:
positions

array(['C5M-C5L', 'C5M-C6L', 'C7L-C8L', '-C5M', 'C7M-C8L', 'C8M-C8L',
       '-C8M', '-C6L', '-C7M', '-C8L', 'C7M-C8M', 'C5L-C6L', '-C5L',
       'C6L-C7L', '-C7L', 'C7M-C7L', 'C6M-C7M', 'C5M-C6M', '-C6M',
       'C6M-C7L', 'C6M-C6L'], dtype=object)

In [115]:
coded_positions = df.compound_position.unique()
coded_positions


array([ 9, 10, 16,  1, 18, 20,  7,  2,  5,  6, 19,  8,  0, 12,  4, 17, 15,
       11,  3, 14, 13])

In [27]:
cath_lat

[7, 9, 12, 14]

In [44]:
df[(df["participant"] == 0) & (df['compound_position'] == 14)]

Unnamed: 0,pulse_amplitude,pulse_train_frequency,pulse_period,pulse_duration,pulse_count,train_delay,channel1_1,channel1_2,channel1_3,channel1_4,...,RBiceps,channel1_laterality,channel1_segment,channel2_laterality,channel2_segment,compound_position,compound_charge_params,participant,subdir_pattern,charge_param_error
