In [4]:
import os
import h5py
import torch
import torch.nn as nn
from torch.utils.data import Dataset, DataLoader, random_split
import matplotlib.pyplot as plt
import torch.nn.functional as F
import numpy as np
import random
import scipy.stats as stats
import pandas as pd


!pip install grad-cam==1.5.5
from pytorch_grad_cam import GradCAM
from pytorch_grad_cam.utils.model_targets import ClassifierOutputTarget
from pytorch_grad_cam.utils.image import show_cam_on_image




class ACDCDataset(Dataset):
    def __init__(self, root_dir, transform=None, crop=True, crop_size=128):
        self.root_dir = root_dir
        self.files = [os.path.join(root_dir, f) for f in os.listdir(root_dir) if f.endswith(".h5")]
        self.transform = transform
        self.crop = crop
        self.crop_size = crop_size  # final size after cropping (square)

    def __len__(self):
        return len(self.files)

    def __getitem__(self, idx):
        file_path = self.files[idx]
        with h5py.File(file_path, "r") as f:
            image = np.array(f["image"])
            mask  = np.array(f["label"])

        # normalize to [0,1]
        image = image.astype(np.float32)
        image = (image - image.min()) / (image.max() - image.min() + 1e-8)

        # crop around heart region if enabled
        if self.crop and np.any(mask > 0):  # only if mask not empty
            coords = np.argwhere(mask > 0)
            y_min, x_min = coords.min(axis=0)
            y_max, x_max = coords.max(axis=0)

            # add padding
            pad = 15
            y_min = max(y_min - pad, 0)
            x_min = max(x_min - pad, 0)
            y_max = min(y_max + pad, image.shape[0])
            x_max = min(x_max + pad, image.shape[1])

            # crop image and mask
            image = image[y_min:y_max, x_min:x_max]
            mask  = mask[y_min:y_max, x_min:x_max]

        # resize/crop to fixed size
        if self.crop_size:
            from skimage.transform import resize
            image = resize(image, (self.crop_size, self.crop_size), preserve_range=True, anti_aliasing=True)
            mask  = resize(mask, (self.crop_size, self.crop_size), order=0, preserve_range=True, anti_aliasing=False)

        # add channel dim (C,H,W)
        image = np.expand_dims(image, axis=0)
        #mask  = np.expand_dims(mask, axis=0)

        image = torch.tensor(image, dtype=torch.float32)
        mask  = torch.tensor(mask, dtype=torch.long)

        return image, mask






  # ---------------  # Basic UNet Block----------------
class DoubleConv(nn.Module):
    def __init__(self, in_channels, out_channels):
        super(DoubleConv, self).__init__()
        self.net = nn.Sequential(
            nn.Conv2d(in_channels, out_channels, kernel_size=3, padding=1),
            nn.BatchNorm2d(out_channels),
            nn.ReLU(inplace=True),
            nn.Conv2d(out_channels, out_channels, kernel_size=3, padding=1),
            nn.BatchNorm2d(out_channels),
            nn.ReLU(inplace=True),
        )

    def forward(self, x):
        return self.net(x)

# -------------------------------
# UNet Model
# -------------------------------
class UNet(nn.Module):
    def __init__(self, in_channels=1, out_channels=2):
        super(UNet, self).__init__()
        self.enc1 = DoubleConv(in_channels, 64)
        self.enc2 = DoubleConv(64, 128)
        self.enc3 = DoubleConv(128, 256)
        self.enc4 = DoubleConv(256, 512)

        self.pool = nn.MaxPool2d(2)

        self.bottleneck = DoubleConv(512, 1024)

        self.up4 = nn.ConvTranspose2d(1024, 512, kernel_size=2, stride=2)
        self.dec4 = DoubleConv(1024, 512)
        self.up3 = nn.ConvTranspose2d(512, 256, kernel_size=2, stride=2)
        self.dec3 = DoubleConv(512, 256)
        self.up2 = nn.ConvTranspose2d(256, 128, kernel_size=2, stride=2)
        self.dec2 = DoubleConv(256, 128)
        self.up1 = nn.ConvTranspose2d(128, 64, kernel_size=2, stride=2)
        self.dec1 = DoubleConv(128, 64)

        self.final = nn.Conv2d(64, out_channels, kernel_size=1)

    def forward(self, x):
        e1 = self.enc1(x)
        e2 = self.enc2(self.pool(e1))
        e3 = self.enc3(self.pool(e2))
        e4 = self.enc4(self.pool(e3))

        b = self.bottleneck(self.pool(e4))

        d4 = self.up4(b)
        d4 = self.dec4(torch.cat([d4, e4], dim=1))
        d3 = self.up3(d4)
        d3 = self.dec3(torch.cat([d3, e3], dim=1))
        d2 = self.up2(d3)
        d2 = self.dec2(torch.cat([d2, e2], dim=1))
        d1 = self.up1(d2)
        d1 = self.dec1(torch.cat([d1, e1], dim=1))

        return self.final(d1)

#--------------------------------- grad_cam relevent functions ------------
# class SegmentationTarget:
#       def __init__(self, category, mask_tensor):
#           """
#           category: int, the class index you want to explain (e.g., 1 for LV, 2 for myocardium)
#           mask_tensor: torch.Tensor of shape (H, W), region of interest
#           """
#           self.category = category
#           self.mask_tensor = mask_tensor

#       def __call__(self, model_output):
#           # model_output shape: (C, H, W)
#           class_map = model_output[self.category, :,  :]
#           return (class_map * self.mask_tensor.to(class_map.device)).sum()

# class SegmentationTarget:
#     def __init__(self, mask_tensor, target_class):
#         self.mask_tensor = mask_tensor  # (1, H, W)
#         self.target_class = target_class

#     def __call__(self, model_output):
#         # Keep everything as a tensor
#         class_map = model_output[self.target_class, :, :]  # Still a torch tensor
#         # Ensure mask is on same device and type
#         mask = self.mask_tensor.to(class_map.device).float()
#         # Element-wise multiply and sum
#         return (class_map * mask).sum()

class SegmentationTarget:
    def __init__(self, mask_tensor, target_class):
        self.mask_tensor = mask_tensor
        self.target_class = target_class
        # self.batch_idx = batch_idx

    def __call__(self, model_output):
        # model_output: (B, C, H, W)
        class_map = model_output[ self.target_class, :, :]
        mask = self.mask_tensor.to(class_map.device).float()
        if mask.ndim == 3:
            mask = mask.squeeze(0)
        return (class_map * mask).sum()

def generate_gradcam(model, image, mask, target_class, target_layer):

  model.eval()
  cam = GradCAM(model=model, target_layers=[target_layer])


  targets = [SegmentationTarget(target_class, mask)]
  input_tensor = image.to("cuda")   # shape (1, 1, H, W)

  grayscale_cam = cam(input_tensor=input_tensor, targets=targets)
  grayscale_cam = grayscale_cam[0, :]   # first image in batch
  return grayscale_cam

def gradcam_overlap_score(gradcam, gt_mask, class_id=2, threshold=0.4, metric='dice'):
    """
    Compute overlap between Grad-CAM heatmap and a ground-truth region (e.g., myocardium).

    Args:
        gradcam (ndarray): Grad-CAM heatmap (values in [0, 1]).
        gt_mask (ndarray): Ground truth segmentation mask.
        class_id (int): Target class label (e.g., 2 for myocardium).
        threshold (float): Threshold to binarize Grad-CAM.
        metric (str): 'dice' or 'iou'.

    Returns:
        float: Overlap score.
    """
    gradcam_bin = (gradcam > threshold).astype(np.uint8)
    gt_bin = (gt_mask == class_id).astype(np.uint8)

    intersection = np.sum(gradcam_bin * gt_bin)
    union = np.sum(gradcam_bin) + np.sum(gt_bin)

    if metric == 'dice':
        return (2.0 * intersection) / (union + 1e-8)
    elif metric == 'iou':
        union_iou = np.sum((gradcam_bin + gt_bin) > 0)
        return intersection / (union_iou + 1e-8)
    else:
        raise ValueError("metric should be 'dice' or 'iou'")


def evaluate_gradcam_for_loader(model, loader, target_layer, class_idx, bin_k=0.2, device='cuda'):
  model.eval()
  cam = GradCAM(model=model, target_layers=[target_layer])
  per_image_scores = []
  # optional: collect cams/images for examples
  for images, masks in loader:

      targets = [ SegmentationTarget(mask_tensor=masks[i], target_class= class_idx)  for i in range(images.shape[0])]

      grayscale_cams = cam(input_tensor=images, targets=targets)

      # visualization

      # for i in range(images.shape[0]):
      #     img = images[i].cpu().squeeze().numpy()
      #     img = np.stack([img]*3, axis=-1)  # make it 3-channel for visualization
      #     cam_image = show_cam_on_image(img, grayscale_cams[i], use_rgb=True)
      #     plt.figure()
      #     plt.title(f"Grad-CAM for Image {i}")
      #     plt.imshow(cam_image)
      #     plt.axis("off")
      #     plt.show()


        # images = images.to(device)
      masks_np = masks.cpu().numpy()  # (B,H,W)
        # # make targets per image so cam returns one heatmap per input
        # # (assumes SegmentationTarget callable exists)
        # # We'll call cam per-batch but pass per-image targets
        # targets = []
        # for b in range(images.shape[0]):
        #     gt_class_mask = (masks_np[b] == class_idx).astype(np.uint8)
        #     targets.append(SegmentationTarget(class_idx, gt_class_mask))
        # # get cams: shape (B, Hcam, Wcam)

        # grayscale_cams = cam(input_tensor=images, targets=targets)
        # # iterate each sample:
      for b in range(images.shape[0]):
            gcam = grayscale_cams[b]
            # upsample if needed
            if gcam.shape != masks_np[b].shape:
                gcam_t = torch.tensor(gcam).unsqueeze(0).unsqueeze(0)
                gcam_t = torch.nn.functional.interpolate(gcam_t, size=masks_np[b].shape, mode='bilinear', align_corners=False)
                gcam = gcam_t.squeeze().cpu().numpy()

            cam_bin = binarize_cam(gcam, method='topk', k=bin_k)
            gt_bin = (masks_np[b] == class_idx).astype(np.uint8)
            score = dice_from_masks(cam_bin, gt_bin)
            per_image_scores.append(score)

            #----------------visualization  -----------
            # plt.figure(figsize=(12,4))

            # # Original MRI
            # plt.subplot(1,3,1)
            # plt.imshow(images[b,0].cpu(), cmap='gray')
            # plt.title("Input MRI")

            # # Ground truth mask
            # plt.subplot(1,3,2)
            # plt.imshow(masks[b].cpu(), cmap='gray')
            # plt.title("Ground truth mask")

            # # Grad-CAM heatmap
            # plt.subplot(1,3,3)
            # plt.imshow(gcam, cmap='jet')
            # plt.title("Grad-CAM")

            # plt.show()
            # input("Press Enter to continue...") # Comment out this line

  return np.array(per_image_scores)


def binarize_cam(cam, method='topk', k=0.2, thresh=None):
    """
    Convert a Grad-CAM heatmap into a binary mask.

    Parameters
    ----------
    cam : np.ndarray
        Grad-CAM heatmap normalized between 0 and 1, shape (H, W)
    method : str
        'topk' – keep top-k fraction of pixels as 1
        'thresh' – use a fixed threshold value
    k : float
        Fraction (for topk) – e.g. 0.2 keeps the top 20% highest values.
    thresh : float
        Explicit threshold (for 'thresh' method) between 0 and 1.

    Returns
    -------
    binary_mask : np.ndarray of dtype uint8
        Binary map (0 or 1)
    """
    cam = np.clip(cam, 0, 1)  # ensure within [0,1]
    if method == 'topk':
        flat = cam.flatten()
        cutoff = np.quantile(flat, 1 - k)
        mask = (cam >= cutoff).astype(np.uint8)
    elif method == 'thresh':
        t = thresh if thresh is not None else 0.5
        mask = (cam >= t).astype(np.uint8)
    else:
        raise ValueError(f"Unknown method: {method}")
    return mask

def dice_from_masks(pred_mask, gt_mask, eps=1e-8):
    """
    Compute Dice score between two binary masks.
    """
    pred = pred_mask.astype(bool)
    gt   = gt_mask.astype(bool)
    inter = np.logical_and(pred, gt).sum()
    union = pred.sum() + gt.sum()
    if union == 0:
        return 1.0  # Or 0.0, depending on convention for empty masks
    return (2.0 * inter) / (union + eps)

#----------------------------------------------------------------
from google.colab import drive
drive.mount('/content/drive')

dataset_path = "/content/drive/MyDrive/ACDC_preprocessed/ACDC_training_slices"

full_dataset = ACDCDataset(dataset_path)
#------------------------------
# Device
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print("Using device:", device)

  # Initialize model
model = UNet(in_channels=1, out_channels=4).to(device)

train_ratio = 0.8
train_size = int(train_ratio * len(full_dataset))
val_size = len(full_dataset) - train_size

dice_scores_seed = []
per_seed_scores = {}
all_seed_stats = []

training = 0  #whether the trained model weights are intended to be loaded or the training pahse should be performed

target_class=2   #1:RV, 2:myocardium, 3:LV
target_layer=model.dec2   #target layer
seed=5

# results_dir = "/content/results"
results_dir = "/content/drive/MyDrive/GRADCAM_Scores_results"
os.makedirs(results_dir, exist_ok=True)

for i in range(seed):
  torch.manual_seed(i)
  np.random.seed(i)
  random.seed(i)
  torch.backends.cudnn.deterministic = True
  torch.backends.cudnn.benchmark = False

  train_dataset, val_dataset = random_split(full_dataset, [train_size, val_size])

  # Create DataLoaders
  train_loader = DataLoader(train_dataset, batch_size=8, shuffle=True)
  val_loader   = DataLoader(val_dataset, batch_size=8, shuffle=False)

  # print(f"Total samples: {len(full_dataset)}")
  # print(f"Training samples: {len(train_dataset)}")
  # print(f"Validation samples: {len(val_dataset)}")

  # -------------------------


  if  training:     #if training is not run before

  # Loss & optimizer
      criterion = nn.CrossEntropyLoss()   # good for segmentation with class labels
      optimizer = torch.optim.Adam(model.parameters(), lr=1e-4)
      num_epochs = 5  # start small for testing

      for epoch in range(num_epochs):
          model.train()
          total_loss = 0
          for images, masks in train_loader:
              images, masks = images.to(device), masks.to(device)

              # Forward
              outputs = model(images)
              loss = criterion(outputs, masks)

              # Backward
              optimizer.zero_grad()
              loss.backward()
              optimizer.step()

              total_loss += loss.item()

          avg_loss = total_loss / len(train_loader)
          print(f"Epoch [{epoch+1}/{num_epochs}], Loss: {avg_loss:.4f}")

      torch.save(model.state_dict(), "/content/drive/MyDrive/unet_resnet_acdc.pth")

      print("weight saved")
  else:

      model.load_state_dict(torch.load("/content/drive/MyDrive/unet_resnet_acdc.pth", map_location="cpu"))

  #-------------- validation loop-----------
  model.eval()   #evaluation of first epoch of images
  with torch.no_grad():
      for images, masks in val_loader:
          images, masks = images.to(device), masks.to(device)
          outputs = model(images)
          preds = torch.argmax(outputs, dim=1)  # (B, H, W)
          print("Pred mask shape:", preds.shape)
          break







  # ---------  dice score for whole dataset  ---------------
  # def dice_score_dataset(loader, model, num_classes, device="cuda"):
  #     model.eval()
  #     total_intersection = torch.zeros(num_classes, device=device)
  #     total_union = torch.zeros(num_classes, device=device)

  #     with torch.no_grad():
  #         for images, masks in loader:
  #             images, masks = images.to(device), masks.to(device)

  #             outputs = model(images)
  #             preds = torch.argmax(outputs, 1)

  #             for c in range(num_classes):
  #                 pred_c = (preds == c).float()
  #                 target_c = (masks == c).float()

  #                 total_intersection[c] += (pred_c * target_c).sum()
  #                 total_union[c] += pred_c.sum() + target_c.sum()

  #     dice_scores = (2. * total_intersection) / (total_union + 1e-8)
  #     return dice_scores.cpu().numpy()  # one dice score per class



  # val_dice = dice_score_dataset(val_loader, model, num_classes=4)
  # print("Validation Dice per class:", val_dice)
  # print("Mean Dice:", val_dice.mean()

  #-----------  GRAD-CAM EVALUATION   --------------

  # import numpy as np
  # import torch
  # import torch.nn.functional as F
  # import pandas as pd
  # from tqdm import tqdm
  # from pytorch_grad_cam import GradCAM

  # # --- Dice and IoU metrics ---
  # def binarize_cam(cam, method='topk', k=0.2, thresh=0.3):
  #     if method == 'topk':
  #         flat = cam.flatten()
  #         cutoff = np.percentile(flat, 100*(1-k))
  #         mask = (cam >= cutoff).astype(np.uint8)
  #     else:
  #         mask = (cam >= thresh).astype(np.uint8)
  #     return mask

  # def dice_from_masks(pred_mask, gt_mask, eps=1e-8):
  #     pred = pred_mask.astype(bool)
  #     gt   = gt_mask.astype(bool)
  #     inter = np.logical_and(pred, gt).sum()
  #     union = pred.sum() + gt.sum()
  #     if union == 0:
  #         return 1.0
  #     return (2.0 * inter) / (union + eps)

  # def iou_from_masks(pred_mask, gt_mask, eps=1e-8):
  #     pred = pred_mask.astype(bool)
  #     gt   = gt_mask.astype(bool)
  #     inter = np.logical_and(pred, gt).sum()
  #     union = np.logical_or(pred, gt).sum()
  #     if union == 0:
  #         return 1.0
  #     return inter / (union + eps)

  # # --- Candidate layers from your U-Net ---
  # candidates = {
  #     'enc3': model.enc3,
  #     'enc4': model.enc4,
  #     'bottleneck': getattr(model, 'bottleneck', None),
  #     'dec1': getattr(model, 'dec1', None),
  # }
  # candidates = {k: v for k, v in candidates.items() if v is not None}

  # # Initialize GradCAM for each candidate layer
  # cams = {name: GradCAM(model=model, target_layers=[layer], device='cuda')
  #         for name, layer in candidates.items()}

  # # --- ACDC class mapping ---
  # class_map = {1: "RV", 2: "MYO", 3: "LV"}

  # # Store results
  # results = []

  # model.eval()
  # with torch.no_grad():
  #     for i, (images, masks) in enumerate(tqdm(val_loader)):
  #         images = images.to('cuda')
  #         masks  = masks.to('cpu').numpy()  # (B,H,W)
  #         B = images.shape[0]

  #         for b in range(B):
  #             inp = images[b:b+1]   # (1,1,H,W)
  #             gt  = masks[b]        # (H,W)

  #             for class_idx, class_name in class_map.items():
  #                 # ground truth binary mask
  #                 gt_class = (gt == class_idx).astype(np.uint8)

  #                 for name, cam in cams.items():
  #                     # GradCAM heatmap
  #                     targets = [SegmentationTarget(class_idx, gt_class)]
  #                     grayscale_cam = cam(input_tensor=inp, targets=targets)
  #                     gcam = grayscale_cam[0]

  #                     # resize CAM to match GT
  #                     if gcam.shape != gt.shape:
  #                         gcam = torch.tensor(gcam).unsqueeze(0).unsqueeze(0)
  #                         gcam = F.interpolate(gcam, size=gt.shape,
  #                                              mode='bilinear', align_corners=False).squeeze().cpu().numpy()

  #                     # Try multiple binarization strategies
  #                     for method, params in [('topk', [0.1, 0.2, 0.3]),
  #                                            ('thresh', [0.2, 0.3, 0.5])]:
  #                         for p in params:
  #                             if method == 'topk':
  #                                 cam_mask = binarize_cam(gcam, method='topk', k=p)
  #                                 label = f"top{int(p*100)}"
  #                             else:
  #                                 cam_mask = binarize_cam(gcam, method='thresh', thresh=p)
  #                                 label = f"thr{p}"

  #                             d = dice_from_masks(cam_mask, gt_class)
  #                             j = iou_from_masks(cam_mask, gt_class)

  #                             results.append({
  #                                 'sample': i,
  #                                 'layer': name,
  #                                 'class': class_name,
  #                                 'bin_method': label,
  #                                 'dice': d,
  #                                 'iou': j
  #                             })

  # # Save results to CSV
  # df = pd.DataFrame(results)
  # df.to_csv("gradcam_layer_comparison.csv", index=False)

  # # Print summary
  # summary = df.groupby(["layer", "class", "bin_method"]).mean().reset_index()
  # print(summary)





  #------------  add SE block  ---------------
  # class SEBlock(nn.Module):
  #     def __init__(self, in_channels, reduction=16):
  #         super().__init__()
  #         self.avg_pool = nn.AdaptiveAvgPool2d(1)
  #         self.fc = nn.Sequential(
  #             nn.Linear(in_channels, in_channels // reduction),
  #             nn.ReLU(inplace=True),
  #             nn.Linear(in_channels // reduction, in_channels),
  #             nn.Sigmoid()
  #         )

  #     def forward(self, x):
  #         b, c, _, _ = x.size()
  #         y = self.avg_pool(x).view(b, c)
  #         y = self.fc(y).view(b, c, 1, 1)
  #         return x * y.expand_as(x)


  # self.enc3 = nn.Sequential(
  #     resnet.layer3,
  #     SEBlock(256)
  # )





#-----------------  compute grad cam for each batch of images-----------
  # dice_scores = []

  # for img, mask in val_loader:
  #     output = model(img.to(device))
  #     gradcam = generate_gradcam(model, img.to(device), mask, target_class,  target_layer)

  #     # gradcam = gradcam.squeeze().cpu().numpy()
  #     mask = mask.squeeze().cpu().numpy()

  #     score = gradcam_overlap_score(gradcam, mask, class_id=2, threshold=0.4)
  #     dice_scores.append(score)

  # print(f"Mean Grad-CAM Dice overlap with myocardium: {np.mean(dice_scores):.3f}")

  # dice_scores_seed.append(np.mean(dice_scores))
  # print(i)


#  ----------- compute grad cam for each image not each batch

  scores = evaluate_gradcam_for_loader(model, val_loader, target_layer=model.dec2, class_idx=2, bin_k=0.2, device='cuda')
  per_seed_scores[i] = scores
  mean = scores[scores != 0].mean()
  std  = scores[scores != 0].std()
  n = scores.size
  sem = stats.sem(scores) if n>1 else 0.0
  ci95 = sem * stats.t.ppf((1+0.95)/2, n-1) if n>1 else 0.0
  all_seed_stats.append({'seed':i, 'mean':float(mean), 'std':float(std), 'n':int(n), 'ci95':float(ci95)})
  # optionally save per-seed raw scores
  pd.DataFrame({'score':scores}).to_csv(os.path.join(results_dir, f'seed_{i}_scores.csv'), index=False)


Mounted at /content/drive
Using device: cuda
Pred mask shape: torch.Size([8, 128, 128])
Pred mask shape: torch.Size([8, 128, 128])
Pred mask shape: torch.Size([8, 128, 128])
Pred mask shape: torch.Size([8, 128, 128])
Pred mask shape: torch.Size([8, 128, 128])


In [5]:
print(all_seed_stats)

[{'seed': 0, 'mean': 0.6572327828095725, 'std': 0.15774768597311406, 'n': 383, 'ci95': 0.01966288623565771}, {'seed': 1, 'mean': 0.6590698321415533, 'std': 0.16396609548718916, 'n': 383, 'ci95': 0.02041931488711316}, {'seed': 2, 'mean': 0.6666479441830517, 'std': 0.15875225201626064, 'n': 383, 'ci95': 0.021775022213965992}, {'seed': 3, 'mean': 0.6605974978072324, 'std': 0.15933921004100612, 'n': 383, 'ci95': 0.019306977281830087}, {'seed': 4, 'mean': 0.6511822754756053, 'std': 0.17143677411703928, 'n': 383, 'ci95': 0.02202292087861412}]
