# Simple Multimodal RAG System

This notebook demonstrates a simple Multimodal Retrieval-Augmented Generation (RAG) system that can run in Google Colab. It uses CLIP to retrieve relevant images based on text queries, which can enhance image generation prompts.


## Step 1: Install Dependencies


In [None]:
%pip install transformers torch pillow requests numpy matplotlib


## Step 2: Import and Setup


In [None]:
import os
from simple_multimodal_rag import SimpleMultimodalRAG, visualize_retrieved_images
from download_images import create_sample_images_from_urls, create_placeholder_images

# Download or create sample images
print("Setting up images...")
try:
    image_files, metadata = create_sample_images_from_urls("images")
except Exception as e:
    print(f"Download failed: {e}")
    print("Creating placeholder images instead...")
    image_files, metadata = create_placeholder_images("images")

print(f"\nReady! Found {len(image_files)} images:")
for f, m in zip(image_files, metadata):
    print(f"  - {f} ({m.get('caption', 'no caption')})")


## Step 3: Initialize Multimodal RAG System


In [None]:
# Initialize the multimodal RAG system
# Using CLIP base model (lightweight and suitable for Colab)
rag = SimpleMultimodalRAG(clip_model_name="openai/clip-vit-base-patch32")

# Load images with metadata
rag.load_images(image_files, metadata)

# Build the embedding index
rag.build_index()


## Step 4: Query the System


In [None]:
# Query the multimodal RAG system
query = "red geometric shape"
result = rag.query(query, top_k=3)

print(f"Query: {result['query']}\n")
print(f"Found {result['num_images']} relevant images:\n")
print("="*60)

for i, img_info in enumerate(result['retrieved_images']):
    print(f"\nImage {i+1}:")
    print(f"  Path: {img_info['image_path']}")
    print(f"  Similarity: {img_info['similarity']:.3f}")
    print(f"  Caption: {img_info['metadata'].get('caption', 'N/A')}")


## Step 5: Visualize Retrieved Images


In [None]:
# Visualize the retrieved images
result = rag.query("colorful geometric shapes", top_k=4)
visualize_retrieved_images(result['retrieved_images'], max_display=4)


## Step 6: Enhance Image Generation Prompts


In [None]:
# Original prompt for image generation
original_prompt = "a modern abstract design"

# Enhance the prompt using retrieved images
enhanced_descriptive = rag.enhance_prompt(original_prompt, top_k=2, style="descriptive")
enhanced_concise = rag.enhance_prompt(original_prompt, top_k=2, style="concise")
enhanced_detailed = rag.enhance_prompt(original_prompt, top_k=2, style="detailed")

print("Original prompt:", original_prompt)
print("\nEnhanced (descriptive):", enhanced_descriptive)
print("\nEnhanced (concise):", enhanced_concise)
print("\nEnhanced (detailed):", enhanced_detailed)

print("\n" + "="*60)
print("You can now use these enhanced prompts with image generation models!")
print("Example: generated_image = image_model.generate(enhanced_descriptive)")


## Step 7: Multiple Queries


In [None]:
queries = [
    "red circle",
    "blue square",
    "geometric shapes",
    "colorful design"
]

for query in queries:
    print(f"\n{'='*60}")
    print(f"Query: {query}")
    print('='*60)
    result = rag.query(query, top_k=2)
    if result['retrieved_images']:
        top_img = result['retrieved_images'][0]
        print(f"Top result (similarity: {top_img['similarity']:.3f}):")
        print(f"  {top_img['image_path']}")
        print(f"  Caption: {top_img['metadata'].get('caption', 'N/A')}")


## Step 8: Save and Load (Optional)


In [None]:
# Save the RAG system
rag.save("multimodal_rag_system.pkl")

# Load it back
rag2 = SimpleMultimodalRAG()
rag2.load("multimodal_rag_system.pkl")

# Test loaded system
result = rag2.query("red shape", top_k=2)
print(f"Loaded system found {result['num_images']} images for query: {result['query']}")
