# GradCAM Visualization for VLMs

This notebook explains how we use Gradient-weighted Class Activation Mapping (GradCAM) to visualize which image regions the Vision-Language Model focuses on during counting.

## What is GradCAM?

GradCAM is a technique for visualizing where a neural network is "looking" when making predictions. It combines:
1. **Gradients**: How much the output changes with respect to feature maps
2. **Activations**: The actual feature representations from the model

For VLMs like Qwen3-VL, GradCAM shows us which image regions influence the model's counting prediction.

## Why GradCAM for VLMs?

Vision-Language Models like Qwen3-VL use a **decoder-only architecture** (similar to GPT with vision):
- Image and text tokens are in the same sequence
- Only self-attention is available (no cross-attention)
- Traditional attention visualization doesn't work well

GradCAM solves this by:
- Using gradients from the backward pass
- Weighting feature activations by their importance
- Creating spatial heatmaps showing region importance

## Setup

First, let's import the necessary modules:

In [None]:
import sys
sys.path.append('../src')

from PIL import Image
import json
from pathlib import Path
import matplotlib.pyplot as plt
import numpy as np

# Import our GradCAM module
from visualize_vlm_gradcam import VLMGradCAM

## How GradCAM Works

### Step-by-Step Process

1. **Forward Pass**: Run image through VLM to get count prediction
   ```python
   outputs = model(**inputs)
   logits = outputs.logits
   score = logits[0, -1, :].max()  # Model's confidence
   ```

2. **Register Hooks**: Capture activations and gradients from vision encoder
   ```python
   def forward_hook(module, input, output):
       self.activations = output  # Save activations
   
   def backward_hook(module, grad_input, grad_output):
       self.gradients = grad_output[0]  # Save gradients
   ```

3. **Backward Pass**: Compute gradients w.r.t. the output
   ```python
   score.backward()  # Backpropagate
   ```

4. **Compute GradCAM**: Weight activations by gradients
   ```python
   # Pool gradients spatially
   pooled_grads = gradients.mean(dim=0)
   
   # Weight activations
   weighted_acts = activations * pooled_grads
   
   # Average across channels and apply ReLU
   heatmap = weighted_acts.mean(dim=-1)
   heatmap = ReLU(heatmap)  # Keep only positive contributions
   ```

5. **Create Visualization**: Resize heatmap to image size and overlay
   ```python
   # Normalize heatmap
   heatmap = (heatmap - heatmap.min()) / (heatmap.max() - heatmap.min())
   
   # Resize to image dimensions
   heatmap_resized = resize(heatmap, (img_height, img_width))
   
   # Overlay on image
   plt.imshow(image)
   plt.imshow(heatmap_resized, cmap='jet', alpha=0.5)
   ```

## Technical Details

### Vision Tokenization

Qwen3-VL divides images into patches:
- Full image (384×512): ~864 vision tokens → 24×36 grid
- Crop (256×256): ~280 vision tokens → 14×20 grid

### Tensor Shapes

During GradCAM computation:
```
Gradients:   [num_tokens, hidden_dim]  e.g., [864, 1024]
Activations: [num_tokens, hidden_dim]  e.g., [864, 1024]
```

We need to:
1. Pool gradients over hidden dimension
2. Weight activations
3. Reshape token sequence to 2D spatial grid
4. Resize to original image dimensions

## Using the GradCAM Visualizer

### Basic Usage

In [None]:
def visualize_with_gradcam(image_path, category, strategy='comparison'):
    """
    Generate GradCAM visualizations for an image.
    
    Args:
        image_path: Path to image
        category: Object category to count
        strategy: 'comparison', 'dense', or 'hybrid'
    """
    # Initialize GradCAM
    gradcam = VLMGradCAM()
    
    # Load image
    image = Image.open(image_path)
    
    # Generate visualization using the unified API
    output_path = f'output_{strategy}.png'
    gradcam.visualize_counting_attention(image, category, strategy=strategy, output_path=output_path)
    
    print(f"Visualization saved to {output_path}!")

# Example usage (uncomment to run)
# visualize_with_gradcam('/media/M2SSD/FSC147/images_384_VarV2/194.jpg', 'peaches', 'hybrid')

## Interpreting GradCAM Heatmaps

### Color Scale
- **Red/Yellow**: High importance (VLM focused here)
- **Green/Blue**: Medium importance
- **Purple/Dark**: Low importance (VLM ignored)

### What Good Heatmaps Look Like

For accurate counting:
- Red/yellow regions should align with object locations
- Multiple objects should have multiple hotspots
- Background should be low activation (blue/purple)

### Common Patterns

1. **Well-distributed objects**: Heatmap shows multiple distinct hotspots
2. **Clustered objects**: Heatmap shows larger connected regions
3. **Edge objects**: May have lower activation (boundary artifacts)
4. **Occlusion**: Partially visible objects may have weaker activation

## Example Visualizations

Let's look at the example visualizations we generated:

In [None]:
# Display example GradCAM visualizations
fig, axes = plt.subplots(1, 3, figsize=(18, 6))

# Load and display the three example visualizations
viz_dir = Path('../visualizations/gradcam')

images = [
    'gradcam_comparison_3.png',
    'gradcam_dense_3.png', 
    'gradcam_hybrid_3.png'
]

titles = [
    'Comparison: Global vs Crop',
    'Dense Grid: 3×3',
    'Hybrid: Global + Quadrants'
]

for ax, img_name, title in zip(axes, images, titles):
    img_path = viz_dir / img_name
    if img_path.exists():
        img = Image.open(img_path)
        ax.imshow(img)
        ax.set_title(title, fontsize=12, fontweight='bold')
        ax.axis('off')
    else:
        ax.text(0.5, 0.5, f'Run visualizations first:\n{img_name}',
               ha='center', va='center', transform=ax.transAxes)
        ax.axis('off')

plt.tight_layout()
plt.show()

## Implementation Notes

### Challenges Solved

1. **Tensor Shape Handling**
   - Vision tokens are 1D sequences that need to be reshaped to 2D grids
   - Not all token counts are perfect squares (e.g., 864 = 24×36)
   - Solution: Try common aspect ratios and grid sizes

2. **Float16 Compatibility**
   - Scipy's `gaussian_filter` doesn't support float16
   - Solution: Convert to float32 before smoothing operations

3. **Hook Registration**
   - Need to capture both forward activations and backward gradients
   - Solution: Register both forward and backward hooks on vision blocks

### Code Structure

```python
class VLMGradCAM:
    def __init__(self, model_name):
        # Load VLM
        # Initialize hook storage
    
    def _register_hooks(self):
        # Register forward/backward hooks on vision encoder
    
    def generate_gradcam(self, image, category, crop_bbox=None):
        # 1. Forward pass
        # 2. Backward pass  
        # 3. Compute GradCAM from gradients + activations
        # 4. Return heatmap
    
    def visualize_comparison(self, image, category, output_path):
        # Generate global and crop GradCAM
        # Create side-by-side visualization
```

## Use Cases

### 1. Debugging Counting Errors
- If count is wrong, check if heatmap covers all objects
- Missing hotspots → objects not detected by VLM
- Background activation → VLM confused by clutter

### 2. Validating Strategies
- Compare global vs local heatmaps
- Ensure crops capture relevant regions
- Verify overlap handling in grid strategies

### 3. Model Analysis
- Understand VLM counting behavior
- Identify failure modes (edge cases, occlusion)
- Guide improvements to counting strategies

### 4. Research & Development
- Visualize attention for RL reward signals
- Analyze which image regions contribute to predictions
- Debug VLM-based RL agents

## Advanced: Custom GradCAM

You can create custom visualizations by directly using the GradCAM generator:

In [None]:
def custom_gradcam_viz(image_path, category, custom_crops):
    """
    Create custom GradCAM visualization with specific crop regions.
    
    Args:
        image_path: Path to image
        category: Object category
        custom_crops: List of (x1, y1, x2, y2) crop boxes
    """
    gradcam = VLMGradCAM()
    image = Image.open(image_path)
    
    fig, axes = plt.subplots(1, len(custom_crops) + 1, figsize=(6*(len(custom_crops)+1), 6))
    
    # Global view
    heatmap = gradcam.generate_gradcam(image, category)
    axes[0].imshow(image)
    if heatmap is not None:
        axes[0].imshow(heatmap, cmap='jet', alpha=0.5)
    axes[0].set_title('Global', fontsize=14, fontweight='bold')
    axes[0].axis('off')
    
    # Custom crops
    for i, crop_bbox in enumerate(custom_crops):
        heatmap = gradcam.generate_gradcam(image, category, crop_bbox)
        
        # Show crop
        cropped = image.crop(crop_bbox)
        axes[i+1].imshow(cropped)
        if heatmap is not None:
            axes[i+1].imshow(heatmap, cmap='jet', alpha=0.5)
        axes[i+1].set_title(f'Crop {i+1}', fontsize=14, fontweight='bold')
        axes[i+1].axis('off')
    
    plt.tight_layout()
    return fig

# Example: Visualize specific regions of interest
# custom_crops = [(0, 0, 192, 192), (192, 0, 384, 192), (0, 192, 192, 384)]
# fig = custom_gradcam_viz('/path/to/image.jpg', 'objects', custom_crops)
# plt.show()

## Further Reading

- **GradCAM Paper**: "Grad-CAM: Visual Explanations from Deep Networks via Gradient-based Localization" (Selvaraju et al., 2017)
- **VLM Architecture**: Qwen3-VL uses decoder-only architecture with vision patches
- **Alternative Methods**: We also explored Relevancy Propagation (Chefer et al.) but GradCAM proved more practical

## Next Steps

- Return to `01_VLM_Counting_Strategies.ipynb` to learn about counting strategies
- Explore source code: `src/visualize_vlm_gradcam.py` for full implementation
- Run evaluations: `src/evaluate_all_methods.py` to test on FSC147 dataset