# BLIP-2 Image Captioning Experiments

In [1]:
import torch
from PIL import Image
from transformers.models.blip_2 import Blip2Processor, Blip2ForConditionalGeneration, Blip2Config
import matplotlib.pyplot as plt
import requests
from io import BytesIO

## Load the Pretrained BLIP-2 Model and Processor

TODO: This notebook is WIP

In [None]:
# Load the model and processor
model_name = "Salesforce/blip2-opt-2.7b"
processor = Blip2Processor.from_pretrained(model_name)
model = Blip2ForConditionalGeneration.from_pretrained(
    model_name,
    device_map="auto",  # This handles CUDA placement automatically
    torch_dtype=torch.float16  # Ensure consistent dtype
)

model.safetensors.index.json:   0%|          | 0.00/122k [00:00<?, ?B/s]

Fetching 2 files:   0%|          | 0/2 [00:00<?, ?it/s]

model-00002-of-00002.safetensors:   0%|          | 0.00/4.98G [00:00<?, ?B/s]

model-00001-of-00002.safetensors:   0%|          | 0.00/10.0G [00:00<?, ?B/s]

## Function to Generate Captions

In [None]:
def generate_caption(image_path_or_url, prompt="", is_url=False):
    """
    Generate a caption for an image using BLIP-2.
    
    Args:
        image_path_or_url: Path to local image or URL to image
        prompt: Optional prompt to guide caption generation
        is_url: Boolean indicating if the image_path_or_url is a URL
        
    Returns:
        Generated caption
    """
    try:
        # Load the image
        if is_url:
            response = requests.get(image_path_or_url)
            image = Image.open(BytesIO(response.content)).convert('RGB')
        else:
            image = Image.open(image_path_or_url).convert('RGB')
        
        # Display the image
        plt.figure(figsize=(10, 10))
        plt.imshow(image)
        plt.axis('off')
        plt.show()
        
        # Process the image and generate caption
        if prompt:
            # If prompt is provided, use it to guide the generation
            inputs = processor(images=image, text=prompt, return_tensors="pt").to(device=device, dtype=torch.float16)
        else:
            # Otherwise, use the default prompt
            inputs = processor(images=image, return_tensors="pt").to(device=device, dtype=torch.float16)
        
        # Generate caption
        generated_ids = model.generate(
            **inputs,
            max_length=50,
            num_beams=5,
            early_stopping=True
        )
        
        # Decode the generated caption
        caption = processor.batch_decode(generated_ids, skip_special_tokens=True)[0].strip()
        
        return caption
    
    except Exception as e:
        return f"Error generating caption: {str(e)}"

## Example Usage with Sample Images

In [None]:
# Example 1: Caption an image from URL
image_url = "https://storage.googleapis.com/sfr-vision-language-research/BLIP/demo.jpg"
caption = generate_caption(image_url, is_url=True)
print(f"Generated caption: {caption}")

In [None]:
# Example 2: Caption with a guiding prompt
image_url = "https://storage.googleapis.com/sfr-vision-language-research/BLIP/demo.jpg"
prompt = "a photo of"
caption = generate_caption(image_url, prompt=prompt, is_url=True)
print(f"Generated caption with prompt '{prompt}': {caption}")

## Function to Caption Images from Local Files

In [None]:
def caption_local_image(file_path, prompt=""):
    """
    Generate a caption for a local image file
    
    Args:
        file_path: Path to local image file
        prompt: Optional prompt to guide caption generation
        
    Returns:
        Generated caption
    """
    return generate_caption(file_path, prompt=prompt, is_url=False)

In [None]:
# local_image_path = "path/to/your/image.jpg"
# caption = caption_local_image(local_image_path)
# print(f"Generated caption for local image: {caption}")

## Function to Process Multiple Images

In [None]:
def batch_caption_images(image_paths_or_urls, is_urls=True, prompt=""):
    """
    Generate captions for multiple images
    
    Args:
        image_paths_or_urls: List of image paths or URLs
        is_urls: Boolean indicating if the inputs are URLs
        prompt: Optional prompt to guide caption generation
        
    Returns:
        Dictionary of image paths/URLs and their captions
    """
    results = {}
    for img in image_paths_or_urls:
        caption = generate_caption(img, prompt=prompt, is_url=is_urls)
        results[img] = caption
        print(f"Image: {img}")
        print(f"Caption: {caption}")
        print("-" * 50)
    
    return results

In [None]:
# Example batch processing of URLs
# image_urls = [
#     "https://storage.googleapis.com/sfr-vision-language-research/BLIP/demo.jpg",
#     "https://example.com/another-image.jpg"
# ]
# batch_results = batch_caption_images(image_urls)

## Advanced: Customize Generation Parameters

In [None]:
def generate_detailed_caption(image_path_or_url, is_url=False, 
                              num_beams=5, max_length=75, 
                              prompt="Describe this image in detail:"):
    """
    Generate a more detailed caption with custom parameters
    
    Args:
        image_path_or_url: Path to local image or URL to image
        is_url: Boolean indicating if the image_path_or_url is a URL
        num_beams: Number of beams for beam search
        max_length: Maximum length of generated caption
        prompt: Prompt to guide caption generation
        
    Returns:
        Generated detailed caption
    """
    try:
        # Load the image
        if is_url:
            response = requests.get(image_path_or_url)
            image = Image.open(BytesIO(response.content)).convert('RGB')
        else:
            image = Image.open(image_path_or_url).convert('RGB')
        
        # Display the image
        plt.figure(figsize=(10, 10))
        plt.imshow(image)
        plt.axis('off')
        plt.show()
        
        # Process the image with the detailed prompt
        inputs = processor(images=image, text=prompt, return_tensors="pt").to(device, torch.float16)
        
        # Generate detailed caption with custom parameters
        generated_ids = model.generate(
            **inputs,
            max_length=max_length,
            num_beams=num_beams,
            early_stopping=True,
            do_sample=True,
            top_k=50,
            temperature=0.7
        )
        
        # Decode the generated caption
        detailed_caption = processor.batch_decode(generated_ids, skip_special_tokens=True)[0].strip()
        
        return detailed_caption
    
    except Exception as e:
        return f"Error generating detailed caption: {str(e)}"

In [None]:
# Example of generating a more detailed caption
image_url = "https://storage.googleapis.com/sfr-vision-language-research/BLIP/demo.jpg"
detailed_caption = generate_detailed_caption(image_url, is_url=True)
print(f"Detailed caption: {detailed_caption}")

## Save Model for Future Use

In [None]:
def save_model_locally(save_dir="./blip2_model"):
    """Save the model and processor locally for future use"""
    try:
        # Create directory if it doesn't exist
        import os
        os.makedirs(save_dir, exist_ok=True)
        
        # Save processor and model
        processor.save_pretrained(save_dir)
        model.save_pretrained(save_dir)
        
        print(f"Model and processor saved to {save_dir}")
    except Exception as e:
        print(f"Error saving model: {str(e)}")


In [None]:
# Uncomment to save the model locally
# save_model_locally()

## Load Local Model

In [None]:
def load_local_model(model_dir="./blip2_model"):
    """Load a locally saved model"""
    try:
        local_processor = Blip2Processor.from_pretrained(model_dir)
        local_model = Blip2ForConditionalGeneration.from_pretrained(
            model_dir,
            torch_dtype=torch.float16
        ).to(device)
        
        print("Local model loaded successfully")
        return local_processor, local_model
    except Exception as e:
        print(f"Error loading local model: {str(e)}")
        return None, None