In [1]:
import os
import pickle

import pandas as pd
import numpy as np

from hbmep.config import Config
from hbmep.model.utils import Site as site

from models import NonHierarchicalBayesianModel
from constants import (
    TOML_PATH,
    DATA_PATH,
    BUILD_DIR,
)


In [2]:
src = os.path.join(BUILD_DIR, "inference.pkl")

with open(src, "rb") as f:
    combinations, posterior_samples = pickle.load(f)

df = pd.read_csv(DATA_PATH)
print(df.shape)


(16440, 44)


In [3]:
def remove_combinations(
    remove,
    combinations,
    posterior_samples
):
    ind = [c not in remove for c in combinations]
    combinations = [c for c in combinations if c not in remove]

    for u in posterior_samples.keys():
        posterior_samples[u] = posterior_samples[u][:, ind, ...]

    return combinations, posterior_samples


In [4]:
df.columns


Index(['pulse_amplitude', 'pulse_train_frequency', 'pulse_period',
       'pulse_duration', 'pulse_count', 'train_delay', 'channel1_1',
       'channel1_2', 'channel1_3', 'channel1_4', 'channel2_1', 'channel2_2',
       'channel2_3', 'channel2_4', 'charge_params_1', 'charge_params_2',
       'charge_params_3', 'charge_params_4', 'bank_check', 'run', 'enabled',
       'channel_failA', 'channel_failB', 'channel_fail_comb', 'time_pulse',
       'time', 'ix_onsets', 'LBiceps', 'LFCR', 'LECR', 'LTriceps', 'LADM',
       'LDeltoid', 'LBicepsFemoris', 'RBiceps', 'channel1_laterality',
       'channel1_segment', 'channel2_laterality', 'channel2_segment',
       'compound_position', 'compound_charge_params', 'participant',
       'subdir_pattern', 'charge_param_error'],
      dtype='object')

In [5]:
# Filter out P Mono 80-0-20-400
remove = [c for c in combinations if c[2] == "80-0-20-400"]
combinations, posterior_samples = remove_combinations(remove, combinations, posterior_samples)

print(len(combinations))
print(posterior_samples[site.a].shape)


240
(4000, 240, 6)


In [6]:
# Filter out ground contacts
remove = [c for c in combinations if "" in c[1].split("-")]
combinations, posterior_samples = remove_combinations(remove, combinations, posterior_samples)

print(len(combinations))
print(posterior_samples[site.a].shape)


144
(4000, 144, 6)


In [7]:
dict_position_charge_to_subject = {}
for subject, position, charge in combinations:
    if (position, charge) not in dict_position_charge_to_subject:
        dict_position_charge_to_subject[(position, charge)] = [subject]
    else:
        dict_position_charge_to_subject[(position, charge)].append(subject)


In [8]:
len(dict_position_charge_to_subject.keys())

30

In [9]:
complete_cases = [(u, len(dict_position_charge_to_subject[u])) for u in dict_position_charge_to_subject.keys()]
complete_cases = [u[0] for u in complete_cases if u[1] == df.participant.nunique()]

complete_cases


[('C7L-C7M', '20-0-80-25'),
 ('C7L-C7M', '50-0-50-0'),
 ('C7L-C7M', '50-0-50-100'),
 ('C7M-C7L', '20-0-80-25'),
 ('C7M-C7L', '50-0-50-0'),
 ('C7M-C7L', '50-0-50-100')]

In [19]:
conditions = [c for c in combinations if (c[1], c[2]) in complete_cases]
len(conditions)


48