# Multimodal RAG with LangChain
This notebook demonstrates how to build a multimodal Retrieval-Augmented Generation (RAG) system using LangChain. The system can process both text and images from a PDF document, create summaries, store them in a vector store, and answer questions based on the retrieved context.

## 1. Setup
### 1.1. Install Dependencies
First, we install all the necessary Python packages.

In [None]:
%pip install -U "unstructured[all-docs]" pillow lxml chromadb tiktoken langchain langchain-community langchain-openai langchain-groq python_dotenv langchain-openrouter langchain-huggingface sentence-transformers InstructorEmbedding

### 1.2. Imports
Import all the required libraries and modules.

In [None]:
import os, re, time, base64, uuid
from dotenv import load_dotenv
from IPython.display import Image, display
from unstructured.partition.pdf import partition_pdf
from langchain_groq import ChatGroq
from langchain.prompts import ChatPromptTemplate
from langchain_core.output_parsers import StrOutputParser
from langchain_core.messages import HumanMessage
from langchain_core.documents import Document
from langchain.vectorstores import Chroma
from langchain.storage import InMemoryStore
from langchain.embeddings import HuggingFaceBgeEmbeddings
from langchain.retrievers.multi_vector import MultiVectorRetriever
from langchain_core.runnables import RunnablePassthrough, RunnableLambda
from google.api_core.exceptions import ResourceExhausted
from groq import RateLimitError
from langchain.chat_models import ChatOpenAI

### 1.3. Load Environment Variables
Load API keys and other secrets from a `.env` file.

In [None]:
load_dotenv()  # take environment variables from .env file

### 1.4. Utility Functions
Helper function to display base64 encoded images in the notebook.

In [None]:
def display_base64_image(base64_code):
    image_data = base64.b64decode(base64_code)
    display(Image(data=image_data))

## 2. Data Extraction and Preparation
### 2.1. Partition PDF
We use `unstructured` to partition the PDF file into chunks of text and extract images and tables. The `hi_res` strategy is used to ensure high-quality extraction.

In [None]:
file_path = "./content/attention.pdf"     # Path to your PDF file

chunks = partition_pdf(
    filename=file_path,
    infer_table_structure=True,            # Extract tables
    strategy="hi_res",                     # Mandatory to infer tables
    extract_image_block_types=["Image", "Table"],   # Extract images of tables
    extract_image_block_to_payload=True,   # Extract base64 for API usage
    chunking_strategy="by_title",          # Chunking strategy
    max_characters=10000,                  # Max characters per chunk
    combine_text_under_n_chars=2000,       # Combine small text chunks
    new_after_n_chars=6000,                # Start new chunk after n chars
)

### 2.2. Separate Elements
We separate the extracted elements into three categories: texts, tables, and images.

In [None]:
tables = []
texts = []
images = []

for chunk in chunks:
    elements = chunk.metadata.orig_elements
    for element in elements:
        if "Table" in str(type(element)):
            tables.append(element)
        elif "Image" in str(type(element)):
            images.append(element.metadata.image_base64)

    if "CompositeElement" in str(type((chunk))):
        texts.append(chunk)

print(f"Found {len(texts)} text chunks, {len(tables)} tables, and {len(images)} images.")

## 3. Content Summarization
To handle large documents and diverse content types, we summarize the extracted text, tables, and images. These summaries will be used for retrieval.

### 3.1. Text and Table Summarization
We use a language model from Groq to create concise summaries of the text chunks and tables.

In [None]:
summary_prompt_template = """
You are an assistant tasked with summarizing tables and text.
Give a concise summary of the table or text.
Respond only with the summary, no additional comment.
Do not start your message by saying "Here is a summary" or anything like that.
Just give the summary as it is.

Table or text chunk: {element}
"""
summary_prompt = ChatPromptTemplate.from_template(summary_prompt_template)

summary_model = ChatGroq(temperature=0.5, model="llama-3.1-8b-instant")

summarize_chain = {"element": lambda x: x} | summary_prompt | summary_model | StrOutputParser()

try:
    text_summaries = summarize_chain.batch(texts, {"max_concurrency": 1})
    tables_html = [table.metadata.text_as_html for table in tables]
    table_summaries = summarize_chain.batch(tables_html, {"max_concurrency": 1})
except RateLimitError as e:
    print(f"Rate limit exceeded: {e}. Waiting for 5 minutes before retrying.")
    time.sleep(300)
    text_summaries = summarize_chain.batch(texts, {"max_concurrency": 1})
    tables_html = [table.metadata.text_as_html for table in tables]
    table_summaries = summarize_chain.batch(tables_html, {"max_concurrency": 1})

print(f"Summarized {len(text_summaries)} text chunks and {len(table_summaries)} tables.")

### 3.2. Image Summarization
For images, we use a multimodal model (Gemini) to generate detailed descriptions. The prompt includes context that the images are from the "Attention Is All You Need" paper.

In [None]:
image_prompt_template = """Describe the image in detail. For context,
                  the image is part of the research paper by google Attention Is All You Need. 
                  Be specific about graphs, such as bar plots."""

image_messages = [
    ("user",[
            {"type": "text", "text": image_prompt_template},
            {"type": "image_url","image_url": {"url": "data:image/jpeg;base64,{image}"}},
        ],
    )
]

image_prompt = ChatPromptTemplate.from_messages(image_messages)

image_llm = ChatGoogleGenerativeAI(
    model="gemini-1.5-flash",
    temperature=0.5,
    google_api_key=os.getenv("GOOGLE_API_KEY"),
)

image_chain = image_prompt | image_llm | StrOutputParser()
# Desprate attempt to handle free tier limits for llm api calls xD

while True:
    try:
        image_summaries = image_chain.batch(images, {"max_concurrency": 1})
        break
    except ResourceExhausted as e:
        msg = str(e)
        match = re.search(r'Please try again in ([\d\.]+)([sm])', msg)
        wait_seconds = 60
        if match:
            value, unit = match.groups()
            wait_seconds = float(value) * 60 if unit == 'm' else float(value)
        print(f"⏳ Rate limit exceeded. Waiting {wait_seconds:.2f} seconds...")
        time.sleep(wait_seconds)
    except Exception as e:
        print(f"⚠️ Unexpected error: {e}")
        time.sleep(5)

print(f"Summarized {len(image_summaries)} images.")

## 4. Vector Store and Retriever Setup
### 4.1. Initialize Embeddings and Vector Store
We use a multi-vector retriever strategy. The summaries are stored in a Chroma vector store, while the original, larger chunks (parent documents) are stored in an in-memory store.

In [None]:
embeddings = HuggingFaceBgeEmbeddings(
    model_name="BAAI/bge-large-en-v1.5",
    encode_kwargs={"normalize_embeddings": True}
)

vectorstore = Chroma(
    collection_name="multimodal_rag",
    embedding_function=embeddings
)

store = InMemoryStore()
id_key = "doc_id"

retriever = MultiVectorRetriever(
    vectorstore=vectorstore,
    docstore=store,
    id_key=id_key,
)

### 4.2. Add Documents to Retriever
We add the text, table, and image summaries to the vector store, and the original documents to the document store.

In [None]:
# Add texts
doc_ids = [str(uuid.uuid4()) for _ in texts]
summary_texts = [Document(page_content=summary, metadata={id_key: doc_ids[i]}) for i, summary in enumerate(text_summaries)]
retriever.vectorstore.add_documents(summary_texts)
retriever.docstore.mset(list(zip(doc_ids, texts)))

# Add tables
table_ids = [str(uuid.uuid4()) for _ in tables]
summary_tables = [Document(page_content=summary, metadata={id_key: table_ids[i]}) for i, summary in enumerate(table_summaries)]
retriever.vectorstore.add_documents(summary_tables)
retriever.docstore.mset(list(zip(table_ids, tables)))

# Add image summaries
img_ids = [str(uuid.uuid4()) for _ in images]
summary_img = [Document(page_content=summary, metadata={id_key: img_ids[i]}) for i, summary in enumerate(image_summaries)]
retriever.vectorstore.add_documents(summary_img)
retriever.docstore.mset(list(zip(img_ids, images)))

print(f"Total documents in vector store: {vectorstore._collection.count()}")

## 5. Building the RAG Pipeline
We build the final RAG chain. The retriever fetches relevant document summaries, and the corresponding full documents (text or images) are passed to the final model to generate an answer.

In [None]:
def parse_docs(docs):
    b64_images = []
    text_docs = []
    for doc in docs:
        try:
            # Attempt to decode to check if it's a base64 string
            base64.b64decode(doc, validate=True)
            b64_images.append(doc)
        except Exception:
            text_docs.append(doc)
    return {"images": b64_images, "texts": text_docs}

def build_prompt(kwargs):
    docs_by_type = kwargs["context"]
    user_question = kwargs["question"]

    context_text = ".".join(doc.page_content + "\n" for doc in docs_by_type["texts"])

    prompt_template = f"""
    You are an expert assistant. Answer the user's question **only** based on the given context.

    Context:
    {context_text}

    Question:
    {user_question}

    Provide a clear and concise answer.
    """

    prompt_content = [{"type": "text", "text": prompt_template}]
    for image in docs_by_type["images"]:
        prompt_content.append({"type": "image_url", "image_url": {"url": f"data:image/jpeg;base64,{image}"}})
    
    return [HumanMessage(content=prompt_content)]

rag_model = ChatOpenAI(
    model="qwen/qwen2.5-vl-32b-instruct:free",
    base_url="https://openrouter.ai/api/v1",
    api_key=os.getenv("OPENROUTER_API_KEY"),
    temperature=0.2,
    max_tokens=800
)

chain_with_sources = (
    {"context": retriever | RunnableLambda(parse_docs), "question": RunnablePassthrough()}
    | RunnablePassthrough().assign(
        response=(RunnableLambda(build_prompt) | rag_model | StrOutputParser())
    )
)

## 6. Running the RAG Pipeline
Now we can ask questions to our RAG system. The system will retrieve relevant context (both text and images) and generate an answer.

In [None]:
question = "What is the attention mechanism? can you show me an image of it?"
response = chain_with_sources.invoke(question)

print("Response:", response['response'])

print("\n\nContext:")
for text in response['context']['texts']:
    print(text.page_content)
    if hasattr(text, 'metadata') and 'page_number' in text.metadata:
        print("Page number: ", text.metadata.page_number)
    print("\n" + "-" * 50 + "\n")

for image in response['context']['images']:
    display_base64_image(image)