#Importing Necessary Libraries and modules

In [None]:
import torch
import torch.nn as nn
import cv2
import numpy as np
import matplotlib.pyplot as plt
from PIL import Image


#GradCAM++ Class

In [None]:
class GradCAMPlusPlus:
    def __init__(self, model, target_layer):
        self.model = model
        self.model.eval()
        self.target_layer = target_layer
        self.activations = None
        self.gradients = None

        def forward_hook(module, inp, out):
            self.activations = out.detach()

        def backward_hook(module, grad_in, grad_out):
            self.gradients = grad_out[0].detach()

        target_layer.register_forward_hook(forward_hook)
        target_layer.register_backward_hook(backward_hook)

    def generate(self, x):
        self.model.zero_grad()
        logits = self.model(x)
        class_idx = logits.argmax(dim=1).item()
        loss = logits[:, class_idx]
        loss.backward()

        grad = self.gradients
        act = self.activations

        numerator = grad.pow(2)
        denominator = (2 * grad.pow(2)) + (act * grad.pow(3)).sum(dim=(2,3), keepdim=True)
        denominator = torch.where(denominator != 0, denominator, torch.ones_like(denominator))
        alpha = numerator / denominator

        weights = (alpha * torch.relu(grad)).sum(dim=(2,3), keepdim=True)
        cam = (weights * act).sum(dim=1).squeeze()

        cam = torch.relu(cam)
        cam -= cam.min()
        cam /= (cam.max() + 1e-9)

        return cam.cpu().numpy()


#Detect Last Conv Layer

In [None]:
def find_last_conv_layer(model):
    last = None
    for name, module in model.named_modules():
        if isinstance(module, nn.Conv2d):
            last = module
    return last

target_layer = find_last_conv_layer(model)
campp = GradCAMPlusPlus(model, target_layer)
print("Using target conv layer:", target_layer)


#Display GradCAM++ Overlay on Image

In [None]:
def show_gradcam_pp(img_path):
    orig = cv2.imread(img_path, cv2.IMREAD_GRAYSCALE)
    orig = cv2.resize(orig, (224, 224))
    pil_img = Image.open(img_path).convert("L")
    x = test_tfms(pil_img).unsqueeze(0).to(device)

    cam = campp.generate(x)
    cam_resized = cv2.resize(cam, (224, 224))

    heatmap = cv2.applyColorMap(np.uint8(255 * cam_resized), cv2.COLORMAP_JET)
    orig_rgb = cv2.cvtColor(orig, cv2.COLOR_GRAY2BGR)
    overlay = (0.4 * heatmap + 0.6 * orig_rgb).astype(np.uint8)

    plt.figure(figsize=(7,7))
    plt.imshow(overlay[:,:,::-1])
    plt.title("GradCAM++ Overlay")
    plt.axis("off")
    plt.show()


In [None]:
from tqdm import tqdm

#RISE Class Implementation

In [None]:
class RISE:
    def __init__(self, model, input_size=(224, 224), N=4000, s=7, p=0.5):
        self.model = model.eval()
        self.H, self.W = input_size
        self.N = N
        self.s = s
        self.p = p
        print("⚡ Generating random masks...")
        self.masks = self.generate_masks()

    def generate_masks(self):
        cell_h = self.H // self.s
        cell_w = self.W // self.s
        masks = []
        for _ in tqdm(range(self.N)):
            grid = np.random.choice([0, 1], size=(self.s, self.s), p=[1-self.p, self.p]).astype(np.float32)
            mask = cv2.resize(grid, (self.W, self.H), interpolation=cv2.INTER_NEAREST)
            masks.append(mask)
        return np.stack(masks)

    @torch.no_grad()
    def explain(self, img_tensor):
        saliency = np.zeros((self.H, self.W), dtype=np.float32)
        for mask in tqdm(self.masks, desc="RISE"):
            masked_img = img_tensor * torch.tensor(mask).to(img_tensor.device)
            output = self.model(masked_img)
            prob = torch.sigmoid(output).item()
            saliency += prob * mask
        saliency /= self.N
        saliency -= saliency.min()
        saliency /= saliency.max() + 1e-9
        return saliency


#Running & Visualizing RISE

In [None]:
def show_rise(img_path):
    img = cv2.imread(img_path, cv2.IMREAD_GRAYSCALE)
    img = cv2.resize(img, (224, 224))
    pil = Image.open(img_path).convert("L")
    x = test_tfms(pil).unsqueeze(0).to(device)

    rise = RISE(model, input_size=(224,224), N=2000, s=7, p=0.5)
    sal = rise.explain(x)

    heatmap = cv2.applyColorMap(np.uint8(255 * sal), cv2.COLORMAP_JET)
    orig_rgb = cv2.cvtColor(img, cv2.COLOR_GRAY2BGR)
    overlay = (0.5 * heatmap + 0.5 * orig_rgb).astype(np.uint8)

    plt.figure(figsize=(7,7))
    plt.imshow(overlay[:,:,::-1])
    plt.title("RISE Saliency Map")
    plt.axis("off")
    plt.show()


In [None]:
from sklearn.metrics import roc_auc_score

model.eval()























<br><br><br><br><br>
<br><br><br><br><br>























#ADDING FAKE 'R' ON OUR NORMAL IMAGES AND THEN TESTING OUR METRICS

#Adding Artificial Marker function to add on normal images

In [None]:
def get_prob(img_tensor):
    with torch.no_grad():
        out = model(img_tensor.to(device))
        prob = torch.sigmoid(out).item()
    return prob

In [None]:
def add_R_marker(gray_224):
    img = gray_224.copy()
    cv2.putText(img, 'R', org=(5, 70), fontFace=cv2.FONT_HERSHEY_SIMPLEX,
                fontScale=1.8, color=(255,), thickness=3, lineType=cv2.LINE_AA)
    return img


In [None]:
test_norm_dir = "/content/chest_xray/test/NORMAL"
norm_paths = [os.path.join(test_norm_dir, f) for f in os.listdir(test_norm_dir)]


#Computing Baseline vs R-Marked Predictions

In [None]:
baseline_probs = []
R_probs = []

for p in tqdm(norm_paths, desc="NORMALs with and without R"):
    g = cv2.imread(p, cv2.IMREAD_GRAYSCALE)
    g = cv2.resize(g, (224, 224))

    pil_base = Image.fromarray(g).convert("L")
    x_base = test_tfms(pil_base).unsqueeze(0)
    baseline_probs.append(get_prob(x_base))

    g_R = add_R_marker(g)
    pil_R = Image.fromarray(g_R).convert("L")
    x_R = test_tfms(pil_R).unsqueeze(0)
    R_probs.append(get_prob(x_R))


#Converting to Arrays + Printing Stats

In [None]:
baseline_probs = np.array(baseline_probs)
R_probs = np.array(R_probs)

print("Baseline NORMAL mean prob:", baseline_probs.mean())
print("With fake R  mean prob  :", R_probs.mean())
print("Mean Δ prob (R - base)  :", (R_probs - baseline_probs).mean())


#Statistical Significance Test

In [None]:
from scipy.stats import ttest_rel
t, p = ttest_rel(R_probs, baseline_probs)
print("t =", t, "p =", p)
