In [76]:
import os
import tempfile
import pandas as pd
import numpy as np
import statsmodels.api as sm
import patsy
import pickle
from typing import Optional, Dict, Union

class TrialSequence:
    def __init__(self, estimand: str):
        """
        Initialize a TrialSequence object with the specified estimand.

        :param estimand: The type of estimand (e.g., "ITT", "PP").
        """
        self.estimand = estimand
        self.directory: Optional[str] = None
        self.data: Optional[pd.DataFrame] = None
        self.id_col: Optional[str] = None
        self.period_col: Optional[str] = None
        self.treatment_col: Optional[str] = None
        self.outcome_col: Optional[str] = None
        self.eligible_col: Optional[str] = None
        self.switch_weights: Optional[Dict] = None
        self.censor_weights: Optional[Dict] = None

    def create_directory(self, base_dir: Optional[str] = None) -> None:
        """
        Create a directory for the trial sequence.

        :param base_dir: The base directory where the trial directory will be created.
                         If not provided, the system's temporary directory is used.
        """
        if base_dir is None:
            base_dir = tempfile.gettempdir()

        dir_name = f"trial_{self.estimand.lower()}"
        self.directory = os.path.join(base_dir, dir_name)
        os.makedirs(self.directory, exist_ok=True)
        print(f"Directory created: {self.directory}")

    def set_data(self, data: pd.DataFrame, id_col: str, period_col: str, treatment_col: str, outcome_col: str, eligible_col: str) -> None:
        """
        Assign data and column mappings to the trial sequence.

        :param data: The DataFrame containing the trial data.
        :param id_col: The column name for the ID.
        :param period_col: The column name for the period.
        :param treatment_col: The column name for the treatment.
        :param outcome_col: The column name for the outcome.
        :param eligible_col: The column name for eligibility.
        """
        self.data = data
        self.id_col = id_col
        self.period_col = period_col
        self.treatment_col = treatment_col
        self.outcome_col = outcome_col
        self.eligible_col = eligible_col
        print(f"Data and columns assigned for {self.estimand} trial.")

    def _fit_logit_model(self, data: pd.DataFrame, formula: str, label: str, save_path: Optional[str] = None) -> sm.Logit:
        """
        Fit a logistic regression model using statsmodels.

        :param data: The DataFrame containing the data.
        :param formula: The formula for the model.
        :param label: The label for the model (used for saving).
        :param save_path: The path where the model should be saved.
        :return: The fitted logistic regression model.
        """
        y, X = patsy.dmatrices(formula, data, return_type="dataframe")
        model = sm.Logit(y, X).fit()

        if save_path:
            os.makedirs(save_path, exist_ok=True)
            model_file = os.path.join(save_path, f"{label}_model.pkl")
            with open(model_file, "wb") as f:
                pickle.dump(model, f)
            print(f"Model saved to {model_file}")

        return model

    def set_switch_weight_model(self, numerator_formula: str, denominator_formula: str, save_path: Optional[str] = None) -> None:
        """
        Set up the switch weight models for censoring due to treatment switching.

        :param numerator_formula: The formula for the numerator model.
        :param denominator_formula: The formula for the denominator model.
        :param save_path: The path where the models should be saved.
        """
        if self.estimand == "ITT":
            raise ValueError("Switch weight models cannot be used with ITT estimand.")

        if self.data is None:
            raise ValueError("Data must be set before fitting models.")

        print("\n3.1 Censoring due to treatment switching")
        print("We specify model formulas to be used for calculating the probability of receiving treatment in the current period.")
        print("Separate models are fitted for patients who had treatment = 1 and those who had treatment = 0 in the previous period.")
        print("Stabilized weights are used by fitting numerator and denominator models.\n")

        numerator_model = self._fit_logit_model(
            data=self.data,
            formula=f"{self.treatment_col} ~ {numerator_formula}",
            label="numerator",
            save_path=save_path
        )

        denominator_model = self._fit_logit_model(
            data=self.data,
            formula=f"{self.treatment_col} ~ {denominator_formula}",
            label="denominator",
            save_path=save_path
        )

        self.switch_weights = {
            "numerator": numerator_model,
            "denominator": denominator_model,
            "numerator_formula": f"{self.treatment_col} ~ {numerator_formula}",
            "denominator_formula": f"{self.treatment_col} ~ {denominator_formula}",
        }
        print("Switch weight models fitted.")

    def set_censor_weight_model(self, censor_event: str, numerator_formula: str, denominator_formula: str, pool_models: str = "none", save_path: Optional[str] = None) -> None:
        if self.data is None:
            raise ValueError("Data must be set before fitting models.")
        print("\n3.2 Other informative censoring")
        data = self.data.copy()
        if pool_models == "numerator":
            numerator_model = self._fit_logit_model(
                data=data,
                formula=f"{censor_event} ~ {numerator_formula}",
                label="numerator_pooled",
                save_path=save_path
            )
        else:
            numerator_model = {
                treatment: self._fit_logit_model(
                    data=data[data[self.treatment_col] == treatment],
                    formula=f"{censor_event} ~ {numerator_formula}",
                    label=f"numerator_treatment_{treatment}",
                    save_path=save_path
                )
                for treatment in data[self.treatment_col].unique()
            }
        denominator_model = {
            treatment: self._fit_logit_model(
                data=data[data[self.treatment_col] == treatment],
                formula=f"{censor_event} ~ {denominator_formula}",
                label=f"denominator_treatment_{treatment}",
                save_path=save_path
            )
            for treatment in data[self.treatment_col].unique()
        }
        self.censor_weights = {
            "numerator": numerator_model,
            "denominator": denominator_model,
            "numerator_formula": f"{censor_event} ~ {numerator_formula}",
            "denominator_formula": f"{censor_event} ~ {denominator_formula}",
            "pool_models": pool_models,
        }
        print("Censor weight models fitted.")

    def __repr__(self) -> str:
        """
        String representation of the TrialSequence object.
        """
        switch_weight_info = ""
        if self.switch_weights:
            models_fitted = all(self.switch_weights.get(k) for k in ["numerator", "denominator"])
            switch_weight_info = (
                f"\n - Numerator formula: {self.switch_weights['numerator_formula']} \n"
                f" - Denominator formula: {self.switch_weights['denominator_formula']} \n"
                f" - Model fitter type: te_stats_glm_logit \n"
                f" - Weight models {'fitted' if models_fitted else 'not fitted. Use calculate_weights()'}"
            )

        censor_weight_info = ""
        if self.censor_weights:
            models_fitted = all(self.censor_weights.get(k) for k in ["numerator", "denominator"])
            censor_weight_info = (
                f"\n - Numerator formula: {self.censor_weights['numerator_formula']} \n"
                f" - Denominator formula: {self.censor_weights['denominator_formula']} \n"
                f" - Model fitter type: te_stats_glm_logit \n"
                f" - Weight models {'fitted' if models_fitted else 'not fitted. Use calculate_weights()'}"
            )
            if self.censor_weights["pool_models"] == "numerator":
                censor_weight_info += "\n - Numerator model is pooled across treatment arms. Denominator model is not pooled."

        return (
            f"TrialSequence(estimand={self.estimand}, directory={self.directory}, "
            f"id_col={self.id_col}, period_col={self.period_col}, "
            f"treatment_col={self.treatment_col}, outcome_col={self.outcome_col}, "
            f"eligible_col={self.eligible_col})" + switch_weight_info + censor_weight_info
        )


# STEP 1 ===============================================
# Create TrialSequence objects
print('\n======================================== STEP 1 ========================================\n')  
trial_pp = TrialSequence(estimand="PP")  # Per-protocol
trial_itt = TrialSequence(estimand="ITT")  # Intention-to-treat

# Create directories for each trial
trial_pp.create_directory()
trial_itt.create_directory()

# Print objects for verification
print(trial_pp)
print(trial_itt)


# STEP 2 ===============================================
# Load the data from the CSV file
print('\n======================================== STEP 2 ========================================\n')
data_censored = pd.read_csv("data_censored.csv")

# Inspect the first few rows of the dataset
print(data_censored.head())

# Set data and column mappings
trial_pp.set_data(
    data=data_censored,
    id_col="id",
    period_col="period",
    treatment_col="treatment",
    outcome_col="outcome",
    eligible_col="eligible"
)

trial_itt.set_data(
    data=data_censored,
    id_col="id",
    period_col="period",
    treatment_col="treatment",
    outcome_col="outcome",
    eligible_col="eligible"
)

print(trial_itt.data)

# STEP 3 ===============================================
print('\n======================================== STEP 3.1 ========================================\n')
# Censoring due to treatment switching
trial_pp.set_switch_weight_model(
    numerator_formula="age",  # Independent variables only
    denominator_formula="age + x1 + x3",  # Independent variables only
    save_path=trial_pp.directory
)

print('\n======================================== STEP 3.2 ========================================\n')
# Other informative censoring
trial_pp.set_censor_weight_model(
    censor_event="censored",
    numerator_formula="x2",  # Independent variables only
    denominator_formula="x2 + x1",  # Independent variables only
    pool_models="none",
    save_path=trial_pp.directory
)

trial_itt.set_censor_weight_model(
    censor_event="censored",
    numerator_formula="x2",  # Independent variables only
    denominator_formula="x2 + x1",  # Independent variables only
    pool_models="numerator",
    save_path=trial_itt.directory
)

# Print the objects for verification
print(trial_pp)
print(trial_itt)



Directory created: C:\Users\Katrina\AppData\Local\Temp\trial_pp
Directory created: C:\Users\Katrina\AppData\Local\Temp\trial_itt
TrialSequence(estimand=PP, directory=C:\Users\Katrina\AppData\Local\Temp\trial_pp, id_col=None, period_col=None, treatment_col=None, outcome_col=None, eligible_col=None)
TrialSequence(estimand=ITT, directory=C:\Users\Katrina\AppData\Local\Temp\trial_itt, id_col=None, period_col=None, treatment_col=None, outcome_col=None, eligible_col=None)


   id  period  treatment  x1        x2  x3        x4  age     age_s  outcome  \
0   1       0          1   1  1.146148   0  0.734203   36  0.083333        0   
1   1       1          1   1  0.002200   0  0.734203   37  0.166667        0   
2   1       2          1   0 -0.481762   0  0.734203   38  0.250000        0   
3   1       3          1   0  0.007872   0  0.734203   39  0.333333        0   
4   1       4          1   1  0.216054   0  0.734203   40  0.416667        0   

   censored  eligible  
0         0         