In [9]:
import numpy as np
import pandas as pd
import pathlib
import os
from tqdm.auto import tqdm

import hydra
from omegaconf import DictConfig, OmegaConf

import torch
#from torch_geometric import seed_everything

import ray

In [10]:
node = !hostname
if "sc" in node[0]:
    base_path = "/sc-projects/sc-proj-ukb-cvd"
else: 
    base_path = "/data/analysis/ag-reils/ag-reils-shared/cardioRS"
print(base_path)

project_label = "22_retina_phewas"
project_path = f"{base_path}/results/projects/{project_label}"
figure_path = f"{project_path}/figures"
output_path = f"{project_path}/data"

pathlib.Path(figure_path).mkdir(parents=True, exist_ok=True)
pathlib.Path(output_path).mkdir(parents=True, exist_ok=True)

experiment = '230905'
experiment_path = f"{output_path}/{experiment}"
print('experiment path:', experiment_path)
pathlib.Path(experiment_path).mkdir(parents=True, exist_ok=True)

name_dict = {
#     "predictions_cropratio0.3": "ConvNextSmall(Retina)+MLP_cropratio0.3",
#     "predictions_cropratio0.5": "ConvNextSmall(Retina)+MLP_cropratio0.5",
#    "predictions_cropratio0.66": "ConvNextSmall(Retina)+MLP_cropratio0.66",
    "predictions": "ConvNextSmall(Retina)+MLP_cropratio0.66",
}

#partitions = [i for i in range(22)]
partitions = [4, 5, 7, 9, 10, 20] # Partitions with eye test centers

/sc-projects/sc-proj-ukb-cvd
experiment path: /sc-projects/sc-proj-ukb-cvd/results/projects/22_retina_phewas/data/230905


In [11]:
endpoint_columns = sorted([l.replace('_prevalent', '') for l in list(pd.read_csv('/sc-projects/sc-proj-ukb-cvd/results/projects/22_retinal_risk/data/230905/min100_endpoints.csv').endpoint.values)])
len(endpoint_columns)

773

In [12]:
#ray.shutdown()
#ray.init(num_cpus=24)
# ray.init(address='auto')

In [13]:
import wandb
api = wandb.Api()
entity, project = "cardiors", "retina" 
tag = '230905'
runs = api.runs(entity + "/" + project, filters={"tags": {"$in": [tag]}}) 

In [14]:
run_list = []
for run in tqdm(runs): 
    run_list.append(
        {
            #"id": run.path[-1], 
            "id": run.id, 
            "name": run.name,
            "tags": run.tags,
            "partition": eval(run.config['_content']['datamodule'])['partition'],
            "config": {k: v for k,v in run.config.items() if not k.startswith('_')},
            "summary": run.summary._json_dict,
            "path": f'/sc-projects/sc-proj-ukb-cvd/results/models/retina/{run.id}/predictions/' if "predictions_path" not in run.config.keys() else str(pathlib.Path(run.config["predictions_path"]))
#             'path': f'/sc-projects/sc-proj-ukb-cvd/results/models/retina/{run.id}/checkpoints/predictions/'
        }
    )

  0%|          | 0/6 [00:00<?, ?it/s]

In [15]:
# select those w/ predictions path:
runs_df = pd.DataFrame(run_list)
runs_df = runs_df[runs_df.partition.isin(partitions)]

In [16]:
runs_df

Unnamed: 0,id,name,tags,partition,config,summary,path
0,20innqco,230905_fullrun_retina,"[230905, baseline_data, image]",20,{'losses': ['<retinalrisk.models.loss_wrapper....,{'valid/phecode_520-2 - Diaphragmatic hernia [...,/sc-projects/sc-proj-ukb-cvd/results/models/re...
1,3egdzwli,230905_fullrun_retina,"[230905, baseline_data, image]",10,{'losses': ['<retinalrisk.models.loss_wrapper....,{'valid/phecode_395 - Other diseases of inner ...,/sc-projects/sc-proj-ukb-cvd/results/models/re...
2,3u8acsmy,230905_fullrun_retina,"[230905, baseline_data, image]",9,{'losses': ['<retinalrisk.models.loss_wrapper....,{'valid/phecode_800 - Chest pain_CIndex': 0.56...,/sc-projects/sc-proj-ukb-cvd/results/models/re...
3,32rned74,230905_fullrun_retina,"[230905, baseline_data, image]",7,{'losses': ['<retinalrisk.models.loss_wrapper....,{'valid/phecode_430-2 - Nontraumatic intracere...,/sc-projects/sc-proj-ukb-cvd/results/models/re...
4,2owxdpk9,230905_fullrun_retina,"[230905, baseline_data, image]",5,{'losses': ['<retinalrisk.models.loss_wrapper....,{'valid/phecode_678 - Other skin and connectiv...,/sc-projects/sc-proj-ukb-cvd/results/models/re...
5,1dra1ycg,230905_fullrun_retina,"[230905, baseline_data, image]",4,{'losses': ['<retinalrisk.models.loss_wrapper....,{'gradients/encoder.features.5.8.block.2.weigh...,/sc-projects/sc-proj-ukb-cvd/results/models/re...


In [17]:
print(runs_df['path'].iloc[0])

/sc-projects/sc-proj-ukb-cvd/results/models/retina/20innqco/predictions/


## Process Predictions

In [18]:
id_vars = ["eid", "model", "partition", "split"]

In [19]:
out_path = f"{experiment_path}/loghs"
pathlib.Path(out_path).mkdir(parents=True, exist_ok=True)

In [20]:
out_path

'/sc-projects/sc-proj-ukb-cvd/results/projects/22_retina_phewas/data/230905/loghs'

In [21]:
# @ray.remote
def prepare_predictions(in_path, out_path):
    for cr in name_dict.keys():
        in_path_cr = os.path.join(in_path, f'{cr}.feather')
        temp = pd.read_feather(in_path_cr).rename(columns={"index": "eid"}).set_index('eid')
        
        # skip stuff if already written:
        # mean duplicated indices aka left and right eye value per eid
        temp_with_meaned_test_preds = temp.groupby(level=0).mean()
        
        # recover columns that are non-endpoints:
        other_cols = [c for c in temp.columns.values if c not in endpoint_columns]
        temp_with_meaned_test_preds = temp_with_meaned_test_preds.merge(temp[other_cols][~temp.index.duplicated()], left_index=True, right_index=True, how='left')
   
        temp = temp_with_meaned_test_preds.reset_index(drop=False)

        # rename 10_1_Ft__ and dropping ft and St cols if present
        cols_to_drop = []
        cols_to_rename = {}
        for col in temp.columns.values:
            if 'ft' in col:
                cols_to_drop.append(col)
            elif 'St' in col:
                cols_to_drop.append(col)
            elif 'Ft' in col:
                cols_to_rename[col] = col.replace('1_10_Ft__', '')
                
        temp['record_cols'] = None
        temp["model"] = (temp.module.astype(str) + "_" + temp.covariate_cols.astype(str) + "_" + temp.encoder.astype(str) + "_" + temp["head"].astype(str)).astype("category")
        temp = temp.replace({"model":name_dict}).drop(columns=["module", "encoder", "head", "covariate_cols", "record_cols"]).drop(columns=cols_to_drop).rename(columns=cols_to_rename)
        for c in id_vars: 
            temp[c] = temp[c].astype("category")
            
        model = temp.model.unique()[0]
        model = f'{model}_{cr}'
        partition = temp.partition.unique()[0]
        
#         if os.path.exists(f"{out_path}/{model}/{partition}/train.feather"):
#             if os.path.exists(f"{out_path}/{model}/{partition}/test.feather") and os.path.exists(f"{out_path}/{model}/{partition}/valid.feather"):
#                 print(f'skipping {partition} as already exists')
#                 continue
        
   
        for split in ["train", "valid", "test"]: #"test_left", 'test_right'
            fp_out = f"{out_path}/{model}/{partition}"
            pathlib.Path(fp_out).mkdir(parents=True, exist_ok=True)
            t = temp.query("split==@split")
            t.reset_index(drop=True).to_feather(f"{fp_out}/{split}.feather")
            print(f"{fp_out}/{split}.feather")

In [22]:
for row_idx in tqdm(range(len(runs_df))): 
    row = runs_df.iloc[row_idx]
    p = row['path']
    id = row['id']
#     prepare_predictions.remote(p, out_path)
    prepare_predictions(p, out_path)

  0%|          | 0/6 [00:00<?, ?it/s]

/sc-projects/sc-proj-ukb-cvd/results/projects/22_retina_phewas/data/230905/loghs/ImageTraining_[]_ConvNeXt_MLPHead_predictions/20/train.feather
/sc-projects/sc-proj-ukb-cvd/results/projects/22_retina_phewas/data/230905/loghs/ImageTraining_[]_ConvNeXt_MLPHead_predictions/20/valid.feather
/sc-projects/sc-proj-ukb-cvd/results/projects/22_retina_phewas/data/230905/loghs/ImageTraining_[]_ConvNeXt_MLPHead_predictions/20/test.feather
/sc-projects/sc-proj-ukb-cvd/results/projects/22_retina_phewas/data/230905/loghs/ImageTraining_[]_ConvNeXt_MLPHead_predictions/10/train.feather
/sc-projects/sc-proj-ukb-cvd/results/projects/22_retina_phewas/data/230905/loghs/ImageTraining_[]_ConvNeXt_MLPHead_predictions/10/valid.feather
/sc-projects/sc-proj-ukb-cvd/results/projects/22_retina_phewas/data/230905/loghs/ImageTraining_[]_ConvNeXt_MLPHead_predictions/10/test.feather
/sc-projects/sc-proj-ukb-cvd/results/projects/22_retina_phewas/data/230905/loghs/ImageTraining_[]_ConvNeXt_MLPHead_predictions/9/train.fea

In [23]:
out_path

'/sc-projects/sc-proj-ukb-cvd/results/projects/22_retina_phewas/data/230905/loghs'

In [24]:
!ls -lah {out_path}/'ImageTraining_[]_ConvNeXt_MLPHead_predictions'

total 256K
drwxrwx--- 8 loockl posix-nogroup 116 Sep 13 15:56 .
drwxrwx--- 3 loockl posix-nogroup  63 Sep 13 15:50 ..
drwxrwx--- 2 loockl posix-nogroup  92 Sep 13 15:51 10
drwxrwx--- 2 loockl posix-nogroup  92 Sep 13 15:50 20
drwxrwx--- 2 loockl posix-nogroup  92 Sep 13 15:56 4
drwxrwx--- 2 loockl posix-nogroup  92 Sep 13 15:54 5
drwxrwx--- 2 loockl posix-nogroup  92 Sep 13 15:53 7
drwxrwx--- 2 loockl posix-nogroup  92 Sep 13 15:52 9
