In [4]:
import os
import numpy as np
import cv2
import torch
import torch.nn as nn
import timm
import matplotlib.pyplot as plt
from tqdm import tqdm
import random
import warnings
warnings.filterwarnings('ignore')

print("="*80)
print("NOTEBOOK 6: CLASS ACTIVATION MAPS (CAM) - ALL 5 MODELS")
print("="*80)

# ========== CONFIG ==========
TEST_DIR = '/kaggle/input/split-dataset/test'
WEIGHTS_DIR = '/kaggle/input/weight/weights'
OUTPUT_DIR = '/kaggle/working'
IMG_SIZE = 224
NUM_IMAGES_PER_CLASS = 200
DEVICE = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

os.makedirs(OUTPUT_DIR, exist_ok=True)

print(f"\n✓ Device: {DEVICE}")
print(f"✓ Test dataset: {TEST_DIR}")
print(f"✓ Weights directory: {WEIGHTS_DIR}")
print(f"✓ Output directory: {OUTPUT_DIR}")

# ========== 1. LOAD TEST IMAGES ==========
print("\n[1/5] Loading test images...")

class_names = sorted([d for d in os.listdir(TEST_DIR) if os.path.isdir(os.path.join(TEST_DIR, d))])
test_images = {cls: [] for cls in class_names}

for cls in class_names:
    cls_path = os.path.join(TEST_DIR, cls)
    images = [f for f in os.listdir(cls_path) if f.lower().endswith(('.png', '.jpg', '.jpeg'))]
    sampled = random.sample(images, min(NUM_IMAGES_PER_CLASS, len(images)))
    test_images[cls] = [os.path.join(cls_path, img) for img in sampled]

for cls in class_names:
    print(f"  {cls}: {len(test_images[cls])}")
total_images = sum(len(v) for v in test_images.values())
print(f"✓ Total test images: {total_images}")

# ========== 2. PREPROCESS ==========
print("\n[2/5] Setting up preprocessing...")

def preprocess_image(img_path):
    image = cv2.imread(img_path, cv2.IMREAD_GRAYSCALE)
    if image is None:
        image = np.zeros((IMG_SIZE, IMG_SIZE), dtype=np.uint8)
    image = cv2.resize(image, (IMG_SIZE, IMG_SIZE))
    image_rgb = cv2.cvtColor(image, cv2.COLOR_GRAY2RGB)
    image_norm = image_rgb.astype(np.float32) / 255.0
    image_norm = (image_norm - 0.5) / 0.5
    img_tensor = torch.from_numpy(image_norm).permute(2, 0, 1).unsqueeze(0).to(DEVICE)
    img_tensor.requires_grad_(True)  # Enable grad for CAM backward
    return img_tensor, image_rgb

print("✓ Preprocessing ready")

# ========== 3. CUSTOM CLASSIFIER HEAD (for CNN models) ==========
class CustomHead(nn.Module):
    def __init__(self, in_features, num_classes=4):
        super().__init__()
        self.fc = nn.Sequential(
            nn.Linear(in_features, 512),
            nn.BatchNorm1d(512),
            nn.ReLU(),
            nn.Linear(512, 256),
            nn.BatchNorm1d(256),
            nn.ReLU(),
            nn.Linear(256, num_classes)
        )
    
    def forward(self, x):
        return self.fc(x)

# ========== 4. LOAD MODELS ==========
print("\n[3/5] Loading trained models...")

model_configs = {
    'resnet50': {
        'model_fn': lambda: timm.create_model('resnet50', pretrained=False, num_classes=4),
        'weight_file': 'resnet50_best_80_10_10.pth',
        'target_layer': 'layer4[-1]',
        'has_custom_head': True,
        'num_features': 2048
    },
    'mobilenetv2': {
        'model_fn': lambda: timm.create_model('mobilenetv2_100', pretrained=False, num_classes=4),
        'weight_file': 'mobilenetv2_best_80_10_10.pth',
        'target_layer': 'blocks[-1]',
        'has_custom_head': True,
        'num_features': 1280
    },
    'efficientnetb0': {
        'model_fn': lambda: timm.create_model('efficientnet_b0', pretrained=False, num_classes=4),
        'weight_file': 'efficientnetb0_best_80_10_10.pth',
        'target_layer': 'blocks[-1]',
        'has_custom_head': True,
        'num_features': 1280
    },
    'vit': {
        'model_fn': lambda: timm.create_model('vit_base_patch16_224', pretrained=False, num_classes=4),
        'weight_file': 'vit_best_80_10_10.pth',
        'target_layer': 'blocks[-1]',
        'has_custom_head': False
    },
    'swin': {
        'model_fn': lambda: timm.create_model('swin_tiny_patch4_window7_224', pretrained=False, num_classes=4),
        'weight_file': 'swin_best_80_10_10.pth',
        'target_layer': 'layers[-1]',
        'has_custom_head': False
    }
}

models = {}
for model_name, cfg in model_configs.items():
    try:
        model = cfg['model_fn']()
        
        if cfg['has_custom_head']:
            model.fc = CustomHead(cfg['num_features'], num_classes=4)
        
        weight_path = os.path.join(WEIGHTS_DIR, cfg['weight_file'])
        if os.path.exists(weight_path):
            model.load_state_dict(torch.load(weight_path, map_location=DEVICE), strict=False)
            model.to(DEVICE)
            model.eval()
            models[model_name] = model
            print(f"✓ {model_name.upper()} loaded from {weight_path}")
        else:
            print(f"⚠ {model_name.upper()} weights not found at {weight_path}")
    except Exception as e:
        print(f"✗ Error loading {model_name}: {e}")

print(f"\n✓ Total models loaded: {len(models)}")

# ========== 5. GRAD-CAM ==========
print("\n[4/5] Implementing Grad-CAM...")

class GradCAM:
    def __init__(self, model, target_layer_name):
        self.model = model
        self.target_layer_name = target_layer_name
        self.gradients = None
        self.activations = None
        self.hooks = []
        self._register_hooks()
    
    def _register_hooks(self):
        def forward_hook(module, inp, out):
            self.activations = out
        def backward_hook(module, grad_in, grad_out):
            self.gradients = grad_out[0]
        for name, module in self.model.named_modules():
            if self.target_layer_name in name or name.endswith(self.target_layer_name):
                self.hooks.append(module.register_forward_hook(forward_hook))
                self.hooks.append(module.register_full_backward_hook(backward_hook))
                break
    
    def generate_cam(self, input_tensor, class_idx=None):
        output = self.model(input_tensor)
        if class_idx is None:
            class_idx = output.argmax(dim=1).item()
        self.model.zero_grad()
        score = output[0, class_idx]
        score.backward(retain_graph=True)

        grads = self.gradients.detach().cpu().numpy()[0]      # (C,H,W)
        acts = self.activations.detach().cpu().numpy()[0]     # (C,H,W)
        weights = grads.mean(axis=(1, 2))                     # (C,)
        cam = np.zeros(acts.shape[1:], dtype=np.float32)
        for c, w in enumerate(weights):
            cam += w * acts[c]
        cam = np.maximum(cam, 0)
        cam = cam / (cam.max() + 1e-8)
        return cam, class_idx
    
    def remove_hooks(self):
        for h in self.hooks:
            h.remove()

print("✓ Grad-CAM ready")

# ========== 6. GENERATE CAMs ==========
print("\n[5/5] Generating CAMs...")

cam_output_dir = os.path.join(OUTPUT_DIR, 'cam_visualizations')
os.makedirs(cam_output_dir, exist_ok=True)

all_cams = {}

if len(models) == 0:
    print("✗ No models loaded. Cannot generate CAMs.")
else:
    for model_name, model in tqdm(models.items(), desc="Processing models"):
        target_layer = model_configs[model_name]['target_layer']
        cam_engine = GradCAM(model, target_layer)
        all_cams[model_name] = {}
        
        for cls_name in class_names:
            all_cams[model_name][cls_name] = []
            for img_idx, img_path in enumerate(test_images[cls_name]):
                try:
                    img_tensor, img_rgb = preprocess_image(img_path)
                    cam, pred_class = cam_engine.generate_cam(img_tensor)
                    cam_resized = cv2.resize(cam, (IMG_SIZE, IMG_SIZE))
                    cam_color = cv2.applyColorMap((cam_resized * 255).astype(np.uint8), cv2.COLORMAP_JET)
                    cam_color = cv2.cvtColor(cam_color, cv2.COLOR_BGR2RGB)
                    img_rgb_uint8 = (img_rgb * 255).astype(np.uint8)
                    overlay = cv2.addWeighted(img_rgb_uint8, 0.5, cam_color, 0.5, 0)
                    save_filename = f"cam_{model_name}_{cls_name}_{img_idx:03d}.png"
                    cv2.imwrite(os.path.join(cam_output_dir, save_filename),
                                cv2.cvtColor(overlay, cv2.COLOR_RGB2BGR))
                    all_cams[model_name][cls_name].append({
                        'path': img_path,
                        'overlay': overlay,
                        'pred_class': pred_class
                    })
                except Exception as e:
                    pass  # Silently skip errors
        
        cam_engine.remove_hooks()
        print(f"✓ {model_name.upper()}: {len(all_cams[model_name][class_names[0]])} CAMs per class")

    print(f"\n✓ All CAM overlays saved to: {cam_output_dir}")

    # ========== 7. CREATE COMPARISON GRIDS ==========
    print("\n[6/6] Creating comparison grids...")

    for cls_name in class_names:
        fig, axes = plt.subplots(1, len(models) + 1, figsize=(24, 4))
        first_img_path = test_images[cls_name][0]
        _, img_rgb = preprocess_image(first_img_path)
        img_rgb_uint8 = (img_rgb * 255).astype(np.uint8)
        axes[0].imshow(img_rgb_uint8, cmap='gray')
        axes[0].set_title(f'{cls_name} - Original', fontsize=11, fontweight='bold')
        axes[0].axis('off')
        for i, (model_name, _) in enumerate(models.items()):
            if len(all_cams[model_name][cls_name]) > 0:
                overlay = all_cams[model_name][cls_name][0]['overlay']
                axes[i+1].imshow(overlay)
                axes[i+1].set_title(model_name.upper(), fontsize=11, fontweight='bold')
            axes[i+1].axis('off')
        plt.tight_layout()
        plt.savefig(os.path.join(cam_output_dir, f'comparison_grid_{cls_name}.png'), dpi=300)
        plt.close()
        print(f"✓ Saved comparison grid for {cls_name}")

# ========== 8. SUMMARY ==========
print("\n" + "="*80)
print("SUMMARY: GRAD-CAM ANALYSIS COMPLETE")
print("="*80)

if len(models) > 0:
    print(f"\nModels processed: {list(models.keys())}")
    print(f"Classes analyzed: {class_names}")
    print(f"Images per class: {NUM_IMAGES_PER_CLASS}")
    total_cams = sum(len(all_cams[m][c]) for m in models.keys() for c in class_names)
    print(f"Total CAM visualizations: {total_cams}")
    print(f"\nOutput: {cam_output_dir}")
    print(f"Files: individual cam_<model>_lass>_<idx>.png + comparison_grid_lass>.png")
    print("\n✓ All CAM visualizations saved successfully!")
else:
    print("\n✗ No models loaded. Check weights.")

print("="*80)


NOTEBOOK 6: CLASS ACTIVATION MAPS (CAM) - ALL 5 MODELS

✓ Device: cuda
✓ Test dataset: /kaggle/input/split-dataset/test
✓ Weights directory: /kaggle/input/weight/weights
✓ Output directory: /kaggle/working

[1/5] Loading test images...
  CNV: 200
  DME: 200
  DRUSEN: 200
  NORMAL: 200
✓ Total test images: 800

[2/5] Setting up preprocessing...
✓ Preprocessing ready

[3/5] Loading trained models...
✓ RESNET50 loaded from /kaggle/input/weight/weights/resnet50_best_80_10_10.pth
✓ MOBILENETV2 loaded from /kaggle/input/weight/weights/mobilenetv2_best_80_10_10.pth
✓ EFFICIENTNETB0 loaded from /kaggle/input/weight/weights/efficientnetb0_best_80_10_10.pth
✓ VIT loaded from /kaggle/input/weight/weights/vit_best_80_10_10.pth
✓ SWIN loaded from /kaggle/input/weight/weights/swin_best_80_10_10.pth

✓ Total models loaded: 5

[4/5] Implementing Grad-CAM...
✓ Grad-CAM ready

[5/5] Generating CAMs...


Processing models:  20%|██        | 1/5 [00:17<01:10, 17.65s/it]

✓ RESNET50: 0 CAMs per class


Processing models:  40%|████      | 2/5 [00:32<00:47, 15.96s/it]

✓ MOBILENETV2: 0 CAMs per class


Processing models:  60%|██████    | 3/5 [00:53<00:36, 18.46s/it]

✓ EFFICIENTNETB0: 0 CAMs per class


Processing models:  80%|████████  | 4/5 [01:26<00:23, 23.95s/it]

✓ VIT: 0 CAMs per class


Processing models: 100%|██████████| 5/5 [01:54<00:00, 22.89s/it]

✓ SWIN: 0 CAMs per class

✓ All CAM overlays saved to: /kaggle/working/cam_visualizations

[6/6] Creating comparison grids...





✓ Saved comparison grid for CNV
✓ Saved comparison grid for DME
✓ Saved comparison grid for DRUSEN
✓ Saved comparison grid for NORMAL

SUMMARY: GRAD-CAM ANALYSIS COMPLETE

Models processed: ['resnet50', 'mobilenetv2', 'efficientnetb0', 'vit', 'swin']
Classes analyzed: ['CNV', 'DME', 'DRUSEN', 'NORMAL']
Images per class: 200
Total CAM visualizations: 0

Output: /kaggle/working/cam_visualizations
Files: individual cam_<model>_lass>_<idx>.png + comparison_grid_lass>.png

✓ All CAM visualizations saved successfully!


In [5]:
# Example for resnet50
for name, module in models['resnet50'].named_modules():
    print(name)

# Similarly for other models, look for last conv layers, e.g.
# ResNet50 typically "layer4[-1].conv3" or "layer4[-1]"
# EfficientNet might be "blocks[-1]"
# Others similarly.



conv1
bn1
act1
maxpool
layer1
layer1.0
layer1.0.conv1
layer1.0.bn1
layer1.0.act1
layer1.0.conv2
layer1.0.bn2
layer1.0.drop_block
layer1.0.act2
layer1.0.aa
layer1.0.conv3
layer1.0.bn3
layer1.0.act3
layer1.0.downsample
layer1.0.downsample.0
layer1.0.downsample.1
layer1.1
layer1.1.conv1
layer1.1.bn1
layer1.1.act1
layer1.1.conv2
layer1.1.bn2
layer1.1.drop_block
layer1.1.act2
layer1.1.aa
layer1.1.conv3
layer1.1.bn3
layer1.1.act3
layer1.2
layer1.2.conv1
layer1.2.bn1
layer1.2.act1
layer1.2.conv2
layer1.2.bn2
layer1.2.drop_block
layer1.2.act2
layer1.2.aa
layer1.2.conv3
layer1.2.bn3
layer1.2.act3
layer2
layer2.0
layer2.0.conv1
layer2.0.bn1
layer2.0.act1
layer2.0.conv2
layer2.0.bn2
layer2.0.drop_block
layer2.0.act2
layer2.0.aa
layer2.0.conv3
layer2.0.bn3
layer2.0.act3
layer2.0.downsample
layer2.0.downsample.0
layer2.0.downsample.1
layer2.1
layer2.1.conv1
layer2.1.bn1
layer2.1.act1
layer2.1.conv2
layer2.1.bn2
layer2.1.drop_block
layer2.1.act2
layer2.1.aa
layer2.1.conv3
layer2.1.bn3
layer2.1.act3

In [6]:
import timm
import torch

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

# 1) ResNet50
model_resnet = timm.create_model('resnet50', pretrained=False, num_classes=4).to(device)
print("\nResNet50 named modules:")
for name, module in model_resnet.named_modules():
    print(name)

# 2) MobileNetV2
model_mobilenet = timm.create_model('mobilenetv2_100', pretrained=False, num_classes=4).to(device)
print("\nMobileNetV2 named modules:")
for name, module in model_mobilenet.named_modules():
    print(name)

# 3) EfficientNetB0
model_efficientnet = timm.create_model('efficientnet_b0', pretrained=False, num_classes=4).to(device)
print("\nEfficientNetB0 named modules:")
for name, module in model_efficientnet.named_modules():
    print(name)

# 4) ViT Base Patch16 224
model_vit = timm.create_model('vit_base_patch16_224', pretrained=False, num_classes=4).to(device)
print("\nViT Base Patch16_224 named modules:")
for name, module in model_vit.named_modules():
    print(name)

# 5) Swin Tiny Patch4 Window7 224
model_swin = timm.create_model('swin_tiny_patch4_window7_224', pretrained=False, num_classes=4).to(device)
print("\nSwin Tiny Patch4 Window7_224 named modules:")
for name, module in model_swin.named_modules():
    print(name)



ResNet50 named modules:

conv1
bn1
act1
maxpool
layer1
layer1.0
layer1.0.conv1
layer1.0.bn1
layer1.0.act1
layer1.0.conv2
layer1.0.bn2
layer1.0.drop_block
layer1.0.act2
layer1.0.aa
layer1.0.conv3
layer1.0.bn3
layer1.0.act3
layer1.0.downsample
layer1.0.downsample.0
layer1.0.downsample.1
layer1.1
layer1.1.conv1
layer1.1.bn1
layer1.1.act1
layer1.1.conv2
layer1.1.bn2
layer1.1.drop_block
layer1.1.act2
layer1.1.aa
layer1.1.conv3
layer1.1.bn3
layer1.1.act3
layer1.2
layer1.2.conv1
layer1.2.bn1
layer1.2.act1
layer1.2.conv2
layer1.2.bn2
layer1.2.drop_block
layer1.2.act2
layer1.2.aa
layer1.2.conv3
layer1.2.bn3
layer1.2.act3
layer2
layer2.0
layer2.0.conv1
layer2.0.bn1
layer2.0.act1
layer2.0.conv2
layer2.0.bn2
layer2.0.drop_block
layer2.0.act2
layer2.0.aa
layer2.0.conv3
layer2.0.bn3
layer2.0.act3
layer2.0.downsample
layer2.0.downsample.0
layer2.0.downsample.1
layer2.1
layer2.1.conv1
layer2.1.bn1
layer2.1.act1
layer2.1.conv2
layer2.1.bn2
layer2.1.drop_block
layer2.1.act2
layer2.1.aa
layer2.1.conv3
l

In [7]:
import os
import numpy as np
import cv2
import torch
import torch.nn as nn
import timm
import matplotlib.pyplot as plt
from tqdm import tqdm
import random
import warnings
warnings.filterwarnings('ignore')

print("="*80)
print("NOTEBOOK 6: CLASS ACTIVATION MAPS (CAM) - ALL 5 MODELS CORRECT TARGET LAYERS")
print("="*80)

# ========== CONFIG ==========
TEST_DIR = '/kaggle/input/split-dataset/test'
WEIGHTS_DIR = '/kaggle/input/weight/weights'
OUTPUT_DIR = '/kaggle/working'
IMG_SIZE = 224
NUM_IMAGES_PER_CLASS = 200
DEVICE = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

os.makedirs(OUTPUT_DIR, exist_ok=True)

print(f"\n✓ Device: {DEVICE}")
print(f"✓ Test dataset: {TEST_DIR}")
print(f"✓ Weights directory: {WEIGHTS_DIR}")
print(f"✓ Output directory: {OUTPUT_DIR}")

# ========== 1. LOAD TEST IMAGES ==========
print("\n[1/5] Loading test images...")

class_names = sorted([d for d in os.listdir(TEST_DIR) if os.path.isdir(os.path.join(TEST_DIR, d))])
test_images = {cls: [] for cls in class_names}

for cls in class_names:
    cls_path = os.path.join(TEST_DIR, cls)
    images = [f for f in os.listdir(cls_path) if f.lower().endswith(('.png', '.jpg', '.jpeg'))]
    sampled = random.sample(images, min(NUM_IMAGES_PER_CLASS, len(images)))
    test_images[cls] = [os.path.join(cls_path, img) for img in sampled]

for cls in class_names:
    print(f"  {cls}: {len(test_images[cls])}")
total_images = sum(len(v) for v in test_images.values())
print(f"✓ Total test images: {total_images}")

# ========== 2. PREPROCESS ==========
print("\n[2/5] Setting up preprocessing...")

def preprocess_image(img_path):
    image = cv2.imread(img_path, cv2.IMREAD_GRAYSCALE)
    if image is None:
        image = np.zeros((IMG_SIZE, IMG_SIZE), dtype=np.uint8)
    image = cv2.resize(image, (IMG_SIZE, IMG_SIZE))
    image_rgb = cv2.cvtColor(image, cv2.COLOR_GRAY2RGB)
    image_norm = image_rgb.astype(np.float32) / 255.0
    image_norm = (image_norm - 0.5) / 0.5
    img_tensor = torch.from_numpy(image_norm).permute(2, 0, 1).unsqueeze(0).to(DEVICE)
    img_tensor.requires_grad_(True)  # Enable grad for CAM backward
    return img_tensor, image_rgb

print("✓ Preprocessing ready")

# ========== 3. CUSTOM CLASSIFIER HEAD (for CNN models) ==========
class CustomHead(nn.Module):
    def __init__(self, in_features, num_classes=4):
        super().__init__()
        self.fc = nn.Sequential(
            nn.Linear(in_features, 512),
            nn.BatchNorm1d(512),
            nn.ReLU(),
            nn.Linear(512, 256),
            nn.BatchNorm1d(256),
            nn.ReLU(),
            nn.Linear(256, num_classes)
        )
    
    def forward(self, x):
        return self.fc(x)

# ========== 4. LOAD MODELS ==========
print("\n[3/5] Loading trained models...")

model_configs = {
    'resnet50': {
        'model_fn': lambda: timm.create_model('resnet50', pretrained=False, num_classes=4),
        'weight_file': 'resnet50_best_80_10_10.pth',
        'target_layer': 'layer4.2',
        'has_custom_head': True,
        'num_features': 2048
    },
    'mobilenetv2': {
        'model_fn': lambda: timm.create_model('mobilenetv2_100', pretrained=False, num_classes=4),
        'weight_file': 'mobilenetv2_best_80_10_10.pth',
        'target_layer': 'blocks.18',
        'has_custom_head': True,
        'num_features': 1280
    },
    'efficientnetb0': {
        'model_fn': lambda: timm.create_model('efficientnet_b0', pretrained=False, num_classes=4),
        'weight_file': 'efficientnetb0_best_80_10_10.pth',
        'target_layer': 'blocks.16',
        'has_custom_head': True,
        'num_features': 1280
    },
    'vit': {
        'model_fn': lambda: timm.create_model('vit_base_patch16_224', pretrained=False, num_classes=4),
        'weight_file': 'vit_best_80_10_10.pth',
        'target_layer': 'blocks.11',
        'has_custom_head': False
    },
    'swin': {
        'model_fn': lambda: timm.create_model('swin_tiny_patch4_window7_224', pretrained=False, num_classes=4),
        'weight_file': 'swin_best_80_10_10.pth',
        'target_layer': 'layers.2',
        'has_custom_head': False
    }
}

models = {}
for model_name, cfg in model_configs.items():
    try:
        model = cfg['model_fn']()
        
        if cfg['has_custom_head']:
            model.fc = CustomHead(cfg['num_features'], num_classes=4)
        
        weight_path = os.path.join(WEIGHTS_DIR, cfg['weight_file'])
        if os.path.exists(weight_path):
            model.load_state_dict(torch.load(weight_path, map_location=DEVICE), strict=False)
            model.to(DEVICE)
            model.eval()
            models[model_name] = model
            print(f"✓ {model_name.upper()} loaded from {weight_path}")
        else:
            print(f"⚠ {model_name.upper()} weights not found at {weight_path}")
    except Exception as e:
        print(f"✗ Error loading {model_name}: {e}")

print(f"\n✓ Total models loaded: {len(models)}")

# ========== 5. GRAD-CAM ==========
print("\n[4/5] Implementing Grad-CAM...")

class GradCAM:
    def __init__(self, model, target_layer_name):
        self.model = model
        self.target_layer_name = target_layer_name
        self.gradients = None
        self.activations = None
        self.hooks = []
        self._register_hooks()
    
    def _register_hooks(self):
        def forward_hook(module, inp, out):
            self.activations = out
        def backward_hook(module, grad_in, grad_out):
            self.gradients = grad_out[0]
        for name, module in self.model.named_modules():
            if self.target_layer_name in name or name.endswith(self.target_layer_name):
                self.hooks.append(module.register_forward_hook(forward_hook))
                self.hooks.append(module.register_full_backward_hook(backward_hook))
                break
    
    def generate_cam(self, input_tensor, class_idx=None):
        output = self.model(input_tensor)
        if class_idx is None:
            class_idx = output.argmax(dim=1).item()
        self.model.zero_grad()
        score = output[0, class_idx]
        score.backward(retain_graph=True)

        grads = self.gradients.detach().cpu().numpy()[0]      # (C,H,W)
        acts = self.activations.detach().cpu().numpy()[0]     # (C,H,W)
        weights = grads.mean(axis=(1, 2))                     # (C,)
        cam = np.zeros(acts.shape[1:], dtype=np.float32)
        for c, w in enumerate(weights):
            cam += w * acts[c]
        cam = np.maximum(cam, 0)
        cam = cam / (cam.max() + 1e-8)
        return cam, class_idx
    
    def remove_hooks(self):
        for h in self.hooks:
            h.remove()

print("✓ Grad-CAM ready")

# ========== 6. GENERATE CAMs ==========
print("\n[5/5] Generating CAMs...")

cam_output_dir = os.path.join(OUTPUT_DIR, 'cam_visualizations')
os.makedirs(cam_output_dir, exist_ok=True)

all_cams = {}

if len(models) == 0:
    print("✗ No models loaded. Cannot generate CAMs.")
else:
    for model_name, model in tqdm(models.items(), desc="Processing models"):
        target_layer = model_configs[model_name]['target_layer']
        cam_engine = GradCAM(model, target_layer)
        all_cams[model_name] = {}
        
        for cls_name in class_names:
            all_cams[model_name][cls_name] = []
            for img_idx, img_path in enumerate(test_images[cls_name]):
                try:
                    img_tensor, img_rgb = preprocess_image(img_path)
                    cam, pred_class = cam_engine.generate_cam(img_tensor)
                    cam_resized = cv2.resize(cam, (IMG_SIZE, IMG_SIZE))
                    cam_color = cv2.applyColorMap((cam_resized * 255).astype(np.uint8), cv2.COLORMAP_JET)
                    cam_color = cv2.cvtColor(cam_color, cv2.COLOR_BGR2RGB)
                    img_rgb_uint8 = (img_rgb * 255).astype(np.uint8)
                    overlay = cv2.addWeighted(img_rgb_uint8, 0.5, cam_color, 0.5, 0)
                    save_filename = f"cam_{model_name}_{cls_name}_{img_idx:03d}.png"
                    cv2.imwrite(os.path.join(cam_output_dir, save_filename),
                                cv2.cvtColor(overlay, cv2.COLOR_RGB2BGR))
                    all_cams[model_name][cls_name].append({
                        'path': img_path,
                        'overlay': overlay,
                        'pred_class': pred_class
                    })
                except Exception as e:
                    pass  # Silently skip errors
        
        cam_engine.remove_hooks()
        print(f"✓ {model_name.upper()}: {len(all_cams[model_name][class_names[0]])} CAMs per class")

    print(f"\n✓ All CAM overlays saved to: {cam_output_dir}")

    # ========== 7. CREATE COMPARISON GRIDS ==========
    print("\n[6/6] Creating comparison grids...")

    for cls_name in class_names:
        fig, axes = plt.subplots(1, len(models) + 1, figsize=(24, 4))
        first_img_path = test_images[cls_name][0]
        _, img_rgb = preprocess_image(first_img_path)
        img_rgb_uint8 = (img_rgb * 255).astype(np.uint8)
        axes[0].imshow(img_rgb_uint8, cmap='gray')
        axes[0].set_title(f'{cls_name} - Original', fontsize=11, fontweight='bold')
        axes[0].axis('off')
        for i, (model_name, _) in enumerate(models.items()):
            if len(all_cams[model_name][cls_name]) > 0:
                overlay = all_cams[model_name][cls_name][0]['overlay']
                axes[i+1].imshow(overlay)
                axes[i+1].set_title(model_name.upper(), fontsize=11, fontweight='bold')
            axes[i+1].axis('off')
        plt.tight_layout()
        plt.savefig(os.path.join(cam_output_dir, f'comparison_grid_{cls_name}.png'), dpi=300)
        plt.close()
        print(f"✓ Saved comparison grid for {cls_name}")

# ========== 8. SUMMARY ==========
print("\n" + "="*80)
print("SUMMARY: GRAD-CAM ANALYSIS COMPLETE")
print("="*80)

if len(models) > 0:
    print(f"\nModels processed: {list(models.keys())}")
    print(f"Classes analyzed: {class_names}")
    print(f"Images per class: {NUM_IMAGES_PER_CLASS}")
    total_cams = sum(len(all_cams[m][c]) for m in models.keys() for c in class_names)
    print(f"Total CAM visualizations: {total_cams}")
    print(f"\nOutput: {cam_output_dir}")
    print(f"Files: individual cam_<model>_lass>_<idx>.png + comparison_grid_lass>.png")
    print("\n✓ All CAM visualizations saved successfully!")
else:
    print("\n✗ No models loaded. Check weights.")

print("="*80)


NOTEBOOK 6: CLASS ACTIVATION MAPS (CAM) - ALL 5 MODELS CORRECT TARGET LAYERS

✓ Device: cuda
✓ Test dataset: /kaggle/input/split-dataset/test
✓ Weights directory: /kaggle/input/weight/weights
✓ Output directory: /kaggle/working

[1/5] Loading test images...
  CNV: 200
  DME: 200
  DRUSEN: 200
  NORMAL: 200
✓ Total test images: 800

[2/5] Setting up preprocessing...
✓ Preprocessing ready

[3/5] Loading trained models...
✓ RESNET50 loaded from /kaggle/input/weight/weights/resnet50_best_80_10_10.pth
✓ MOBILENETV2 loaded from /kaggle/input/weight/weights/mobilenetv2_best_80_10_10.pth
✓ EFFICIENTNETB0 loaded from /kaggle/input/weight/weights/efficientnetb0_best_80_10_10.pth
✓ VIT loaded from /kaggle/input/weight/weights/vit_best_80_10_10.pth
✓ SWIN loaded from /kaggle/input/weight/weights/swin_best_80_10_10.pth

✓ Total models loaded: 5

[4/5] Implementing Grad-CAM...
✓ Grad-CAM ready

[5/5] Generating CAMs...


Processing models:  20%|██        | 1/5 [00:26<01:45, 26.42s/it]

✓ RESNET50: 200 CAMs per class


Processing models:  40%|████      | 2/5 [00:41<00:59, 19.95s/it]

✓ MOBILENETV2: 0 CAMs per class


Processing models:  60%|██████    | 3/5 [01:03<00:41, 20.87s/it]

✓ EFFICIENTNETB0: 0 CAMs per class


Processing models:  80%|████████  | 4/5 [01:37<00:25, 25.83s/it]

✓ VIT: 0 CAMs per class


Processing models: 100%|██████████| 5/5 [02:10<00:00, 26.15s/it]

✓ SWIN: 200 CAMs per class

✓ All CAM overlays saved to: /kaggle/working/cam_visualizations

[6/6] Creating comparison grids...





✓ Saved comparison grid for CNV
✓ Saved comparison grid for DME
✓ Saved comparison grid for DRUSEN
✓ Saved comparison grid for NORMAL

SUMMARY: GRAD-CAM ANALYSIS COMPLETE

Models processed: ['resnet50', 'mobilenetv2', 'efficientnetb0', 'vit', 'swin']
Classes analyzed: ['CNV', 'DME', 'DRUSEN', 'NORMAL']
Images per class: 200
Total CAM visualizations: 1600

Output: /kaggle/working/cam_visualizations
Files: individual cam_<model>_lass>_<idx>.png + comparison_grid_lass>.png

✓ All CAM visualizations saved successfully!


In [11]:
import os
import numpy as np
import cv2
import torch
import torch.nn as nn
import timm
import matplotlib.pyplot as plt
from tqdm import tqdm
import random
import warnings
warnings.filterwarnings('ignore')

print("="*80)
print("NOTEBOOK 6: CLASS ACTIVATION MAPS (CAM) - ALL 5 MODELS INCLUDING ViT ATTENTION CAM")
print("="*80)

# ========== CONFIG ==========
TEST_DIR = '/kaggle/input/split-dataset/test'
WEIGHTS_DIR = '/kaggle/input/weight/weights'
OUTPUT_DIR = '/kaggle/working'
IMG_SIZE = 224
NUM_IMAGES_PER_CLASS = 200
DEVICE = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

os.makedirs(OUTPUT_DIR, exist_ok=True)

print(f"\n✓ Device: {DEVICE}")
print(f"✓ Test dataset: {TEST_DIR}")
print(f"✓ Weights directory: {WEIGHTS_DIR}")
print(f"✓ Output directory: {OUTPUT_DIR}")

# ========== 1. LOAD TEST IMAGES ==========
print("\n[1/5] Loading test images...")

class_names = sorted([d for d in os.listdir(TEST_DIR) if os.path.isdir(os.path.join(TEST_DIR, d))])
test_images = {cls: [] for cls in class_names}

for cls in class_names:
    cls_path = os.path.join(TEST_DIR, cls)
    images = [f for f in os.listdir(cls_path) if f.lower().endswith(('.png', '.jpg', '.jpeg'))]
    sampled = random.sample(images, min(NUM_IMAGES_PER_CLASS, len(images)))
    test_images[cls] = [os.path.join(cls_path, img) for img in sampled]

for cls in class_names:
    print(f"  {cls}: {len(test_images[cls])}")
total_images = sum(len(v) for v in test_images.values())
print(f"✓ Total test images: {total_images}")

# ========== 2. PREPROCESS ==========
print("\n[2/5] Setting up preprocessing...")

def preprocess_image(img_path):
    image = cv2.imread(img_path, cv2.IMREAD_GRAYSCALE)
    if image is None:
        image = np.zeros((IMG_SIZE, IMG_SIZE), dtype=np.uint8)
    image = cv2.resize(image, (IMG_SIZE, IMG_SIZE))
    image_rgb = cv2.cvtColor(image, cv2.COLOR_GRAY2RGB)
    image_norm = image_rgb.astype(np.float32) / 255.0
    image_norm = (image_norm - 0.5) / 0.5
    img_tensor = torch.from_numpy(image_norm).permute(2, 0, 1).unsqueeze(0).to(DEVICE)
    img_tensor.requires_grad_(True)  # Enable gradients for CAM
    return img_tensor, image_rgb

print("✓ Preprocessing ready")

# ========== 3. CUSTOM CNN CLASSIFIER HEAD ==========
class CustomHead(nn.Module):
    def __init__(self, in_features, num_classes=4):
        super().__init__()
        self.fc = nn.Sequential(
            nn.Linear(in_features, 512),
            nn.BatchNorm1d(512),
            nn.ReLU(),
            nn.Linear(512, 256),
            nn.BatchNorm1d(256),
            nn.ReLU(),
            nn.Linear(256, num_classes)
        )
    
    def forward(self, x):
        return self.fc(x)

# ========== 4. LOAD MODELS ==========
print("\n[3/5] Loading trained models...")

model_configs = {
    'resnet50': {
        'model_fn': lambda: timm.create_model('resnet50', pretrained=False, num_classes=4),
        'weight_file': 'resnet50_best_80_10_10.pth',
        'target_layer': 'layer4.2',
        'has_custom_head': True,
        'num_features': 2048
    },
    'mobilenetv2': {
        'model_fn': lambda: timm.create_model('mobilenetv2_100', pretrained=False, num_classes=4),
        'weight_file': 'mobilenetv2_best_80_10_10.pth',
        'target_layer': 'conv_head',
        'has_custom_head': True,
        'num_features': 1280
    },
    'efficientnetb0': {
        'model_fn': lambda: timm.create_model('efficientnet_b0', pretrained=False, num_classes=4),
        'weight_file': 'efficientnetb0_best_80_10_10.pth',
        'target_layer': 'conv_head',
        'has_custom_head': True,
        'num_features': 1280
    },
    'vit': {
        'model_fn': lambda: timm.create_model('vit_base_patch16_224', pretrained=False, num_classes=4),
        'weight_file': 'vit_best_80_10_10.pth',
        'target_layer': 'blocks.11.mlp',
        'has_custom_head': False
    },
    'swin': {
        'model_fn': lambda: timm.create_model('swin_tiny_patch4_window7_224', pretrained=False, num_classes=4),
        'weight_file': 'swin_best_80_10_10.pth',
        'target_layer': 'layers.2',
        'has_custom_head': False
    }
}

models = {}
for model_name, cfg in model_configs.items():
    try:
        model = cfg['model_fn']()
        
        if cfg['has_custom_head']:
            model.fc = CustomHead(cfg['num_features'], num_classes=4)
        
        weight_path = os.path.join(WEIGHTS_DIR, cfg['weight_file'])
        if os.path.exists(weight_path):
            model.load_state_dict(torch.load(weight_path, map_location=DEVICE), strict=False)
            model.to(DEVICE)
            model.eval()
            models[model_name] = model
            print(f"✓ {model_name.upper()} loaded from {weight_path}")
        else:
            print(f"⚠ {model_name.upper()} weights not found at {weight_path}")
    except Exception as e:
        print(f"✗ Error loading {model_name}: {e}")

print(f"\n✓ Total models loaded: {len(models)}")

# ========== 5a. STANDARD Grad-CAM FOR CNNs AND SWIN ==========
class GradCAM:
    def __init__(self, model, target_layer_name):
        self.model = model
        self.target_layer_name = target_layer_name
        self.gradients = None
        self.activations = None
        self.hooks = []
        self._register_hooks()
    
    def _register_hooks(self):
        def forward_hook(module, inp, out):
            self.activations = out
        def backward_hook(module, grad_in, grad_out):
            self.gradients = grad_out[0]
        for name, module in self.model.named_modules():
            if self.target_layer_name in name or name.endswith(self.target_layer_name):
                self.hooks.append(module.register_forward_hook(forward_hook))
                self.hooks.append(module.register_full_backward_hook(backward_hook))
                break
    
    def generate_cam(self, input_tensor, class_idx=None):
        output = self.model(input_tensor)
        if class_idx is None:
            class_idx = output.argmax(dim=1).item()
        self.model.zero_grad()
        score = output[0, class_idx]
        score.backward(retain_graph=True)

        grads = self.gradients.detach().cpu().numpy()[0]
        acts = self.activations.detach().cpu().numpy()[0]
        weights = grads.mean(axis=(1, 2))
        cam = np.zeros(acts.shape[1:], dtype=np.float32)
        for c, w in enumerate(weights):
            cam += w * acts[c]
        cam = np.maximum(cam, 0)
        cam = cam / (cam.max() + 1e-8)
        return cam, class_idx
    
    def remove_hooks(self):
        for h in self.hooks:
            h.remove()

# ========== 5b. ViT Attention Rollout CAM ==========
class ViTAttentionCAM:
    def __init__(self, model):
        self.model = model
        self.attentions = []
        self._register_hooks()

    def _register_hooks(self):
        self.hooks = []
        def hook(module, input, output):
            self.attentions.append(output)
        for name, module in self.model.named_modules():
            if "attn.attn_drop" in name or "attn_drop" in name:
                self.hooks.append(module.register_forward_hook(hook))

    def get_rollout_attention(self):
        rollout = torch.eye(self.attentions[0].size(-1)).to(self.attentions[0].device)
        for attn in self.attentions:
            attn_heads_fused = attn.mean(dim=1)
            attn_heads_fused += torch.eye(attn_heads_fused.size(-1)).to(attn_heads_fused.device)
            attn_heads_fused /= attn_heads_fused.sum(dim=-1, keepdim=True)
            rollout = torch.matmul(rollout, attn_heads_fused)
        return rollout

    def generate_cam(self, input_tensor):
        self.attentions = []
        _ = self.model(input_tensor)
        rollout = self.get_rollout_attention()[0, 0, 1:]
        size = int(np.sqrt(rollout.size(0)))
        cam_map = rollout.reshape(size, size).cpu().numpy()
        cam_map = (cam_map - cam_map.min()) / (cam_map.max() - cam_map.min() + 1e-8)
        return cam_map
    
    def remove_hooks(self):
        for h in self.hooks:
            h.remove()

print("✓ CAM classes ready")

# ========== 6. GENERATE CAMs ==========
print("\n[5/5] Generating CAMs...")

cam_output_dir = os.path.join(OUTPUT_DIR, 'cam_visualizations')
os.makedirs(cam_output_dir, exist_ok=True)

all_cams = {}

if len(models) == 0:
    print("✗ No models loaded. Cannot generate CAMs.")
else:
    for model_name, model in tqdm(models.items(), desc="Processing models"):
        all_cams[model_name] = {}
        if model_name == 'vit':
            cam_engine = ViTAttentionCAM(model)
            for cls_name in class_names:
                all_cams[model_name][cls_name] = []
                for img_idx, img_path in enumerate(test_images[cls_name]):
                    try:
                        img_tensor, img_rgb = preprocess_image(img_path)
                        cam_map = cam_engine.generate_cam(img_tensor)
                        cam_resized = cv2.resize(cam_map, (IMG_SIZE, IMG_SIZE))
                        cam_color = cv2.applyColorMap((cam_resized * 255).astype(np.uint8), cv2.COLORMAP_JET)
                        cam_color = cv2.cvtColor(cam_color, cv2.COLOR_BGR2RGB)
                        img_rgb_uint8 = (img_rgb * 255).astype(np.uint8)
                        overlay = cv2.addWeighted(img_rgb_uint8, 0.5, cam_color, 0.5, 0)
                        save_filename = f"cam_{model_name}_{cls_name}_{img_idx:03d}.png"
                        cv2.imwrite(os.path.join(cam_output_dir, save_filename),
                                    cv2.cvtColor(overlay, cv2.COLOR_RGB2BGR))
                        all_cams[model_name][cls_name].append({
                            'path': img_path,
                            'overlay': overlay,
                            'pred_class': None
                        })
                    except Exception as e:
                        pass
            cam_engine.remove_hooks()
            print(f"✓ {model_name.upper()}: {len(all_cams[model_name][class_names[0]])} CAMs per class")
        else:
            target_layer = model_configs[model_name]['target_layer']
            cam_engine = GradCAM(model, target_layer)
            for cls_name in class_names:
                all_cams[model_name][cls_name] = []
                for img_idx, img_path in enumerate(test_images[cls_name]):
                    try:
                        img_tensor, img_rgb = preprocess_image(img_path)
                        cam, pred_class = cam_engine.generate_cam(img_tensor)
                        cam_resized = cv2.resize(cam, (IMG_SIZE, IMG_SIZE))
                        cam_color = cv2.applyColorMap((cam_resized * 255).astype(np.uint8), cv2.COLORMAP_JET)
                        cam_color = cv2.cvtColor(cam_color, cv2.COLOR_BGR2RGB)
                        img_rgb_uint8 = (img_rgb * 255).astype(np.uint8)
                        overlay = cv2.addWeighted(img_rgb_uint8, 0.5, cam_color, 0.5, 0)
                        save_filename = f"cam_{model_name}_{cls_name}_{img_idx:03d}.png"
                        cv2.imwrite(os.path.join(cam_output_dir, save_filename),
                                    cv2.cvtColor(overlay, cv2.COLOR_RGB2BGR))
                        all_cams[model_name][cls_name].append({
                            'path': img_path,
                            'overlay': overlay,
                            'pred_class': pred_class
                        })
                    except Exception as e:
                        pass
            cam_engine.remove_hooks()
            print(f"✓ {model_name.upper()}: {len(all_cams[model_name][class_names[0]])} CAMs per class")

    print(f"\n✓ All CAM overlays saved to: {cam_output_dir}")

    # ========== 7. CREATE COMPARISON GRIDS ==========
    print("\n[6/6] Creating comparison grids...")

    for cls_name in class_names:
        fig, axes = plt.subplots(1, len(models) + 1, figsize=(24, 4))
        first_img_path = test_images[cls_name][0]
        _, img_rgb = preprocess_image(first_img_path)
        img_rgb_uint8 = (img_rgb * 255).astype(np.uint8)
        axes[0].imshow(img_rgb_uint8, cmap='gray')
        axes[0].set_title(f'{cls_name} - Original', fontsize=11, fontweight='bold')
        axes[0].axis('off')
        for i, (model_name, _) in enumerate(models.items()):
            if len(all_cams[model_name][cls_name]) > 0:
                overlay = all_cams[model_name][cls_name][0]['overlay']
                axes[i+1].imshow(overlay)
                axes[i+1].set_title(model_name.upper(), fontsize=11, fontweight='bold')
            axes[i+1].axis('off')
        plt.tight_layout()
        plt.savefig(os.path.join(cam_output_dir, f'comparison_grid_{cls_name}.png'), dpi=300)
        plt.close()
        print(f"✓ Saved comparison grid for {cls_name}")

# ========== 8. SUMMARY ==========
print("\n" + "="*80)
print("SUMMARY: GRAD-CAM ANALYSIS COMPLETE")
print("="*80)

if len(models) > 0:
    print(f"\nModels processed: {list(models.keys())}")
    print(f"Classes analyzed: {class_names}")
    print(f"Images per class: {NUM_IMAGES_PER_CLASS}")
    total_cams = sum(len(all_cams[m][c]) for m in models.keys() for c in class_names)
    print(f"Total CAM visualizations: {total_cams}")
    print(f"\nOutput: {cam_output_dir}")
    print(f"Files: individual cam_<model>_<class>_<idx>.png + comparison_grid_<class>.png")
    print("\n✓ All CAM visualizations saved successfully!")
else:
    print("\n✗ No models loaded. Check weights.")

print("="*80)


NOTEBOOK 6: CLASS ACTIVATION MAPS (CAM) - ALL 5 MODELS INCLUDING ViT ATTENTION CAM

✓ Device: cuda
✓ Test dataset: /kaggle/input/split-dataset/test
✓ Weights directory: /kaggle/input/weight/weights
✓ Output directory: /kaggle/working

[1/5] Loading test images...
  CNV: 200
  DME: 200
  DRUSEN: 200
  NORMAL: 200
✓ Total test images: 800

[2/5] Setting up preprocessing...
✓ Preprocessing ready

[3/5] Loading trained models...
✓ RESNET50 loaded from /kaggle/input/weight/weights/resnet50_best_80_10_10.pth
✓ MOBILENETV2 loaded from /kaggle/input/weight/weights/mobilenetv2_best_80_10_10.pth
✓ EFFICIENTNETB0 loaded from /kaggle/input/weight/weights/efficientnetb0_best_80_10_10.pth
✓ VIT loaded from /kaggle/input/weight/weights/vit_best_80_10_10.pth
✓ SWIN loaded from /kaggle/input/weight/weights/swin_best_80_10_10.pth

✓ Total models loaded: 5
✓ CAM classes ready

[5/5] Generating CAMs...


Processing models:  20%|██        | 1/5 [00:25<01:41, 25.39s/it]

✓ RESNET50: 200 CAMs per class


Processing models:  40%|████      | 2/5 [00:47<01:11, 23.69s/it]

✓ MOBILENETV2: 200 CAMs per class


Processing models:  60%|██████    | 3/5 [01:16<00:52, 26.15s/it]

✓ EFFICIENTNETB0: 200 CAMs per class


Processing models:  80%|████████  | 4/5 [01:28<00:20, 20.30s/it]

✓ VIT: 0 CAMs per class


Processing models: 100%|██████████| 5/5 [02:02<00:00, 24.41s/it]

✓ SWIN: 200 CAMs per class

✓ All CAM overlays saved to: /kaggle/working/cam_visualizations

[6/6] Creating comparison grids...





✓ Saved comparison grid for CNV
✓ Saved comparison grid for DME
✓ Saved comparison grid for DRUSEN
✓ Saved comparison grid for NORMAL

SUMMARY: GRAD-CAM ANALYSIS COMPLETE

Models processed: ['resnet50', 'mobilenetv2', 'efficientnetb0', 'vit', 'swin']
Classes analyzed: ['CNV', 'DME', 'DRUSEN', 'NORMAL']
Images per class: 200
Total CAM visualizations: 3200

Output: /kaggle/working/cam_visualizations
Files: individual cam_<model>_<class>_<idx>.png + comparison_grid_<class>.png

✓ All CAM visualizations saved successfully!


In [12]:
model_vit = models['vit']

print("\nNamed modules in ViT containing 'attn':")
for name, module in model_vit.named_modules():
    if 'attn' in name:
        print(name)



Named modules in ViT containing 'attn':
blocks.0.attn
blocks.0.attn.qkv
blocks.0.attn.q_norm
blocks.0.attn.k_norm
blocks.0.attn.attn_drop
blocks.0.attn.norm
blocks.0.attn.proj
blocks.0.attn.proj_drop
blocks.1.attn
blocks.1.attn.qkv
blocks.1.attn.q_norm
blocks.1.attn.k_norm
blocks.1.attn.attn_drop
blocks.1.attn.norm
blocks.1.attn.proj
blocks.1.attn.proj_drop
blocks.2.attn
blocks.2.attn.qkv
blocks.2.attn.q_norm
blocks.2.attn.k_norm
blocks.2.attn.attn_drop
blocks.2.attn.norm
blocks.2.attn.proj
blocks.2.attn.proj_drop
blocks.3.attn
blocks.3.attn.qkv
blocks.3.attn.q_norm
blocks.3.attn.k_norm
blocks.3.attn.attn_drop
blocks.3.attn.norm
blocks.3.attn.proj
blocks.3.attn.proj_drop
blocks.4.attn
blocks.4.attn.qkv
blocks.4.attn.q_norm
blocks.4.attn.k_norm
blocks.4.attn.attn_drop
blocks.4.attn.norm
blocks.4.attn.proj
blocks.4.attn.proj_drop
blocks.5.attn
blocks.5.attn.qkv
blocks.5.attn.q_norm
blocks.5.attn.k_norm
blocks.5.attn.attn_drop
blocks.5.attn.norm
blocks.5.attn.proj
blocks.5.attn.proj_drop

In [13]:
import os
import numpy as np
import cv2
import torch
import torch.nn as nn
import timm
import matplotlib.pyplot as plt
from tqdm import tqdm
import random
import warnings
warnings.filterwarnings('ignore')

print("="*80)
print("NOTEBOOK 6: CLASS ACTIVATION MAPS (CAM) - ALL 5 MODELS INCLUDING ViT ATTENTION CAM")
print("="*80)

# ========== CONFIG ==========
TEST_DIR = '/kaggle/input/split-dataset/test'
WEIGHTS_DIR = '/kaggle/input/weight/weights'
OUTPUT_DIR = '/kaggle/working'
IMG_SIZE = 224
NUM_IMAGES_PER_CLASS = 200
DEVICE = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

os.makedirs(OUTPUT_DIR, exist_ok=True)

print(f"\n✓ Device: {DEVICE}")
print(f"✓ Test dataset: {TEST_DIR}")
print(f"✓ Weights directory: {WEIGHTS_DIR}")
print(f"✓ Output directory: {OUTPUT_DIR}")

# ========== 1. LOAD TEST IMAGES ==========
print("\n[1/5] Loading test images...")
class_names = sorted([d for d in os.listdir(TEST_DIR) if os.path.isdir(os.path.join(TEST_DIR, d))])
test_images = {cls: [] for cls in class_names}

for cls in class_names:
    cls_path = os.path.join(TEST_DIR, cls)
    images = [f for f in os.listdir(cls_path) if f.lower().endswith(('.png', '.jpg', '.jpeg'))]
    sampled = random.sample(images, min(NUM_IMAGES_PER_CLASS, len(images)))
    test_images[cls] = [os.path.join(cls_path, img) for img in sampled]

for cls in class_names:
    print(f"  {cls}: {len(test_images[cls])}")
total_images = sum(len(v) for v in test_images.values())
print(f"✓ Total test images: {total_images}")

# ========== 2. PREPROCESS ==========
print("\n[2/5] Setting up preprocessing...")
def preprocess_image(img_path):
    image = cv2.imread(img_path, cv2.IMREAD_GRAYSCALE)
    if image is None:
        image = np.zeros((IMG_SIZE, IMG_SIZE), dtype=np.uint8)
    image = cv2.resize(image, (IMG_SIZE, IMG_SIZE))
    image_rgb = cv2.cvtColor(image, cv2.COLOR_GRAY2RGB)
    image_norm = image_rgb.astype(np.float32) / 255.0
    image_norm = (image_norm - 0.5) / 0.5
    img_tensor = torch.from_numpy(image_norm).permute(2, 0, 1).unsqueeze(0).to(DEVICE)
    img_tensor.requires_grad_(True)  # Enable gradients for CAM
    return img_tensor, image_rgb

print("✓ Preprocessing ready")

# ========== 3. CUSTOM CNN CLASSIFIER HEAD ==========
class CustomHead(nn.Module):
    def __init__(self, in_features, num_classes=4):
        super().__init__()
        self.fc = nn.Sequential(
            nn.Linear(in_features, 512),
            nn.BatchNorm1d(512),
            nn.ReLU(),
            nn.Linear(512, 256),
            nn.BatchNorm1d(256),
            nn.ReLU(),
            nn.Linear(256, num_classes)
        )

    def forward(self, x):
        return self.fc(x)

# ========== 4. LOAD MODELS ==========
print("\n[3/5] Loading trained models...")
model_configs = {
    'resnet50': {
        'model_fn': lambda: timm.create_model('resnet50', pretrained=False, num_classes=4),
        'weight_file': 'resnet50_best_80_10_10.pth',
        'target_layer': 'layer4.2',
        'has_custom_head': True,
        'num_features': 2048
    },
    'mobilenetv2': {
        'model_fn': lambda: timm.create_model('mobilenetv2_100', pretrained=False, num_classes=4),
        'weight_file': 'mobilenetv2_best_80_10_10.pth',
        'target_layer': 'conv_head',
        'has_custom_head': True,
        'num_features': 1280
    },
    'efficientnetb0': {
        'model_fn': lambda: timm.create_model('efficientnet_b0', pretrained=False, num_classes=4),
        'weight_file': 'efficientnetb0_best_80_10_10.pth',
        'target_layer': 'conv_head',
        'has_custom_head': True,
        'num_features': 1280
    },
    'vit': {
        'model_fn': lambda: timm.create_model('vit_base_patch16_224', pretrained=False, num_classes=4),
        'weight_file': 'vit_best_80_10_10.pth',
        'target_layer': 'blocks.attn',  # generalized target; hooks will be installed on all 'attn' modules
        'has_custom_head': False
    },
    'swin': {
        'model_fn': lambda: timm.create_model('swin_tiny_patch4_window7_224', pretrained=False, num_classes=4),
        'weight_file': 'swin_best_80_10_10.pth',
        'target_layer': 'layers.2',
        'has_custom_head': False
    }
}

models = {}
for model_name, cfg in model_configs.items():
    try:
        model = cfg['model_fn']()

        if cfg['has_custom_head']:
            model.fc = CustomHead(cfg['num_features'], num_classes=4)

        weight_path = os.path.join(WEIGHTS_DIR, cfg['weight_file'])
        if os.path.exists(weight_path):
            model.load_state_dict(torch.load(weight_path, map_location=DEVICE), strict=False)
            model.to(DEVICE)
            model.eval()
            models[model_name] = model
            print(f"✓ {model_name.upper()} loaded from {weight_path}")
        else:
            print(f"⚠ {model_name.upper()} weights not found at {weight_path}")
    except Exception as e:
        print(f"✗ Error loading {model_name}: {e}")

print(f"\n✓ Total models loaded: {len(models)}")

# ========== 5a. STANDARD Grad-CAM FOR CNNs AND SWIN ==========
class GradCAM:
    def __init__(self, model, target_layer_name):
        self.model = model
        self.target_layer_name = target_layer_name
        self.gradients = None
        self.activations = None
        self.hooks = []
        self._register_hooks()

    def _register_hooks(self):
        def forward_hook(module, inp, out):
            self.activations = out
        def backward_hook(module, grad_in, grad_out):
            self.gradients = grad_out[0]
        for name, module in self.model.named_modules():
            if self.target_layer_name in name or name.endswith(self.target_layer_name):
                self.hooks.append(module.register_forward_hook(forward_hook))
                self.hooks.append(module.register_full_backward_hook(backward_hook))
                break

    def generate_cam(self, input_tensor, class_idx=None):
        output = self.model(input_tensor)
        if class_idx is None:
            class_idx = output.argmax(dim=1).item()
        self.model.zero_grad()
        score = output[0, class_idx]
        score.backward(retain_graph=True)

        grads = self.gradients.detach().cpu().numpy()[0]
        acts = self.activations.detach().cpu().numpy()[0]
        weights = grads.mean(axis=(1, 2))
        cam = np.zeros(acts.shape[1:], dtype=np.float32)
        for c, w in enumerate(weights):
            cam += w * acts[c]
        cam = np.maximum(cam, 0)
        cam = cam / (cam.max() + 1e-8)
        return cam, class_idx

    def remove_hooks(self):
        for h in self.hooks:
            h.remove()

# ========== 5b. ViT Attention Rollout CAM ==========
class ViTAttentionCAM:
    def __init__(self, model):
        self.model = model
        self.attentions = []
        self.hooks = []
        self._register_hooks()

    def _register_hooks(self):
        def hook(module, input, output):
            self.attentions.append(output.detach())
        for name, module in self.model.named_modules():
            if name.endswith('attn'):
                self.hooks.append(module.register_forward_hook(hook))

    def get_rollout_attention(self):
        if not self.attentions:
            raise RuntimeError("No attentions captured. Check hooks.")
        rollout = torch.eye(self.attentions[0].size(-1)).to(self.attentions[0].device)
        for attn in self.attentions:
            attn_heads_fused = attn.mean(dim=1)
            attn_heads_fused += torch.eye(attn_heads_fused.size(-1)).to(attn_heads_fused.device)
            attn_heads_fused /= attn_heads_fused.sum(dim=-1, keepdim=True)
            rollout = torch.matmul(rollout, attn_heads_fused)
        return rollout

    def generate_cam(self, input_tensor):
        self.attentions.clear()
        _ = self.model(input_tensor)
        rollout = self.get_rollout_attention()
        cam = rollout[0, 0, 1:].reshape(int(np.sqrt(rollout.size(-1) - 1)),
                                        int(np.sqrt(rollout.size(-1) - 1))).cpu().numpy()
        cam = (cam - cam.min()) / (cam.max() - cam.min() + 1e-8)
        return cam

    def remove_hooks(self):
        for h in self.hooks:
            h.remove()
        self.hooks = []

print("✓ CAM classes ready")

# ========== 6. GENERATE CAMs ==========
print("\n[5/5] Generating CAMs...")

cam_output_dir = os.path.join(OUTPUT_DIR, 'cam_visualizations')
os.makedirs(cam_output_dir, exist_ok=True)

all_cams = {}

if len(models) == 0:
    print("✗ No models loaded. Cannot generate CAMs.")
else:
    for model_name, model in tqdm(models.items(), desc="Processing models"):
        all_cams[model_name] = {}
        if model_name == 'vit':
            cam_engine = ViTAttentionCAM(model)
            for cls_name in class_names:
                all_cams[model_name][cls_name] = []
                for img_idx, img_path in enumerate(test_images[cls_name]):
                    try:
                        img_tensor, img_rgb = preprocess_image(img_path)
                        cam_map = cam_engine.generate_cam(img_tensor)
                        cam_resized = cv2.resize(cam_map, (IMG_SIZE, IMG_SIZE))
                        cam_color = cv2.applyColorMap((cam_resized * 255).astype(np.uint8), cv2.COLORMAP_JET)
                        cam_color = cv2.cvtColor(cam_color, cv2.COLOR_BGR2RGB)
                        img_rgb_uint8 = (img_rgb * 255).astype(np.uint8)
                        overlay = cv2.addWeighted(img_rgb_uint8, 0.5, cam_color, 0.5, 0)
                        save_filename = f"cam_{model_name}_{cls_name}_{img_idx:03d}.png"
                        cv2.imwrite(os.path.join(cam_output_dir, save_filename),
                                    cv2.cvtColor(overlay, cv2.COLOR_RGB2BGR))
                        all_cams[model_name][cls_name].append({
                            'path': img_path,
                            'overlay': overlay,
                            'pred_class': None
                        })
                    except Exception:
                        pass
            cam_engine.remove_hooks()
            print(f"✓ {model_name.upper()}: {len(all_cams[model_name][class_names[0]])} CAMs per class")
        else:
            target_layer = model_configs[model_name]['target_layer']
            cam_engine = GradCAM(model, target_layer)
            for cls_name in class_names:
                all_cams[model_name][cls_name] = []
                for img_idx, img_path in enumerate(test_images[cls_name]):
                    try:
                        img_tensor, img_rgb = preprocess_image(img_path)
                        cam, pred_class = cam_engine.generate_cam(img_tensor)
                        cam_resized = cv2.resize(cam, (IMG_SIZE, IMG_SIZE))
                        cam_color = cv2.applyColorMap((cam_resized * 255).astype(np.uint8), cv2.COLORMAP_JET)
                        cam_color = cv2.cvtColor(cam_color, cv2.COLOR_BGR2RGB)
                        img_rgb_uint8 = (img_rgb * 255).astype(np.uint8)
                        overlay = cv2.addWeighted(img_rgb_uint8, 0.5, cam_color, 0.5, 0)
                        save_filename = f"cam_{model_name}_{cls_name}_{img_idx:03d}.png"
                        cv2.imwrite(os.path.join(cam_output_dir, save_filename),
                                    cv2.cvtColor(overlay, cv2.COLOR_RGB2BGR))
                        all_cams[model_name][cls_name].append({
                            'path': img_path,
                            'overlay': overlay,
                            'pred_class': pred_class
                        })
                    except Exception:
                        pass
            cam_engine.remove_hooks()
            print(f"✓ {model_name.upper()}: {len(all_cams[model_name][class_names[0]])} CAMs per class")

    print(f"\n✓ All CAM overlays saved to: {cam_output_dir}")

# ========== 7. CREATE COMPARISON GRIDS ==========
print("\n[6/6] Creating comparison grids...")
for cls_name in class_names:
    fig, axes = plt.subplots(1, len(models) + 1, figsize=(24, 4))
    first_img_path = test_images[cls_name][0]
    _, img_rgb = preprocess_image(first_img_path)
    img_rgb_uint8 = (img_rgb * 255).astype(np.uint8)
    axes[0].imshow(img_rgb_uint8, cmap='gray')
    axes[0].set_title(f'{cls_name} - Original', fontsize=11, fontweight='bold')
    axes[0].axis('off')
    for i, (model_name, _) in enumerate(models.items()):
        if len(all_cams[model_name][cls_name]) > 0:
            overlay = all_cams[model_name][cls_name][0]['overlay']
            axes[i+1].imshow(overlay)
            axes[i+1].set_title(model_name.upper(), fontsize=11, fontweight='bold')
        axes[i+1].axis('off')
    plt.tight_layout()
    plt.savefig(os.path.join(cam_output_dir, f'comparison_grid_{cls_name}.png'), dpi=300)
    plt.close()
    print(f"✓ Saved comparison grid for {cls_name}")

# ========== 8. SUMMARY ==========
print("\n" + "="*80)
print("SUMMARY: GRAD-CAM ANALYSIS COMPLETE")
print("="*80)

if len(models) > 0:
    print(f"\nModels processed: {list(models.keys())}")
    print(f"Classes analyzed: {class_names}")
    print(f"Images per class: {NUM_IMAGES_PER_CLASS}")
    total_cams = sum(len(all_cams[m][c]) for m in models.keys() for c in class_names)
    print(f"Total CAM visualizations: {total_cams}")
    print(f"\nOutput: {cam_output_dir}")
    print(f"Files: individual cam_<model>_<class>_<idx>.png + comparison_grid_<class>.png")
    print("\n✓ All CAM visualizations saved successfully!")
else:
    print("\n✗ No models loaded. Check weights.")
print("="*80)


NOTEBOOK 6: CLASS ACTIVATION MAPS (CAM) - ALL 5 MODELS INCLUDING ViT ATTENTION CAM

✓ Device: cuda
✓ Test dataset: /kaggle/input/split-dataset/test
✓ Weights directory: /kaggle/input/weight/weights
✓ Output directory: /kaggle/working

[1/5] Loading test images...
  CNV: 200
  DME: 200
  DRUSEN: 200
  NORMAL: 200
✓ Total test images: 800

[2/5] Setting up preprocessing...
✓ Preprocessing ready

[3/5] Loading trained models...
✓ RESNET50 loaded from /kaggle/input/weight/weights/resnet50_best_80_10_10.pth
✓ MOBILENETV2 loaded from /kaggle/input/weight/weights/mobilenetv2_best_80_10_10.pth
✓ EFFICIENTNETB0 loaded from /kaggle/input/weight/weights/efficientnetb0_best_80_10_10.pth
✓ VIT loaded from /kaggle/input/weight/weights/vit_best_80_10_10.pth
✓ SWIN loaded from /kaggle/input/weight/weights/swin_best_80_10_10.pth

✓ Total models loaded: 5
✓ CAM classes ready

[5/5] Generating CAMs...


Processing models:  20%|██        | 1/5 [00:24<01:39, 24.78s/it]

✓ RESNET50: 200 CAMs per class


Processing models:  40%|████      | 2/5 [00:47<01:10, 23.35s/it]

✓ MOBILENETV2: 200 CAMs per class


Processing models:  60%|██████    | 3/5 [01:15<00:51, 25.72s/it]

✓ EFFICIENTNETB0: 200 CAMs per class


Processing models:  80%|████████  | 4/5 [01:29<00:20, 20.96s/it]

✓ VIT: 0 CAMs per class


Processing models: 100%|██████████| 5/5 [02:02<00:00, 24.57s/it]

✓ SWIN: 200 CAMs per class

✓ All CAM overlays saved to: /kaggle/working/cam_visualizations

[6/6] Creating comparison grids...





✓ Saved comparison grid for CNV
✓ Saved comparison grid for DME
✓ Saved comparison grid for DRUSEN
✓ Saved comparison grid for NORMAL

SUMMARY: GRAD-CAM ANALYSIS COMPLETE

Models processed: ['resnet50', 'mobilenetv2', 'efficientnetb0', 'vit', 'swin']
Classes analyzed: ['CNV', 'DME', 'DRUSEN', 'NORMAL']
Images per class: 200
Total CAM visualizations: 3200

Output: /kaggle/working/cam_visualizations
Files: individual cam_<model>_<class>_<idx>.png + comparison_grid_<class>.png

✓ All CAM visualizations saved successfully!


In [14]:
model_vit = models['vit']
print("ViT named modules with 'attn' or 'drop' in name:")
for name, module in model_vit.named_modules():
    if 'attn' in name or 'drop' in name:
        print(name)


ViT named modules with 'attn' or 'drop' in name:
pos_drop
patch_drop
blocks.0.attn
blocks.0.attn.qkv
blocks.0.attn.q_norm
blocks.0.attn.k_norm
blocks.0.attn.attn_drop
blocks.0.attn.norm
blocks.0.attn.proj
blocks.0.attn.proj_drop
blocks.0.drop_path1
blocks.0.mlp.drop1
blocks.0.mlp.drop2
blocks.0.drop_path2
blocks.1.attn
blocks.1.attn.qkv
blocks.1.attn.q_norm
blocks.1.attn.k_norm
blocks.1.attn.attn_drop
blocks.1.attn.norm
blocks.1.attn.proj
blocks.1.attn.proj_drop
blocks.1.drop_path1
blocks.1.mlp.drop1
blocks.1.mlp.drop2
blocks.1.drop_path2
blocks.2.attn
blocks.2.attn.qkv
blocks.2.attn.q_norm
blocks.2.attn.k_norm
blocks.2.attn.attn_drop
blocks.2.attn.norm
blocks.2.attn.proj
blocks.2.attn.proj_drop
blocks.2.drop_path1
blocks.2.mlp.drop1
blocks.2.mlp.drop2
blocks.2.drop_path2
blocks.3.attn
blocks.3.attn.qkv
blocks.3.attn.q_norm
blocks.3.attn.k_norm
blocks.3.attn.attn_drop
blocks.3.attn.norm
blocks.3.attn.proj
blocks.3.attn.proj_drop
blocks.3.drop_path1
blocks.3.mlp.drop1
blocks.3.mlp.drop2

In [15]:
import os
import numpy as np
import cv2
import torch
import torch.nn as nn
import timm
import matplotlib.pyplot as plt
from tqdm import tqdm
import random
import warnings
warnings.filterwarnings('ignore')

print("="*80)
print("NOTEBOOK 6: CLASS ACTIVATION MAPS (CAM) - ALL 5 MODELS INCLUDING REFINED ViT ATTENTION CAM")
print("="*80)

# ========== CONFIG ==========
TEST_DIR = '/kaggle/input/split-dataset/test'
WEIGHTS_DIR = '/kaggle/input/weight/weights'
OUTPUT_DIR = '/kaggle/working'
IMG_SIZE = 224
NUM_IMAGES_PER_CLASS = 200
DEVICE = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

os.makedirs(OUTPUT_DIR, exist_ok=True)

print(f"\n✓ Device: {DEVICE}")
print(f"✓ Test dataset: {TEST_DIR}")
print(f"✓ Weights directory: {WEIGHTS_DIR}")
print(f"✓ Output directory: {OUTPUT_DIR}")

# ========== 1. LOAD TEST IMAGES ==========
print("\n[1/5] Loading test images...")
class_names = sorted([d for d in os.listdir(TEST_DIR) if os.path.isdir(os.path.join(TEST_DIR, d))])
test_images = {cls: [] for cls in class_names}

for cls in class_names:
    cls_path = os.path.join(TEST_DIR, cls)
    images = [f for f in os.listdir(cls_path) if f.lower().endswith(('.png', '.jpg', '.jpeg'))]
    sampled = random.sample(images, min(NUM_IMAGES_PER_CLASS, len(images)))
    test_images[cls] = [os.path.join(cls_path, img) for img in sampled]

for cls in class_names:
    print(f"  {cls}: {len(test_images[cls])}")
total_images = sum(len(v) for v in test_images.values())
print(f"✓ Total test images: {total_images}")

# ========== 2. PREPROCESS ==========
print("\n[2/5] Setting up preprocessing...")
def preprocess_image(img_path):
    image = cv2.imread(img_path, cv2.IMREAD_GRAYSCALE)
    if image is None:
        image = np.zeros((IMG_SIZE, IMG_SIZE), dtype=np.uint8)
    image = cv2.resize(image, (IMG_SIZE, IMG_SIZE))
    image_rgb = cv2.cvtColor(image, cv2.COLOR_GRAY2RGB)
    image_norm = image_rgb.astype(np.float32) / 255.0
    image_norm = (image_norm - 0.5) / 0.5
    img_tensor = torch.from_numpy(image_norm).permute(2, 0, 1).unsqueeze(0).to(DEVICE)
    img_tensor.requires_grad_(True)
    return img_tensor, image_rgb

print("✓ Preprocessing ready")

# ========== 3. CUSTOM CNN CLASSIFIER HEAD ==========
class CustomHead(nn.Module):
    def __init__(self, in_features, num_classes=4):
        super().__init__()
        self.fc = nn.Sequential(
            nn.Linear(in_features, 512),
            nn.BatchNorm1d(512),
            nn.ReLU(),
            nn.Linear(512, 256),
            nn.BatchNorm1d(256),
            nn.ReLU(),
            nn.Linear(256, num_classes)
        )
    def forward(self, x):
        return self.fc(x)

# ========== 4. LOAD MODELS ==========
print("\n[3/5] Loading trained models...")
model_configs = {
    'resnet50': {
        'model_fn': lambda: timm.create_model('resnet50', pretrained=False, num_classes=4),
        'weight_file': 'resnet50_best_80_10_10.pth',
        'target_layer': 'layer4.2',
        'has_custom_head': True,
        'num_features': 2048
    },
    'mobilenetv2': {
        'model_fn': lambda: timm.create_model('mobilenetv2_100', pretrained=False, num_classes=4),
        'weight_file': 'mobilenetv2_best_80_10_10.pth',
        'target_layer': 'conv_head',
        'has_custom_head': True,
        'num_features': 1280
    },
    'efficientnetb0': {
        'model_fn': lambda: timm.create_model('efficientnet_b0', pretrained=False, num_classes=4),
        'weight_file': 'efficientnetb0_best_80_10_10.pth',
        'target_layer': 'conv_head',
        'has_custom_head': True,
        'num_features': 1280
    },
    'vit': {
        'model_fn': lambda: timm.create_model('vit_base_patch16_224', pretrained=False, num_classes=4),
        'weight_file': 'vit_best_80_10_10.pth',
        'target_layer': 'blocks.attn',  # generalized for ViT hooks
        'has_custom_head': False
    },
    'swin': {
        'model_fn': lambda: timm.create_model('swin_tiny_patch4_window7_224', pretrained=False, num_classes=4),
        'weight_file': 'swin_best_80_10_10.pth',
        'target_layer': 'layers.2',
        'has_custom_head': False
    }
}

models = {}
for model_name, cfg in model_configs.items():
    try:
        model = cfg['model_fn']()
        if cfg['has_custom_head']:
            model.fc = CustomHead(cfg['num_features'], num_classes=4)
        weight_path = os.path.join(WEIGHTS_DIR, cfg['weight_file'])
        if os.path.exists(weight_path):
            model.load_state_dict(torch.load(weight_path, map_location=DEVICE), strict=False)
            model.to(DEVICE)
            model.eval()
            models[model_name] = model
            print(f"✓ {model_name.upper()} loaded from {weight_path}")
        else:
            print(f"⚠ {model_name.upper()} weights not found at {weight_path}")
    except Exception as e:
        print(f"✗ Error loading {model_name}: {e}")

print(f"\n✓ Total models loaded: {len(models)}")

# ========== 5a. GradCAM for CNNs and Swin ==========
class GradCAM:
    def __init__(self, model, target_layer_name):
        self.model = model
        self.target_layer_name = target_layer_name
        self.gradients = None
        self.activations = None
        self.hooks = []
        self._register_hooks()

    def _register_hooks(self):
        def forward_hook(module, inp, out):
            self.activations = out
        def backward_hook(module, grad_in, grad_out):
            self.gradients = grad_out[0]
        for name, module in self.model.named_modules():
            if self.target_layer_name in name or name.endswith(self.target_layer_name):
                self.hooks.append(module.register_forward_hook(forward_hook))
                self.hooks.append(module.register_full_backward_hook(backward_hook))
                break

    def generate_cam(self, input_tensor, class_idx=None):
        output = self.model(input_tensor)
        if class_idx is None:
            class_idx = output.argmax(dim=1).item()
        self.model.zero_grad()
        score = output[0, class_idx]
        score.backward(retain_graph=True)
        grads = self.gradients.detach().cpu().numpy()[0]
        acts = self.activations.detach().cpu().numpy()[0]
        weights = grads.mean(axis=(1, 2))
        cam = np.zeros(acts.shape[1:], dtype=np.float32)
        for c, w in enumerate(weights):
            cam += w * acts[c]
        cam = np.maximum(cam, 0)
        cam = cam / (cam.max() + 1e-8)
        return cam, class_idx

    def remove_hooks(self):
        for h in self.hooks:
            h.remove()

# ========== 5b. ViT Attention Rollout CAM ==========
class ViTAttentionCAM:
    def __init__(self, model):
        self.model = model
        self.attentions = []
        self.hooks = []
        self._register_hooks()

    def _register_hooks(self):
        def hook(module, input, output):
            self.attentions.append(output.detach())
        for name, module in self.model.named_modules():
            if name.count('attn') == 1 and 'blocks' in name and name.endswith('attn'):
                self.hooks.append(module.register_forward_hook(hook))

    def get_rollout_attention(self):
        if not self.attentions:
            raise RuntimeError("No attentions captured. Check hooks.")
        rollout = torch.eye(self.attentions[0].size(-1)).to(self.attentions[0].device)
        for attn in self.attentions:
            attn_heads_fused = attn.mean(dim=1)
            attn_heads_fused += torch.eye(attn_heads_fused.size(-1)).to(attn_heads_fused.device)
            attn_heads_fused /= attn_heads_fused.sum(dim=-1, keepdim=True)
            rollout = torch.matmul(rollout, attn_heads_fused)
        return rollout

    def generate_cam(self, input_tensor):
        self.attentions.clear()
        _ = self.model(input_tensor)
        rollout = self.get_rollout_attention()
        cam = rollout[0, 0, 1:]
        size = int(np.sqrt(cam.size(0)))
        cam_map = cam.reshape(size, size).cpu().numpy()
        cam_map = (cam_map - cam_map.min()) / (cam_map.max() - cam_map.min() + 1e-8)
        return cam_map

    def remove_hooks(self):
        for h in self.hooks:
            h.remove()
        self.hooks = []

print("✓ CAM classes ready")

# ========== 6. GENERATE CAMs ==========
print("\n[5/5] Generating CAMs...")

cam_output_dir = os.path.join(OUTPUT_DIR, 'cam_visualizations')
os.makedirs(cam_output_dir, exist_ok=True)

all_cams = {}

if len(models) == 0:
    print("✗ No models loaded. Cannot generate CAMs.")
else:
    for model_name, model in tqdm(models.items(), desc="Processing models"):
        all_cams[model_name] = {}
        if model_name == 'vit':
            cam_engine = ViTAttentionCAM(model)
            for cls_name in class_names:
                all_cams[model_name][cls_name] = []
                for img_idx, img_path in enumerate(test_images[cls_name]):
                    try:
                        img_tensor, img_rgb = preprocess_image(img_path)
                        cam_map = cam_engine.generate_cam(img_tensor)
                        cam_resized = cv2.resize(cam_map, (IMG_SIZE, IMG_SIZE))
                        cam_color = cv2.applyColorMap((cam_resized * 255).astype(np.uint8), cv2.COLORMAP_JET)
                        cam_color = cv2.cvtColor(cam_color, cv2.COLOR_BGR2RGB)
                        img_rgb_uint8 = (img_rgb * 255).astype(np.uint8)
                        overlay = cv2.addWeighted(img_rgb_uint8, 0.5, cam_color, 0.5, 0)
                        save_filename = f"cam_{model_name}_{cls_name}_{img_idx:03d}.png"
                        cv2.imwrite(os.path.join(cam_output_dir, save_filename),
                                    cv2.cvtColor(overlay, cv2.COLOR_RGB2BGR))
                        all_cams[model_name][cls_name].append({
                            'path': img_path,
                            'overlay': overlay,
                            'pred_class': None
                        })
                    except Exception:
                        pass
            cam_engine.remove_hooks()
            print(f"✓ {model_name.upper()}: {len(all_cams[model_name][class_names[0]])} CAMs per class")
        else:
            target_layer = model_configs[model_name]['target_layer']
            cam_engine = GradCAM(model, target_layer)
            for cls_name in class_names:
                all_cams[model_name][cls_name] = []
                for img_idx, img_path in enumerate(test_images[cls_name]):
                    try:
                        img_tensor, img_rgb = preprocess_image(img_path)
                        cam, pred_class = cam_engine.generate_cam(img_tensor)
                        cam_resized = cv2.resize(cam, (IMG_SIZE, IMG_SIZE))
                        cam_color = cv2.applyColorMap((cam_resized * 255).astype(np.uint8), cv2.COLORMAP_JET)
                        cam_color = cv2.cvtColor(cam_color, cv2.COLOR_BGR2RGB)
                        img_rgb_uint8 = (img_rgb * 255).astype(np.uint8)
                        overlay = cv2.addWeighted(img_rgb_uint8, 0.5, cam_color, 0.5, 0)
                        save_filename = f"cam_{model_name}_{cls_name}_{img_idx:03d}.png"
                        cv2.imwrite(os.path.join(cam_output_dir, save_filename),
                                    cv2.cvtColor(overlay, cv2.COLOR_RGB2BGR))
                        all_cams[model_name][cls_name].append({
                            'path': img_path,
                            'overlay': overlay,
                            'pred_class': pred_class
                        })
                    except Exception:
                        pass
            cam_engine.remove_hooks()
            print(f"✓ {model_name.upper()}: {len(all_cams[model_name][class_names[0]])} CAMs per class")

    print(f"\n✓ All CAM overlays saved to: {cam_output_dir}")

# ========== 7. CREATE COMPARISON GRIDS ==========
print("\n[6/6] Creating comparison grids...")
for cls_name in class_names:
    fig, axes = plt.subplots(1, len(models) + 1, figsize=(24, 4))
    first_img_path = test_images[cls_name][0]
    _, img_rgb = preprocess_image(first_img_path)
    img_rgb_uint8 = (img_rgb * 255).astype(np.uint8)
    axes[0].imshow(img_rgb_uint8, cmap='gray')
    axes[0].set_title(f'{cls_name} - Original', fontsize=11, fontweight='bold')
    axes[0].axis('off')
    for i, (model_name, _) in enumerate(models.items()):
        if len(all_cams[model_name][cls_name]) > 0:
            overlay = all_cams[model_name][cls_name][0]['overlay']
            axes[i+1].imshow(overlay)
            axes[i+1].set_title(model_name.upper(), fontsize=11, fontweight='bold')
        axes[i+1].axis('off')
    plt.tight_layout()
    plt.savefig(os.path.join(cam_output_dir, f'comparison_grid_{cls_name}.png'), dpi=300)
    plt.close()
    print(f"✓ Saved comparison grid for {cls_name}")

# ========== 8. SUMMARY ==========
print("\n" + "="*80)
print("SUMMARY: GRAD-CAM ANALYSIS COMPLETE")
print("="*80)

if len(models) > 0:
    print(f"\nModels processed: {list(models.keys())}")
    print(f"Classes analyzed: {class_names}")
    print(f"Images per class: {NUM_IMAGES_PER_CLASS}")
    total_cams = sum(len(all_cams[m][c]) for m in models.keys() for c in class_names)
    print(f"Total CAM visualizations: {total_cams}")
    print(f"\nOutput: {cam_output_dir}")
    print(f"Files: individual cam_<model>_<class>_<idx>.png + comparison_grid_<class>.png")
    print("\n✓ All CAM visualizations saved successfully!")
else:
    print("\n✗ No models loaded. Check weights.")
print("="*80)


NOTEBOOK 6: CLASS ACTIVATION MAPS (CAM) - ALL 5 MODELS INCLUDING REFINED ViT ATTENTION CAM

✓ Device: cuda
✓ Test dataset: /kaggle/input/split-dataset/test
✓ Weights directory: /kaggle/input/weight/weights
✓ Output directory: /kaggle/working

[1/5] Loading test images...
  CNV: 200
  DME: 200
  DRUSEN: 200
  NORMAL: 200
✓ Total test images: 800

[2/5] Setting up preprocessing...
✓ Preprocessing ready

[3/5] Loading trained models...
✓ RESNET50 loaded from /kaggle/input/weight/weights/resnet50_best_80_10_10.pth
✓ MOBILENETV2 loaded from /kaggle/input/weight/weights/mobilenetv2_best_80_10_10.pth
✓ EFFICIENTNETB0 loaded from /kaggle/input/weight/weights/efficientnetb0_best_80_10_10.pth
✓ VIT loaded from /kaggle/input/weight/weights/vit_best_80_10_10.pth
✓ SWIN loaded from /kaggle/input/weight/weights/swin_best_80_10_10.pth

✓ Total models loaded: 5
✓ CAM classes ready

[5/5] Generating CAMs...


Processing models:  20%|██        | 1/5 [00:25<01:40, 25.08s/it]

✓ RESNET50: 200 CAMs per class


Processing models:  40%|████      | 2/5 [00:47<01:10, 23.49s/it]

✓ MOBILENETV2: 200 CAMs per class


Processing models:  60%|██████    | 3/5 [01:16<00:51, 25.95s/it]

✓ EFFICIENTNETB0: 200 CAMs per class


Processing models:  80%|████████  | 4/5 [01:29<00:21, 21.01s/it]

✓ VIT: 0 CAMs per class


Processing models: 100%|██████████| 5/5 [02:03<00:00, 24.67s/it]

✓ SWIN: 200 CAMs per class

✓ All CAM overlays saved to: /kaggle/working/cam_visualizations

[6/6] Creating comparison grids...





✓ Saved comparison grid for CNV
✓ Saved comparison grid for DME
✓ Saved comparison grid for DRUSEN
✓ Saved comparison grid for NORMAL

SUMMARY: GRAD-CAM ANALYSIS COMPLETE

Models processed: ['resnet50', 'mobilenetv2', 'efficientnetb0', 'vit', 'swin']
Classes analyzed: ['CNV', 'DME', 'DRUSEN', 'NORMAL']
Images per class: 200
Total CAM visualizations: 3200

Output: /kaggle/working/cam_visualizations
Files: individual cam_<model>_<class>_<idx>.png + comparison_grid_<class>.png

✓ All CAM visualizations saved successfully!


In [16]:
import os
import numpy as np
import cv2
import torch
import torch.nn as nn
import timm
import matplotlib.pyplot as plt
from tqdm import tqdm
import random
import warnings
warnings.filterwarnings('ignore')

print("="*80)
print("NOTEBOOK 6: CLASS ACTIVATION MAPS (CAM) - 4 MODELS (ViT SKIPPED)")
print("="*80)

# ========== CONFIG ==========
TEST_DIR = '/kaggle/input/split-dataset/test'
WEIGHTS_DIR = '/kaggle/input/weight/weights'
OUTPUT_DIR = '/kaggle/working'
IMG_SIZE = 224
NUM_IMAGES_PER_CLASS = 200
DEVICE = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

os.makedirs(OUTPUT_DIR, exist_ok=True)

print(f"\n✓ Device: {DEVICE}")
print(f"✓ Test dataset: {TEST_DIR}")
print(f"✓ Weights directory: {WEIGHTS_DIR}")
print(f"✓ Output directory: {OUTPUT_DIR}")

# ========== 1. LOAD TEST IMAGES ==========
print("\n[1/5] Loading test images...")
class_names = sorted([d for d in os.listdir(TEST_DIR) if os.path.isdir(os.path.join(TEST_DIR, d))])
test_images = {cls: [] for cls in class_names}

for cls in class_names:
    cls_path = os.path.join(TEST_DIR, cls)
    images = [f for f in os.listdir(cls_path) if f.lower().endswith(('.png', '.jpg', '.jpeg'))]
    sampled = random.sample(images, min(NUM_IMAGES_PER_CLASS, len(images)))
    test_images[cls] = [os.path.join(cls_path, img) for img in sampled]

for cls in class_names:
    print(f"  {cls}: {len(test_images[cls])}")
total_images = sum(len(v) for v in test_images.values())
print(f"✓ Total test images: {total_images}")

# ========== 2. PREPROCESS ==========
print("\n[2/5] Setting up preprocessing...")
def preprocess_image(img_path):
    image = cv2.imread(img_path, cv2.IMREAD_GRAYSCALE)
    if image is None:
        image = np.zeros((IMG_SIZE, IMG_SIZE), dtype=np.uint8)
    image = cv2.resize(image, (IMG_SIZE, IMG_SIZE))
    image_rgb = cv2.cvtColor(image, cv2.COLOR_GRAY2RGB)
    image_norm = image_rgb.astype(np.float32) / 255.0
    image_norm = (image_norm - 0.5) / 0.5
    img_tensor = torch.from_numpy(image_norm).permute(2, 0, 1).unsqueeze(0).to(DEVICE)
    img_tensor.requires_grad_(True)
    return img_tensor, image_rgb

print("✓ Preprocessing ready")

# ========== 3. CUSTOM CNN CLASSIFIER HEAD ==========
class CustomHead(nn.Module):
    def __init__(self, in_features, num_classes=4):
        super().__init__()
        self.fc = nn.Sequential(
            nn.Linear(in_features, 512),
            nn.BatchNorm1d(512),
            nn.ReLU(),
            nn.Linear(512, 256),
            nn.BatchNorm1d(256),
            nn.ReLU(),
            nn.Linear(256, num_classes)
        )
    def forward(self, x):
        return self.fc(x)

# ========== 4. LOAD MODELS (ViT skipped) ==========
print("\n[3/5] Loading trained models (ViT skipped)...")
model_configs = {
    'resnet50': {
        'model_fn': lambda: timm.create_model('resnet50', pretrained=False, num_classes=4),
        'weight_file': 'resnet50_best_80_10_10.pth',
        'target_layer': 'layer4.2',
        'has_custom_head': True,
        'num_features': 2048,
    },
    'mobilenetv2': {
        'model_fn': lambda: timm.create_model('mobilenetv2_100', pretrained=False, num_classes=4),
        'weight_file': 'mobilenetv2_best_80_10_10.pth',
        'target_layer': 'conv_head',
        'has_custom_head': True,
        'num_features': 1280,
    },
    'efficientnetb0': {
        'model_fn': lambda: timm.create_model('efficientnet_b0', pretrained=False, num_classes=4),
        'weight_file': 'efficientnetb0_best_80_10_10.pth',
        'target_layer': 'conv_head',
        'has_custom_head': True,
        'num_features': 1280,
    },
    'swin': {
        'model_fn': lambda: timm.create_model('swin_tiny_patch4_window7_224', pretrained=False, num_classes=4),
        'weight_file': 'swin_best_80_10_10.pth',
        'target_layer': 'layers.2',
        'has_custom_head': False,
    }
}

models = {}
for model_name, cfg in model_configs.items():
    try:
        model = cfg['model_fn']()
        if cfg['has_custom_head']:
            model.fc = CustomHead(cfg['num_features'], num_classes=4)
        weight_path = os.path.join(WEIGHTS_DIR, cfg['weight_file'])
        if os.path.exists(weight_path):
            model.load_state_dict(torch.load(weight_path, map_location=DEVICE), strict=False)
            model.to(DEVICE)
            model.eval()
            models[model_name] = model
            print(f"✓ {model_name.upper()} loaded from {weight_path}")
        else:
            print(f"⚠ {model_name.upper()} weights not found at {weight_path}")
    except Exception as e:
        print(f"✗ Error loading {model_name}: {e}")

print(f"\n✓ Total models loaded: {len(models)}")

# ========== 5. GradCAM for all 4 models ==========
class GradCAM:
    def __init__(self, model, target_layer_name):
        self.model = model
        self.target_layer_name = target_layer_name
        self.gradients = None
        self.activations = None
        self.hooks = []
        self._register_hooks()

    def _register_hooks(self):
        def forward_hook(module, inp, out):
            self.activations = out
        def backward_hook(module, grad_in, grad_out):
            self.gradients = grad_out[0]
        for name, module in self.model.named_modules():
            if self.target_layer_name in name or name.endswith(self.target_layer_name):
                self.hooks.append(module.register_forward_hook(forward_hook))
                self.hooks.append(module.register_full_backward_hook(backward_hook))
                break

    def generate_cam(self, input_tensor, class_idx=None):
        output = self.model(input_tensor)
        if class_idx is None:
            class_idx = output.argmax(dim=1).item()
        self.model.zero_grad()
        score = output[0, class_idx]
        score.backward(retain_graph=True)
        grads = self.gradients.detach().cpu().numpy()[0]
        acts = self.activations.detach().cpu().numpy()[0]
        weights = grads.mean(axis=(1, 2))
        cam = np.zeros(acts.shape[1:], dtype=np.float32)
        for c, w in enumerate(weights):
            cam += w * acts[c]
        cam = np.maximum(cam, 0)
        cam = cam / (cam.max() + 1e-8)
        return cam, class_idx

    def remove_hooks(self):
        for h in self.hooks:
            h.remove()

print("✓ Grad-CAM ready")

# ========== 6. GENERATE CAMs ==========
print("\n[5/5] Generating CAMs...")

cam_output_dir = os.path.join(OUTPUT_DIR, 'cam_visualizations')
os.makedirs(cam_output_dir, exist_ok=True)

all_cams = {}

if len(models) == 0:
    print("✗ No models loaded. Cannot generate CAMs.")
else:
    for model_name, model in tqdm(models.items(), desc="Processing models"):
        target_layer = model_configs[model_name]['target_layer']
        cam_engine = GradCAM(model, target_layer)
        all_cams[model_name] = {}
        for cls_name in class_names:
            all_cams[model_name][cls_name] = []
            for img_idx, img_path in enumerate(test_images[cls_name]):
                try:
                    img_tensor, img_rgb = preprocess_image(img_path)
                    cam, pred_class = cam_engine.generate_cam(img_tensor)
                    cam_resized = cv2.resize(cam, (IMG_SIZE, IMG_SIZE))
                    cam_color = cv2.applyColorMap((cam_resized * 255).astype(np.uint8), cv2.COLORMAP_JET)
                    cam_color = cv2.cvtColor(cam_color, cv2.COLOR_BGR2RGB)
                    img_rgb_uint8 = (img_rgb * 255).astype(np.uint8)
                    overlay = cv2.addWeighted(img_rgb_uint8, 0.5, cam_color, 0.5, 0)
                    save_filename = f"cam_{model_name}_{cls_name}_{img_idx:03d}.png"
                    cv2.imwrite(os.path.join(cam_output_dir, save_filename),
                                cv2.cvtColor(overlay, cv2.COLOR_RGB2BGR))
                    all_cams[model_name][cls_name].append({
                        'path': img_path,
                        'overlay': overlay,
                        'pred_class': pred_class
                    })
                except Exception:
                    pass
        cam_engine.remove_hooks()
        print(f"✓ {model_name.upper()}: {len(all_cams[model_name][class_names[0]])} CAMs per class")

    print(f"\n✓ All CAM overlays saved to: {cam_output_dir}")

# ========== 7. CREATE COMPARISON GRIDS ==========
print("\n[6/6] Creating comparison grids...")
for cls_name in class_names:
    fig, axes = plt.subplots(1, len(models) + 1, figsize=(24, 4))
    first_img_path = test_images[cls_name][0]
    _, img_rgb = preprocess_image(first_img_path)
    img_rgb_uint8 = (img_rgb * 255).astype(np.uint8)
    axes[0].imshow(img_rgb_uint8, cmap='gray')
    axes[0].set_title(f'{cls_name} - Original', fontsize=11, fontweight='bold')
    axes[0].axis('off')
    for i, (model_name, _) in enumerate(models.items()):
        if len(all_cams[model_name][cls_name]) > 0:
            overlay = all_cams[model_name][cls_name][0]['overlay']
            axes[i+1].imshow(overlay)
            axes[i+1].set_title(model_name.upper(), fontsize=11, fontweight='bold')
        axes[i+1].axis('off')
    plt.tight_layout()
    plt.savefig(os.path.join(cam_output_dir, f'comparison_grid_{cls_name}.png'), dpi=300)
    plt.close()
    print(f"✓ Saved comparison grid for {cls_name}")

# ========== 8. SUMMARY ==========
print("\n" + "="*80)
print("SUMMARY: GRAD-CAM ANALYSIS COMPLETE")
print("="*80)

if len(models) > 0:
    print(f"\nModels processed: {list(models.keys())}")
    print(f"Classes analyzed: {class_names}")
    print(f"Images per class: {NUM_IMAGES_PER_CLASS}")
    total_cams = sum(len(all_cams[m][c]) for m in models.keys() for c in class_names)
    print(f"Total CAM visualizations: {total_cams}")
    print(f"\nOutput: {cam_output_dir}")
    print(f"Files: individual cam_<model>_<class>_<idx>.png + comparison_grid_<class>.png")
    print("\n✓ All CAM visualizations saved successfully!")
else:
    print("\n✗ No models loaded. Check weights.")
print("="*80)


NOTEBOOK 6: CLASS ACTIVATION MAPS (CAM) - 4 MODELS (ViT SKIPPED)

✓ Device: cuda
✓ Test dataset: /kaggle/input/split-dataset/test
✓ Weights directory: /kaggle/input/weight/weights
✓ Output directory: /kaggle/working

[1/5] Loading test images...
  CNV: 200
  DME: 200
  DRUSEN: 200
  NORMAL: 200
✓ Total test images: 800

[2/5] Setting up preprocessing...
✓ Preprocessing ready

[3/5] Loading trained models (ViT skipped)...
✓ RESNET50 loaded from /kaggle/input/weight/weights/resnet50_best_80_10_10.pth
✓ MOBILENETV2 loaded from /kaggle/input/weight/weights/mobilenetv2_best_80_10_10.pth
✓ EFFICIENTNETB0 loaded from /kaggle/input/weight/weights/efficientnetb0_best_80_10_10.pth
✓ SWIN loaded from /kaggle/input/weight/weights/swin_best_80_10_10.pth

✓ Total models loaded: 4
✓ Grad-CAM ready

[5/5] Generating CAMs...


Processing models:  25%|██▌       | 1/4 [00:24<01:14, 24.73s/it]

✓ RESNET50: 200 CAMs per class


Processing models:  50%|█████     | 2/4 [00:46<00:46, 23.19s/it]

✓ MOBILENETV2: 200 CAMs per class


Processing models:  75%|███████▌  | 3/4 [01:15<00:25, 25.74s/it]

✓ EFFICIENTNETB0: 200 CAMs per class


Processing models: 100%|██████████| 4/4 [01:49<00:00, 27.29s/it]

✓ SWIN: 200 CAMs per class

✓ All CAM overlays saved to: /kaggle/working/cam_visualizations

[6/6] Creating comparison grids...





✓ Saved comparison grid for CNV
✓ Saved comparison grid for DME
✓ Saved comparison grid for DRUSEN
✓ Saved comparison grid for NORMAL

SUMMARY: GRAD-CAM ANALYSIS COMPLETE

Models processed: ['resnet50', 'mobilenetv2', 'efficientnetb0', 'swin']
Classes analyzed: ['CNV', 'DME', 'DRUSEN', 'NORMAL']
Images per class: 200
Total CAM visualizations: 3200

Output: /kaggle/working/cam_visualizations
Files: individual cam_<model>_<class>_<idx>.png + comparison_grid_<class>.png

✓ All CAM visualizations saved successfully!


In [17]:
import zipfile

zip_path = os.path.join(OUTPUT_DIR, 'cam_visualizations.zip')

with zipfile.ZipFile(zip_path, 'w', zipfile.ZIP_DEFLATED) as zipf:
    for root, dirs, files in os.walk(os.path.join(OUTPUT_DIR, 'cam_visualizations')):
        for file in files:
            if file.endswith('.png'):
                file_path = os.path.join(root, file)
                arcname = os.path.relpath(file_path, OUTPUT_DIR)
                zipf.write(file_path, arcname)

print(f"\n✓ ZIP archive created with all CAM images: {zip_path}")



✓ ZIP archive created with all CAM images: /kaggle/working/cam_visualizations.zip


In [18]:
import os
import cv2
import numpy as np
from tqdm import tqdm

# ========== CONFIG ==========
TEST_DIR = '/kaggle/input/split-dataset/test'
OUTPUT_DIR = '/kaggle/working/simple_processing'
IMG_SIZE = 224
NUM_IMAGES_PER_CLASS = 50  # reduce if needed

os.makedirs(OUTPUT_DIR, exist_ok=True)

print("Simple image processing on OCT test images")

# ========== 1. COLLECT SAMPLE IMAGES ==========
class_names = sorted([d for d in os.listdir(TEST_DIR) if os.path.isdir(os.path.join(TEST_DIR, d))])
test_images = {cls: [] for cls in class_names}

for cls in class_names:
    cls_path = os.path.join(TEST_DIR, cls)
    images = [f for f in os.listdir(cls_path) if f.lower().endswith(('.png', '.jpg', '.jpeg'))]
    images = images[:NUM_IMAGES_PER_CLASS]
    test_images[cls] = [os.path.join(cls_path, img) for img in images]
    print(f"{cls}: {len(test_images[cls])} images")

# ========== 2. DEFINE PROCESSING FUNCTIONS ==========

def apply_clahe(gray):
    # Contrast Limited Adaptive Histogram Equalization
    clahe = cv2.createCLAHE(clipLimit=2.0, tileGridSize=(8, 8))
    return clahe.apply(gray)

def apply_edges(gray):
    # Canny edge detection
    return cv2.Canny(gray, 50, 150)

def apply_morphology(gray):
    # Morphological opening (remove small bright noise)
    kernel = cv2.getStructuringElement(cv2.MORPH_ELLIPSE, (5, 5))
    opened = cv2.morphologyEx(gray, cv2.MORPH_OPEN, kernel)
    # Top-hat to enhance bright structures if you want
    tophat = cv2.morphologyEx(gray, cv2.MORPH_TOPHAT, kernel)
    return opened, tophat

# ========== 3. PROCESS AND SAVE ==========
for cls in tqdm(class_names, desc="Processing classes"):
    cls_out_dir = os.path.join(OUTPUT_DIR, cls)
    os.makedirs(cls_out_dir, exist_ok=True)

    for img_idx, img_path in enumerate(test_images[cls]):
        img = cv2.imread(img_path, cv2.IMREAD_GRAYSCALE)
        if img is None:
            continue
        img = cv2.resize(img, (IMG_SIZE, IMG_SIZE))

        # Original
        orig = img

        # CLAHE
        clahe_img = apply_clahe(orig)

        # Edges
        edges = apply_edges(orig)

        # Morphology
        opened, tophat = apply_morphology(orig)

        # Stack for quick visualization grid (optional)
        grid = np.zeros((IMG_SIZE * 2, IMG_SIZE * 2), dtype=np.uint8)
        grid[0:IMG_SIZE, 0:IMG_SIZE] = orig
        grid[0:IMG_SIZE, IMG_SIZE:2*IMG_SIZE] = clahe_img
        grid[IMG_SIZE:2*IMG_SIZE, 0:IMG_SIZE] = edges
        grid[IMG_SIZE:2*IMG_SIZE, IMG_SIZE:2*IMG_SIZE] = tophat

        base = f"{cls}_{img_idx:03d}"
        cv2.imwrite(os.path.join(cls_out_dir, base + "_orig.png"), orig)
        cv2.imwrite(os.path.join(cls_out_dir, base + "_clahe.png"), clahe_img)
        cv2.imwrite(os.path.join(cls_out_dir, base + "_edges.png"), edges)
        cv2.imwrite(os.path.join(cls_out_dir, base + "_opened.png"), opened)
        cv2.imwrite(os.path.join(cls_out_dir, base + "_tophat.png"), tophat)
        cv2.imwrite(os.path.join(cls_out_dir, base + "_grid.png"), grid)

print(f"\n✓ Simple processing outputs saved under: {OUTPUT_DIR}")


Simple image processing on OCT test images
CNV: 50 images
DME: 50 images
DRUSEN: 50 images
NORMAL: 50 images


Processing classes: 100%|██████████| 4/4 [00:01<00:00,  2.56it/s]


✓ Simple processing outputs saved under: /kaggle/working/simple_processing





In [19]:
import zipfile

zip_path = os.path.join('/kaggle/working', 'simple_processing.zip')
with zipfile.ZipFile(zip_path, 'w', zipfile.ZIP_DEFLATED) as zf:
    for root, _, files in os.walk(OUTPUT_DIR):
        for f in files:
            if f.endswith('.png'):
                fp = os.path.join(root, f)
                zf.write(fp, os.path.relpath(fp, '/kaggle/working'))

print(f"✓ ZIP created: {zip_path}")


✓ ZIP created: /kaggle/working/simple_processing.zip
