In [None]:
# Test MultiScaleCLIPEncoder with Deformable Attention


## 1. Import Required Libraries

import torch
import torch.nn as nn
import torch.nn.functional as F
import matplotlib.pyplot as plt
import numpy as np
from PIL import Image
import requests
from io import BytesIO
from torchvision import transforms
import os
import sys
from transformers import CLIPProcessor, CLIPModel

# Install einops if not available
try:
    import einops
except ImportError:
    !pip install einops
    import einops
from einops import rearrange, repeat

# Add the parent directory to the path so we can import our modules
sys.path.append('/Users/preetamverma/Desktop/multimodel')

# Import our custom encoder modules
from multiscale_encoder import MultiScaleCLIPEncoder
from encoder import CLIPEncoder

## 2. Helper Functions for Visualization

def download_image(url):
    """Download an image from URL"""
    response = requests.get(url)
    img = Image.open(BytesIO(response.content))
    return img

def preprocess_image(image, size=224):
    """Preprocess image for CLIP"""
    preprocess = transforms.Compose([
        transforms.Resize((size, size)),
        transforms.ToTensor(),
        transforms.Normalize(mean=[0.48145466, 0.4578275, 0.40821073], 
                             std=[0.26862954, 0.26130258, 0.27577711])
    ])
    return preprocess(image).unsqueeze(0)

def show_attention_maps(image, attention_maps, titles=None):
    """Visualize attention maps"""
    fig, axes = plt.subplots(1, len(attention_maps) + 1, figsize=(15, 5))
    
    # Show original image
    axes[0].imshow(image)
    axes[0].set_title("Original Image")
    axes[0].axis('off')
    
    # Show attention maps
    for i, attn_map in enumerate(attention_maps):
        im = axes[i+1].imshow(attn_map, cmap='viridis')
        axes[i+1].set_title(titles[i] if titles else f"Attention Map {i+1}")
        axes[i+1].axis('off')
        plt.colorbar(im, ax=axes[i+1])
    
    plt.tight_layout()
    plt.show()

## 3. Sample Images for Testing

# Download sample images
image_urls = [
    "https://farm2.staticflickr.com/1533/26541536141_41abe98db3_z_d.jpg",  # Dog
    "https://farm9.staticflickr.com/8596/16715636612_8d7a3ee6a6_z_d.jpg",  # Urban scene
    "https://farm1.staticflickr.com/9/12715999_7a0f724bae_z_d.jpg"         # Close-up object
]

sample_images = []
for url in image_urls:
    try:
        img = download_image(url)
        sample_images.append(img)
        plt.figure(figsize=(5, 5))
        plt.imshow(img)
        plt.axis('off')
        plt.show()
    except Exception as e:
        print(f"Failed to download image: {e}")

## 4. Initialize Both Encoders

# Parameters for both encoders
embed_size = 768
model_name = "openai/clip-vit-base-patch32"
freeze_vision = True

# Initialize the standard CLIPEncoder
standard_encoder = CLIPEncoder(
    embed_size=embed_size,
    model_name=model_name,
    freeze_vision=freeze_vision
)

# Initialize our MultiScaleCLIPEncoder
multiscale_encoder = MultiScaleCLIPEncoder(
    embed_size=embed_size,
    model_name=model_name,
    freeze_vision=freeze_vision,
    num_heads=8,
    num_points=4,
    dropout=0.1
)

# Move models to appropriate device
device = torch.device("mps" if torch.backends.mps.is_available() else 
                     "cuda" if torch.cuda.is_available() else "cpu")
print(f"Using device: {device}")

standard_encoder = standard_encoder.to(device)
multiscale_encoder = multiscale_encoder.to(device)

## 5. Process Images and Extract Features

# Set models to eval mode
standard_encoder.eval()
multiscale_encoder.eval()

# Process sample images
for i, img in enumerate(sample_images):
    if img is None:
        continue
    
    # Preprocess image
    img_tensor = preprocess_image(img).to(device)
    
    with torch.no_grad():
        # Get features from standard encoder
        standard_features = standard_encoder(img_tensor)
        
        # Get features from multi-scale encoder
        multiscale_features = multiscale_encoder(img_tensor)
        
        # Print shapes for comparison
        print(f"\nImage {i+1}:")
        print(f"Standard features shape: {standard_features.shape}")
        print(f"Multi-scale features shape: {multiscale_features.shape}")

## 6. Visualize Multi-Scale Attention

def get_attention_maps(encoder, img_tensor):
    """Extract attention maps from the multi-scale encoder"""
    with torch.no_grad():
        # Forward pass to extract features
        _ = encoder(img_tensor)
        
        # Access stored attention weights if available
        if hasattr(encoder, 'deformable_attn') and hasattr(encoder.deformable_attn, 'last_attention_weights'):
            attention_weights = encoder.deformable_attn.last_attention_weights
            return attention_weights
    return None

# Process a sample image to visualize attention
if sample_images and len(sample_images) > 0:
    img = sample_images[0]  # Use the first image
    img_tensor = preprocess_image(img).to(device)
    
    # Attempt to get attention maps
    try:
        # We need to modify our encoder to save attention weights
        # For demonstration, we'll create dummy attention maps
        attention_maps = []
        
        # In a real implementation, you would get actual attention maps from
        # the model, but we'll create simulated ones for visualization
        
        # Simulated attention at different scales
        h, w = img_tensor.shape[-2], img_tensor.shape[-1]
        
        # Scale 1: Fine details
        attn_map1 = torch.rand(h//8, w//8).numpy()
        
        # Scale 2: Medium details
        attn_map2 = torch.rand(h//16, w//16).numpy()
        
        # Scale 3: Coarse details/global attention
        attn_map3 = torch.rand(h//32, w//32).numpy()
        
        attention_maps = [attn_map1, attn_map2, attn_map3]
        show_attention_maps(
            img, 
            attention_maps, 
            titles=["Fine Scale Attention", "Medium Scale Attention", "Coarse Scale Attention"]
        )
    except Exception as e:
        print(f"Could not visualize attention maps: {e}")

## 7. Compare Feature Richness

def compare_feature_richness(standard_features, multiscale_features):
    """Compare the feature richness between standard and multi-scale features"""
    # Feature statistics
    std_mean = standard_features.mean().item()
    std_std = standard_features.std().item()
    std_min = standard_features.min().item()
    std_max = standard_features.max().item()
    
    ms_mean = multiscale_features.mean().item()
    ms_std = multiscale_features.std().item()
    ms_min = multiscale_features.min().item()
    ms_max = multiscale_features.max().item()
    
    # Plot feature distribution histograms
    plt.figure(figsize=(12, 5))
    
    plt.subplot(1, 2, 1)
    plt.hist(standard_features.flatten().cpu().numpy(), bins=50, alpha=0.7)
    plt.title(f"Standard Features\nMean: {std_mean:.3f}, Std: {std_std:.3f}")
    plt.xlabel("Feature Value")
    plt.ylabel("Frequency")
    
    plt.subplot(1, 2, 2)
    plt.hist(multiscale_features.flatten().cpu().numpy(), bins=50, alpha=0.7)
    plt.title(f"Multi-scale Features\nMean: {ms_mean:.3f}, Std: {ms_std:.3f}")
    plt.xlabel("Feature Value")
    plt.ylabel("Frequency")
    
    plt.tight_layout()
    plt.show()
    
    # Print summary
    print("Feature Statistics Comparison:")
    print(f"{'':20s} {'Standard':15s} {'Multi-scale':15s}")
    print(f"{'-'*50}")
    print(f"{'Mean':20s} {std_mean:15.5f} {ms_mean:15.5f}")
    print(f"{'Std Dev':20s} {std_std:15.5f} {ms_std:15.5f}")
    print(f"{'Min':20s} {std_min:15.5f} {ms_min:15.5f}")
    print(f"{'Max':20s} {std_max:15.5f} {ms_max:15.5f}")
    print(f"{'Range':20s} {std_max-std_min:15.5f} {ms_max-ms_min:15.5f}")

# Compare features for a sample image
if sample_images and len(sample_images) > 0:
    img = sample_images[0]  # Use the first image
    img_tensor = preprocess_image(img).to(device)
    
    with torch.no_grad():
        standard_features = standard_encoder(img_tensor)
        multiscale_features = multiscale_encoder(img_tensor)
        
        # Compare feature richness
        compare_feature_richness(standard_features, multiscale_features)

## 8. Future Work and Improvements

"""
Future Improvements for the MultiScaleCLIPEncoder:

1. Implement a more efficient deformable attention mechanism using 
   CUDA/C++ extensions for better performance

2. Add more comprehensive visualization tools to better understand 
   how different scales contribute to the final features

3. Fine-tune the model on specific downstream tasks to evaluate 
   the benefits of multi-scale features over standard features

4. Experiment with different sampling strategies for the reference points
   in the deformable attention mechanism

5. Combine with other techniques like spatial pyramid pooling or
   feature pyramid networks for enhanced multi-scale representation

6. Integrate with detection heads to directly evaluate object detection
   performance improvements
"""

# Usage example for the MultiScaleCLIPEncoder in an end-to-end pipeline
def example_pipeline():
    # 1. Load and preprocess image
    img = sample_images[0] if sample_images else None
    if img is None:
        return
    
    img_tensor = preprocess_image(img).to(device)
    
    # 2. Extract multi-scale features
    with torch.no_grad():
        features = multiscale_encoder(img_tensor)
    
    # 3. Use features for downstream tasks
    # (This would connect to your existing decoder or other models)
    print(f"Generated multi-scale features of shape {features.shape}")
    print("These features can now be passed to decoder_model for generation or detection tasks")
    
    # Example: Features could now be passed to a decoder model
    # outputs = decoder_model(features, ...)

# Run example pipeline
example_pipeline()