# Benchmarks

## Initialize

In [1]:
import os
import math
import pathlib
import numpy as np
import pandas as pd
from tqdm.auto import tqdm
from IPython.display import clear_output

import warnings
from lifelines.utils import CensoringType
from lifelines.utils import concordance_index

In [2]:
node = !hostname
if "sc" in node[0]:
    base_path = "/sc-projects/sc-proj-ukb-cvd"
else: 
    base_path = "/data/analysis/ag-reils/ag-reils-shared/cardioRS"
print(base_path)

project_label = "22_retina_phewas_220603_fullrun"
project_path = f"{base_path}/results/projects/{project_label}"
figure_path = f"{project_path}/figures"
output_path = f"{project_path}/data"

pathlib.Path(figure_path).mkdir(parents=True, exist_ok=True)
pathlib.Path(output_path).mkdir(parents=True, exist_ok=True)

experiment = '220603_fullrun'
experiment_path = f"{output_path}/{experiment}"
pathlib.Path(experiment_path).mkdir(parents=True, exist_ok=True)

name_dict = {
    "predictions_cropratio0.3": "ConvNextSmall(Retina)+MLP_cropratio0.3",
    "predictions_cropratio0.5": "ConvNextSmall(Retina)+MLP_cropratio0.5",
    "predictions_cropratio0.8": "ConvNextSmall(Retina)+MLP_cropratio0.8",
}

partitions = [i for i in range(22)]
partitions

/sc-projects/sc-proj-ukb-cvd


[0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21]

In [3]:
endpoint_defs = pd.read_feather(f"{output_path}/phecode_defs_220306.feather").sort_values("endpoint")

In [4]:
#endpoints = [e[:-6] for e in data_outcomes.columns if "_event" in e]
endpoints = [
#    "phecode_008",
#    "phecode_092-2",
#    "phecode_105",
#    "phecode_107-2",
#    "phecode_164",
#    "phecode_202-2",
#    "phecode_284",
#    "phecode_292",
#    "phecode_324-11",
#    "phecode_328",
#    "phecode_371",
#    "phecode_401",
#    "phecode_404",
#    "phecode_424",
#    "phecode_440-11",
#    "phecode_468",
#    "phecode_474",
#    "phecode_522-1",
#    "phecode_542-1",
#    "phecode_581-1",
#    "phecode_583",
#    "phecode_665",
#    "phecode_705-1",
    "OMOP_4306655"  
]

In [5]:
data_outcomes = pd.read_feather(f"{output_path}/baseline_outcomes_220627.feather").set_index("eid")
data_outcomes = data_outcomes[[c for c in data_outcomes.columns if "_event" in c and c[:-6] in endpoints]]

In [6]:
data_records = pd.read_feather(f"{output_path}/baseline_records_220627.feather").set_index("eid")

In [7]:
data_records = data_records[[c for c in tqdm(data_records.columns.to_list()) if "OMOP_" in c]]

  0%|          | 0/73871 [00:00<?, ?it/s]

In [8]:
records = data_records.columns.to_list()

In [9]:
data_all = data_records.merge(data_outcomes, left_index=True, right_index=True, how="left")

In [10]:
eligable_eids = pd.read_feather(f"{output_path}/eligable_eids_2022-07-01.feather")
eids_dict = eligable_eids.set_index("endpoint")["eid_list"].to_dict()

In [11]:
record_freqs = data_records.sum().sort_values(ascending=False).pipe(lambda x: x[x>=50])
record_freqs

OMOP_4081598    307739
OMOP_4052351    270116
OMOP_4061103    263319
OMOP_4144272    247882
OMOP_4057411    221203
                 ...  
OMOP_4039277        50
OMOP_4116240        50
OMOP_4050692        50
OMOP_4209141        50
OMOP_4171619        50
Length: 15595, dtype: int64

In [12]:
import ray

ray.init(num_cpus=16, include_dashboard=False)#dashboard_port=24763, dashboard_host="0.0.0.0", include_dashboard=True)#, webui_url="0.0.0.0"))

RayContext(dashboard_url=None, python_version='3.9.7', ray_version='1.12.1', ray_commit='4863e33856b54ccf8add5cbe75e41558850a1b75', address_info={'node_ip_address': '10.32.105.8', 'raylet_ip_address': '10.32.105.8', 'redis_address': None, 'object_store_address': '/tmp/ray/session_2022-07-06_11-23-54_669161_2112768/sockets/plasma_store', 'raylet_socket_name': '/tmp/ray/session_2022-07-06_11-23-54_669161_2112768/sockets/raylet', 'webui_url': None, 'session_dir': '/tmp/ray/session_2022-07-06_11-23-54_669161_2112768', 'metrics_export_port': 63549, 'gcs_address': '10.32.105.8:59865', 'address': '10.32.105.8:59865', 'node_id': '17fc500e6c07bf523452eb9424446cf711ecb9f8da4ac84f265d1ab9'})

In [13]:
@ray.remote
def calc_ratio(data_all, eids_dict, record, eids_record, eids_nonrecord, endpoints):
    r_ds = []
    
    for endpoint in endpoints:
        eids_endpoint = eids_dict[endpoint]
        
        # record set
        eid_idxs_dict = {}
        eid_idxs_dict["record"] = np.where(np.in1d(eids_endpoint, eids_record, assume_unique=True))[0]
        eid_idxs_dict["nonrecord"] = np.where(np.in1d(eids_endpoint, eids_nonrecord, assume_unique=True))[0]

        for key, eid_idxs in eid_idxs_dict.items():
            eids_temp = eids_endpoint[eid_idxs]
            s = data_all[f"{endpoint}_event"].loc[eids_temp]
            n=s.sum()
            freq = n/len(s)
            
            if key=="record":
                s_record
                n_record = n
                freq_record = freq
                
            if key=="nonrecord":
                s_nonrecord = s
                n_nonrecord = n
                freq_nonrecord = freq
        
        #jaccard = n_record / (n_nonrecord + len(eid_idxs_dict["record"])) 
        
        r_ds.append({"endpoint": endpoint, "n_eligable": len(eids_dict[endpoint]), 
                  "record": record, "n_records": len(eids_record), 
                  "n_events_record": n_record, "freq_events_record": freq_record,
                    "n_events_nonrecord": n_nonrecord, "freq_events_nonrecord": freq_nonrecord})
    return r_ds

In [14]:
d_nested = []
ref_data_all = ray.put(data_all)
ref_eids_dict = ray.put(eids_dict)
for record in tqdm(record_freqs.index):
    s_record = data_all[record]
    eids_record = s_record[s_record==True].index.values
    eids_nonrecord = s_record[s_record==False].index.values
    ref_results = calc_ratio.remote(ref_data_all, ref_eids_dict, record, eids_record, eids_nonrecord, endpoints)
    d_nested.append(ref_results)
d_nested = [ray.get(e) for e in tqdm(d_nested)]
del ref_data_all
del ref_eids_dict

  0%|          | 0/15595 [00:00<?, ?it/s]

  0%|          | 0/15595 [00:00<?, ?it/s]

In [15]:
from itertools import chain

d = list(chain(*d_nested))

In [16]:
endpoints_freqs = pd.DataFrame().from_dict(d)

In [17]:
endpoints_freqs.to_feather(f"{experiment_path}/records_inc_disease_freq.feather")

In [18]:
endpoints_freqs

Unnamed: 0,endpoint,n_eligable,record,n_records,n_events_record,freq_events_record,n_events_nonrecord,freq_events_nonrecord
0,OMOP_4306655,502453,OMOP_4081598,307739,12031,0.039095,25666,0.131810
1,OMOP_4306655,502453,OMOP_4052351,270116,10840,0.040132,26857,0.115592
2,OMOP_4306655,502453,OMOP_4061103,263319,7190,0.027306,30507,0.127570
3,OMOP_4306655,502453,OMOP_4144272,247882,7614,0.030717,30083,0.118168
4,OMOP_4306655,502453,OMOP_4057411,221203,12112,0.054756,25585,0.090967
...,...,...,...,...,...,...,...,...
15590,OMOP_4306655,502453,OMOP_4039277,50,1,0.020000,37696,0.075031
15591,OMOP_4306655,502453,OMOP_4116240,50,6,0.120000,37691,0.075021
15592,OMOP_4306655,502453,OMOP_4050692,50,8,0.160000,37689,0.075017
15593,OMOP_4306655,502453,OMOP_4209141,50,0,0.000000,37697,0.075033
