In [None]:
import pandas as pd
import numpy as np
from enum import Enum
from dataclasses import dataclass, field
from typing import Dict, Any, Optional, Callable
import statsmodels.formula.api as smf

# Define estimand types as enum
class EstimandType(str, Enum):
    ITT = "ITT"  # Intention-to-treat
    PP = "PP"    # Per-protocol
    AT = "AT"    # As-treated

# Data structures using dataclasses
@dataclass
class DataStore:
    N: int = 0

@dataclass
class DataUnset:
    def __str__(self):
        return "Data not set"

@dataclass
class WeightsUnset:
    def __str__(self):
        return "Weights not calculated"

@dataclass
class ExpansionUnset:
    datastore: DataStore = field(default_factory=DataStore)
    
    def __str__(self):
        return "Data expansion not performed"

@dataclass
class ModelUnset:
    def __str__(self):
        return "Model not fit"

@dataclass
class TrialSequenceData:
    """Data structure holding all trial sequence components"""
    estimand: EstimandType
    data: Any = field(default_factory=lambda: pd.read_csv('data_censored.csv'))
    censor_weights: Any = field(default_factory=WeightsUnset)
    switch_weights: Optional[Any] = None
    expansion: Any = field(default_factory=ExpansionUnset)
    outcome_model: Any = field(default_factory=ModelUnset)
    outcome_data: Optional[Any] = None

# Functions for creating and manipulating trial sequence data
# Model fitter --------------------------------------------------------------------------------
def model_fitter(formula, data):
    model = smf.logit(formula=formula, data=data).fit()
    return model

def create_trial_sequence(estimand: str, **kwargs) -> TrialSequenceData:
    try:
        estimand_type = EstimandType(estimand)
    except ValueError:
        raise ValueError(f"{estimand} is not a valid estimand type. Must be one of: {', '.join([e.value for e in EstimandType])}")
    
    return TrialSequenceData(estimand=estimand_type, **kwargs)

def format_trial_sequence(seq: TrialSequenceData) -> str:
    
    lines = []
    lines.append("Trial Sequence Object")
    lines.append(f"Estimand: {seq.estimand.value}")
    lines.append("")
    lines.append("Data:")
    lines.append(str(seq.data))
    lines.append("")
    lines.append("IPW for informative censoring:")
    lines.append(str(seq.censor_weights))
    
    if seq.switch_weights is not None:
        lines.append("")
        lines.append("IPW for treatment switch censoring:")
        lines.append(str(seq.switch_weights))
    
    lines.append("")
    if not isinstance(seq.data, DataUnset):
        lines.append(str(seq.expansion))
        lines.append("")
    
    lines.append("Outcome model:")
    lines.append(str(seq.outcome_model))
    lines.append("")
    
    if hasattr(seq.expansion, "datastore") and seq.expansion.datastore.N > 0:
        lines.append(str(seq.outcome_data))
    
    return "\n".join(lines)

def print_trial_sequence(seq: TrialSequenceData) -> None:
    """Print a formatted trial sequence."""
    print(format_trial_sequence(seq))

def update_trial_sequence(seq: TrialSequenceData, **kwargs) -> TrialSequenceData:
    # Create a new data structure with updated fields
    updated_data = {**seq.__dict__, **kwargs}
    return TrialSequenceData(**updated_data)

def calculate_weights(seq: TrialSequenceData, weight_func: Callable) -> TrialSequenceData:
    weights = weight_func(seq.data)
    return update_trial_sequence(seq, censor_weights=weights)

def set_censor_weight_model(trial, censor_event, numerator, denominator, pool_models="none", model_fitter=None):
    # Create default model_fitter if None provided
    if model_fitter is None:
        model_fitter = lambda formula, data: smf.logit(formula=formula, data=data).fit()
    
    # Transform formulas to include censor_event
    numerator_formula = f"1 - {censor_event} ~ {numerator}"
    denominator_formula = f"1 - {censor_event} ~ {denominator}"
    
    # Create a weight specification
    censor_weights = {
        "numerator": numerator_formula,
        "denominator": denominator_formula,
        "pool_models": pool_models,
        "model_fitter": model_fitter
    }
    
    # Assign to trial object
    trial.censor_weights = censor_weights
    
    return trial


# Example usage
if __name__ == "__main__":
    # Create trial sequences
    trial_pp = create_trial_sequence(estimand="PP")
    trial_itt = create_trial_sequence(estimand="ITT")

    
    # Display the sequences
    print_trial_sequence(trial_pp)
    print("\n" + "="*50 + "\n")
    print_trial_sequence(trial_itt)

    # Apply the function to set censor weights
trial_pp = set_censor_weight_model(
    trial=trial_pp,
    censor_event="censored",
    numerator="x2",
    denominator="x2 + x1",
    pool_models="none",
    model_fitter=model_fitter  
)

trial_itt = set_censor_weight_model(
    trial=trial_itt,
    censor_event="censored",
    numerator="x2",
    denominator="x2 + x1",
    pool_models="numerator",
    model_fitter=model_fitter  
)

# Print the censor weights
print(trial_pp.censor_weights)

Trial Sequence Object
Estimand: PP

Data:
     id  period  treatment  x1        x2  x3        x4  age     age_s  \
0     1       0          1   1  1.146148   0  0.734203   36  0.083333   
1     1       1          1   1  0.002200   0  0.734203   37  0.166667   
2     1       2          1   0 -0.481762   0  0.734203   38  0.250000   
3     1       3          1   0  0.007872   0  0.734203   39  0.333333   
4     1       4          1   1  0.216054   0  0.734203   40  0.416667   
..   ..     ...        ...  ..       ...  ..       ...  ...       ...   
720  99       3          0   0 -0.747906   1  0.575268   68  2.750000   
721  99       4          0   0 -0.790056   1  0.575268   69  2.833333   
722  99       5          1   1  0.387429   1  0.575268   70  2.916667   
723  99       6          1   1 -0.033762   1  0.575268   71  3.000000   
724  99       7          0   0 -1.340497   1  0.575268   72  3.083333   

     outcome  censored  eligible  
0          0         0         1  
1          