# Model Interpretability & Robustness Analysis

This notebook analyzes the trained models using Grad-CAM for interpretability and evaluates robustness to input perturbations.

In [ ]:
import torch
import torch.nn as nn
from torchvision import transforms, datasets, models
from torch.utils.data import DataLoader
import matplotlib.pyplot as plt
import numpy as np
import cv2
from PIL import Image
import random

## Configuration

In [ ]:
# Model paths
MODEL_PATH = 'baseline_resnet18.pth'  # Change to DANN/CDAN model if desired
REAL_TEST_DIR = 'data/real/test'
BATCH_SIZE = 16
DEVICE = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

In [ ]:
# Data loader
transform = transforms.Compose([
    transforms.Resize((224,224)),
    transforms.ToTensor(),
    transforms.Normalize([0.485,0.456,0.406],[0.229,0.224,0.225])
])
test_dataset = datasets.ImageFolder(REAL_TEST_DIR, transform=transform)
test_loader = DataLoader(test_dataset, batch_size=BATCH_SIZE, shuffle=False)
classes = test_dataset.classes

## 1. Load Pretrained Model

In [ ]:
model = models.resnet18(pretrained=False)
model.fc = nn.Linear(model.fc.in_features, len(classes))
model.load_state_dict(torch.load(MODEL_PATH, map_location=DEVICE))
model = model.to(DEVICE)
model.eval()

## 2. Interpretability using Grad-CAM

In [ ]:
def generate_gradcam(model, image, target_class):
    model.eval()
    gradients = []
    activations = []
    
    def backward_hook(module, grad_input, grad_output):
        gradients.append(grad_output[0])
        
    def forward_hook(module, input, output):
        activations.append(output)
        
    # Attach hooks to last conv layer
    for name, module in model.named_modules():
        if isinstance(module, nn.Conv2d):
            target_layer = module
    
    handle_fw = target_layer.register_forward_hook(forward_hook)
    handle_bw = target_layer.register_backward_hook(backward_hook)
    
    image = image.unsqueeze(0).to(DEVICE)
    output = model(image)
    loss = output[0, target_class]
    model.zero_grad()
    loss.backward()
    
    grads = gradients[0].cpu().data.numpy()[0]
    acts = activations[0].cpu().data.numpy()[0]
    weights = np.mean(grads, axis=(1,2))
    cam = np.zeros(acts.shape[1:], dtype=np.float32)
    for i, w in enumerate(weights):
        cam += w * acts[i, :, :]
    cam = np.maximum(cam, 0)
    cam = cv2.resize(cam, (224,224))
    cam = cam - np.min(cam)
    cam = cam / np.max(cam)
    
    handle_fw.remove()
    handle_bw.remove()
    return cam

In [ ]:
# Show Grad-CAM for random samples
num_samples = 3
indices = random.sample(range(len(test_dataset)), num_samples)

for idx in indices:
    img, label = test_dataset[idx]
    cam = generate_gradcam(model, img, label)
    img_np = np.transpose(img.numpy(), (1,2,0))
    img_np = (img_np * [0.229,0.224,0.225]) + [0.485,0.456,0.406]  # denormalize
    img_np = np.clip(img_np, 0, 1)
    
    plt.figure(figsize=(6,6))
    plt.imshow(img_np)
    plt.imshow(cam, cmap='jet', alpha=0.5)
    plt.title(f'True class: {classes[label]}')
    plt.axis('off')
    plt.show()

## 3. Robustness Analysis

In [ ]:
def perturb_image(img, noise_level=0.1, blur_kernel=3, brightness_factor=1.2):
    img_np = np.transpose(img.numpy(), (1,2,0))
    img_np = (img_np * [0.229,0.224,0.225]) + [0.485,0.456,0.406]
    img_np = np.clip(img_np, 0, 1)
    img_np = (img_np*255).astype(np.uint8)
    img_cv = cv2.cvtColor(img_np, cv2.COLOR_RGB2BGR)
    # Add noise
    noise = np.random.randn(*img_cv.shape) * noise_level * 255
    img_cv = np.clip(img_cv + noise, 0, 255).astype(np.uint8)
    # Blur
    img_cv = cv2.GaussianBlur(img_cv, (blur_kernel, blur_kernel), 0)
    # Brightness
    img_cv = cv2.convertScaleAbs(img_cv, alpha=brightness_factor, beta=0)
    img_cv = cv2.cvtColor(img_cv, cv2.COLOR_BGR2RGB)
    img_cv = img_cv.astype(np.float32)/255.0
    img_tensor = torch.tensor(np.transpose(img_cv,(2,0,1)), dtype=torch.float32)
    img_tensor = transforms.Normalize([0.485,0.456,0.406],[0.229,0.224,0.225])(img_tensor)
    return img_tensor

In [ ]:
# Evaluate robustness on a subset
num_test = 50
indices = random.sample(range(len(test_dataset)), num_test)
correct = 0

for idx in indices:
    img, label = test_dataset[idx]
    perturbed_img = perturb_image(img)
    perturbed_img = perturbed_img.unsqueeze(0).to(DEVICE)
    with torch.no_grad():
        output = model(perturbed_img)
        _, pred = torch.max(output,1)
        if pred.item() == label:
            correct +=1

robust_acc = correct / num_test
print(f'Robustness Accuracy under perturbations: {robust_acc:.4f}')