In [1]:
import os
import pickle
import models

import numpy as np
import numpy.ma as ma
import matplotlib.pyplot as plt
import seaborn as sns
from numpyro.diagnostics import hpdi

from hbmep.nn import functional as F
from hbmep.model.utils import Site as site

from hbmep.config import Config
from hbmep.model.utils import Site as site
from scipy import stats

from models import NonHierarchicalBayesianModel

USER = os.environ["USER"]

In [36]:
src = f"/home/andres/repos/rat-mapping-paper/reports/non-hierarchical/L_SHIE/non_hierarchical_bayesian_model/inference.pkl"
with open(src, "rb") as f:
    (
        df,
        encoder_dict,
        model,
        posterior_samples,
    ) = pickle.load(f)

In [42]:
encoder_dict['compound_charge_params'].classes_

array(['20-0-80-25', '50-0-50-100'], dtype=object)

In [38]:
a = posterior_samples[site.a]
b = posterior_samples[site.b]
L = posterior_samples[site.L]
ell = posterior_samples[site.ell]
H = posterior_samples[site.H]

x = np.linspace(0, 500, 1000)

named_params = [site.a, site.b, site.L, site.ell, site.H]
params = [posterior_samples[param][ ...] for param in named_params]

for named_param, param in zip(named_params, params):
    
    print(named_param, param.shape)

a (4000, 8, 4, 2, 6)
b (4000, 8, 4, 2, 6)
L (4000, 8, 4, 2, 6)
ℓ (4000, 8, 4, 2, 6)
H (4000, 8, 4, 2, 6)


In [39]:
subjects = df['participant'].unique()
subjects = sorted(encoder_dict['participant'].inverse_transform(subjects))

In [40]:
positions =df['compound_position'].unique()

In [33]:
sorted(encoder_dict['compound_position'].inverse_transform(positions))

['-C6LC', 'C6LC-', 'C6LC-C6LX', 'C6LX-C6LC']

In [43]:
charge_ind = 0

In [44]:
norm_y = []
norm_x = []

for subject_ind, subject in enumerate(subjects):
    curr_params = [
        param[:, subject_ind, ...][..., None] for param in params
    ]
    constant = np.nanmedian(np.nanmean(curr_params[0], axis=0))

    x_temp = np.linspace(0., 5 * constant, 500)
    x_temp = x_temp[None, None, None, None, :]

    temp_thresh = F.rectified_logistic(
        x_temp,
        *curr_params
    )
    temp_thresh = temp_thresh - curr_params[2]
    norm_y.append(temp_thresh)
    norm_x.append(x_temp)

norm_y = np.array(norm_y)
norm_x = np.array(norm_x)

# print(norm_x.shape)

In [45]:
print(norm_y.shape)

(8, 4000, 4, 2, 6, 500)


In [46]:
arr = norm_y
arr = ma.masked_invalid(arr)


In [47]:
arr[...,1,1,:,:].shape

(8, 4000, 6, 500)

In [49]:
def get_entropy(charge):
    y = arr[...,charge,0,:,:]
    x = norm_x

    y_max = ma.max(y, axis=(2,-1), keepdims=True) 

    y = ma.where(y, y / y_max, 0.)

    p = ma.sum(y, axis=-2, keepdims=True)
    p = ma.where(y, y / p, 0.)

    plogp = ma.where(p, p * ma.log(p), 0.)

    entropy = ma.where(
        ma.any(p, axis=-2, keepdims=True),
        (
            1
            + (ma.sum(plogp, axis=-2, keepdims=True) / np.log(plogp.shape[-2]))
        ),
        0.
    )
    entropy = entropy[..., 0, :]
    entropy.shape

    return entropy, x

In [13]:
compound_position = encoder_dict[model.features[1]].inverse_transform(sorted(df['compound_position'].unique()))

In [50]:
compound_position


array(['-C6LC', 'C6LC-', 'C6LC-C6LX', 'C6LX-C6LC'], dtype=object)

In [51]:
dashc_entropy, dashc_x = get_entropy(0)
cdash_entropy, cdash_x = get_entropy(1)
c_x_entropy, c_x_x = get_entropy(2)
x_c_entropy, x_c_x = get_entropy(3)

In [52]:
x_c_x.shape

(8, 1, 1, 1, 1, 500)

In [53]:
c_x_entropy.shape

(8, 4000, 500)

In [54]:
def AUC(entropy, x):
    area_list = []
    for subject_ind, subject in enumerate(subjects):
        x_temp = x[subject_ind, 0,0,0, ...]
        sample_integrals = entropy[subject_ind, ...]
        area = np.trapz(sample_integrals, x_temp)
        area_list.append(area)

    area_list = np.array(area_list)
    area_list = area_list.reshape(len(subjects), *area_list.shape[1:])
    area_list = ma.masked_values(area_list, 0)
    area_list_mean = ma.mean(area_list, axis=(-1))
    
    return area_list_mean

In [55]:
dashc = AUC(dashc_entropy, dashc_x)
cdash = AUC(cdash_entropy, cdash_x)
c_x = AUC(c_x_entropy, c_x_x)
x_c = AUC(x_c_entropy, x_c_x)

In [56]:
src = f"/home/andres/repos/rat-mapping-paper/notebooks/L_SHIE/selectivity_means.pkl"

with open(src, "wb") as f:
    pickle.dump([dashc , cdash, c_x, x_c], f)