In [1]:
%load_ext autoreload
%autoreload 2

In [2]:
import hydra
import numpy as np
import torch
from hydra import compose, initialize
from omegaconf import OmegaConf
import pandas as pd

from tqdm.auto import tqdm

In [3]:
from ehrgraphs.models.supervised import RecordsGraphTraining
from ehrgraphs.training import setup_training

Using backend: pytorch


In [4]:
import pathlib

In [5]:
node = !hostname
if "sc" in node[0]:
    base_path = "/sc-projects/sc-proj-ukb-cvd"
else: 
    base_path = "/data/analysis/ag-reils/ag-reils-shared/cardioRS"
print(base_path)

project_label = "22_medical_records"
project_path = f"{base_path}/results/projects/{project_label}"
figure_path = f"{project_path}/figures"
output_path = f"{project_path}/data"

pathlib.Path(figure_path).mkdir(parents=True, exist_ok=True)
pathlib.Path(output_path).mkdir(parents=True, exist_ok=True)

/sc-projects/sc-proj-ukb-cvd


In [6]:
#from ehrgraphs.utils.attribution import ShapWrapper

In [7]:
import shap as shap

In [8]:
hydra.initialize(config_path="../../ehrgraphs/config")

hydra.initialize()

In [9]:
import wandb
api = wandb.Api()
entity, project = "cardiors", "recordgraphs"  # set to your entity and project 
runs = api.runs(entity + "/" + project) 

In [14]:
run_list = []
for run in tqdm(runs): 
    run_list.append(
        {
            "id": run.path[-1], 
            "name": run.name,
            "tags": run.tags,
            "config": {k: v for k,v in run.config.items() if not k.startswith('_')},
            "summary": run.summary._json_dict,
            "path": None if "best_checkpoint" not in run.config.keys() else str(pathlib.Path(run.config["best_checkpoint"]).parent.parent),
            "best_checkpoint": None if "best_checkpoint" not in run.config.keys() else str(pathlib.Path(run.config["best_checkpoint"]).parent.parent)
        }
    )

  0%|          | 0/1280 [00:00<?, ?it/s]

In [15]:
runs_df = pd.DataFrame(run_list)

In [18]:
tag = "220222"
runs_df = runs_df[runs_df.tags.astype(str).str.contains(tag)].query("path==path")
runs_df = runs_df[runs_df.name.astype(str).str.contains("gnn_agesex")]

In [25]:
runs_df.iloc[0].config["datamodule"]#["partition"]

KeyError: 'datamodule'

In [None]:
@ray.remote
def extract_records_events_times(args):
    
    # prepare extraction
    datamodule, _, _ = setup_training(args)

    eids = datamodule.eids["test"]
    record_cols = datamodule.record_cols
    label_cols = list(datamodule.label_mapping.keys())
    dataset = datamodule.test_dataloader().dataset

    # extract records
    records_df = pd.DataFrame.sparse.from_spmatrix(dataset.records, index=eids, columns=record_cols).sparse.to_dense().rename_axis("eid").reset_index()

    # extract exclusion & events
    exclusions_df = pd.DataFrame.sparse.from_spmatrix(dataset.exclusions, index=eids, columns=label_cols).sparse.to_dense().rename_axis("eid").reset_index()
    events_df = pd.DataFrame.sparse.from_spmatrix(dataset.labels_events, index=eids, columns=label_cols).sparse.to_dense().rename_axis("eid").reset_index()

    times = dataset.labels_times.todense()
    censorings = dataset.censorings

    no_event_idxs = times == 0
    times[no_event_idxs] = censorings[:, None].repeat(repeats=times.shape[1], axis=1)[no_event_idxs]

    times_df = pd.DataFrame(data=times, index=eids, columns=label_cols).rename_axis("eid").reset_index()

    # pivot to long
    exclusions_long = exclusions_df.astype(np.float32, copy=False).melt(id_vars="eid", var_name="endpoint", value_name="prevalent").set_index(["eid", "endpoint"])
    events_long = events_df.astype(np.float32, copy=False).melt(id_vars="eid", var_name="endpoint", value_name="event").set_index(["eid", "endpoint"])
    times_long = times_df.astype(np.float32, copy=False).melt(id_vars="eid", var_name="endpoint", value_name="time").set_index(["eid", "endpoint"])

    outcomes_df = pd.concat([exclusions_long, events_long, times_long], axis=1).reset_index()

    
    return records_df, outcomes_df

In [9]:
predictions_list = []

for p in tqdm(runs_df.path):
    temp = pd.read_feather(f"{p}/predictions/predictions.feather")
    predictions_list.append(temp)
    
predictions = pd.concat(predictions_list, axis=0).rename(columns={"index": "eid"}).reset_index(drop=True)

  0%|          | 0/110 [00:00<?, ?it/s]

In [9]:
# basic:
cfg = hydra.compose(config_name="config")
checkpoint_path = "/sc-projects/sc-proj-ukb-cvd/results/recordgraphs/outputs/2022-02-21/14-08-30/RecordGraphs/189e3kgt/checkpoints/epoch=88-step=4182.ckpt"

In [10]:
device = torch.device("cpu")#;#("cuda:0")

In [11]:
concepts = pd.read_feather("/sc-projects/sc-proj-ukb-cvd/data/mapping/athena/CONCEPT.feather")
concepts = concepts.assign(record = lambda x: "OMOP_" + x.concept_id.astype(str)).set_index("record")

In [12]:
def prepare_model(model, device):
    ckpt_model = RecordsGraphTraining.load_from_checkpoint(
        checkpoint_path, graph_encoder=model.graph_encoder, head=model.head, projector=model.projector
    )

    ckpt_model.eval()
    ckpt_model.freeze()

    ckpt_model = ckpt_model.to(device)
    return ckpt_model

In [13]:
def prepare_test_dataloader(datamodule):
    _ = datamodule.train_dataloader()
    test_dataloader = datamodule.test_dataloader()
        
    return test_dataloader

In [14]:
def prepare_node_embeddings(test_dataloader, ckpt_model):
    
    batch = next(iter(test_dataloader))

    batch.graph = batch.graph.to(device)
    batch.records = batch.records.to(device)
    batch.covariates = batch.covariates.to(device)

    predictions = ckpt_model(batch)
    del batch

    if "record_node_embeddings" in predictions:
        record_node_embeddings = predictions["record_node_embeddings"].detach()#.cpu()#.numpy()
        record_node_embeddings.requires_grad = False
    else:
        record_node_embeddings = None
        
    return record_node_embeddings

In [15]:
def prepare_data(datamodule, n_rows=None):
    # and returns an Nx10 tensor of class probabilities.
    # collect inputs
    test_data = torch.cat((torch.Tensor(datamodule.test_dataset.records.todense()),
                           torch.Tensor(datamodule.test_dataset.covariates)), axis=1)
    
    if n_rows is not None:
        test_data = test_data[:n_rows, :].to(device) # TODO remov
    else:
        test_data = test_data[:, :].to(device)

    baseline_data = torch.zeros(2, test_data.shape[1]).to(device)
    baseline_data.requires_grad = True
    
    return test_data, baseline_data

In [16]:
datamodule, model, _ = setup_training(cfg)

/sc-projects/sc-proj-ukb-cvd/data/2_datasets_pre/211110_anewbeginning/artifacts/WandBGraphDataNoShortcuts256_220203.p
Generating train dataset...
Generating valid dataset...
Using edge types: ['Is a' 'ATC - RxNorm sec up' 'Tradename of' 'Has tradename' 'Constitutes'
 'RxNorm inverse is a' 'RxNorm ing of' 'Dose form group of' 'Has method'
 'Has asso morph' 'Has interprets' 'Interprets of' 'Is descendant of'
 'Is associated with' 'Is ancestor of' 'Asso morph of' 'Method of'
 'Interacts with' 'Is part of' 'Composed of']


In [17]:
ckpt_model = prepare_model(model, device)

In [18]:
test_dataloader = prepare_test_dataloader(datamodule)

Generating test dataset...


In [19]:
record_node_embeddings = prepare_node_embeddings(test_dataloader, ckpt_model)

Process Process-2:
Process Process-1:
Process Process-3:
Process Process-4:
Traceback (most recent call last):
Traceback (most recent call last):
Traceback (most recent call last):
  File "/home/steinfej/miniconda3/envs/ehrgraphs/lib/python3.9/multiprocessing/process.py", line 318, in _bootstrap
    util._exit_function()
  File "/home/steinfej/miniconda3/envs/ehrgraphs/lib/python3.9/multiprocessing/process.py", line 318, in _bootstrap
    util._exit_function()
  File "/home/steinfej/miniconda3/envs/ehrgraphs/lib/python3.9/multiprocessing/process.py", line 318, in _bootstrap
    util._exit_function()
  File "/home/steinfej/miniconda3/envs/ehrgraphs/lib/python3.9/multiprocessing/util.py", line 360, in _exit_function
    _run_finalizers()
  File "/home/steinfej/miniconda3/envs/ehrgraphs/lib/python3.9/multiprocessing/util.py", line 360, in _exit_function
    _run_finalizers()
  File "/home/steinfej/miniconda3/envs/ehrgraphs/lib/python3.9/multiprocessing/util.py", line 300, in _run_finalize

In [20]:
test_data, baseline_data = prepare_data(datamodule)

In [21]:
head_model = ckpt_model.head

In [22]:
import torch
import torch.nn as nn


class ShapWrapper(nn.Module):
    def __init__(self,
                 head_model, 
                 record_node_embeddings,
                 device, 
                 n_records=16201
                ):
        """
        Wrapps a Module of instance RecordsTraining and fixes the embedding to be able to attribute to the records.
        """
        super().__init__()
        self.head_model = head_model
        self.head_model.gradient_checkpointing = False
        self.record_node_embeddings = nn.Parameter(record_node_embeddings)
        self.record_node_embeddings.requires_grad = False
        self.n_records = n_records
        self.data = None

    def forward(self, batch_data):
        records = batch_data
        if batch_data.shape[1] > self.n_records:
            # then we have covariates:
            covariates = batch_data[:, self.n_records:]
            records = batch_data[:, :self.n_records]
            
        features = records @ self.record_node_embeddings
        features = torch.sign(features) * torch.log1p(torch.abs(features))

        if features is None:
            assert covariates is not None
            inputs = covariates.to(device)
        else:
            inputs = torch.cat((features, covariates), axis=1).to(device)

        head_outputs = self.head_model(inputs)
        logits = head_outputs['logits']
        
        return logits

In [23]:
import ray
ray.init(num_cpus=10, dashboard_port=24762, dashboard_host="0.0.0.0", include_dashboard=True)#, webui_url="0.0.0.0"))

2022-02-28 11:52:40,519	INFO services.py:1374 -- View the Ray dashboard at [1m[32mhttp://10.32.105.108:24762[39m[22m


{'node_ip_address': '10.32.105.108',
 'raylet_ip_address': '10.32.105.108',
 'redis_address': '10.32.105.108:6379',
 'object_store_address': '/tmp/ray/session_2022-02-28_11-52-34_008994_1951908/sockets/plasma_store',
 'raylet_socket_name': '/tmp/ray/session_2022-02-28_11-52-34_008994_1951908/sockets/raylet',
 'webui_url': '10.32.105.108:24762',
 'session_dir': '/tmp/ray/session_2022-02-28_11-52-34_008994_1951908',
 'metrics_export_port': 42583,
 'gcs_address': '10.32.105.108:34781',
 'node_id': 'a8e04af0a96c3c80cc6c12e0aff89b044252c666f78202546f9e0ded'}

In [34]:
endpoints = sorted(list(datamodule.label_mapping.values()))

In [35]:
n_records = len(datamodule.record_cols)

In [36]:
attrib_path = f"{output_path}/shap_220228"
pathlib.Path(attrib_path).mkdir(parents=True, exist_ok=True)

In [37]:
@ray.remote
def save_tensor(fp, tensor):
    array = tensor.numpy()
    with open(fp, "wb") as fh:
        cctx = zstandard.ZstdCompressor()
        with cctx.stream_writer(fh) as compressor:
            compressor.write(pickle.dumps(array, protocol=pickle.HIGHEST_PROTOCOL))

In [38]:
device = torch.device("cuda")

In [39]:
batch_size=25000

In [40]:
test_data = test_data.to(device)
baseline_data = baseline_data.to(device)

In [41]:
test_splits = torch.split(test_data, batch_size)

In [None]:
# gpu
from captum.attr import DeepLiftShap
import zstandard
import pickle

task = ShapWrapper(head_model,
           record_node_embeddings,
           device,
           n_records,
              )

task = task.to(device)
explainer = DeepLiftShap(task)

for endpoint_idx in tqdm(range(len(endpoints))):  
    
    endpoint_label = endpoints[endpoint_idx]
    
    temp_shap = torch.cat([
        explainer.attribute(
            test_batch, 
            baseline_data, 
            target=endpoint_idx).detach().cpu() 
        for test_batch in test_splits
    ])
        
    fp = f"{attrib_path}/shapmatrix_{endpoint_label}.p"
    save_tensor.remote(fp, temp_shap)

  0%|          | 0/1001 [00:00<?, ?it/s]

  0%|          | 0/2 [00:00<?, ?it/s]

Setting forward, backward hooks and attributes on non-linear
               activations. The hooks and attributes will be removed
            after the attribution is finished


  0%|          | 0/2 [00:00<?, ?it/s]

  0%|          | 0/2 [00:00<?, ?it/s]

  0%|          | 0/2 [00:00<?, ?it/s]

  0%|          | 0/2 [00:00<?, ?it/s]

  0%|          | 0/2 [00:00<?, ?it/s]

  0%|          | 0/2 [00:00<?, ?it/s]

  0%|          | 0/2 [00:00<?, ?it/s]

  0%|          | 0/2 [00:00<?, ?it/s]

  0%|          | 0/2 [00:00<?, ?it/s]

  0%|          | 0/2 [00:00<?, ?it/s]

  0%|          | 0/2 [00:00<?, ?it/s]

  0%|          | 0/2 [00:00<?, ?it/s]

  0%|          | 0/2 [00:00<?, ?it/s]

  0%|          | 0/2 [00:00<?, ?it/s]

  0%|          | 0/2 [00:00<?, ?it/s]

  0%|          | 0/2 [00:00<?, ?it/s]

  0%|          | 0/2 [00:00<?, ?it/s]

  0%|          | 0/2 [00:00<?, ?it/s]

  0%|          | 0/2 [00:00<?, ?it/s]

  0%|          | 0/2 [00:00<?, ?it/s]

  0%|          | 0/2 [00:00<?, ?it/s]

  0%|          | 0/2 [00:00<?, ?it/s]

  0%|          | 0/2 [00:00<?, ?it/s]

  0%|          | 0/2 [00:00<?, ?it/s]

  0%|          | 0/2 [00:00<?, ?it/s]

  0%|          | 0/2 [00:00<?, ?it/s]

  0%|          | 0/2 [00:00<?, ?it/s]

  0%|          | 0/2 [00:00<?, ?it/s]

  0%|          | 0/2 [00:00<?, ?it/s]

  0%|          | 0/2 [00:00<?, ?it/s]

  0%|          | 0/2 [00:00<?, ?it/s]

  0%|          | 0/2 [00:00<?, ?it/s]

  0%|          | 0/2 [00:00<?, ?it/s]

  0%|          | 0/2 [00:00<?, ?it/s]

  0%|          | 0/2 [00:00<?, ?it/s]

  0%|          | 0/2 [00:00<?, ?it/s]

  0%|          | 0/2 [00:00<?, ?it/s]

  0%|          | 0/2 [00:00<?, ?it/s]

  0%|          | 0/2 [00:00<?, ?it/s]

  0%|          | 0/2 [00:00<?, ?it/s]

  0%|          | 0/2 [00:00<?, ?it/s]

  0%|          | 0/2 [00:00<?, ?it/s]

  0%|          | 0/2 [00:00<?, ?it/s]

  0%|          | 0/2 [00:00<?, ?it/s]

  0%|          | 0/2 [00:00<?, ?it/s]

  0%|          | 0/2 [00:00<?, ?it/s]

  0%|          | 0/2 [00:00<?, ?it/s]

  0%|          | 0/2 [00:00<?, ?it/s]

  0%|          | 0/2 [00:00<?, ?it/s]

  0%|          | 0/2 [00:00<?, ?it/s]

  0%|          | 0/2 [00:00<?, ?it/s]

  0%|          | 0/2 [00:00<?, ?it/s]

  0%|          | 0/2 [00:00<?, ?it/s]

  0%|          | 0/2 [00:00<?, ?it/s]

  0%|          | 0/2 [00:00<?, ?it/s]

  0%|          | 0/2 [00:00<?, ?it/s]

  0%|          | 0/2 [00:00<?, ?it/s]

  0%|          | 0/2 [00:00<?, ?it/s]

  0%|          | 0/2 [00:00<?, ?it/s]

  0%|          | 0/2 [00:00<?, ?it/s]

  0%|          | 0/2 [00:00<?, ?it/s]

  0%|          | 0/2 [00:00<?, ?it/s]

  0%|          | 0/2 [00:00<?, ?it/s]

  0%|          | 0/2 [00:00<?, ?it/s]

  0%|          | 0/2 [00:00<?, ?it/s]

  0%|          | 0/2 [00:00<?, ?it/s]

  0%|          | 0/2 [00:00<?, ?it/s]

  0%|          | 0/2 [00:00<?, ?it/s]

  0%|          | 0/2 [00:00<?, ?it/s]

  0%|          | 0/2 [00:00<?, ?it/s]

  0%|          | 0/2 [00:00<?, ?it/s]

  0%|          | 0/2 [00:00<?, ?it/s]

  0%|          | 0/2 [00:00<?, ?it/s]

  0%|          | 0/2 [00:00<?, ?it/s]

  0%|          | 0/2 [00:00<?, ?it/s]

  0%|          | 0/2 [00:00<?, ?it/s]

  0%|          | 0/2 [00:00<?, ?it/s]

  0%|          | 0/2 [00:00<?, ?it/s]

  0%|          | 0/2 [00:00<?, ?it/s]

  0%|          | 0/2 [00:00<?, ?it/s]

  0%|          | 0/2 [00:00<?, ?it/s]

  0%|          | 0/2 [00:00<?, ?it/s]

  0%|          | 0/2 [00:00<?, ?it/s]

  0%|          | 0/2 [00:00<?, ?it/s]

  0%|          | 0/2 [00:00<?, ?it/s]

  0%|          | 0/2 [00:00<?, ?it/s]

  0%|          | 0/2 [00:00<?, ?it/s]

  0%|          | 0/2 [00:00<?, ?it/s]

  0%|          | 0/2 [00:00<?, ?it/s]

  0%|          | 0/2 [00:00<?, ?it/s]

  0%|          | 0/2 [00:00<?, ?it/s]

  0%|          | 0/2 [00:00<?, ?it/s]

  0%|          | 0/2 [00:00<?, ?it/s]

  0%|          | 0/2 [00:00<?, ?it/s]

  0%|          | 0/2 [00:00<?, ?it/s]

  0%|          | 0/2 [00:00<?, ?it/s]

  0%|          | 0/2 [00:00<?, ?it/s]

  0%|          | 0/2 [00:00<?, ?it/s]

  0%|          | 0/2 [00:00<?, ?it/s]

  0%|          | 0/2 [00:00<?, ?it/s]

  0%|          | 0/2 [00:00<?, ?it/s]

  0%|          | 0/2 [00:00<?, ?it/s]

  0%|          | 0/2 [00:00<?, ?it/s]

  0%|          | 0/2 [00:00<?, ?it/s]

  0%|          | 0/2 [00:00<?, ?it/s]

  0%|          | 0/2 [00:00<?, ?it/s]

  0%|          | 0/2 [00:00<?, ?it/s]

  0%|          | 0/2 [00:00<?, ?it/s]

  0%|          | 0/2 [00:00<?, ?it/s]

  0%|          | 0/2 [00:00<?, ?it/s]

  0%|          | 0/2 [00:00<?, ?it/s]

  0%|          | 0/2 [00:00<?, ?it/s]

  0%|          | 0/2 [00:00<?, ?it/s]

  0%|          | 0/2 [00:00<?, ?it/s]

  0%|          | 0/2 [00:00<?, ?it/s]

  0%|          | 0/2 [00:00<?, ?it/s]

  0%|          | 0/2 [00:00<?, ?it/s]

  0%|          | 0/2 [00:00<?, ?it/s]

  0%|          | 0/2 [00:00<?, ?it/s]

  0%|          | 0/2 [00:00<?, ?it/s]

  0%|          | 0/2 [00:00<?, ?it/s]

  0%|          | 0/2 [00:00<?, ?it/s]

  0%|          | 0/2 [00:00<?, ?it/s]

  0%|          | 0/2 [00:00<?, ?it/s]

  0%|          | 0/2 [00:00<?, ?it/s]

  0%|          | 0/2 [00:00<?, ?it/s]

  0%|          | 0/2 [00:00<?, ?it/s]

  0%|          | 0/2 [00:00<?, ?it/s]

  0%|          | 0/2 [00:00<?, ?it/s]

  0%|          | 0/2 [00:00<?, ?it/s]

  0%|          | 0/2 [00:00<?, ?it/s]

  0%|          | 0/2 [00:00<?, ?it/s]

  0%|          | 0/2 [00:00<?, ?it/s]

  0%|          | 0/2 [00:00<?, ?it/s]

  0%|          | 0/2 [00:00<?, ?it/s]

  0%|          | 0/2 [00:00<?, ?it/s]

  0%|          | 0/2 [00:00<?, ?it/s]

  0%|          | 0/2 [00:00<?, ?it/s]

  0%|          | 0/2 [00:00<?, ?it/s]

  0%|          | 0/2 [00:00<?, ?it/s]

  0%|          | 0/2 [00:00<?, ?it/s]

  0%|          | 0/2 [00:00<?, ?it/s]

  0%|          | 0/2 [00:00<?, ?it/s]

  0%|          | 0/2 [00:00<?, ?it/s]

  0%|          | 0/2 [00:00<?, ?it/s]

  0%|          | 0/2 [00:00<?, ?it/s]

  0%|          | 0/2 [00:00<?, ?it/s]

  0%|          | 0/2 [00:00<?, ?it/s]

  0%|          | 0/2 [00:00<?, ?it/s]

  0%|          | 0/2 [00:00<?, ?it/s]

  0%|          | 0/2 [00:00<?, ?it/s]

  0%|          | 0/2 [00:00<?, ?it/s]

  0%|          | 0/2 [00:00<?, ?it/s]

  0%|          | 0/2 [00:00<?, ?it/s]

  0%|          | 0/2 [00:00<?, ?it/s]

  0%|          | 0/2 [00:00<?, ?it/s]

  0%|          | 0/2 [00:00<?, ?it/s]

  0%|          | 0/2 [00:00<?, ?it/s]

  0%|          | 0/2 [00:00<?, ?it/s]

  0%|          | 0/2 [00:00<?, ?it/s]

  0%|          | 0/2 [00:00<?, ?it/s]

  0%|          | 0/2 [00:00<?, ?it/s]

  0%|          | 0/2 [00:00<?, ?it/s]

  0%|          | 0/2 [00:00<?, ?it/s]

  0%|          | 0/2 [00:00<?, ?it/s]

  0%|          | 0/2 [00:00<?, ?it/s]

  0%|          | 0/2 [00:00<?, ?it/s]

  0%|          | 0/2 [00:00<?, ?it/s]

  0%|          | 0/2 [00:00<?, ?it/s]

  0%|          | 0/2 [00:00<?, ?it/s]

  0%|          | 0/2 [00:00<?, ?it/s]

  0%|          | 0/2 [00:00<?, ?it/s]

  0%|          | 0/2 [00:00<?, ?it/s]

  0%|          | 0/2 [00:00<?, ?it/s]

  0%|          | 0/2 [00:00<?, ?it/s]

  0%|          | 0/2 [00:00<?, ?it/s]

  0%|          | 0/2 [00:00<?, ?it/s]

  0%|          | 0/2 [00:00<?, ?it/s]

  0%|          | 0/2 [00:00<?, ?it/s]

  0%|          | 0/2 [00:00<?, ?it/s]

  0%|          | 0/2 [00:00<?, ?it/s]

  0%|          | 0/2 [00:00<?, ?it/s]

  0%|          | 0/2 [00:00<?, ?it/s]

  0%|          | 0/2 [00:00<?, ?it/s]

  0%|          | 0/2 [00:00<?, ?it/s]

  0%|          | 0/2 [00:00<?, ?it/s]

  0%|          | 0/2 [00:00<?, ?it/s]

  0%|          | 0/2 [00:00<?, ?it/s]

  0%|          | 0/2 [00:00<?, ?it/s]

  0%|          | 0/2 [00:00<?, ?it/s]

  0%|          | 0/2 [00:00<?, ?it/s]

  0%|          | 0/2 [00:00<?, ?it/s]

  0%|          | 0/2 [00:00<?, ?it/s]

  0%|          | 0/2 [00:00<?, ?it/s]

  0%|          | 0/2 [00:00<?, ?it/s]

  0%|          | 0/2 [00:00<?, ?it/s]

  0%|          | 0/2 [00:00<?, ?it/s]

  0%|          | 0/2 [00:00<?, ?it/s]

  0%|          | 0/2 [00:00<?, ?it/s]

  0%|          | 0/2 [00:00<?, ?it/s]

  0%|          | 0/2 [00:00<?, ?it/s]

  0%|          | 0/2 [00:00<?, ?it/s]

  0%|          | 0/2 [00:00<?, ?it/s]

  0%|          | 0/2 [00:00<?, ?it/s]

  0%|          | 0/2 [00:00<?, ?it/s]

  0%|          | 0/2 [00:00<?, ?it/s]

  0%|          | 0/2 [00:00<?, ?it/s]

  0%|          | 0/2 [00:00<?, ?it/s]

  0%|          | 0/2 [00:00<?, ?it/s]

  0%|          | 0/2 [00:00<?, ?it/s]

  0%|          | 0/2 [00:00<?, ?it/s]

  0%|          | 0/2 [00:00<?, ?it/s]

  0%|          | 0/2 [00:00<?, ?it/s]

  0%|          | 0/2 [00:00<?, ?it/s]

  0%|          | 0/2 [00:00<?, ?it/s]

  0%|          | 0/2 [00:00<?, ?it/s]

  0%|          | 0/2 [00:00<?, ?it/s]

  0%|          | 0/2 [00:00<?, ?it/s]

  0%|          | 0/2 [00:00<?, ?it/s]

  0%|          | 0/2 [00:00<?, ?it/s]

  0%|          | 0/2 [00:00<?, ?it/s]

  0%|          | 0/2 [00:00<?, ?it/s]

  0%|          | 0/2 [00:00<?, ?it/s]

  0%|          | 0/2 [00:00<?, ?it/s]

  0%|          | 0/2 [00:00<?, ?it/s]

  0%|          | 0/2 [00:00<?, ?it/s]

  0%|          | 0/2 [00:00<?, ?it/s]

  0%|          | 0/2 [00:00<?, ?it/s]

  0%|          | 0/2 [00:00<?, ?it/s]

  0%|          | 0/2 [00:00<?, ?it/s]

  0%|          | 0/2 [00:00<?, ?it/s]

  0%|          | 0/2 [00:00<?, ?it/s]

  0%|          | 0/2 [00:00<?, ?it/s]

  0%|          | 0/2 [00:00<?, ?it/s]

  0%|          | 0/2 [00:00<?, ?it/s]

  0%|          | 0/2 [00:00<?, ?it/s]

  0%|          | 0/2 [00:00<?, ?it/s]

  0%|          | 0/2 [00:00<?, ?it/s]

  0%|          | 0/2 [00:00<?, ?it/s]

  0%|          | 0/2 [00:00<?, ?it/s]

  0%|          | 0/2 [00:00<?, ?it/s]

  0%|          | 0/2 [00:00<?, ?it/s]

  0%|          | 0/2 [00:00<?, ?it/s]

  0%|          | 0/2 [00:00<?, ?it/s]

  0%|          | 0/2 [00:00<?, ?it/s]

  0%|          | 0/2 [00:00<?, ?it/s]

  0%|          | 0/2 [00:00<?, ?it/s]

  0%|          | 0/2 [00:00<?, ?it/s]

  0%|          | 0/2 [00:00<?, ?it/s]

  0%|          | 0/2 [00:00<?, ?it/s]

  0%|          | 0/2 [00:00<?, ?it/s]

  0%|          | 0/2 [00:00<?, ?it/s]

  0%|          | 0/2 [00:00<?, ?it/s]

  0%|          | 0/2 [00:00<?, ?it/s]

  0%|          | 0/2 [00:00<?, ?it/s]

  0%|          | 0/2 [00:00<?, ?it/s]

  0%|          | 0/2 [00:00<?, ?it/s]

  0%|          | 0/2 [00:00<?, ?it/s]

  0%|          | 0/2 [00:00<?, ?it/s]

  0%|          | 0/2 [00:00<?, ?it/s]

  0%|          | 0/2 [00:00<?, ?it/s]

  0%|          | 0/2 [00:00<?, ?it/s]

  0%|          | 0/2 [00:00<?, ?it/s]

  0%|          | 0/2 [00:00<?, ?it/s]

  0%|          | 0/2 [00:00<?, ?it/s]

  0%|          | 0/2 [00:00<?, ?it/s]

  0%|          | 0/2 [00:00<?, ?it/s]

  0%|          | 0/2 [00:00<?, ?it/s]

  0%|          | 0/2 [00:00<?, ?it/s]

  0%|          | 0/2 [00:00<?, ?it/s]

  0%|          | 0/2 [00:00<?, ?it/s]

  0%|          | 0/2 [00:00<?, ?it/s]

  0%|          | 0/2 [00:00<?, ?it/s]

  0%|          | 0/2 [00:00<?, ?it/s]

  0%|          | 0/2 [00:00<?, ?it/s]

  0%|          | 0/2 [00:00<?, ?it/s]

  0%|          | 0/2 [00:00<?, ?it/s]

  0%|          | 0/2 [00:00<?, ?it/s]

  0%|          | 0/2 [00:00<?, ?it/s]

  0%|          | 0/2 [00:00<?, ?it/s]

  0%|          | 0/2 [00:00<?, ?it/s]

  0%|          | 0/2 [00:00<?, ?it/s]

  0%|          | 0/2 [00:00<?, ?it/s]

  0%|          | 0/2 [00:00<?, ?it/s]

  0%|          | 0/2 [00:00<?, ?it/s]

  0%|          | 0/2 [00:00<?, ?it/s]

  0%|          | 0/2 [00:00<?, ?it/s]

  0%|          | 0/2 [00:00<?, ?it/s]

  0%|          | 0/2 [00:00<?, ?it/s]

  0%|          | 0/2 [00:00<?, ?it/s]

  0%|          | 0/2 [00:00<?, ?it/s]

  0%|          | 0/2 [00:00<?, ?it/s]

  0%|          | 0/2 [00:00<?, ?it/s]

  0%|          | 0/2 [00:00<?, ?it/s]

  0%|          | 0/2 [00:00<?, ?it/s]

  0%|          | 0/2 [00:00<?, ?it/s]

  0%|          | 0/2 [00:00<?, ?it/s]

  0%|          | 0/2 [00:00<?, ?it/s]

  0%|          | 0/2 [00:00<?, ?it/s]

  0%|          | 0/2 [00:00<?, ?it/s]

  0%|          | 0/2 [00:00<?, ?it/s]

  0%|          | 0/2 [00:00<?, ?it/s]

  0%|          | 0/2 [00:00<?, ?it/s]

  0%|          | 0/2 [00:00<?, ?it/s]

  0%|          | 0/2 [00:00<?, ?it/s]

  0%|          | 0/2 [00:00<?, ?it/s]

  0%|          | 0/2 [00:00<?, ?it/s]

  0%|          | 0/2 [00:00<?, ?it/s]

  0%|          | 0/2 [00:00<?, ?it/s]

  0%|          | 0/2 [00:00<?, ?it/s]

  0%|          | 0/2 [00:00<?, ?it/s]

  0%|          | 0/2 [00:00<?, ?it/s]

  0%|          | 0/2 [00:00<?, ?it/s]

  0%|          | 0/2 [00:00<?, ?it/s]

  0%|          | 0/2 [00:00<?, ?it/s]

  0%|          | 0/2 [00:00<?, ?it/s]

  0%|          | 0/2 [00:00<?, ?it/s]

  0%|          | 0/2 [00:00<?, ?it/s]

  0%|          | 0/2 [00:00<?, ?it/s]

  0%|          | 0/2 [00:00<?, ?it/s]

  0%|          | 0/2 [00:00<?, ?it/s]

  0%|          | 0/2 [00:00<?, ?it/s]

  0%|          | 0/2 [00:00<?, ?it/s]

  0%|          | 0/2 [00:00<?, ?it/s]

  0%|          | 0/2 [00:00<?, ?it/s]

  0%|          | 0/2 [00:00<?, ?it/s]

  0%|          | 0/2 [00:00<?, ?it/s]

  0%|          | 0/2 [00:00<?, ?it/s]

  0%|          | 0/2 [00:00<?, ?it/s]

  0%|          | 0/2 [00:00<?, ?it/s]

  0%|          | 0/2 [00:00<?, ?it/s]

  0%|          | 0/2 [00:00<?, ?it/s]

  0%|          | 0/2 [00:00<?, ?it/s]

  0%|          | 0/2 [00:00<?, ?it/s]

  0%|          | 0/2 [00:00<?, ?it/s]

  0%|          | 0/2 [00:00<?, ?it/s]

  0%|          | 0/2 [00:00<?, ?it/s]

  0%|          | 0/2 [00:00<?, ?it/s]

  0%|          | 0/2 [00:00<?, ?it/s]

  0%|          | 0/2 [00:00<?, ?it/s]

  0%|          | 0/2 [00:00<?, ?it/s]

  0%|          | 0/2 [00:00<?, ?it/s]

  0%|          | 0/2 [00:00<?, ?it/s]

  0%|          | 0/2 [00:00<?, ?it/s]

  0%|          | 0/2 [00:00<?, ?it/s]

  0%|          | 0/2 [00:00<?, ?it/s]

  0%|          | 0/2 [00:00<?, ?it/s]

In [None]:
# cpu
from captum.attr import DeepLiftShap
import zstandard
import pickle

endpoints = list(datamodule.label_mapping.values())

task = ShapWrapper(ckpt_model,
           record_node_embeddings,
           device,
           n_records,
              )

task = task.to(device)

batch_size=2000

for endpoint_idx in tqdm(range(len(endpoints))):  
    
    endpoint_label = endpoints[endpoint_idx]
    explainer = DeepLiftShap(task)
    
    test_splits = torch.split(test_data, batch_size)
    
    temp_shap = torch.cat([
        explainer.attribute(
            test_batch, 
            baseline_data, 
            target=endpoint_idx).detach() 
        for test_batch in tqdm(test_splits)
    ])
        
    fp = f"{attrib_path}/shapmatrix_{endpoint_label}.p"
    save_tensor.remote(fp, temp_shap)

In [54]:
torch.cat(x).shape

torch.Size([44193, 16204])

In [None]:
#shap_global[endpoint_label] = temp_shap.mean(dim=0).numpy()
    #shap_local[endpoint_label] = (temp_shap.sum(dim=0) / (test_data > 0).sum(dim=0)).numpy()

In [82]:
temp_shap

tensor([[ 0.0000,  0.0000,  0.0000,  ...,  0.0670, -0.0037,  0.0000],
        [ 0.0000,  0.0000,  0.0000,  ...,  0.0707,  0.0000,  0.0028],
        [ 0.0000,  0.0000,  0.0000,  ..., -0.1326,  0.0042,  0.0000],
        ...,
        [ 0.0000,  0.0000,  0.0000,  ...,  0.0119,  0.0000, -0.0125],
        [ 0.0000,  0.0000,  0.0000,  ...,  0.0772,  0.0000, -0.0020],
        [ 0.0000,  0.0000,  0.0000,  ...,  0.0257, -0.0125,  0.0000]])

In [72]:
import scipy

scipy.sparse.csr_matrix(temp_shap.numpy())

<10000x16204 sparse matrix of type '<class 'numpy.float32'>'
	with 289993 stored elements in Compressed Sparse Row format>

In [75]:
pd.DataFrame(data=temp_shap, columns=datamodule.record_cols + ['age_at_recruitment_f21022_0_0', 'sex_f31_0_0_1', 'sex_f31_0_0_0'])

KeyboardInterrupt: 

In [45]:
# fix for covariates
shap_local_df = pd.DataFrame(data=shap_local, 
                 index=datamodule.record_cols + ['age_at_recruitment_f21022_0_0', 'sex_f31_0_0_1', 'sex_f31_0_0_0'],
                 columns=endpoints).rename_axis("record")

shap_local_df["partition"] = 0
shap_local_df["record_n"] = (test_data > 0).sum(dim=0).numpy()
shap_local_df["record_freq"] = (test_data > 0).sum(dim=0).numpy()/len(test_data)

shap_local_mapped_df = shap_local_mapped_df.merge(concepts[["concept_name"]], how="left", left_index=True, right_index=True).reset_index().set_index(["partition", "record", "concept_name", "record_n", "record_freq"])

shap_local_mapped_df.reset_index().to_csv(f"{output_path}/shap_local_p0.csv")

In [None]:
# fix for covariates
shap_global_df = pd.DataFrame(data=shap_global, 
                 index=datamodule.record_cols + ['age_at_recruitment_f21022_0_0', 'sex_f31_0_0_1', 'sex_f31_0_0_0'],
                 columns=endpoints).rename_axis("record")

shap_global_df["partition"] = 0
shap_global_df["record_n"] = (test_data > 0).sum(dim=0).numpy()
shap_global_df["record_freq"] = (test_data > 0).sum(dim=0).numpy()/len(test_data)

shap_global_mapped_df = shap_global_mapped_df.merge(concepts[["concept_name"]], how="left", left_index=True, right_index=True).reset_index().set_index(["partition", "record", "concept_name", "record_n", "record_freq"])

shap_global_mapped_df.reset_index().to_csv(f"{output_path}/shap_global_p0.csv")

In [None]:
# fix for covariates
x = pd.DataFrame(data=shap_local, 
                 index=datamodule.record_cols + ['age_at_recruitment_f21022_0_0', 'sex_f31_0_0_1', 'sex_f31_0_0_0'],
                 columns=endpoints).rename_axis("record")

x["partition"] = 0
x["record_n"] = (test_data > 0).sum(dim=0).numpy()
x["record_freq"] = (test_data > 0).sum(dim=0).numpy()/len(test_data)

x_mapped = x.merge(concepts[["concept_name"]], how="left", left_index=True, right_index=True).reset_index().set_index(["partition", "record", "concept_name", "record_n", "record_freq"])

x_mapped.reset_index().to_csv(f"{output_path}/shapglobal_p0_test.csv")

In [48]:
pd.set_option('display.max_rows', 25)
code = "phecode_121 - Leukemia"
display(x_mapped[code].dropna().sort_values().head(25))
display(x_mapped[code].dropna().sort_values().tail(25))

record         concept_name                                      
OMOP_4195603   Operation on vagina                                  -0.275197
OMOP_4263879   Marie's cerebellar ataxia                            -0.260202
OMOP_434316    von Willebrand disorder                              -0.241875
OMOP_439404    Primary malignant neoplasm of oral cavity            -0.235471
OMOP_435094    Open fracture of femur, distal end                   -0.187018
OMOP_4030065   Hyposplenism                                         -0.186954
OMOP_4018853   Decompression of cardiac tamponade                   -0.176256
OMOP_133713    Malignant melanoma of skin of face                   -0.176001
OMOP_4070718   Biopsy of lesion of palate                           -0.174164
OMOP_4168815   Contact with plant spines, thorns, or sharp leaves   -0.170073
OMOP_4289309   Atrial septal defect                                 -0.165623
OMOP_4052685   Metatarsus adductus                                  -0.15743

record                         concept_name                                                       
OMOP_4058706                   History of leukemia                                                    0.286241
OMOP_4013643                   Pulmonary arterial hypertension                                        0.289997
OMOP_195861                    Small kidney                                                           0.293899
OMOP_4304002                   Eosinophil count raised                                                0.301299
OMOP_4060043                   Excision of lesion of atrium                                           0.304675
OMOP_4017875                   Homograft tricuspid valve replacement                                  0.307416
OMOP_4198132                   Hematology screening test                                              0.308533
OMOP_4114976                   Balanced rearrangement and structural marker                           0.308767
OMOP_4027567 

In [None]:
x_mapped["phecode_089 - Infections"].sort_values()

In [75]:
datamodule.label_mapping

{'OMOP_4306655': 'OMOP_4306655',
 'phecode_401': 'phecode_401 - Hypertension',
 'phecode_401-1': 'phecode_401-1 - Essential hypertension',
 'phecode_130': 'phecode_130 - Cancer (solid tumor, excluding BCC)',
 'phecode_089': 'phecode_089 - Infections',
 'phecode_460': 'phecode_460 - Acute respiratory infection',
 'phecode_202': 'phecode_202 - Diabetes mellitus',
 'phecode_202-2': 'phecode_202-2 - Type 2 diabetes',
 'phecode_404': 'phecode_404 - Ischemic heart disease',
 'phecode_583': 'phecode_583 - Chronic kidney disease',
 'phecode_089-2': 'phecode_089-2 - Viral infections',
 'phecode_718': 'phecode_718 - Back pain',
 'phecode_708': 'phecode_708 - Osteoarthritis',
 'phecode_475': 'phecode_475 - Asthma',
 'phecode_089-1': 'phecode_089-1 - Bacterial infections',
 'phecode_809': 'phecode_809 - Pain',
 'phecode_713': 'phecode_713 - Symptoms related to joints',
 'phecode_416': 'phecode_416 - Cardiac arrhythmia and conduction disorders',
 'phecode_239': 'phecode_239 - Hyperlipidemia',
 'phe

In [None]:
shap_values = shap_values if not isinstance(shap_values, list) else shap_values[0]

print('Done')

    