In [1]:
import fitz
from PIL import Image
import torch
import numpy as np
import os 
import base64
import io
from transformers import CLIPProcessor, CLIPModel
import pymupdf

from langchain_ollama import ChatOllama
from langchain_core.prompts import PromptTemplate
from langchain_core.messages import HumanMessage
from langchain_community.vectorstores import FAISS
from langchain.text_splitter import RecursiveCharacterTextSplitter
from sklearn.metrics.pairwise import cosine_similarity
from langchain_core.documents import Document

  from .autonotebook import tqdm as notebook_tqdm


In [2]:
## Clip Model for unified embeddings
clip_model=CLIPModel.from_pretrained("openai/clip-vit-base-patch32")
clip_processor=CLIPProcessor.from_pretrained("openai/clip-vit-base-patch32")

Using a slow image processor as `use_fast` is unset and a slow processor was saved with this model. `use_fast=True` will be the default behavior in v4.52, even if the model was saved with a slow processor. This will result in minor differences in outputs. You'll still be able to use a slow processor with `use_fast=False`.
Fetching 1 files: 100%|██████████| 1/1 [00:00<00:00, 4076.10it/s]


In [7]:
### Embedding functions
def embed_image(image_data):
    """Embed image using CLIP"""
    if isinstance(image_data, str):
        image = Image.open(image_data).convert("RGB")
    else:
        image = image_data
        
    inputs = clip_processor(images=image, return_tensors="pt")
    
    with torch.no_grad():
        features = clip_model.get_image_features(**inputs)
        # Normalize embeddings to unit vector
        features = features / features.norm(dim=-1, keepdim=True)
        return features.squeeze().numpy()
    
def embed_text(text):
    """Embed text using CLIP"""
    inputs = clip_processor(
        text=text,
        return_tensors="pt",
        padding=True,
        truncation=True,
        max_length=77 # CLIPS max token length
    )
    
    with torch.no_grad():
        features = clip_model.get_text_features(**inputs)
        # Normalize embeddings
        features = features / features.norm(dim=-1, keepdim=True)
        return features.squeeze().numpy()

In [8]:
## Process PDF
pdf_path="multimodal_sample.pdf"
doc=pymupdf.open(pdf_path)

# Storage for all documents and embeddings
all_docs = []
all_embeddings = []
image_data_store = {}

# Text Splitter
splitter = RecursiveCharacterTextSplitter(chunk_size=500, chunk_overlap=100)

In [9]:
for i, page in enumerate(doc):
    # Process text
    text=page.get_text()
    if text.strip():
        # create a temporary document for splitting
        temp_doc = Document(page_content=text, metadata={'page': i, "type": "text"})
        text_chunks = splitter.split_documents([temp_doc])
        
        # Embed each chunk using CLIP
        for chunk in text_chunks:
            embedding = embed_text(chunk.page_content)
            all_embeddings.append(embedding)
            all_docs.append(chunk)
    
    # Process images
    ## Three Steps:
    ### Convert PDF Image to PIL format
    ### Store as base64 
    ### Create CLIP embedding for retrieval
    
    for img_index, img in enumerate(page.get_images(full=True)):
        try:
            xref=img[0]
            base_image=doc.extract_image(xref)
            image_bytes=base_image['image']
            
            # Convert to PIL mage
            pil_image = Image.open(io.BytesIO(image_bytes)).convert("RGB")
            
            # Create unique identifier
            image_id = f"page_{i}_img_{img_index}"
            
            # Store image as base64 for later use with Model
            buffered = io.BytesIO()
            pil_image.save(buffered, format="PNG")
            img_base64=base64.b64encode(buffered.getvalue()).decode()
            image_data_store[image_id] = img_base64
            
            # Embed image using CLIP
            embedding = embed_image(pil_image)
            all_embeddings.append(embedding)
            
            # Create document for CLIP
            image_doc = Document(
                page_content=f"[Image: {image_id}]",
                metadata={'page':i, "type": "image", "image_id": image_id}
            )            
            all_docs.append(image_doc)
        except Exception as e:
            print(f"Error processing image {img_index} on page {i}: {e}")
            continue
        
doc.close()            

In [10]:
# Create unified FAISS vectorestore with CLIP embeddings
embeddings_array = np.array(all_embeddings)

# Create custum FAISS index since we have precomputed embeddings
vectore_store = FAISS.from_embeddings(
    text_embeddings=[(doc.page_content, emb) for doc, emb in zip(all_docs, embeddings_array)],
    embedding=None, # Using precomputed embeddings
    metadatas=[doc.metadata for doc in all_docs]
)

`embedding_function` is expected to be an Embeddings object, support for passing in a function will soon be removed.


In [12]:
llm = ChatOllama(model="gemma3:4b-it-q4_K_M")
llm

ChatOllama(model='gemma3:4b-it-q4_K_M')

In [None]:
# Retrieve
def retrieve_multimodal(query, k=5):
    """Unified retrieval using CLIP embeddings for both text and images"""
    # Embed query using CLIP
    query_embedding = embed_text(query)
    
    # Search in unified vectore_store based on the query embeddings
    results = vectore_store.similarity_search_by_vector(
        embedding=query_embedding,
        k=k
    )
    
    return results

In [None]:
# Create Message
def create_multimodal_message(query, retrieved_docs):
    """Create a message with both text and images"""
    content = []
    
    # Add the query
    content.append({
        "type":"text",
        "text": f"Question: {query}\n\nContext:\n"
    })
    
    # Seperate text and image documents
    text_docs = [doc for doc in retrieved_docs if doc.metadata.get("type") == "text"]
    image_docs = [doc for doc in retrieved_docs if doc.metadata.get("type") == "image"]
    
    if text_docs:
        for doc in text_docs:
            text_content = "\n\n".join([
                f"[Page {doc.metadata['page']}]: {doc.page_content}"
            ])
            content.append({
                "type":"text",
                "text": f"Text excerpts:\n{text_content}\n"
            })
        
    for doc in image_docs:
        image_id = doc.metadata.get("image_id")    
        if image_id and image_id in image_data_store:
            content.append({
                "type":"text",
                "text":f"\n[Image from page {doc.metadata['page']}]:\n"
            })
            content.append({
                "type": "image_url",
                "image_url": {
                    "url": f"data:image/png;base64,{image_data_store[image_id]}"
                }
            })
            
    # Add instruction
    content.append({
        "type": "text",
        "text": "\n\nPlease answer the question based on the provided text and images."
    })
    
    return HumanMessage(content=content)

In [27]:
# RAG Pipeline
def multimodal_pdf_rag_pipeline(query):
    """Main pipeline for multimodal RAG"""
    
    # Retrieve relevant documents
    context_docs = retrieve_multimodal(query, k=5)
    
    # Create multimodal message
    message = create_multimodal_message(query, context_docs)
    
    # Get response from LLM
    response = llm.invoke([message])
    
    # Print retrieved context info
    print(f"\nRetrieved {len(context_docs)} documents")
    for doc in context_docs:
        doc_type = doc.metadata.get("type", "unknown")
        page = doc.metadata.get("page", "?")
        if doc_type == "text":
            preview = doc.page_content[:100] + "..." if len(doc.page_content) > 100 else doc.page_content
            print(f"  - Text from page {page}: {preview}")
        else:
            print(f"  - Image from page {page}")
    print("\n")
    
    return response.content
    

In [29]:
queries = [
    "What does the chart on page 1 show about revenue trends?",
    "Summarize the main findigs from the document",
    "What visual elements are present in the codument?"
]

for query in queries:
    print(f"\nQuery: {query}")
    print("=" * 50)
    answer = multimodal_pdf_rag_pipeline([query])
    print(f"Answer: {answer}")
    print("=" * 50)


Query: What does the chart on page 1 show about revenue trends?

Retrieved 2 documents
  - Text from page 0: Annual Revenue Overview
This document summarizes the revenue trends across Q1, Q2, and Q3. As illust...
  - Image from page 0


Answer: Here's the answer to your question, based on the text and image:

The chart shows that revenue grew steadily, with the highest growth occurring in Q3. Q1 showed a moderate increase due to new product lines, Q2 outperformed Q1 due to marketing campaigns, and Q3 had exponential growth due to global expansion.

Query: Summarize the main findigs from the document

Retrieved 2 documents
  - Text from page 0: Annual Revenue Overview
This document summarizes the revenue trends across Q1, Q2, and Q3. As illust...
  - Image from page 0


Answer: Here’s a summary of the findings from the document:

The document reveals a positive trend in revenue growth over the first three quarters (Q1, Q2, and Q3). 

*   **Q1:** Revenue increased moderately due to the 