In [1]:
import pathlib
import datetime
import pandas as pd
import hydra
import numpy as np
import torch
from hydra import compose, initialize
from omegaconf import OmegaConf

import matplotlib.pyplot as plt
import captum
from captum.attr import IntegratedGradients, NoiseTunnel, InputXGradient

ModuleNotFoundError: No module named 'captum'

In [None]:
from retinalrisk.training import setup_training
from retinalrisk.models.supervised import SupervisedTraining

from IBA.pytorch import IBA, tensor_to_np_img, get_imagenet_folder, imagenet_transform
from IBA.utils import plot_saliency_map, to_unit_interval, load_monkeys

In [None]:
ARTIFACT_PATH = '/sc-projects/sc-proj-ukb-cvd/data/2_datasets_pre/211110_anewbeginning/artifacts/baseline_outcomes_220627.feather'
MODEL_PATH = '/sc-projects/sc-proj-ukb-cvd/results/models/retina/rids0apm/epoch=127-step=47744.ckpt'
BASE_DIR = '/sc-projects/sc-proj-ukb-cvd/results/projects/22_retina_phewas_220812'

In [None]:
import wandb
from omegaconf import DictConfig, OmegaConf
from pytorch_lightning import Trainer
from pytorch_lightning.callbacks import EarlyStopping, ModelCheckpoint
from pytorch_lightning.loggers import WandbLogger
from torch_geometric import seed_everything

In [None]:
from typing import Union

import PIL
import torch.nn as nn
import torchvision as tv
import torchvision.transforms.functional as TF
from torchvision import transforms
from random import choice
from omegaconf import ListConfig


class AdaptiveRandomCropTransform(nn.Module):
    def __init__(
        self, crop_ratio: Union[list, float], out_size: int, interpolation=PIL.Image.BILINEAR
    ):
        super().__init__()
        self.crop_ratio = crop_ratio
        self.out_size = out_size
        self.interpolation = interpolation

    def forward(self, sample):
        input_size = min(sample.size)
        if isinstance(self.crop_ratio, (list, ListConfig)):
            crop_ratio = choice(self.crop_ratio)
        else:
            crop_ratio = self.crop_ratio

        crop_size = int(crop_ratio * input_size)
        if crop_size < self.out_size:
            crop_size = tv.transforms.transforms._setup_size(
                self.out_size, error_msg="Please provide only two dimensions (h, w) for size."
            )
            i, j, h, w = transforms.RandomCrop.get_params(sample, crop_size)
            return TF.crop(sample, i, j, h, w)
        else:
            crop_size = tv.transforms.transforms._setup_size(
                crop_size, error_msg="Please provide only two dimensions (h, w) for size."
            )
            i, j, h, w = transforms.RandomCrop.get_params(sample, crop_size)
            cropped = TF.crop(sample, i, j, h, w)
        out = TF.resize(cropped, self.out_size, self.interpolation)
        return out, (i, j, h, w)


class ModelWrapper(torch.nn.Module):
    def __init__(self, encoder, head):
        super().__init__()

        self.encoder = encoder
        self.head = head

    def forward(self, x):
        x_hat = self.encoder(x)
        return self.head(x_hat)["logits"]


def loader_wrapper(loader):
    for batch in loader:
        yield batch.data, _

In [None]:
records = pd.read_feather(
    #"/sc-projects/sc-proj-ukb-cvd/data/2_datasets_pre/211110_anewbeginning/artifacts/final_records_omop_220531.feather",
    ARTIFACT_PATH,
    columns=["eid", "recruitment_date", "concept_id", "date"],
)
records.concept_id = records.concept_id.astype("category")

In [None]:
checkpoint_path = (
    #"/sc-projects/sc-proj-ukb-cvd/results/models/retina/1zkzua6h/checkpoints/last.ckpt" # fullrun_jun
    MODEL_PATH # fullrun_aug
)

In [None]:
hydra.initialize(config_path="dev/projects/RetinalRisk/config")

In [None]:
cfg = hydra.compose(
    config_name="config",
    overrides=[
        "training.gradient_checkpointing=False",
        "datamodule/covariates=no_covariates",
        "model=image",
        "setup.use_data_artifact_if_available=False",
        "head=mlp",
        "head.kwargs.num_hidden=512",
        "head.kwargs.num_layers=2",
        "head.dropout=0.5",
        "training.optimizer_kwargs.weight_decay=0.001",
        "training.optimizer_kwargs.lr=0.0001",
        "model.freeze_encoder=False",
        "model.encoder=convnext_small",
        "datamodule.batch_size=32",
        "training.warmup_period=8",
        "datamodule/augmentation=contrast_sharpness_posterize",
        "datamodule.img_size_to_gpu=420",
        "datamodule.num_workers=8",
        "model.pretrained=True",
    ],
)

In [None]:
datamodule, model, _ = setup_training(cfg)
datamodule.test_dataloader()

In [None]:
crop_transform = AdaptiveRandomCropTransform(
    crop_ratio=datamodule.test_dataset.img_crop_ratio,
    out_size=datamodule.test_dataset.img_size_to_gpu,
    interpolation=PIL.Image.BICUBIC,
)

remaining_transforms = transforms.Compose(
    [
        transforms.ToTensor(),
        transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
    ]
)

In [None]:
training_kwargs = dict(
    label_mapping=datamodule.label_mapping,
)

model = SupervisedTraining.load_from_checkpoint(
    checkpoint_path, encoder=model.encoder, head=model.head, **training_kwargs
)

In [None]:
device = "cuda:0"
model_wrapped = ModelWrapper(model.encoder, model.head).to(device)

In [None]:
mean = np.array([0.485, 0.456, 0.406])[:, None, None]
std = np.array([0.229, 0.224, 0.225])[:, None, None]

In [None]:
iba = IBA(model_wrapped.encoder.features[3])
iba.estimate(
    model_wrapped, loader_wrapper(datamodule.val_dataloader()), n_samples=10000, progbar=True
)

In [None]:
endpoints = [
    # generally very important
    "phecode_202 - Diabetes mellitus",
    "phecode_404 - Ischemic heart disease",
    "phecode_404-1 - Myocardial infarction [Heart attack]",
    "phecode_431-11 - Cerebral infarction [Ischemic stroke]",
    "phecode_424 - Heart failure",
    "OMOP_4306655 - All-Cause Death",
    # also generally important and relevant
    "phecode_440-3 - Pulmonary embolism",
    "phecode_468 - Pneumonia",
    "phecode_474 - Chronic obstructive pulmonary disease [COPD]",
    "phecode_542 - Chronic liver disease and sequelae",
    "phecode_583 - Chronic kidney disease",
    "phecode_328 - Dementias and cerebral degeneration",
    # generally important and fun to check
    "phecode_164 - Anemia",
    "phecode_726-1 - Osteoporosis",
    "phecode_103 - Malignant neoplasm of the skin",
    "phecode_101 - Malignant neoplasm of the digestive organs",
    "phecode_665 - Psoriasis",
    "phecode_705-1 - Rheumatoid arthritis",
    # important for eye
    "phecode_371 - Cataract",
    "phecode_374-3 - Retinal vascular changes and occlusions",
    "phecode_374-42 - Diabetic retinopathy",
    "phecode_374-5 - Macular degeneration",
    "phecode_375-1 - Glaucoma",
    "phecode_388 - Blindness and low vision",
    # other eye
    "phecode_374-51 - Age-related macular degeneration",
    "phecode_367-5 - Uveitis"
]

prefix = "220812"
suffix = "_features3"

In [None]:
num_per_endpoint = 4
samples_per_image = 256

for combined_endpoint in endpoints:
    endpoint, endpoint_name = combined_endpoint.split(" - ")
    endpoint_idx = np.argwhere([l == endpoint for l in datamodule.labels])[0, 0]

    endpoint_records = records[records.concept_id.isin([endpoint.replace("-", ".")])].reset_index()
    # endpoint_records = endpoint_records.query("date <= recruitment_date")

    endpoint_records["td"] = (endpoint_records.date - endpoint_records.recruitment_date).abs()
    endpoint_records = endpoint_records.sort_values("td").drop_duplicates(
        subset=["eid"], keep="first"
    )
    all_eids = set(endpoint_records.eid.values)
    eid_subset = endpoint_records[endpoint_records.td < datetime.timedelta(days=180)].eid.values

    test_eid_subset = [e for e in eid_subset if e in datamodule.test_dataset.eids]
    sample_idxs = [list(datamodule.test_dataset.eids).index(e) for e in test_eid_subset]

    for idx in range(min(num_per_endpoint, len(sample_idxs))):
        path = datamodule.test_dataset.retina_map["file_path"].values[sample_idxs[idx]]
        img = datamodule.test_dataset.loader(path)

        attributions = []
        for _ in range(samples_per_image):
            img_cropped, img_cropped_coords = crop_transform(img)
            img_tensor = remaining_transforms(img_cropped)

            crop_h = (img_tensor.shape[1] - 384) // 2
            crop_w = (img_tensor.shape[2] - 384) // 2

            input = img_tensor[None, :, crop_w:-crop_w, crop_h:-crop_h].clone().detach().to(device)
            model_loss_closure = lambda x: -model_wrapped(x)[:, endpoint_idx].mean()
            saliency_map = iba.analyze(input, model_loss_closure, beta=10)

            saliency_map_full = np.zeros((420, 420), dtype=np.float32) * np.nan
            saliency_map_full[crop_w:-crop_w, crop_h:-crop_h] = saliency_map

            saliency_map_full_resized = TF.resize(
                torch.from_numpy(saliency_map_full).unsqueeze(0),
                img_cropped_coords[2:],
                PIL.Image.BICUBIC,
            ).numpy()[0]

            attribution = np.zeros((img.height, img.width), dtype=np.float32) * np.nan
            attribution[
                img_cropped_coords[0] : img_cropped_coords[0] + img_cropped_coords[2],
                img_cropped_coords[1] : img_cropped_coords[1] + img_cropped_coords[3],
            ] = saliency_map_full_resized

            attributions.append(attribution)

        attribution = np.stack(attributions)
        attribution = np.nanmean(attribution, axis=0)
        attribution[np.isnan(attribution)] = 0

        fig, ax = captum.attr.visualization.visualize_image_attr_multiple(
            attribution[:, :, None],
            np.asarray(img),
            ["original_image", "heat_map", "blended_heat_map"],
            ["all", "positive", "positive"],
            use_pyplot=False,
            fig_size=(12, 6),
        )
        managed_fig = plt.figure()
        canvas_manager = managed_fig.canvas.manager
        canvas_manager.canvas.figure = fig
        fig.set_canvas(canvas_manager.canvas)
        plt.suptitle(f"Attribution for {endpoint_name} (IBA)", y=0.85, fontsize=16)
        plt.tight_layout()
        plt.savefig(
            f"{BASE_DIR}/attributions/{prefix}_{datamodule.test_dataset.eids[sample_idxs[idx]]}_{endpoint}_{endpoint_name}_IBA{suffix}.png",
            dpi=300,
            bbox_inches="tight",
        )

In [None]:
for combined_endpoint in endpoints:
    endpoint, endpoint_name = combined_endpoint.split(" - ")

    base_path = pathlib.Path(
        #f"/sc-projects/sc-proj-ukb-cvd/results/projects/22_retina_phewas_220608/attributions/"
        f"{BASE_DIR}/attributions/"
    )

    p = pathlib.Path(base_path).glob(f"*{endpoint}*{suffix}*")

    files = [x for x in p if x.is_file()]

    if not len(files):
        continue

    images = [PIL.Image.open(x) for x in files]
    widths, heights = zip(*(i.size for i in images))

    total_width = max(widths)
    max_height = sum(heights)

    new_im = PIL.Image.new("RGB", (total_width, max_height))

    x_offset = 0
    for im in images:
        new_im.paste(im, (0, x_offset))
        x_offset += im.size[1]

    new_im.save(base_path / f"{endpoint_name}{suffix}.png")