<a href="https://colab.research.google.com/github/bhatsbharath/generative_ai_agents/blob/main/multi_modal_rag.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

# Multi Modal RAG

It extracts and processes text and images from PDFs, utilizing a multi-modal Retrieval-Augmented Generation (RAG) system for summarizing and retrieving content for question answering

In [None]:
import fitz  # PyMuPDF
from PIL import Image
import io
import os
from io import BytesIO

from langchain_core.prompts import ChatPromptTemplate
from langchain_core.documents import Document
from langchain.text_splitter import RecursiveCharacterTextSplitter
from langchain_community.vectorstores import Chroma
text_data = []
img_data = []

Opens a PDF file.

Creates a directory named extracted_images to store output.

Iterates through each page in the PDF to:

Extract and store the page text in text_data.

Find and extract all embedded images.

Save each image using its original format and a unique filename based on page and image index.

In [None]:
with fitz.open('DETR.pdf') as pdf_file:
    # Create a directory to store the images
    if not os.path.exists("extracted_images"):
        os.makedirs("extracted_images")

    # Loop through every page in the PDF
    for page_number in range(len(pdf_file)):
        page = pdf_file[page_number]

        # Get the text on page
        text = page.get_text().strip()
        text_data.append({"response": text, "name": page_number+1})
        # Get the list of images on the page
        images = page.get_images(full=True)

        # Loop through all images found on the page
        for image_index, img in enumerate(images, start=0):
            xref = img[0]  # Get the XREF of the image
            base_image = pdf_file.extract_image(xref)  # Extract the image
            image_bytes = base_image["image"]  # Get the image bytes
            image_ext = base_image["ext"]  # Get the image extension

            # Load the image using PIL and save it
            image = Image.open(io.BytesIO(image_bytes))
            image.save(f"extracted_images/image_{page_number+1}_{image_index+1}.{image_ext}")

In [None]:
from langchain_community.chat_models import ChatOllama
from langchain_community.llms import Ollama

#model = Ollama(model="llama3.2-vision")
model = ChatOllama(model="llava", temperature=0.2)  # Use a vision-supported model name like "llava"

In [None]:
def pil_to_base64(image):
    buffered = BytesIO()
    image.save(buffered, format="PNG")
    return base64.b64encode(buffered.getvalue()).decode()

Iterates through all images in the extracted_images directory.

Converts each image to a base64-encoded string.

Constructs a HumanMessage containing both the image and a summarization prompt.

Sends the message to the model for processing.

Stores the model's summarized response for each image in the img_data list, along with the image name.

In [None]:
import base64
from langchain_core.messages import HumanMessage

for img in os.listdir("extracted_images"):
    image = Image.open(f"extracted_images/{img}")
    image_base64 = pil_to_base64(image)

    # Prepare the message in LangChain's HumanMessage format with image + text
    msg = HumanMessage(
        content=[
            {"type": "image_url", "image_url": {"url": f"data:image/png;base64,{image_base64}"}},
            {
                "type": "text",
                "text": "You are an assistant tasked with summarizing tables, images and text for retrieval. "
                        "These summaries will be embedded and used to retrieve the raw text or table elements. "
                        "Give a concise summary of the table or text that is well optimized for retrieval. "
                        "Table or text or image:"
            }
        ]
    )

    response = model.invoke([msg])
    img_data.append({"response": response.content, "name": img})

Initializes an embedding model (nomic-embed-text) using OllamaEmbeddings.

Converts previously extracted text (text_data) and image summaries (img_data) into Document objects with associated metadata.

Uses a RecursiveCharacterTextSplitter to divide the documents into manageable chunks (400 tokens with 50-token overlap) for better embedding and retrieval performance.

Splits both text and image-based documents into chunks: doc_splits and img_splits.

In [None]:
from langchain_community.embeddings import OllamaEmbeddings

embedding_model  = OllamaEmbeddings(model="nomic-embed-text")

# Load the document
docs_list = [Document(page_content=text['response'], metadata={"name": text['name']}) for text in text_data]
img_list = [Document(page_content=img['response'], metadata={"name": img['name']}) for img in img_data]

# Split
text_splitter = RecursiveCharacterTextSplitter.from_tiktoken_encoder(
    chunk_size=400, chunk_overlap=50
)

doc_splits = text_splitter.split_documents(docs_list)
img_splits = text_splitter.split_documents(img_list)

In [None]:
# Add to vectorstore
vectorstore = Chroma.from_documents(
    documents=doc_splits + img_splits, # adding the both text and image splits
    collection_name="multi_model_rag",
    embedding=embedding_model,
)

retriever = vectorstore.as_retriever(
                search_type="similarity",
                search_kwargs={'k': 1}, # number of documents to retrieve
            )

In [None]:
query = "Does image contain giraffe?"

In [None]:
docs = retriever.invoke(query)


Sets up a prompt and language model for concise question answering.

Builds a chain combining prompt, LLM, and output parser.

Runs the chain with retrieved documents and user query to generate an answer.

In [None]:
from langchain_core.output_parsers import StrOutputParser

# Prompt
system = """You are an assistant for question-answering tasks. Answer the question based upon your knowledge.
Use three-to-five sentences maximum and keep the answer concise."""
prompt = ChatPromptTemplate.from_messages(
    [
        ("system", system),
        ("human", "Retrieved documents: \n\n <docs>{documents}</docs> \n\n User question: <question>{question}</question>"),
    ]
)

# LLM
llm = ChatOllama(model="llama3.2")

# Chain
rag_chain = prompt | llm | StrOutputParser()

# Run
generation = rag_chain.invoke({"documents":docs[0].page_content, "question": query})
print(generation)