This is a adaptation of PyTorch reimplementation of RISE from this repository: https://github.com/yiskw713/RISE

The RISE code has been modified and adapted to work with models with tabular data.

In [None]:
import numpy as np
import torch
from PIL import Image
from matplotlib.pyplot import imshow
from torchvision import transforms
from torchvision.utils import save_image
from rise import CXR_RISE
from utils.visualize import visualize, reverse_normalize
import sys

project_root = "/home/wasabi/PycharmProjects/cs-7643-final-project"
sys.path.insert(0, project_root)

from src.models.cxr_model import CXRModel

In [None]:
# Choose a X-Ray to analyze and paste it's Path
image = Image.open('/home/wasabi/PycharmProjects/cs-7643-final-project/artifacts/embedded_test/00000661_001.png')

# Choose the targeted Outcome Class Index
target_class = 1


image = image.convert('RGB')


In [None]:
normalize = transforms.Normalize(
   mean=[0.485, 0.456, 0.406],
   std=[0.229, 0.224, 0.225]
)
preprocess = transforms.Compose([
    transforms.Resize((224, 224)),
    transforms.ToTensor(),
    normalize
])


In [None]:
tensor = preprocess(image).unsqueeze(0)
_, _, H, W = tensor.shape
device = 'cuda' if torch.cuda.is_available() else 'cpu'
tensor = tensor.to(device)


In [None]:
# model_path = 'embd_vit_b_32_lr_1e-05_bs_32_do_0.2_hd_None_ms_32_best.pth'
# model_path = 'vit_b_32_lr_1e-05_bs_32_do_0.2_hd_(512, 256, 128, 64, 32)_best.pth'
model_path = '/tmp/cs7643_final_share/emad_results/best_model_vit_b_32_embedded_focal.pth'

save_info = torch.load(model_path)
print(save_info["config"])
model = CXRModel(**save_info["config"])
model.load_state_dict(save_info["model"])
model.to(device)
model.eval()

In [None]:
tabular_data = torch.tensor([[0.5, 0.5, 0.5, 0.0]])

wrapped_model = CXR_RISE(model, tabular_data, input_size=(H, W))


In [None]:
print(target_class)
with torch.no_grad():
    saliency = wrapped_model(tensor)

In [None]:
saliency = saliency[target_class]

In [None]:
img = reverse_normalize(tensor.to('cpu'))
saliency = saliency.view(1, 1, H, W)
heatmap = visualize(img, saliency)


In [None]:
save_image(heatmap, 'class_cardiomegaly_explanation_new.png')
