In [None]:
import pandas as pd
import numpy as np
import statsmodels.api as sm
from lifelines import KaplanMeierFitter
from sklearn.linear_model import LogisticRegression
import matplotlib.pyplot as plt
from pathlib import Path
import scipy.stats
import tempfile

class TrialEmulation:
    def __init__(self, data, estimand="Per-Protocol"):
        self.data = data.copy()
        self.estimand = estimand
        self.trial_data = None
        self.msm_model = None
        self.expansion_data = None
        self.censor_weights_specified = False
        self.censor_weights_spec = None
        self.switch_weights_specified = False
        self.switch_weights_spec = None
        self.expansion_specified = False
        self.outcome_model_specified = False

    def __str__(self):
        """Display trial sequence information"""
        output = []
        output.append("\nIPW for informative censoring:")
        if self.censor_weights_specified:
            output.append(" - Weights calculated")
        else:
            output.append(" - No weight model specified")
            
        output.append("\nIPW for treatment switch censoring:")
        if self.switch_weights_specified:
            output.append(" - Weights calculated")
        else:
            output.append(" - No weight model specified")
            
        output.append("\nSequence of Trials Data:")
        if self.expansion_specified:
            output.append(" - Expansion complete")
        else:
            output.append(" - Use set_expansion_options() and expand_trials() to construct the sequence of trials dataset.")

        output.append("\nOutcome model:")
        if self.outcome_model_specified:
            output.append(" - Model fitted")
            if hasattr(self, 'msm_model') and self.msm_model is not None:
                output.append(f" - Family: Binomial")
                output.append(f" - Link: logit")
                output.append(f" - Groups: {len(self.trial_data['id'].unique())}")
        else:
            output.append(" - Model not specified. Use set_outcome_model()")
            
        return "\n".join(output)
        
    def prepare_data(self):
        """Filter data according to per-protocol criteria"""
        # Make a copy of the data
        pp_data = self.data.copy()
        
        # Sort by ID and period
        pp_data = pp_data.sort_values(['id', 'period'])
        
        # Initial eligibility filtering
        eligible_ids = pp_data[
            (pp_data['period'] == 0) & 
            (pp_data['eligible'] == 1)
        ]['id'].unique()
        
        pp_data = pp_data[pp_data['id'].isin(eligible_ids)]
        
        # Track treatment switches
        pp_data['prev_treatment'] = pp_data.groupby('id')['treatment'].shift(1)
        pp_data['treatment_switch'] = (
            (pp_data['treatment'] != pp_data['prev_treatment']) & 
            (pp_data['period'] > 0)
        )
        
        # Keep observations until first switch
        pp_data['cumulative_switch'] = pp_data.groupby('id')['treatment_switch'].cumsum()
        pp_data = pp_data[pp_data['cumulative_switch'] == 0]
        pp_data = pp_data[pp_data['censored'] == 0]
        
        # Clean up intermediate columns but keep prev_treatment
        pp_data = pp_data.drop(['treatment_switch', 'cumulative_switch'], axis=1)
        
        # Fill NA in prev_treatment with current treatment for first period
        pp_data['prev_treatment'] = pp_data['prev_treatment'].fillna(pp_data['treatment'])
        
        self.trial_data = pp_data
        return self

    def show_weight_specs(self):
        """Display weight model specifications"""
        output = []
        
        if self.censor_weights_spec is not None:
            output.append("\n - Numerator formula: 1 - censored ~ " + self.censor_weights_spec['numerator'])
            output.append(" - Denominator formula: 1 - censored ~ " + self.censor_weights_spec['denominator'])
            output.append(f" - Numerator model is pooled across treatment arms: {self.censor_weights_spec['pool_models'] in ['numerator', 'both']}")
            output.append(" - Model fitter type: te_stats_glm_logit")
            output.append(" - Weight models not fitted. Use calculate_weights()\n")
        else:
            output.append(" - No weight model specified\n")

        # Display switch weight specifications
        if self.switch_weights_spec is not None:
            output.append("\n - Numerator formula: treatment ~ " + self.switch_weights_spec['numerator'])
            output.append(" - Denominator formula: treatment ~ " + self.switch_weights_spec['denominator'])
            output.append(" - Model fitter type: te_stats_glm_logit")
            output.append(" - Weight models not fitted. Use calculate_weights()\n")
        else:
            output.append(" - No weight model specified\n")
            
        return "\n".join(output)
    
  
    def calculate_weights(self):
        """Calculate IPW for censoring and treatment switching and store model statistics"""
        from scipy import stats
        
        # Store model results
        self.weight_models = {
            'censoring': {'numerator': None, 'denominator_0': None, 'denominator_1': None},
            'switching': {'numerator_0': None, 'numerator_1': None, 
                         'denominator_0': None, 'denominator_1': None}
        }
        
        # 1. Censoring weights
        # If all observations are uncensored, set weights to 1
        if len(np.unique(self.trial_data['censored'])) == 1:
            censor_weights = np.ones(len(self.trial_data))
            
            # Store dummy model results for display
            self.weight_models['censoring']['numerator'] = {
                'coefficients': np.array([0.0]),
                'intercept': np.log(1e6),  # Large positive number to get prob ≈ 1
                'X': sm.add_constant(self.trial_data[['x2']]),
                'y': np.ones(len(self.trial_data))
            }
        else:
            # Original censoring weight calculation
            censor_model_num = LogisticRegression(solver='lbfgs', max_iter=1000)
            X_num = sm.add_constant(self.trial_data[['x2']])
            y = 1 - self.trial_data['censored']
            
            censor_model_num.fit(X_num, y)
            num_probs = censor_model_num.predict_proba(X_num)[:, 1]
            
            self.weight_models['censoring']['numerator'] = {
                'coefficients': censor_model_num.coef_[0],
                'intercept': censor_model_num.intercept_[0],
                'X': X_num,
                'y': y
            }
            
            # Initialize censoring weights
            censor_weights = np.ones(len(self.trial_data))
            
            # Denominator models (by previous treatment)
            for prev_treat in [0, 1]:
                mask = self.trial_data['prev_treatment'] == prev_treat
                if sum(mask) > 0:
                    X_den = sm.add_constant(self.trial_data.loc[mask, ['x2', 'x1']])
                    
                    censor_model_den = LogisticRegression(solver='lbfgs', max_iter=1000)
                    censor_model_den.fit(X_den, y[mask])
                    den_probs = censor_model_den.predict_proba(X_den)[:, 1]
                    
                    # Calculate weights for this group
                    censor_weights[mask] = num_probs[mask] / (den_probs + 1e-8)
                    
                    self.weight_models['censoring'][f'denominator_{prev_treat}'] = {
                        'coefficients': censor_model_den.coef_[0],
                        'intercept': censor_model_den.intercept_[0],
                        'X': X_den,
                        'y': y[mask]
                    }
        
        # 2. Treatment switching weights
        switch_weights = np.ones(len(self.trial_data))
        
        for prev_treat in [0, 1]:
            mask = self.trial_data['prev_treatment'] == prev_treat
            if sum(mask) > 0:
                # Numerator model
                X_num = sm.add_constant(self.trial_data.loc[mask, ['age']])
                y = self.trial_data.loc[mask, 'treatment']
            
                if len(np.unique(y)) > 1:
                    switch_model_num = LogisticRegression(solver='lbfgs', max_iter=1000)
                    switch_model_num.fit(X_num, y)
                    num_probs = switch_model_num.predict_proba(X_num)[:, 1]
                    
                    # Denominator model
                    X_den = sm.add_constant(self.trial_data.loc[mask, ['age', 'x1', 'x3']])
                    switch_model_den = LogisticRegression(solver='lbfgs', max_iter=1000)
                    switch_model_den.fit(X_den, y)
                    den_probs = switch_model_den.predict_proba(X_den)[:, 1]
                    
                    # Calculate weights for this group
                    switch_weights[mask] = num_probs / (den_probs + 1e-8)
                    
                    # Store model results
                    self.weight_models['switching'][f'numerator_{prev_treat}'] = {
                        'coefficients': switch_model_num.coef_[0],
                        'intercept': switch_model_num.intercept_[0],
                        'X': X_num,
                        'y': y
                    }
                    
                    self.weight_models['switching'][f'denominator_{prev_treat}'] = {
                        'coefficients': switch_model_den.coef_[0],
                        'intercept': switch_model_den.intercept_[0],
                        'X': X_den,
                        'y': y
                    }
        
        # Combine weights and store
        self.trial_data['weight'] = censor_weights * switch_weights
        
        # Trim weights at 99th percentile
        q99 = np.percentile(self.trial_data['weight'], 99)
        self.trial_data['weight'] = self.trial_data['weight'].clip(upper=q99)
        
        return self

    
    def show_weight_models(self):
        """Display weight model statistics"""
        from scipy import stats
        output = []

        if not self.weight_models:
            return "No weight models have been calculated yet."
    
        # 1. Display censoring models
        output.append("Weight Models for Informative Censoring")
        output.append("-" * 40)
        
        # Numerator model
        model = self.weight_models['censoring']['numerator']
        if model is not None:
            output.append("\n[[n]]")
            output.append("Model: P(censor_event = 0 | X) for numerator")
            
            X = model['X']
            y = model['y']
            coef = np.concatenate([[model['intercept']], model['coefficients']])

            if X.ndim == 1:
                X = X.reshape(-1, 1)
                
            pred = 1 / (1 + np.exp(-X @ coef))
            V = np.diag(pred * (1 - pred))

            cov = np.linalg.inv(X.T @ V @ X)
            se = np.sqrt(np.diag(cov))
            
            z_stats = coef / se
            p_values = 2 * (1 - stats.norm.cdf(np.abs(z_stats)))
    
            
            terms = ['(Intercept)', 'x2']
            output.append("\nterm        estimate    std.error  statistic  p.value")
            for term, est, std, z, p in zip(terms, coef, se, z_stats, p_values):
                output.append(f"{term:10} {est:10.8f} {std:10.8f} {z:10.8f} {p:10.8f}")
            
            output.append(f"\nnull.deviance: {null_dev:8.4f}")
            output.append(f"deviance: {model_dev:8.4f}")
            output.append(f"nobs: {len(y)}")
    
        # Denominator models
        for prev_treat in [0, 1]:
            model = self.weight_models['censoring'][f'denominator_{prev_treat}']
            if model is not None:
                output.append(f"\n[[d{prev_treat}]]")
                output.append(f"Model: P(censor_event = 0 | X, previous treatment = {prev_treat}) for denominator")
                
                X = model['X']
                y = model['y']
                coef = np.concatenate([[model['intercept']], model['coefficients']])
                
                pred = 1 / (1 + np.exp(-X @ coef))
                V = np.diag(pred * (1 - pred))
                cov = np.linalg.inv(X.T @ V @ X)
                se = np.sqrt(np.diag(cov))
                
                z_stats = coef / se
                p_values = 2 * (1 - stats.norm.cdf(np.abs(z_stats)))
                
                null_dev = -2 * sum(y * np.log(y.mean()) + (1-y) * np.log(1-y.mean()))
                model_dev = -2 * sum(y * np.log(pred) + (1-y) * np.log(1-pred))
                
                terms = ['(Intercept)', 'x2', 'x1']
                output.append("\nterm        estimate    std.error  statistic  p.value")
                for term, est, std, z, p in zip(terms, coef, se, z_stats, p_values):
                    output.append(f"{term:10} {est:10.8f} {std:10.8f} {z:10.8f} {p:10.8f}")
                
                output.append(f"\nnull.deviance: {null_dev:8.4f}")
                output.append(f"deviance: {model_dev:8.4f}")
                output.append(f"nobs: {len(y)}")
    
        # 2. Display switching models
        output.append("\nWeight Models for Treatment Switching")
        output.append("-------------------------------------")
        
        for prev_treat in [0, 1]:
            # Numerator model
            model = self.weight_models['switching'][f'numerator_{prev_treat}']
            if model is not None:
                output.append(f"\n[[n{prev_treat}]]")
                output.append(f"Model: P(treatment = 1 | previous treatment = {prev_treat}) for numerator")
                
                X = model['X'].values
                y = model['y'].values
                coef = np.concatenate([[model['intercept']], model['coefficients']])
                
                pred = 1 / (1 + np.exp(-X @ coef))
                V = np.diag(pred * (1 - pred))
                cov = np.linalg.inv(X.T @ V @ X)
                se = np.sqrt(np.diag(cov))
                
                z_stats = coef / se
                p_values = 2 * (1 - stats.norm.cdf(np.abs(z_stats)))
                
                null_dev = -2 * sum(y * np.log(y.mean()) + (1-y) * np.log(1-y.mean()))
                model_dev = -2 * sum(y * np.log(pred) + (1-y) * np.log(1-pred))
                
                terms = ['(Intercept)', 'age']
                output.append("\nterm        estimate    std.error  statistic  p.value")
                for term, est, std, z, p in zip(terms, coef, se, z_stats, p_values):
                    output.append(f"{term:10} {est:10.8f} {std:10.8f} {z:10.8f} {p:10.8f}")
                
                output.append(f"\nnull.deviance: {null_dev:8.4f}")
                output.append(f"deviance: {model_dev:8.4f}")
                output.append(f"nobs: {len(y)}")
            
            # Denominator model
            model = self.weight_models['switching'][f'denominator_{prev_treat}']
            if model is not None:
                output.append(f"\n[[d{prev_treat}]]")
                output.append(f"Model: P(treatment = 1 | previous treatment = {prev_treat}) for denominator")
                
                X = model['X']
                y = model['y']
                coef = np.concatenate([[model['intercept']], model['coefficients']])
                
                pred = 1 / (1 + np.exp(-X @ coef))
                V = np.diag(pred * (1 - pred))
                
                try:
                    cov = np.linalg.inv(X.T @ V @ X)
                    se = np.sqrt(np.diag(cov))
                except np.linalg.LinAlgError:
                    output.append("Error: Covariance matrix is singular. Cannot compute standard errors.")
                    continue
                
                z_stats = coef / se
                p_values = 2 * (1 - stats.norm.cdf(np.abs(z_stats)))
                
                null_dev = -2 * sum(y * np.log(y.mean()) + (1-y) * np.log(1-y.mean()))
                model_dev = -2 * sum(y * np.log(pred) + (1-y) * np.log(1-pred))
                
                terms = ['(Intercept)', 'age', 'x1', 'x3']
                output.append("\nterm        estimate    std.error  statistic  p.value")
                for term, est, std, z, p in zip(terms, coef, se, z_stats, p_values):
                    output.append(f"{term:10} {est:10.8f} {std:10.8f} {z:10.8f} {p:10.8f}")
                
                output.append(f"\nnull.deviance: {null_dev:8.4f}")
                output.append(f"deviance: {model_dev:8.4f}")
                output.append(f"nobs: {len(y)}")
        
        return "\n".join(output)

    def set_expansion_options(self, chunk_size=500):
        """Set options for expanding trials"""
        self.chunk_size = chunk_size
        return self
    
    def expand_trials(self):
        """Expand observational data into sequence of trials"""
        expanded_data = []
    
        # Get unique time points
        time_points = sorted(self.trial_data['period'].unique())
        
        # For each time point, create a trial
        for t in time_points:
            # Get eligible subjects at time t
            trial_data = self.trial_data[self.trial_data['period'] >= t].copy()
            
            # Reset time to start from 0
            trial_data['trial_time'] = trial_data['period'] - t
            
            # Add trial identifier
            trial_data['trial_id'] = t
            
            expanded_data.append(trial_data)
        
        # Combine all trials
        self.expanded_data = pd.concat(expanded_data, ignore_index=True)
        return self
    
    def load_expanded_data(self, seed=1234, p_control=0.5):
        """Sample from expanded trials data"""
        if not hasattr(self, 'expanded_data'):
            raise ValueError("Must expand trials first using expand_trials()")
            
        np.random.seed(seed)
        
        # Sample control group
        control_size = int(len(self.expanded_data) * p_control)
        sampled_indices = np.random.choice(
            len(self.expanded_data), 
            size=control_size, 
            replace=False
        )
        
        self.sampled_data = self.expanded_data.iloc[sampled_indices].copy()
        return self
        

    # Usage
if __name__ == "__main__":
        # Load data
        data = pd.read_csv("../data_censored.csv")
        
        # Set display options for pandas
        pd.set_option('display.max_columns', None)
        pd.set_option('display.width', None)
        pd.set_option('display.max_rows', 6)
        pd.set_option('display.expand_frame_repr', False)
        
        print("Initial data:")
        print(data.head(6))
        
        # Create and run analysis
        trial = TrialEmulation(data)
        
        # Set weight specifications first
        trial.censor_weights_spec = {
            'numerator': 'x2',
            'denominator': 'x2 + x1',
            'pool_models': 'numerator'
        }
        
        trial.switch_weights_spec = {
            'numerator': 'age',
            'denominator': 'age + x1 + x3'
        }
        
        # Run analysis
        (trial.prepare_data()
              .calculate_weights()
              .set_expansion_options(chunk_size=500)
              .expand_trials()
              .load_expanded_data(seed=1234, p_control=0.5))
        
        # Display results
        print("\nTrial Sequence Object")
        print(f"Estimand: {trial.estimand}\n")
        print(f"Data:\n - N: {len(trial.trial_data)} observations from {trial.trial_data['id'].nunique()} patients\n")
        print(trial.trial_data.head(2))
        print("--")
        print(trial.trial_data.tail(2))
        print(trial)
        print(trial.show_weight_specs())
        print(trial.show_weight_models())
        
        print("\nModel Summary:")
        print(trial.msm_model.summary())
