In [1]:
import pandas as pd
import numpy as np
from dataclasses import dataclass
from typing import Dict, Optional, List, Tuple
import matplotlib.pyplot as plt
import seaborn as sns
from scipy import stats
import warnings
warnings.filterwarnings('ignore')

# Set style for better plots
plt.style.use('seaborn-v0_8')
sns.set_palette("husl")
plt.rcParams['figure.figsize'] = (10, 6)
plt.rcParams['font.size'] = 10

print("Libraries imported successfully!")

Libraries imported successfully!


In [41]:
@dataclass
class Parameter:
    """Simple parameter class with sampling capability"""
    mean: float
    std: float
    dist_type: str = 'normal'  # normal, gamma, beta

    def sample(self, n=1):
        """Sample from parameter distribution"""
        if self.std == 0:
            return np.full(n, self.mean)
        if self.dist_type == 'gamma':
            shape = (self.mean / self.std) ** 2
            scale = self.std ** 2 / self.mean
            return np.random.gamma(shape, scale, n)
        elif self.dist_type == 'beta':
            alpha = self.mean * (self.mean * (1 - self.mean) / self.std**2 - 1)
            beta = (1 - self.mean) * (self.mean * (1 - self.mean) / self.std**2 - 1)
            return np.random.beta(max(alpha, 0.1), max(beta, 0.1), n)
        else:  # normal
            return np.maximum(0, np.random.normal(self.mean, self.std, n))

print("Parameter class defined!")

Parameter class defined!


In [42]:
class GlaucomaParameters:
    """All model parameters in one place"""

    def __init__(self, scenario_name="Default",
                 mild_to_moderate_mean=0.15, mild_to_moderate_std=0.05,
                 moderate_to_severe_mean=0.12, moderate_to_severe_std=0.04,
                 severe_to_vi_mean=0.10, severe_to_vi_std=0.03,
                 true_positive_rate=0.90, tp_std=0.05,
                 true_negative_rate=0.85, tn_std=0.05,
                 false_positive_rate=0.15, fp_std=0.05,
                 false_negative_rate=0.10, fn_std=0.05,
                 sensitivity=0.775, sensitivity_std=0.066,
                 specificity=0.954, specificity_std=0.009,
                 # State-specific screening accuracy parameters
                 tp_mild=0.85, tp_mild_std=0.05,
                 tn_mild=0.90, tn_mild_std=0.05,
                 fp_mild=0.10, fp_mild_std=0.05,
                 fn_mild=0.15, fn_mild_std=0.05,
                 tp_moderate=0.90, tp_moderate_std=0.05,
                 tn_moderate=0.88, tn_moderate_std=0.05,
                 fp_moderate=0.12, fp_moderate_std=0.05,
                 fn_moderate=0.10, fn_moderate_std=0.05,
                 tp_severe=0.95, tp_severe_std=0.04,
                 tn_severe=0.85, tn_severe_std=0.05,
                 fp_severe=0.15, fp_severe_std=0.05,
                 fn_severe=0.05, fn_severe_std=0.04,
                 detection_proportion=1.0, detection_std=0.000001,
                 prevalence_general=0.05, prevalence_general_std=0.005,
                 prevalence_dr=0.07, prevalence_dr_std=0.01,
                 screening_cost=10, screening_cost_std=0.01,
                 # Separate screening costs
                 ai_screening_cost=30, ai_screening_cost_std=5,
                 human_screening_cost=75, human_screening_cost_std=15):

        self.scenario_name = scenario_name

        # COSTS (annual, in USD)
        self.costs = {
            'monitoring_mild': Parameter(352, 0.2*352, 'gamma'),
            'monitoring_moderate': Parameter(463, 0.2*463, 'gamma'),
            'monitoring_severe': Parameter(644, 0.2*644, 'gamma'),
            'monitoring_vi': Parameter(576, 0.2*576, 'gamma'),
            'treatment_mild': Parameter(303, 0.2*303, 'gamma'),
            'treatment_moderate': Parameter(429, 0.2*429, 'gamma'),
            'treatment_severe': Parameter(609, 0.2*609, 'gamma'),
            'treatment_vi': Parameter(662, 0.2*662, 'gamma'),
            'other_mild': Parameter(0, 0, 'gamma'),
            'other_moderate': Parameter(0, 0, 'gamma'),
            'other_severe': Parameter(0, 0, 'gamma'),
            'other_vi': Parameter(4186 + 1334, 0.2*(4186 + 1334), 'gamma'),
            'productivity_mild': Parameter(0, 0, 'gamma'),
            'productivity_moderate': Parameter(0, 0, 'gamma'),
            'productivity_severe': Parameter(0, 0, 'gamma'),
            'productivity_vi': Parameter(7630, 0.2*7630, 'gamma'),
            'screening': Parameter(screening_cost, screening_cost_std, 'gamma'),
            'ai_screening': Parameter(ai_screening_cost, ai_screening_cost_std, 'gamma'),
            'human_screening': Parameter(human_screening_cost, human_screening_cost_std, 'gamma'),
        }

        # UTILITIES (0-1 scale)
        self.utilities = {
            'utility_mild': Parameter(0.985, 0.023, 'beta'),
            'utility_moderate': Parameter(0.899, 0.039, 'beta'),
            'utility_severe': Parameter(0.773, 0.046, 'beta'),
            'utility_vi': Parameter(0.634, 0.052, 'beta'),
        }

        # TRANSITION PROBABILITIES
        self.transitions = {
            'mild_to_moderate': Parameter(mild_to_moderate_mean, mild_to_moderate_std, 'beta'),
            'moderate_to_severe': Parameter(moderate_to_severe_mean, moderate_to_severe_std, 'beta'),
            'severe_to_vi': Parameter(severe_to_vi_mean, severe_to_vi_std, 'beta'),
        }

        # SCREENING ACCURACY PARAMETERS (overall)
        self.screening_accuracy = {
            'true_positive_rate': Parameter(true_positive_rate, tp_std, 'beta'),
            'true_negative_rate': Parameter(true_negative_rate, tn_std, 'beta'),
            'false_positive_rate': Parameter(false_positive_rate, fp_std, 'beta'),
            'false_negative_rate': Parameter(false_negative_rate, fn_std, 'beta'),
            'sensitivity': Parameter(sensitivity, sensitivity_std, 'beta'),
            'specificity': Parameter(specificity, specificity_std, 'beta'),
        }

        # STATE-SPECIFIC SCREENING ACCURACY PARAMETERS
        self.screening_accuracy_mild = {
            'tp_mild': Parameter(tp_mild, tp_mild_std, 'beta'),
            'tn_mild': Parameter(tn_mild, tn_mild_std, 'beta'),
            'fp_mild': Parameter(fp_mild, fp_mild_std, 'beta'),
            'fn_mild': Parameter(fn_mild, fn_mild_std, 'beta'),
        }

        self.screening_accuracy_moderate = {
            'tp_moderate': Parameter(tp_moderate, tp_moderate_std, 'beta'),
            'tn_moderate': Parameter(tn_moderate, tn_moderate_std, 'beta'),
            'fp_moderate': Parameter(fp_moderate, fp_moderate_std, 'beta'),
            'fn_moderate': Parameter(fn_moderate, fn_moderate_std, 'beta'),
        }

        self.screening_accuracy_severe = {
            'tp_severe': Parameter(tp_severe, tp_severe_std, 'beta'),
            'tn_severe': Parameter(tn_severe, tn_severe_std, 'beta'),
            'fp_severe': Parameter(fp_severe, fp_severe_std, 'beta'),
            'fn_severe': Parameter(fn_severe, fn_severe_std, 'beta'),
        }

        # DETECTION AND PREVALENCE PARAMETERS
        self.screening_params = {
            'detection_proportion': Parameter(detection_proportion, detection_std, 'beta'),
            'prevalence_general': Parameter(prevalence_general, prevalence_general_std, 'beta'),
            'prevalence_dr': Parameter(prevalence_dr, prevalence_dr_std, 'beta'),
        }

        # DISCOUNT RATES
        self.discount_rates = {
            'cost_discount': Parameter(0.03, 0.01, 'beta'),
            'health_discount': Parameter(0.015, 0.005, 'beta'),
        }

    @classmethod
    def create_ai_pure_scenario(cls, **kwargs):
        """AI PURE SCENARIO - AI transition matrix + AI screening"""
        defaults = {
            'scenario_name': "AI Pure",
            'mild_to_moderate_mean': 0.058, 'mild_to_moderate_std': 0.000303,
            'moderate_to_severe_mean': 0.04, 'moderate_to_severe_std': 0.000253,
            'severe_to_vi_mean': 0.032, 'severe_to_vi_std': 0.00023,
            'true_positive_rate': 0.95, 'tp_std': 0.02,
            'true_negative_rate': 0.92, 'tn_std': 0.02,
            'false_positive_rate': 0.08, 'fp_std': 0.02,
            'false_negative_rate': 0.05, 'fn_std': 0.02,
            'sensitivity' : 0.775, 'sensitivity_std': 0.066,
            'specificity' : 0.954, 'specificity_std': 0.009,
            # AI-enhanced state-specific screening accuracy
            'tp_mild': 0.92, 'tp_mild_std': 0.03,
            'tn_mild': 0.94, 'tn_mild_std': 0.02,
            'fp_mild': 0.06, 'fp_mild_std': 0.02,
            'fn_mild': 0.08, 'fn_mild_std': 0.03,
            'tp_moderate': 0.95, 'tp_moderate_std': 0.02,
            'tn_moderate': 0.93, 'tn_moderate_std': 0.02,
            'fp_moderate': 0.07, 'fp_moderate_std': 0.02,
            'fn_moderate': 0.05, 'fn_moderate_std': 0.02,
            'tp_severe': 0.98, 'tp_severe_std': 0.01,
            'tn_severe': 0.91, 'tn_severe_std': 0.02,
            'fp_severe': 0.09, 'fp_severe_std': 0.02,
            'fn_severe': 0.02, 'fn_severe_std': 0.01,
            'detection_proportion': 0.90, 'detection_std': 0.05,
            # AI screening costs
            'ai_screening_cost': 11.5, 'ai_screening_cost_std': 3, ## AI screening costs include
            'human_screening_cost': 100, 'human_screening_cost_std': 12,
        }
        defaults.update(kwargs)
        return cls(**defaults)

    @classmethod
    def create_non_ai_pure_scenario(cls, **kwargs):
        """NON-AI PURE SCENARIO - Non-AI transition matrix + Non-AI screening"""
        defaults = {
            'scenario_name': "Non-AI Pure",
            'mild_to_moderate_mean': 0.143, 'mild_to_moderate_std': 0.0323,
            'moderate_to_severe_mean': 0.087, 'moderate_to_severe_std': 0.02603,
            'severe_to_vi_mean': 0.077, 'severe_to_vi_std': 0.02467,
            'true_positive_rate': 0.75, 'tp_std': 0.08,
            'true_negative_rate': 0.80, 'tn_std': 0.08,
            'false_positive_rate': 0.20, 'fp_std': 0.08,
            'false_negative_rate': 0.25, 'fn_std': 0.08,
            # Non-AI state-specific screening accuracy (lower performance)
            'tp_mild': 0.70, 'tp_mild_std': 0.08,
            'tn_mild': 0.85, 'tn_mild_std': 0.06,
            'fp_mild': 0.15, 'fp_mild_std': 0.06,
            'fn_mild': 0.30, 'fn_mild_std': 0.08,
            'tp_moderate': 0.78, 'tp_moderate_std': 0.07,
            'tn_moderate': 0.82, 'tn_moderate_std': 0.06,
            'fp_moderate': 0.18, 'fp_moderate_std': 0.06,
            'fn_moderate': 0.22, 'fn_moderate_std': 0.07,
            'tp_severe': 0.88, 'tp_severe_std': 0.05,
            'tn_severe': 0.78, 'tn_severe_std': 0.07,
            'fp_severe': 0.22, 'fp_severe_std': 0.07,
            'fn_severe': 0.12, 'fn_severe_std': 0.05,
            'detection_proportion': 0.70, 'detection_std': 0.10,
            # Human screening costs (higher due to specialist time)
            'ai_screening_cost': 0.01, 'ai_screening_cost_std': 0.0001,
            'human_screening_cost': 0.01, 'human_screening_cost_std': 0.0000001,
        }
        defaults.update(kwargs)
        instance = cls(**defaults)

        instance._set_non_ai_cost_structure()
        return instance

    def _set_non_ai_cost_structure(self):
        """Set Non-AI cost structure: ONLY VI patients incur costs"""
        zero_cost = Parameter(0, 0, 'gamma')

        # Zero out all costs except VI
        self.costs.update({
            # NO costs for undetected cases
            'monitoring_mild': zero_cost,
            'monitoring_moderate': zero_cost,
            'monitoring_severe': zero_cost,
            'treatment_mild': zero_cost,
            'treatment_moderate': zero_cost,
            'treatment_severe': zero_cost,
            'other_mild': zero_cost,
            'other_moderate': zero_cost,
            'other_severe': zero_cost,
            'productivity_mild': zero_cost,
            'productivity_moderate': zero_cost,
            'productivity_severe': zero_cost,

            # VI costs remain the same (clinically obvious)
            # 'monitoring_vi': unchanged
            # 'treatment_vi': unchanged
            # 'other_vi': unchanged
            # 'productivity_vi': unchanged
        })

        print(f"Applied Non-AI cost structure: Only VI patients incur costs")

    def sample_all(self):
        """Sample all parameters once"""
        sample = {}
        for category in [self.costs, self.utilities, self.transitions,
                        self.screening_accuracy, self.screening_accuracy_mild,
                        self.screening_accuracy_moderate, self.screening_accuracy_severe,
                        self.screening_params, self.discount_rates]:
            for name, param in category.items():
                sample[name] = param.sample(1)[0]
        return sample

    def get_summary(self):
        """Get parameter summary as DataFrame"""
        data = []
        for category_name, category in [('Costs', self.costs),
                                       ('Utilities', self.utilities),
                                       ('Transitions', self.transitions),
                                       ('Screening_Accuracy', self.screening_accuracy),
                                       ('Screening_Accuracy_Mild', self.screening_accuracy_mild),
                                       ('Screening_Accuracy_Moderate', self.screening_accuracy_moderate),
                                       ('Screening_Accuracy_Severe', self.screening_accuracy_severe),
                                       ('Screening_Params', self.screening_params),
                                       ('Discount_Rates', self.discount_rates)]:
            for name, param in category.items():
                data.append({
                    'Category': category_name,
                    'Parameter': name,
                    'Mean': param.mean,
                    'Std': param.std,
                    'Distribution': param.dist_type
                })
        return pd.DataFrame(data)

    def get_screening_cost(self, screening_type='combined'):
        """Get screening cost based on type"""
        if screening_type == 'ai_only':
            return self.costs['ai_screening']
        elif screening_type == 'human_only':
            return self.costs['human_screening']
        elif screening_type == 'combined':
            return self.costs['screening']
        else:
            raise ValueError(f"Unknown screening type: {screening_type}. Use 'ai_only', 'human_only', or 'combined'")

    def get_state_specific_accuracy(self, state):
        """Get screening accuracy parameters for a specific state"""
        if state == 'mild':
            return self.screening_accuracy_mild
        elif state == 'moderate':
            return self.screening_accuracy_moderate
        elif state == 'severe':
            return self.screening_accuracy_severe
        else:
            raise ValueError(f"Unknown state: {state}. Use 'mild', 'moderate', or 'severe'")

print("Enhanced GlaucomaParameters class defined with state-specific screening accuracy!")

Enhanced GlaucomaParameters class defined with state-specific screening accuracy!


In [51]:
class BaseGlaucomaModel:
    """Base class for Glaucoma Health Economic Models"""

    def __init__(self, params=None, starting_age=60, mortality_table=None):
        self.params = params or GlaucomaParameters()
        self.states = ['Mild', 'Moderate', 'Severe', 'VI', 'Dead']
        self.scenario_name = self.params.scenario_name
        self.starting_age = starting_age
        self.mortality_table = mortality_table or self._get_default_mortality_table()

    def _get_default_mortality_table(self):
        """
        Default mortality table with q(x) values
        This should be replaced with actual life table data
        Based on typical developed country life tables
        """
        return {
            40: 0.00143, 41: 0.00154, 42: 0.00166, 43: 0.00179, 44: 0.00194,
            45: 0.00210, 46: 0.00227, 47: 0.00246, 48: 0.00268, 49: 0.00292,
            50: 0.00319, 51: 0.00349, 52: 0.00382, 53: 0.00419, 54: 0.00460,
            55: 0.00505, 56: 0.00555, 57: 0.00610, 58: 0.00671, 59: 0.00738,
            60: 0.00812, 61: 0.00894, 62: 0.00983, 63: 0.01082, 64: 0.01190,
            65: 0.01309, 66: 0.01439, 67: 0.01581, 68: 0.01737, 69: 0.01907,
            70: 0.02094, 71: 0.02298, 72: 0.02522, 73: 0.02767, 74: 0.03035,
            75: 0.03329, 76: 0.03650, 77: 0.04002, 78: 0.04387, 79: 0.04808,
            80: 0.05270, 81: 0.05776, 82: 0.06331, 83: 0.06940, 84: 0.07607,
            85: 0.08339, 86: 0.09141, 87: 0.10020, 88: 0.10982, 89: 0.12035,
            90: 0.13187, 91: 0.14446, 92: 0.15821, 93: 0.17321, 94: 0.18956,
            95: 0.20736, 96: 0.22672, 97: 0.24775, 98: 0.27057, 99: 0.29531,
            100: 0.32210, 101: 0.35109, 102: 0.38243, 103: 0.41628, 104: 0.45281,
            105: 0.49220, 106: 0.53464, 107: 0.58032, 108: 0.62946, 109: 0.68227,
            110: 1.00000  # Assume death at 110
        }

    def load_mortality_table_from_file(self, filepath_male, filepath_female, male_proportion=0.5, age_col='Age', qx_col='qx'):
        """
        Load mortality table from CSV file
        
        Parameters:
        -----------
        filepath : str
            Path to CSV file containing mortality table
        age_col : str
            Name of the age column (default: 'Age')
        qx_col : str
            Name of the q(x) column (default: 'qx')
        
        Returns:
        --------
        dict : Dictionary mapping age to mortality probability
        """
        import pandas as pd
        df_male = pd.read_csv(filepath_male)
        df_female = pd.read_csv(filepath_female)
        df = pd.DataFrame()
        df[age_col] = df_male[age_col]
        df[qx_col] = (male_proportion * df_male [qx_col]) + ((1 - male_proportion) * df_female[qx_col])  
        mortality_dict = dict(zip(df[age_col], df[qx_col]))
        self.mortality_table = mortality_dict
        return mortality_dict

    def get_age_specific_mortality(self, age, sample):
        """
        Get age-specific mortality rate q(x) from mortality table
        
        Parameters:
        -----------
        age : float or int
            Current age
        sample : dict
            Dictionary of sampled parameters (for uncertainty around mortality)
        
        Returns:
        --------
        float : Probability of dying in the next year
        """
        age_int = int(round(age))
        
        # Get base mortality from table
        if age_int in self.mortality_table:
            base_qx = self.mortality_table[age_int]
        else:
            # If age not in table, use nearest available or extrapolate
            if age_int < min(self.mortality_table.keys()):
                base_qx = self.mortality_table[min(self.mortality_table.keys())]
            else:
                base_qx = self.mortality_table[max(self.mortality_table.keys())]
        
        # Optional: Apply uncertainty/variation to mortality rates in PSA
        # mortality_adjustment = sample.get('mortality_adjustment_factor', 1.0)
        # base_qx = base_qx * mortality_adjustment
        
        return np.clip(base_qx, 0, 1)

    def get_state_specific_mortality_multiplier(self, state, sample):
        """
        Get mortality multiplier for each health state (hazard ratio vs general population)
        These represent excess mortality risk due to glaucoma severity
        """
        multipliers = {
            'Mild': sample.get('mortality_multiplier_mild', 1.0), ## modify the mortality multipliers as needed
            'Moderate': sample.get('mortality_multiplier_moderate', 1.05),
            'Severe': sample.get('mortality_multiplier_severe', 1.10),
            'VI': sample.get('mortality_multiplier_vi', 1.20),
            'Dead': 0.0
        }
        return multipliers.get(state, 1.0)

    def get_transition_matrix(self, sample, age):
        """Build transition matrix from sampled parameters including age-specific mortality"""
        p1 = sample['mild_to_moderate']
        p2 = sample['moderate_to_severe']
        p3 = sample['severe_to_vi']
        p1, p2, p3 = np.clip([p1, p2, p3], 0, 1)
        
        # Get age-specific base mortality q(x) from mortality table
        base_qx = self.get_age_specific_mortality(age, sample)
        
        # Apply state-specific mortality multipliers (hazard ratios)
        mort_mild = base_qx * self.get_state_specific_mortality_multiplier('Mild', sample)
        mort_moderate = base_qx * self.get_state_specific_mortality_multiplier('Moderate', sample)
        mort_severe = base_qx * self.get_state_specific_mortality_multiplier('Severe', sample)
        mort_vi = base_qx * self.get_state_specific_mortality_multiplier('VI', sample)
        
        # Clip mortality rates to valid range [0, 1]
        mort_mild, mort_moderate, mort_severe, mort_vi = np.clip(
            [mort_mild, mort_moderate, mort_severe, mort_vi], 0, 1
        )
        
        # Ensure transition probabilities + mortality don't exceed 1
        p1 = np.clip(p1, 0, 1 - mort_mild)
        p2 = np.clip(p2, 0, 1 - mort_moderate)
        p3 = np.clip(p3, 0, 1 - mort_severe)
        
        # Build transition matrix with age-specific mortality
        # Each row sums to 1
        return np.array([
            # From Mild: stay, progress to Moderate, skip, skip, die
            [1 - p1 - mort_mild, p1, 0, 0, mort_mild],
            # From Moderate: skip, stay, progress to Severe, skip, die
            [0, 1 - p2 - mort_moderate, p2, 0, mort_moderate],
            # From Severe: skip, skip, stay, progress to VI, die
            [0, 0, 1 - p3 - mort_severe, p3, mort_severe],
            # From VI: skip, skip, skip, stay, die
            [0, 0, 0, 1 - mort_vi, mort_vi],
            # From Dead: stay dead
            [0, 0, 0, 0, 1]
        ])

    def simulate_cohort(self, initial_dist, years, sample):
        """Simulate cohort over time with age-dependent mortality"""
        n_states = len(self.states)
        cohort = np.zeros((years + 1, n_states))
        cohort[0] = initial_dist
        
        # Store age-specific transition matrices and ages for tracing
        transition_matrices = []
        ages = []
        
        for year in range(years):
            current_age = self.starting_age + year
            ages.append(current_age)
            trans_matrix = self.get_transition_matrix(sample, current_age)
            transition_matrices.append(trans_matrix)
            cohort[year + 1] = cohort[year] @ trans_matrix
        
        return cohort, transition_matrices, ages

    def create_detailed_traces(self, cohort, costs, qalys, costs_disc, qalys_disc,
                              state_costs, state_utilities, sample, screening_costs,
                              transition_matrices=None, ages=None):
        """Create detailed year-by-year traces"""
        years = len(cohort)
        trace_data = []

        for year in range(years):
            current_age = ages[year] if ages and year < len(ages) else self.starting_age + year
            cost_discount_factor = 1 / (1 + sample['cost_discount']) ** year
            health_discount_factor = 1 / (1 + sample['health_discount']) ** year
            state_costs_year = cohort[year][:4] * state_costs  # Only living states have costs
            state_utilities_year = cohort[year][:4] * state_utilities  # Only living states have utilities

            row = {
                'Year': year,
                'Age': current_age,
                'Prop_Mild': cohort[year][0],
                'Prop_Moderate': cohort[year][1],
                'Prop_Severe': cohort[year][2],
                'Prop_VI': cohort[year][3],
                'Prop_Dead': cohort[year][4],
                'Prop_Alive': 1 - cohort[year][4],
                'Total_Cost': costs[year],
                'Total_QALY': qalys[year],
                'Total_Cost_Disc': costs_disc[year],
                'Total_QALY_Disc': qalys_disc[year],
                'Screening_Cost': screening_costs[year],
                'Cost_Mild': state_costs_year[0],
                'Cost_Moderate': state_costs_year[1],
                'Cost_Severe': state_costs_year[2],
                'Cost_VI': state_costs_year[3],
                'QALY_Mild': state_utilities_year[0],
                'QALY_Moderate': state_utilities_year[1],
                'QALY_Severe': state_utilities_year[2],
                'QALY_VI': state_utilities_year[3],
                'Cost_Discount_Factor': cost_discount_factor,
                'Health_Discount_Factor': health_discount_factor,
            }
            
            # Add age-specific mortality rates if transition matrices available
            if transition_matrices and year < len(transition_matrices):
                row['Mortality_Rate_Mild'] = transition_matrices[year][0, 4]
                row['Mortality_Rate_Moderate'] = transition_matrices[year][1, 4]
                row['Mortality_Rate_Severe'] = transition_matrices[year][2, 4]
                row['Mortality_Rate_VI'] = transition_matrices[year][3, 4]
                
                # Also store base mortality from table
                base_mort = self.get_age_specific_mortality(current_age, sample)
                row['Base_Mortality_qx'] = base_mort
            
            trace_data.append(row)

        return pd.DataFrame(trace_data)

    # Abstract methods to be implemented by subclasses
    def calculate_outcomes(self, cohort, sample, include_screening=True, population_type='general', 
                          transition_matrices=None, ages=None):
        raise NotImplementedError("Subclasses must implement calculate_outcomes")

    def run_deterministic(self, initial_dist=None, years=10, include_screening=True, 
                         population_type='general', starting_age=None):
        raise NotImplementedError("Subclasses must implement run_deterministic")

    def run_probabilistic(self, n_iterations=1000, initial_dist=None, years=10,
                         include_screening=False, population_type='general', random_seed=42, 
                         return_traces=False, starting_age=None):
        raise NotImplementedError("Subclasses must implement run_probabilistic")


class AIGlaucomaModel(BaseGlaucomaModel):
    """AI-Enhanced Glaucoma Model with advanced screening and early detection"""

    def __init__(self, params=None, starting_age=60, mortality_table=None):
        if params is None:
            params = GlaucomaParameters.create_ai_pure_scenario()
        super().__init__(params, starting_age, mortality_table)
        self.model_type = "AI_Enhanced"

    def calculate_outcomes(self, cohort, sample, include_screening=True, population_type='general',
                          transition_matrices=None, ages=None):
        """AI model: All detected cases incur monitoring and treatment costs"""
        years = len(cohort)
        costs = np.zeros(years)
        qalys = np.zeros(years)
        costs_discounted = np.zeros(years)
        qalys_discounted = np.zeros(years)
        screening_costs = np.zeros(years)

        cost_discount_rate = sample['cost_discount']
        health_discount_rate = sample['health_discount']

        # AI Model: ALL living states incur costs when detected (comprehensive care)
        state_costs = [
            sample['monitoring_mild'] + sample['treatment_mild'] + sample['other_mild'] + sample['productivity_mild'],
            sample['monitoring_moderate'] + sample['treatment_moderate'] + sample['other_moderate'] + sample['productivity_moderate'],
            sample['monitoring_severe'] + sample['treatment_severe'] + sample['other_severe'] + sample['productivity_severe'],
            sample['monitoring_vi'] + sample['treatment_vi'] + sample['other_vi'] + sample['productivity_vi']
        ]

        state_utilities = [
            sample['utility_mild'],
            sample['utility_moderate'],
            sample['utility_severe'],
            sample['utility_vi']
        ]

        if include_screening:
            annual_screening_cost = sample['ai_screening']
            detection_multiplier = sample['detection_proportion']
        else:
            annual_screening_cost = 0
            detection_multiplier = 1.0

        for year in range(years):
            # Only apply costs to living states (first 4 states)
            costs[year] = np.sum(cohort[year][:4] * state_costs) * detection_multiplier
            # Only living patients accumulate QALYs
            qalys[year] = np.sum(cohort[year][:4] * state_utilities)

            if include_screening and year == 0:
                screening_costs[year] = annual_screening_cost
                costs[year] += screening_costs[year]

            cost_discount_factor = 1 / (1 + cost_discount_rate) ** year
            health_discount_factor = 1 / (1 + health_discount_rate) ** year
            costs_discounted[year] = costs[year] * cost_discount_factor
            qalys_discounted[year] = qalys[year] * health_discount_factor

        return costs, qalys, costs_discounted, qalys_discounted, state_costs, state_utilities, screening_costs

    def run_deterministic(self, initial_dist=None, years=10, include_screening=True, 
                         population_type='general', starting_age=None):
        """Run AI model with mean parameter values"""
        if starting_age is not None:
            self.starting_age = starting_age
            
        if initial_dist is None:
            initial_dist = [1, 0, 0, 0, 0]  # All start in Mild, none dead

        sample = {}
        for category in [self.params.costs, self.params.utilities, self.params.transitions,
                        self.params.screening_accuracy, self.params.screening_accuracy_mild,
                        self.params.screening_accuracy_moderate, self.params.screening_accuracy_severe,
                        self.params.screening_params, self.params.discount_rates]:
            for name, param in category.items():
                sample[name] = param.mean

        # Add mortality multiplier parameters if they exist
        if hasattr(self.params, 'mortality_multipliers'):
            for name, param in self.params.mortality_multipliers.items():
                sample[name] = param.mean

        cohort, transition_matrices, ages = self.simulate_cohort(initial_dist, years, sample)
        costs, qalys, costs_disc, qalys_disc, state_costs, state_utilities, screening_costs = self.calculate_outcomes(
            cohort, sample, include_screening, population_type, transition_matrices, ages)

        traces = self.create_detailed_traces(cohort, costs, qalys, costs_disc, qalys_disc,
                                           state_costs, state_utilities, sample, screening_costs,
                                           transition_matrices, ages)

        return {
            'cohort': cohort,
            'costs': costs,
            'qalys': qalys,
            'costs_discounted': costs_disc,
            'qalys_discounted': qalys_disc,
            'total_cost': np.sum(costs),
            'total_qalys': np.sum(qalys),
            'total_cost_discounted': np.sum(costs_disc),
            'total_qalys_discounted': np.sum(qalys_disc),
            'traces': traces,
            'state_costs': state_costs,
            'state_utilities': state_utilities,
            'screening_costs': screening_costs,
            'sample_params': sample,
            'scenario_name': self.scenario_name,
            'model_type': self.model_type,
            'starting_age': self.starting_age,
            'transition_matrices': transition_matrices,
            'ages': ages
        }

    def run_probabilistic(self, n_iterations=1000, initial_dist=None, years=10,
                         include_screening=True, population_type='general', random_seed=42, 
                         return_traces=False, starting_age=None):
        """Run AI model probabilistic sensitivity analysis"""
        if starting_age is not None:
            self.starting_age = starting_age
            
        if initial_dist is None:
            initial_dist = [1, 0, 0, 0, 0]  # All start in Mild, none dead

        np.random.seed(random_seed)

        results = {
            'total_costs': [],
            'total_qalys': [],
            'total_costs_discounted': [],
            'total_qalys_discounted': [],
            'iterations': [],
            'parameters': [],
            'scenario_name': self.scenario_name,
            'model_type': self.model_type,
            'starting_age': self.starting_age
        }

        trace_tensor = None
        trace_variable_names = None

        if return_traces:
            # Define trace variables we want to store (now including Dead state, Age, and mortality rates)
            trace_vars = [
                'Year', 'Age', 'Prop_Mild', 'Prop_Moderate', 'Prop_Severe', 'Prop_VI', 'Prop_Dead', 'Prop_Alive',
                'Total_Cost', 'Total_QALY', 'Total_Cost_Disc', 'Total_QALY_Disc',
                'Screening_Cost', 'Cost_Mild', 'Cost_Moderate', 'Cost_Severe', 'Cost_VI',
                'QALY_Mild', 'QALY_Moderate', 'QALY_Severe', 'QALY_VI',
                'Cost_Discount_Factor', 'Health_Discount_Factor',
                'Mortality_Rate_Mild', 'Mortality_Rate_Moderate', 'Mortality_Rate_Severe', 'Mortality_Rate_VI',
                'Base_Mortality_qx'
            ]

            # Initialize 3D tensor: [iterations, years, variables]
            trace_tensor = np.zeros((n_iterations, years + 1, len(trace_vars)))
            trace_variable_names = trace_vars
            results['trace_variable_names'] = trace_variable_names

        print(f"Running {n_iterations} PSA iterations for AI Enhanced Model (Starting Age: {self.starting_age})...")

        for i in range(n_iterations):
            if (i + 1) % 100 == 0:
                print(f"  AI Model Iteration {i + 1}/{n_iterations}")

            sample = self.params.sample_all()
            cohort, transition_matrices, ages = self.simulate_cohort(initial_dist, years, sample)
            costs, qalys, costs_disc, qalys_disc, state_costs, state_utilities, screening_costs = self.calculate_outcomes(
                cohort, sample, include_screening, population_type, transition_matrices, ages)

            results['total_costs'].append(np.sum(costs))
            results['total_qalys'].append(np.sum(qalys))
            results['total_costs_discounted'].append(np.sum(costs_disc))
            results['total_qalys_discounted'].append(np.sum(qalys_disc))
            results['iterations'].append(i)
            results['parameters'].append(sample)

            # Store detailed traces if requested
            if return_traces:
                traces_df = self.create_detailed_traces(cohort, costs, qalys, costs_disc, qalys_disc,
                                                       state_costs, state_utilities, sample, screening_costs,
                                                       transition_matrices, ages)

                # Extract values for each trace variable and store in tensor
                for year_idx in range(years + 1):
                    year_data = traces_df.iloc[year_idx]
                    for var_idx, var_name in enumerate(trace_variable_names):
                        trace_tensor[i, year_idx, var_idx] = year_data[var_name]

        for key in ['total_costs', 'total_qalys', 'total_costs_discounted', 'total_qalys_discounted']:
            results[key] = np.array(results[key])

        results['trace_tensor'] = trace_tensor
        results['trace_variable_names'] = trace_variable_names

        return results


class NonAIGlaucomaModel(BaseGlaucomaModel):
    """Traditional/Non-AI Glaucoma Model with conventional screening and late detection"""

    def __init__(self, params=None, starting_age=60, mortality_table=None):
        if params is None:
            params = GlaucomaParameters.create_non_ai_pure_scenario()
        super().__init__(params, starting_age, mortality_table)
        self.model_type = "Traditional_NonAI"

    def calculate_outcomes(self, cohort, sample, include_screening=False, population_type='general',
                          transition_matrices=None, ages=None):
        """Non-AI model: Only VI patients incur costs (late detection model)"""
        years = len(cohort)
        costs = np.zeros(years)
        qalys = np.zeros(years)
        costs_discounted = np.zeros(years)
        qalys_discounted = np.zeros(years)
        screening_costs = np.zeros(years)

        cost_discount_rate = sample['cost_discount']
        health_discount_rate = sample['health_discount']

        # Non-AI Model: ONLY VI patients incur costs (early stages undetected)
        state_costs = [
            sample['monitoring_mild'] + sample['treatment_mild'] + sample['other_mild'] + sample['productivity_mild'],
            sample['monitoring_moderate'] + sample['treatment_moderate'] + sample['other_moderate'] + sample['productivity_moderate'],
            sample['monitoring_severe'] + sample['treatment_severe'] + sample['other_severe'] + sample['productivity_severe'],
            sample['monitoring_vi'] + sample['treatment_vi'] + sample['other_vi'] + sample['productivity_vi']
        ]

        state_utilities = [
            sample['utility_mild'],
            sample['utility_moderate'],
            sample['utility_severe'],
            sample['utility_vi']
        ]

        if include_screening:
            annual_screening_cost = sample['human_screening']
            detection_multiplier = 1.0
        else:
            annual_screening_cost = 0
            detection_multiplier = 1.0

        for year in range(years):
            # Only apply costs to living states (first 4 states)
            costs[year] = np.sum(cohort[year][:4] * state_costs)
            # Only living patients accumulate QALYs
            qalys[year] = np.sum(cohort[year][:4] * state_utilities)

            if include_screening and year == 0:
                screening_costs[year] = annual_screening_cost
                costs[year] += screening_costs[year]

            cost_discount_factor = 1 / (1 + cost_discount_rate) ** year
            health_discount_factor = 1 / (1 + health_discount_rate) ** year
            costs_discounted[year] = costs[year] * cost_discount_factor
            qalys_discounted[year] = qalys[year] * health_discount_factor

        return costs, qalys, costs_discounted, qalys_discounted, state_costs, state_utilities, screening_costs

    def run_deterministic(self, initial_dist=None, years=10, include_screening=False, 
                         population_type='general', starting_age=None):
        """Run Non-AI model with mean parameter values"""
        if starting_age is not None:
            self.starting_age = starting_age
            
        if initial_dist is None:
            initial_dist = [1, 0, 0, 0, 0]  # All start in Mild, none dead

        sample = {}
        for category in [self.params.costs, self.params.utilities, self.params.transitions,
                        self.params.screening_accuracy, self.params.screening_accuracy_mild,
                        self.params.screening_accuracy_moderate, self.params.screening_accuracy_severe,
                        self.params.screening_params, self.params.discount_rates]:
            for name, param in category.items():
                sample[name] = param.mean

        # Add mortality multiplier parameters if they exist
        if hasattr(self.params, 'mortality_multipliers'):
            for name, param in self.params.mortality_multipliers.items():
                sample[name] = param.mean

        cohort, transition_matrices, ages = self.simulate_cohort(initial_dist, years, sample)
        costs, qalys, costs_disc, qalys_disc, state_costs, state_utilities, screening_costs = self.calculate_outcomes(
            cohort, sample, include_screening, population_type, transition_matrices, ages)

        traces = self.create_detailed_traces(cohort, costs, qalys, costs_disc, qalys_disc,
                                           state_costs, state_utilities, sample, screening_costs,
                                           transition_matrices, ages)

        return {
            'cohort': cohort,
            'costs': costs,
            'qalys': qalys,
            'costs_discounted': costs_disc,
            'qalys_discounted': qalys_disc,
            'total_cost': np.sum(costs),
            'total_qalys': np.sum(qalys),
            'total_cost_discounted': np.sum(costs_disc),
            'total_qalys_discounted': np.sum(qalys_disc),
            'traces': traces,
            'state_costs': state_costs,
            'state_utilities': state_utilities,
            'screening_costs': screening_costs,
            'sample_params': sample,
            'scenario_name': self.scenario_name,
            'model_type': self.model_type,
            'starting_age': self.starting_age,
            'transition_matrices': transition_matrices,
            'ages': ages
        }

    def run_probabilistic(self, n_iterations=1000, initial_dist=None, years=10,
                         include_screening=False, population_type='general', random_seed=42, 
                         return_traces=False, starting_age=None):
        """Run Non-AI model probabilistic sensitivity analysis"""
        if starting_age is not None:
            self.starting_age = starting_age
            
        if initial_dist is None:
            initial_dist = [1, 0, 0, 0, 0]  # All start in Mild, none dead

        np.random.seed(random_seed)

        results = {
            'total_costs': [],
            'total_qalys': [],
            'total_costs_discounted': [],
            'total_qalys_discounted': [],
            'iterations': [],
            'parameters': [],
            'scenario_name': self.scenario_name,
            'model_type': self.model_type,
            'starting_age': self.starting_age
        }

        trace_tensor = None
        trace_variable_names = None

        if return_traces:
            # Define trace variables we want to store (now including Dead state, Age, and mortality rates)
            trace_vars = [
                'Year', 'Age', 'Prop_Mild', 'Prop_Moderate', 'Prop_Severe', 'Prop_VI', 'Prop_Dead', 'Prop_Alive',
                'Total_Cost', 'Total_QALY', 'Total_Cost_Disc', 'Total_QALY_Disc',
                'Screening_Cost', 'Cost_Mild', 'Cost_Moderate', 'Cost_Severe', 'Cost_VI',
                'QALY_Mild', 'QALY_Moderate', 'QALY_Severe', 'QALY_VI',
                'Cost_Discount_Factor', 'Health_Discount_Factor',
                'Mortality_Rate_Mild', 'Mortality_Rate_Moderate', 'Mortality_Rate_Severe', 'Mortality_Rate_VI',
                'Base_Mortality_qx'
            ]

            # Initialize 3D tensor: [iterations, years, variables]
            trace_tensor = np.zeros((n_iterations, years + 1, len(trace_vars)))
            trace_variable_names = trace_vars
            results['trace_variable_names'] = trace_variable_names

        print(f"Running {n_iterations} PSA iterations for Traditional Non-AI Model (Starting Age: {self.starting_age})...")

        for i in range(n_iterations):
            if (i + 1) % 100 == 0:
                print(f"  Non-AI Model Iteration {i + 1}/{n_iterations}")

            sample = self.params.sample_all()
            cohort, transition_matrices, ages = self.simulate_cohort(initial_dist, years, sample)
            costs, qalys, costs_disc, qalys_disc, state_costs, state_utilities, screening_costs = self.calculate_outcomes(
                cohort, sample, include_screening, population_type, transition_matrices, ages)

            results['total_costs'].append(np.sum(costs))
            results['total_qalys'].append(np.sum(qalys))
            results['total_costs_discounted'].append(np.sum(costs_disc))
            results['total_qalys_discounted'].append(np.sum(qalys_disc))
            results['iterations'].append(i)
            results['parameters'].append(sample)

            # Store detailed traces if requested
            if return_traces:
                traces_df = self.create_detailed_traces(cohort, costs, qalys, costs_disc, qalys_disc,
                                                       state_costs, state_utilities, sample, screening_costs,
                                                       transition_matrices, ages)

                # Extract values for each trace variable and store in tensor
                for year_idx in range(years + 1):
                    year_data = traces_df.iloc[year_idx]
                    for var_idx, var_name in enumerate(trace_variable_names):
                        trace_tensor[i, year_idx, var_idx] = year_data[var_name]

        for key in ['total_costs', 'total_qalys', 'total_costs_discounted', 'total_qalys_discounted']:
            results[key] = np.array(results[key])

        results['trace_tensor'] = trace_tensor
        results['trace_variable_names'] = trace_variable_names

        return results


# Comparison and utility functions (same as before)
def compare_ai_vs_nonai_models(results_ai, results_non_ai, discounted=True):
    """Compare AI vs Non-AI model results and calculate incremental metrics"""
    if discounted:
        costs_ai = results_ai['total_costs_discounted']
        qalys_ai = results_ai['total_qalys_discounted']
        costs_non_ai = results_non_ai['total_costs_discounted']
        qalys_non_ai = results_non_ai['total_qalys_discounted']
    else:
        costs_ai = results_ai['total_costs']
        qalys_ai = results_ai['total_qalys']
        costs_non_ai = results_non_ai['total_costs']
        qalys_non_ai = results_non_ai['total_qalys']

    incremental_costs = costs_ai - costs_non_ai
    incremental_qalys = qalys_ai - qalys_non_ai

    icer_values = np.where(incremental_qalys != 0,
                          incremental_costs / incremental_qalys,
                          np.inf)

    comparison = {
        'incremental_costs_mean': np.mean(incremental_costs),
        'incremental_costs_std': np.std(incremental_costs),
        'incremental_qalys_mean': np.mean(incremental_qalys),
        'incremental_qalys_std': np.std(incremental_qalys),
        'icer_mean': np.mean(icer_values[np.isfinite(icer_values)]),
        'icer_std': np.std(icer_values[np.isfinite(icer_values)]),
        'costs_ai_mean': np.mean(costs_ai),
        'qalys_ai_mean': np.mean(qalys_ai),
        'costs_non_ai_mean': np.mean(costs_non_ai),
        'qalys_non_ai_mean': np.mean(qalys_non_ai),
        'ai_model_type': results_ai.get('model_type', 'AI'),
        'non_ai_model_type': results_non_ai.get('model_type', 'Non-AI'),
        'discounted': discounted
    }

    return comparison


def run_full_ai_vs_nonai_analysis(years=10, n_iterations=1000, return_traces=False, 
                                   starting_age=60, mortality_table=None):
    """Run complete analysis comparing separate AI and Non-AI models"""

    print("=== Running Full AI vs Non-AI Model Comparison ===")

    ai_model = AIGlaucomaModel(starting_age=starting_age, mortality_table=mortality_table)
    nonai_model = NonAIGlaucomaModel(starting_age=starting_age, mortality_table=mortality_table)

    print(f"AI Model: {ai_model.model_type} (Starting Age: {starting_age})")
    print(f"Non-AI Model: {nonai_model.model_type} (Starting Age: {starting_age})")

    print("\n1. Running deterministic analyses...")
    det_ai = ai_model.run_deterministic(years=years)
    det_nonai = nonai_model.run_deterministic(years=years)

    print("\n2. Running probabilistic analyses...")
    psa_ai = ai_model.run_probabilistic(n_iterations=n_iterations, years=years, return_traces=return_traces)
    psa_nonai = nonai_model.run_probabilistic(n_iterations=n_iterations, years=years, return_traces=return_traces)

    print("\n3. Comparing results...")
    comparison = compare_ai_vs_nonai_models(psa_ai, psa_nonai, discounted=True)

    return {
        'ai_model': ai_model,
        'nonai_model': nonai_model,
        'deterministic_ai': det_ai,
        'deterministic_nonai': det_nonai,
        'probabilistic_ai': psa_ai,
        'probabilistic_nonai': psa_nonai,
        'comparison': comparison
    }


def quick_model_comparison(starting_age=60, mortality_table=None):
    """Quick 2-line comparison of AI vs Non-AI models"""
    ai_model = AIGlaucomaModel(starting_age=starting_age, mortality_table=mortality_table)
    nonai_model = NonAIGlaucomaModel(starting_age=starting_age, mortality_table=mortality_table)

    ai_results = ai_model.run_deterministic()
    nonai_results = nonai_model.run_deterministic()

    print(f"Starting Age: {starting_age}")
    print(f"AI Model Total Cost (Discounted): ${ai_results['total_cost_discounted']:,.0f}")
    print(f"AI Model Total QALYs (Discounted): {ai_results['total_qalys_discounted']:.2f}")
    print(f"Non-AI Model Total Cost (Discounted): ${nonai_results['total_cost_discounted']:,.0f}")
    print(f"Non-AI Model Total QALYs (Discounted): {nonai_results['total_qalys_discounted']:.2f}")

    return ai_results, nonai_results


# Utility functions for working with trace tensors (same as before)
def get_trace_summary_stats(trace_tensor, trace_variable_names, variable_name, year=None):
    """Get summary statistics for a specific variable across all simulations"""
    var_idx = trace_variable_names.index(variable_name)

    if year is not None:
        data = trace_tensor[:, year, var_idx]
        return {
            'mean': np.mean(data),
            'std': np.std(data),
            'median': np.median(data),
            'q25': np.percentile(data, 25),
            'q75': np.percentile(data, 75),
            'min': np.min(data),
            'max': np.max(data)
        }
    else:
        data = trace_tensor[:, :, var_idx]
        return {
            'mean': np.mean(data, axis=0),
            'std': np.std(data, axis=0),
            'median': np.median(data, axis=0),
            'q25': np.percentile(data, 25, axis=0),
            'q75': np.percentile(data, 75, axis=0),
            'min': np.min(data, axis=0),
            'max': np.max(data, axis=0)
        }


def get_trace_percentiles(trace_tensor, trace_variable_names, variable_name, percentiles=[5, 25, 50, 75, 95]):
    """Get specified percentiles for a variable across all years and simulations"""
    var_idx = trace_variable_names.index(variable_name)
    data = trace_tensor[:, :, var_idx]

    result = {}
    for p in percentiles:
        result[f'p{p}'] = np.percentile(data, p, axis=0)

    return result


def extract_trace_variable(trace_tensor, trace_variable_names, variable_name):
    """Extract a specific variable from the trace tensor"""
    var_idx = trace_variable_names.index(variable_name)
    return trace_tensor[:, :, var_idx]


def compare_trace_variables(trace_tensor_ai, trace_tensor_nonai, trace_variable_names, variable_name):
    """Compare a specific variable between AI and Non-AI models"""
    ai_data = extract_trace_variable(trace_tensor_ai, trace_variable_names, variable_name)
    nonai_data = extract_trace_variable(trace_tensor_nonai, trace_variable_names, variable_name)

    incremental = ai_data - nonai_data

    return {
        'ai_mean': np.mean(ai_data, axis=0),
        'nonai_mean': np.mean(nonai_data, axis=0),
        'incremental_mean': np.mean(incremental, axis=0),
        'incremental_std': np.std(incremental, axis=0),
        'incremental_median': np.median(incremental, axis=0),
        'incremental_q25': np.percentile(incremental, 25, axis=0),
        'incremental_q75': np.percentile(incremental, 75, axis=0)
    }


def create_trace_dataframe(trace_tensor, trace_variable_names, iteration=0):
    """Convert a single iteration's trace tensor to a DataFrame for easy viewing"""
    data = trace_tensor[iteration, :, :]
    return pd.DataFrame(data, columns=trace_variable_names)


print("Enhanced models with age-dependent mortality from life tables defined!")

Enhanced models with age-dependent mortality from life tables defined!


In [52]:
model_ai = AIGlaucomaModel(starting_age=55)
model_ai.load_mortality_table_from_file(filepath_male='data/male_mortality_2023.csv', 
                                        filepath_female='data/female_mortality_2023.csv', 
                                        male_proportion=0.5)

results_ai = model_ai.run_deterministic(years=10)

In [53]:
results_ai['traces'].head()

Unnamed: 0,Year,Age,Prop_Mild,Prop_Moderate,Prop_Severe,Prop_VI,Prop_Dead,Prop_Alive,Total_Cost,Total_QALY,...,QALY_Moderate,QALY_Severe,QALY_VI,Cost_Discount_Factor,Health_Discount_Factor,Mortality_Rate_Mild,Mortality_Rate_Moderate,Mortality_Rate_Severe,Mortality_Rate_VI,Base_Mortality_qx
0,0,55,1.0,0.0,0.0,0.0,0.0,1.0,601.0,0.985,...,0.0,0.0,0.0,1.0,1.0,0.00437,0.004589,0.004807,0.005244,0.00437
1,1,56,0.93763,0.058,0.0,0.0,0.00437,0.99563,599.295285,0.975708,...,0.052142,0.0,0.0,0.970874,0.985222,0.00483,0.005072,0.005313,0.005796,0.00483
2,2,57,0.878719,0.109768,0.00232,0.0,0.009193,0.990807,608.743008,0.966013,...,0.098682,0.001793,0.0,0.942596,0.970662,0.00541,0.00568,0.005951,0.006492,0.00541
3,3,58,0.822999,0.15572,0.006623,7.4e-05,0.014584,0.985416,618.599614,0.955813,...,0.139992,0.005119,4.7e-05,0.915142,0.956317,0.005755,0.006043,0.006331,0.006906,0.005755
4,4,59,0.770529,0.196284,0.012598,0.000286,0.020304,0.979696,629.708867,0.945349,...,0.176459,0.009738,0.000181,0.888487,0.942184,0.005475,0.005749,0.006023,0.00657,0.005475


In [17]:
model_ai.params.get_summary()

Unnamed: 0,Category,Parameter,Mean,Std,Distribution
0,Costs,monitoring_mild,352.0,70.4,gamma
1,Costs,monitoring_moderate,463.0,92.6,gamma
2,Costs,monitoring_severe,644.0,128.8,gamma
3,Costs,monitoring_vi,576.0,115.2,gamma
4,Costs,treatment_mild,303.0,60.6,gamma
5,Costs,treatment_moderate,429.0,85.8,gamma
6,Costs,treatment_severe,609.0,121.8,gamma
7,Costs,treatment_vi,662.0,132.4,gamma
8,Costs,other_mild,0.0,0.0,gamma
9,Costs,other_moderate,0.0,0.0,gamma


In [None]:
results_ai['traces']



Unnamed: 0,Year,Age,Prop_Mild,Prop_Moderate,Prop_Severe,Prop_VI,Prop_Dead,Prop_Alive,Total_Cost,Total_QALY,...,QALY_Moderate,QALY_Severe,QALY_VI,Cost_Discount_Factor,Health_Discount_Factor,Mortality_Rate_Mild,Mortality_Rate_Moderate,Mortality_Rate_Severe,Mortality_Rate_VI,Base_Mortality_qx
0,0,55,1.0,0.0,0.0,0.0,0.0,1.0,601.0,0.985,...,0.0,0.0,0.0,1.0,1.0,0.00505,0.005302,0.005555,0.00606,0.00505
1,1,56,0.93695,0.058,0.0,0.0,0.00505,0.99495,598.894425,0.975038,...,0.052142,0.0,0.0,0.970874,0.985222,0.00555,0.005828,0.006105,0.00666,0.00555
2,2,57,0.877407,0.109685,0.00232,0.0,0.010588,0.989412,607.902791,0.964646,...,0.098607,0.001793,0.0,0.942596,0.970662,0.0061,0.006405,0.00671,0.00732,0.0061
3,3,58,0.821165,0.155485,0.006618,7.4e-05,0.016658,0.983342,617.323978,0.953791,...,0.139781,0.005115,4.7e-05,0.915142,0.956317,0.00671,0.007045,0.007381,0.008052,0.00671
4,4,59,0.768027,0.195797,0.012576,0.000285,0.023313,0.976687,627.816557,0.942431,...,0.176022,0.009722,0.000181,0.888487,0.942184,0.00738,0.007749,0.008118,0.008856,0.00738
5,5,60,0.717814,0.230994,0.019904,0.000685,0.030603,0.969397,639.913,0.93053,...,0.207664,0.015386,0.000434,0.862609,0.92826,0.00812,0.008526,0.008932,0.009744,0.00812
6,6,61,0.670352,0.261418,0.028329,0.001316,0.038586,0.961414,654.020685,0.918044,...,0.235015,0.021898,0.000834,0.837484,0.914542,0.00894,0.009387,0.009834,0.010728,0.00894
7,7,62,0.625479,0.287388,0.0376,0.002208,0.047325,0.952675,670.427934,0.904923,...,0.258362,0.029065,0.0014,0.813092,0.901027,0.00983,0.010322,0.010813,0.011796,0.00983
8,8,63,0.583052,0.309204,0.047486,0.003385,0.056873,0.943127,689.323059,0.891134,...,0.277974,0.036707,0.002146,0.789409,0.887711,0.01082,0.011361,0.011902,0.012984,0.01082
9,9,64,0.542927,0.32714,0.05777,0.004861,0.067303,0.932697,710.77251,0.876619,...,0.294099,0.044656,0.003082,0.766417,0.874592,0.0119,0.012495,0.01309,0.01428,0.0119


In [18]:
model_non_ai = NonAIGlaucomaModel(starting_age=55)
results_non_ai = model_non_ai.run_deterministic(years=10)

Applied Non-AI cost structure: Only VI patients incur costs


In [44]:
df_male = pd.read_csv("data/mltper_1x1.txt", delim_whitespace=True, skiprows=1)
df_male.head()


Unnamed: 0,Year,Age,mx,qx,ax,lx,dx,Lx,Tx,ex
0,1940,0,0.15228,0.1376,0.3,100000,13760,90357,4907719,49.08
1,1940,1,0.05115,0.04987,0.5,86240,4301,84090,4817363,55.86
2,1940,2,0.01991,0.01971,0.5,81939,1615,81132,4733273,57.77
3,1940,3,0.00999,0.00994,0.5,80324,798,79925,4652141,57.92
4,1940,4,0.00683,0.00681,0.5,79526,541,79255,4572216,57.49


In [19]:
results_non_ai['traces']

Unnamed: 0,Year,Age,Prop_Mild,Prop_Moderate,Prop_Severe,Prop_VI,Prop_Dead,Prop_Alive,Total_Cost,Total_QALY,...,QALY_Moderate,QALY_Severe,QALY_VI,Cost_Discount_Factor,Health_Discount_Factor,Mortality_Rate_Mild,Mortality_Rate_Moderate,Mortality_Rate_Severe,Mortality_Rate_VI,Base_Mortality_qx
0,0,55,1.0,0.0,0.0,0.0,0.0,1.0,0.0,0.985,...,0.0,0.0,0.0,1.0,1.0,0.00505,0.005302,0.005555,0.00606,0.00505
1,1,56,0.85195,0.143,0.0,0.0,0.00505,0.99495,0.0,0.967728,...,0.128557,0.0,0.0,0.970874,0.985222,0.00555,0.005828,0.006105,0.00666,0.00555
2,2,57,0.725393,0.251555,0.012441,0.0,0.010612,0.989388,0.0,0.950276,...,0.226148,0.009617,0.0,0.942596,0.970662,0.0061,0.006405,0.00671,0.00732,0.0061
3,3,58,0.617237,0.331789,0.033285,0.000958,0.016731,0.983269,13.783085,0.932593,...,0.298279,0.025729,0.000607,0.915142,0.956317,0.00671,0.007045,0.007381,0.008052,0.00671
4,4,59,0.52483,0.388851,0.059342,0.003513,0.023464,0.976536,50.547543,0.914633,...,0.349577,0.045871,0.002227,0.888487,0.942184,0.00738,0.007749,0.008118,0.008856,0.00738
5,5,60,0.445906,0.427058,0.088121,0.008051,0.030863,0.969137,115.843322,0.896365,...,0.383925,0.068117,0.005105,0.862609,0.92826,0.00812,0.008526,0.008932,0.009744,0.00812
6,6,61,0.378521,0.450028,0.117702,0.014758,0.038991,0.961009,212.341493,0.877759,...,0.404575,0.090984,0.009357,0.837484,0.914542,0.00894,0.009387,0.009834,0.010728,0.00894
7,7,62,0.321008,0.460779,0.146634,0.023663,0.047915,0.952085,340.463268,0.858785,...,0.414241,0.113348,0.015002,0.813092,0.901027,0.00983,0.010322,0.010813,0.011796,0.00983
8,8,63,0.271949,0.46184,0.173846,0.034675,0.057691,0.942309,498.899823,0.83943,...,0.415194,0.134383,0.021984,0.789409,0.887711,0.01082,0.011361,0.011902,0.012984,0.01082
9,9,64,0.230118,0.455302,0.198571,0.047611,0.0684,0.9316,685.021632,0.819662,...,0.409316,0.153495,0.030185,0.766417,0.874592,0.0119,0.012495,0.01309,0.01428,0.0119


In [24]:
initial_dist = [38, 9, 7, 0, 0] ## Addition of Mortality state
initial_dist = np.array(initial_dist) / np.sum(initial_dist)

In [25]:
ai_model = AIGlaucomaModel()
ai_results = ai_model.run_deterministic( initial_dist= initial_dist)

non_ai_model = NonAIGlaucomaModel()
non_ai_results = non_ai_model.run_deterministic(initial_dist = initial_dist)

Applied Non-AI cost structure: Only VI patients incur costs


In [26]:
ai_results['traces']

Unnamed: 0,Year,Age,Prop_Mild,Prop_Moderate,Prop_Severe,Prop_VI,Prop_Dead,Prop_Alive,Total_Cost,Total_QALY,...,QALY_Moderate,QALY_Severe,QALY_VI,Cost_Discount_Factor,Health_Discount_Factor,Mortality_Rate_Mild,Mortality_Rate_Moderate,Mortality_Rate_Severe,Mortality_Rate_VI,Base_Mortality_qx
0,0,60,0.703704,0.166667,0.12963,0.0,0.0,1.0,706.316667,0.943185,...,0.149833,0.100204,0.0,1.0,1.0,0.00812,0.008526,0.008932,0.009744,0.00812
1,1,61,0.657175,0.199394,0.13099,0.004148,0.008293,0.991707,748.910865,0.930458,...,0.179255,0.101255,0.00263,0.970874,0.985222,0.00894,0.009387,0.009834,0.010728,0.00894
2,2,62,0.613184,0.227662,0.133486,0.008295,0.017372,0.982628,802.189498,0.917098,...,0.204669,0.103185,0.005259,0.942596,0.970662,0.00983,0.010322,0.010813,0.011796,0.00983
3,3,63,0.571591,0.251771,0.136878,0.012469,0.027291,0.972709,854.895857,0.903071,...,0.226342,0.105807,0.007905,0.915142,0.956317,0.01082,0.011361,0.011902,0.012984,0.01082
4,4,64,0.532254,0.271992,0.140939,0.016687,0.038127,0.961873,907.142724,0.888317,...,0.244521,0.108946,0.01058,0.888487,0.942184,0.0119,0.012495,0.01309,0.01428,0.0119
5,5,65,0.49505,0.288584,0.145464,0.020959,0.049943,0.950057,958.949599,0.872793,...,0.259437,0.112444,0.013288,0.862609,0.92826,0.01309,0.013744,0.014399,0.015708,0.01309
6,6,66,0.459857,0.301788,0.150258,0.025285,0.062813,0.937187,1010.222284,0.856446,...,0.271307,0.11615,0.01603,0.837484,0.914542,0.01439,0.01511,0.015829,0.017268,0.01439
7,7,67,0.426568,0.311828,0.155143,0.029656,0.076805,0.923195,1060.776667,0.83923,...,0.280333,0.119925,0.018802,0.813092,0.901027,0.01581,0.016601,0.017391,0.018972,0.01581
8,8,68,0.395083,0.318919,0.159953,0.034058,0.091987,0.908013,1110.335458,0.821102,...,0.286708,0.123644,0.021593,0.789409,0.887711,0.01737,0.018239,0.019107,0.020844,0.01737
9,9,69,0.365305,0.323261,0.164535,0.038467,0.108432,0.891568,1158.521957,0.802011,...,0.290611,0.127186,0.024388,0.766417,0.874592,0.01907,0.020023,0.020977,0.022884,0.01907


In [27]:
non_ai_results['traces']

Unnamed: 0,Year,Age,Prop_Mild,Prop_Moderate,Prop_Severe,Prop_VI,Prop_Dead,Prop_Alive,Total_Cost,Total_QALY,...,QALY_Moderate,QALY_Severe,QALY_VI,Cost_Discount_Factor,Health_Discount_Factor,Mortality_Rate_Mild,Mortality_Rate_Moderate,Mortality_Rate_Severe,Mortality_Rate_VI,Base_Mortality_qx
0,0,60,0.703704,0.166667,0.12963,0.0,0.0,1.0,0.0,0.943185,...,0.149833,0.100204,0.0,1.0,1.0,0.00812,0.008526,0.008932,0.009744,0.00812
1,1,61,0.59736,0.251375,0.13299,0.009981,0.008293,0.991707,143.613556,0.923516,...,0.225986,0.102801,0.006328,0.970874,0.985222,0.00894,0.009387,0.009834,0.010728,0.00894
2,2,62,0.506597,0.312568,0.143312,0.020115,0.017408,0.982592,289.409627,0.90353,...,0.280999,0.11078,0.012753,0.942596,0.970662,0.00983,0.010322,0.010813,0.011796,0.00983
3,3,63,0.429174,0.354592,0.157921,0.030912,0.027401,0.972599,444.76753,0.883186,...,0.318778,0.122073,0.019598,0.915142,0.956317,0.01082,0.011361,0.011902,0.012984,0.01082
4,4,64,0.363158,0.381086,0.174731,0.042671,0.038354,0.961646,613.949198,0.862428,...,0.342596,0.135067,0.027053,0.888487,0.942184,0.0119,0.012495,0.01309,0.01428,0.0119
5,5,65,0.306905,0.395102,0.192144,0.055516,0.050334,0.949666,798.761995,0.841222,...,0.355196,0.148527,0.035197,0.862609,0.92826,0.01309,0.013744,0.014399,0.015708,0.01309
6,6,66,0.259,0.399185,0.208956,0.069439,0.06342,0.93658,999.086471,0.819529,...,0.358867,0.161523,0.044024,0.837484,0.914542,0.01439,0.01511,0.015829,0.017268,0.01439
7,7,67,0.218236,0.395461,0.224288,0.084329,0.077685,0.922315,1213.331383,0.797322,...,0.35552,0.173374,0.053465,0.813092,0.901027,0.01581,0.016601,0.017391,0.018972,0.01581
8,8,68,0.183578,0.385699,0.237522,0.1,0.093201,0.906799,1438.795049,0.774572,...,0.346743,0.183605,0.0634,0.789409,0.887711,0.01737,0.018239,0.019107,0.020844,0.01737
9,9,69,0.154138,0.37136,0.24825,0.116204,0.110047,0.889953,1671.949842,0.75125,...,0.333853,0.191898,0.073674,0.766417,0.874592,0.01907,0.020023,0.020977,0.022884,0.01907


In [28]:
ai_results['total_cost_discounted'], ai_results['total_qalys_discounted']

(8974.613642225198, 8.904358050030082)

In [29]:
non_ai_results['total_cost_discounted'], non_ai_results['total_qalys_discounted']

(7714.500979121389, 8.607375389710308)

In [30]:
ai_psa_results = ai_model.run_probabilistic(
    n_iterations=5000,
    initial_dist=initial_dist,
    years=10,
    include_screening=True,
    population_type='general',
    random_seed=42,
    return_traces=True, starting_age=60
)

Running 5000 PSA iterations for AI Enhanced Model (Starting Age: 60)...
  AI Model Iteration 100/5000
  AI Model Iteration 200/5000
  AI Model Iteration 300/5000
  AI Model Iteration 400/5000
  AI Model Iteration 500/5000
  AI Model Iteration 600/5000
  AI Model Iteration 700/5000
  AI Model Iteration 800/5000
  AI Model Iteration 900/5000
  AI Model Iteration 1000/5000
  AI Model Iteration 1100/5000
  AI Model Iteration 1200/5000
  AI Model Iteration 1300/5000
  AI Model Iteration 1400/5000
  AI Model Iteration 1500/5000
  AI Model Iteration 1600/5000
  AI Model Iteration 1700/5000
  AI Model Iteration 1800/5000
  AI Model Iteration 1900/5000
  AI Model Iteration 2000/5000
  AI Model Iteration 2100/5000
  AI Model Iteration 2200/5000
  AI Model Iteration 2300/5000
  AI Model Iteration 2400/5000
  AI Model Iteration 2500/5000
  AI Model Iteration 2600/5000
  AI Model Iteration 2700/5000
  AI Model Iteration 2800/5000
  AI Model Iteration 2900/5000
  AI Model Iteration 3000/5000
  AI Mo

In [36]:
non_ai_psa_results = non_ai_model.run_probabilistic(
    n_iterations=5000,
    initial_dist=initial_dist,
    years=30,
    include_screening=True,
    population_type='general',
    random_seed=42,
    return_traces=True, starting_age=60
)

Running 5000 PSA iterations for Traditional Non-AI Model (Starting Age: 60)...
  Non-AI Model Iteration 100/5000
  Non-AI Model Iteration 200/5000
  Non-AI Model Iteration 300/5000
  Non-AI Model Iteration 400/5000
  Non-AI Model Iteration 500/5000
  Non-AI Model Iteration 600/5000
  Non-AI Model Iteration 700/5000
  Non-AI Model Iteration 800/5000
  Non-AI Model Iteration 900/5000
  Non-AI Model Iteration 1000/5000
  Non-AI Model Iteration 1100/5000
  Non-AI Model Iteration 1200/5000
  Non-AI Model Iteration 1300/5000
  Non-AI Model Iteration 1400/5000
  Non-AI Model Iteration 1500/5000
  Non-AI Model Iteration 1600/5000
  Non-AI Model Iteration 1700/5000
  Non-AI Model Iteration 1800/5000
  Non-AI Model Iteration 1900/5000
  Non-AI Model Iteration 2000/5000
  Non-AI Model Iteration 2100/5000
  Non-AI Model Iteration 2200/5000
  Non-AI Model Iteration 2300/5000
  Non-AI Model Iteration 2400/5000
  Non-AI Model Iteration 2500/5000
  Non-AI Model Iteration 2600/5000
  Non-AI Model Itera

array([[0.00000000e+00, 6.00000000e+01, 7.03703704e-01, 1.66666667e-01,
        1.29629630e-01, 0.00000000e+00, 0.00000000e+00, 1.00000000e+00,
        9.99995305e-03, 9.58451374e-01, 9.99995305e-03, 9.58451374e-01,
        9.99995305e-03, 0.00000000e+00, 0.00000000e+00, 0.00000000e+00,
        0.00000000e+00, 7.02433531e-01, 1.50589518e-01, 1.05428325e-01,
        0.00000000e+00, 1.00000000e+00, 1.00000000e+00, 8.12000000e-03,
        8.52600000e-03, 8.93200000e-03, 9.74400000e-03, 8.12000000e-03],
       [1.00000000e+00, 6.10000000e+01, 6.24192528e-01, 2.26985180e-01,
        1.29935498e-01, 1.05938688e-02, 8.29292593e-03, 9.91707074e-01,
        1.44445091e+02, 9.41214083e-01, 1.40682447e+02, 9.26581597e-01,
        0.00000000e+00, 0.00000000e+00, 0.00000000e+00, 0.00000000e+00,
        1.44445091e+02, 6.23065871e-01, 2.05089533e-01, 1.05677088e-01,
        7.38159014e-03, 9.73951044e-01, 9.84453606e-01, 8.94000000e-03,
        9.38700000e-03, 9.83400000e-03, 1.07280000e-02, 8.94000