# Model Visualization - GradCAM and Feature Analysis

This notebook visualizes what the ConvNeXt model learns using GradCAM and feature embeddings.

In [None]:
import sys
sys.path.append('..')

import torch
import numpy as np
import matplotlib.pyplot as plt
from PIL import Image
from torchvision import transforms
import cv2

from src.model import ConvNeXtPrecursorModel
from src.inference import PrecursorPredictor

%matplotlib inline

## 1. Load Model

In [None]:
# Load pretrained model
model = ConvNeXtPrecursorModel.load_pretrained('../models/convnext_loeo_best.pth')
print("Model loaded successfully!")

## 2. GradCAM Visualization

In [None]:
class GradCAM:
    def __init__(self, model, target_layer):
        self.model = model
        self.target_layer = target_layer
        self.gradients = None
        self.activations = None
        
        target_layer.register_forward_hook(self.save_activation)
        target_layer.register_backward_hook(self.save_gradient)
    
    def save_activation(self, module, input, output):
        self.activations = output.detach()
    
    def save_gradient(self, module, grad_input, grad_output):
        self.gradients = grad_output[0].detach()
    
    def generate(self, input_tensor, target_class=None):
        self.model.model.eval()
        output = self.model.model(input_tensor)
        
        if target_class is None:
            target_class = output[0].argmax().item()
        
        self.model.model.zero_grad()
        output[0][0, target_class].backward()
        
        weights = self.gradients.mean(dim=(2, 3), keepdim=True)
        cam = (weights * self.activations).sum(dim=1, keepdim=True)
        cam = torch.relu(cam)
        cam = cam - cam.min()
        cam = cam / cam.max()
        
        return cam.squeeze().cpu().numpy()

In [None]:
def visualize_gradcam(image_path, model, save_path=None):
    # Load and preprocess image
    img = Image.open(image_path).convert('RGB')
    transform = transforms.Compose([
        transforms.Resize((224, 224)),
        transforms.ToTensor(),
        transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
    ])
    input_tensor = transform(img).unsqueeze(0)
    
    # Get prediction
    result = model.predict(input_tensor)
    
    # Generate GradCAM
    target_layer = model.model.backbone.features[-1]
    gradcam = GradCAM(model, target_layer)
    cam = gradcam.generate(input_tensor)
    
    # Resize CAM to image size
    cam_resized = cv2.resize(cam, (224, 224))
    
    # Create visualization
    fig, axes = plt.subplots(1, 3, figsize=(15, 5))
    
    # Original image
    axes[0].imshow(img.resize((224, 224)))
    axes[0].set_title('Original Spectrogram')
    axes[0].axis('off')
    
    # GradCAM heatmap
    axes[1].imshow(cam_resized, cmap='jet')
    axes[1].set_title('GradCAM Heatmap')
    axes[1].axis('off')
    
    # Overlay
    img_array = np.array(img.resize((224, 224))) / 255.0
    heatmap = plt.cm.jet(cam_resized)[:, :, :3]
    overlay = 0.6 * img_array + 0.4 * heatmap
    axes[2].imshow(overlay)
    axes[2].set_title(f'Overlay\nMag: {result["magnitude_class"]} ({result["magnitude_prob"]:.1%})\nAzi: {result["azimuth_class"]} ({result["azimuth_prob"]:.1%})')
    axes[2].axis('off')
    
    plt.tight_layout()
    if save_path:
        plt.savefig(save_path, dpi=150)
    plt.show()

# Example usage
# visualize_gradcam('../data/spectrograms/sample.png', model)

## 3. Feature Embedding Visualization

In [None]:
from sklearn.manifold import TSNE
from sklearn.decomposition import PCA

def visualize_embeddings(model, dataloader, method='tsne'):
    """Visualize feature embeddings using t-SNE or PCA"""
    model.model.eval()
    
    features = []
    mag_labels = []
    azi_labels = []
    
    with torch.no_grad():
        for images, mag, azi in dataloader:
            feat = model.model.get_features(images.to(model.device))
            if feat.dim() == 4:
                feat = feat.flatten(1)
            features.append(feat.cpu().numpy())
            mag_labels.extend(mag.numpy())
            azi_labels.extend(azi.numpy())
    
    features = np.vstack(features)
    
    # Dimensionality reduction
    if method == 'tsne':
        reducer = TSNE(n_components=2, random_state=42, perplexity=30)
    else:
        reducer = PCA(n_components=2)
    
    embeddings = reducer.fit_transform(features)
    
    # Plot
    fig, axes = plt.subplots(1, 2, figsize=(14, 6))
    
    scatter1 = axes[0].scatter(embeddings[:, 0], embeddings[:, 1], 
                               c=mag_labels, cmap='viridis', alpha=0.7)
    axes[0].set_title('Feature Embeddings (Magnitude)')
    plt.colorbar(scatter1, ax=axes[0])
    
    scatter2 = axes[1].scatter(embeddings[:, 0], embeddings[:, 1], 
                               c=azi_labels, cmap='tab10', alpha=0.7)
    axes[1].set_title('Feature Embeddings (Azimuth)')
    plt.colorbar(scatter2, ax=axes[1])
    
    plt.tight_layout()
    plt.show()

# Example usage (requires dataloader)
# visualize_embeddings(model, test_loader)