In [None]:
'''
@Author: Jiqiao Lu, George
@Email: george6.lu@polyu.edu.hk 
@Description: the main script implements all the cleaning steps, generate results and save the model. For some reason, the jupyter
built-in editor does not support correct spacing and indent and I have no idea of how to use vim(without plugin), so I have to
use notebook cell as script file, it is supposed to be a runnable scripts enabling user input argument for time-spectrum 
if you look at this code outside of HA, please convert it to a script as descibed above, thank you.

'''
import pandas as pd
import os
import matplotlib as mpl
import matplotlib.pyplot as plt
import numpy as np
import pickle
from scipy import stats
import util.cleaning_tools as tools


import matplotlib.pyplot as plt

from imblearn.over_sampling import RandomOverSampler, SMOTE

import sklearn
from sklearn.model_selection import train_test_split, GridSearchCV, cross_val_score, RepeatedStratifiedKFold
from sklearn.metrics import balanced_accuracy_score, classification_report, confusion_matrix, ConfusionMatrixDisplay,\
precision_recall_curve, auc, roc_auc_score, roc_curve, recall_score, precision_score, accuracy_score
from sklearn.feature_selection import SelectKBest, f_classif
from sklearn.preprocessing import StandardScaler, RobustScaler
from sklearn.linear_model import LogisticRegression, SGDClassifier, Lasso
from sklearn.tree import DecisionTreeClassifier
from sklearn.pipeline import Pipeline
from sklearn.neighbors import KNeighborsClassifier
from sklearn.svm import SVC
from sklearn.utils import resample
from sklearn.ensemble import RandomForestClassifier, VotingClassifier, GradientBoostingClassifier, AdaBoostClassifier
from sklearn.experimental import enable_iterative_imputer
from sklearn.impute import IterativeImputer
from random import sample
import time
import warnings
import tensorflow as tf
from tensorflow import keras
import scipy.stats as stats
import json
import os
import glob

from tensorflow.keras.wrappers.scikit_learn import KerasClassifier, KerasRegressor
from sklearn.metrics import roc_auc_score
# from sklearn.inspection import permutation_importance
from typing import Callable

warnings.filterwarnings("ignore")

In [None]:
name_dict = {}

#define indicators id
term_id = list(name_dict.keys())

METRICS = [
    keras.metrics.TruePositives(name='tp'),
    keras.metrics.FalsePositives(name='fp'),
    keras.metrics.TrueNegatives(name='tn'),
    keras.metrics.FalseNegatives(name='fn'),
    keras.metrics.BinaryAccuracy(name='accuracy'),
    keras.metrics.Precision(name='precision'),
    keras.metrics.Recall(name='recall'), # we focus on recall metrics
    keras.metrics.AUC(name='auc'),
    keras.metrics.AUC(name='prc', curve='PR') # precision-recall curve
]

In [None]:
class AUCStopping(keras.callbacks.Callback):
    '''
    callback class, overwrite the method to trigger the call back function ath the end of the end of
    epoch
    '''
    def on_epoch_end(self, epoch, logs={}):
        if(logs.get('val_auc') >= 0.81 and logs.get('val_recall') >= 0.85):
            print("\n Early stopping beacause validation auc excesses 80%")
            self.model.stop_training = True

In [None]:
def map_age(df, fields, dob="dob_Y"):
    '''
    map the date time fieds into age inplace
    Args:
        field: date fields
        dob: date of birth
    '''
        
    age_fields = list(map(lambda x : x.split(r'_')[0] + "_age", fields))
    for af, f in zip(age_fields, fields):
        df[af] = (pd.to_datetime(df[f]) - pd.to_datetime(df[dob])).apply(lambda x : x / np.timedelta64(1, "Y"))
        

def cls_mapper(prog_pd: float, PERIOD_LONG) -> int:
    if prog_pd < PERIOD_LONG:
        return 1
    else:
        return 0

      # define helper fuctions
def plot_barchart(s:pd.Series, n_bin:int, bin_width:int=1, sort=False, title=""):
    bin = pd.cut(s, bins=[bin_width * i for i in range(n_bin)])
    out = bin.value_counts(sort=sort)
    ax = out.plot.bar(rot=0, color='b', figsize=(6,4))
    ax.set_xticklabels([bin_width * x for x in range(n_bin)])
    ax.set_title(title)

def plot_boxplot(df, cat, title):
    # boxplot for death age
    fig, axs = plt.subplots(1,len(cat))
    fig.suptitle(title)
    for idx, name in zip(range(len(cat)), cat):
        data = df[df.label == idx]["death_age"]
        axs[idx].boxplot(data, 0, '')
        axs[idx].set_title(name)
        
def make_model(X_train, metrics=METRICS, output_bias=None):
    if output_bias is not None:
        output_bias = tf.keras.initializers.Constant(output_bias)
        
    # build the model, this model can be seen as a logistic regression but with extra drop-out layers 
    # for avoiding overfitting     
    model = keras.Sequential([
        keras.layers.Dense(16, activation='relu', input_shape=(X_train.shape[-1],)),
        keras.layers.Dense(64, activation='relu'),
        keras.layers.Dropout(0.5), # avoid overfitting
        keras.layers.Dense(1, activation='sigmoid', bias_initializer=output_bias)
    ])
    
    model.compile(
        optimizer = keras.optimizers.Adam(learning_rate=1e-3), #adam is not sensitive to different scale of loss
        loss=keras.losses.BinaryCrossentropy(),
        metrics=metrics
    )
    return model
  
def plot_metrics(history):
    metrics = ['loss', 'accuracy', "recall", "precision"]
    for n, metric in enumerate(metrics):
        name = metric.replace("_", " ").capitalize()
        ax = plt.subplot(2, 2, n+1)
        fig = plt.gcf()
        fig.set_size_inches(10,10)
        plt.plot(history.epoch, history.history[metric], color=colors[0], label="Train")
        plt.plot(history.epoch, history.history['val_'+metric], color=colors[0], linestyle="--", label="Val")
        plt.xlabel("Epoch")
        plt.ylabel(name)
        if metric == 'loss':
            plt.ylim([0, plt.ylim()[1]])
        elif metric == 'auc':
            plt.ylim([0.8,1])
        else:
            plt.ylim([0,1])
        plt.legend()
    plt.savefig(os.path.join(OUT_PATH, "charts", "training_history.png"), bbox_inches='tight', dpi=400)
        
def plot_roc(name, labels, predictions, **kwargs):
    fp, tp, threshold = sklearn.metrics.roc_curve(labels, predictions)
    AUC = auc(fp,tp)
    plt.plot(100*fp, 100*tp, linewidth=2, label="{} (area = {:.3f})".format(name, AUC), **kwargs)
    plt.xlabel('False positives [%]')
    plt.ylabel('True positives [%]')
    plt.xlim([0,100])
    plt.ylim([0,100])
    plt.grid(True)
    ax = plt.gca()
    ax.set_aspect('equal')
    return fp, tp, threshold

def plot_feature_cpf(feature_name:str):
    y, x, _ = plt.hist(dataset[feature_name], ec="black", bins=500, cumulative=-1, histtype="step", density=True)
    print(len(x[1:]), len(y))
    pd.DataFrame({"x": x[1:], "y": y}).to_csv(os.path.join("charts", "charts", "baseline_{}_dist.csv".format(feature_name)))
    plt.savefig(os.path.join("charts", "baseline_{}_dist.png".format(feature_name)))
    
def permutation_importance(model, X: pd.DataFrame, y: pd.Series, n_repeats: int, metric_fn, **params) -> dict:
    '''
    the method return the mean score and std through out the n_repeat, the score is computed by the provided callable type
    metric
    
    Assumption & Method
    It assumes that the higher the score, the better the performance of the model is, and we use difference to measure the model
    reliance of the feature.
    
    Arg:
        model: model interface implements predicit method that return the decision score of the prediciton 
        X: input of validation set to be permutated
        y: ground truth
        n_repeats: the times of the iteration
        metric_fn: the callable type to compute the score
        params: addtional parameter to pass into the metric_funtion
    
    Return:
        mean score over n repeats for each features.
    '''
    features = list(X.columns)
    y_pred = model.predict(X)
    original_score = metric_fn(y, y_pred, **params)
    imp = {}
    for _ in range(n_repeats):
        for f in features:
            X_new = X.copy() #copy the reference to the original dataframe
            X_new[f] = X[f].sample(frac=1, replace=False).to_list()
            y_pred = model.predict(X_new)
            score = metric_fn(y, y_pred, **params)
            diff = original_score - score
            # assert the new score should not be greater than before
            prev = imp.get(f, 0)
            # add the new score difference
            imp[f] = prev + diff
    for f in imp:
        imp[f] = imp[f] / n_repeats
    return imp

# plot confusion matrix for deep learning models
def plot_cm(y, predictions, name, threshold=0.5):
    cm = confusion_matrix(y, predictions > threshold)
    fig, ax = plt.subplots(figsize=(5,5))
    disp = ConfusionMatrixDisplay(confusion_matrix=cm, display_labels=['no incidence', 'incidence ocurred'])
    disp.plot(ax=ax, values_format='d')
    plt.title('Confusion Matrix for Weighted Neural Network')
    plt.ylabel('True label')
    plt.xlabel("Predicted label")
    plt.savefig(os.path.join(OUT_PATH, "charts", name + ".png"), bbox_inches='tight', dpi=1000)
    
def get_scores(model, X_test, y_test):
    
    y_pred_test = model.predict(X_test)
    
#     print("=========================================================")
#     print("Metrics for model " + model.__class__.__name__)
#     report = classification_report(y_test, y_pred_test, target_names=['no incidence', 'incidence ocurred'])
#     print(report)

#     plot confusion matrix
    cm = confusion_matrix(y_test, y_pred_test, labels=[0.0,1.0])
    disp = ConfusionMatrixDisplay(confusion_matrix=cm, display_labels=['no incidence', 'incidence ocurred'])
    disp.plot(values_format='d')
    fig = disp.figure_
    fig.set_figwidth(5)
    fig.set_figheight(5)
    plt.grid(False)
    plt.title("Confusion Matrix for " + model.__class__.__name__, fontsize=12, fontweight='bold')
    plt.savefig(os.path.join(OUT_PATH, "charts", "Confusion Matrix for " + model.__class__.__name__ + ".png"), bbox_inches='tight', dpi=400)
    return classification_report(y_test, y_pred_test, target_names=['no incidence', 'incidence ocurred'], output_dict=True)

def compute_metrics(model, X, y, k=100):
    # implement bootstraping
    roc_auc = []
    recall = []
    precision = []
    accuracy = []
    for _ in range(k):
        X_bs, y_bs = resample(X, y, replace=True)
        try:
            y_test_score = model.decision_function(X_bs)
            fpr, tpr, thresholds = roc_curve(y_bs, y_test_score)
            roc_auc.append(auc(fpr, tpr))
        except AttributeError as e:
            roc_auc.append(-1)

        y_pred = model.predict(X_bs)
        recall.append(recall_score(y_bs, y_pred))
        precision.append(precision_score(y_bs, y_pred))
        accuracy.append(accuracy_score(y_bs, y_pred)) 
        
    return {
        "auc": CI(roc_auc),
        "recall": CI(recall),
        "precision": CI(precision),
        "accuracy": CI(accuracy)
    }

def compute_metrics_dl(model, X, y, k=100):
    roc_auc = []
    recall = []
    precision = []
    accuracy = []
    for _ in range(k):
        X_bs, y_bs = resample(X, y, replace=True)
        evaluation = model.evaluate(X_bs, y_bs, batch_size=BATCH_SIZE, verbose=2)
        eva = dict(zip(weighted_model.metrics_names, evaluation))
        roc_auc.append(eva["auc"])
        recall.append(eva["recall"])
        precision.append(eva["precision"])
        accuracy.append(eva["accuracy"])
    return {
        "auc": CI(roc_auc),
        "recall": CI(recall),
        "precision": CI(precision),
        "accuracy": CI(accuracy)
    }
        
def plot_auc(model, X_test, y_test):
    plt.figure(figsize=(5,5))
    y_test_score = model.decision_function(X_test)
    fpr, tpr, thresholds = roc_curve(y_test, y_test_score)
    roc_auc = auc(fpr, tpr)
    plt.title("ROC")
    plt.plot(fpr, tpr, 'b', label='AUC = %0.4f' % roc_auc)
    plt.legend(loc='lower right')
    plt.plot([0,1],[0,1], 'r--')
    plt.xlim([0, 1.0])
    plt.ylim([0, 1.01])
    plt.title("ROC for model " + model.__class__.__name__)
    plt.ylabel('True Positive Rate')
    plt.xlabel('False Positive Rate')
    plt.savefig(os.path.join(OUT_PATH, "charts", model.__class__.__name__ + "ROCAUC.png"))

In [None]:
file_path = r'../DATAFILE'
labresult_cps_path = 'lis_cps_result_data'
labresult_hms_path = 'lis_hms_result_data'
# read the file of patient demographic information
patient_info = tools.fileReader(r"../DATAFILE", 'patient_data')
# read the grouped patients
group_patient = pd.read_csv(r"../tables/output/group_patient.csv", index_col=0)


# term_id mapping
tid_desc = tools.fileReader(r'../DATAFILE', r"iams_concept")
tid_to_eid = tools.fileReader(r'../DATAFILE', r'iams_entity_concept')

# labresult
usecols = ["pseudo_patient_key", "reference_dtm", "diff_in_hour_reference_dtm", "result_str", "entity_id", "si_unit", "si_numeric"]
labresult_cps = tools.fileReader(file_path, labresult_cps_path, usecols=usecols)
labresult_hms = tools.fileReader(file_path, labresult_hms_path, usecols=usecols)
labresult = pd.concat([labresult_cps, labresult_hms])

In [None]:
left = group_patient
right = patient_info[["pseudo_patient_key", "dob_Y", "sex", "death_date_Y", "diff_in_hour_death_date"]]
diab_patients_info = pd.merge(left=left, right=right, how='left', on='pseudo_patient_key')

# replace null value to np.nan
diab_patient_age = diab_patients_info.replace(r'""', np.nan)
# map the date time fieds into age inplace
map_age(diab_patient_age, ["pre_dtm", "diab_dtm", "death_date_Y"])
# convert to year of birth
diab_patient_age["dob_Y"] = diab_patient_age["dob_Y"].apply(lambda x : x[:4]).astype("int")
# compute the progression period in hours
diab_patient_age["prog_pd"] = diab_patient_age["diab_diff_hour"] - diab_patient_age["pre_diff_hour"] 
diab_patient_age["diff_in_hour_death_date"] = diab_patient_age["diff_in_hour_death_date"].astype("float")

In [None]:
import scipy.stats as st
from typing import List
class LogMetrics:
    auc: List[float]
    recall: List[float]
    precision: List[float]
    accuracy: List[float]
    
    def __init__(self):
        auc = []
        recall = []
        precision = []
        accuracy = []
        
    def to_dict(self):
        return {
            "auc": self.auc,
            "recall": self.recall,
            "precision": self.precision,
            "accuracy": self.accuracy
        }

def parse_log(logs):
    
    mls = ["logreg_info", "dt_info", "rf_info", "adb_info", "logreg_info", "xgb_info", "dl_info"]
    ml_metrics = {}
    for ml in mls:
        ml_metrics[ml] = LogMetrics()
    
    for info in logs:
        for ml, metrics in ml_metrics.items():
            if ml != "dl_info":
                metrics.auc.append(logs["ml_info"]["ml"]["metrics"]["auc"])
                metrics.recall.append(logs["ml_info"]["ml"]["metrics"]["recall"])
                metrics.precision.append(logs["ml_info"]["ml"]["metrics"]["precision"])
                metrics.accuracy.append(logs["ml_info"]["ml"]["metrics"]["accuracy"])
            else:
                metrics.auc.append(logs[ml]["acu"])
                metrics.recall.append(logs[ml]["recall"])
                metrics.precision.append(logs[ml]["precision"])
                metrics.accuracy.append(logs[ml]["accuracy"])
    return ml_metrics

def CI(samples, alpha = 0.95):
    left, right = st.norm.interval(alpha = alpha, loc = np.mean(samples), scale = st.sem(np.array(samples)))
    return [left, (left + right) / 2, right]

In [None]:
for t in [2, 5, 10]:
    
    TIME_SPEC = t
    YEAR = 2019 - TIME_SPEC
    CUT_OFF = '{year}-12-31'.format(year=YEAR)
    colors = plt.rcParams['axes.prop_cycle'].by_key()['color']
    mpl.rcParams['figure.figsize'] = (12,10)
    OUT_PATH = os.path.join(f"../output/george/spec-{TIME_SPEC}year")
    PERIOD_LONG = TIME_SPEC * 365.25 * 24

    #clean the directory
    files = glob.glob('../output/george/spec-{}year/**/*.*'.format(t), recursive=True)
    for f in files:
        os.remove(f)
    log = {}
    ############################################################cleaning################################################
    # replace null value to np.nan

    # exclusion
    diab_patient_excluded = diab_patient_age.copy() # create a new reference
    patient_num = {}
    def show_num():
        return [diab_patient_excluded.shape[0],\
                diab_patient_excluded.query("label == 0").shape[0],\
                diab_patient_excluded.query("label == 1").shape[0]
               ] 

    patient_num["No_filtering"] = show_num()

    # Enrolment in the first year and last three year (pre_dtm > 2003-12-31 and pre_dtm <= 2014-12-31)
    diab_patient_excluded = diab_patient_excluded.query(f"pre_dtm <= '{CUT_OFF}' or label == 2")
    patient_num["cutoff_filtering"] = show_num()

    # Follow up time is less than {{TIME_SPEC}}
    death_diff_hour = TIME_SPEC * 365.25 * 24
    diab_patient_excluded = diab_patient_excluded.query(f"diff_in_hour_death_date.isnull() or diff_in_hour_death_date - pre_diff_hour > {death_diff_hour} \
or label > 0", engine='python')
    patient_num["death_filtering"] = show_num()

    # Diabetes only.
    diab_patient_excluded = diab_patient_excluded.query("label < 2")
    patient_num["diabonly_filtering"] = show_num()

    # Patients younger than 18(pre_age < 18)
    diab_patient_excluded = diab_patient_excluded.query("pre_age >= 18")
    patient_num["age_filtering"] = show_num()

    log["patient_num"] = patient_num
    # # Follow-up time less than 1 month i.e. 30.5*24 hours
    # diab_patient_age = diab_patient_age.query(f"prog_pd > {MIN_FT} | prog_pd.isnull()", engine='python')
    diab_patient_excluded["cls"] = diab_patient_excluded["prog_pd"].apply(cls_mapper, PERIOD_LONG =PERIOD_LONG)
    diab_patient_excluded.to_csv(f"../tables/output/diab_patient-{TIME_SPEC}year.csv")
#     diab_patient_excluded = pd.read_csv(f"../tables/output/diab_patient-{TIME_SPEC}year.csv", index_col=0)
    ############################################################plot###########################################################
    # patient number of each group
    fig, ax = plt.subplots(figsize=(6,4))
    count = diab_patient_excluded.groupby("label")["label"].count()
    group = ["pre-diabetes only", "pre-diabetes to diabetes"]
    bar_color = ['tab:red', 'tab:blue']
    ax.bar(group, count, label = group, color = bar_color)
    ax.set_ylabel("patients number")
    ax.set_title("patients number of each group")
    for i, v in enumerate(count.to_list()):
        plt.text(i, v, "{:,}".format(v), ha = 'center')
    plt.savefig(os.path.join(OUT_PATH, "charts", "patients_number.png"))

    death = diab_patient_excluded[diab_patient_age.death_age.notnull()]
    pre = diab_patient_excluded[diab_patient_age.label == 0]
    pre2diab = diab_patient_excluded[diab_patient_age.label == 1]
    diab = diab_patient_excluded[diab_patient_age.label == 2]

    # death number of each group
    fig, ax = plt.subplots(figsize=(6,4))
    count = death.groupby("label")["label"].count()
    group = ["pre-diabetes only", "pre-diabetes to diabetes"]
    bar_color = ['tab:green', 'tab:purple', 'tab:pink']
    ax.bar(group, count, label = group, color = bar_color)
    ax.set_ylabel("patients number")
    ax.set_title("death number of each class")
    for i, v in enumerate(count.to_list()):
        plt.text(i, v, "{:,}".format(v), ha = 'center')
    plt.savefig(os.path.join(OUT_PATH,"charts", "death_class.png"))

    plt.cla()
    plt.clf()
    
    # prediabetes distribution against age
    fig, ax = plt.subplots(figsize=(6,4))
    bin = pd.cut(pre.pre_age, bins=[5 * (i) for i in range(3,23)])
    out = bin.value_counts(sort=False)
    ax = out.plot.bar(rot=30, color='b', figsize=(8,4))
    ax.set_xticklabels(["18-20"] + [f"{i*5}-{i*5+5}" for i in range(4, 22)])
    ax.set_xlabel("Age")
    ax.set_ylabel("Number")
    _ = ax.set_title("Distribution of prediabetes against age")
    plt.savefig(os.path.join(OUT_PATH, "charts", "distribution_pre_age.png"))
    plt.cla()
    plt.clf()
    
    # distribution of pre to diabetes patients for the baseline cohort
    fig, ax = plt.subplots(figsize=(6,4))
    bin = pd.cut(pre2diab.pre_age, bins=[5 * (i) for i in range(3,23)])
    out = bin.value_counts(sort=False)
    out.to_csv(os.path.join(OUT_PATH,"charts", "distribution_pre2diab_age.csv"))
    ax = out.plot.bar(rot=30, color='b', figsize=(8,4))
    ax.set_xticklabels(["18-20"] + [f"{i*5}-{i*5+5}" for i in range(4, 22)])
    ax.set_xlabel("Age")
    ax.set_ylabel("Number")
    _ = ax.set_title("Distribution of prediabetes against age")
    plt.savefig(os.path.join(OUT_PATH, "charts", "distribution_pre2diab_age.png"))

    plt.cla()
    plt.clf()
    
    # patient portion
    temp = diab_patient_excluded.assign(pre_year = diab_patient_age["pre_dtm"].apply(lambda s : str(s)[:4]))
    idx_to_dx = {0: "No T2DM Incidence", 1: "With T2DM Incidence"}
    temp["Label"] = temp["cls"].apply(lambda x : idx_to_dx[x])
    temp_group = temp.groupby(["pre_year", "Label"]).size().unstack()
    percent = (temp_group["With T2DM Incidence"] / (temp_group["With T2DM Incidence"] + temp_group["No T2DM Incidence"])) * 100
    ax = temp_group.plot(kind="bar", stacked=True, colormap="Set2", figsize=(8,5))
    mid = temp_group["With T2DM Incidence"] / 2 + temp_group["No T2DM Incidence"]
    for con in ax.containers:
        plt.setp(con, width=0.5)
    x0, x1 = ax.get_xlim()
    ax.set_xlim(x0-1, x1+1)
    for i,per in enumerate(percent):
        plt.text(i, mid[i], str(np.round(per,1)) + '%', va='center', ha='center')
    ax.set_ylabel("Number of patients")
    ax.set_xlabel("Year of confirming prediabetes")
    _ = ax.set_title("Patients of different class portion against prediabetes confirmation time(year)")
    temp_group["incidence_rate"] = percent
    temp_group.to_csv(os.path.join(OUT_PATH,"charts", "patient_portion_each_class.csv"))
    plt.savefig(os.path.join(OUT_PATH,"charts", "patient_portion_each_class.png"))
    plt.cla()
    plt.clf()
    
#     # T2DM progression survival curve
#     years = [f'{x}-01-01' for x in range(2004,2018)]
#     cnt = []
#     INTV = 1
#     age = [INTV*i for i in range(20//INTV,90//INTV + 1)]
#     for a in age:
#         cnt.append(diab_patient_age.query(f"diab_age < {a} and not (death_age < {a})")["pseudo_patient_key"].count())
#     surv = [diab_patient_age.shape[0] - c for c in cnt]
#     fig, ax = plt.subplots(figsize=(10,5))
#     ax.step(age, surv, 'k-', color='r')
#     ax.set_xticks([5*i for i in range(4, 19)])
#     ax.tick_params(axis='x', labelrotation=45)
#     ax.set_xlabel("Prediabetes age")
#     ax.set_ylabel("Number of patients")
#     _ = ax.set_title("T2DM progression survival curve")
#     plt.savefig(os.path.join(OUT_PATH,"charts", "survival_curve.png"))
#     plt.cla()
#     plt.clf()
    
    # investigate the progress free period against the age
    data = pre2diab.assign(period = pre2diab.diab_age - pre2diab.pre_age)[["pre_age", "diab_age", "death_age", "period"]]
    bin = pd.cut(data.pre_age, bins = [5 * i for i in range(3,23)])
    data = data.assign(bin = bin)
    out = data.groupby("bin").agg({"period":["count","mean"]})
    out.to_csv(os.path.join(OUT_PATH,"charts", "progress_free_period.csv"))
    ###################### plot ##########################
    ax = out["period"]["mean"].plot.bar(rot=30, color='b', figsize=(10,4))
    ax.plot(["18-20"] + [f"{i*5}-{i*5+5}" for i in range(4, 22)],out["period"]["mean"].tolist())
    ax.set_xticklabels(["18-20"] + [f"{i*5}-{i*5+5}" for i in range(4, 22)])
    ax.set_xlabel("Age")
    ax.set_ylabel("Year")
    ax.set_ylim([0,6])
    _ = ax.set_title("Mean progression period with respect to prediabetes age")
    # write to disk
    plt.savefig(os.path.join(OUT_PATH,"charts", "progress_free_period.png"))
    out.to_csv(os.path.join(OUT_PATH,"charts", "progress_free_period.csv"))
    log["mean_progression_year"] = data["period"].mean()
    plt.cla()
    plt.clf()
    
    ##########################################################analysis###########################################################

    target_id = pd.Series(term_id).rename('term_id')
    patients = diab_patient_excluded
    target_tid_mapping = pd.merge(target_id,tid_to_eid,how='left',on='term_id')
    labresult_filtered = pd.merge(labresult,target_tid_mapping,how='inner',on='entity_id')
    #replace the null value with np.nan
    labresult_filtered.replace(r'""', np.nan, inplace=True)
    #left join the patients tables and test tables
    patients_test = pd.merge(left=patients[["pseudo_patient_key", "pre_diff_hour", "diab_diff_hour"]], 
                             right=labresult_filtered[["pseudo_patient_key", "term_id", "reference_dtm", "diff_in_hour_reference_dtm", "si_unit","si_numeric" ]], 
                             on="pseudo_patient_key", 
                             how="left")
    #rename
    patients_test.rename(columns = {"diff_in_hour_reference_dtm": "test_diff_hour", "reference_dtm": "test_dtm"}, inplace=True)
    # select the observations that the test is within six month before the 
    # prediabetes diagnosis to three months after
    # patients_test_filtered = patients_test.query("test_diff_hour > (pre_diff_hour - 6 * 30 * 24) and test_diff_hour < (pre_diff_hour + 3 * 60 * 24)")
    patients_test_filtered = patients_test.query("test_diff_hour <= pre_diff_hour and test_diff_hour > (pre_diff_hour - 6 * 30 * 24)")
    # drop the observations that the test is later than diagnosis of diabetes
    # patients_test_filtered = patients_test_filtered.query("diab_diff_hour.isnull() or test_diff_hour < diab_diff_hour", engine="python")

    # only keep the patient id, term_id and test results
    patients_test_filtered = patients_test_filtered[["pseudo_patient_key", "term_id", "test_diff_hour", "si_numeric"]]
    # casting type for si_numeric
    patients_test_filtered["si_numeric"] = patients_test_filtered["si_numeric"].astype("float")
    patients_test_group = patients_test_filtered\
    .sort_values(by=["pseudo_patient_key", "test_diff_hour"])\
    .groupby(["pseudo_patient_key", "term_id"], as_index=False)\
    .agg({"si_numeric":"mean"})
    patients_features_pivoted = patients_test_group.pivot_table(index="pseudo_patient_key", columns="term_id", values="si_numeric")
    # reset index
    patients_features_pivoted = patients_features_pivoted.reset_index()
    # rename all the tests out of interest
    patients_features_pivoted.rename(columns=name_dict, inplace=True)
    dataset = pd.merge(left=patients_features_pivoted, right=patients, how="inner", on="pseudo_patient_key") # join the with the patient information
    dataset = dataset.query("HBA1C < 6.4 or HBA1C.isnull()", engine='python')
    dataset = dataset.query("cholesLDL_1 > 0 or cholesLDL_1.isnull()", engine='python')
    dataset.to_csv(f"../tables/output/dataset-{TIME_SPEC}year")
    
    # plot feature distribution
    def plot_feature_cpf(feature_name:str):
        y, x, _ = plt.hist(dataset[feature_name], ec="black", bins=500, cumulative=-1, histtype="step", density=True)
        pd.DataFrame({"x": x[1:], "y": y}).to_csv(os.path.join(OUT_PATH, "charts", "baseline_{}_dist.csv".format(feature_name)))
        plt.savefig(os.path.join(OUT_PATH, "charts", "baseline_{}_dist.png".format(feature_name)))
        plt.cla()
        plt.clf()
    
        
    
    mean = []
    std = []
    iqr = []
    mad = []
    missing_rate = []
    pvalue = []
    indicators = ["pre_age", "sex"] + list(name_dict.values())
    pos_mean = []
    neg_mean = []
    pos_std = []
    neg_std = []
    pos_iqr = []
    neg_iqr = []
    pos_mad = []
    neg_mad = []
    
    for ind in indicators:
        if ind == "sex":
            temp = dataset[['sex', 'cls']]
            male_0 = temp.query("sex == 'M' and cls == 0").count()
            male_1 = temp.query("sex == 'M' and cls == 1").count()
            female_0 = temp.query("sex == 'F' and cls == 0").count()
            female_1 = temp.query("sex == 'F' and cls == 1").count()
            result = stats.chi2_contingency([[male_0, female_0], [male_1, female_1]])
            p_value = result[1]
            pvalue.append(p_value)
            mean.append(np.nan)
            std.append(np.nan)
            iqr.append(np.nan)
#             mad.append(np.nan)
            pos_mean.append(np.nan)
            neg_mean.append(np.nan)
            pos_std.append(np.nan)
            neg_std.append(np.nan)
            pos_iqr.append(np.nan)
            neg_iqr.append(np.nan)
#             pos_mad.append(np.nan)
#             neg_mad.append(np.nan)
        else:
            temp = dataset[[ind, "cls"]]
            temp = temp[temp[ind].notnull()]
            t0 = temp.query("cls == 0")[ind]
            t1 = temp.query("cls == 1")[ind]
            result = stats.ttest_ind(t0, t1)
            pvalue.append(result.pvalue)
            mean.append(temp[ind].mean())
            std.append(temp[ind].std())
            iqr.append(stats.iqr(temp[ind]))
#             mad.append(stats.median_abs_deviation(temp[ind]))
            pos_mean.append(t1.mean())
            neg_mean.append(t0.mean())
            pos_std.append(t1.std())
            neg_std.append(t0.std())
            neg_iqr.append(stats.iqr(t0))
            pos_iqr.append(stats.iqr(t1))
#             pos_mad.append(stats.median_abs_deviation(t1))
#             neg_mad.append(stats.median_abs_deviation(t0))
            plot_feature_cpf(ind)
            
    stt = pd.DataFrame({
        "feauture": indicators,
        "mean": mean,
        "standard deviance": std,
        "iqr": iqr,
#         "mad": mad,
        "neg_mean": neg_mean,
        "neg_std": neg_std,
        "neg_iqr": neg_iqr,
#         "neg_mad": neg_mad,
        "pos_mean": pos_mean,
        "pos_std": pos_std,
        "pos_iqr": pos_iqr,
#         "pos_mad": pos_mad,
        "p-value": pvalue
    })

    demo_info =['pseudo_patient_key',
            'pre_dtm', 
            'pre_diff_hour', 
            'sex',
            'pre_age',
           'cls']
    # select missingness less than 30% tets and HBA1C

    tests_name = list(name_dict.values())
    features = dataset.copy()[demo_info+tests_name]
    missing = features.isnull().sum()
    percent = features.isnull().sum() / features.isnull().count()
    valid = features.notnull().sum()
    missing_data = pd.concat([missing, valid,percent], axis=1, keys=["Missing","Valid", "Missing_percent"])
    missing_data.sort_values("Missing_percent")
    # add to statistic result dataframe
    stt["Missing Rate"] = missing_data.loc[indicators, "Missing_percent"].to_list()

    # write to disk
    stt.to_csv(os.path.join(OUT_PATH, "tables", "overall_statistics.csv"))
    valid_tests = missing_data.loc[tests_name].query("Missing_percent < 0.3").index.to_list()

    if not 'HBA1C' in valid_tests:
        valid_tests.append('HBA1C')

    ds = dataset.copy()[demo_info + valid_tests]
    # drop all the missing data
    ds = ds.dropna(how="any")
    sex_mapper = {'F':0, 'M':1}
    ds["sex"] = ds["sex"].apply(lambda x : sex_mapper[x])
    df_train, df_test = train_test_split(ds, test_size=0.1, random_state=42)

    # normalize the data using RobustScaler()
    scaler = RobustScaler()
    df_train[valid_tests] = scaler.fit_transform(df_train[valid_tests])
    df_test[valid_tests] = scaler.transform(df_test[valid_tests])
    # write to disk
    df_train.describe().to_csv(os.path.join(OUT_PATH, "tables", "training_data_desc.csv"))
    df_test.describe().to_csv(os.path.join(OUT_PATH, "tables", "test_data_desc.csv"))
    # save the scaler
    file = os.path.join(OUT_PATH, "models", "scaler.pkl")
    pickle.dump(scaler, open(file, 'wb'))

    X_train = df_train[valid_tests + ["pre_age", "sex"]]
    y_train = df_train["cls"]
    X_test = df_test[valid_tests + ["pre_age", "sex"]]
    y_test = df_test["cls"]

    #######################################################feature selection#########################################
    fs_conf = {}
    lr = LogisticRegression(penalty='l1', solver='liblinear')
    grid = {"C": [0.001, 0.01, 0.1, 1, 10 ,100, 1000]}
    search = GridSearchCV(estimator=lr, param_grid=grid)
    search.fit(X_train, y_train)
    fs_conf["C"] = search.best_params_["C"]
    features = np.array(X_train.columns).reshape(1,-1)
    gs_model = LogisticRegression(**search.best_params_, penalty='l1', solver='liblinear')
    gs_model.fit(X_train, y_train)
    coefficients = gs_model.coef_
    importance = np.abs(coefficients)

    valid_features=features[importance>0]

    fs_conf["valid_features"] = list(valid_features)
    log["feature_selection"] = fs_conf

    # oversample
    ros = RandomOverSampler(random_state=0)
    X_train, y_train = ros.fit_resample(X_train, y_train)
    dataset = {
        "training_size": X_train.shape[0],
        "test_size": X_test.shape[0],
        "positive_rate": y_test.mean()
    }
    log["dataset"] = dataset
    
    
    
    ##################################################Machine Learning Model###################################################
    ml_info={}

    ##### Logistic regression ########
    logreg_info = {}
    logreg = LogisticRegression()

    grid = {"C": [0.001, 0.01, 0.1, 1, 10 ,100, 1000], 'penalty':['l1', 'l2']}

    logreg_cv = GridSearchCV(logreg, grid, cv=5, scoring='balanced_accuracy')
    logreg_cv.fit(X_train, y_train)
    print("Best parameters " , logreg_cv.best_params_)
    logreg2 = LogisticRegression(**logreg_cv.best_params_)
    logreg2.fit(X_train, y_train)
    logreg_info["parametes"] = logreg_cv.best_params_
    logreg_info["evaluation"] = compute_metrics(logreg2, X_test, y_test)
    plot_auc(logreg2, X_test, y_test)
    ml_info["logreg_info"] = logreg_info
    # write to disk
    file = os.path.join(OUT_PATH, "models", "logreg.pkl")
    pickle.dump(logreg2, open(file, 'wb'))
    ##### Decision tree ##############
    dt_info = {}
    dt = DecisionTreeClassifier()
    dt.fit(X_train, y_train)
    dt_info["evaluation"] = compute_metrics(dt, X_test, y_test)
    ml_info["dt_info"] = dt_info
    # write to disk
    file = os.path.join(OUT_PATH, "models", "dt.pkl")
    pickle.dump(dt, open(file, 'wb'))
    ##########Random Forest ##############
    rf_info = {}
    rfc = RandomForestClassifier(class_weight={0.0:1,1.0:3}, random_state=42)
    rfc.fit(X_train, y_train)
    rf_info["evaluation"] = compute_metrics(rfc, X_test, y_test)
    rf_info["configuration"] = {"weight": {0:1,1:3}}
    ml_info["rf_info"] = rf_info
    # write to disk
    file = os.path.join(OUT_PATH, "models", "rf.pkl")
    pickle.dump(rfc, open(file, 'wb'))
    #######Adaboost Classifier###########
    adb_info = {}
    adb = AdaBoostClassifier(n_estimators=100)
    adb.fit(X_train, y_train)
    adb_info["configuration"] = {
        "n_estimators": 100
    }
    adb_info["evaluation"] = compute_metrics(adb, X_test, y_test)
    plot_auc(adb, X_test, y_test)
    ml_info["adb_info"] = adb_info
    # write to disk
    file = os.path.join(OUT_PATH, "models", "adb.pkl")
    pickle.dump(adb, open(file, 'wb'))
    ##########Gradient Boosting Classifier###########           
    xgb_info = {}
    xgb= GradientBoostingClassifier(n_estimators=100, max_depth=1, random_state=42).fit(X_train, y_train)
    xgb_info["configuration"] = {
        "n_estimators":100,
        "max_depth": 1
    }
    xgb_info["evaluation"] = compute_metrics(xgb, X_test, y_test)
    plot_auc(xgb, X_test, y_test)
    ml_info["xgb_info"] = xgb_info
    # write to disk
    file = os.path.join(OUT_PATH, "models", "xgb.pkl")
    pickle.dump(xgb, open(file, 'wb'))
    ###################SVM#############################
    # SVM_info = {}
    # SVM = SVC(gamma="auto").fit(X_train, y_train)
    # SVM_info["configuration"] = {"kernel": "rbf"}
    # SVM_info["evaluation"] = get_scores(SVM)
    # ml_info["SVM_info"] = SVM_info

    log["ml_info"] = ml_info
    ##################################################Deep Learning Model###################################################
    # early stopping callback on AUC
    dl_info = {}
    EPOCHS = 100
    BATCH_SIZE = 2000 # make sure each batch containes positive cases
    callbacks: AUCStopping = AUCStopping()


    weighted_model = make_model(X_train)
    class_weight = {0: 1, 1: 1.5}
    weighted_history = weighted_model.fit(
        X_train,
        y_train.astype(np.int64),
        batch_size=100,
        epochs=100,
        validation_data=(X_test, y_test),
        callbacks = [callbacks],
        class_weight=class_weight
    )
    weighted_model.save(os.path.join(OUT_PATH,"models", "weighted_model"))
    model_conf = {}
    model_conf["epochs"] = EPOCHS
    model_conf["BATCH_SIZE"] = BATCH_SIZE
    model_conf["weight"] = class_weight
    dl_info["model_conf"] = model_conf
    dl_info["evaluation"] = compute_metrics_dl(weighted_model, X_test, y_test)
    log["dl_info"] = dl_info

    #################feature importance##########
    imp = permutation_importance(weighted_model, X_test, y_test, 30, roc_auc_score, average="weighted")
    log["feature_importance"] = imp

    #################plot########################
    colors = plt.rcParams['axes.prop_cycle'].by_key()['color']
    mpl.rcParams['figure.figsize'] = (12,10)
    train_predictions_weighted = weighted_model.predict(X_train, batch_size=BATCH_SIZE)
    test_predictions_weighted = weighted_model.predict(X_test, batch_size=BATCH_SIZE)

    plot_metrics(weighted_history)
    plt.cla()
    plt.clf()

    #####################################################PLOT ROC#############################
    _ = plot_roc("Train Weighted", y_train, train_predictions_weighted, color=colors[0], linestyle='--')
    fp, tp, threshold = plot_roc("Test Weighted", y_test, test_predictions_weighted, color=colors[1], linestyle='--')
    plt.legend(loc='lower right')
    plt.savefig(os.path.join(OUT_PATH, "charts", "Model ROC"))


    # get optimal threshold
    gmean = np.sqrt(tp * (1-fp))
    index = np.argmax(gmean)
    threshold_op = round(threshold[index], ndigits=4)
    fp_op = round(fp[index], ndigits=4)
    tp_op = round(tp[index], ndigits=4)
    pos_sample = test_predictions_weighted[y_test == 1]
    pos_dist = pd.DataFrame(pos_sample).describe()
    log["threshold"] = float(threshold_op)

    ## plot confusion matrix
    plot_cm(y_test, weighted_model.predict(X_test), name="Weighted Neural Network Confusion Matrix")
#     X_test_valid = X_test[y_test == 1]
    X_test_valid = X_test
    pred_series_valid = weighted_model.predict(X_test_valid)
    pred_series_valid = pd.Series(pred_series_valid[:,0])
    density = stats.gaussian_kde(pred_series_valid)
    x = np.linspace(0,1,100)
    d = density(x)
    log["density"] = d.tolist()
    log["risk_level"] = [pos_dist.loc["25%"][0], pos_dist.loc["50%"][0], pos_dist.loc["75%"][0]]
    json_log = json.dumps(log)
    with open(os.path.join(OUT_PATH, "result.json"), 'w', encoding='utf-8')as f:
        json.dump(json_log, f, ensure_ascii=False, indent=4)   