# STickNet Feature Map Visualization

This notebook demonstrates how to extract and visualize feature maps from different stages of the STickNet (Spatial Tick Network) model. We'll explore what the network learns at various depths by examining the intermediate representations.

## Overview
- Load and configure a STickNet model
- Extract feature maps from different stages of the backbone
- Create visualizations to understand learned features
- Compare feature evolution across network stages

In [None]:
# Import Required Libraries
import torch
import torch.nn as nn
import torchvision.transforms as transforms
from torchvision import datasets
from PIL import Image
import matplotlib.pyplot as plt
import numpy as np
import os
import sys
from typing import Dict, List, Tuple

# Add parent directory to path for importing models
sys.path.append(os.path.abspath('..'))

# Import STickNet components
try:
  from models import build_STickNet, SpatialTickNet
  print("Successfully imported STickNet components")
except ImportError as e:
  print(f"Import error: {e}")
  # Fallback method
  import importlib.util
  spec = importlib.util.spec_from_file_location("STickNet", "../models/STickNet.py")
  STickNet_module = importlib.util.module_from_spec(spec)
  spec.loader.exec_module(STickNet_module)
  
  build_STickNet = STickNet_module.build_STickNet
  SpatialTickNet = STickNet_module.SpatialTickNet
  print("Using fallback import method")

ModuleNotFoundError: No module named 'torch'

## Load and Configure STickNet Model
Create a STickNet model and set it to evaluation mode for feature extraction.

In [None]:
# Set device
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(f"Using device: {device}")

# Build STickNet model
model = build_STickNet(
  num_classes=1000,  # ImageNet classes
  typesize='small',
  cifar=False,
  use_lightweight_optimization=False
)

# Move model to device and set to evaluation mode
model = model.to(device)
model.eval()

print("Model architecture:")
print(model)

: 

## Prepare Input Data
Load and preprocess input images to match model requirements.

In [None]:
# Define transforms for preprocessing
transform = transforms.Compose([
  transforms.Resize((224, 224)),
  transforms.ToTensor(),
  transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
])

# Create a sample input or load an image
# Option 1: Create a random input for testing
sample_input = torch.randn(1, 3, 224, 224).to(device)

# Option 2: Load a real image (uncomment if you have an image file)
# image_path = "path/to/your/image.jpg"
# if os.path.exists(image_path):
#     image = Image.open(image_path).convert('RGB')
#     sample_input = transform(image).unsqueeze(0).to(device)

print(f"Input tensor shape: {sample_input.shape}")

# Display input if it's a real image
plt.figure(figsize=(6, 6))
if sample_input.max() > 1:  # If not normalized
  plt.imshow(sample_input.squeeze().permute(1, 2, 0).cpu().numpy())
else:  # If normalized, denormalize for display
  mean = torch.tensor([0.485, 0.456, 0.406]).view(3, 1, 1)
  std = torch.tensor([0.229, 0.224, 0.225]).view(3, 1, 1)
  denorm_img = sample_input.squeeze().cpu() * std + mean
  denorm_img = torch.clamp(denorm_img, 0, 1)
  plt.imshow(denorm_img.permute(1, 2, 0).numpy())
plt.title("Input Image")
plt.axis('off')
plt.show()

: 

## Extract Feature Maps from Different Stages
Set up forward hooks to capture intermediate feature maps from the STickNet backbone.

In [None]:
# Dictionary to store feature maps
feature_maps = {}

def get_activation(name):
  """Hook function to capture activations."""
  def hook(model, input, output):
    feature_maps[name] = output.detach()
  return hook

# Register hooks for different stages
hooks = []

# Hook for initial convolution
hooks.append(model.backbone.init_conv.register_forward_hook(get_activation('init_conv')))

# Hook for each stage
for i in range(5):  # STickNet has 5 stages
  stage_name = f'stage{i+1}'
  if hasattr(model.backbone, stage_name):
    stage = getattr(model.backbone, stage_name)
    hooks.append(stage.register_forward_hook(get_activation(stage_name)))

# Hook for final conv
hooks.append(model.backbone.final_conv.register_forward_hook(get_activation('final_conv')))

# Forward pass to collect feature maps
with torch.no_grad():
  output = model(sample_input)

print("Captured feature maps:")
for name, feature_map in feature_maps.items():
  print(f"{name}: {feature_map.shape}")

# Clean up hooks
for hook in hooks:
  hook.remove()

: 

## Create Feature Map Visualization Functions
Implement helper functions to normalize and display feature maps effectively.

In [None]:
def normalize_feature_map(feature_map):
  """Normalize feature map for visualization."""
  fm = feature_map.cpu().numpy()
  fm_min, fm_max = fm.min(), fm.max()
  fm_norm = (fm - fm_min) / (fm_max - fm_min + 1e-8)
  return fm_norm

def visualize_feature_maps(feature_maps_dict, layer_name, max_channels=16):
  """
  Visualize feature maps from a specific layer.
  
  Args:
    feature_maps_dict: Dictionary containing feature maps
    layer_name: Name of the layer to visualize
    max_channels: Maximum number of channels to display
  """
  if layer_name not in feature_maps_dict:
    print(f"Layer {layer_name} not found in feature maps")
    return
  
  feature_map = feature_maps_dict[layer_name]
  batch_size, channels, height, width = feature_map.shape
  
  # Limit number of channels to display
  num_channels = min(channels, max_channels)
  cols = 4
  rows = (num_channels + cols - 1) // cols
  
  plt.figure(figsize=(15, 3 * rows))
  plt.suptitle(f'Feature Maps from {layer_name} (Shape: {feature_map.shape})', fontsize=16)
  
  for i in range(num_channels):
    plt.subplot(rows, cols, i + 1)
    fm = normalize_feature_map(feature_map[0, i])
    plt.imshow(fm, cmap='viridis')
    plt.title(f'Channel {i}')
    plt.axis('off')
  
  plt.tight_layout()
  plt.show()

def create_feature_map_summary(feature_maps_dict):
  """Create a summary visualization of all feature maps."""
  plt.figure(figsize=(20, 12))
  
  layer_names = list(feature_maps_dict.keys())
  num_layers = len(layer_names)
  cols = 3
  rows = (num_layers + cols - 1) // cols
  
  for i, layer_name in enumerate(layer_names):
    feature_map = feature_maps_dict[layer_name]
    
    # Take mean across channels for summary view
    if len(feature_map.shape) == 4:
      mean_fm = torch.mean(feature_map[0], dim=0).cpu().numpy()
    else:
      mean_fm = feature_map[0].cpu().numpy()
    
    plt.subplot(rows, cols, i + 1)
    
    if len(mean_fm.shape) == 2:
      plt.imshow(normalize_feature_map(torch.tensor(mean_fm)), cmap='viridis')
    else:
      # For 1D features (like global pool output)
      plt.plot(mean_fm.flatten())
    
    plt.title(f'{layer_name}\nShape: {feature_map.shape}')
    plt.axis('off')
  
  plt.suptitle('Feature Map Summary Across All Layers', fontsize=16)
  plt.tight_layout()
  plt.show()

: 

## Visualize Initial Convolution Features
Display feature maps from the initial convolution layer to observe low-level features.

In [None]:
# Visualize initial convolution features
visualize_feature_maps(feature_maps, 'init_conv', max_channels=16)

print("Initial convolution layer captures low-level features like edges and textures.")

: 

## Visualize Stage-wise Feature Maps
Explore how features evolve through each stage of the STickNet backbone.

In [None]:
# Visualize feature maps from each stage
stages = ['stage1', 'stage2', 'stage3', 'stage4', 'stage5']

for stage in stages:
  if stage in feature_maps:
    print(f"\n=== {stage.upper()} ===")
    visualize_feature_maps(feature_maps, stage, max_channels=12)
    
    # Print stage analysis
    fm = feature_maps[stage]
    print(f"Stage shape: {fm.shape}")
    print(f"Spatial resolution: {fm.shape[2]}x{fm.shape[3]}")
    print(f"Number of channels: {fm.shape[1]}")
    print(f"Feature range: [{fm.min().item():.3f}, {fm.max().item():.3f}]")

: 

## Create Multi-stage Feature Comparison
Generate side-by-side comparisons to observe feature evolution across stages.

In [None]:
# Create comprehensive feature map summary
create_feature_map_summary(feature_maps)

# Compare specific channels across stages
def compare_across_stages(feature_maps_dict, channel_idx=0):
  """Compare a specific channel across different stages."""
  stages_to_compare = ['init_conv', 'stage1', 'stage2', 'stage3', 'stage4', 'stage5']
  
  plt.figure(figsize=(18, 3))
  
  for i, stage in enumerate(stages_to_compare):
    if stage in feature_maps_dict:
      fm = feature_maps_dict[stage]
      if fm.shape[1] > channel_idx:  # Check if channel exists
        plt.subplot(1, len(stages_to_compare), i + 1)
        channel_fm = normalize_feature_map(fm[0, channel_idx])
        plt.imshow(channel_fm, cmap='viridis')
        plt.title(f'{stage}\nCh {channel_idx}\n{fm.shape[2]}x{fm.shape[3]}')
        plt.axis('off')
  
  plt.suptitle(f'Feature Evolution Across Stages (Channel {channel_idx})', fontsize=14)
  plt.tight_layout()
  plt.show()

# Compare first channel across stages
compare_across_stages(feature_maps, channel_idx=0)

# Compare different channels if available
if any(fm.shape[1] > 5 for fm in feature_maps.values()):
  compare_across_stages(feature_maps, channel_idx=5)

: 

## Save Feature Map Visualizations
Export the generated visualizations for documentation and further analysis.

In [None]:
# Create output directory
output_dir = '../feature_maps_output'
os.makedirs(output_dir, exist_ok=True)

def save_feature_map_visualization(feature_maps_dict, layer_name, output_dir, max_channels=16):
  """Save feature map visualization to file."""
  if layer_name not in feature_maps_dict:
    return
  
  feature_map = feature_maps_dict[layer_name]
  batch_size, channels, height, width = feature_map.shape
  
  num_channels = min(channels, max_channels)
  cols = 4
  rows = (num_channels + cols - 1) // cols
  
  plt.figure(figsize=(15, 3 * rows))
  plt.suptitle(f'Feature Maps from {layer_name} (Shape: {feature_map.shape})', fontsize=16)
  
  for i in range(num_channels):
    plt.subplot(rows, cols, i + 1)
    fm = normalize_feature_map(feature_map[0, i])
    plt.imshow(fm, cmap='viridis')
    plt.title(f'Channel {i}')
    plt.axis('off')
  
  plt.tight_layout()
  plt.savefig(os.path.join(output_dir, f'{layer_name}_feature_maps.png'), 
              dpi=300, bbox_inches='tight')
  plt.close()
  print(f"Saved {layer_name} feature maps to {output_dir}")

# Save visualizations for all layers
for layer_name in feature_maps.keys():
  save_feature_map_visualization(feature_maps, layer_name, output_dir)

# Save the summary comparison
plt.figure(figsize=(20, 12))
create_feature_map_summary(feature_maps)
plt.savefig(os.path.join(output_dir, 'feature_maps_summary.png'), 
            dpi=300, bbox_inches='tight')
plt.close()

print(f"\nAll feature map visualizations saved to: {output_dir}")
print("Files saved:")
for file in os.listdir(output_dir):
  if file.endswith('.png'):
    print(f"  - {file}")

: 

## Analysis and Conclusions

### Key Observations:

1. **Initial Convolution**: Captures low-level features like edges, textures, and basic patterns
2. **Early Stages (Stage 1-2)**: Focus on local feature detection and simple pattern recognition
3. **Middle Stages (Stage 3-4)**: Combine local features into more complex patterns and shapes
4. **Late Stages (Stage 5)**: High-level semantic features relevant for classification
5. **Final Convolution**: Abstract representations ready for classification

### STickNet Architecture Insights:

- **Spatial Attention**: The SE attention mechanism helps focus on important spatial regions
- **Feature Evolution**: Progressive abstraction from low-level to high-level features
- **Channel Reduction**: Later stages may have fewer channels but richer semantic content
- **Resolution Changes**: Spatial resolution decreases while feature complexity increases

### Usage Tips:

- Use different input images to see how feature maps change based on content
- Experiment with different model sizes ('basic', 'small', 'large')
- Try enabling `use_lightweight_optimization` to see LWO block effects
- Compare feature maps with and without SE attention

This visualization helps understand what the STickNet model learns and can guide model improvements and interpretability.