<a href="https://colab.research.google.com/github/jtaru28912/MULTIMODAL-RAG-SYSTEM/blob/main/MultiModalRAG.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

# Multi-modal RAG System

Many documents contain a mixture of content types, including text, tables and images.

Yet, information captured in images is lost in most RAG applications.

With the emergence of multimodal LLMs, like [GPT-4o](https://openai.com/index/hello-gpt-4o/), it is worth considering how to utilize images in RAG Systems:


![](https://i.imgur.com/wcCDT38.gif)


In [1]:
!pip install langchain==0.3.7

Collecting langchain-text-splitters<0.4.0,>=0.3.0 (from langchain==0.3.7)
  Using cached langchain_text_splitters-0.3.11-py3-none-any.whl.metadata (1.8 kB)
Collecting langsmith<0.2.0,>=0.1.17 (from langchain==0.3.7)
  Using cached langsmith-0.1.147-py3-none-any.whl.metadata (14 kB)
Collecting numpy<2.0.0,>=1.26.0 (from langchain==0.3.7)
  Using cached numpy-1.26.4-cp312-cp312-manylinux_2_17_x86_64.manylinux2014_x86_64.whl.metadata (61 kB)
INFO: pip is looking at multiple versions of langchain-core to determine which version is compatible with other requirements. This could take a while.
Collecting langchain-core<0.4.0,>=0.3.15 (from langchain==0.3.7)
  Using cached langchain_core-0.3.82-py3-none-any.whl.metadata (3.2 kB)
  Using cached langchain_core-0.3.81-py3-none-any.whl.metadata (3.2 kB)
  Using cached langchain_core-0.3.80-py3-none-any.whl.metadata (3.2 kB)
  Using cached langchain_core-0.3.79-py3-none-any.whl.metadata (3.2 kB)
  Using cached langchain_core-0.3.78-py3-none-any.whl

In [None]:
!pip install langchain-openai==0.2.8

[31mERROR: Operation cancelled by user[0m[31m
[0m^C


In [None]:
!pip install langchain-community

In [None]:
!pip install langchain-chroma==0.1.4

In [None]:
!pip install redis==5.2.0

In [None]:
import nltk
nltk.download('punkt') #punctuation mark
nltk.download('punkt_tab') #punctuation tab
nltk.download('averaged_perceptron_tagger')
# The 'averaged_perceptron_tagger' is a Part-of-Speech (POS) tagger. It's used in Natural Language Processing (NLP)
# to assign grammatical categories (like noun, verb, adjective) to words in a text.
# This is crucial for tasks like text analysis, information extraction, and machine translation.

In [None]:
!pip install "unstructured[all-docs]"
# for parsing the data

In [None]:
# install OCR dependencies for unstructured image processing
!sudo apt-get install tesseract-ocr
!sudo apt-get install poppler-utils

In [None]:
# extract table from pdf
!pip install htmltabletomd==1.0.0

## Data Loading & Processing

### Partition PDF tables, text, and images
  

In [None]:
!wget https://sgp.fas.org/crs/misc/IF10244.pdf
# website getting - download the pdf from online

In [None]:
# This command removes the 'figures' directory and all its contents recursively and forcefully.
# This is typically done to clean up previous outputs or temporary files.
!rm -rf ./figures

In [None]:
# Extractig images and tables from unstructured.io
from langchain_community.document_loaders import UnstructuredPDFLoader

doc = '/content/IF10244.pdf'
# Extract tables
# takes 1-2 min on Colab
loader = UnstructuredPDFLoader(file_path=doc,
                               strategy='hi_res',
                               extract_images_in_pdf=True,
                               infer_table_structure=True,
                               mode='elements',
                               image_output_dir_path='/content/figures')
data = loader.load()

In [None]:
len(data)
# data

In [None]:
# extracting table here
[doc.metadata['category'] for doc in data if doc.metadata['category'] == 'Table']

In [None]:
tables = [doc for doc in data if doc.metadata['category'] == 'Table']
len(tables)

In [None]:
loader = UnstructuredPDFLoader(file_path=doc,
                               strategy='hi_res',
                               extract_images_in_pdf=True,
                               infer_table_structure=True,
                               chunking_strategy="by_title", # section-based chunking
                               max_characters=4000, # max size of chunks
                               new_after_n_chars=4000, # preferred size of chunks
                               overlap_n_chars=20, # overlap between chunks
                               combine_text_under_n_chars=2000, # smaller chunks < 2000 chars will be combined into a larger chunk
                               mode='elements',
                               image_output_dir_path='./figures')
texts = loader.load()
len(texts)

In [None]:
data = texts + tables
data
len(data)
data[5]

In [None]:
from IPython.display import HTML, display, Markdown

In [None]:
print(data[2].page_content)

In [None]:
print(data[1].metadata['text_as_html'])

In [None]:
display(Markdown(data[1].metadata['text_as_html']))

In [None]:
# Since unstructured extracts the text from the table without any borders, we can use the HTML text and put it directly in prompts (LLMs understand HTML tables well) or even better convert HTML tables to Markdown tables as below

In [None]:
import htmltabletomd

md_table = htmltabletomd.convert_table(data[1].metadata['text_as_html'])
print(md_table)

## Separate Data into Text and Table Elements

In [None]:
docs = []
tables = []

for doc in data:
  if doc.metadata['category'] =='Table':
    tables.append(doc)
  elif doc.metadata['category'] =='NarrativeText':
    docs.append(doc)
  elif doc.metadata['category'] =='Title':
    docs.append(doc)
  elif doc.metadata['category'] =='FigureCaption':
    docs.append(doc)
  elif doc.metadata['category'] =='UncategorizedText':
    docs.append(doc)
  elif doc.metadata['category'] =='Header':
    docs.append(doc)
  else:
    docs.append(doc)

In [None]:
print(len(docs))
print(len(tables))

# CONVERT HTML TABLES TO MARKDOWN

In [None]:
# convert html tables into markdown

for table in tables:
    table.page_content = htmltabletomd.convert_table(table.metadata['text_as_html'])

for table in tables:
    print(table.page_content)
    print()

In [None]:
# Viwe extracted image

! ls -l ./figures

In [None]:
from IPython.display import Image

Image('./figures/figure-1-1.jpg')
Image('./figures/figure-1-2.jpg')

In [None]:
### Enter Open AI API Key


from getpass import getpass

OPENAI_KEY = getpass('Enter Open AI API Key: ')



import os

os.environ['OPENAI_API_KEY'] = OPENAI_KEY


In [None]:
### Load Connection to LLM

# Here we create a connection to ChatGPT to use later in our chains

In [None]:
from langchain_openai import ChatOpenAI

# Pass the API key explicitly to the ChatOpenAI constructor
chatgpt = ChatOpenAI(model_name='gpt-4o', temperature=0, api_key=openai_key)



### Text and Table summaries

We will use GPT-4o to produce table and, text summaries.

Text summaries are advised if using large chunk sizes (e.g., as set above, we use 4k token chunks).

Summaries are used to retrieve raw tables and / or raw chunks of text.

In [None]:
from langchain_core.output_parsers import StrOutputParser
from langchain_core.prompts import ChatPromptTemplate
from langchain_openai import ChatOpenAI
from langchain_core.runnables import RunnablePassthrough

In [None]:

# Prompt
prompt_text = """
You are an assistant tasked with summarizing tables and text particularly for semantic retrieval.
These summaries will be embedded and used to retrieve the raw text or table elements
Give a detailed summary of the table or text below that is well optimized for retrieval.
For any tables also add in a one line description of what the table is about besides the summary.
Do not add redundant words like Summary.
Just output the actual summary content.

Table or text chunk:
{element}
"""
prompt = ChatPromptTemplate.from_template(prompt_text)

# Summary chain
summarize_chain = (
                    {"element": RunnablePassthrough()}
                      |
                    prompt
                      |
                    chatgpt
                      |
                    StrOutputParser() # extracts the response as text and returns it as a string
)

# Initialize empty summaries
text_summaries = []
table_summaries = []

text_docs = [doc.page_content for doc in docs]
table_docs = [table.page_content for table in tables]

text_summaries = summarize_chain.batch(text_docs, {"max_concurrency": 5})
table_summaries = summarize_chain.batch(table_docs, {"max_concurrency": 5})

len(text_summaries), len(table_summaries)

# IMAGE SUMMARIZATION VIA LLM

In [None]:
import base64
import os

from langchain_core.messages import HumanMessage


def encode_image(image_path):
    """Getting the base64 string"""
    with open(image_path, "rb") as image_file:
        return base64.b64encode(image_file.read()).decode("utf-8")


def image_summarize(img_base64, prompt):
    """Make image summary"""
    chat = ChatOpenAI(model="gpt-4o", temperature=0)

    msg = chat.invoke(
        [
            HumanMessage(
                content=[
                    {"type": "text", "text": prompt},
                    {
                        "type": "image_url",
                        "image_url": {"url": f"data:image/jpeg;base64,{img_base64}"},
                    },
                ]
            )
        ]
    )
    return msg.content


def generate_img_summaries(path):
    """
    Generate summaries and base64 encoded strings for images
    path: Path to list of .jpg files extracted by Unstructured
    """

    # Store base64 encoded images
    img_base64_list = []

    # Store image summaries
    image_summaries = []

    # Prompt
    prompt = """You are an assistant tasked with summarizing images for retrieval.
                Remember these images could potentially contain graphs, charts or tables also.
                These summaries will be embedded and used to retrieve the raw image for question answering.
                Give a detailed summary of the image that is well optimized for retrieval.
                Do not add additional words like Summary, This image represents, etc.
             """

    # Apply to images
    for img_file in sorted(os.listdir(path)):
        if img_file.endswith(".jpg"):
            img_path = os.path.join(path, img_file)
            base64_image = encode_image(img_path)
            img_base64_list.append(base64_image)
            image_summaries.append(image_summarize(base64_image, prompt))

    return img_base64_list, image_summaries


# Image summaries
IMG_PATH = './figures'
imgs_base64, image_summaries = generate_img_summaries(IMG_PATH)

In [None]:
len(imgs_base64), len(image_summaries)

In [None]:
display(Image('./figures/figure-1-2.jpg'))

In [None]:
display(Markdown(image_summaries[1]))

## Multi-vector retriever

Use [multi-vector-retriever](https://python.langchain.com/docs/modules/data_connection/retrievers/multi_vector#summary) to index image (and / or text, table) summaries, but retrieve raw images (along with raw texts or tables).

### Download and Install Redis as a DocStore

You can use any other database or cache as a docstore to store the raw text, table and image elements

In [None]:
%%sh
curl -fsSL https://packages.redis.io/gpg | sudo gpg --dearmor -o /usr/share/keyrings/redis-archive-keyring.gpg
echo "deb [signed-by=/usr/share/keyrings/redis-archive-keyring.gpg] https://packages.redis.io/deb $(lsb_release -cs) main" | sudo tee /etc/apt/sources.list.d/redis.list
sudo apt-get update  > /dev/null 2>&1
sudo apt-get install redis-stack-server  > /dev/null 2>&1
redis-stack-server --daemonize yes

### Open AI Embedding Models

LangChain enables us to access Open AI embedding models which include the newest models: a smaller and highly efficient `text-embedding-3-small` model, and a larger and more powerful `text-embedding-3-large` model.

In [None]:
from langchain_openai import OpenAIEmbeddings

# details here: https://openai.com/blog/new-embedding-models-and-api-updates
openai_embed_model = OpenAIEmbeddings(model='text-embedding-3-small')


### Add to vectorstore & docstore

Add raw docs and doc summaries to [Multi Vector Retriever](https://python.langchain.com/docs/modules/data_connection/retrievers/multi_vector#summary):

* Store the raw texts, tables, and images in the `docstore` (here we are using Redis).
* Store the texts, table summaries, and image summaries and their corresponding embeddings in the `vectorstore` (here we are using Chroma) for efficient semantic retrieval.
* Connect them using a common `document_id`

In [None]:
import uuid

from langchain.retrievers.multi_vector import MultiVectorRetriever
from langchain_community.storage import RedisStore
from langchain_community.utilities.redis import get_client
from langchain_chroma import Chroma
from langchain_core.documents import Document
from langchain_openai import OpenAIEmbeddings


def create_multi_vector_retriever(
    docstore, vectorstore, text_summaries, texts, table_summaries, tables, image_summaries, images
):
    """
    Create retriever that indexes summaries, but returns raw images or texts
    """


    id_key = "doc_id"

    # Create the multi-vector retriever
    retriever = MultiVectorRetriever(
        vectorstore=vectorstore,
        docstore=docstore,
        id_key=id_key,
    )

    # Helper function to add documents to the vectorstore and docstore
    def add_documents(retriever, doc_summaries, doc_contents):
        doc_ids = [str(uuid.uuid4()) for _ in doc_contents]
        summary_docs = [
            Document(page_content=s, metadata={id_key: doc_ids[i]})
            for i, s in enumerate(doc_summaries)
        ]
        retriever.vectorstore.add_documents(summary_docs)
        retriever.docstore.mset(list(zip(doc_ids, doc_contents)))

    # Add texts, tables, and images
    # Check that text_summaries is not empty before adding
    if text_summaries:
        add_documents(retriever, text_summaries, texts)
    # Check that table_summaries is not empty before adding
    if table_summaries:
        add_documents(retriever, table_summaries, tables)
    # Check that image_summaries is not empty before adding
    if image_summaries:
        add_documents(retriever, image_summaries, images)

    return retriever


# The vectorstore to use to index the summaries and their embeddings
chroma_db = Chroma(
    collection_name="mm_rag",
    embedding_function=openai_embed_model,
    collection_metadata={"hnsw:space": "cosine"},
)

# Initialize the storage layer - to store raw images, text and tables
client = get_client('redis://localhost:6379')
redis_store = RedisStore(client=client) # you can use filestore, memorystory, any other DB store also

# Create retriever
retriever_multi_vector = create_multi_vector_retriever(
    redis_store,
    chroma_db,
    text_summaries,
    text_docs,
    table_summaries,
    table_docs,
    image_summaries,
    imgs_base64,
)

In [None]:
retriever_multi_vector

## Test Multimodal RAG Retriever


In [None]:
from IPython.display import HTML, display, Image
from PIL import Image
import base64
from io import BytesIO

def plt_img_base64(img_base64):
    """Disply base64 encoded string as image"""
    # Decode the base64 string
    img_data = base64.b64decode(img_base64)
    # Create a BytesIO object
    img_buffer = BytesIO(img_data)
    # Open the image using PIL
    img = Image.open(img_buffer)
    display(img)

In [None]:
# CHECK RETRIEVAL
query = "Analyze the wildfires trend with acres burned over the years"
docs = retriever_multi_vector.invoke(query, limit=5)

# We get 3 docs
len(docs)

docs

In [None]:
# Check retrieval
query = "Tell me about the percentage of residences burned by wildfires in 2022"
docs = retriever_multi_vector.invoke(query, limit=5)

# We get 4 docs
len(docs)
docs

## Utilities to separate retrieved elements

We need to bin the retrieved doc(s) into the correct parts of the GPT-4o prompt template.

Here we need to have text, table elements as one set of inputs and image elements as the other set of inputs as both require separate prompts in GPT-4o.

In [None]:
import re
import base64

def looks_like_base64(sb):
    """Check if the string looks like base64"""
    return re.match("^[A-Za-z0-9+/]+[=]{0,2}$", sb) is not None


def is_image_data(b64data):
    """
    Check if the base64 data is an image by looking at the start of the data
    """
    image_signatures = {
        b"\xff\xd8\xff": "jpg",
        b"\x89\x50\x4e\x47\x0d\x0a\x1a\x0a": "png",
        b"\x47\x49\x46\x38": "gif",
        b"\x52\x49\x46\x46": "webp",
    }
    try:
        header = base64.b64decode(b64data)[:8]  # Decode and get the first 8 bytes
        for sig, format in image_signatures.items():
            if header.startswith(sig):
                return True
        return False
    except Exception:
        return False


def split_image_text_types(docs):
    """
    Split base64-encoded images and texts
    """
    b64_images = []
    texts = []
    for doc in docs:
        # Check if the document is of type Document and extract page_content if so
        if isinstance(doc, Document):
            doc = doc.page_content.decode('utf-8')
        else:
            doc = doc.decode('utf-8')
        if looks_like_base64(doc) and is_image_data(doc):
            b64_images.append(doc)
        else:
            texts.append(doc)
    return {"images": b64_images, "texts": texts}

In [None]:
# Check retrieval
query = "Tell me detailed statistics of the top 5 years with largest wildfire acres burned"
docs = retriever_multi_vector.invoke(query, limit=5)

# We get 3 docs
len(docs)
docs

In [None]:
is_image_data(docs[2].decode('utf-8'))

In [None]:
r = split_image_text_types(docs)
r

# BUILD END TO END RAG PIPELINE for MULTIMODAL RAG

In [None]:
from operator import itemgetter
from langchain_core.runnables import RunnableLambda, RunnablePassthrough
from langchain_core.messages import HumanMessage

def multimodal_prompt_function(data_dict):
    """
    Create a multimodal prompt with both text and image context.

    This function formats the provided context from `data_dict`, which contains
    text, tables, and base64-encoded images. It joins the text (with table) portions
    and prepares the image(s) in a base64-encoded format to be included in a message.

    The formatted text and images (context) along with the user question are used to
    construct a prompt for GPT-4o
    """
    formatted_texts = "\n".join(data_dict["context"]["texts"])
    messages = []

    # Adding image(s) to the messages if present
    if data_dict["context"]["images"]:
        for image in data_dict["context"]["images"]:
            image_message = {
                "type": "image_url",
                "image_url": {"url": f"data:image/jpeg;base64,{image}"},
            }
            messages.append(image_message)

    # Adding the text and tables for analysis
    text_message = {
        "type": "text",
        "text": (
            f"""You are an analyst tasked with understanding detailed information and trends
                from text documents, data tables, and charts and graphs in images.
                You will be given context information below which will be a mix of text, tables,
                and images usually of charts or graphs.
                Use this information to provide answers related to the user question.
                Analyze all the context information including tables, text and images to generate the answer.
                Do not make up answers, If the question context is not present in the document just say dont know the answer and in that case please dont generate any sources.
                use the provided context documents below
                and answer the question to the best of your ability.

                User question:
                {data_dict['question']}

                Context documents:
                {formatted_texts}

                Answer:
            """
        ),
    }
    messages.append(text_message)
    return [HumanMessage(content=messages)]


# Create RAG chain
multimodal_rag = (
        {
            "context": itemgetter('context'),
            "question": itemgetter('input'),
        }
            |
        RunnableLambda(multimodal_prompt_function)
            |
        chatgpt
            |
        StrOutputParser()
)

# Pass input query to retriever and get context document elements
retrieve_docs = (itemgetter('input')
                    |
                retriever_multi_vector
                    |
                RunnableLambda(split_image_text_types))

# Below, we chain `.assign` calls. This takes a dict and successively
# adds keys-- "context" and "answer"-- where the value for each key
# is determined by a Runnable (function or chain executing at runtime).
# This helps in also having the retrieved context along with the answer generated by GPT-4o
multimodal_rag_w_sources = (RunnablePassthrough.assign(context=retrieve_docs)
                                               .assign(answer=multimodal_rag)
)

In [None]:
# Run RAG chain
query = "Tell me detailed statistics of the top 5 years with largest wildfire acres burned"
response = multimodal_rag_w_sources.invoke({'input': query})
response

In [None]:
def multimodal_rag_qa(query):
    response = multimodal_rag_w_sources.invoke({'input': query})
    print('=='*50)
    print('Answer:')
    display(Markdown(response['answer']))
    print('--'*50)
    print('Sources:')
    text_sources = response['context']['texts']
    img_sources = response['context']['images']
    for text in text_sources:
        display(Markdown(text))
        print()
    for img in img_sources:
        plt_img_base64(img)
        print()
    print('=='*50)

In [None]:
query = "Tell me detailed statistics of the top 5 years with largest wildfire acres burned"
multimodal_rag_qa(query)

In [None]:
# Run RAG chain
query = "Tell me about the percentage of residences burned by wildfires in 2022"
multimodal_rag_qa(query)

In [None]:
# Run RAG chain
query = "Analyze the wildfires trend with acres burned over the years"
multimodal_rag_qa(query)

In [None]:
# Run RAG chain
query = "which teams are the part of ICC t20 worldcup 2026"
multimodal_rag_qa(query)