## Load and Prep Data

In [1]:
import pandas as pd

aidssi = pd.read_csv("/home/azureuser/cloudfiles/code/Users/draetta.edoardo/pymsm/src/pymsm/datasets/aidssi.csv", index_col=0)
aidssi

Unnamed: 0,patnr,time,status,cause,ccr5
1,1,9.106,1,AIDS,WW
2,2,11.039,0,event-free,WM
3,3,2.234,1,AIDS,WW
4,4,9.878,2,SI,WM
5,5,3.819,1,AIDS,WW
...,...,...,...,...,...
325,325,0.112,2,SI,WW
326,326,9.068,1,AIDS,WW
327,327,5.314,2,SI,WW
328,328,10.117,1,AIDS,WW


In [2]:
import numpy as np

def prepare_multistate_data(df):

    transitions = df.copy()

    # Add tstart column
    transitions['tstart'] = 0

    # Add origin state column
    transitions['origin_state'] = 0

    # Rename columns
    transitions.rename(columns={'time': 'tstop', 'status':'target_state'}, inplace=True)

    # Add censoring column
    transitions["status"] = np.where(
        transitions["target_state"] == 0, 0, 1
    )

    # Add dummy target state to censored observations
    transitions.loc[transitions["target_state"] == 0, "target_state"] = 1 

    transitions[["patnr", "tstart", "tstop", "origin_state", "target_state", "status"]]


    # Add counterfactuals for each observation
    counterfactuals = []
    for _, row in transitions.iterrows():
        # For each transition to state k, add counterfactuals to all other possible states
        possible_states = [1, 2]  # Adjust based on your actual states
        current_target = row['target_state']
        
        for state in possible_states:
            if state != current_target:
                # Create counterfactual transition
                counterfactual = row.copy()
                counterfactual['target_state'] = state
                counterfactual['status'] = 0  # Counterfactuals are censored
                counterfactuals.append(counterfactual)

    # Combine original transitions with counterfactuals
    if counterfactuals:
        counterfactuals_df = pd.DataFrame(counterfactuals)
        transitions = pd.concat([transitions, counterfactuals_df], ignore_index=True)

    # Sort by patient number and target state for clarity
    transitions = transitions.sort_values(['patnr', 'target_state']).reset_index(drop=True)

    return transitions[["patnr", "tstart", "tstop", "origin_state", "target_state", "status"]]

transitions = prepare_multistate_data(aidssi)
transitions

Unnamed: 0,patnr,tstart,tstop,origin_state,target_state,status
0,1,0,9.106,0,1,1
1,1,0,9.106,0,2,0
2,2,0,11.039,0,1,0
3,2,0,11.039,0,2,0
4,3,0,2.234,0,1,1
...,...,...,...,...,...,...
653,327,0,5.314,0,2,1
654,328,0,10.117,0,1,1
655,328,0,10.117,0,2,0
656,329,0,2.631,0,1,0


In [3]:
def make_transition_table(df):
    # Select true transitions
    filtered = transitions[transitions['status'] == 1]
    transition_counts = filtered.groupby(['origin_state', 'target_state']).size().unstack(fill_value=0)

    # Fill in missing rows/columns if any states are missing
    states = sorted(set(transitions['origin_state']).union(set(transitions['target_state'])))
    transition_table = transition_counts.reindex(index=states, columns=states, fill_value=0)

    return transition_table

make_transition_table(transitions)

target_state,0,1,2
origin_state,Unnamed: 1_level_1,Unnamed: 2_level_1,Unnamed: 3_level_1
0,0,114,108
1,0,0,0
2,0,0,0


In [4]:
def make_censoring_table(df):
    # First, get patients that never have status = 1
    censored_patients = df.groupby('patnr')['status'].max() == 0
    patients_to_keep = censored_patients[censored_patients].index
    filtered_df = df[df['patnr'].isin(patients_to_keep)]
    filtered_df = filtered_df.drop_duplicates(subset=['patnr', 'origin_state'])

    # count censoring from each state
    censoring_table = filtered_df[filtered_df['status'] == 0].groupby('origin_state').size().reset_index(name='censored')
    return censoring_table

make_censoring_table(transitions)

Unnamed: 0,origin_state,censored
0,0,107


## State Diagram Generator

In [5]:
transitions = prepare_multistate_data(aidssi)
transitions

Unnamed: 0,patnr,tstart,tstop,origin_state,target_state,status
0,1,0,9.106,0,1,1
1,1,0,9.106,0,2,0
2,2,0,11.039,0,1,0
3,2,0,11.039,0,2,0
4,3,0,2.234,0,1,1
...,...,...,...,...,...,...
653,327,0,5.314,0,2,1
654,328,0,10.117,0,1,1
655,328,0,10.117,0,2,0
656,329,0,2.631,0,1,0


In [6]:
print(transitions)

     patnr  tstart   tstop  origin_state  target_state  status
0        1       0   9.106             0             1       1
1        1       0   9.106             0             2       0
2        2       0  11.039             0             1       0
3        2       0  11.039             0             2       0
4        3       0   2.234             0             1       1
..     ...     ...     ...           ...           ...     ...
653    327       0   5.314             0             2       1
654    328       0  10.117             0             1       1
655    328       0  10.117             0             2       0
656    329       0   2.631             0             1       0
657    329       0   2.631             0             2       1

[658 rows x 6 columns]


In [7]:
from pymsm.visualization import StateDiagramGenerator

state_labels = {0: "Event-free", 1: "AIDS", 2: "SI"}
terminal_states = [3]

sdg = StateDiagramGenerator(
    state_labels=state_labels,
    terminal_states=terminal_states
)

sdg.load_data(transitions)

In [8]:
sdg.get_transition_matrix()

Unnamed: 0,0,1,2
0,0,114,108
1,0,0,0
2,0,0,0


In [9]:
sdg.plot_state_diagram()