# Benchmarks

## Initialize

In [1]:
import os
import math
import pathlib
import numpy as np
import pandas as pd
from tqdm.auto import tqdm
from IPython.display import clear_output

import warnings
from lifelines.utils import CensoringType
from lifelines.utils import concordance_index

In [2]:
node = !hostname
if "sc" in node[0]:
    base_path = "/sc-projects/sc-proj-ukb-cvd"
else: 
    base_path = "/data/analysis/ag-reils/ag-reils-shared/cardioRS"
print(base_path)

project_label = "22_retina_phewas"
project_path = f"{base_path}/results/projects/{project_label}"
figure_path = f"{project_path}/figures"
output_path = f"{project_path}/data"

pathlib.Path(figure_path).mkdir(parents=True, exist_ok=True)
pathlib.Path(output_path).mkdir(parents=True, exist_ok=True)

experiment = '221108'
experiment_path = f"{output_path}/{experiment}"
pathlib.Path(experiment_path).mkdir(parents=True, exist_ok=True)


name_dict = {
    "predictions_cropratio0.66": "ConvNextSmall(Retina)+MLP_cropratio0.66",
}

partitions = [i for i in range(22)]
# partitions = [ 4, 15,  3,  2, 21, 14, 13, 20,  1,  0, 12, 11, 19, 18, 10, 17, 16]
partitions

/sc-projects/sc-proj-ukb-cvd


[0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21]

In [3]:
today = '221109'

In [4]:
model_path = f"{experiment_path}/coxph/models"
model_list =  !ls $model_path

In [5]:
import pandas as pd
endpoints = sorted([l.replace('_prevalent', '') for l in list(pd.read_csv('/sc-projects/sc-proj-ukb-cvd/results/projects/22_retinal_risk/data/220602/endpoints.csv').endpoint.values)])

In [6]:
splits = ["train", "valid", 'test'] # "test_left", 'test_right']

In [7]:
endpoint_defs = pd.read_feather(f"{output_path}/phecode_defs_220306.feather").query("endpoint==@endpoints").sort_values("endpoint").set_index("endpoint")

In [8]:
from datetime import date
today = str(date.today()) if today is None else today


In [9]:
eligable_eids = pd.read_feather(f"{output_path}/eligable_eids_{today}.feather")
eids_dict = eligable_eids.set_index("endpoint")["eid_list"].to_dict()

In [10]:
%env MKL_NUM_THREADS=4
%env NUMEXPR_NUM_THREADS=4
%env OMP_NUM_THREADS=4

env: MKL_NUM_THREADS=4
env: NUMEXPR_NUM_THREADS=4
env: OMP_NUM_THREADS=4


In [11]:
import ray
ray.shutdown()
#ray start --head --port=6379 --num-cpus 64
ray.init(address='auto')
#ray.init(num_cpus=24)#, dashboard_port=24762, dashboard_host="0.0.0.0", include_dashboard=True)#, webui_url="0.0.0.0"))

RayContext(dashboard_url='', python_version='3.9.7', ray_version='1.13.0', ray_commit='e4ce38d001dbbe09cd21c497fedd03d692b2be3e', address_info={'node_ip_address': '10.32.105.4', 'raylet_ip_address': '10.32.105.4', 'redis_address': None, 'object_store_address': '/tmp/ray/session_2022-11-09_00-47-11_884248_1810431/sockets/plasma_store', 'raylet_socket_name': '/tmp/ray/session_2022-11-09_00-47-11_884248_1810431/sockets/raylet', 'webui_url': '', 'session_dir': '/tmp/ray/session_2022-11-09_00-47-11_884248_1810431', 'metrics_export_port': 62434, 'gcs_address': '10.32.105.4:6321', 'address': '10.32.105.4:6321', 'node_id': 'ac31dcab439be3cd133e03d321b37ad63bd336fc955ff53097d3e85d'})

# Predict COX

In [12]:
in_path = pathlib.Path(f"{experiment_path}/coxph/input")
model_path = f"{experiment_path}/coxph/models"

out_path = f"{experiment_path}/coxph/predictions"
pathlib.Path(out_path).mkdir(parents=True, exist_ok=True)

In [13]:
in_path

PosixPath('/sc-projects/sc-proj-ukb-cvd/results/projects/22_retina_phewas/data/221108/coxph/input')

In [14]:
models = [f.name for f in in_path.iterdir() if f.is_dir() and "ipynb_checkpoints" not in str(f)]
models

['ImageTraining_[]_ConvNeXt_MLPHead_predictions_cropratio0.66']

In [15]:
#AgeSex = ["age_at_recruitment_f21022_0_0", "sex_f31_0_0"]

In [16]:
from lifelines import CoxPHFitter
from lifelines.exceptions import ConvergenceError
import zstandard
import pickle
import os

def get_score_defs():

    with open(r'/sc-projects/sc-proj-ukb-cvd/results/projects/22_retina_phewas/data/score_definitions.yaml') as file:
        score_defs = yaml.full_load(file)
    
    return score_defs

def get_features(endpoint, score_defs):
    features = {
        model: {
            "Age+Sex": score_defs["AgeSex"],
            "Retina": [endpoint],
#             "SCORE2": score_defs["SCORE2"],
#             "ASCVD": score_defs["ASCVD"],
#             "QRISK3": score_defs["QRISK3"],
            "Age+Sex+Retina": score_defs["AgeSex"] + [endpoint],
#             "SCORE2+Retina": score_defs["SCORE2"] + [endpoint],
#             "ASCVD+Retina": score_defs["ASCVD"] + [endpoint],
#             "QRISK3+Retina": score_defs["QRISK3"] + [endpoint],
            }
        for model in models}
    return features

#def get_test_data(in_path, partition, models, mapping):
def get_test_data(in_path, partition, models):
    data = {model: pd.read_feather(f"{in_path}/{model}/{partition}/test.feather").set_index("eid")#.replace(mapping)
            for model in models}
    return data
    #left_data = {model: pd.read_feather(f"{in_path}/{model}/{partition}/test_left.feather").set_index("eid").replace(mapping)for model in models}
    #right_data = {model: pd.read_feather(f"{in_path}/{model}/{partition}/test_right.feather").set_index("eid").replace(mapping)for model in models}
    #return (left_data, right_data)
            
def load_pickle(fp):
    with open(fp, "rb") as fh:
        dctx = zstandard.ZstdDecompressor()
        with dctx.stream_reader(fh) as decompressor:
            data = pickle.loads(decompressor.read())
    return data

def predict_cox(cph, data_endpoint, endpoint, feature_set, partition, pred_path, model):
    times = [1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15]
    time_cols = {t: f"Ft_{t}" for t in times}
    
    if feature_set=="Age+Sex+MedicalHistory+I(Age*MH)":
        data_endpoint.columns = [c.replace("-", "") for c in data_endpoint.columns]
    
    surv_test = 1-cph.predict_survival_function(data_endpoint, times=times)
    temp_pred = data_endpoint.reset_index()[["eid"]].assign(endpoint=endpoint, features=feature_set, partition=partition)
    for t, col in time_cols.items(): 
        temp_pred[col] = surv_test.T[t].to_list()
    
    temp_pred.to_feather(f"{out_path}/{endpoint}_{feature_set}_{model}_{partition}.feather") 

# for both eyes
def predict_cox_both_eyes(cph, data_endpoint_left, data_endpoint_right, endpoint, feature_set, partition, pred_path):
    times = [1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15]
    time_cols = {t: f"Ft_{t}" for t in times}
    
    if feature_set=="Age+Sex+MedicalHistory+I(Age*MH)":
        data_endpoint_left.columns = [c.replace("-", "") for c in data_endpoint_left.columns]
        data_endpoint_right.columns = [c.replace("-", "") for c in data_endpoint_right.columns]

    # left eye
    surv_test_left = 1-cph.predict_survival_function(data_endpoint_left, times=times)
    temp_pred_left = data_endpoint_left.reset_index()[["eid"]].assign(endpoint=endpoint, features=feature_set, partition=partition)
    for t, col in time_cols.items(): 
        temp_pred_left[col] = surv_test_left.T[t].to_list()
    
    temp_pred_left.to_feather(f"{out_path}/{endpoint}_{feature_set}_{partition}_left.feather")
    
    # right eye
    surv_test_right = 1-cph.predict_survival_function(data_endpoint_right, times=times)
    temp_pred_right = data_endpoint_right.reset_index()[["eid"]].assign(endpoint=endpoint, features=feature_set, partition=partition)
    for t, col in time_cols.items(): 
        temp_pred_right[col] = surv_test_right.T[t].to_list()
    
    temp_pred_right.to_feather(f"{out_path}/{endpoint}_{feature_set}_{partition}_right.feather")
    
    # mean both eyes
    eids = list(surv_test_left.columns.values)
    surv_test_mean = pd.concat((surv_test_left, surv_test_right), axis=1)
    surv_test_mean.columns = [f'{col}_left' for col in list(surv_test_left.columns.values)] + [f'{col}_right' for col in list(surv_test_right.columns.values)]
    for eid in eids:
        eid_columns = [col for col in list(surv_test_mean.columns.values) if str(eid) in col]
        surv_test_mean[f'{eid}_mean'] = surv_test_mean[eid_columns].mean(axis=1)
    surv_test_mean = surv_test_mean[[col for col in list(surv_test_mean.columns.values) if 'mean' in col]].rename(columns={col: col.replace('_mean', '') for col in list(surv_test_mean.columns.values)})
    
    temp_pred_mean = data_endpoint_left.reset_index()[["eid"]].assign(endpoint=endpoint, features=feature_set, partition=partition)
    for t, col in time_cols.items(): 
        temp_pred_mean[col] = surv_test_mean.T[t].to_list()
     
    temp_pred_mean.to_feather(f"{out_path}/{endpoint}_{feature_set}_{partition}_mean.feather")

@ray.remote
def predict_endpoint(data_partition, eids_dict, endpoint, partition, models, features, model_path, out_path):
    #data_partition_left, data_partition_right = data_partition
    eids_incl = eids_dict[endpoint].tolist()
    results = []
    for model in models:
        data_model = data_partition[model]
        #data_model_left = data_partition_left[model]
        #data_model_right = data_partition_right[model]
        for feature_set, covariates in features[model].items():
            identifier = f"{endpoint}_{feature_set}_{model}_{partition}"
            pred_path = f"{out_path}/{identifier}.feather"
            if not os.path.isfile(pred_path):
                try:
                    cph = load_pickle(f"{model_path}/{identifier}.p")
                    data_endpoint = data_model[data_model.index.isin(eids_incl)]
                    #data_endpoint_left = data_model_left[data_model_left.index.isin(eids_incl)]
                    #data_endpoint_right = data_model_right[data_model_right.index.isin(eids_incl)]
                    predict_cox(cph, data_endpoint, endpoint, feature_set, partition, pred_path, model)
                    #predict_cox_both_eyes(cph, data_endpoint_left, data_endpoint_right, endpoint, feature_set, partition, pred_path)
                except FileNotFoundErrorundError:
                    print(f"{identifier} not available")
    return True

In [17]:
import yaml

#mapping = {"sex_f31_0_0": {"Female":0, "Male":1}}
score_defs = get_score_defs()

ray_eids = ray.put(eids_dict)
for partition in tqdm(partitions):
    try:
        del ray_partition
    except:
        print("Ray object not yet initialised")
    #ray_partition = ray.put(get_test_data(in_path, partition, models, mapping))
    ray_partition = ray.put(get_test_data(in_path, partition, models))
    progress = []
    for endpoint in endpoints:
        features = get_features(endpoint, score_defs)
        progress.append(predict_endpoint.remote(ray_partition, ray_eids, endpoint, partition, models, features, model_path, out_path))
    [ray.get(s) for s in tqdm(progress)]

  0%|          | 0/22 [00:00<?, ?it/s]

Ray object not yet initialised


  0%|          | 0/1171 [00:00<?, ?it/s]

[2m[33m(raylet)[0m [2022-11-09 00:49:28,007 E 1810492 1810492] (raylet) worker_pool.cc:502: Some workers of the worker process(1818111) have not registered within the timeout. The process is still alive, probably it's hanging during start.
[2m[33m(raylet)[0m [2022-11-09 00:49:28,010 E 1810492 1810492] (raylet) worker_pool.cc:502: Some workers of the worker process(1818114) have not registered within the timeout. The process is still alive, probably it's hanging during start.
[2m[33m(raylet)[0m [2022-11-09 00:49:28,014 E 1810492 1810492] (raylet) worker_pool.cc:502: Some workers of the worker process(1818124) have not registered within the timeout. The process is still alive, probably it's hanging during start.
[2m[33m(raylet)[0m [2022-11-09 00:49:28,018 E 1810492 1810492] (raylet) worker_pool.cc:502: Some workers of the worker process(1818129) have not registered within the timeout. The process is still alive, probably it's hanging during start.
[2m[33m(raylet)[0m [2022-

  0%|          | 0/1171 [00:00<?, ?it/s]

  0%|          | 0/1171 [00:00<?, ?it/s]

  0%|          | 0/1171 [00:00<?, ?it/s]

  0%|          | 0/1171 [00:00<?, ?it/s]

  0%|          | 0/1171 [00:00<?, ?it/s]

  0%|          | 0/1171 [00:00<?, ?it/s]

  0%|          | 0/1171 [00:00<?, ?it/s]

  0%|          | 0/1171 [00:00<?, ?it/s]

  0%|          | 0/1171 [00:00<?, ?it/s]

  0%|          | 0/1171 [00:00<?, ?it/s]

  0%|          | 0/1171 [00:00<?, ?it/s]

  0%|          | 0/1171 [00:00<?, ?it/s]

  0%|          | 0/1171 [00:00<?, ?it/s]

  0%|          | 0/1171 [00:00<?, ?it/s]

  0%|          | 0/1171 [00:00<?, ?it/s]

  0%|          | 0/1171 [00:00<?, ?it/s]

  0%|          | 0/1171 [00:00<?, ?it/s]

  0%|          | 0/1171 [00:00<?, ?it/s]

  0%|          | 0/1171 [00:00<?, ?it/s]

  0%|          | 0/1171 [00:00<?, ?it/s]

  0%|          | 0/1171 [00:00<?, ?it/s]