In [1]:
%load_ext autoreload
%autoreload 2

In [2]:
import hydra
import numpy as np
import torch
from hydra import compose, initialize
from omegaconf import OmegaConf
import pandas as pd
import pathlib

from tqdm.auto import tqdm

In [3]:
node = !hostname
if "sc" in node[0]:
    base_path = "/sc-projects/sc-proj-ukb-cvd"
else: 
    base_path = "/data/analysis/ag-reils/ag-reils-shared/cardioRS"
print(base_path)

project_label = "22_retina_phewas"
project_path = f"{base_path}/results/projects/{project_label}"
figure_path = f"{project_path}/figures"
output_path = f"{project_path}/data"

pathlib.Path(figure_path).mkdir(parents=True, exist_ok=True)
pathlib.Path(output_path).mkdir(parents=True, exist_ok=True)

experiment = 'test_experiment'
experiment_path = f"{output_path}/{experiment}"
pathlib.Path(experiment_path).mkdir(parents=True, exist_ok=True)

/sc-projects/sc-proj-ukb-cvd


In [4]:
import wandb
api = wandb.Api()
entity, project = "cardiors", "recordgraphs"  # set to your entity and project 
runs = api.runs(entity + "/" + project) 

In [5]:
run_list = []
for run in tqdm(runs): 
    run_list.append(
        {
            "id": run.path[-1], 
            "name": run.name,
            "tags": run.tags,
            "config": {k: v for k,v in run.config.items()},
            "summary": run.summary._json_dict,
            "path": None if "best_checkpoint" not in run.config.keys() else str(pathlib.Path(run.config["best_checkpoint"]).parent.parent)
        }
    )

  0%|          | 0/2519 [00:00<?, ?it/s]



ReadTimeout: HTTPSConnectionPool(host='api.wandb.ai', port=443): Read timed out. (read timeout=9)

In [None]:
runs_df = pd.DataFrame(run_list)

In [None]:
tag = "220413"
model = "identityagesex"
runs_df = runs_df[runs_df.tags.astype(str).str.contains(tag)].query("path==path")

In [None]:
attribution_metadata = runs_df[runs_df.name.astype(str).str.contains(model)].query("path==path")
attribution_metadata["partition"] = [eval(d["_content"]["datamodule"])["partition"] for d in attribution_metadata.config.to_list()]
attribution_metadata = attribution_metadata.sort_values("partition").reset_index(drop=True)
attribution_metadata

In [None]:
%env MKL_NUM_THREADS=4
%env NUMEXPR_NUM_THREADS=4
%env OMP_NUM_THREADS=4

In [None]:
endpoints_md = pd.read_csv(f"{experiment_path}/endpoints.csv")
endpoints = sorted(endpoints_md.endpoint.to_list())

In [None]:
endpoint_defs = pd.read_feather(f"{output_path}/phecode_defs_220306.feather").query("endpoint==@endpoints").sort_values("endpoint").set_index("endpoint")

In [None]:
endpoint_defs["label"] = endpoint_defs.index + " - " + endpoint_defs["phecode_string"].str.replace(".", "-")
endpoint_defs.at["OMOP_4306655", "label"] = "OMOP_4306655"
endpoint_defs.query("endpoint=='phecode_008'")

In [None]:
import ray

ray.init(num_cpus=24, include_dashboard=False)#dashboard_port=24763, dashboard_host="0.0.0.0", include_dashboard=True)#, webui_url="0.0.0.0"))

In [14]:
import zstandard
import pickle

@ray.remote
def load_tensor(fp):
    with open(fp, "rb") as fh:
        dctx = zstandard.ZstdDecompressor()
        with dctx.stream_reader(fh) as decompressor:
            data = pickle.loads(decompressor.read())
    return torch.Tensor(data)

def calc_per_endpoint(fps, idxs):
    tensors = [load_tensor.remote(fp) for fp in fps]
    temp_shap = torch.cat([ray.get(t) for t in tensors], dim=0).index_select(0, idxs)
    shap_mask = temp_shap!=0
    shap_local = (temp_shap.sum(dim=0) / (shap_mask > 0).sum(dim=0)).numpy()
    return shap_local

In [49]:
ray.shutdown()

In [15]:
endpoint_dict = endpoint_defs["label"].to_dict()
endpoints = sorted(endpoint_dict)
endpoint_values = [endpoint_dict[key] for key in endpoints]

In [16]:
attribution_paths = attribution_metadata[["id", "path", "partition"]].sort_values("partition")

In [17]:
base_paths = [f"{row['path']}/attributions" for i, row in attribution_paths.iterrows()]

In [18]:
eids = pd.concat([pd.read_csv(f"{p}/eids.txt", header=None) for p in base_paths])[0].values
features = pd.read_csv(f"{base_paths[0]}/features.txt", header=None)[0].values

In [19]:
eligable_eids = pd.read_feather(f"{output_path}/eligable_eids_220414.feather")
eids_dict = eligable_eids.set_index("endpoint")["eid_list"].to_dict()

In [20]:
idx_dict = {key: torch.from_numpy(np.where(np.in1d(eids, eids_incl))[0]) for key, eids_incl in tqdm(list(eids_dict.items()))}

  0%|          | 0/1684 [00:00<?, ?it/s]

In [21]:
del eligable_eids
del eids_dict

In [None]:
from captum.attr import DeepLiftShap
import zstandard
import pickle
shap_dict = {}

#ray_test_data = ray.put(test_data)
for endpoint in tqdm(endpoints):  
    
    endpoint_label = endpoint_dict[endpoint]
    
    fps = [f"{path}/attributions/shap_{endpoint_label}_{partition}.p" 
           for path, partition in zip(attribution_paths.path.to_list(), attribution_paths.partition.to_list())]
    
    try:
        shap_dict[endpoint] = calc_per_endpoint(fps, idx_dict[endpoint])
    except FileNotFoundError:
        print(endpoint)

In [26]:
#shap_local = {key: value for key, value in tqdm(shap_dict.items())}
#shap_global = {key: value["global"] for key, value in tqdm(shap_dict.items())}

  0%|          | 0/1001 [00:00<?, ?it/s]

In [24]:
def get_shap_df(shap_dict, endpoints, features):
    shap_df = pd.DataFrame(data=shap_dict, 
                     index=features,
                     columns=endpoints).rename_axis("record")

    #shap_df["record_n"] = (test_data > 0).sum(dim=0).numpy()
    #shap_df["record_freq"] = (test_data > 0).sum(dim=0).numpy()/len(test_data)

    shap_mapped_df = shap_df#.reset_index().set_index("record")#.set_index(["record", "concept_name", 
                   # "record_n", "record_freq"
                   #])
    return shap_mapped_df

In [25]:
shap_local_df = get_shap_df(shap_dict, endpoints, features)

In [26]:
shap_local_df.reset_index().to_feather(f"{experiment_path}/shap_local.feather")

In [27]:
shap_local_df.sort_values("OMOP_4306655", ascending=False)

Unnamed: 0_level_0,OMOP_4306655,phecode_001,phecode_002,phecode_002-1,phecode_003,phecode_004,phecode_004-1,phecode_005,phecode_005-1,phecode_006,...,phecode_950,phecode_954,phecode_976,phecode_979,phecode_979-2,phecode_981,phecode_983,phecode_988,phecode_989,phecode_997
record,Unnamed: 1_level_1,Unnamed: 2_level_1,Unnamed: 3_level_1,Unnamed: 4_level_1,Unnamed: 5_level_1,Unnamed: 6_level_1,Unnamed: 7_level_1,Unnamed: 8_level_1,Unnamed: 9_level_1,Unnamed: 10_level_1,Unnamed: 11_level_1,Unnamed: 12_level_1,Unnamed: 13_level_1,Unnamed: 14_level_1,Unnamed: 15_level_1,Unnamed: 16_level_1,Unnamed: 17_level_1,Unnamed: 18_level_1,Unnamed: 19_level_1,Unnamed: 20_level_1,Unnamed: 21_level_1
OMOP_4058696,0.549468,0.029723,0.025715,0.549770,0.228404,0.082867,0.182187,0.165308,-0.012237,0.215066,...,-0.110415,-0.027592,0.372499,0.008523,0.122815,0.022191,0.243612,0.013182,-0.399509,0.184798
OMOP_200451,0.476312,0.356606,0.351725,0.248946,0.396733,0.393189,0.459947,0.456302,0.450454,0.522613,...,0.239419,0.305234,0.599032,0.262643,0.362190,0.426654,0.300184,0.311388,-0.076667,0.440549
OMOP_436043,0.396561,0.161382,0.150268,0.336060,0.299898,0.260065,0.193798,0.172257,0.180445,0.311304,...,,,0.349427,0.099276,0.205647,0.158167,0.310194,0.474510,-0.397817,0.230082
OMOP_44784106,0.393289,0.275448,0.270610,0.189553,0.291478,0.299761,0.313691,0.312247,0.271191,0.382349,...,-0.076444,-0.066753,0.334913,0.173953,0.263428,0.253763,0.349042,0.144425,-0.473518,0.380342
OMOP_4069332,0.384898,0.201218,0.198611,0.320965,0.239953,0.222439,0.317807,0.310292,0.259660,0.357495,...,0.165758,0.107004,0.446672,0.176212,0.136582,0.121407,0.261274,0.325633,-0.309345,0.353391
...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...
OMOP_948555,,,,,,,,,,,...,,,,,,,,,,
OMOP_960900,,,,,,,,,,,...,,,,,,,,,,
OMOP_987366,,,,,,,,,,,...,,,,,,,,,,
OMOP_991003,,,,,,,,,,,...,,,,,,,,,,


In [30]:
#shap_global_df.reset_index().to_csv(f"{output_path}/shap_global_p0.csv")

In [35]:
#shap_global_df.reset_index().to_excel(f"{output_path}/shap_global_p0.xlsx")

In [31]:
shap_local_df = get_shap_df(shap_local, test_data, datamodule, concepts)

In [32]:
shap_local_df.reset_index().to_csv(f"{output_path}/shap_local_p0.csv")

In [38]:
shap_local_df.reset_index().to_excel(f"{output_path}/shap_local_p0.xlsx")

In [37]:
#with pd.ExcelWriter(f"{output_path}/shap_p0_220228.xlsx") as writer:  
    #shap_global_df.reset_index().to_excel(writer, sheet_name='Global')
    #shap_local_df.reset_index().to_excel(writer, sheet_name='Local')

In [50]:
%%time
test_fp = "/sc-projects/sc-proj-ukb-cvd/results/projects/22_medical_records/data/shap_220227/shapmatrix_phecode_006-2 - Neisseria gonorrhea.p"
x = load_tensor(test_fp)

CPU times: user 851 ms, sys: 4.65 ms, total: 855 ms
Wall time: 857 ms


In [None]:
load_tensor(test_fp)

In [52]:
%%time
zarr.save(f"{test_fp[:-2]}.zarr", x.numpy())

CPU times: user 1.91 s, sys: 261 ms, total: 2.18 s
Wall time: 6.8 s


In [56]:
%%time
y = zarr.load(f"{test_fp[:-2]}.zarr")

CPU times: user 2.21 s, sys: 121 ms, total: 2.33 s
Wall time: 3.74 s


In [55]:
%%time
save_tensor(f"{test_fp[:-2]}2.zarr", x)

CPU times: user 1.13 s, sys: 11.7 ms, total: 1.14 s
Wall time: 1.2 s


In [53]:
def save_tensor(fp, tensor):
    array = tensor.numpy()
    with open(fp, "wb") as fh:
        cctx = zstandard.ZstdCompressor()
        with cctx.stream_writer(fh) as compressor:
            compressor.write(pickle.dumps(array, protocol=pickle.HIGHEST_PROTOCOL))

In [43]:
pd.DataFrame(data=x, index=datamodule.eids["test"], columns = datamodule.record_cols + ['age_at_recruitment_f21022_0_0', 'sex_f31_0_0_1', 'sex_f31_0_0_0'])

KeyboardInterrupt: 

In [46]:
x.shape

torch.Size([44193, 16204])

In [47]:
import zarr

ModuleNotFoundError: No module named 'zarr'

In [48]:
import zarr

In [None]:
# fix for covariates
shap_global_df = pd.DataFrame(data=shap_global, 
                 index=datamodule.record_cols + ['age_at_recruitment_f21022_0_0', 'sex_f31_0_0_1', 'sex_f31_0_0_0'],
                 columns=endpoints).rename_axis("record")

shap_global_df["partition"] = 0
shap_global_df["record_n"] = (test_data > 0).sum(dim=0).numpy()
shap_global_df["record_freq"] = (test_data > 0).sum(dim=0).numpy()/len(test_data)

shap_global_mapped_df = shap_global_mapped_df.merge(concepts[["concept_name"]], how="left", left_index=True, right_index=True).reset_index().set_index(["partition", "record", "concept_name", "record_n", "record_freq"])

shap_global_mapped_df.reset_index().to_csv(f"{output_path}/shap_global_p0.csv")

In [None]:
# fix for covariates
x = pd.DataFrame(data=shap_local, 
                 index=datamodule.record_cols + ['age_at_recruitment_f21022_0_0', 'sex_f31_0_0_1', 'sex_f31_0_0_0'],
                 columns=endpoints).rename_axis("record")

x["partition"] = 0
x["record_n"] = (test_data > 0).sum(dim=0).numpy()
x["record_freq"] = (test_data > 0).sum(dim=0).numpy()/len(test_data)

x_mapped = x.merge(concepts[["concept_name"]], how="left", left_index=True, right_index=True).reset_index().set_index(["partition", "record", "concept_name", "record_n", "record_freq"])

x_mapped.reset_index().to_csv(f"{output_path}/shapglobal_p0_test.csv")

In [48]:
pd.set_option('display.max_rows', 25)
code = "phecode_121 - Leukemia"
display(x_mapped[code].dropna().sort_values().head(25))
display(x_mapped[code].dropna().sort_values().tail(25))

record         concept_name                                      
OMOP_4195603   Operation on vagina                                  -0.275197
OMOP_4263879   Marie's cerebellar ataxia                            -0.260202
OMOP_434316    von Willebrand disorder                              -0.241875
OMOP_439404    Primary malignant neoplasm of oral cavity            -0.235471
OMOP_435094    Open fracture of femur, distal end                   -0.187018
OMOP_4030065   Hyposplenism                                         -0.186954
OMOP_4018853   Decompression of cardiac tamponade                   -0.176256
OMOP_133713    Malignant melanoma of skin of face                   -0.176001
OMOP_4070718   Biopsy of lesion of palate                           -0.174164
OMOP_4168815   Contact with plant spines, thorns, or sharp leaves   -0.170073
OMOP_4289309   Atrial septal defect                                 -0.165623
OMOP_4052685   Metatarsus adductus                                  -0.15743

record                         concept_name                                                       
OMOP_4058706                   History of leukemia                                                    0.286241
OMOP_4013643                   Pulmonary arterial hypertension                                        0.289997
OMOP_195861                    Small kidney                                                           0.293899
OMOP_4304002                   Eosinophil count raised                                                0.301299
OMOP_4060043                   Excision of lesion of atrium                                           0.304675
OMOP_4017875                   Homograft tricuspid valve replacement                                  0.307416
OMOP_4198132                   Hematology screening test                                              0.308533
OMOP_4114976                   Balanced rearrangement and structural marker                           0.308767
OMOP_4027567 

In [None]:
x_mapped["phecode_089 - Infections"].sort_values()

In [75]:
datamodule.label_mapping

{'OMOP_4306655': 'OMOP_4306655',
 'phecode_401': 'phecode_401 - Hypertension',
 'phecode_401-1': 'phecode_401-1 - Essential hypertension',
 'phecode_130': 'phecode_130 - Cancer (solid tumor, excluding BCC)',
 'phecode_089': 'phecode_089 - Infections',
 'phecode_460': 'phecode_460 - Acute respiratory infection',
 'phecode_202': 'phecode_202 - Diabetes mellitus',
 'phecode_202-2': 'phecode_202-2 - Type 2 diabetes',
 'phecode_404': 'phecode_404 - Ischemic heart disease',
 'phecode_583': 'phecode_583 - Chronic kidney disease',
 'phecode_089-2': 'phecode_089-2 - Viral infections',
 'phecode_718': 'phecode_718 - Back pain',
 'phecode_708': 'phecode_708 - Osteoarthritis',
 'phecode_475': 'phecode_475 - Asthma',
 'phecode_089-1': 'phecode_089-1 - Bacterial infections',
 'phecode_809': 'phecode_809 - Pain',
 'phecode_713': 'phecode_713 - Symptoms related to joints',
 'phecode_416': 'phecode_416 - Cardiac arrhythmia and conduction disorders',
 'phecode_239': 'phecode_239 - Hyperlipidemia',
 'phe

In [None]:
shap_values = shap_values if not isinstance(shap_values, list) else shap_values[0]

print('Done')

    