In [1]:
import os
import pickle

import pandas as pd
from pandas.testing import assert_frame_equal
import numpy as np
import numpy.ma as ma
import matplotlib.pyplot as plt
import seaborn as sns
from scipy import stats

from hbmep.config import Config
from hbmep.model.utils import Site as site

from models import HB
from constants import (
    TOML_PATH,
    DATA_PATH,
    BUILD_DIR,
    INFERENCE_FILE
)


In [2]:
DIR = "/home/vishu/repos/rat-mapping-paper/reports/log-hierarchical/L_CIRC/hb__4000W_4000S_4C_4T_20D_0.95A_mixtureTrue"

d = {}
df = None
for response_ind in range(6):
    src = os.path.join(DIR, f"response_{response_ind}", INFERENCE_FILE)
    with open(src, "rb") as f:
        df_current, encoder_dict, model, posterior_samples_current, = pickle.load(f)

    if df is None:
        df = df_current.copy()
        posterior_samples = {u: v for u, v in posterior_samples_current.items() if v.ndim > 1}

    else:
        assert_frame_equal(df, df_current)
        for u in posterior_samples.keys():
            posterior_samples[u] = np.concatenate([posterior_samples[u], posterior_samples_current[u]], axis=-1)


An NVIDIA GPU may be present on this machine, but a CUDA-enabled jaxlib is not installed. Falling back to cpu.


In [3]:
for u, v in posterior_samples.items():
    print(u, v.shape)


H (4000, 8, 21, 6)
H_raw (4000, 8, 21, 6)
L (4000, 8, 21, 6)
L_raw (4000, 8, 21, 6)
a (4000, 8, 21, 6)
a_raw (4000, 8, 21, 6)
b (4000, 8, 21, 6)
b_raw (4000, 8, 21, 6)
c_1_raw (4000, 8, 21, 6)
c_2_raw (4000, 8, 21, 6)
c₁ (4000, 8, 21, 6)
c₂ (4000, 8, 21, 6)
ell_raw (4000, 8, 21, 6)
µ (4000, 8946, 6)
α (4000, 8946, 6)
β (4000, 8946, 6)
ℓ (4000, 8, 21, 6)


In [4]:
src = DATA_PATH
data = pd.read_csv(src)

config = Config(toml_path=TOML_PATH)
model = HB(config=config)
model.build_dir = DIR

# Run inference
ind = data[model.intensity] > 0
data = data[ind].reset_index(drop=True).copy()
data, encoder_dict = model.load(df=data)
data[model.intensity] = np.log(data[model.intensity])
assert_frame_equal(data, df)


In [5]:
dest = os.path.join(model.build_dir, INFERENCE_FILE)
with open(dest, "wb") as f:
    pickle.dump((data, encoder_dict, model.__dict__, posterior_samples,), f)

print(f"Saved to {dest}")

dest = os.path.join(model.build_dir, "model.pkl")
with open(dest, "wb") as f:
    pickle.dump((model,), f)

print(f"Saved to {dest}")


Saved to /home/vishu/repos/rat-mapping-paper/reports/log-hierarchical/L_CIRC/hb__4000W_4000S_4C_4T_20D_0.95A_mixtureTrue/inference.pkl
Saved to /home/vishu/repos/rat-mapping-paper/reports/log-hierarchical/L_CIRC/hb__4000W_4000S_4C_4T_20D_0.95A_mixtureTrue/model.pkl
