In [25]:
import os

import cv2
import h5py
import matplotlib.pyplot as plt
import numpy as np
import openslide
import torch
from datasets.dataset_h5 import eval_transforms
from models.model_clam import CLAM_MB
from models.resnet_custom_grad import resnet50_baseline
from torch.autograd import grad
from torchvision import models, transforms, utils

In [26]:
def seed_torch(seed=7):
    import random

    random.seed(seed)
    os.environ["PYTHONHASHSEED"] = str(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    # if device.type == "cuda":
    #     torch.cuda.manual_seed(seed)
    #     torch.cuda.manual_seed_all(seed)  # if you are using multi-GPU.
    torch.backends.cudnn.benchmark = False
    torch.backends.cudnn.deterministic = True


seed_torch(1)

In [27]:
def initiate_model(ckpt_path):

    model = CLAM_MB(dropout=True, n_classes=6, subtyping=True)

    ckpt = torch.load(ckpt_path)
    ckpt_clean = {}
    for key in ckpt.keys():
        if "instance_loss_fn" in key:
            continue
        ckpt_clean.update({key.replace(".module", ""): ckpt[key]})
    model.load_state_dict(ckpt_clean, strict=True)

    model.relocate()
    model.eval()
    model.cpu()

    return model

In [28]:
model = initiate_model("/home/jupyter/CLAM/results/disorder_tau_j_s1/s_0_checkpoint.pt")

In [29]:
resnet = resnet50_baseline()

In [30]:
mean = (0.485, 0.456, 0.406)
std = (0.229, 0.224, 0.225)

toTensor = transforms.ToTensor()
normalize = transforms.Normalize(mean=mean, std=std)

In [10]:
def explain(
    resnet,
    model,
    slide_name,
    target_idx,
    features_dir,
    slide_dir,
    target_label=0,
    full_pipeline=False,
):
    # load features
    target_features_path = os.path.join(features_dir, "{}.h5".format(slide_name))
    with h5py.File(target_features_path, "r") as hdf5_file:
        features = hdf5_file["features"][:]
        coords = hdf5_file["coords"][:]
    features = torch.from_numpy(features).cuda()

    # grab targeted tile from wsi
    target_coords = coords[target_idx]
    target_coords_x, target_coords_y = target_coords
    print(target_coords_x, target_coords_y)
    target_slide_path = os.path.join(slide_dir, "{}.svs".format(slide_name))
    wsi = openslide.open_slide(target_slide_path)
    target_tile = wsi.read_region(
        (target_coords_x, target_coords_y), 0, (256, 256)
    ).convert("RGB")
    # display(target_tile)
    target_tile = toTensor(target_tile)
    target_tile = normalize(target_tile).unsqueeze(0)
    # pass image through resnet

    y = resnet(target_tile)

    # replace target feature and pass through model, if needed
    if not full_pipeline:
        l, y, y1, a, r = model(y)
        print(l.shape)

    # back prop and grab target gradient
    loss_metric = torch.nn.CrossEntropyLoss()

    target_label = torch.Tensor([target_label]).long()
    loss = loss_metric(y, target_label)
    loss.backward()

    gradients = resnet.grad
    # print(grad.shape)

    activations = resnet.get_activation(target_tile).detach()
    # print(activation.shape)
    # generate explaination heatmap

    pooled_gradients = torch.mean(gradients, dim=[0, 2, 3])

    for i in range(512):
        activations[:, i, :, :] *= pooled_gradients[i]

    heatmap = torch.mean(activations, dim=1).squeeze()

    heatmap = np.maximum(heatmap, 0)

    heatmap /= torch.max(heatmap)

    heatmap.squeeze()
    img = wsi.read_region((target_coords_x, target_coords_y), 0, (256, 256)).convert(
        "RGB"
    )
    img = np.array(img)
    print(img.shape)
    heatmap = cv2.resize(np.float32(heatmap), (img.shape[1], img.shape[0]))
    heatmap = np.uint8(255 * heatmap)
    heatmap = cv2.applyColorMap(heatmap, cv2.COLORMAP_JET)
    superimposed_img = 0.4 * heatmap + 0.6 * img
    cv2.imwrite("/home/jupyter/CLAM/grad_results/map_sample.jpg", superimposed_img)

In [None]:
slide_name = "1057550"
slides_dir = "WSI_IHC_J/"
features_dir = "WSI_tau_j_features/h5_files/"
explain(resnet, model, slide_name, target_idx, features_dir, slides_dir, target_label=0)