In [1]:
import numpy as np
import pandas as pd
from pandas import Series
from sklearn.preprocessing import OneHotEncoder
from pymsm.multi_state_competing_risks_model import PathObject, MultiStateModel


def get_categorical_columns(df, cat_cols):
    encoder = OneHotEncoder(drop="first", sparse=False)
    new_df = pd.DataFrame(encoder.fit_transform(df[cat_cols]), dtype=int)
    new_df.columns = encoder.get_feature_names_out(cat_cols)
    return new_df


In [2]:
longdata = pd.read_csv("msebmt.csv", index_col=0)

# take care of right censoring
longdata.loc[longdata['status']==0, 'to'] = 0
longdata = longdata.drop('trans', axis=1).drop_duplicates()
longdata = longdata.sort_values(['id','Tstart', 'from', 'status']).drop_duplicates(['id','Tstart', 'from'], keep='last')

longdata = longdata.reset_index(drop=True)

longdata.head(10)


Unnamed: 0,id,from,to,Tstart,Tstop,time,status,match,proph,year,agecl
0,1,1,2,0.0,22.0,22.0,1,no gender mismatch,no,1995-1998,20-40
1,1,2,0,22.0,995.0,973.0,0,no gender mismatch,no,1995-1998,20-40
2,2,1,3,0.0,12.0,12.0,1,no gender mismatch,no,1995-1998,20-40
3,2,3,4,12.0,29.0,17.0,1,no gender mismatch,no,1995-1998,20-40
4,2,4,5,29.0,422.0,393.0,1,no gender mismatch,no,1995-1998,20-40
5,3,1,3,0.0,27.0,27.0,1,no gender mismatch,no,1995-1998,20-40
6,3,3,0,27.0,1264.0,1237.0,0,no gender mismatch,no,1995-1998,20-40
7,4,1,3,0.0,42.0,42.0,1,gender mismatch,no,1995-1998,20-40
8,4,3,4,42.0,50.0,8.0,1,gender mismatch,no,1995-1998,20-40
9,4,4,5,50.0,84.0,34.0,1,gender mismatch,no,1995-1998,20-40


In [3]:
# Categorical columns
cat_cols = ["match", "proph", "year", "agecl"]
cat_df = get_categorical_columns(longdata, cat_cols)
covariate_cols = cat_df.columns
data = pd.concat([longdata.drop(cat_cols, axis=1), cat_df], axis=1)
data


Unnamed: 0,id,from,to,Tstart,Tstop,time,status,match_no gender mismatch,proph_yes,year_1990-1994,year_1995-1998,agecl_<=20,agecl_>40
0,1,1,2,0.0,22.0,22.0,1,1,0,0,1,0,0
1,1,2,0,22.0,995.0,973.0,0,1,0,0,1,0,0
2,2,1,3,0.0,12.0,12.0,1,1,0,0,1,0,0
3,2,3,4,12.0,29.0,17.0,1,1,0,0,1,0,0
4,2,4,5,29.0,422.0,393.0,1,1,0,0,1,0,0
...,...,...,...,...,...,...,...,...,...,...,...,...,...
4626,2278,1,2,0.0,15.0,15.0,1,1,0,0,0,1,0
4627,2278,2,0,15.0,676.0,661.0,0,1,0,0,0,1,0
4628,2279,1,2,0.0,18.0,18.0,1,0,0,0,1,1,0
4629,2279,2,4,18.0,30.0,12.0,1,0,0,0,1,1,0


In [4]:
rename_cols = {
    "id": "sample_id",
    "from": "origin_state",
    "to": "target_state",
    "Tstart": "time_entry_to_origin",
    "Tstop": "time_transition_to_target",
}

competing_risk_dataset = data[rename_cols.keys()].rename(columns=rename_cols)
competing_risk_dataset = pd.concat([competing_risk_dataset, data[covariate_cols]], axis=1)

competing_risk_dataset

Unnamed: 0,sample_id,origin_state,target_state,time_entry_to_origin,time_transition_to_target,match_no gender mismatch,proph_yes,year_1990-1994,year_1995-1998,agecl_<=20,agecl_>40
0,1,1,2,0.0,22.0,1,0,0,1,0,0
1,1,2,0,22.0,995.0,1,0,0,1,0,0
2,2,1,3,0.0,12.0,1,0,0,1,0,0
3,2,3,4,12.0,29.0,1,0,0,1,0,0
4,2,4,5,29.0,422.0,1,0,0,1,0,0
...,...,...,...,...,...,...,...,...,...,...,...
4626,2278,1,2,0.0,15.0,1,0,0,0,1,0
4627,2278,2,0,15.0,676.0,1,0,0,0,1,0
4628,2279,1,2,0.0,18.0,0,0,0,1,1,0
4629,2279,2,4,18.0,30.0,0,0,0,1,1,0


In [5]:
def default_update_covariates_function(
    covariates_entering_origin_state,
    origin_state=None,
    target_state=None,
    time_at_origin=None,
    abs_time_entry_to_target_state=None,
):
    return covariates_entering_origin_state


terminal_states = [5, 6]



In [6]:
# dataset = []
# final_states = []

# for sample_id in data.id.unique():
#     sample_df = data[data["id"] == sample_id]
#     # add covariates
#     path = PathObject(
#         covariates=(sample_df.iloc[0][covariate_cols]), sample_id=sample_id,
#     )

#     # add transitions
#     for i, row in sample_df.iterrows():
#         path.states.append(row["from"].astype(int))
#         path.time_at_each_state.append(row["time"])
#     # append final state
#     final_state = row["to"].astype(int)
#     if final_state in terminal_states:
#         path.states.append(final_state)
#         final_states.append(final_state)
#         dataset.append(path)


# print(type(path))
# print("\n------covariates------")
# print(path.covariates)
# print("\n-------states---------")
# print(path.states)
# print("\n--time at each state--")
# print(path.time_at_each_state)
# print("\n------sample id-------")
# print(path.sample_id)


In [7]:
from pymsm.multi_state_competing_risks_model import MultiStateModel

multi_state_model = MultiStateModel(
    None,
    terminal_states,
    default_update_covariates_function,
    covariate_cols,
    competing_risk_dataset=competing_risk_dataset,
)



In [8]:
multi_state_model.fit()


Fitting Model at State: 1
>>> Fitting Transition to State: 2, n events: 785
>>> Fitting Transition to State: 3, n events: 907
>>> Fitting Transition to State: 5, n events: 95
>>> Fitting Transition to State: 6, n events: 160
Fitting Model at State: 2
>>> Fitting Transition to State: 5, n events: 112
>>> Fitting Transition to State: 6, n events: 39
>>> Fitting Transition to State: 4, n events: 227
Fitting Model at State: 3
>>> Fitting Transition to State: 4, n events: 433
>>> Fitting Transition to State: 6, n events: 197
>>> Fitting Transition to State: 5, n events: 56
Fitting Model at State: 4
>>> Fitting Transition to State: 5, n events: 107
>>> Fitting Transition to State: 6, n events: 137


In [9]:
# WORKS!
