In [7]:
import subprocess
import sys
import os
import pandas as pd
import glob
import math
import emcee
import string
import corner
import numpy as np
from tqdm import tqdm
from numba import jit
import concurrent.futures
from multiprocessing import Pool, cpu_count
import matplotlib.pyplot as plt
import matplotlib.lines as mlines
import matplotlib.ticker as mticker
from scipy.integrate import odeint
import seaborn as sns
import scipy.stats as stats

data_files = sorted(glob.glob("./data/*.txt"))
full_data = {}

for file in data_files:
    data_type = os.path.basename(file).split('/')[-1].split('_c')[0]
    mouse_name = 'c' + file.split('/')[-1].split('_c')[-1].split('_t')[0]
    t_days = np.array([int(t) for t in file.split('/')[-1].split('_c')[-1].split('.txt')[0].split('_t')[1:]])
    data = np.loadtxt(file)
    treatment = os.path.basename(data_type).split('_')[0]
    if data_type not in full_data:
        full_data[data_type] = []
    full_data[data_type].append({
        'name': mouse_name,
        'data': data,
        'treatment': treatment,
        'treatment_days': t_days
    })

def calculate_n_total(full_data, group):
    n_total = 0
    for mouse_data in full_data[group]:
        n_total += len(mouse_data['data'])
    return n_total

all_models = [
    '_exp',
    '_mendel'
]

group_ll_total = {'control_sensitive': {}, 'control_resistant': {}}
group_params_total = {'control_sensitive': {}, 'control_resistant': {}}
group_obs_total = {'control_sensitive': 0, 'control_resistant': 0}
group_bic = {'control_sensitive': {}, 'control_resistant': {}}

for model in all_models:
    for group in group_ll_total:
        group_ll_total[group][model] = 0
        group_params_total[group][model] = 0
        group_bic[group][model] = np.nan

for group in ['control_sensitive', 'control_resistant']:
    n_total = calculate_n_total(full_data, group)
    group_obs_total[group] = n_total

    for model in all_models:
        files_location = f'./Output_Calibration/multi_ll_pars_{group}{model}.npz'
        npzfile = np.load(files_location)

        max_ll = npzfile['max_ll']
        theta = npzfile['pars']
        k = len(theta)

        group_ll_total[group][model] = max_ll
        group_params_total[group][model] = k

        bic_global = k * np.log(n_total) - 2 * max_ll if n_total > 0 else np.nan
        group_bic[group][model] = bic_global

print("% --- BIC ---")
header = "group & " + " & ".join([model.replace('_', '') for model in all_models]) + " \\\\ \\hline"
print(header)

for group in ['control_sensitive', 'control_resistant']:
    row_text = f"{group} "
    for model in all_models:
        bic_global = group_bic[group][model]
        row_text += "& --- " if np.isnan(bic_global) else f"& {bic_global:.2f} "
    print(row_text + "\\\\ \\hline")

print("\n% --- WBIC (BIC weights) ---")
header_wbic = "group & " + " & ".join([model.replace('_', '') for model in all_models]) + " \\\\ \\hline"
print(header_wbic)

for group in ['control_sensitive', 'control_resistant']:
    bics = np.array([group_bic[group][model] for model in all_models], dtype=float)
    mask = ~np.isnan(bics)
    bics_valid = bics[mask]

    if bics_valid.size == 0:
        weights = np.full_like(bics, np.nan, dtype=float)
    else:
        delta = bics_valid - np.min(bics_valid)
        w_unnorm = np.exp(-0.5 * delta)
        w_norm = w_unnorm / np.sum(w_unnorm)
        weights = np.full_like(bics, np.nan, dtype=float)
        weights[mask] = w_norm

    row_text = f"{group} "
    for w in weights:
        row_text += "& --- " if np.isnan(w) else f"& {w:.3f} "
    print(row_text + "\\\\ \\hline")


% --- BIC ---
group & exp & mendel \\ \hline
control_sensitive & 422.33 & 396.75 \\ \hline
control_resistant & 556.07 & 564.47 \\ \hline

% --- WBIC (BIC weights) ---
group & exp & mendel \\ \hline
control_sensitive & 0.000 & 1.000 \\ \hline
control_resistant & 0.985 & 0.015 \\ \hline
