## Load and Preprocess

In [4]:
import pandas as pd

pd.set_option("display.max_columns", None)
# Example usage with covid dataset
# First, let's check what columns are in covid
covid = pd.read_csv("../src/skms/datasets/covid_hosp_data.csv", index_col=0)
print("Original columns:", covid.columns.tolist())
display(covid.head())

Original columns: ['sex', 'type0', 'current_type', 'current_time', 'time1', 'type1', 'time2', 'type2', 'time3', 'type3', 'time4', 'type4', 'time5', 'type5', 'time6', 'type6', 'type7', 'time7', 'type8', 'time8', 'time9', 'type9', 'type10', 'time10', 'type11', 'time11', 'first_date', 'new_time1', 'new_time10', 'new_time2', 'new_time3', 'new_time4', 'new_time5', 'new_time6', 'new_time7', 'new_time8', 'new_time9', 'new_type0', 'new_type1', 'new_type10', 'new_type2', 'new_type3', 'new_type4', 'new_type5', 'new_type6', 'new_type7', 'new_type8', 'new_type9', 'age_group']


Unnamed: 0_level_0,sex,type0,current_type,current_time,time1,type1,time2,type2,time3,type3,time4,type4,time5,type5,time6,type6,type7,time7,type8,time8,time9,type9,type10,time10,type11,time11,first_date,new_time1,new_time10,new_time2,new_time3,new_time4,new_time5,new_time6,new_time7,new_time8,new_time9,new_type0,new_type1,new_type10,new_type2,new_type3,new_type4,new_type5,new_type6,new_type7,new_type8,new_type9,age_group
id,Unnamed: 1_level_1,Unnamed: 2_level_1,Unnamed: 3_level_1,Unnamed: 4_level_1,Unnamed: 5_level_1,Unnamed: 6_level_1,Unnamed: 7_level_1,Unnamed: 8_level_1,Unnamed: 9_level_1,Unnamed: 10_level_1,Unnamed: 11_level_1,Unnamed: 12_level_1,Unnamed: 13_level_1,Unnamed: 14_level_1,Unnamed: 15_level_1,Unnamed: 16_level_1,Unnamed: 17_level_1,Unnamed: 18_level_1,Unnamed: 19_level_1,Unnamed: 20_level_1,Unnamed: 21_level_1,Unnamed: 22_level_1,Unnamed: 23_level_1,Unnamed: 24_level_1,Unnamed: 25_level_1,Unnamed: 26_level_1,Unnamed: 27_level_1,Unnamed: 28_level_1,Unnamed: 29_level_1,Unnamed: 30_level_1,Unnamed: 31_level_1,Unnamed: 32_level_1,Unnamed: 33_level_1,Unnamed: 34_level_1,Unnamed: 35_level_1,Unnamed: 36_level_1,Unnamed: 37_level_1,Unnamed: 38_level_1,Unnamed: 39_level_1,Unnamed: 40_level_1,Unnamed: 41_level_1,Unnamed: 42_level_1,Unnamed: 43_level_1,Unnamed: 44_level_1,Unnamed: 45_level_1,Unnamed: 46_level_1,Unnamed: 47_level_1,Unnamed: 48_level_1,Unnamed: 49_level_1
1,Male,2,1,29,3,1,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,2020-04-03,3.0,0,,,,,,,,0,23,16,0,0,0,0,0,0,0,0,0,"[55.0, 60.0)"
2,Female,2,2,20,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,2020-04-14,,0,,,,,,,,0,23,0,0,0,0,0,0,0,0,0,0,"[75.0, 80.0)"
3,Female,2,2,9,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,2020-04-24,,0,,,,,,,,0,23,0,0,0,0,0,0,0,0,0,0,"[80.0, 105.0)"
4,Male,2,1,20,6,1,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,2020-04-13,6.0,0,,,,,,,,0,23,16,0,0,0,0,0,0,0,0,0,"[70.0, 75.0)"
5,Female,2,1,32,5,1,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,2020-03-30,5.0,0,,,,,,,,0,23,16,0,0,0,0,0,0,0,0,0,"[80.0, 105.0)"


In [5]:
import pandas as pd
import numpy as np
from tqdm import tqdm


def prep_covid_hosp_data_multistate():
    """
    Transform COVID hospitalization data into multi-state event history format.

    Returns a dataframe with columns:
    - id: subject identifier
    - from: initial state
    - to: possible end state
    - tstart: time entered risk set for transition
    - tstop: time exited from state (or censored)
    - status: 1 if transition happened, 0 if not
    - sex: patient sex (Male/Female)
    - age: patient age
    - is_male: binary indicator
    """

    # Column mappings
    state_cols = [
        "new_type0",
        "new_type1",
        "new_type2",
        "new_type3",
        "new_type4",
        "new_type5",
        "new_type6",
        "new_type7",
        "new_type8",
        "new_type9",
        "new_type10",
    ]
    time_cols = [
        "new_time1",
        "new_time2",
        "new_time3",
        "new_time4",
        "new_time5",
        "new_time6",
        "new_time7",
        "new_time8",
        "new_time9",
        "new_time10",
    ]

    # State mappings
    states_mapper = {0: 0, 16: 1, 23: 2, 4: 3, 5: 4}
    state_labels = {1: "Discharged or Recovered", 2: "Mild or Moderate", 3: "Severe", 4: "Deceased"}

    # Define possible transitions (from_state -> [possible_to_states])
    # Based on typical COVID progression patterns
    possible_transitions = {
        1: [2, 3,],  # Discharged/Recovered
        2: [1, 3, 4],  # From Mild/Moderate
        3: [1, 2, 4],  # From Severe
        4: [],  # Deceased is terminal
    }

    # Age and sex mappings
    age_mapper = {
        "[55.0, 60.0)": 57.5,
        "[75.0, 80.0)": 77.5,
        "[80.0, 105.0)": 92.5,
        "[70.0, 75.0)": 72.5,
        "[45.0, 50.0)": 47.5,
        "[25.0, 30.0)": 27.5,
        "[60.0, 65.0)": 62.5,
        "[35.0, 40.0)": 37.5,
        "[20.0, 25.0)": 22.5,
        "[0.0, 20.0)": 10,
        "[65.0, 70.0)": 67.5,
        "[40.0, 45.0)": 42.5,
        "[50.0, 55.0)": 52.5,
        "[30.0, 35.0)": 32.5,
    }
    sex_mapper = {"Male": 1, "Female": 0}

    # Read data (you'll need to adjust the path)
    df = pd.read_csv("../src/skms/datasets/covid_hosp_data.csv")

    # Process mappings
    df["age"] = df["age_group"].map(age_mapper)
    df["is_male"] = df["sex"].map(sex_mapper)

    # Rename states
    for col in state_cols:
        df[col] = df[col].map(states_mapper).astype(int)

    # Process each patient
    all_transitions = []

    for _, row in tqdm(df.iterrows(), total=len(df), desc="Processing patients"):
        patient_id = row["id"]

        # Parse patient trajectory
        states = row[state_cols].values.astype(int)
        times = row[time_cols].values.astype(float)

        # Find first NaN to determine actual trajectory length
        first_nan = np.where(np.isnan(times))[0]
        if len(first_nan) > 0:
            first_nan = first_nan[0]
            states = states[: (first_nan + 1)]
            times = times[:first_nan].astype(int)

        # Handle current time and censoring
        total_transitions_time = np.sum(times)
        current_time = row["current_time"]

        # Terminal states
        terminal_states = [4]  # Discharged/Recovered and Deceased

        # Add remaining time if not in terminal state
        if (current_time > total_transitions_time) and (states[-1] not in terminal_states):
            times = np.append(times, current_time - total_transitions_time)

        # Handle edge case
        if (len(states) != len(times)) and (states[-1] not in terminal_states):
            times = np.append(times, 1)

        # Fix zero transition times
        times[times == 0] = 1

        # Build transition records
        cumulative_time = 0

        for i in range(len(states)):
            current_state = states[i]

            # Time at entry and exit for this state
            tstart = cumulative_time
            if i < len(times):
                tstop = cumulative_time + times[i]
            else:
                tstop = current_time  # Censored

            # Get possible next states
            possible_next = possible_transitions.get(current_state, [])

            # Determine actual next state (if any)
            actual_next = None
            if i < len(states) - 1:
                actual_next = states[i + 1]

            # Create records for all possible transitions
            for next_state in possible_next:
                transition_record = {
                    "id": patient_id,
                    "from": current_state,
                    "to": next_state,
                    "tstart": tstart,
                    "tstop": tstop,
                    "status": 1 if next_state == actual_next else 0,
                    "sex": row["sex"],
                    "age": row["age"],
                    "is_male": row["is_male"],
                }
                all_transitions.append(transition_record)

            # Update cumulative time
            if i < len(times):
                cumulative_time += times[i]

    # Create final dataframe
    result_df = pd.DataFrame(all_transitions)

    # Add state labels for readability
    result_df["from_label"] = result_df["from"].map(state_labels)
    result_df["to_label"] = result_df["to"].map(state_labels)

    # Sort by patient and time
    result_df = result_df.sort_values(["id", "tstart", "from", "to"]).reset_index(drop=True)

    return result_df, state_labels


# Example usage:
df_multistate, state_labels = prep_covid_hosp_data_multistate()
display(df_multistate.head(5))
print(f"\nTransition summary:\n{df_multistate.groupby(['from_label', 'to_label'])['status'].sum()}")


Processing patients: 100%|██████████| 2675/2675 [00:00<00:00, 3820.60it/s]


Unnamed: 0,id,from,to,tstart,tstop,status,sex,age,is_male,from_label,to_label
0,1,2,1,0.0,3.0,1,Male,57.5,1,Mild or Moderate,Discharged or Recovered
1,1,2,3,0.0,3.0,0,Male,57.5,1,Mild or Moderate,Severe
2,1,2,4,0.0,3.0,0,Male,57.5,1,Mild or Moderate,Deceased
3,1,1,2,3.0,29.0,0,Male,57.5,1,Discharged or Recovered,Mild or Moderate
4,1,1,3,3.0,29.0,0,Male,57.5,1,Discharged or Recovered,Severe



Transition summary:
from_label               to_label               
Discharged or Recovered  Mild or Moderate             98
                         Severe                        2
Mild or Moderate         Deceased                     52
                         Discharged or Recovered    2135
                         Severe                      275
Severe                   Deceased                    135
                         Discharged or Recovered       9
                         Mild or Moderate            193
Name: status, dtype: int64


In [7]:
from skms.visualization import StateDiagramGenerator

terminal_states = [4]

sdg = StateDiagramGenerator(
    dataset=df_multistate,
    patient_id='id',
    from_state='from',
    to_state='to',
    tstart='tstart',
    tstop='tstop',
    status='status',
    state_labels=state_labels,
    terminal_states=terminal_states
)

sdg.plot_state_diagram()