In [19]:
import pandas as pd
import numpy as np
import statsmodels.api as sm
from lifelines import KaplanMeierFitter
from sklearn.linear_model import LogisticRegression
import matplotlib.pyplot as plt
import scipy.stats

class TrialEmulation:
    def __init__(self, data, estimand="Per-Protocol"):
        self.data = data.copy()
        self.estimand = estimand
        self.trial_data = None
        self.msm_model = None
        self.expanded_data = None
        self.weight_models = None
        self.censor_weights_specified = False
        self.censor_weights_spec = None
        self.switch_weights_specified = False
        self.switch_weights_spec = None
        self.expansion_specified = False
        self.outcome_model_specified = False

    def __str__(self):
        """Display trial sequence information"""
        output = []
        output.append("\nIPW for informative censoring:")
        if self.censor_weights_specified:
            output.append(" - Weights calculated")
        else:
            output.append(" - No weight model specified")
            
        output.append("\nIPW for treatment switch censoring:")
        if self.switch_weights_specified:
            output.append(" - Weights calculated")
        else:
            output.append(" - No weight model specified")
            
        output.append("\nSequence of Trials Data:")
        if hasattr(self, 'expanded_data') and self.expanded_data is not None:
            output.append(" - Expansion complete")
            self.expansion_specified = True
        else:
            output.append(" - Use set_expansion_options() and expand_trials() to construct the sequence of trials dataset.")

        output.append("\nOutcome model:")
        if hasattr(self, 'msm_model') and self.msm_model is not None:
            output.append(" - Model fitted")
            self.outcome_model_specified = True
            output.append(f" - Family: Binomial")
            output.append(f" - Link: logit")
            output.append(f" - Groups: {len(self.trial_data['id'].unique())}")
        else:
            output.append(" - Model not specified. Use set_outcome_model()")
            
        return "\n".join(output)
        
    def prepare_data(self):
        """Filter data according to per-protocol criteria"""
        pp_data = self.data.copy()
        pp_data = pp_data.sort_values(['id', 'period'])
        
        eligible_ids = pp_data[
            (pp_data['period'] == 0) & 
            (pp_data['eligible'] == 1)
        ]['id'].unique()
        
        pp_data = pp_data[pp_data['id'].isin(eligible_ids)]
        
        pp_data['prev_treatment'] = pp_data.groupby('id')['treatment'].shift(1)
        pp_data['treatment_switch'] = (
            (pp_data['treatment'] != pp_data['prev_treatment']) & 
            (pp_data['period'] > 0)
        )
        
        pp_data['cumulative_switch'] = pp_data.groupby('id')['treatment_switch'].cumsum()
        pp_data = pp_data[pp_data['cumulative_switch'] == 0]
        pp_data = pp_data[pp_data['censored'] == 0]
        
        pp_data = pp_data.drop(['treatment_switch', 'cumulative_switch'], axis=1)
        pp_data['prev_treatment'] = pp_data['prev_treatment'].fillna(pp_data['treatment'])
        
        self.trial_data = pp_data
        return self

    def show_weight_specs(self):
        """Display weight model specifications"""
        output = []
        
        if self.censor_weights_spec is not None:
            self.censor_weights_specified = True
            output.append("\n - Numerator formula: 1 - censored ~ " + self.censor_weights_spec['numerator'])
            output.append(" - Denominator formula: 1 - censored ~ " + self.censor_weights_spec['denominator'])
            output.append(f" - Numerator model is pooled across treatment arms: {self.censor_weights_spec['pool_models'] in ['numerator', 'both']}")
            output.append(" - Model fitter type: te_stats_glm_logit")
            output.append(" - Weight models not fitted. Use calculate_weights()\n")
        else:
            output.append(" - No weight model specified\n")

        if self.switch_weights_spec is not None:
            self.switch_weights_specified = True
            output.append("\n - Numerator formula: treatment ~ " + self.switch_weights_spec['numerator'])
            output.append(" - Denominator formula: treatment ~ " + self.switch_weights_spec['denominator'])
            output.append(" - Model fitter type: te_stats_glm_logit")
            output.append(" - Weight models not fitted. Use calculate_weights()\n")
        else:
            output.append(" - No weight model specified\n")
            
        return "\n".join(output)
    
    def calculate_weights(self):
        """Calculate IPW for censoring and treatment switching and store model statistics"""
        self.weight_models = {
            'censoring': {'numerator': None, 'denominator_0': None, 'denominator_1': None},
            'switching': {'numerator_0': None, 'numerator_1': None, 
                         'denominator_0': None, 'denominator_1': None}
        }
        
        if len(np.unique(self.trial_data['censored'])) == 1:
            censor_weights = np.ones(len(self.trial_data))
            self.weight_models['censoring']['numerator'] = {
                'coefficients': np.array([0.0]),
                'intercept': np.log(1e6),
                'X': sm.add_constant(self.trial_data[['x2']]),
                'y': np.ones(len(self.trial_data))
            }
        else:
            censor_model_num = LogisticRegression(solver='lbfgs', max_iter=1000)
            X_num = sm.add_constant(self.trial_data[['x2']])
            y = 1 - self.trial_data['censored']
            
            censor_model_num.fit(X_num, y)
            num_probs = censor_model_num.predict_proba(X_num)[:, 1]
            
            self.weight_models['censoring']['numerator'] = {
                'coefficients': censor_model_num.coef_[0],
                'intercept': censor_model_num.intercept_[0],
                'X': X_num,
                'y': y
            }
            
            censor_weights = np.ones(len(self.trial_data))
            
            for prev_treat in [0, 1]:
                mask = self.trial_data['prev_treatment'] == prev_treat
                if sum(mask) > 0:
                    X_den = sm.add_constant(self.trial_data.loc[mask, ['x2', 'x1']])
                    censor_model_den = LogisticRegression(solver='lbfgs', max_iter=1000)
                    censor_model_den.fit(X_den, y[mask])
                    den_probs = censor_model_den.predict_proba(X_den)[:, 1]
                    censor_weights[mask] = num_probs[mask] / (den_probs + 1e-8)
                    
                    self.weight_models['censoring'][f'denominator_{prev_treat}'] = {
                        'coefficients': censor_model_den.coef_[0],
                        'intercept': censor_model_den.intercept_[0],
                        'X': X_den,
                        'y': y[mask]
                    }
        
        switch_weights = np.ones(len(self.trial_data))
        
        for prev_treat in [0, 1]:
            mask = self.trial_data['prev_treatment'] == prev_treat
            if sum(mask) > 0:
                X_num = sm.add_constant(self.trial_data.loc[mask, ['age']])
                y = self.trial_data.loc[mask, 'treatment']
            
                if len(np.unique(y)) > 1:
                    switch_model_num = LogisticRegression(solver='lbfgs', max_iter=1000)
                    switch_model_num.fit(X_num, y)
                    num_probs = switch_model_num.predict_proba(X_num)[:, 1]
                    
                    X_den = sm.add_constant(self.trial_data.loc[mask, ['age', 'x1', 'x3']])
                    switch_model_den = LogisticRegression(solver='lbfgs', max_iter=1000)
                    switch_model_den.fit(X_den, y)
                    den_probs = switch_model_den.predict_proba(X_den)[:, 1]
                    
                    probs_ratio = num_probs / (den_probs + 1e-8)
                    switch_weights[mask] = probs_ratio
                    
                    self.weight_models['switching'][f'numerator_{prev_treat}'] = {
                        'coefficients': switch_model_num.coef_[0],
                        'intercept': switch_model_num.intercept_[0],
                        'X': X_num,
                        'y': y
                    }
                    
                    self.weight_models['switching'][f'denominator_{prev_treat}'] = {
                        'coefficients': switch_model_den.coef_[0],
                        'intercept': switch_model_den.intercept_[0],
                        'X': X_den,
                        'y': y
                    }
        
        self.trial_data['weight'] = censor_weights * switch_weights
        q99 = np.percentile(self.trial_data['weight'], 99)
        self.trial_data['weight'] = self.trial_data['weight'].clip(upper=q99)
        
        return self

    def show_weight_models(self):
        """Display weight model statistics"""
        output = []

        if not hasattr(self, 'weight_models') or self.weight_models is None:
            return "No weight models have been calculated yet. Use calculate_weights() first."
    
        output.append("Weight Models for Informative Censoring")
        output.append("-" * 40)
        
        model = self.weight_models['censoring']['numerator']
        if model is not None:
            output.append("\n[[n]]")
            output.append("Model: P(censor_event = 0 | X) for numerator")
            
            X = model['X']
            if isinstance(X, pd.DataFrame):
                X = X.values
            y = model['y']
            coef = np.concatenate([[model['intercept']], model['coefficients']])

            # Handle 1D arrays
            if X.ndim == 1:
                X = X.reshape(-1, 1)

            if len(coef) != X.shape[1]:
                output.append(f"Warning: Coefficient shape mismatch: {len(coef)} vs {X.shape[1]}")
                # Adjust coef or X as needed
                if len(coef) > X.shape[1]:
                    coef = coef[:X.shape[1]]
                else:
                    # Handle if X has more columns than coefficients
                    X = X[:, :len(coef)]
                
            pred = 1 / (1 + np.exp(-X @ coef))
            V = np.diag(pred * (1 - pred))

            try:
                cov = np.linalg.inv(X.T @ V @ X)
                se = np.sqrt(np.diag(cov))
                
                z_stats = coef / se
                p_values = 2 * (1 - scipy.stats.norm.cdf(np.abs(z_stats)))

                mean_y = np.mean(y)
                if mean_y in [0, 1]:
                    null_dev = 0
                else:
                    null_dev = -2 * np.sum(y * np.log(mean_y) + (1 - y) * np.log(1 - mean_y))
                
                model_dev = -2 * np.sum(y * np.log(pred) + (1 - y) * np.log(1 - pred))
                
                terms = ['(Intercept)'] + ['x2']
                output.append("\nterm        estimate    std.error  statistic  p.value")
                for term, est, std, z, p in zip(terms, coef, se, z_stats, p_values):
                    output.append(f"{term:10} {est:10.8f} {std:10.8f} {z:10.8f} {p:10.8f}")
                
                output.append(f"\nnull.deviance: {null_dev:8.4f}")
                output.append(f"deviance: {model_dev:8.4f}")
                output.append(f"nobs: {len(y)}")
            except np.linalg.LinAlgError:
                output.append("Error: Covariance matrix is singular. Cannot compute standard errors.")
    
        for prev_treat in [0, 1]:
            model = self.weight_models['censoring'][f'denominator_{prev_treat}']
            if model is not None:
                output.append(f"\n[[d{prev_treat}]]")
                output.append(f"Model: P(censor_event = 0 | X, previous treatment = {prev_treat}) for denominator")
                
                X = model['X']
                if isinstance(X, pd.DataFrame):
                    X = X.values
                y = model['y']
                if isinstance(y, pd.Series):
                    y = y.values
                coef = np.concatenate([[model['intercept']], model['coefficients']])
                
                # Handle 1D arrays
                if X.ndim == 1:
                    X = X.reshape(-1, 1)
                
                pred = 1 / (1 + np.exp(-X @ coef))
                V = np.diag(pred * (1 - pred))
                
                try:
                    cov = np.linalg.inv(X.T @ V @ X)
                    se = np.sqrt(np.diag(cov))
                    
                    z_stats = coef / se
                    p_values = 2 * (1 - scipy.stats.norm.cdf(np.abs(z_stats)))
                    
                    null_dev = -2 * np.sum(y * np.log(np.mean(y)) + (1-y) * np.log(1-np.mean(y))) if np.mean(y) not in [0, 1] else 0
                    model_dev = -2 * np.sum(y * np.log(pred) + (1-y) * np.log(1-pred))
                    
                    terms = ['(Intercept)', 'x2', 'x1']
                    output.append("\nterm        estimate    std.error  statistic  p.value")
                    for term, est, std, z, p in zip(terms, coef, se, z_stats, p_values):
                        output.append(f"{term:10} {est:10.8f} {std:10.8f} {z:10.8f} {p:10.8f}")
                    
                    output.append(f"\nnull.deviance: {null_dev:8.4f}")
                    output.append(f"deviance: {model_dev:8.4f}")
                    output.append(f"nobs: {len(y)}")
                except np.linalg.LinAlgError:
                    output.append("Error: Covariance matrix is singular. Cannot compute standard errors.")
    
        output.append("\nWeight Models for Treatment Switching")
        output.append("-------------------------------------")
        
        for prev_treat in [0, 1]:
            model = self.weight_models['switching'][f'numerator_{prev_treat}']
            if model is not None:
                output.append(f"\n[[n{prev_treat}]]")
                output.append(f"Model: P(treatment = 1 | previous treatment = {prev_treat}) for numerator")
                
                X = model['X']
                if isinstance(X, pd.DataFrame):
                    X = X.values
                y = model['y']
                if isinstance(y, pd.Series):
                    y = y.values
                coef = np.concatenate([[model['intercept']], model['coefficients']])
                
                # Handle 1D arrays
                if X.ndim == 1:
                    X = X.reshape(-1, 1)
                
                pred = 1 / (1 + np.exp(-X @ coef))
                V = np.diag(pred * (1 - pred))
                
                try:
                    cov = np.linalg.inv(X.T @ V @ X)
                    se = np.sqrt(np.diag(cov))
                    
                    z_stats = coef / se
                    p_values = 2 * (1 - scipy.stats.norm.cdf(np.abs(z_stats)))
                    
                    null_dev = -2 * np.sum(y * np.log(np.mean(y)) + (1-y) * np.log(1-np.mean(y))) if np.mean(y) not in [0, 1] else 0
                    model_dev = -2 * np.sum(y * np.log(pred) + (1-y) * np.log(1-pred))
                    
                    terms = ['(Intercept)', 'age']
                    output.append("\nterm        estimate    std.error  statistic  p.value")
                    for term, est, std, z, p in zip(terms, coef, se, z_stats, p_values):
                        output.append(f"{term:10} {est:10.8f} {std:10.8f} {z:10.8f} {p:10.8f}")

                    output.append(f"\nnull.deviance: {null_dev:8.4f}")
                    output.append(f"deviance: {model_dev:8.4f}")
                    output.append(f"nobs: {len(y)}")
                except np.linalg.LinAlgError:
                    output.append("Error: Covariance matrix is singular. Cannot compute standard errors.")
            
            model = self.weight_models['switching'][f'denominator_{prev_treat}']
            if model is not None:
                output.append(f"\n[[d{prev_treat}]]")
                output.append(f"Model: P(treatment = 1 | previous treatment = {prev_treat}) for denominator")
                
                X = model['X']
                if isinstance(X, pd.DataFrame):
                    X = X.values
                y = model['y']
                if isinstance(y, pd.Series):
                    y = y.values
                coef = np.concatenate([[model['intercept']], model['coefficients']])
                
                # Handle 1D arrays
                if X.ndim == 1:
                    X = X.reshape(-1, 1)
                
                pred = 1 / (1 + np.exp(-X @ coef))
                V = np.diag(pred * (1 - pred))
                
                try:
                    cov = np.linalg.inv(X.T @ V @ X)
                    se = np.sqrt(np.diag(cov))
                    
                    z_stats = coef / se
                    p_values = 2 * (1 - scipy.stats.norm.cdf(np.abs(z_stats)))
                    
                    null_dev = -2 * np.sum(y * np.log(np.mean(y)) + (1-y) * np.log(1-np.mean(y))) if np.mean(y) not in [0, 1] else 0
                    model_dev = -2 * np.sum(y * np.log(pred) + (1-y) * np.log(1-pred))
                    
                    terms = ['(Intercept)', 'age', 'x1', 'x3']
                    output.append("\nterm        estimate    std.error  statistic  p.value")
                    for term, est, std, z, p in zip(terms, coef, se, z_stats, p_values):
                        output.append(f"{term:10} {est:10.8f} {std:10.8f} {z:10.8f} {p:10.8f}")
                    
                    output.append(f"\nnull.deviance: {null_dev:8.4f}")
                    output.append(f"deviance: {model_dev:8.4f}")
                    output.append(f"nobs: {len(y)}")
                except np.linalg.LinAlgError:
                    output.append("Error: Covariance matrix is singular. Cannot compute standard errors.")
        
        return "\n".join(output)

    def set_expansion_options(self, chunk_size=500):
        """Set options for expanding trials"""
        self.chunk_size = chunk_size
        self.expansion_specified = True
        return self
    
    def expand_trials(self):
        """Expand observational data into sequence of trials"""
        if self.trial_data is None:
            raise ValueError("Must prepare data first using prepare_data()")
            
        expanded_data = []
    
        time_points = sorted(self.trial_data['period'].unique())
        
        for t in time_points:
            trial_data = self.trial_data[self.trial_data['period'] >= t].copy()
            trial_data['trial_time'] = trial_data['period'] - t
            trial_data['trial_id'] = t
            expanded_data.append(trial_data)
        
        self.expanded_data = pd.concat(expanded_data, ignore_index=True)
        return self
    
    def load_expanded_data(self, seed=1234, p_control=0.5):
        """Sample from expanded trials data"""
        if not hasattr(self, 'expanded_data') or self.expanded_data is None:
            raise ValueError("Must expand trials first using expand_trials()")
            
        np.random.seed(seed)
        
        control_size = int(len(self.expanded_data) * p_control)
        sampled_indices = np.random.choice(
            len(self.expanded_data), 
            size=control_size, 
            replace=False
        )
        
        self.sampled_data = self.expanded_data.iloc[sampled_indices].copy()
        return self
        
    def fit_msm_model(self):
        """Fit the MSM model and store it in self.msm_model"""
        if self.trial_data is None:
            raise ValueError("Must prepare data first using prepare_data()")
            
        X = self.trial_data[['x1', 'x2', 'x3']]
        y = self.trial_data['treatment']
    
        self.msm_model = LogisticRegression(solver='lbfgs', max_iter=1000)
        self.msm_model.fit(X, y)
        self.outcome_model_specified = True
        return self

    def plot_survival_difference(self):
        """Plot survival curves for different treatment groups"""
        if self.trial_data is None:
            raise ValueError("Must prepare data first using prepare_data()")
            
        # Check if duration column exists
        if 'duration' not in self.trial_data.columns:
            raise ValueError("The trial_data must contain a 'duration' column for survival analysis")
        
        kmf = KaplanMeierFitter()
        plt.figure(figsize=(10, 6))
        
        # Plot for each treatment group
        for treatment in self.trial_data['treatment'].unique():
            mask = self.trial_data['treatment'] == treatment
            
            # Some datasets use 'event' instead of 'censored' with opposite meaning
            if 'event' in self.trial_data.columns:
                event_col = 'event'
            else:
                event_col = 'censored'
                
            kmf.fit(
                durations=self.trial_data.loc[mask, 'duration'], 
                event_observed=self.trial_data.loc[mask, event_col],
                label=f'Treatment {treatment}'
            )
            kmf.plot()
        
        plt.title('Survival Curves by Treatment Group')
        plt.xlabel('Time')
        plt.ylabel('Survival Probability')
        plt.legend()
        plt.grid(True, linestyle='--', alpha=0.7)
        plt.tight_layout()
        plt.show()
        return self

def print_model_summary(model, X, y):
    """Print a summary of the fitted model in R-like format."""
    if X is None or y is None:
        raise ValueError("X and y must be defined before calling this function.")
    
    # Get coefficients
    coef = model.coef_[0]
    intercept = model.intercept_[0]
    
    # Make sure X is numpy array
    if isinstance(X, pd.DataFrame):
        X = X.values
    if isinstance(y, pd.Series):
        y = y.values
        
    # Calculate predictions
    pred_probs = model.predict_proba(X)[:, 1]
    
    # Calculate standard errors
    X_with_intercept = sm.add_constant(X)
    pred = 1 / (1 + np.exp(-X_with_intercept @ np.concatenate(([intercept], coef))))
    V = np.diag(pred * (1 - pred))
    
    # Calculate covariance matrix
    try:
        cov = np.linalg.inv(X_with_intercept.T @ V @ X_with_intercept)
        se = np.sqrt(np.diag(cov))
        
        # Calculate z-statistics and p-values
        z_stats = np.concatenate(([intercept], coef)) / se
        p_values = 2 * (1 - scipy.stats.norm.cdf(np.abs(z_stats)))
        
        # Create a DataFrame for the summary
        feature_names = ['(Intercept)']
        if hasattr(X, 'columns'):
            feature_names.extend(X.columns)
        else:
            feature_names.extend([f'x{i}' for i in range(1, len(coef) + 1)])
            
        summary_df = pd.DataFrame({
            'term': feature_names[:len(coef) + 1],
            'estimate': np.concatenate(([intercept], coef)),
            'std.error': se,
            'statistic': z_stats,
            'p.value': p_values
        })
        
        # Print the summary
        print("Model Summary:\n")
        print(summary_df.to_string(index=False, float_format='%.4f'))
        
        # Calculate null deviance
        mean_y = np.mean(y)
        if mean_y in [0, 1]:
            null_dev = 0
        else:
            null_dev = -2 * np.sum(y * np.log(mean_y) + (1 - y) * np.log(1 - mean_y))
        
        # Calculate model deviance
        model_dev = -2 * np.sum(y * np.log(pred_probs + 1e-10) + (1 - y) * np.log(1 - pred_probs + 1e-10))
        
        # Calculate log-likelihood
        log_lik = -model_dev / 2
        
        # Print additional model statistics
        print(f"\nnull.deviance: {null_dev:.4f}")
        print(f"deviance: {model_dev:.4f}")
        print(f"df.null: {len(y) - 1}")
        print(f"logLik: {log_lik:.4f}")
        print(f"AIC: {2 * (len(coef) + 1) - 2 * log_lik:.4f}")
        print(f"nobs: {len(y)}")
        
    except np.linalg.LinAlgError:
        print("Error: Covariance matrix is singular. Cannot compute standard errors.")
        
    except Exception as e:
        print(f"Error in calculating model summary: {str(e)}")

# Usage
if __name__ == "__main__":
    # Load data
    try:
        data = pd.read_csv("../data_censored.csv")
    except FileNotFoundError:
        print("Data file not found. Please check the file path.")
        # Create dummy data for demonstration
        import numpy as np
        
        n_patients = 100
        n_periods = 5
        rows = []
        
        for i in range(n_patients):
            age = np.random.normal(50, 10)
            x1 = np.random.normal(0, 1)
            x2 = np.random.normal(0, 1)
            x3 = np.random.normal(0, 1)
            
            for p in range(n_periods):
                treatment = np.random.binomial(1, 0.5)
                censored = np.random.binomial(1, 0.1)
                eligible = 1 if p == 0 else np.random.binomial(1, 0.9)
                
                rows.append({
                    'id': i,
                    'period': p,
                    'age': age,
                    'x1': x1,
                    'x2': x2,
                    'x3': x3,
                    'treatment': treatment,
                    'censored': censored,
                    'eligible': eligible,
                    'duration': p + np.random.exponential(2),
                    'event': np.random.binomial(1, 0.2)
                })
                
        data = pd.DataFrame(rows)
    
    # Set display options for pandas
    pd.set_option('display.max_columns', None)
    pd.set_option('display.width', None)
    pd.set_option('display.max_rows', 6)
    pd.set_option('display.expand_frame_repr', False)
    
    print(data.head(6))
    
    try:
        # Create and run analysis
        trial = TrialEmulation(data)
        
        # Set weight specifications first
        trial.censor_weights_spec = {
            'numerator': 'x2',
            'denominator': 'x2 + x1',
            'pool_models': 'numerator'
        }
        
        trial.switch_weights_spec = {
            'numerator': 'age',
            'denominator': 'age + x1 + x3'
        }
        
        # Run analysis
        (trial.prepare_data()
              .calculate_weights()
              .set_expansion_options(chunk_size=500)
              .expand_trials()
              .load_expanded_data(seed=1234, p_control=0.5)
              .fit_msm_model())
        
        # Display results
        print("\nTrial Sequence Object")
        print(f"Estimand: {trial.estimand}\n")
        print(f"Data:\n - N: {len(trial.trial_data)} observations from {trial.trial_data['id'].nunique()} patients\n")
        print(trial.trial_data.head(2))
        print("--")
        print(trial.trial_data.tail(2))
        print(trial)
        print(trial.show_weight_specs())
        print(trial.show_weight_models())
        
        # Print model summary
        print("\nModel Summary:")
        print_model_summary(trial.msm_model, trial.trial_data[['x1', 'x2', 'x3']], trial.trial_data['treatment'])
    
        # Plot survival difference
        trial.plot_survival_difference()

SyntaxError: incomplete input (873566349.py, line 653)