In [5]:
import pandas as pd
import numpy as np
from pymsm.datasets import prep_covid_hosp_data
from pymsm.multi_state_competing_risks_model import MultiStateModel

# from pymsm.plotting import stackplot, stackplot_state_timesteps, stackplot_state_timesteps_from_paths
from pymsm.statistics import (
    prob_visited_states,
    stats_total_time_at_states,
)

%load_ext autoreload
%autoreload 2

The autoreload extension is already loaded. To reload it, use:
  %reload_ext autoreload


# Load Covid hospitalization data set  
raw public data available at https://github.com/JonathanSomer/covid-19-multi-state-model/blob/master/data/data_for_paper.csv

In [6]:
covid_dataset = prep_covid_hosp_data()

# Some path definitions
covariate_cols = ["is_male", "age"]
states_labels_long = {0: "Censored", 1: "Discharged\Recovered", 2: "Mild or Moderate", 3: "Severe", 4: "Deceased"}
states_labels = {0: "C", 1: "R", 2: "M", 3: "S", 4: "D"}
terminal_states = [4]

100%|██████████| 2675/2675 [00:09<00:00, 289.74it/s]


Let's look at one patients path

In [7]:
covid_dataset[567].print_path()

Sample id: 577
States: [2 3 4]
Transition times: [ 6 31]
Covariates:
is_male     1.0
age        72.5
Name: 567, dtype: float64


# Let's fit the Multistate model

In [8]:
def default_update_covariates_function(
    covariates_entering_origin_state,
    origin_state=None,
    target_state=None,
    time_at_origin=None,
    abs_time_entry_to_target_state=None,
):
    return covariates_entering_origin_state


multi_state_model = MultiStateModel(
    covid_dataset, terminal_states, default_update_covariates_function, covariate_cols,
)


multi_state_model.fit()


Fitting Model at State: 2
>>> Fitting Transition to State: 1, n events: 2135
>>> Fitting Transition to State: 3, n events: 275
>>> Fitting Transition to State: 4, n events: 52
Fitting Model at State: 1
>>> Fitting Transition to State: 2, n events: 98
>>> Fitting Transition to State: 3, n events: 2
Fitting Model at State: 3
>>> Fitting Transition to State: 2, n events: 193



>>> events = df['target_state'].astype(bool)
>>> print(df.loc[events, 'is_male'].var())
>>> print(df.loc[~events, 'is_male'].var())

A very low variance means that the column is_male completely determines whether a subject dies or not. See https://stats.stackexchange.com/questions/11109/how-to-deal-with-perfect-separation-in-logistic-regression.




>>> Fitting Transition to State: 1, n events: 9
>>> Fitting Transition to State: 4, n events: 135


# Single patient stats  
Let's take a look at how the model models transitions for a single patient - a female aged 75

In [9]:
# Run MC for a sample single patient
mc_paths = multi_state_model.run_monte_carlo_simulation(
    sample_covariates=pd.Series({"is_male":0, "age":75}),
    origin_state=2,
    current_time=0,
    n_random_samples=10,
    max_transitions=10,
    print_paths=False,
    n_jobs=-1
)

100%|██████████| 10/10 [00:00<00:00, 48.92it/s]


In [11]:
# Probability of visiting any of the states
for state, state_label in states_labels_long.items():
    if state == 0:
        continue
    print(
        f"Probabilty of ever being {state_label} = {prob_visited_states(mc_paths, states=[state])}"
    )


# Stats for times at states
dfs = []
for state, state_label in states_labels_long.items():
    if state == 0 or state in terminal_states:
        continue
    dfs.append(
        pd.DataFrame(
            data=stats_total_time_at_states(mc_paths, states=[state]),
            index=[state_label],
        )
    )
pd.concat(dfs).round(3).T


Probabilty of ever being Discharged\Recovered = 0.1
Probabilty of ever being Mild or Moderate = 1.0
Probabilty of ever being Severe = 0.3
Probabilty of ever being Deceased = 1.0


Unnamed: 0,Discharged\Recovered,Mild or Moderate,Severe
time_in_state_mean,2.2,3.0,0.5
time_in_state_std,6.6,2.145,1.025
time_in_state_median,0.0,2.0,0.0
time_in_state_min,0.0,1.0,0.0
time_in_state_max,22.0,8.0,3.0
time_in_state_quantile_0.1,0.0,1.0,0.0
time_in_state_quantile_0.25,0.0,2.0,0.0
time_in_state_quantile_0.75,0.0,3.0,0.0
time_in_state_quantile_0.9,2.2,6.2,2.1


In [12]:
from collections import Counter


def get_path_frequencies(paths, states_labels=None):
    states_list = []
    for path in paths:
        states_list.append(path.states)
        
    # Change from numbers to labels
    if states_labels is not None:
        states_list = [[states_labels[y] for y in x] for x in states_list]

    counter = Counter(tuple(x) for x in states_list)
    path_freqs = {str(k).replace(", ", "->").replace("(", "").replace(")", "").replace("'","").replace(",", ""):v for k,v in counter.items()}
    return pd.Series(path_freqs).sort_values(ascending=False)

In [17]:
path_freqs = get_path_frequencies(covid_dataset, states_labels)
path_freqs

M->R                               1906
M                                   202
M->S                                 76
S->D                                 74
M->S->D                              59
M->S->M->R                           46
M->D                                 44
M->R->M->R                           42
S                                    42
M->R->M                              36
M->S->M                              35
S->M->R                              28
S->M                                 27
M->S->M->S                            8
S->R                                  6
S->M->S                               5
S->M->S->M->R                         3
S->M->D                               3
M->S->M->S->M                         3
M->S->M->S->M->R                      3
M->R->M->R->M->R                      3
M->S->R                               3
S->M->S->D                            2
M->S->M->D                            2
M->R->M->S                            2
