In [1]:
%load_ext autoreload
%autoreload 2

In [2]:
import hydra
import numpy as np
import torch
from hydra import compose, initialize
from omegaconf import OmegaConf
import pandas as pd

from tqdm.auto import tqdm

In [3]:
from ehrgraphs.models.supervised import RecordsGraphTraining
from ehrgraphs.training import setup_training

Using backend: pytorch


In [4]:
import pathlib

In [5]:
node = !hostname
if "sc" in node[0]:
    base_path = "/sc-projects/sc-proj-ukb-cvd"
else: 
    base_path = "/data/analysis/ag-reils/ag-reils-shared/cardioRS"
print(base_path)

project_label = "22_medical_records"
project_path = f"{base_path}/results/projects/{project_label}"
figure_path = f"{project_path}/figures"
output_path = f"{project_path}/data"

pathlib.Path(figure_path).mkdir(parents=True, exist_ok=True)
pathlib.Path(output_path).mkdir(parents=True, exist_ok=True)

/sc-projects/sc-proj-ukb-cvd


In [6]:
#from ehrgraphs.utils.attribution import ShapWrapper

In [7]:
import shap as shap

In [8]:
hydra.initialize(config_path="../../ehrgraphs/config")

hydra.initialize()

In [9]:
# basic:
cfg = hydra.compose(config_name="config")
checkpoint_path = "/sc-projects/sc-proj-ukb-cvd/results/recordgraphs/outputs/2022-02-21/14-08-30/RecordGraphs/189e3kgt/checkpoints/epoch=88-step=4182.ckpt"

In [10]:
device = torch.device("cpu")#;#("cuda:0")

In [11]:
concepts = pd.read_feather("/sc-projects/sc-proj-ukb-cvd/data/mapping/athena/CONCEPT.feather")
concepts = concepts.assign(record = lambda x: "OMOP_" + x.concept_id.astype(str)).set_index("record")

In [12]:
def prepare_test_dataloader(datamodule):
    _ = datamodule.train_dataloader()
    test_dataloader = datamodule.test_dataloader()
        
    return test_dataloader

In [13]:
def prepare_data(datamodule, n_rows=None):
    # and returns an Nx10 tensor of class probabilities.
    # collect inputs
    test_data = torch.cat((torch.Tensor(datamodule.test_dataset.records.todense()),
                           torch.Tensor(datamodule.test_dataset.covariates)), axis=1)
    
    if n_rows is not None:
        test_data = test_data[:n_rows, :].to(device) # TODO remov
    else:
        test_data = test_data[:, :].to(device)

    baseline_data = torch.zeros(2, test_data.shape[1]).to(device)
    baseline_data.requires_grad = True
    
    return test_data, baseline_data

In [14]:
datamodule, model, _ = setup_training(cfg)

/sc-projects/sc-proj-ukb-cvd/data/2_datasets_pre/211110_anewbeginning/artifacts/WandBGraphDataNoShortcuts256_220203.p
Generating train dataset...
Generating valid dataset...
Using edge types: ['Is a' 'ATC - RxNorm sec up' 'Tradename of' 'Has tradename' 'Constitutes'
 'RxNorm inverse is a' 'RxNorm ing of' 'Dose form group of' 'Has method'
 'Has asso morph' 'Has interprets' 'Interprets of' 'Is descendant of'
 'Is associated with' 'Is ancestor of' 'Asso morph of' 'Method of'
 'Interacts with' 'Is part of' 'Composed of']


In [16]:
test_dataloader = prepare_test_dataloader(datamodule)

Generating test dataset...


In [17]:
test_data, baseline_data = prepare_data(datamodule)

In [18]:
%env MKL_NUM_THREADS=1
%env NUMEXPR_NUM_THREADS=1
%env OMP_NUM_THREADS=1

env: MKL_NUM_THREADS=1
env: NUMEXPR_NUM_THREADS=1
env: OMP_NUM_THREADS=1


In [19]:
import ray

ray.init(num_cpus=24, dashboard_port=24763, dashboard_host="0.0.0.0", include_dashboard=True)#, webui_url="0.0.0.0"))

2022-02-28 09:50:39,825	INFO services.py:1374 -- View the Ray dashboard at [1m[32mhttp://10.32.105.10:24763[39m[22m


{'node_ip_address': '10.32.105.10',
 'raylet_ip_address': '10.32.105.10',
 'redis_address': '10.32.105.10:61328',
 'object_store_address': '/tmp/ray/session_2022-02-28_09-50-33_282458_490686/sockets/plasma_store',
 'raylet_socket_name': '/tmp/ray/session_2022-02-28_09-50-33_282458_490686/sockets/raylet',
 'webui_url': '10.32.105.10:24763',
 'session_dir': '/tmp/ray/session_2022-02-28_09-50-33_282458_490686',
 'metrics_export_port': 48945,
 'gcs_address': '10.32.105.10:39609',
 'node_id': '2f6c1521ba3c41c3eb1da60357929381507c181ebd21066c6b69b52a'}

In [20]:
n_records = len(datamodule.record_cols)

In [21]:
attrib_path = f"{output_path}/shap_220227"
pathlib.Path(attrib_path).mkdir(parents=True, exist_ok=True)

In [40]:
import zstandard
import pickle

def load_tensor(fp):
    with open(fp, "rb") as fh:
        dctx = zstandard.ZstdDecompressor()
        with dctx.stream_reader(fh) as decompressor:
            data = pickle.loads(decompressor.read())
    return torch.Tensor(data)

In [None]:
@ray.remote
def calc_per_endpoint(fp, test_data):
    temp_shap = load_tensor(fp)
    shap_local = (temp_shap.sum(dim=0) / (test_data > 0).sum(dim=0)).numpy()
    return shap_local

In [49]:
ray.shutdown()

In [24]:
from captum.attr import DeepLiftShap
import zstandard
import pickle
shap_dict = {}

endpoints = list(datamodule.label_mapping.values())

ray_test_data = ray.put(test_data)

for endpoint_idx in tqdm(range(len(endpoints))):  
    
    endpoint_label = endpoints[endpoint_idx]
    
    fp = f"{attrib_path}/shapmatrix_{endpoint_label}.p"
    
    shap_dict[endpoint_label] = calc_per_endpoint.remote(fp, ray_test_data)

  0%|          | 0/1001 [00:00<?, ?it/s]

In [None]:
shap_dict = {key: ray.get(value) for key, value in tqdm(shap_dict.items())}

In [26]:
shap_local = {key: value for key, value in tqdm(shap_dict.items())}
#shap_global = {key: value["global"] for key, value in tqdm(shap_dict.items())}

  0%|          | 0/1001 [00:00<?, ?it/s]

In [28]:
def get_shap_df(shap_dict, test_data, datamodule, concepts):
    shap_df = pd.DataFrame(data=shap_dict, 
                     index=datamodule.record_cols + ['age_at_recruitment_f21022_0_0', 'sex_f31_0_0_1', 'sex_f31_0_0_0'],
                     columns=endpoints).rename_axis("record")

    shap_df["partition"] = 0
    shap_df["record_n"] = (test_data > 0).sum(dim=0).numpy()
    shap_df["record_freq"] = (test_data > 0).sum(dim=0).numpy()/len(test_data)

    shap_mapped_df = shap_df.merge(concepts[["concept_name"]], how="left", left_index=True, right_index=True).reset_index()\
        .set_index(["partition", "record", "concept_name", "record_n", "record_freq"])
    return shap_mapped_df

In [29]:
shap_global_df = get_shap_df(shap_global, test_data, datamodule, concepts)

In [30]:
#shap_global_df.reset_index().to_csv(f"{output_path}/shap_global_p0.csv")

In [35]:
#shap_global_df.reset_index().to_excel(f"{output_path}/shap_global_p0.xlsx")

In [31]:
shap_local_df = get_shap_df(shap_local, test_data, datamodule, concepts)

In [32]:
shap_local_df.reset_index().to_csv(f"{output_path}/shap_local_p0.csv")

In [38]:
shap_local_df.reset_index().to_excel(f"{output_path}/shap_local_p0.xlsx")

In [37]:
#with pd.ExcelWriter(f"{output_path}/shap_p0_220228.xlsx") as writer:  
    #shap_global_df.reset_index().to_excel(writer, sheet_name='Global')
    #shap_local_df.reset_index().to_excel(writer, sheet_name='Local')

In [50]:
%%time
test_fp = "/sc-projects/sc-proj-ukb-cvd/results/projects/22_medical_records/data/shap_220227/shapmatrix_phecode_006-2 - Neisseria gonorrhea.p"
x = load_tensor(test_fp)

CPU times: user 851 ms, sys: 4.65 ms, total: 855 ms
Wall time: 857 ms


In [None]:
load_tensor(test_fp)

In [52]:
%%time
zarr.save(f"{test_fp[:-2]}.zarr", x.numpy())

CPU times: user 1.91 s, sys: 261 ms, total: 2.18 s
Wall time: 6.8 s


In [56]:
%%time
y = zarr.load(f"{test_fp[:-2]}.zarr")

CPU times: user 2.21 s, sys: 121 ms, total: 2.33 s
Wall time: 3.74 s


In [55]:
%%time
save_tensor(f"{test_fp[:-2]}2.zarr", x)

CPU times: user 1.13 s, sys: 11.7 ms, total: 1.14 s
Wall time: 1.2 s


In [53]:
def save_tensor(fp, tensor):
    array = tensor.numpy()
    with open(fp, "wb") as fh:
        cctx = zstandard.ZstdCompressor()
        with cctx.stream_writer(fh) as compressor:
            compressor.write(pickle.dumps(array, protocol=pickle.HIGHEST_PROTOCOL))

In [43]:
pd.DataFrame(data=x, index=datamodule.eids["test"], columns = datamodule.record_cols + ['age_at_recruitment_f21022_0_0', 'sex_f31_0_0_1', 'sex_f31_0_0_0'])

KeyboardInterrupt: 

In [46]:
x.shape

torch.Size([44193, 16204])

In [47]:
import zarr

ModuleNotFoundError: No module named 'zarr'

In [48]:
import zarr

In [None]:
# fix for covariates
shap_global_df = pd.DataFrame(data=shap_global, 
                 index=datamodule.record_cols + ['age_at_recruitment_f21022_0_0', 'sex_f31_0_0_1', 'sex_f31_0_0_0'],
                 columns=endpoints).rename_axis("record")

shap_global_df["partition"] = 0
shap_global_df["record_n"] = (test_data > 0).sum(dim=0).numpy()
shap_global_df["record_freq"] = (test_data > 0).sum(dim=0).numpy()/len(test_data)

shap_global_mapped_df = shap_global_mapped_df.merge(concepts[["concept_name"]], how="left", left_index=True, right_index=True).reset_index().set_index(["partition", "record", "concept_name", "record_n", "record_freq"])

shap_global_mapped_df.reset_index().to_csv(f"{output_path}/shap_global_p0.csv")

In [None]:
# fix for covariates
x = pd.DataFrame(data=shap_local, 
                 index=datamodule.record_cols + ['age_at_recruitment_f21022_0_0', 'sex_f31_0_0_1', 'sex_f31_0_0_0'],
                 columns=endpoints).rename_axis("record")

x["partition"] = 0
x["record_n"] = (test_data > 0).sum(dim=0).numpy()
x["record_freq"] = (test_data > 0).sum(dim=0).numpy()/len(test_data)

x_mapped = x.merge(concepts[["concept_name"]], how="left", left_index=True, right_index=True).reset_index().set_index(["partition", "record", "concept_name", "record_n", "record_freq"])

x_mapped.reset_index().to_csv(f"{output_path}/shapglobal_p0_test.csv")

In [48]:
pd.set_option('display.max_rows', 25)
code = "phecode_121 - Leukemia"
display(x_mapped[code].dropna().sort_values().head(25))
display(x_mapped[code].dropna().sort_values().tail(25))

record         concept_name                                      
OMOP_4195603   Operation on vagina                                  -0.275197
OMOP_4263879   Marie's cerebellar ataxia                            -0.260202
OMOP_434316    von Willebrand disorder                              -0.241875
OMOP_439404    Primary malignant neoplasm of oral cavity            -0.235471
OMOP_435094    Open fracture of femur, distal end                   -0.187018
OMOP_4030065   Hyposplenism                                         -0.186954
OMOP_4018853   Decompression of cardiac tamponade                   -0.176256
OMOP_133713    Malignant melanoma of skin of face                   -0.176001
OMOP_4070718   Biopsy of lesion of palate                           -0.174164
OMOP_4168815   Contact with plant spines, thorns, or sharp leaves   -0.170073
OMOP_4289309   Atrial septal defect                                 -0.165623
OMOP_4052685   Metatarsus adductus                                  -0.15743

record                         concept_name                                                       
OMOP_4058706                   History of leukemia                                                    0.286241
OMOP_4013643                   Pulmonary arterial hypertension                                        0.289997
OMOP_195861                    Small kidney                                                           0.293899
OMOP_4304002                   Eosinophil count raised                                                0.301299
OMOP_4060043                   Excision of lesion of atrium                                           0.304675
OMOP_4017875                   Homograft tricuspid valve replacement                                  0.307416
OMOP_4198132                   Hematology screening test                                              0.308533
OMOP_4114976                   Balanced rearrangement and structural marker                           0.308767
OMOP_4027567 

In [None]:
x_mapped["phecode_089 - Infections"].sort_values()

In [75]:
datamodule.label_mapping

{'OMOP_4306655': 'OMOP_4306655',
 'phecode_401': 'phecode_401 - Hypertension',
 'phecode_401-1': 'phecode_401-1 - Essential hypertension',
 'phecode_130': 'phecode_130 - Cancer (solid tumor, excluding BCC)',
 'phecode_089': 'phecode_089 - Infections',
 'phecode_460': 'phecode_460 - Acute respiratory infection',
 'phecode_202': 'phecode_202 - Diabetes mellitus',
 'phecode_202-2': 'phecode_202-2 - Type 2 diabetes',
 'phecode_404': 'phecode_404 - Ischemic heart disease',
 'phecode_583': 'phecode_583 - Chronic kidney disease',
 'phecode_089-2': 'phecode_089-2 - Viral infections',
 'phecode_718': 'phecode_718 - Back pain',
 'phecode_708': 'phecode_708 - Osteoarthritis',
 'phecode_475': 'phecode_475 - Asthma',
 'phecode_089-1': 'phecode_089-1 - Bacterial infections',
 'phecode_809': 'phecode_809 - Pain',
 'phecode_713': 'phecode_713 - Symptoms related to joints',
 'phecode_416': 'phecode_416 - Cardiac arrhythmia and conduction disorders',
 'phecode_239': 'phecode_239 - Hyperlipidemia',
 'phe

In [None]:
shap_values = shap_values if not isinstance(shap_values, list) else shap_values[0]

print('Done')

    