# RAJNI-ViT Demo

**RAJNI: Relative Adaptive Jacobian-based Neuronal Importance**

This notebook demonstrates how to use RAJNI for efficient Vision Transformer inference through adaptive token pruning. We'll cover:

1. Loading a pretrained ViT and wrapping it with RAJNI
2. Running inference with different pruning intensities (gamma values)
3. Visualizing which tokens get pruned at each layer
4. Measuring FLOPs reduction compared to the baseline

---

## Key Idea

RAJNI approximates the gradient of the CLS token with respect to each patch token using attention weights and value norms:

$$\text{importance}_j \approx \sum_h |A_{0j}^h| \cdot \|V_j^h\|$$

This captures how much each patch contributes to the final classification—no backpropagation needed!

In [None]:
# Setup: Install dependencies if needed
# !pip install torch torchvision timm matplotlib numpy

import sys
sys.path.insert(0, '..')  # Add parent directory for imports

import torch
import timm
import matplotlib.pyplot as plt
import numpy as np

# Check device
device = "cuda" if torch.cuda.is_available() else "mps" if torch.backends.mps.is_available() else "cpu"
print(f"Using device: {device}")

## 1. Load a Pretrained Vision Transformer

We'll use `timm` to load a ViT-Small model. RAJNI works with any ViT architecture from timm.

In [None]:
# Load a ViT-Small model (you can swap this for vit_base_patch16_224, deit_small_patch16_224, etc.)
base_model = timm.create_model('vit_small_patch16_224', pretrained=True)
base_model.eval()
base_model.to(device)

print(f"Model: {base_model.__class__.__name__}")
print(f"Embedding dim: {base_model.embed_dim}")
print(f"Number of blocks: {len(base_model.blocks)}")
print(f"Number of heads: {base_model.blocks[0].attn.num_heads}")

## 2. Wrap with RAJNI

The `AdaptiveJacobianPrunedViT` wrapper adds token pruning to the forward pass.

**Key parameters:**
- `gamma`: Pruning intensity (higher = more aggressive). Start with 0.01
- `min_tokens`: Floor on token count to prevent over-pruning

In [None]:
from rajni import AdaptiveJacobianPrunedViT

# Wrap the base model with RAJNI
gamma = 0.02  # Moderate pruning intensity
model = AdaptiveJacobianPrunedViT(base_model, gamma=gamma, min_tokens=16)
model.to(device)
model.eval()

print(model)

## 3. Run Inference and Inspect Pruning

Let's create a random image and see how tokens get pruned across layers.

In [None]:
# Create a dummy image (in practice, you'd load a real image)
dummy_input = torch.randn(1, 3, 224, 224).to(device)

# Run forward pass
with torch.no_grad():
    logits = model(dummy_input)

# Get pruning statistics
stats = model.get_last_stats()
token_counts = stats["token_counts"]

print(f"Output shape: {logits.shape}")
print(f"\nToken counts per layer:")
for i, count in enumerate(token_counts):
    bar = "█" * (count // 5)
    print(f"  Layer {i+1:2d}: {count:3d} tokens {bar}")

## 4. Compare Different Gamma Values

Let's see how pruning intensity affects token retention.

In [None]:
# Test different gamma values
gamma_values = [0.0, 0.01, 0.02, 0.05, 0.1]
results = {}

for g in gamma_values:
    test_model = AdaptiveJacobianPrunedViT(base_model, gamma=g, min_tokens=16)
    test_model.to(device).eval()
    
    with torch.no_grad():
        _ = test_model(dummy_input)
    
    stats = test_model.get_last_stats()
    results[g] = stats["token_counts"]

# Plot token retention curves
plt.figure(figsize=(10, 5))
for g, counts in results.items():
    plt.plot(range(1, len(counts) + 1), counts, marker='o', label=f'γ = {g}')

plt.xlabel('Layer')
plt.ylabel('Token Count')
plt.title('Token Retention Across Layers (Different γ Values)')
plt.legend()
plt.grid(True, alpha=0.3)
plt.tight_layout()
plt.show()

## 5. Compute FLOPs Reduction

RAJNI reduces compute by processing fewer tokens in later layers. Let's quantify this.

In [None]:
from evaluation.flops import flops_reduction

# Compute FLOPs for each gamma setting
print("FLOPs Analysis:\n" + "=" * 50)
print(f"{'Gamma':<10} {'Baseline':>12} {'RAJNI':>12} {'Reduction':>12}")
print("-" * 50)

for g in gamma_values:
    test_model = AdaptiveJacobianPrunedViT(base_model, gamma=g, min_tokens=16)
    test_model.to(device).eval()
    
    with torch.no_grad():
        _ = test_model(dummy_input)
    
    stats = test_model.get_last_stats()
    flops_info = flops_reduction(base_model, stats)
    
    print(f"γ = {g:<6.3f} {flops_info['baseline_GFLOPs']:>10.2f} G  "
          f"{flops_info['rajni_GFLOPs']:>10.2f} G  "
          f"{flops_info['reduction_%']:>10.1f}%")

## 6. Visualize Pruned Patches

Load a real image and see which patches RAJNI considers important vs. unimportant.

In [None]:
from torchvision import transforms
from PIL import Image
import urllib.request
import io

# Download a sample image (cat from ImageNet)
url = "https://upload.wikimedia.org/wikipedia/commons/thumb/3/3a/Cat03.jpg/1200px-Cat03.jpg"
try:
    with urllib.request.urlopen(url, timeout=5) as response:
        img_data = response.read()
    pil_img = Image.open(io.BytesIO(img_data)).convert('RGB')
    print("Downloaded sample image")
except:
    # Fallback: create a synthetic image with a clear subject
    print("Using synthetic image (no internet)")
    pil_img = Image.new('RGB', (224, 224), color='white')
    # Draw a simple circle in the center
    import numpy as np
    arr = np.ones((224, 224, 3), dtype=np.uint8) * 255
    y, x = np.ogrid[:224, :224]
    mask = (x - 112)**2 + (y - 112)**2 < 50**2
    arr[mask] = [200, 100, 50]  # Orange circle
    pil_img = Image.fromarray(arr)

# Preprocess for ViT
preprocess = transforms.Compose([
    transforms.Resize(256),
    transforms.CenterCrop(224),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
])

img_tensor = preprocess(pil_img).unsqueeze(0).to(device)
plt.imshow(pil_img)
plt.title("Input Image")
plt.axis('off')
plt.show()

In [None]:
# Run RAJNI on the real image and visualize pruning
from rajni.utils import denormalize_image

# Use moderate pruning for clear visualization
vis_model = AdaptiveJacobianPrunedViT(base_model, gamma=0.05, min_tokens=16)
vis_model.to(device).eval()

with torch.no_grad():
    logits = vis_model(img_tensor)

stats = vis_model.get_last_stats()
kept_indices = stats["kept_indices"]

# Denormalize for display
img_np = img_tensor[0].cpu().permute(1, 2, 0).numpy()
img_display = denormalize_image(img_np)

# Visualize pruning across selected layers
patch_size = 16
patches_per_row = 224 // patch_size  # 14
total_patches = patches_per_row ** 2  # 196

# Track which patches survive
alive_patches = set(range(total_patches))

layers_to_show = [0, 3, 6, 9, 11] if len(kept_indices) >= 12 else list(range(min(5, len(kept_indices))))

fig, axes = plt.subplots(1, len(layers_to_show) + 1, figsize=(3 * (len(layers_to_show) + 1), 3))

# Original image
axes[0].imshow(img_display)
axes[0].set_title("Original")
axes[0].axis('off')

# Show pruning at each layer
alive_ids = list(range(total_patches))

for ax_idx, layer_idx in enumerate(layers_to_show):
    if layer_idx >= len(kept_indices):
        break
        
    keep_idx = kept_indices[layer_idx]
    
    # Remove CLS (index 0) and adjust
    patch_keep = keep_idx[keep_idx != 0].cpu().numpy() - 1
    
    # Only keep valid indices
    patch_keep = patch_keep[(patch_keep >= 0) & (patch_keep < len(alive_ids))]
    
    # Update alive patches
    alive_ids = [alive_ids[i] for i in patch_keep]
    pruned = set(range(total_patches)) - set(alive_ids)
    
    # Draw
    ax = axes[ax_idx + 1]
    ax.imshow(img_display)
    ax.set_title(f"Layer {layer_idx + 1}")
    ax.axis('off')
    
    # Highlight pruned patches in blue
    for p in pruned:
        r, c = p // patches_per_row, p % patches_per_row
        rect = plt.Rectangle((c * patch_size, r * patch_size), patch_size, patch_size,
                            linewidth=0, facecolor='blue', alpha=0.4)
        ax.add_patch(rect)

plt.suptitle(f"RAJNI Pruning Visualization (γ = 0.05)", y=1.02)
plt.tight_layout()
plt.show()

print(f"\nFinal token count: {len(alive_ids)} / {total_patches} patches "
      f"({100 * len(alive_ids) / total_patches:.1f}% retained)")

## 7. Understanding the Algorithm

RAJNI's pruning decision at each layer uses three quantities:

1. **CLS Sensitivity (ρ)**: How strongly the CLS token attends to patches
2. **Mass (η)**: Total importance mass in the current layer  
3. **Keep Ratio**: Computed as $(ρ \cdot η)^{-γ}$, clamped to [0, 1]

Higher gamma → more aggressive pruning.

In [None]:
# Inspect the pruning functions directly
from rajni.pruning import compute_cls_sensitivity, compute_jacobian_importance, compute_keep_ratio

# These are the building blocks of RAJNI
print("Core Pruning Functions:")
print("=" * 50)
print("""
1. compute_cls_sensitivity(attention, values)
   → Measures how much CLS token "sees" each patch
   
2. compute_jacobian_importance(attention, values, num_patches)
   → Approximates ∂CLS/∂patch using attention × value norms
   
3. compute_keep_ratio(rho, mass, prev_mass, gamma)
   → Adaptive budget: (rho × mass/prev_mass)^(-gamma)
   
4. select_tokens(importance, num_keep, device)
   → Top-k selection with CLS always preserved
""")

## 8. Next Steps

- **Benchmark on ImageNet**: Use `examples/run_imagenet.py` 
- **Hyperparameter sweep**: Run `scripts/sweep_gamma.sh`
- **Compare models**: Try different ViT variants (tiny, small, base)

For questions or issues, check the README or open a GitHub issue!