In [None]:
pip install torch torchvision transformers huggingface_hub pickle5


Note: you may need to restart the kernel to use updated packages.


: 

In [None]:
# Import necessary libraries
import torch
from transformers import BlipProcessor, BlipForConditionalGeneration
from PIL import Image
import requests
from io import BytesIO
import matplotlib.pyplot as plt
import numpy as np

# Load the BLIP model for image captioning from Hugging Face
def load_captioning_model():
    processor = BlipProcessor.from_pretrained("Salesforce/blip-image-captioning-base")
    model = BlipForConditionalGeneration.from_pretrained("Salesforce/blip-image-captioning-base")
    return processor, model

# Initialize the model
processor, model = load_captioning_model()
print("Image captioning model loaded successfully!")

In [None]:
# Function to load image from URL or local path
def load_image(image_source):
    """Load image from URL or local file path"""
    try:
        if image_source.startswith('http'):
            response = requests.get(image_source)
            image = Image.open(BytesIO(response.content)).convert('RGB')
        else:
            image = Image.open(image_source).convert('RGB')
        return image
    except Exception as e:
        print(f"Error loading image: {e}")
        return None

# Function to generate captions
def generate_caption(image_source, beam_size=5, max_length=50):
    """Generate captions for an image using BLIP model"""
    try:
        # Load the image
        image = load_image(image_source)
        if image is None:
            return ["Error: Could not load image"]
        
        # Process the image
        inputs = processor(image, return_tensors="pt")
        
        # Generate caption with beam search
        with torch.no_grad():
            out = model.generate(
                **inputs, 
                max_length=max_length,
                num_beams=beam_size,
                num_return_sequences=beam_size,
                early_stopping=True
            )
        
        # Decode the generated captions
        captions = []
        for i in range(beam_size):
            caption = processor.decode(out[i], skip_special_tokens=True)
            captions.append(caption)
        
        return captions
    
    except Exception as e:
        print(f"Error generating caption: {e}")
        return ["Error generating caption"]

# Function to display image with captions
def caption_image(image_source, beam_size=5):
    """Display image with generated captions"""
    # Load and display image
    image = load_image(image_source)
    if image is None:
        return
    
    # Generate captions
    captions = generate_caption(image_source, beam_size)
    
    # Display results
    plt.figure(figsize=(12, 8))
    plt.subplot(1, 2, 1)
    plt.imshow(image)
    plt.axis('off')
    plt.title('Input Image')
    
    # Display captions
    plt.subplot(1, 2, 2)
    plt.axis('off')
    caption_text = "Generated Captions:\n\n"
    for i, caption in enumerate(captions, 1):
        caption_text += f"{i}. {caption}\n\n"
    
    plt.text(0.1, 0.9, caption_text, fontsize=12, verticalalignment='top',
             transform=plt.gca().transAxes, wrap=True,
             bbox=dict(boxstyle="round,pad=0.5", facecolor="lightblue", alpha=0.8))
    
    plt.tight_layout()
    plt.show()
    
    return captions

In [None]:
# Test with example images
test_images = [
    "http://images.cocodataset.org/train2017/000000505539.jpg",
    "https://images.unsplash.com/photo-1514888286974-6c03e2ca1dba?w=400",  # cat
    "https://images.unsplash.com/photo-1552053831-71594a27632d?w=400",    # dog
]

# Generate captions for test images
for i, image_url in enumerate(test_images, 1):
    print(f"\n{'='*60}")
    print(f"Test Image {i}")
    print(f"{'='*60}")
    captions = caption_image(image_url, beam_size=3)
    
    print("\nGenerated Captions:")
    for j, caption in enumerate(captions, 1):
        print(f"{j}. {caption}")

In [1]:
# Grad-CAM implementation for Vision Transformer in BLIP model
import torch.nn.functional as F
from torch.nn import functional as F
import cv2

class GradCAMViT:
    """Grad-CAM implementation for Vision Transformer models like BLIP's vision encoder"""
    
    def __init__(self, model, processor, target_layer_name='vision_model.encoder.layers.-1.attention.self'):
        """
        Initialize Grad-CAM for ViT
        Args:
            model: BLIP model
            processor: BLIP processor
            target_layer_name: Name of the target layer (last attention layer by default)
        """
        self.model = model
        self.processor = processor
        self.target_layer_name = target_layer_name
        self.gradients = None
        self.activations = None
        self.hooks = []
        
    def save_gradient(self, module, grad_input, grad_output):
        """Hook function to save gradients"""
        self.gradients = grad_output[0]
        
    def save_activation(self, module, input, output):
        """Hook function to save activations"""
        self.activations = output
        
    def register_hooks(self):
        """Register forward and backward hooks"""
        # Get the target layer
        target_layer = self.get_target_layer()
        
        # Register hooks
        forward_hook = target_layer.register_forward_hook(self.save_activation)
        backward_hook = target_layer.register_backward_hook(self.save_gradient)
        
        self.hooks = [forward_hook, backward_hook]
        
    def get_target_layer(self):
        """Get the target layer from the model"""
        # Navigate to the last attention layer in the vision encoder
        vision_model = self.model.vision_model
        # Get the last transformer layer
        last_layer = vision_model.encoder.layers[-1]
        # Get the attention mechanism
        return last_layer.attention.self
        
    def remove_hooks(self):
        """Remove all registered hooks"""
        for hook in self.hooks:
            hook.remove()
        self.hooks = []
        
    def generate_cam(self, image_path, caption_text=None, patch_size=16):
        """
        Generate Grad-CAM for the given image
        Args:
            image_path: Path to the input image
            caption_text: Optional text prompt for conditional generation
            patch_size: Size of ViT patches (default 16 for BLIP)
        """
        # Load and preprocess image
        image = load_image(image_path)
        if image is None:
            return None, None
            
        inputs = self.processor(image, text=caption_text, return_tensors="pt")
        
        # Register hooks
        self.register_hooks()
        
        try:
            # Forward pass
            self.model.eval()
            
            if caption_text is None:
                # For image captioning, we need to do generation
                outputs = self.model.generate(**inputs, max_length=50, num_beams=1, return_dict_in_generate=True, output_scores=True)
                # Use the last generated token's score for backpropagation
                target_score = outputs.scores[-1].max()
            else:
                # For conditional generation with text
                outputs = self.model(**inputs)
                target_score = outputs.logits.max()
            
            # Backward pass
            self.model.zero_grad()
            target_score.backward(retain_graph=True)
            
            # Generate CAM
            if self.gradients is not None and self.activations is not None:
                # Get gradients and activations
                gradients = self.gradients.cpu().data.numpy()[0]  # [seq_len, hidden_dim]
                activations = self.activations.cpu().data.numpy()[0]  # [seq_len, hidden_dim]
                
                # Remove CLS token (first token)
                gradients = gradients[1:]  # Remove CLS token
                activations = activations[1:]  # Remove CLS token
                
                # Calculate weights (global average pooling of gradients)
                weights = np.mean(gradients, axis=1)  # [seq_len-1]
                
                # Generate CAM
                cam = np.zeros(activations.shape[0])  # [seq_len-1]
                for i, w in enumerate(weights):
                    cam[i] = w * np.mean(activations[i])
                
                # Normalize CAM
                cam = np.maximum(cam, 0)  # ReLU
                if cam.max() > 0:
                    cam = cam / cam.max()
                
                # Reshape to spatial dimensions
                # For ViT, we need to convert from sequence to spatial
                num_patches = int(np.sqrt(len(cam)))
                if num_patches * num_patches == len(cam):
                    cam_2d = cam.reshape(num_patches, num_patches)
                else:
                    # Handle cases where patches don't form perfect square
                    side = int(np.sqrt(len(cam)))
                    cam_2d = cam[:side*side].reshape(side, side)
                
                return cam_2d, image
            
        finally:
            self.remove_hooks()
        
        return None, image

# Visualization functions
def overlay_cam_on_image(image, cam, alpha=0.6, colormap=cv2.COLORMAP_JET):
    """
    Overlay CAM heatmap on the original image
    Args:
        image: PIL Image
        cam: 2D numpy array representing the CAM
        alpha: Transparency factor
        colormap: OpenCV colormap
    """
    # Convert PIL image to numpy array
    img_array = np.array(image)
    height, width = img_array.shape[:2]
    
    # Resize CAM to match image dimensions
    cam_resized = cv2.resize(cam, (width, height))
    
    # Normalize CAM to 0-255
    cam_normalized = np.uint8(255 * cam_resized)
    
    # Apply colormap
    heatmap = cv2.applyColorMap(cam_normalized, colormap)
    heatmap = cv2.cvtColor(heatmap, cv2.COLOR_BGR2RGB)
    
    # Overlay heatmap on original image
    overlayed = img_array * (1 - alpha) + heatmap * alpha
    overlayed = np.uint8(overlayed)
    
    return overlayed

def visualize_gradcam_results(image_path, gradcam_vit, caption_text=None):
    """
    Complete visualization pipeline for Grad-CAM results
    """
    # Generate CAM
    cam, original_image = gradcam_vit.generate_cam(image_path, caption_text)
    
    if cam is None:
        print("Failed to generate CAM")
        return
    
    # Generate caption for the image
    captions = generate_caption(image_path, beam_size=1)
    main_caption = captions[0] if captions else "No caption generated"
    
    # Create overlay
    overlayed_image = overlay_cam_on_image(original_image, cam)
    
    # Visualization
    fig, axes = plt.subplots(1, 4, figsize=(20, 5))
    
    # Original image
    axes[0].imshow(original_image)
    axes[0].set_title('Original Image')
    axes[0].axis('off')
    
    # CAM heatmap
    im1 = axes[1].imshow(cam, cmap='jet')
    axes[1].set_title('Grad-CAM Heatmap')
    axes[1].axis('off')
    plt.colorbar(im1, ax=axes[1], fraction=0.046, pad=0.04)
    
    # Overlay
    axes[2].imshow(overlayed_image)
    axes[2].set_title('Grad-CAM Overlay')
    axes[2].axis('off')
    
    # Caption
    axes[3].axis('off')
    caption_text_display = f"Generated Caption:\n\n{main_caption}"
    axes[3].text(0.1, 0.7, caption_text_display, fontsize=12, 
                verticalalignment='top', wrap=True,
                bbox=dict(boxstyle="round,pad=0.5", facecolor="lightblue", alpha=0.8))
    axes[3].set_title('Model Prediction')
    
    plt.tight_layout()
    plt.show()
    
    return cam, overlayed_image, main_caption

print("Grad-CAM for ViT implementation ready!")

In [None]:
# Advanced Grad-CAM Analysis Functions

def compare_attention_layers(image_path, layer_indices=[-1, -2, -3]):
    """
    Compare Grad-CAM across different transformer layers
    """
    original_image = load_image(image_path)
    if original_image is None:
        return
    
    fig, axes = plt.subplots(2, len(layer_indices) + 1, figsize=(5 * (len(layer_indices) + 1), 10))
    
    # Show original image
    axes[0, 0].imshow(original_image)
    axes[0, 0].set_title('Original Image')
    axes[0, 0].axis('off')
    axes[1, 0].axis('off')
    
    for i, layer_idx in enumerate(layer_indices):
        # Create Grad-CAM for specific layer
        gradcam_layer = GradCAMViT(model, processor)
        
        # Modify target layer
        try:
            vision_model = model.vision_model
            target_layer = vision_model.encoder.layers[layer_idx].attention.self
            
            # Generate CAM for this layer
            cam, _ = gradcam_layer.generate_cam(image_path)
            
            if cam is not None:
                # Show heatmap
                im1 = axes[0, i + 1].imshow(cam, cmap='jet')
                axes[0, i + 1].set_title(f'Layer {layer_idx} CAM')
                axes[0, i + 1].axis('off')
                plt.colorbar(im1, ax=axes[0, i + 1], fraction=0.046, pad=0.04)
                
                # Show overlay
                overlay = overlay_cam_on_image(original_image, cam)
                axes[1, i + 1].imshow(overlay)
                axes[1, i + 1].set_title(f'Layer {layer_idx} Overlay')
                axes[1, i + 1].axis('off')
            else:
                axes[0, i + 1].text(0.5, 0.5, 'Failed to\ngenerate CAM', 
                                   ha='center', va='center', transform=axes[0, i + 1].transAxes)
                axes[0, i + 1].axis('off')
                axes[1, i + 1].axis('off')
                
        except Exception as e:
            print(f"Error with layer {layer_idx}: {e}")
            axes[0, i + 1].text(0.5, 0.5, f'Error:\n{str(e)[:20]}...', 
                               ha='center', va='center', transform=axes[0, i + 1].transAxes)
            axes[0, i + 1].axis('off')
            axes[1, i + 1].axis('off')
    
    plt.tight_layout()
    plt.show()

def analyze_attention_patterns(image_path, save_results=False):
    """
    Detailed analysis of attention patterns
    """
    print(f"🔬 Detailed Attention Analysis")
    print(f"Image: {os.path.basename(image_path) if image_path.startswith('/') else 'Online Image'}")
    print("-" * 50)
    
    # Generate caption and CAM
    captions = generate_caption(image_path, beam_size=5)
    cam, original_image = gradcam_vit.generate_cam(image_path)
    
    if cam is None:
        print("❌ Failed to generate attention map")
        return
    
    # Statistics
    print(f"📊 Attention Statistics:")
    print(f"   • Shape: {cam.shape}")
    print(f"   • Min activation: {cam.min():.4f}")
    print(f"   • Max activation: {cam.max():.4f}")
    print(f"   • Mean activation: {cam.mean():.4f}")
    print(f"   • Std activation: {cam.std():.4f}")
    
    # Find most attended regions
    flat_cam = cam.flatten()
    top_indices = np.argsort(flat_cam)[-5:]  # Top 5 patches
    print(f"\n🎯 Top 5 Most Attended Patches:")
    for i, idx in enumerate(reversed(top_indices)):
        row, col = divmod(idx, cam.shape[1])
        activation = flat_cam[idx]
        print(f"   {i+1}. Patch ({row}, {col}): {activation:.4f}")
    
    # Generated captions
    print(f"\n📝 Generated Captions:")
    for i, caption in enumerate(captions, 1):
        print(f"   {i}. {caption}")
    
    # Visualization
    fig, axes = plt.subplots(2, 3, figsize=(15, 10))
    
    # Original image
    axes[0, 0].imshow(original_image)
    axes[0, 0].set_title('Original Image')
    axes[0, 0].axis('off')
    
    # CAM heatmap
    im1 = axes[0, 1].imshow(cam, cmap='jet')
    axes[0, 1].set_title('Attention Heatmap')
    axes[0, 1].axis('off')
    plt.colorbar(im1, ax=axes[0, 1], fraction=0.046, pad=0.04)
    
    # Overlay
    overlay = overlay_cam_on_image(original_image, cam)
    axes[0, 2].imshow(overlay)
    axes[0, 2].set_title('Attention Overlay')
    axes[0, 2].axis('off')
    
    # Thresholded attention (top 25%)
    threshold = np.percentile(cam, 75)
    cam_thresh = np.where(cam >= threshold, cam, 0)
    axes[1, 0].imshow(cam_thresh, cmap='jet')
    axes[1, 0].set_title(f'Top 25% Attention\n(>{threshold:.3f})')
    axes[1, 0].axis('off')
    
    # Attention distribution histogram
    axes[1, 1].hist(cam.flatten(), bins=50, alpha=0.7, color='blue')
    axes[1, 1].axvline(cam.mean(), color='red', linestyle='--', label=f'Mean: {cam.mean():.3f}')
    axes[1, 1].axvline(threshold, color='orange', linestyle='--', label=f'75th percentile: {threshold:.3f}')
    axes[1, 1].set_title('Attention Distribution')
    axes[1, 1].set_xlabel('Activation Value')
    axes[1, 1].set_ylabel('Frequency')
    axes[1, 1].legend()
    
    # Captions
    axes[1, 2].axis('off')
    caption_text = "Generated Captions:\n\n"
    for i, caption in enumerate(captions[:3], 1):
        caption_text += f"{i}. {caption}\n\n"
    
    axes[1, 2].text(0.1, 0.9, caption_text, fontsize=10, 
                   verticalalignment='top', transform=axes[1, 2].transAxes,
                   bbox=dict(boxstyle="round,pad=0.5", facecolor="lightgreen", alpha=0.8))
    axes[1, 2].set_title('Model Predictions')
    
    plt.tight_layout()
    plt.show()
    
    if save_results:
        # Save the overlay image
        output_path = f"gradcam_result_{int(time.time())}.png"
        plt.figure(figsize=(10, 5))
        plt.subplot(1, 2, 1)
        plt.imshow(original_image)
        plt.title('Original')
        plt.axis('off')
        plt.subplot(1, 2, 2)
        plt.imshow(overlay)
        plt.title('Grad-CAM')
        plt.axis('off')
        plt.tight_layout()
        plt.savefig(output_path, dpi=300, bbox_inches='tight')
        print(f"💾 Results saved to: {output_path}")
    
    return cam, overlay, captions

print("🚀 Advanced analysis functions ready!")