In [1]:
import pandas as pd
import numpy as np
from pymsm.datasets import prep_covid_hosp_data
from pymsm.multi_state_competing_risks_model import MultiStateModel

# from pymsm.plotting import stackplot, stackplot_state_timesteps, stackplot_state_timesteps_from_paths
from pymsm.statistics import (
    prob_visited_states,
    stats_total_time_at_states,
)

%load_ext autoreload
%autoreload 2

# Load Covid hospitalization data set  
raw public data available at https://github.com/JonathanSomer/covid-19-multi-state-model/blob/master/data/data_for_paper.csv

In [2]:
covid_dataset = prep_covid_hosp_data()

# Some path definitions
terminal_states = [4]
covariate_cols = ["is_male", "age"]
states_labels = {0: "Censored", 1: "OOHQ", 2: "M&M", 3: "Severe", 4: "Deceased"}

100%|██████████| 2675/2675 [00:03<00:00, 873.49it/s]


Let's look at one patients path

In [3]:
covid_dataset[567].print_path()

Sample id: 577
States: [2 3 4]
Transition times: [ 6 31]
Covariates:
is_male     1.0
age        72.5
Name: 567, dtype: float64


# Let's fit the Multistate model

In [4]:
def default_update_covariates_function(
    covariates_entering_origin_state,
    origin_state=None,
    target_state=None,
    time_at_origin=None,
    abs_time_entry_to_target_state=None,
):
    return covariates_entering_origin_state


multi_state_model = MultiStateModel(
    covid_dataset, terminal_states, default_update_covariates_function, covariate_cols,
)


multi_state_model.fit()


Fitting Model at State: 2
>>> Fitting Transition to State: 1, n events: 2135
>>> Fitting Transition to State: 3, n events: 275
>>> Fitting Transition to State: 4, n events: 52
Fitting Model at State: 1
>>> Fitting Transition to State: 2, n events: 98
>>> Fitting Transition to State: 3, n events: 2



>>> events = df['target_state'].astype(bool)
>>> print(df.loc[events, 'is_male'].var())
>>> print(df.loc[~events, 'is_male'].var())

A very low variance means that the column is_male completely determines whether a subject dies or not. See https://stats.stackexchange.com/questions/11109/how-to-deal-with-perfect-separation-in-logistic-regression.




Fitting Model at State: 3
>>> Fitting Transition to State: 2, n events: 193
>>> Fitting Transition to State: 1, n events: 9
>>> Fitting Transition to State: 4, n events: 135


# Single patient stats  
Let's take a look at how the model models transitions for a single patient - a female aged 75

In [5]:
# Run MC for a sample single patient
mc_paths = multi_state_model.run_monte_carlo_simulation(
    sample_covariates=pd.Series({"is_male":0, "age":75}),
    origin_state=2,
    current_time=0,
    n_random_samples=100,
    max_transitions=10,
    print_paths=False,
    n_jobs=8
)

  probability_for_each_t / probability_for_each_t.max()
  probability_for_each_t / probability_for_each_t.max()
100%|██████████| 100/100 [00:49<00:00,  2.02it/s]


In [15]:
# Probability of visiting any of the states
for state in states_labels.keys():
    if state == 0:
        continue
    print(
        f"Probabilty of {states_labels[state]} = {prob_visited_states(mc_paths, states=[state])}"
    )

# Probability of terminal states - Death and Relapse
print(
    f"Probabilty of any terminal state = {prob_visited_states(mc_paths, states=multi_state_model.terminal_states)}"
)

# Stats for times at states
dfs = []
for state in states_labels.keys():
    if state == 0 or state in terminal_states:
        continue
    dfs.append(
        pd.DataFrame(
            data=stats_total_time_at_states(mc_paths, states=[state]),
            index=[states_labels[state]],
        )
    )
pd.concat(dfs).round(3).T


Probabilty of OOHQ = 0.03
Probabilty of M&M = 1.0
Probabilty of Severe = 0.31
Probabilty of Deceased = 1.0
Probabilty of any terminal state = 1.0


Unnamed: 0,OOHQ,M&M,Severe
time_in_state_mean,1.89,3.37,0.67
time_in_state_std,12.098,2.129,1.588
time_in_state_median,0.0,3.0,0.0
time_in_state_min,0.0,1.0,0.0
time_in_state_max,97.0,17.0,8.0
time_in_state_quantile_0.1,0.0,1.0,0.0
time_in_state_quantile_0.25,0.0,2.0,0.0
time_in_state_quantile_0.75,0.0,4.0,0.0
time_in_state_quantile_0.9,0.0,6.0,2.0


In [64]:
states_list = []
for path in mc_paths:
    states_list.append(path.states)
# Change from numbers to labels
states_list = [[states_labels[y] for y in x] for x in states_list]

from collections import Counter
counter = Counter(tuple(x) for x in states_list)
{str(k).replace(", ", "->").replace("(", "").replace(")", "").replace("'",""):v for k,v in counter.items()}

{'M&M->Deceased': 66,
 'M&M->Severe->Deceased': 31,
 'M&M->OOHQ->M&M->Deceased': 3}