In [8]:
import os
import tempfile
import pandas as pd
import numpy as np
import statsmodels.api as sm
import patsy
import pickle
from typing import Optional, Dict, Union
import statsmodels.formula.api as smf 

class TrialSequence:
    def __init__(self, estimand: str):
        self.estimand = estimand
        self.directory: Optional[str] = None
        self.data: Optional[pd.DataFrame] = None
        self.id_col: Optional[str] = None
        self.period_col: Optional[str] = None
        self.treatment_col: Optional[str] = None
        self.outcome_col: Optional[str] = None
        self.eligible_col: Optional[str] = None
        self.switch_weights: Optional[Dict] = None
        self.censor_weights: Optional[Dict] = None

    def create_directory(self, base_dir: Optional[str] = None) -> None:
        if base_dir is None:
            base_dir = tempfile.gettempdir()
        dir_name = f"trial_{self.estimand.lower()}"
        self.directory = os.path.join(base_dir, dir_name)
        os.makedirs(self.directory, exist_ok=True)
        print(f"Directory created: {self.directory}")

    def set_data(self, data: pd.DataFrame, id_col: str, period_col: str, treatment_col: str, outcome_col: str, eligible_col: str) -> None:
        self.data = data
        self.id_col = id_col
        self.period_col = period_col
        self.treatment_col = treatment_col
        self.outcome_col = outcome_col
        self.eligible_col = eligible_col
        print(f"Data and columns assigned for {self.estimand} trial.")

    def _fit_logit_model(self, data: pd.DataFrame, formula: str, label: str, save_path: Optional[str] = None) -> sm.Logit:
        y, X = patsy.dmatrices(formula, data, return_type="dataframe")
        model = sm.Logit(y, X).fit(disp=0)  # disp=0 to reduce verbosity
        if save_path:
            os.makedirs(save_path, exist_ok=True)
            model_file = os.path.join(save_path, f"{label}_model.pkl")
            with open(model_file, "wb") as f:
                pickle.dump(model, f)
            print(f"Model saved to {model_file}")
        return model

    def set_switch_weight_model(self, numerator_formula: str, denominator_formula: str, save_path: Optional[str] = None) -> None:
        if self.estimand == "ITT":
            raise ValueError("Switch weight models cannot be used with ITT estimand.")
        if self.data is None:
            raise ValueError("Data must be set before fitting models.")
        print("\n3.1 Censoring due to treatment switching")
        numerator_model = self._fit_logit_model(self.data, f"{self.treatment_col} ~ {numerator_formula}", "numerator", save_path)
        denominator_model = self._fit_logit_model(self.data, f"{self.treatment_col} ~ {denominator_formula}", "denominator", save_path)
        self.switch_weights = {"numerator": numerator_model, "denominator": denominator_model,
                              "numerator_formula": f"{self.treatment_col} ~ {numerator_formula}",
                              "denominator_formula": f"{self.treatment_col} ~ {denominator_formula}"}
        print("Switch weight models fitted.")

    def set_censor_weight_model(self, censor_event: str, numerator_formula: str, denominator_formula: str, pool_models: str = "none", save_path: Optional[str] = None) -> None:
        if self.data is None:
            raise ValueError("Data must be set before fitting models.")
        print("\n3.2 Other informative censoring")
        data = self.data.copy()
        if pool_models == "numerator":
            numerator_model = self._fit_logit_model(data, f"{censor_event} ~ {numerator_formula}", "numerator_pooled", save_path)
        else:
            numerator_model = {treatment: self._fit_logit_model(data[data[self.treatment_col] == treatment],
                                                              f"{censor_event} ~ {numerator_formula}",
                                                              f"numerator_treatment_{treatment}", save_path)
                              for treatment in data[self.treatment_col].unique()}
        denominator_model = {treatment: self._fit_logit_model(data[data[self.treatment_col] == treatment],
                                                             f"{censor_event} ~ {denominator_formula}",
                                                             f"denominator_treatment_{treatment}", save_path)
                            for treatment in data[self.treatment_col].unique()}
        self.censor_weights = {"numerator": numerator_model, "denominator": denominator_model,
                              "numerator_formula": f"{censor_event} ~ {numerator_formula}",
                              "denominator_formula": f"{censor_event} ~ {denominator_formula}",
                              "pool_models": pool_models}
        print("Censor weight models fitted.")

    def calculate_weights(self) -> None:
        if self.data is None:
            raise ValueError("Data must be set before calculating weights")
        if self.switch_weights:
            print("\nCalculating weights for treatment switching models...")
            X_num = patsy.dmatrix(self.switch_weights["numerator_formula"].split("~")[1], self.data, return_type="dataframe")
            X_denom = patsy.dmatrix(self.switch_weights["denominator_formula"].split("~")[1], self.data, return_type="dataframe")
            num_pred = self.switch_weights["numerator"].predict(X_num)
            denom_pred = self.switch_weights["denominator"].predict(X_denom)
            self.data["switch_weight"] = num_pred / denom_pred
            print("Weights for treatment switching models calculated and stored as 'switch_weight'.")
        if self.censor_weights:
            print("\nCalculating weights for censor models...")
            if isinstance(self.censor_weights["numerator"], dict):
                censor_weights = np.zeros(len(self.data))
                for treatment in self.censor_weights["numerator"].keys():
                    mask = self.data[self.treatment_col] == treatment
                    X_num = patsy.dmatrix(self.censor_weights["numerator_formula"].split("~")[1], self.data[mask], return_type="dataframe")
                    X_denom = patsy.dmatrix(self.censor_weights["denominator_formula"].split("~")[1], self.data[mask], return_type="dataframe")
                    num_pred = self.censor_weights["numerator"][treatment].predict(X_num)
                    denom_pred = self.censor_weights["denominator"][treatment].predict(X_denom)
                    censor_weights[mask] = num_pred / denom_pred
            else:
                X_num = patsy.dmatrix(self.censor_weights["numerator_formula"].split("~")[1], self.data, return_type="dataframe")
                X_denom = patsy.dmatrix(self.censor_weights["denominator_formula"].split("~")[1], self.data, return_type="dataframe")
                num_pred = self.censor_weights["numerator"].predict(X_num)
                denom_pred = list(self.censor_weights["denominator"].values())[0].predict(X_denom)  # Use first treatment for pooled
                censor_weights = num_pred / denom_pred
            self.data["censor_weight"] = censor_weights
            print("Weights for censor models calculated and stored as 'censor_weight'.")

    def show_weight_models(self) -> None:
        if self.switch_weights:
            print("\nWeight Models for Treatment Switching")
            print("--------------------------------------")
            for key in ["numerator", "denominator"]:
                print(f"\nModel: {key.capitalize()} model")
                print(self.switch_weights[key].summary())
        if self.censor_weights:
            print("\nWeight Models for Informative Censoring")
            print("--------------------------------------")
            for key in ["numerator", "denominator"]:
                print(f"\nModel: {key.capitalize()} model")
                if isinstance(self.censor_weights[key], dict):
                    for treatment, sub_model in self.censor_weights[key].items():
                        print(f"\nTreatment: {treatment}")
                        print(sub_model.summary())
                else:
                    print(self.censor_weights[key].summary())

    def set_outcome_model(self, adjustment_terms: Optional[str] = None) -> None:
        if self.data is None:
            raise ValueError("Data must be set before specifying the outcome model.")
        formula = f"{self.outcome_col} ~ {self.treatment_col}" if not adjustment_terms else f"{self.outcome_col} ~ {self.treatment_col} + {adjustment_terms}"
        y, X = patsy.dmatrices(formula, self.data, return_type="dataframe")
        self.outcome_model = sm.GLM(y, X, family=sm.families.Gaussian()).fit()
        print(f"Outcome model specified with formula: {formula}")

    def __repr__(self) -> str:
        switch_weight_info = ""
        if self.switch_weights:
            models_fitted = all(self.switch_weights.get(k) for k in ["numerator", "denominator"])
            switch_weight_info = (f"\n - Numerator formula: {self.switch_weights['numerator_formula']} \n"
                                 f" - Denominator formula: {self.switch_weights['denominator_formula']} \n"
                                 f" - Model fitter type: te_stats_glm_logit \n"
                                 f" - Weight models {'fitted' if models_fitted else 'not fitted. Use calculate_weights()'}")
        censor_weight_info = ""
        if self.censor_weights:
            models_fitted = all(self.censor_weights.get(k) for k in ["numerator", "denominator"])
            censor_weight_info = (f"\n - Numerator formula: {self.censor_weights['numerator_formula']} \n"
                                 f" - Denominator formula: {self.censor_weights['denominator_formula']} \n"
                                 f" - Model fitter type: te_stats_glm_logit \n"
                                 f" - Weight models {'fitted' if models_fitted else 'not fitted. Use calculate_weights()'}"
                                 f"\n - Numerator pooling: {self.censor_weights['pool_models']}")
        return (f"TrialSequence(estimand={self.estimand}, directory={self.directory}, "
                f"id_col={self.id_col}, period_col={self.period_col}, "
                f"treatment_col={self.treatment_col}, outcome_col={self.outcome_col}, "
                f"eligible_col={self.eligible_col})" + switch_weight_info + censor_weight_info)

    def set_expansion_options(self, chunk_size: int = 500):
        self.chunk_size = chunk_size
        print(f"Expansion options set: chunk size = {self.chunk_size}")

    def expand_trials(self, max_period=10):
        if self.data is None:
            raise ValueError("Data must be set before expanding trials.")
        expanded_data = []
        for _, row in self.data.iterrows():
            for t in range(max_period + 1):
                new_row = row.copy()
                new_row["trial_period"] = t
                new_row["followup_time"] = t
                expanded_data.append(new_row)
        self.data = pd.DataFrame(expanded_data)
        print("Trials expanded.")
        return self.data

    def load_expanded_data(self, seed: int = 1234, p_control: float = 1.0):
        if self.data is None:
            raise ValueError("Expanded data must exist before loading.")
        np.random.seed(seed)
        if p_control < 1.0:
            mask = (self.data[self.outcome_col] == 1) | (np.random.rand(len(self.data)) < p_control)
            self.data = self.data[mask].reset_index(drop=True)
        print(f"Loaded expanded data with {len(self.data)} observations.")
        return self.data

    def fit_msm(self, weight_cols, modify_weights=None):
        if self.data is None:
            raise ValueError("Data must be set before fitting MSM.")
        missing_cols = [col for col in weight_cols if col not in self.data.columns]
        if missing_cols:
            raise KeyError(f"Missing required weight columns: {missing_cols}")
        self.data["final_weight"] = self.data[weight_cols].prod(axis=1) if weight_cols else 1.0
        if modify_weights:
            q99 = np.quantile(self.data["final_weight"], 0.99)
            self.data["final_weight"] = modify_weights(self.data["final_weight"])
        formula = "outcome ~ assigned_treatment + x2 + followup_time + I(followup_time**2) + trial_period + I(trial_period**2)"
        model = smf.glm(formula=formula, data=self.data, family=sm.families.Binomial(), freq_weights=self.data["final_weight"]).fit()
        self.outcome_model = model
        print("Marginal Structural Model fitted.")
        print(model.summary())
        return model

# Main Execution Steps
print('\n======================================== STEP 1 ========================================\n')
trial_pp = TrialSequence(estimand="PP")
trial_itt = TrialSequence(estimand="ITT")
trial_pp.create_directory()
trial_itt.create_directory()
print(trial_pp)
print(trial_itt)

print('\n======================================== STEP 2 ========================================\n')
data_censored = pd.read_csv("data_censored.csv")
print(data_censored.head())
trial_pp.set_data(data_censored, "id", "period", "treatment", "outcome", "eligible")
trial_itt.set_data(data_censored, "id", "period", "treatment", "outcome", "eligible")
print(trial_itt.data)

print('\n======================================== STEP 3.1 ========================================\n')
trial_pp.set_switch_weight_model("age", "age + x1 + x3", trial_pp.directory)

print('\n======================================== STEP 3.2 ========================================\n')
trial_pp.set_censor_weight_model("censored", "x2", "x2 + x1", "none", trial_pp.directory)
trial_itt.set_censor_weight_model("censored", "x2", "x2 + x1", "numerator", trial_itt.directory)

print('\n======================================== STEP 4 ========================================\n')
trial_pp.calculate_weights()
trial_itt.calculate_weights()
print('\nWeight Models for Trial PP:')
trial_pp.show_weight_models()
print('\nWeight Models for Trial ITT:')
trial_itt.show_weight_models()

print('\n======================================== STEP 5 ========================================\n')
trial_pp_data = trial_pp.data
trial_itt_data = trial_itt.data
trial_pp_data["assigned_treatment"] = trial_pp_data["treatment"]
trial_itt_data["assigned_treatment"] = trial_itt_data["treatment"]
formula_pp = "outcome ~ assigned_treatment + period + I(period**2)"
outcome_model_pp = smf.logit(formula_pp, data=trial_pp_data).fit()
formula_itt = "outcome ~ assigned_treatment + x2 + period + I(period**2)"
outcome_model_itt = smf.logit(formula_itt, data=trial_itt_data).fit()
print("\nPer-Protocol (PP) Model Summary:")
print(outcome_model_pp.summary())
print("\nIntention-to-Treat (ITT) Model Summary:")
print(outcome_model_itt.summary())

print('\n======================================== STEP 6 ========================================\n')
trial_pp.set_expansion_options(chunk_size=500)
trial_itt.set_expansion_options(chunk_size=500)
trial_pp_expanded = trial_pp.expand_trials(max_period=10)
trial_itt_expanded = trial_itt.expand_trials(max_period=10)
print("\nExpanded Per-Protocol (PP) Trial Data Sample:")
print(trial_pp_expanded.head())
print("\nExpanded Intention-to-Treat (ITT) Trial Data Sample:")
print(trial_itt_expanded.head())

print('\n======================================== STEP 7 ========================================\n')
trial_itt_sampled = trial_itt.load_expanded_data(seed=1234, p_control=0.5)
print("\nLoaded Intention-to-Treat (ITT) Trial Data Sample:")
print(trial_itt_sampled.head())

print('\n======================================== STEP 8 ========================================\n')
# Recalculate weights after data expansion and sampling
trial_itt.calculate_weights()
available_weight_cols = [col for col in ["switch_weight", "censor_weight"] if col in trial_itt.data.columns]
if not available_weight_cols:
    print("Warning: No weight columns available. Fitting MSM without weights.")
    trial_itt_msm = trial_itt.fit_msm(weight_cols=[], modify_weights=None)
else:
    trial_itt_msm = trial_itt.fit_msm(
        weight_cols=available_weight_cols,
        modify_weights=lambda w: np.minimum(w, np.quantile(w, 0.99))
    )



Directory created: C:\Users\Katrina\AppData\Local\Temp\trial_pp
Directory created: C:\Users\Katrina\AppData\Local\Temp\trial_itt
TrialSequence(estimand=PP, directory=C:\Users\Katrina\AppData\Local\Temp\trial_pp, id_col=None, period_col=None, treatment_col=None, outcome_col=None, eligible_col=None)
TrialSequence(estimand=ITT, directory=C:\Users\Katrina\AppData\Local\Temp\trial_itt, id_col=None, period_col=None, treatment_col=None, outcome_col=None, eligible_col=None)


   id  period  treatment  x1        x2  x3        x4  age     age_s  outcome  \
0   1       0          1   1  1.146148   0  0.734203   36  0.083333        0   
1   1       1          1   1  0.002200   0  0.734203   37  0.166667        0   
2   1       2          1   0 -0.481762   0  0.734203   38  0.250000        0   
3   1       3          1   0  0.007872   0  0.734203   39  0.333333        0   
4   1       4          1   1  0.216054   0  0.734203   40  0.416667        0   

   censored  eligible  
0         0         