# Benchmarks

## Initialize

In [2]:
import os
import math
import pathlib
import numpy as np
import pandas as pd
import pyarrow as pa
import pyarrow.feather as feather
from tqdm.auto import tqdm
from IPython.display import clear_output

import warnings
from lifelines.utils import CensoringType
from lifelines.utils import concordance_index

In [4]:
node = !hostname
if "sc" in node[0]:
    base_path = "/sc-projects/sc-proj-ukb-cvd"
else: 
    base_path = "/data/analysis/ag-reils/ag-reils-shared/cardioRS"
print(base_path)

project_label = "22_retina_phewas_220603_fullrun"
project_path = f"{base_path}/results/projects/{project_label}"
figure_path = f"{project_path}/figures"
output_path = f"{project_path}/data"

pathlib.Path(figure_path).mkdir(parents=True, exist_ok=True)
pathlib.Path(output_path).mkdir(parents=True, exist_ok=True)

##### BEGIN ADAPT #####
# second best model
# wandb_name = 'aug++_convnext_s_mlp'
# wandb_id = '8ngm6apd'
# best model
# wandb_name = 'aug++_convnext_s_mlp+'
# wandb_id = '3p3smraz'
# transformer model
#wandb_name = '...'
#wandb_id = '2af9tvdp'
##### END   ADAPT #####

experiment = '220603_fullrun'
experiment_path = f"{output_path}/{experiment}"
pathlib.Path(experiment_path).mkdir(parents=True, exist_ok=True)

name_dict = {
    "predictions_cropratio0.3": "ConvNextSmall(Retina)+MLP_cropratio0.3",
    "predictions_cropratio0.5": "ConvNextSmall(Retina)+MLP_cropratio0.5",
    "predictions_cropratio0.8": "ConvNextSmall(Retina)+MLP_cropratio0.8",
}

partitions = [i for i in range(22)]
partitions

/sc-projects/sc-proj-ukb-cvd


[0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21]

In [6]:
import ray
#ray.shutdown()
#ray.init(num_cpus=24)
ray.init(address='auto')

RayContext(dashboard_url='', python_version='3.9.7', ray_version='1.12.1', ray_commit='4863e33856b54ccf8add5cbe75e41558850a1b75', address_info={'node_ip_address': '10.32.105.2', 'raylet_ip_address': '10.32.105.2', 'redis_address': None, 'object_store_address': '/tmp/ray/session_2022-06-23_11-20-47_129565_808944/sockets/plasma_store', 'raylet_socket_name': '/tmp/ray/session_2022-06-23_11-20-47_129565_808944/sockets/raylet', 'webui_url': '', 'session_dir': '/tmp/ray/session_2022-06-23_11-20-47_129565_808944', 'metrics_export_port': 48875, 'gcs_address': '10.32.105.2:6379', 'address': '10.32.105.2:6379', 'node_id': 'ce155bbd3fd7e6cf07052b035fb022bb2128648e78bb68e1b26166f2'})

In [7]:
import pandas as pd
endpoints = sorted([l.replace('_prevalent', '') for l in list(pd.read_csv('/sc-projects/sc-proj-ukb-cvd/results/projects/22_retinal_risk/data/220602/endpoints.csv').endpoint.values)])

In [8]:
covariates = ["age_at_recruitment_f21022_0_0", "sex_f31_0_0",  "ethnic_background_f21000_0_0"]

In [9]:
data_covariates = pd.read_feather(f"{experiment_path}/data_covariates.feather").set_index("eid")[covariates]\
    .assign(age_at_recruitment_f21022_0_0 = lambda x: x.age_at_recruitment_f21022_0_0.astype(np.int32))

In [10]:
data_covariates_ray = ray.put(data_covariates)

In [11]:
variables_to_norm = ["age_at_recruitment_f21022_0_0"] + endpoints
len(variables_to_norm)

1172

In [12]:
in_path = pathlib.Path(f"{experiment_path}/loghs")
in_path.mkdir(parents=True, exist_ok=True)

out_path = f"{experiment_path}/coxph/input"
pathlib.Path(out_path).mkdir(parents=True, exist_ok=True)

In [13]:
models = [f.name for f in in_path.iterdir() if f.is_dir() and "ipynb_checkpoints" not in str(f)]
print(models)

['ImageTraining_[]_ConvNeXt_MLPHead_predictions_cropratio0.3', 'ImageTraining_[]_ConvNeXt_MLPHead_predictions_cropratio0.8', 'ImageTraining_[]_ConvNeXt_MLPHead_predictions_cropratio0.5']


In [14]:
from sklearn.preprocessing import StandardScaler
import pickle
import zstandard
import glob, os
    
def find_retina_eid_intersection():
    img_root = '/sc-projects/sc-proj-ukb-cvd/data/retina/preprocessed/preprocessed'
    img_visit = 0
    img_file_extension = '.png'
    eids_with_retinapic = [int(fp.split('/')[-1].split('_')[0]) for fp in sorted( glob.glob(os.path.join(img_root, f'*{img_file_extension}' 
                           if img_file_extension is not None else '*'))) 
                           if f'_{img_visit}_' in fp]
    len(eids_with_retinapic)
    
    d = []
    for endpoint in tqdm(endpoints):
        s = data_outcomes[f"{endpoint}_event"].loc[np.intersect1d(eids_dict[endpoint], eids_with_retinapic)]   # .loc[eids_dict[endpoint]]
        n = s.sum()
        freq = s.mean()
        d.append({"endpoint": endpoint, "eligable":len(np.intersect1d(eids_dict[endpoint], eids_with_retinapic)), "n": n, "freq": freq})
        
    endpoints_freqs = pd.DataFrame().from_dict(d)
    endpoints_ds = endpoints_freqs.query("n>100").sort_values("endpoint").reset_index(drop=True)

    return endpoints_ds # TODO
    

def read_merge_data(fp_in, split, data_covariates):
    temp = pd.read_feather(f"{fp_in}/{split}.feather").set_index("eid")
    temp = temp.merge(data_covariates, left_index=True, right_index=True, how="left")
    return temp   
    
def save_pickle(data, data_path):
    with open(data_path, "wb") as fh:
        cctx = zstandard.ZstdCompressor()
        with cctx.stream_writer(fh) as compressor:
            compressor.write(pickle.dumps(data, protocol=pickle.HIGHEST_PROTOCOL))
    
@ray.remote
def norm_variables(data_covariates, model, partition, variables):

    fp_in = f"{in_path}/{model}/{partition}"
    fp_out = f"{out_path}/{model}/{partition}"
    
    if pathlib.Path(fp_in).is_dir():
        if not pathlib.Path(fp_out).is_dir():
            pathlib.Path(fp_out).mkdir(parents=True, exist_ok=True)
            for split in ["train", "valid", 'test']: # "test_left", 'test_right'
                temp = read_merge_data(fp_in, split, data_covariates)
                if split=="train": 
                    scaler = StandardScaler(with_mean=True, with_std=True, copy=True).fit(temp[variables].values)
                    save_pickle(scaler, f"{fp_out}/scaler.p")
                temp[variables] = scaler.transform(temp[variables].values)
                temp.reset_index(drop=False).to_feather(f"{fp_out}/{split}.feather")
    return True

In [15]:
def norm_logh_and_extra(data_covariates_ray, variables):
    progress = []
    for model in tqdm(models):
        for partition in [p for p in partitions]: 
            progress.append(norm_variables.remote(data_covariates, model, partition, variables))
    [ray.get(s) for s in tqdm(progress)]

In [16]:
norm_logh_and_extra(data_covariates_ray, variables_to_norm)

  0%|          | 0/3 [00:00<?, ?it/s]

  0%|          | 0/66 [00:00<?, ?it/s]

[2m[1m[36m(scheduler +1m0s)[0m Tip: use `ray status` to view detailed cluster status. To disable these messages, set RAY_SCHEDULER_EVENTS=0.


In [17]:
1+1

2