## Load and Preprocess

In [None]:
import pandas as pd

# Example usage with aidssi dataset
# First, let's check what columns are in aidssi
aidssi = pd.read_csv("../src/skms/datasets/aidssi.csv", index_col=0)
print("Original columns:", aidssi.columns.tolist())
display(aidssi.head())

Original columns: ['patnr', 'time', 'status', 'cause', 'ccr5']


Unnamed: 0,patnr,time,status,cause,ccr5
1,1,9.106,1,AIDS,WW
2,2,11.039,0,event-free,WM
3,3,2.234,1,AIDS,WW
4,4,9.878,2,SI,WM
5,5,3.819,1,AIDS,WW


In [4]:
import numpy as np
import pandas as pd


def prepare_multistate_data(df, covariate_cols=None):
    """
    Prepare multistate data with counterfactuals, preserving covariates.

    Parameters:
    -----------
    df : DataFrame
        Input data with columns: patnr, time, status, and any covariates
    covariate_cols : list, optional
        List of covariate column names to preserve. If None, will auto-detect.
    """
    transitions = df.copy()

    # Auto-detect covariates if not specified
    if covariate_cols is None:
        # Assume columns that are not patnr, time, or status are covariates
        standard_cols = ["patnr", "time", "status"]
        covariate_cols = [col for col in df.columns if col not in standard_cols]

    # Add tstart column
    transitions["tstart"] = 0
    # Add origin state column
    transitions["origin_state"] = 0
    # Rename columns
    transitions.rename(
        columns={"time": "tstop", "status": "target_state"}, inplace=True
    )
    # Add censoring column
    transitions["status"] = np.where(transitions["target_state"] == 0, 0, 1)
    # Add dummy target state to censored observations
    transitions.loc[transitions["target_state"] == 0, "target_state"] = 1

    # Add counterfactuals for each observation
    counterfactuals = []
    for _, row in transitions.iterrows():
        # For each transition to state k, add counterfactuals to all other possible states
        possible_states = [1, 2]  # Adjust based on your actual states
        current_target = row["target_state"]

        for state in possible_states:
            if state != current_target:
                # Create counterfactual transition
                counterfactual = row.copy()
                counterfactual["target_state"] = state
                counterfactual["status"] = 0  # Counterfactuals are censored
                counterfactuals.append(counterfactual)

    # Combine original transitions with counterfactuals
    if counterfactuals:
        counterfactuals_df = pd.DataFrame(counterfactuals)
        transitions = pd.concat([transitions, counterfactuals_df], ignore_index=True)

    # Sort by patient number and target state for clarity
    transitions = transitions.sort_values(["patnr", "target_state"]).reset_index(
        drop=True
    )

    # Return all columns including covariates
    base_cols = ["patnr", "tstart", "tstop", "origin_state", "target_state", "status"]
    return transitions[base_cols + covariate_cols]

# Prepare data with covariates
covariate_cols = ["ccr5"]
transitions = prepare_multistate_data(aidssi, covariate_cols=covariate_cols)
print("Transitions data shape:", transitions.shape)
print("Transitions columns:", transitions.columns.tolist())
display(transitions.head())

Transitions data shape: (658, 7)
Transitions columns: ['patnr', 'tstart', 'tstop', 'origin_state', 'target_state', 'status', 'ccr5']


Unnamed: 0,patnr,tstart,tstop,origin_state,target_state,status,ccr5
0,1,0,9.106,0,1,1,WW
1,1,0,9.106,0,2,0,WW
2,2,0,11.039,0,1,0,WM
3,2,0,11.039,0,2,0,WM
4,3,0,2.234,0,1,1,WW


In [5]:
def counterfactual_to_competing_risks_format(df, covariate_cols=None):
    """
    Convert counterfactual data to competing risks format with covariates.

    Parameters:
    -----------
    df : DataFrame
        Counterfactual format data
    covariate_cols : list, optional
        List of covariate column names to preserve
    """
    # Auto-detect covariates if not specified
    if covariate_cols is None:
        standard_cols = [
            "patnr",
            "tstart",
            "tstop",
            "origin_state",
            "target_state",
            "status",
        ]
        covariate_cols = [col for col in df.columns if col not in standard_cols]

    # Get only the rows where an event actually occurred
    actual_events = df[df["status"] == 1].copy()

    # If a patient has no events (all status=0), they're censored
    all_patients = df["patnr"].unique()
    patients_with_events = actual_events["patnr"].unique()
    censored_patients = set(all_patients) - set(patients_with_events)

    # Add censored patients
    if censored_patients:
        censored_data = []
        for patient in censored_patients:
            patient_data = df[df["patnr"] == patient].iloc[0].copy()
            patient_data["status"] = 0
            patient_data["target_state"] = 0  # 0 typically indicates censoring
            censored_data.append(patient_data)

        censored_df = pd.DataFrame(censored_data)
        actual_events = pd.concat([actual_events, censored_df], ignore_index=True)

    # Rename columns to match expected format
    standard_df = actual_events.rename(
        columns={"patnr": "id", "tstop": "duration", "target_state": "event"}
    )

    # Keep necessary columns including covariates
    base_cols = ["id", "duration", "event"]
    standard_df = standard_df[base_cols + covariate_cols]

    return standard_df


# Convert to standard format preserving covariates
standard_df = counterfactual_to_competing_risks_format(
    transitions, covariate_cols=covariate_cols
)
print("Standard format shape:", standard_df.shape)
print("Standard format columns:", standard_df.columns.tolist())
display(standard_df.head())

Standard format shape: (329, 4)
Standard format columns: ['id', 'duration', 'event', 'ccr5']


Unnamed: 0,id,duration,event,ccr5
0,1,9.106,1,WW
1,3,2.234,1,WW
2,4,9.878,2,WM
3,5,3.819,1,WW
4,6,6.801,1,WW


In [8]:
from skms.visualization import StateDiagramGenerator

state_labels = {0: "Event-free", 1: "AIDS", 2: "SI"}
terminal_states = [3]

sdg = StateDiagramGenerator(
    dataset=transitions,
    patient_id='patnr',
    from_state='origin_state',
    to_state='target_state',
    tstart='tstart',
    tstop='tstop',
    status='status',
    state_labels=state_labels,
    terminal_states=terminal_states
)

sdg.plot_state_diagram()