In [25]:
import numpy as np
import pandas as pd
import pathlib
from tqdm.auto import tqdm

import hydra
from omegaconf import DictConfig, OmegaConf

import torch
from torch_geometric import seed_everything

import ray

In [26]:
node = !hostname
if "sc" in node[0]:
    base_path = "/sc-projects/sc-proj-ukb-cvd"
else: 
    base_path = "/data/analysis/ag-reils/ag-reils-shared/cardioRS"
print(base_path)

project_label = "22_medical_records"
project_path = f"{base_path}/results/projects/{project_label}"
figure_path = f"{project_path}/figures"
output_path = f"{project_path}/data"

pathlib.Path(figure_path).mkdir(parents=True, exist_ok=True)
pathlib.Path(output_path).mkdir(parents=True, exist_ok=True)

experiment = 220413
experiment_path = f"{output_path}/{experiment}"
pathlib.Path(experiment_path).mkdir(parents=True, exist_ok=True)

/sc-projects/sc-proj-ukb-cvd


In [27]:
import wandb
api = wandb.Api()
entity, project = "cardiors", "recordgraphs"  # set to your entity and project 
runs = api.runs(entity + "/" + project)

In [28]:
run_list = []
for run in tqdm(runs): 
    run_list.append(
        {
            "id": run.path[-1], 
            "name": run.name,
            "tags": run.tags,
            "config": {k: v for k,v in run.config.items()},
            "summary": run.summary._json_dict,
            "path": None if "best_checkpoint" not in run.config.keys() else str(pathlib.Path(run.config["best_checkpoint"]).parent.parent)
        }
    )

  0%|          | 0/2287 [00:00<?, ?it/s]



In [32]:
runs_df = pd.DataFrame(run_list)

In [33]:
tag = "220413"
model = "identityagesex"
runs_df = runs_df[runs_df.tags.astype(str).str.contains(tag)].query("path==path")

In [34]:
attribution_metadata = runs_df[runs_df.name.astype(str).str.contains(model)].query("path==path")
attribution_metadata["partition"] = [eval(d["_content"]["datamodule"])["partition"] for d in attribution_metadata.config.to_list()]
attribution_metadata = attribution_metadata.sort_values("partition").reset_index(drop=True)
attribution_metadata

Unnamed: 0,id,name,tags,config,summary,path,partition
0,2x979oyf,220413identityagesex0,"[220413, full_data, identity]",{'losses': ['<ehrgraphs.models.loss_wrapper.En...,{'gradients/head.layers.5.weight': {'bins': [-...,/sc-projects/sc-proj-ukb-cvd/results/models/Re...,0
1,iv9mgbwk,220413identityagesex1,"[220413, full_data, identity]",{'losses': ['<ehrgraphs.models.loss_wrapper.En...,{'valid/phecode_734-9 - Jaw pain_CIndex': 0.74...,/sc-projects/sc-proj-ukb-cvd/results/models/Re...,1
2,1rubejdr,220413identityagesex2,"[220413, full_data, identity]",{'losses': ['<ehrgraphs.models.loss_wrapper.En...,{'valid/phecode_280 - Substance related disord...,/sc-projects/sc-proj-ukb-cvd/results/models/Re...,2
3,2trepsux,220413identityagesex3,"[220413, full_data, identity]",{'losses': ['<ehrgraphs.models.loss_wrapper.En...,{'valid/phecode_525-1 - Celiac disease_CIndex'...,/sc-projects/sc-proj-ukb-cvd/results/models/Re...,3
4,37x6n9iw,220413identityagesex4,"[220413, full_data, identity]",{'losses': ['<ehrgraphs.models.loss_wrapper.En...,{'valid/phecode_592-2 - Urethritis and urethra...,/sc-projects/sc-proj-ukb-cvd/results/models/Re...,4
5,rn6z39ky,220413identityagesex5,"[220413, full_data, identity]",{'losses': ['<ehrgraphs.models.loss_wrapper.En...,{'valid/phecode_618-5 - Prolapse of vaginal va...,/sc-projects/sc-proj-ukb-cvd/results/models/Re...,5
6,26hfbnfl,220413identityagesex6,"[220413, full_data, identity]",{'losses': ['<ehrgraphs.models.loss_wrapper.En...,{'valid/phecode_666-2 - Idiopathic urticaria_C...,/sc-projects/sc-proj-ukb-cvd/results/models/Re...,6
7,b24pbiyt,220413identityagesex7,"[220413, full_data, identity]",{'losses': ['<ehrgraphs.models.loss_wrapper.En...,{'valid/phecode_337-2 - Inflammatory polyneuro...,/sc-projects/sc-proj-ukb-cvd/results/models/Re...,7
8,17ajmonc,220413identityagesex8,"[220413, full_data, identity]",{'losses': ['<ehrgraphs.models.loss_wrapper.En...,{'valid/phecode_391-6 - Cholesteatoma of middl...,/sc-projects/sc-proj-ukb-cvd/results/models/Re...,8
9,1c0o97vl,220413identityagesex9,"[220413, full_data, identity]",{'losses': ['<ehrgraphs.models.loss_wrapper.En...,{'valid/phecode_723-5 - Tendinitis_CIndex': 0....,/sc-projects/sc-proj-ukb-cvd/results/models/Re...,9


In [35]:
in_path = pathlib.Path(f"{experiment_path}/loghs")

In [36]:
model = 'Identity(AgeSex+Records)+MLP'

In [37]:
partitions = [p for p in range(22)]

In [38]:
import pathlib

In [39]:
dfs = []
for partition in tqdm(partitions):
    fp_in = f"{in_path}/{model}/{partition}"
    fp_test = f"{fp_in}/test.feather"
    if pathlib.Path(fp_test).is_file():
        dfs.append(pd.read_feather(fp_test).set_index("eid"))

  0%|          | 0/22 [00:00<?, ?it/s]

In [40]:
predictions_wide = pd.concat(dfs).sort_index()

In [41]:
endpoint_cols = [c for c in predictions_wide.columns if "OMOP" in c or "phecode" in c]

In [42]:
from sklearn.preprocessing import StandardScaler
predictions_scaled = StandardScaler().fit_transform(predictions_wide[endpoint_cols])

In [43]:
import umap
reducer = umap.UMAP(verbose=True)

In [44]:
reducer.fit(predictions_scaled)

UMAP( verbose=True)
Fri Apr 15 13:43:42 2022 Construct fuzzy simplicial set
Fri Apr 15 13:43:43 2022 Finding Nearest Neighbors
Fri Apr 15 13:43:43 2022 Building RP forest with 40 trees
Fri Apr 15 13:44:08 2022 NN descent for 19 iterations
	 1  /  19
	 2  /  19
	 3  /  19
	 4  /  19
	 5  /  19
	Stopping threshold met -- exiting after 5 iterations
Fri Apr 15 13:44:48 2022 Finished Nearest Neighbor Search
Fri Apr 15 13:44:51 2022 Construct embedding




Epochs completed:   0%|            0/200 [00:00]

Fri Apr 15 14:05:21 2022 Finished embedding


UMAP( verbose=True)

In [45]:
embeddings = reducer.transform(predictions_scaled)

In [46]:
embeddings_df = pd.DataFrame(data=embeddings, index=predictions_wide.index.values).sort_index()
embeddings_df

Unnamed: 0,0,1
1000018,2.662032,-6.283121
1000020,-11.396855,5.266734
1000037,3.566753,-5.027369
1000043,6.165914,5.368650
1000051,4.097596,-1.529928
...,...,...
6025150,7.383742,-12.067386
6025165,0.110340,-4.913511
6025173,11.252336,4.008609
6025182,9.481647,6.293105


In [47]:
embeddings_df.columns = ["umap_0", "umap_1"]

In [48]:
embeddings_df.sort_index().rename_axis("eid").reset_index().to_feather(f"{experiment_path}/logh_umap_agesex.feather")