# 🧠 NeuroWell AI: Batch Grad-CAM Simulation (Colab Ready)

This notebook simulates Grad-CAM overlays for pediatric neurosurgery use cases using a dummy CNN and two fake MRI images.

Use this to demo your pipeline or build future MRI visualization tools.

In [None]:
# 🔧 Install required packages
!pip install torch torchvision matplotlib opencv-python Pillow

In [None]:
# 🔁 Import libraries
import numpy as np
import matplotlib.pyplot as plt
import torch
import torch.nn as nn
from torchvision import transforms
from PIL import Image
import cv2
import os

os.makedirs("GradCAM_Results", exist_ok=True)

In [None]:
# 📦 Define dummy CNN model
class SimpleCNN(nn.Module):
    def __init__(self):
        super(SimpleCNN, self).__init__()
        self.conv1 = nn.Conv2d(1, 8, kernel_size=3, padding=1)
        self.relu1 = nn.ReLU()
        self.pool = nn.MaxPool2d(2, 2)
        self.conv2 = nn.Conv2d(8, 16, kernel_size=3, padding=1)
        self.relu2 = nn.ReLU()
        self.fc1 = nn.Linear(16 * 7 * 7, 3)

    def forward(self, x):
        self.features = self.relu1(self.conv1(x))
        x = self.pool(self.features)
        x = self.relu2(self.conv2(x))
        x = self.pool(x)
        x = x.view(x.size(0), -1)
        return self.fc1(x)

model = SimpleCNN()
model.eval()

In [None]:
# 🧪 Simulate 2 MRI images and apply Grad-CAM
transform = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize((0.5,), (0.5,))
])

label_map = {0: 'astrocytoma', 1: 'non-CJD', 2: 'simulated-CJD'}

for i in range(2):
    img = (np.random.rand(28, 28) * 255).astype(np.uint8)
    img_pil = Image.fromarray(img)
    input_tensor = transform(img_pil).unsqueeze(0)

    output = model(input_tensor)
    pred = torch.argmax(output, dim=1).item()

    grad_cam = model.features.detach().squeeze()[0].numpy()
    heatmap = cv2.applyColorMap(np.uint8(255 * grad_cam / grad_cam.max()), cv2.COLORMAP_JET)
    overlay = cv2.addWeighted(np.stack([img]*3, axis=-1), 0.6, heatmap, 0.4, 0)

    filename = f"GradCAM_Results/{label_map.get(pred, 'unknown')}_img{i+1}.png"
    Image.fromarray(overlay).save(filename)
    print(f"Saved: {filename}")