In [1]:
import os
import pickle

import numpy as np
import numpy.ma as ma

from hbmep.nn import functional as F
from hbmep.model.utils import Site as site

from models import (
    NonHierarchicalBayesianModel
)

USER = os.environ["USER"]


In [2]:
src = f"/home/{USER}/repos/rat-mapping-paper/reports/non-hierarchical/C_SMA_LAR/non_hierarchical_bayesian_model/processed_inference.pkl"
with open(src, "rb") as f:
    (
        df,
        encoder_dict,
        model,
        posterior_samples,
        y,
        subjects,
        compound_positions,
        compound_size,
    ) = pickle.load(f)


An NVIDIA GPU may be present on this machine, but a CUDA-enabled jaxlib is not installed. Falling back to cpu.


In [3]:
combinations = df[model.features].apply(tuple, axis=1).unique()
combinations = sorted(combinations)
len(combinations)


403

In [4]:
df[["channel1_segment", "channel2_segment"]].apply(tuple, axis=1).unique()


array([('C5', 'C5'), (nan, 'C5'), ('C6', 'C6'), ('C5', 'C6'), (nan, 'C6'),
       ('C7', 'C7'), (nan, 'C7')], dtype=object)

In [5]:
df[["channel1_designation", "channel2_designation"]].apply(tuple, axis=1).unique()


array([('M', 'LL'), (nan, 'LL'), ('M', 'LM'), (nan, 'M'), (nan, 'LM'),
       (nan, 'L'), ('M', 'L'), ('LM', 'RR'), (nan, 'LM2'), ('RM', 'RR'),
       ('M', 'RM'), ('LM1', 'L'), ('RM', 'R'), ('LM2', 'L'), ('LM', 'RM'),
       ('LM1', 'LL'), (nan, 'R'), (nan, 'LM1'), ('LM2', 'LL'),
       (nan, 'RR'), ('LM', 'M'), ('LM2', 'LM1'), ('LM2', 'M'),
       ('M', 'RR'), ('R', 'RR'), ('M', 'R'), ('L', 'LL'), (nan, 'RM'),
       ('LM', 'R'), ('LL', 'L'), ('M', 'LM2'), ('M', 'LM1'), ('LM', 'LL'),
       ('LM', 'L')], dtype=object)

In [6]:
contacts = ["mono", "bi"]   # Monopolar and bipolar
positions = ["C5", "C6", "C7"]  # left and right positions are the same
left_degrees = ["LL", "L", "LM", "LM1", "LM2", "M"]  # Right degree equals M incase of bipolar
size = compound_size[:-1]   # Keeps only B and S

arr = []
for subject_ind, subject in enumerate(subjects):
    for contact_ind, contact in enumerate(contacts):
        for position_ind, position in enumerate(positions):
            for left_degree_ind, left_degree in enumerate(left_degrees):
                for size_ind, size_ in enumerate(size):
                    if contact == "mono": cpos = f"-{position}{left_degree}"
                    else: cpos = f"{position}M-{position}{left_degree}"

                    combination = (subject, cpos, size)
                    if combination in combinations:
                        cpos_ind = [i for i, x in enumerate(compound_positions) if x == cpos]
                        print(cpos, cpos_ind)
                        assert len(cpos_ind) == 1
                        cpos_ind = cpos_ind[0]
                        # Append the response for this combination of shape
                        # (n_samples, n_response, n_time) = (400, 6, 1000)
                        assert np.isnan(y[:, subject_ind, cpos_ind, size_ind, ...]).sum() == 0
                        arr.append(y[:, subject_ind, cpos_ind, size_ind, ...])

                    else:
                        # np.nan will work as mask
                        arr.append(np.full((400, 6, 1000), np.nan))

arr = np.array(arr)
arr = arr.reshape(len(subjects), len(contacts), len(positions), len(left_degrees), len(size), *arr.shape[1:])
arr.shape

(8, 2, 3, 6, 2, 400, 6, 1000)

In [8]:
arr = ma.masked_invalid(arr)
arr.shape


(8, 2, 3, 6, 2, 400, 6, 1000)

In [9]:
x = np.arange(0, 500, .5)
x.shape


(1000,)

In [11]:
dest = os.path.join(model.build_dir, "processed_inference_selectivity.pkl")
with open(dest, "wb") as f:
    pickle.dump(
        (
            df,
            encoder_dict,
            model,
            posterior_samples,
            arr,
            subjects,
            contacts,
            positions,
            left_degrees,
            size,
            x
        ),
        f
    )


In [12]:
src = f"/home/{USER}/repos/rat-mapping-paper/reports/non-hierarchical/C_SMA_LAR/non_hierarchical_bayesian_model/processed_inference_selectivity.pkl"
with open(src, "rb") as f:
    (
        df,
        encoder_dict,
        model,
        posterior_samples,
        arr,
        subjects,
        contacts,
        positions,
        left_degrees,
        size,
        x,
    ) = pickle.load(f)
