In [None]:
import pylab as pl
%reload_ext autoreload
%autoreload 2

import pickle

import numpy as np
import scipy.stats as stats
from scipy.optimize import minimize_scalar
from scipy.special import logit, expit

from early_markers.cribsy.common.constants import PKL_DIR, RAND_STATE
from early_markers.cribsy.common.bayes import BayesianData


np.random.seed(RAND_STATE)


def beta_hdi(alpha, beta, ci=0.95):
    """Calculate HDI for Beta distribution using optimization"""
    def interval_width(low):
        high = low + ci
        lower = stats.beta.ppf(low, alpha, beta)
        upper = stats.beta.ppf(high, alpha, beta)
        return abs(upper - lower)
    
    result = minimize_scalar(interval_width, bounds=(0, 1-ci), method='bounded')
    low = result.x
    high = low + ci
    return stats.beta.ppf([low, high], alpha, beta)

# ========================================================================
# 1. Bayesian Assurance Method (BAM) for Model Development Sample Size
# ========================================================================
def revised_bam_model(pilot_data, hdi_width=0.15, ci=0.95, 
                     target_assurance=0.8, simulations=2000,
                     max_sample=10000):
    # Hyperpriors from pilot ESS
    n_pilot = len(pilot_data)
    p_pilot = np.mean(pilot_data)
    pilot_ess = n_pilot
    
    alpha_shape = pilot_ess * p_pilot + 1
    beta_shape = pilot_ess * (1 - p_pilot) + 1
    
    alpha_hyper = stats.gamma(alpha_shape, scale=1/pilot_ess)
    beta_hyper = stats.gamma(beta_shape, scale=1/pilot_ess)
    
    # Binary search with stabilized estimates
    low, high = max(50, int(n_pilot*0.5)), max_sample
    best_n = max_sample
    
    while low <= high:
        mid = (low + high) // 2
        valid = 0
        min_sims = 0
        
        while min_sims < simulations:
            # Hierarchical sampling
            a_h = alpha_hyper.rvs()
            b_h = beta_hyper.rvs()
            theta = stats.beta(a_h + np.sum(pilot_data), 
                              b_h + n_pilot - np.sum(pilot_data)).rvs()
            
            # Data generation
            k = np.random.binomial(mid, theta)
            a_post = a_h + np.sum(pilot_data) + k
            b_post = b_h + n_pilot - np.sum(pilot_data) + mid - k
            
            # Accurate HDI calculation
            lower, upper = beta_hdi(a_post, b_post, ci)
            if (upper - lower) <= hdi_width:
                valid += 1
            min_sims += 1
        
        assurance = valid / simulations
        if assurance >= target_assurance:
            best_n = mid
            high = mid - 1
        else:
            low = mid + 1
            
    return best_n


def ipw_informed_bam_performance(pilot_se, pilot_sp, prevalence_prior=(8,32),
                             hdi_width=0.1, ci=0.95, target_assurance=0.8,
                             simulations=1000, max_sample=5000,
                             strata_props=None, strata_prev=None,
                             strata_se_var=None, strata_sp_var=None,
                             use_optimal_allocation=True):
    """
    Sample size determination for diagnostic accuracy studies using inverse probability weighting.
    
    This function extends the informed_bam_performance function to incorporate stratified sampling
    and inverse probability weighting, which can reduce the required sample size by optimizing
    the allocation of samples across population strata.
    
    Parameters:
    -----------
    pilot_se : tuple (TP, FN)
        Pilot data for sensitivity (true positives, false negatives)
    pilot_sp : tuple (TN, FP)
        Pilot data for specificity (true negatives, false positives)
    prevalence_prior : tuple (alpha, beta), default=(8, 32)
        Beta prior parameters for disease prevalence
    hdi_width : float, default=0.1
        Target width for the highest density interval
    ci : float, default=0.95
        Confidence level for the highest density interval
    target_assurance : float, default=0.8
        Target probability of achieving the desired HDI width
    simulations : int, default=1000
        Number of simulations for the assurance calculation
    max_sample : int, default=5000
        Maximum sample size to consider
    strata_props : list or array, default=None
        Proportions of different strata in the target population
        If None, assumes a single stratum (homogeneous population)
    strata_prev : list or array, default=None
        Expected prevalence in each stratum
        If None, uses prevalence_prior for all strata
    strata_se_var : list or array, default=None
        Expected variance of sensitivity in each stratum
        If None, estimates variance based on pilot data
    strata_sp_var : list or array, default=None
        Expected variance of specificity in each stratum
        If None, estimates variance based on pilot data
    use_optimal_allocation : bool, default=True
        Whether to use optimal allocation (Neyman allocation) based on stratum variances
        If False, uses proportional allocation based on stratum sizes
    
    Returns:
    --------
    int
        Minimum sample size required to achieve the target assurance
    dict
        Additional information including optimal allocation and effective sample size
    """
    # Derive Beta priors from pilot data
    se_alpha, se_beta = pilot_se[0] + 1, pilot_se[1] + 1
    sp_alpha, sp_beta = pilot_sp[0] + 1, pilot_sp[1] + 1
    
    # Set up strata
    if strata_props is None:
        strata_props = [1.0]  # Single stratum
    
    n_strata = len(strata_props)
    
    if strata_prev is None:
        # Use the same prevalence prior for all strata
        prev_mean = prevalence_prior[0] / (prevalence_prior[0] + prevalence_prior[1])
        strata_prev = [prev_mean] * n_strata
    
    # Estimate sensitivity and specificity variance if not provided
    if strata_se_var is None:
        se_mean = se_alpha / (se_alpha + se_beta)
        se_var = (se_alpha * se_beta) / ((se_alpha + se_beta)**2 * (se_alpha + se_beta + 1))
        strata_se_var = [se_var] * n_strata
    
    if strata_sp_var is None:
        sp_mean = sp_alpha / (sp_alpha + sp_beta)
        sp_var = (sp_alpha * sp_beta) / ((sp_alpha + sp_beta)**2 * (sp_alpha + sp_beta + 1))
        strata_sp_var = [sp_var] * n_strata
    
    # Determine allocation of samples to strata
    if use_optimal_allocation:
        # For sensitivity, optimal allocation is proportional to stratum size * sqrt(variance)
        # We consider both sensitivity and specificity by taking the maximum variance
        se_alloc = [strata_props[i] * np.sqrt(strata_se_var[i]) for i in range(n_strata)]
        sp_alloc = [strata_props[i] * np.sqrt(strata_sp_var[i]) for i in range(n_strata)]
        
        # Use the maximum of the two allocations for each stratum
        opt_alloc = [max(se_alloc[i], sp_alloc[i]) for i in range(n_strata)]
        sampling_props = [a / sum(opt_alloc) for a in opt_alloc]
    else:
        # Proportional allocation
        sampling_props = strata_props
    
    # Calculate inverse probability weights
    ipw = [strata_props[i] / sampling_props[i] for i in range(n_strata)]
    
    # Calculate design effect
    deff = sum([(ipw[i]**2) * sampling_props[i] for i in range(n_strata)])
    
    # Calculate effective sample size reduction factor
    ess_factor = 1 / deff
    
    # Binary search setup with realistic bounds
    low = max(100, int(100 * ess_factor))  # Adjust lower bound based on effective sample size
    high = max_sample
    best_n = max_sample
    
    while low <= high:
        mid = (low + high) // 2
        valid = 0
        min_cases = max(10, int(0.05*mid))  # At least 5% cases per class
        
        for _ in range(simulations):
            # Allocate samples to strata based on sampling proportions
            stratum_sizes = np.random.multinomial(mid, sampling_props)
            
            # Generate data for each stratum
            weighted_tp = 0
            weighted_fn = 0
            weighted_tn = 0
            weighted_fp = 0
            
            total_pos = 0
            total_neg = 0
            
            for i in range(n_strata):
                # Skip empty strata
                if stratum_sizes[i] == 0:
                    continue
                
                # Sample prevalence for this stratum
                prev = stats.beta(*prevalence_prior).rvs() if len(strata_prev) == 1 else strata_prev[i]
                
                # Sample sensitivity and specificity
                se = stats.beta(se_alpha, se_beta).rvs()
                sp = stats.beta(sp_alpha, sp_beta).rvs()
                
                # Generate synthetic data with prevalence floor
                n_pos = max(1, np.random.binomial(stratum_sizes[i], prev)) if stratum_sizes[i] > 0 else 0
                n_neg = stratum_sizes[i] - n_pos
                
                total_pos += n_pos
                total_neg += n_neg
                
                if n_pos > 0:
                    tp = np.random.binomial(n_pos, se)
                    fn = n_pos - tp
                else:
                    tp, fn = 0, 0
                
                if n_neg > 0:
                    tn = np.random.binomial(n_neg, sp)
                    fp = n_neg - tn
                else:
                    tn, fp = 0, 0
                
                # Apply inverse probability weight
                w = ipw[i]
                weighted_tp += tp * w
                weighted_fn += fn * w
                weighted_tn += tn * w
                weighted_fp += fp * w
            
            # Ensure minimum cases per class
            if total_pos < min_cases or total_neg < min_cases:
                continue
            
            # Calculate posteriors with weighted counts
            se_post_alpha = se_alpha + weighted_tp
            se_post_beta = se_beta + weighted_fn
            sp_post_alpha = sp_alpha + weighted_tn
            sp_post_beta = sp_beta + weighted_fp
            
            # Calculate HDI widths
            se_hdi = beta_hdi(se_post_alpha, se_post_beta, ci)
            sp_hdi = beta_hdi(sp_post_alpha, sp_post_beta, ci)
            
            if (se_hdi[1] - se_hdi[0] <= hdi_width and
                sp_hdi[1] - sp_hdi[0] <= hdi_width):
                valid += 1
        
        assurance = valid / simulations
        
        if assurance >= target_assurance:
            best_n = mid
            high = mid - 1
        else:
            low = mid + 1
    
    # Return additional information about the design
    info = {
        'sampling_props': sampling_props,
        'ipw': ipw,
        'design_effect': deff,
        'ess_factor': ess_factor,
        'estimated_reduction': 1 - (best_n / (best_n / ess_factor))
    }
    
    return best_n, info


# def informed_bam_performance(pilot_se, pilot_sp, prevalence_prior=(8,32),
#                             hdi_width=0.1, ci=0.95, target_assurance=0.8,
#                             simulations=1000, max_sample=5000):
#     """
#     Optimized BAM implementation for performance evaluation
#     
#     Parameters:
#     pilot_se (tuple): (TP, FN) from pilot data
#     pilot_sp (tuple): (TN, FP) from pilot data
#     prevalence_prior (tuple): Beta parameters for prevalence
#     hdi_width (float): Desired HDI width
#     ci (float): Credible interval level
#     target_assurance (float): Required joint assurance probability
#     simulations (int): Number of simulations per sample size
#     max_sample (int): Maximum allowable sample size
#     
#     Returns:
#     Optimal sample size (int)
#     """
#     # Derive Beta priors from pilot data
#     se_alpha, se_beta = pilot_se[0] + 1, pilot_se[1] + 1
#     sp_alpha, sp_beta = pilot_sp[0] + 1, pilot_sp[1] + 1
#     
#     # Binary search setup with realistic bounds
#     low, high = 100, max_sample
#     best_n = max_sample
#     
#     while low <= high:
#         mid = (low + high) // 2
#         valid = 0
#         min_cases = max(10, int(0.05*mid))  # At least 5% cases per class
#         
#         for _ in range(simulations):
#             # Sample parameters from informed priors
#             se = stats.beta(se_alpha, se_beta).rvs()
#             sp = stats.beta(sp_alpha, sp_beta).rvs()
#             prev = stats.beta(*prevalence_prior).rvs()
#             
#             # Generate synthetic data with prevalence floor
#             n_pos = max(min_cases, np.random.binomial(mid, prev))
#             n_neg = max(min_cases, mid - n_pos)
#             
#             tp = np.random.binomial(n_pos, se)
#             tn = np.random.binomial(n_neg, sp)
#             
#             # Calculate posteriors
#             se_post_alpha = se_alpha + tp
#             se_post_beta = se_beta + (n_pos - tp)
#             sp_post_alpha = sp_alpha + tn
#             sp_post_beta = sp_beta + (n_neg - tn)
#             
#             # Calculate HDI widths
#             se_hdi = beta_hdi(se_post_alpha, se_post_beta, ci)
#             sp_hdi = beta_hdi(sp_post_alpha, sp_post_beta, ci)
#             
#             if (se_hdi[1] - se_hdi[0] <= hdi_width and 
#                 sp_hdi[1] - sp_hdi[0] <= hdi_width):
#                 valid += 1
#                 
#         assurance = valid / simulations
#         
#         if assurance >= target_assurance:
#             best_n = mid
#             high = mid - 1
#         else:
#             low = mid + 1
#             
#     return best_n


In [None]:
import pickle

import polars as pl
from polars import DataFrame

from early_markers.cribsy.common.constants import PKL_DIR, RAND_STATE
from early_markers.cribsy.common.bayes import BayesianData


with open(PKL_DIR / "bd_real.pkl", "rb") as f:
    bd: BayesianData = pickle.load(f)
    
metrics = bd.metrics("real_k_19")

# perf = (
#     metrics.metrics
#     .filter(pl.col("thresh").round(5) == round(metrics.threshold_j, 5))
#     .to_dicts()[0]
# )

prims = (
    metrics.primitives
    .filter(pl.col("thresh").round(5) == round(metrics.threshold_j, 5))
    .with_columns(
        hits=pl.col("tp") + pl.col("tn")
    ).to_dicts()[0]
)
total = metrics.test_n

pos = [1 for _ in range(prims["hits"])]
neg = [0 for _ in range(total - prims["hits"])]
pilot_data = pos + neg

dev_sample = revised_bam_model(pilot_data, hdi_width=0.2, max_sample=1000)
print(f"BAM Model Development Sample Size: {dev_sample}")

# Artificially inflate variance estimate for positive stratum
strata_se_var = [0.25, 0.1]  # High variance for positives, low for negatives
strata_sp_var = [0.1, 0.25]  # Reverse pattern for specificity

# sens 0.8265714285714286 = TP / (TP + FN) = 
# spec 0.6857142857142857

# Pilot data: 80 TP, 20 FN (sens 0.8), 90 TN, 10 FP (spec 0.9)
perf_sample, design_info = ipw_informed_bam_performance(
    pilot_se=(prims["tp"], prims["fn"]),
    pilot_sp=(prims["tn"], prims["fp"]),
    prevalence_prior=(4, 32),  # Beta(8,32) ≈ mean 0.2
    hdi_width=0.2,
    strata_props=[0.10, 0.90],
    strata_prev=[0.90, 0.05],
    strata_se_var=strata_se_var,
    strata_sp_var=strata_sp_var,
    simulations=2000,
)

# sample_size, design_info = ipw_informed_bam_performance(
#     pilot_se=(45,5), 
#     pilot_sp=(90,10),
#     strata_props=[0.1, 0.9],
#     sampling_props=[0.5, 0.5],
#     use_optimal_allocation=False
# )

print(f"BAM Performance Evaluation Sample Size: {perf_sample}")

# CI Width: 0.25
# BAM Model Development Sample Size: 50
# BAM Performance Evaluation Sample Size: 372

In [None]:
stats.beta(4, 28).mean()

In [None]:


# Generate pilot data (60% positive outcomes)

pilot_data = np.random.binomial(1, 0.6, 50)

# 1. Model development sample size
dev_sample = bam_model_development(pilot_data, hdi_width=0.15)
print(f"BAM Model Development Sample Size: {dev_sample}")

# 2. Performance evaluation sample size
# Using Gamma(10,2) priors for Se/Sp Beta parameters
# Beta(1,1) prior for prevalence (uniform)
perf_sample = bam_performance_evaluation(prior_se=(10,2), prior_sp=(10,2))
print(f"BAM Performance Evaluation Sample Size: {perf_sample}")


In [None]:
import numpy as np
import scipy.stats as stats
from scipy.optimize import minimize_scalar


# Example usage
if __name__ == "__main__":
    # Pilot data: 80 TP, 20 FN (sens 0.8), 90 TN, 10 FP (spec 0.9)
    sample_size = informed_bam_performance(
        pilot_se=(80, 20),
        pilot_sp=(90, 10),
        prevalence_prior=(8, 32),  # Beta(8,32) ≈ mean 0.2
        hdi_width=0.2,
        simulations=2000,
    )
    print(f"Optimized sample size: {sample_size}")


In [None]:

def bam_performance_evaluation(prior_se=(10, 2), prior_sp=(10, 2), prior_prev=(1, 1),
                              hdi_width=0.1, ci=0.95, target_assurance=0.8,
                              simulations=500, max_sample=10000):
    """
    BAM implementation for classification performance evaluation
    
    Parameters:
    prior_se (tuple): Gamma hyperparameters for sensitivity Beta prior
    prior_sp (tuple): Gamma hyperparameters for specificity Beta prior
    prior_prev (tuple): Beta parameters for prevalence
    hdi_width (float): Desired HDI width for both Se and Sp
    ci (float): Credible interval level
    target_assurance (float): Required joint assurance probability
    simulations (int): Number of prior samples per evaluation
    max_sample (int): Maximum allowable sample size
    
    Returns:
    Optimal sample size (int)
    """
    # Binary search setup
    low, high = 10, max_sample
    best_n = max_sample
    
    while low <= high:
        mid = (low + high) // 2
        joint_assurance = 0
        
        for _ in range(simulations):
            # Sample hyperparameters
            a_se = stats.gamma(*prior_se).rvs()
            b_se = stats.gamma(*prior_se).rvs()
            a_sp = stats.gamma(*prior_sp).rvs()
            b_sp = stats.gamma(*prior_sp).rvs()
            
            # Sample true parameters
            se = stats.beta(a_se, b_se).rvs()
            sp = stats.beta(a_sp, b_sp).rvs()
            prev = stats.beta(*prior_prev).rvs()
            
            # Generate synthetic data
            n_pos = np.random.binomial(mid, prev)
            n_neg = mid - n_pos
            
            tp = np.random.binomial(n_pos, se)
            fn = n_pos - tp
            tn = np.random.binomial(n_neg, sp)
            fp = n_neg - tn
            
            # Update posteriors
            a_se_post = a_se + tp
            b_se_post = b_se + fn
            a_sp_post = a_sp + tn
            b_sp_post = b_sp + fp
            
            # Check HDI widths
            se_lower, se_upper = beta_hdi(a_se_post, b_se_post, ci)
            sp_lower, sp_upper = beta_hdi(a_sp_post, b_sp_post, ci)
            
            if (se_upper - se_lower <= hdi_width) and (sp_upper - sp_lower <= hdi_width):
                joint_assurance += 1
                
        assurance_prob = joint_assurance / simulations
        
        if assurance_prob >= target_assurance:
            best_n = mid
            high = mid - 1
        else:
            low = mid + 1
            
    return best_n