In [None]:
import lifelines
from lifelines import CoxPHFitter
import pandas as pd
import numpy as np
import matplotlib.pyplot as plt
import seaborn as sns
from pickle import loads,dumps
import pickle
import warnings
warnings.filterwarnings("ignore")

In [None]:
def create_survival_data(data_df, id_col, index_time_col, outcome_time_col,
                         followup_end_time, followup_max_time_from_index=np.inf,
                         censoring_time_col=None):
    tmp_data_df = data_df.copy()
    if censoring_time_col is None:
        censoring_time_col = 'censoring_time_col'
        tmp_data_df[censoring_time_col] = np.nan

    survival_df = tmp_data_df[[id_col, index_time_col, outcome_time_col, censoring_time_col]].copy()
    survival_df['event_time_from_index'] = (survival_df[outcome_time_col] - survival_df[index_time_col])
    survival_df['censoring_time_from_index'] = (survival_df[censoring_time_col] - survival_df[index_time_col])
    survival_df['followup_end_time_from_index'] = (followup_end_time - survival_df[index_time_col])
    survival_df['max_time_from_index'] = followup_max_time_from_index
    survival_df['earliest_censoring_time_from_index'] = survival_df[['censoring_time_from_index',
                                                                     'followup_end_time_from_index',
                                                                     'max_time_from_index']].min(axis=1)
    # Seperate into 2 types of people:
    # (1) those who have an event before followup_max_time_from_index, followup_end_time_from_index and censoring_time_from_index
    idx_event = ( (survival_df[outcome_time_col].notna()) &  \
                  (survival_df['event_time_from_index']<=followup_max_time_from_index) & \
                  (survival_df['event_time_from_index']<=survival_df['earliest_censoring_time_from_index']) )
    survival_df.loc[idx_event, 'E'] = 1
    survival_df.loc[idx_event, 'T'] = survival_df['event_time_from_index']
    survival_df.loc[~idx_event, 'E'] = 0
    survival_df.loc[~idx_event, 'T'] = survival_df['earliest_censoring_time_from_index'] 
    survival_df['E'] = survival_df['E'].astype(int)
    survival_df['T'] = survival_df['T'].replace(np.inf, 1000)
    survival_df['T'] = survival_df['T'].astype(int)
    survival_df.sort_values(['T'], inplace=True)
    return survival_df[[id_col, 'E', 'T']]


In [None]:
quest_matrix = pd.read_csv('c:/corona_segal/quest_matrix_outcome_test_new.txt', sep="\t")

In [None]:
data_df = quest_matrix.reset_index().copy()

for col in ['visit_date', 'test_date','recover_date']:
    data_df.loc[:,col].replace(-9,np.nan, inplace=True)
    data_df.loc[:,col] = pd.to_datetime(data_df[col], format='%Y%m%d')

MAX_DATE = data_df[['visit_date', 'test_date','recover_date']].max().max()
MIN_DATE = data_df[['visit_date', 'test_date','recover_date']].min().min()

data_df['pos_test_date'] = data_df['test_date']
data_df.loc[(data_df['test_result']!=1), 'pos_test_date'] = np.nan

data_df['neg_test_date'] = data_df['test_date']
data_df.loc[(data_df['test_result']>0), 'neg_test_date'] = np.nan

for col in ['visit_date', 'pos_test_date', 'neg_test_date','recover_date']:
    data_df.loc[:,col+'_T'] = (data_df.loc[:,col] -  MIN_DATE).dt.days

In [None]:
data_df['test_date_correction'] = (data_df['test_date']-MIN_DATE)/np.timedelta64(1,'D')

In [None]:
def prepare_time_to_pos_test_surv_df(data_df, followup_max_time_from_index=21, censoring_time_col=None):  
    
    surv_df = create_survival_data(data_df,
                     id_col='index', index_time_col='visit_date_T', outcome_time_col='pos_test_date_T',
                     followup_end_time=(MAX_DATE-MIN_DATE).days, followup_max_time_from_index=21,
                                   censoring_time_col=censoring_time_col)
    surv_df = surv_df[surv_df['T']>=0]


    symp_cols = ['chom_375_379', 'chom_38_40',
           'chom_up_to_374', 'symp_ayefut', 'symp_bchilot_akahot', 'symp_bilbul',
           'symp_godesh_nazelet', 'symp_keev_garon', 'symp_keev_rosh',
           'symp_keev_shririm', 'symp_kotzer_neshima', 'symp_none', 'symp_other',
           'symp_shilshul', 'symp_shiul', 'symp_shiul_leicha', 'symp_shiul_yavesh',
           'symp_taam_reach', 'symp_zmarmoret','visit_date','time_to_test']

    surv_df = surv_df.merge(data_df[['id','index', 'age', 'gender','id_disease','test_date_correction','test_result']+symp_cols], on='index', how='left')
    
    return surv_df

In [None]:
surv_df = prepare_time_to_pos_test_surv_df(data_df, followup_max_time_from_index=21, censoring_time_col='neg_test_date_T')


In [None]:
df_id_count = surv_df[['id','age']].groupby(['id']).size().reset_index(name="count_id")
df_id_count['weight'] = 1/df_id_count['count_id']
surv_df['weight_by_n_quest'] = surv_df[['id']].merge(df_id_count[['id','weight']], on='id', how='left')['weight']

In [None]:
surv_df['weight_1'] = 1

In [None]:
random_index_per_id =  surv_df.sample(frac=1.0).groupby('id').head(1)
random_index_per_id['random_per_id']=1
surv_df['rand_by_id'] = surv_df.merge(random_index_per_id[['index','random_per_id']], on='index', how='left')['random_per_id']
surv_df['rand_by_id'] = surv_df['rand_by_id'].fillna(0)

In [None]:
first_per_id = surv_df.sort_values('visit_date').groupby('id').first().reset_index()
first_per_id['first_per_id']=1
surv_df['first_per_id'] = surv_df.merge(first_per_id[['index','first_per_id']], on='index', how='left')['first_per_id']
surv_df['first_per_id'] = surv_df['first_per_id'].fillna(0)

In [None]:
last_per_id = surv_df.sort_values('visit_date',ascending=False).groupby('id').first().reset_index()
last_per_id['last_per_id']=1
surv_df['last_per_id'] = surv_df.merge(last_per_id[['index','last_per_id']], on='index', how='left')['last_per_id']
surv_df['last_per_id'] = surv_df['last_per_id'].fillna(0)

In [None]:
from lifelines import KaplanMeierFitter
from lifelines.plotting import add_at_risk_counts
kmf = KaplanMeierFitter()

In [None]:
surv_df_tested =  surv_df[surv_df['test_result']>=0]

## Kaplan Meier

In [None]:
 symp_cols = ['chom_375_379', 'chom_38_40',
           'chom_up_to_374', 'symp_ayefut', 'symp_bchilot_akahot', 'symp_bilbul',
           'symp_godesh_nazelet', 'symp_keev_garon', 'symp_keev_rosh',
           'symp_keev_shririm', 'symp_kotzer_neshima', 'symp_none', 'symp_other',
           'symp_shilshul', 'symp_shiul', 'symp_shiul_leicha', 'symp_shiul_yavesh',
           'symp_taam_reach', 'symp_zmarmoret']

In [None]:
folder = 'final/'

In [None]:
def run_km(data_df, use_weight, use_quest , write_to_pickle,pop):  
    
    pickle_name = 'c:/corona_segal/'+folder+'/quest_'+ pop + '_km_outcome_pos_test_' + use_weight + '_quest_' + use_quest + '_pickle'
    
    if (write_to_pickle==1) :
        file = open(pickle_name,'wb')
    
    
    if (use_quest=='first'):
        surv_df = data_df[data_df['first_per_id']==1]
    elif (use_quest=='last'):
        surv_df = data_df[data_df['last_per_id']==1]
    elif (use_quest=='random'):
        surv_df = data_df[data_df['rand_by_id']==1]
    else:
        surv_df = data_df
        
    if (use_weight=='weight_by_id_count'):
        vec_weight = data_df[['weight_by_n_quest']]
    else:
        vec_weight = data_df[['weight_1']]                             
        
    
    fig, axes = plt.subplots(5,4, figsize=(18,20), dpi=100)
    for i, symp_col in enumerate(symp_cols):
        ax = axes[i//4, i%4]

        if (use_weight=='weight_by_id_count'):
            kmf_0 = KaplanMeierFitter()
            kmf_0.fit(surv_df.loc[surv_df[symp_col]==0, 'T'], surv_df.loc[surv_df[symp_col]==0, 'E'], label=f'{symp_col}=0', weights=surv_df.loc[surv_df[symp_col]==0, 'weight_by_n_quest'])
            kmf_0.plot(ax=ax)

            kmf_1 = KaplanMeierFitter()
            kmf_1.fit(surv_df.loc[surv_df[symp_col]==1, 'T'], surv_df.loc[surv_df[symp_col]==1, 'E'], label=f'{symp_col}=1', weights=surv_df.loc[surv_df[symp_col]==1, 'weight_by_n_quest'])
            kmf_1.plot(ax=ax)
        else:
            kmf_0 = KaplanMeierFitter()
            kmf_0.fit(surv_df.loc[surv_df[symp_col]==0, 'T'], surv_df.loc[surv_df[symp_col]==0, 'E'], label=f'{symp_col}=0')
            kmf_0.plot(ax=ax)

            kmf_1 = KaplanMeierFitter()
            kmf_1.fit(surv_df.loc[surv_df[symp_col]==1, 'T'], surv_df.loc[surv_df[symp_col]==1, 'E'], label=f'{symp_col}=1')
            kmf_1.plot(ax=ax)
        
        if (write_to_pickle==1) :
            pickle.dump(kmf_0,file)
            pickle.dump(kmf_1,file)    

        ax.legend()
        #fig_name = 'c:/corona_segal/quest_km_outcome_pos_test_' + 'weight_' + use_weight + '_quest_' + use_quest
        #plt.savefig(fig_name)
    
    if (write_to_pickle==1) :
        file.close()
        
    
    

In [None]:
run_km(surv_df_tested, use_weight='weight_1', use_quest='all' , write_to_pickle=1,pop='tested')
run_km(surv_df_tested, use_weight='weight_by_id_count', use_quest='all' , write_to_pickle=1,pop='tested')

## IPW

In [None]:
y_vec = surv_df['test_result']>=0
y_vec= y_vec*1
x_mat = surv_df[symp_cols + ['age','gender']]

In [None]:
from sklearn.linear_model import LogisticRegression
from causallib.datasets import load_nhefs
%matplotlib inline
from causallib.datasets import load_nhefs
from causallib.estimation import IPW
from causallib.evaluation import PropensityEvaluator
from sklearn.linear_model import LogisticRegression
learner = LogisticRegression(solver="liblinear")
ipw = IPW(learner)
ipw.fit(x_mat, y_vec)
ipw_vec = ipw.compute_weights(x_mat, y_vec)


In [None]:
pickle_name_ipw = 'c:/corona_segal/'+folder + 'time_to_outcome_predict_test_ipw_pickle'
file_ipw = open(pickle_name_ipw,'wb')
pickle.dump(ipw,file_ipw)
file_ipw.close()

In [None]:
from sklearn import metrics
plots=["roc_curve", "pr_curve", "weight_distribution", 
       "calibration", "covariate_balance_love", "covariate_balance_slope"]
metrics = {"roc_auc": metrics.roc_auc_score,
           "avg_precision": metrics.average_precision_score,}
evaluator = PropensityEvaluator(ipw)
results = evaluator.evaluate_cv(x_mat, y_vec, y_vec, 
                                plots=plots, metrics_to_evaluate=metrics)

In [None]:
surv_df['ipw'] = ipw_vec
surv_df['ipw_and_weight_sample']=ipw_vec*surv_df['weight_by_n_quest']

## Cox

In [None]:
def run_cox(data_df, use_weight, use_quest , write_to_pickle,pop):  
    
    pickle_name = 'c:/corona_segal/'+folder+'/quest_'+ pop + '_cox_outcome_pos_test_' + use_weight + '_quest_' + use_quest + '_pickle'
    summary_name = 'c:/corona_segal/'+folder+'/quest_'+ pop + '_cox_outcome_pos_test_' + use_weight + '_quest_' + use_quest + '_summary'
    
    if (write_to_pickle==1) :
        file = open(pickle_name,'wb')
    
    
    if (use_quest=='first'):
        surv_dfxx = data_df[data_df['first_per_id']==1]
    elif (use_quest=='last'):
        surv_dfxx = data_df[data_df['last_per_id']==1]
    elif (use_quest=='random'):
        surv_dfxx = data_df[data_df['rand_by_id']==1]
    else:
        surv_dfxx = data_df
        
    if (use_weight=='weight_by_id_count'):
        vec_weight = data_df[['weight_by_n_quest']]
    else:
        vec_weight = data_df[['weight_1']]                             
        
       
        
    
    fig, axes = plt.subplots(5,4, figsize=(18,20), dpi=100)
    for i, symp_col in enumerate(symp_cols):
        ax = axes[i//4, i%4]

        if (use_weight=='weight_by_id_count'):
            cph = CoxPHFitter()
            cph.fit(surv_dfxx[['T', 'E', 'age', 'gender','id_disease','test_date_correction','ipw_and_weight_sample', symp_col]],robust=False, duration_col='T', event_col='E', step_size=0.01, weights_col='ipw_and_weight_sample')
        else:
            cph = CoxPHFitter()
            cph.fit(surv_dfxx[['T', 'E', 'age', 'gender','id_disease','test_date_correction','ipw', symp_col]],robust=False, duration_col='T', event_col='E', step_size=0.01, weights_col='ipw')
        
        a = cph.summary
        if (i==0):
            orig_summary = a
        else:
            orig_summary = pd.concat([orig_summary,a])            
        
        
        cph.plot_partial_effects_on_outcome(covariates=symp_col, values=[0,1], cmap='coolwarm', ax=ax)
        
        if (write_to_pickle==1) :
            pickle.dump(cph,file)

        #fig_name = 'c:/corona_segal/quest_cox_outcome_pos_test_' + use_weight + '_quest_' + use_quest
        #plt.savefig(fig_name)
    
    orig_summary.to_csv(summary_name)
    
    if (write_to_pickle==1) :
        file.close()
       
        

In [None]:
surv_df_tested2 =  surv_df[surv_df['test_result']>=0]

In [None]:
run_cox(surv_df_tested2, use_weight='weight_1', use_quest='all' , write_to_pickle=1,pop='tested')
run_cox(surv_df_tested2, use_weight='weight_by_id_count', use_quest='all' , write_to_pickle=1,pop='tested')