# Benchmarks

## Initialize

In [1]:
import os
import math
import pathlib
import numpy as np
import pandas as pd
from tqdm.auto import tqdm
from IPython.display import clear_output
import ray

import warnings
import lifelines
from lifelines.utils import CensoringType
from lifelines.utils import concordance_index

In [2]:
lifelines.__version__

'0.26.0'

In [3]:
node = !hostname
if "sc" in node[0]:
    base_path = "/sc-projects/sc-proj-ukb-cvd"
else: 
    base_path = "/data/analysis/ag-reils/ag-reils-shared/cardioRS"
print(base_path)

project_label = "22_retina_phewas"
project_path = f"{base_path}/results/projects/{project_label}"
figure_path = f"{project_path}/figures"
output_path = f"{project_path}/data"

##### BEGIN ADAPT #####
# second best model
# wandb_name = 'aug++_convnext_s_mlp'
# wandb_id = '8ngm6apd'
# best model
wandb_name = 'aug++_convnext_s_mlp+'
wandb_id = '3p3smraz'
partitions = [0] # [i for i in range(22)]
##### END   ADAPT #####

experiment = wandb_id
experiment_path = f"{output_path}/{experiment}"
pathlib.Path(experiment_path).mkdir(parents=True, exist_ok=True)
print(experiment_path)

/sc-projects/sc-proj-ukb-cvd
/sc-projects/sc-proj-ukb-cvd/results/projects/22_retina_phewas/data/3p3smraz


In [4]:
!ls {output_path}

3p3smraz			       eligable_eids_2022-05-24.feather
8ngm6apd			       eligable_eids_2022-05-25.feather
baseline_outcomes_220223.feather       eligable_eids_220511.feather
baseline_outcomes_220412	       eligable_eids_long_2022-05-24.feather
baseline_outcomes_220412.feather       eligable_eids_long_2022-05-25.feather
baseline_outcomes_long_220412.feather  eligable_eids_long_220511.feather
baseline_outcomes_wide_220301	       phecode_defs_220306.feather
baseline_outcomes_wide_220301.feather  retina_endpoints_220301.feather
baseline_outcomes_wide_220306.feather  test_experiment


In [5]:
print(output_path)
data_outcomes = pd.read_feather(f"{output_path}/baseline_outcomes_220412.feather").set_index("eid")
data_outcomes

/sc-projects/sc-proj-ukb-cvd/results/projects/22_retina_phewas/data


Unnamed: 0_level_0,OMOP_4306655_prev,phecode_401_prev,phecode_401-1_prev,phecode_202_prev,phecode_475_prev,phecode_202-2_prev,phecode_713_prev,phecode_718_prev,phecode_460_prev,phecode_713-3_prev,...,phecode_546-2_time,phecode_902_time,phecode_361-4_time,phecode_401-2_time,phecode_596-3_time,phecode_168-211_time,phecode_714-32_time,phecode_719-4_time,phecode_684-12_time,phecode_240_time
eid,Unnamed: 1_level_1,Unnamed: 2_level_1,Unnamed: 3_level_1,Unnamed: 4_level_1,Unnamed: 5_level_1,Unnamed: 6_level_1,Unnamed: 7_level_1,Unnamed: 8_level_1,Unnamed: 9_level_1,Unnamed: 10_level_1,Unnamed: 11_level_1,Unnamed: 12_level_1,Unnamed: 13_level_1,Unnamed: 14_level_1,Unnamed: 15_level_1,Unnamed: 16_level_1,Unnamed: 17_level_1,Unnamed: 18_level_1,Unnamed: 19_level_1,Unnamed: 20_level_1,Unnamed: 21_level_1
1000018,False,True,True,False,False,False,False,False,False,False,...,11.866089,11.866089,11.866089,11.866089,11.866089,11.866089,11.866089,11.866089,11.866089,11.866089
1000020,False,False,False,False,False,False,False,False,False,False,...,13.596446,13.596446,13.596446,13.596446,13.596446,13.596446,13.596446,13.596446,13.596446,13.596446
1000037,False,False,False,False,False,False,True,True,False,True,...,12.868163,12.868163,12.868163,12.868163,12.868163,12.868163,12.868163,12.868163,12.868163,12.868163
1000043,False,True,True,False,False,False,True,False,False,True,...,12.309629,12.309629,12.309629,12.309629,12.309629,12.309629,12.309629,12.309629,12.309629,12.309629
1000051,False,False,False,True,False,True,False,False,False,False,...,15.291210,15.291210,15.291210,15.291210,15.291210,15.291210,15.291210,15.291210,15.291210,15.291210
...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...
6025150,False,False,False,False,False,False,False,False,False,False,...,14.237117,14.237117,14.237117,14.237117,14.237117,14.237117,14.237117,14.237117,14.237117,14.237117
6025165,False,False,False,False,False,False,False,True,True,False,...,13.059816,13.059816,13.059816,13.059816,13.059816,13.059816,13.059816,13.059816,13.059816,13.059816
6025173,False,False,False,False,False,False,False,False,True,False,...,13.018748,13.018748,13.018748,13.018748,13.018748,13.018748,13.018748,13.018748,13.018748,13.018748
6025182,False,False,False,False,True,False,True,True,False,True,...,11.233632,11.233632,11.233632,11.233632,11.233632,11.233632,11.233632,11.233632,11.233632,11.233632


In [6]:
import pandas as pd
all_endpoints = sorted([l.replace('_prevalent', '') for l in list(pd.read_csv('/sc-projects/sc-proj-ukb-cvd/results/projects/22_medical_records/data/220413/frequent_endpoints.csv').endpoint.values)])

#all_endpoints = sorted(endpoints_all_md.endpoint.to_list())
print(len(all_endpoints))

endpoints_not_overlapping_with_preds = []
#endpoints_not_overlapping_with_preds_md = pd.read_csv(f"{experiment_path}/endpoints_not_overlapping.csv", header=None)
#print(len(endpoints_not_overlapping_with_preds_md))
#endpoints_not_overlapping_with_preds = list(endpoints_not_overlapping_with_preds_md[0].values)

endpoints = []
for c in all_endpoints:
    if c not in endpoints_not_overlapping_with_preds: # this is what i want
        #print('OK    - ',c)
        endpoints.append(c)
    #if c in endpoints_not_overlapping_with_preds: # this is what causes errors!
    #    print('ERROR - ',c)
print(len(endpoints))

498
498


In [7]:
splits = ["train", "valid", 'test'] # "test_left", 'test_right'

In [8]:
endpoint_defs = pd.read_feather(f"{output_path}/phecode_defs_220306.feather").query("endpoint==@endpoints").sort_values("endpoint").set_index("endpoint")

In [9]:
from datetime import date
today = str(date.today())

In [10]:
eligable_eids = pd.read_feather(f"{output_path}/eligable_eids_{today}.feather") # TODO CHANGE!
eids_dict = eligable_eids.set_index("endpoint")["eid_list"].to_dict()

In [11]:
%env MKL_NUM_THREADS=4
%env NUMEXPR_NUM_THREADS=4
%env OMP_NUM_THREADS=4

env: MKL_NUM_THREADS=4
env: NUMEXPR_NUM_THREADS=4
env: OMP_NUM_THREADS=4


In [12]:
ray.shutdown()

In [13]:
import ray

ray.init(num_cpus=24)#, dashboard_port=24762, dashboard_host="0.0.0.0", include_dashboard=True)#, webui_url="0.0.0.0"))

{'node_ip_address': '10.32.105.6',
 'raylet_ip_address': '10.32.105.6',
 'redis_address': '10.32.105.6:46209',
 'object_store_address': '/tmp/ray/session_2022-05-25_11-30-34_720418_3768856/sockets/plasma_store',
 'raylet_socket_name': '/tmp/ray/session_2022-05-25_11-30-34_720418_3768856/sockets/raylet',
 'webui_url': None,
 'session_dir': '/tmp/ray/session_2022-05-25_11-30-34_720418_3768856',
 'metrics_export_port': 56609,
 'node_id': 'f143ee874e5f2564900490d610cb26a701c223f90ad6eca21980de80'}

In [14]:
AgeSex = ["age_at_recruitment_f21022_0_0", "sex_f31_0_0"]

# Train COX

In [15]:
in_path = pathlib.Path(f"{experiment_path}/coxph/input")
in_path.mkdir(parents=True, exist_ok=True)

model_path = f"{experiment_path}/coxph/models"
pathlib.Path(model_path).mkdir(parents=True, exist_ok=True)

In [16]:
models = [f.name for f in in_path.iterdir() if f.is_dir() and "ipynb_checkpoints" not in str(f)]
models

['ImageTraining_[]_ConvNeXt_MLPHead']

In [17]:
from formulaic.errors import FactorEvaluationError

In [18]:
in_path

PosixPath('/sc-projects/sc-proj-ukb-cvd/results/projects/22_retina_phewas/data/3p3smraz/coxph/input')

In [19]:
from lifelines import CoxPHFitter
from lifelines.exceptions import ConvergenceError
import zstandard
import pickle

def get_features(endpoint):
    features = {
        models[0]: { # TODO CHANGE!
            "Age+Sex": AgeSex,
            "Retina": [endpoint],
            "Age+Sex+Retina": AgeSex + [endpoint],
            #"Age+Sex+MedicalHistory+I(Age*MH)": AgeSex + [endpoint]
            }
    }
    return features

def get_train_data(in_path, partition, models, mapping):
    train_data = {
        model: pd.read_feather(f"{in_path}/{model}/{partition}/train.feather").set_index("eid").merge(data_outcomes, left_index=True, right_index=True, how="left").replace(mapping)
    for model in models}
    
    return train_data

def fit_cox(data_fit, feature_set, covariates, endpoint, penalizer, step_size=1):
    if feature_set=="Age+Sex+MedicalHistory+I(Age*MH)":
        endpoint_label = endpoint.replace("-", "")
        data_fit.columns = [c.replace("-", "") for c in data_fit.columns]
        covariates = [c.replace("-", "") for c in covariates]
        #print(endpoint_label)
        #print(data_fit)
        #print(covariates)
        if "sex_f31_0_0" in covariates:
            formula=f"age_at_recruitment_f21022_0_0*{endpoint_label}+sex_f31_0_0*{endpoint_label}"
        else:
            formula=f"age_at_recruitment_f21022_0_0*{endpoint_label}"
        cph = CoxPHFitter(penalizer=penalizer)
        cph.fit(data_fit, f"{endpoint_label}_time", f"{endpoint_label}_event", formula=formula, step_size=step_size)
    else:
        cph = CoxPHFitter(penalizer=penalizer)
        cph.fit(data_fit, f"{endpoint}_time", f"{endpoint}_event", step_size=step_size)

    return cph

def save_pickle(data, data_path):
    with open(data_path, "wb") as fh:
        cctx = zstandard.ZstdCompressor()
        with cctx.stream_writer(fh) as compressor:
            compressor.write(pickle.dumps(data, protocol=pickle.HIGHEST_PROTOCOL))
            
def load_pickle(fp):
    with open(fp, "rb") as fh:
        dctx = zstandard.ZstdDecompressor()
        with dctx.stream_reader(fh) as decompressor:
            data = pickle.loads(decompressor.read())
    return data

@ray.remote
def fit_endpoint(data_partition, eids_dict, endpoint_defs, endpoint, partition, models, model_path):
    eids_incl = eids_dict[endpoint].tolist()
    features = get_features(endpoint)
    eligibility = endpoint_defs.loc[endpoint]["sex"]
    for model in models:
        data_model = data_partition[model]
        for feature_set, covariates in features[model].items():
            cph_path = f"{model_path}/{endpoint}_{feature_set}_{partition}.p"
            if os.path.isfile(cph_path):
                try:
                    cph = load_pickle(cph_path)
                    success = True
                except:
                    success = False
                    pass
            if not os.path.isfile(cph_path) or success==False:
                if (eligibility != "Both") and ("sex_f31_0_0" in covariates): 
                    covariates = [c for c in covariates if c!="sex_f31_0_0"]
                #print('covariates:', covariates)
                data_endpoint = data_model[covariates + [f"{endpoint}_event", f"{endpoint}_time"]].astype(np.float32)
                data_endpoint = data_endpoint[data_endpoint.index.isin(eids_incl)]
                try:
                    cph = fit_cox(data_endpoint, feature_set, covariates, endpoint, penalizer=0.0)
                    save_pickle(cph, cph_path)
                except (ValueError, ConvergenceError, KeyError,FactorEvaluationError) as e:
                    print("ConvergenceError", model, endpoint, feature_set, partition, "problem: reduce step size")
                    try:
                        cph = fit_cox(data_endpoint, feature_set, covariates, endpoint, penalizer=0.0, step_size=0.5)
                        save_pickle(cph, cph_path)
                        print("ConvergenceError", model, endpoint, feature_set, partition, "trying with reduced step size ... 0.5 successfull")
                    except (ValueError, ConvergenceError, KeyError,FactorEvaluationError) as e:
                        print("ConvergenceError", model, endpoint, feature_set, partition, "trying with reduced step size ... 0.5 failed")
                        try:
                            cph = fit_cox(data_endpoint, feature_set, covariates, endpoint, penalizer=0.0, step_size=0.1)
                            save_pickle(cph, cph_path)
                            print("ConvergenceError", model, endpoint, feature_set, partition, "trying with reduced step size ... 0.1 successfull")
                        except (ValueError, ConvergenceError, KeyError, FactorEvaluationError) as e:
                            print("ConvergenceError", model, endpoint, feature_set, partition, "trying with reduced step size ... 0.1 failed")
                            save_pickle(data_endpoint, f"{experiment_path}/coxph/errordata_{endpoint}_{feature_set}_{partition}.p")
                            pass
    return True

In [20]:
f"{experiment_path}/coxph"

'/sc-projects/sc-proj-ukb-cvd/results/projects/22_retina_phewas/data/3p3smraz/coxph'

In [21]:
model_list =  !ls $model_path
#model_list = [m for m in model_list if "I(" in m]
model_list = [m for m in model_list]

In [22]:
model_list

[]

In [23]:
1+1

2

In [24]:
mapping = {"sex_f31_0_0": {"Female":0, "Male":1}}

ray_eids = ray.put(eids_dict)
ray_endpoint_defs = ray.put(endpoint_defs)
for partition in tqdm([0]): # in tqdm(partitions) # TODO: CHANGE!
    try:
        del ray_partition
    except:
        print("Ray object not yet initialised")
    try:
        data_partition = get_train_data(in_path, partition, models, mapping)
        ray_partition = ray.put(data_partition)
        progress = []
        for endpoint in endpoints:
            progress.append(fit_endpoint.remote(ray_partition, ray_eids, ray_endpoint_defs, endpoint, partition, models, model_path))
        [ray.get(s) for s in tqdm(progress)]
    except FileNotFoundError:
        print('file not found')
        pass

  0%|          | 0/1 [00:00<?, ?it/s]

Ray object not yet initialised


  0%|          | 0/498 [00:00<?, ?it/s]

[2m[36m(pid=3769047)[0m 
[2m[36m(pid=3769047)[0m >>> events = df['phecode_608-1_event'].astype(bool)
[2m[36m(pid=3769047)[0m >>> print(df.loc[events, 'sex_f31_0_0'].var())
[2m[36m(pid=3769047)[0m >>> print(df.loc[~events, 'sex_f31_0_0'].var())
[2m[36m(pid=3769047)[0m 
[2m[36m(pid=3769047)[0m A very low variance means that the column sex_f31_0_0 completely determines whether a subject dies or not. See https://stats.stackexchange.com/questions/11109/how-to-deal-with-perfect-separation-in-logistic-regression.
[2m[36m(pid=3769047)[0m 
[2m[36m(pid=3769047)[0m 
[2m[36m(pid=3769058)[0m 
[2m[36m(pid=3769058)[0m >>> events = df['phecode_614-55_event'].astype(bool)
[2m[36m(pid=3769058)[0m >>> print(df.loc[events, 'sex_f31_0_0'].var())
[2m[36m(pid=3769058)[0m >>> print(df.loc[~events, 'sex_f31_0_0'].var())
[2m[36m(pid=3769058)[0m 
[2m[36m(pid=3769058)[0m A very low variance means that the column sex_f31_0_0 completely determines whether a subject dies or 

In [33]:
load_pickle("/sc-projects/sc-proj-ukb-cvd/results/projects/22_retina_phewas/data/test_experiment/coxph/models/phecode_841_Retina_0.p")

<lifelines.CoxPHFitter: fitted with 47954 total observations, 46466 right-censored observations>

In [22]:
data_partition['Identity(Records)+MLP']['phecode_977']

eid
1303905    0.507814
1303918    0.033095
1303920   -0.946619
1303937    0.042184
1303943   -0.270619
             ...   
5316968   -0.614777
5316970    0.009382
5316985   -0.361461
5316994   -0.194883
5317002   -0.581893
Name: phecode_977, Length: 401263, dtype: float64