# Multi-modal RAG for Chemical Engineers

**Method:**
* **Ingest a PDF**: Process a PDF of a chemical engineering textbook.
* **Partition**: Using `unstructured.io`, break the PDF into its constituent parts: text, tables, and images.
* **Summarize**: Generate concise summaries for every text chunk, table, and image using a LLMs.
* **Index**: Summaries will be embedded with an embedding model (such as `text-embedding-3-small`) and stored in a `Chroma` vector store (`MultiVectorRetriever` strategy searches over small, dense summaries).
* **Retrieve & Generate**: When a user asks a question, the system will retrieve the most relevant *original* text, tables, or images based on the summary search and use a multi-modal LLM (such as `gpt-4o-mini`) to generate a comprehensive answer.

## 1. Environment Setup and Dependencies

This notebook requires a specific setup, including a Linux environment (like WSL for Windows), system packages for document processing, and several Python libraries.

### 1a. System-Level Dependencies (Run in Terminal)
First, you need to install libraries for processing PDFs and images. Run the following command in your Ubuntu/WSL terminal:
```bash
sudo apt-get update && sudo apt-get install -y poppler-utils tesseract-ocr libmagic-dev
```

### 1b. Python Environment Setup (Run in Terminal)

1. Create the virtual environment (only needs to be done once)
    ```bash
    python3 -m venv venv
    ```

2. Activate the environment (do this every time you start a new session)
    ```bash
    source ./venv/bin/activate
    ```

3. Install Jupyter and create a kernel for this environment
    ```bash
    pip install jupyter ipykernel
    python -m ipykernel install --user --name=chem_rag_env --display-name "Python (ChemE RAG)"
    ```

4. Restart VS Code and select the "Python (ChemE RAG)" kernel.

### 1c. Core Dependencies
Run the following cell to install Unstructured, ChromaDB, LangChain, and Python Dotenv (for API keys)

In [None]:
%pip install -U "unstructured[all-docs]" pillow lxml
%pip install -U chromadb tiktoken
%pip install -U langchain langchain-community langchain-openai langchain-groq
%pip install -U python_dotenv

## 2. API Key Configuration

This RAG system uses services from Azure OpenAI (for embeddings and generation), and Groq (for fast text summarization). `python-dotenv` is used to manage our API keys securely within the notebook.

**Important**: The cell below looks for your API keys stored as environment variables. Make sure you have a `.env` file in the root directory of this project if you plan on sharing or showing this notebook to others to protect your API keys.

In [None]:
import os

os.environ["OPENAI_API_TYPE"] = "azure"
os.environ["AZURE_OPENAI_ENDPOINT"] = "YOUR_MULTIMODAL_LLM_ENDPOINT"
os.environ["OPENAI_API_VERSION"] = "YOUR_MULTIMODAL_LLM_API_VERSION"
os.environ["AZURE_OPENAI_API_KEY"] = "YOUR_MULTIMODAL_LLM_API_KEY"
os.environ["AZURE_OPENAI_DEPLOYMENT_NAME"] = "YOUR_MULTIMODAL_LLM_DEPLOYMENT_NAME"
os.environ["AZURE_EMBEDDING_ENDPOINT"] = "YOUR_EMBEDDING_MODEL_ENDPOINT"
os.environ["AZURE_EMBEDDING_API_KEY"] = "YOUR_EMBEDDING_MODEL_API_KEY"
os.environ["AZURE_EMBEDDING_DEPLOYMENT_NAME"] = "YOUR_EMBEDDING_MODEL_DEPLOYMENT_NAME"
os.environ["EMBEDDING_API_VERSION"] = "YOUR_EMBEDDING_MODEL_API_VERSION"
os.environ["GROQ_API_KEY"] = "YOUR_GROQ_API_KEY"
os.environ["LANGCHAIN_API_KEY"] = "YOUR_LANGCHAIN_API_KEY"
os.environ["LANGCHAIN_TRACING_V2"] = "false" # Set to "true" if you want to enable LangChain tracing (Debugging with LangChain)
os.environ["CUDA_LAUNCH_BLOCKING"] = "1"

## 3. Partitioning the PDF

### 3a. Using Unstructured
Now that the environment is set up, you can parse your document using a PDF file.

Using the `partition_pdf` function from the `unstructured` library, this function automatically identifies and extracts different types of content from your PDF.

**Key Parameters:**
- `infer_table_structure=True`: Reconstructs the underlying structure of tables.
- `extract_images_in_pdf=True`: Extracts all images.
- `extract_image_block_to_payload=True`: Stores extracted images directly as base64 strings within the element's metadata.
- `chunking_strategy="by_title"`: Groups related content under the same title, which is effective for textbooks.

To choose different chunking strategy/parameters: https://docs.unstructured.io/open-source/core-functionality/chunking

In [None]:
from unstructured.partition.pdf import partition_pdf

output_path = "./content/"
file_path = output_path + 'YOUR_PDF_FILE.pdf'

chunks = partition_pdf(
    filename=file_path,
    infer_table_structure=True,
    strategy="hi_res", # The "hi_res" strategy is necessary if you want to infer tables
    extract_images_in_pdf=True, 

    extract_image_block_types=["Image", "Table"],   
    extract_image_block_to_payload=True, # If true, will extract base64 for API usage

    chunking_strategy="by_title",
    max_characters=10000,                  
    combine_text_under_n_chars=2000,      
    new_after_n_chars=6000,

    extract_images_in_pdf=False, # Set to True if you want to extract images from the PDF seperate from base64 (not needed for RAG)
)

The output of the partitioner is a list of `CompositeElement` objects. Each `CompositeElement` contains a group of related smaller elements, making them easy to use together in a RAG pipeline. The contents of one chunk can be inspected below to understand the structure.

In [None]:
chunks[0].metadata.orig_elements

A chunk can also be cast to a string to see its primary text content.

In [None]:
str(chunks[0])

### 3b. Caching the Partitioned Data
PDF partitioning can be time-consuming for large documents. To save time on future runs, the `chunks` object can be saved to a `.pkl` file. The **SAVE** cell only needs to be run once after the initial parse. Every subsequent session or if kernel needs to be restarted, the parsed  `chunks` can be loaded back in with the **LOAD** cell.

Run the cell below to **SAVE** the partitioned data.

In [None]:
import pickle

with open("partitioned_output.pkl", "wb") as f:
    pickle.dump(chunks, f)

Run the cell below to **LOAD** the partitioned data on subsequent sessions.

In [None]:
import pickle

with open("partitioned_output.pkl", "rb") as f:
    chunks = pickle.load(f)

## 4. Element Processing and Verification

The `partition_pdf` function returns a list of `CompositeElement` objects. These elements now need to be processed and separated by type for the next steps.
### 4a. Separating Elements by Type
The partitioned elements are now separated into distinct lists of texts and tables for easier handling.

In [None]:
tables = []
texts = []

for chunk in chunks:
    for elem in chunk.metadata.orig_elements:
        if elem.to_dict()["type"] == 'Table':
            tables.append(elem)
            chunk.metadata.orig_elements.remove(elem) # The table is removed from the original elements to avoid duplication
    texts.append(chunk)

### 4b. Extracting Image Data
This function iterates through all chunks and extracts the base64-encoded image data that unstructured stored in the metadata.

In [None]:
def get_images_base64(chunks):
    images_b64 = []
    for chunk in chunks:
        if "CompositeElement" in str(type(chunk)):
            chunk_els = chunk.metadata.orig_elements
            for el in chunk_els:
                if "Image" in str(type(el)):
                    images_b64.append(el.metadata.image_base64)
    return images_b64

images = get_images_base64(chunks)

### 4c. Verifying Extraction

An individual table element can now be inspected. The `to_dict()` method provides a structured view of its content and metadata.

In [None]:
tables[0].to_dict()

One of the extracted images can be displayed to confirm that the process worked correctly. The images are stored as base64 strings, which can be decoded and displayed directly in the notebook.

In [None]:
import base64
from IPython.display import Image, display

def display_base64_image(base64_code):
    # Decode the base64 string to binary
    image_data = base64.b64decode(base64_code)
    # Display the image
    display(Image(data=image_data))

# Display the first extracted image to verify
if images:
    display_base64_image(images[0])
else:
    print("No images found in the document.")

## 5. Element Summarization

For the `MultiVectorRetriever` strategy to work, a summary of every element (text, table, and image) is needed. These summaries will be embedded and used for searching.

* Texts and Tables: A fast, open-source model via Groq (`gemma2-9b-it` is a good choice) will be used to generate concise text summaries.
* Images: A multi-modal model (like `gpt-4o-mini`) is needed to generate detailed descriptions of the images, paying special attention to diagrams and charts relevant to chemical engineering.

### 5a. Generating Text and Table Summaries
Groq is used for its and open-source models, which are ideal for summarizing a large number of text chunks and tables. A prompt is constructed to instruct the model to provide a concise summary. `Asynchronous I/O` is used to create a delay between API calls avoid exceeding the Groq Free Tier TPM limits.

First, install dependency to integrate `LangChain` and `Groq`.

In [None]:
%pip install -U langchain-groq

In [None]:
from langchain_groq import ChatGroq
from langchain_core.prompts import ChatPromptTemplate
from langchain_core.output_parsers import StrOutputParser
import asyncio

# Initialize the Groq model
model = ChatGroq(temperature=0.2, model="gemma2-9b-it")

prompt_text = """You are an assistant tasked with summarizing tables and text for a chemical engineering textbook. Give a concise summary of the table or text.
Respond only with the summary. Do not start your message with "Here is a summary" or any other introductory text.
Table or text chunk: {element}
"""
prompt = ChatPromptTemplate.from_template(prompt_text)

# Define the summarization chain
summarize_chain = {"element": lambda x: x} | prompt | model | StrOutputParser()

# Async function with delay to not exceded TPM limit
async def process_with_delay(elements, delay=2):
    results = []
    for i, element in enumerate(elements):
        print(f"Summarizing element {i+1}/{len(elements)}...")
        summary = summarize_chain.invoke(element)
        results.append(summary)
        if i < len(elements) - 1:
            await asyncio.sleep(delay)
    return results

# Run the async summarization tasks for both texts and tables
texts_summaries = await process_with_delay(texts, delay=2)
tables_html = [table.metadata.text_as_html for table in tables]
table_summaries = await process_with_delay(tables_html, delay=2)

### 5b. Caching Text and Table Summaries
To avoid re-running this step (making more API calls), the generated summaries can be saved to a `.pkl` file similar to the `chunks`. The **SAVE** cell only needs to be run once after inital summaries are made. Every subsequent session or if kernel needs to be restarted, the summaries can be loaded back in with the **LOAD** cell.

Run the cell below to **SAVE** the summaries.

In [None]:
import pickle

with open("summaries.pkl", "wb") as f:
    pickle.dump({"texts_summaries": texts_summaries, "table_summaries": table_summaries}, f)

Run the cell below to **LOAD** the summaries.

In [None]:
import pickle
with open("summaries.pkl", "rb") as f:
    data = pickle.load(f)
    texts_summaries = data["texts_summaries"]
    table_summaries = data["table_summaries"]

### 5c. Generating Image Summaries
For the images, a multi-modal model is required to understand and describe visual content like schematics and process flow diagrams (such as `gpt-4o-mini`).

**Note on Content Filtering:** The Azure OpenAI service has a content safety filter. Some images, especially complex diagrams, might be flagged, causing the API call to fail. The code includes a `try...except` block to catch these errors, log the failed images, and continue the process. A delay between requests is also included to manage API TPM limits.

First, install dependency to integrate `LangChain` and `OpenAI`.

In [None]:
%pip install -U langchain_openai

In [None]:
import time
from langchain_openai import AzureChatOpenAI
from langchain_core.prompts import ChatPromptTemplate
from langchain_core.output_parsers import StrOutputParser
from openai import BadRequestError

prompt_template_text = """Describe the image in detail. The image is part of a textbook on material
                         & energy balances and other introductory Chemical
                         Engineering concepts. Be specific about graphs,
                         such as enthalpy and entropy graphs, and diagrams
                         that show chemical engineering processes.
                      """
messages_template = [
    (
        "user",
        [
            {"type": "text", "text": prompt_template_text},
            {
                "type": "image_url",
                "image_url": {"url": "data:image/jpeg;base64,{image}"},
            },
        ],
    )
]

prompt = ChatPromptTemplate.from_messages(messages_template)

# Initialize the Azure Chat model for multi-modal tasks
llm = AzureChatOpenAI(
    openai_api_version=os.getenv("OPENAI_API_VERSION", "2024-02-01"),
    azure_deployment=os.getenv("AZURE_OPENAI_DEPLOYMENT_NAME"),
    azure_endpoint=os.getenv("AZURE_OPENAI_ENDPOINT"),
    api_key=os.getenv("AZURE_OPENAI_API_KEY"),
    model="YOUR_MULTIMODAL_LLM_MODEL_NAME",  # Replace with your model name (e.g., "gpt-4o-mini")
    max_tokens=1024
)

# Define the image summarization chain
chain = prompt | llm | StrOutputParser()

DELAY_BETWEEN_REQUESTS = 5
image_summaries = []
failed_image_indices = [] # Keeps track of images that were not able to be processed

if 'images' in locals() and images:
    for i, image_data in enumerate(images):
        print(f"Processing image {i+1} of {len(images)}...")
        try:
            summary = chain.invoke({"image": image_data})
            image_summaries.append(summary)

        except BadRequestError as e:
            # If a content filter error occurs, log it and continue
            print(f"--> Content filter triggered for image {i+1}. Skipping. Error: {e}")
            failed_image_indices.append(i)

        finally:
            # This 'finally' block ensures the delay happens even if an error occurs
            if i < len(images) - 1:
                print(f"Waiting for {DELAY_BETWEEN_REQUESTS} seconds...")
                time.sleep(DELAY_BETWEEN_REQUESTS)

    print("\nProcessing complete.")
    if failed_image_indices:
        print(f"The following image indices failed due to content filtering: {failed_image_indices}")

else:
    print("The 'images' variable is not defined or is empty. Please populate it with image data.")

### 5d. Caching Image Summaries
To avoid re-running this step (making more API calls, especially expensive for multi-modal LLMs), the generated summaries can be saved to a `.pkl` file similar to the `chunks`. The **SAVE** cell only needs to be run once after inital summaries are made. Every subsequent session or if kernel needs to be restarted, the summaries can be loaded back in with the **LOAD** cell.

Run the cells below to **SAVE** the summaries and failed image indices.

In [None]:
import pickle
with open("img_summaries.pkl", "wb") as f:
    pickle.dump({"image_summaries": image_summaries}, f)

In [None]:
import pickle
with open("failed_img_indices.pkl", "wb") as f:
    pickle.dump({"failed_image_indices": failed_image_indices}, f)

Run the cells below to **LOAD** the summaries and failed image indices.

In [None]:
import pickle
with open("img_summaries.pkl", "rb") as f:
    data = pickle.load(f)
    image_summaries = data["image_summaries"]

In [None]:
import pickle
with open("failed_img_indices.pkl", "rb") as f:
    data = pickle.load(f)
    failed_image_indices = data["failed_image_indices"]

## 6. Building the Vector Store and Retriever

With all summaries generated, the retrieval system can now be built using a `MultiVectorRetriever`.
* **Vector Store:** Contains the vector embeddings of the summaries using `ChromaDB`.
* **Document Store:** Contains the original, full-sized elements (text chunks, table objects, and raw image data). An in-memory store is used here for simplicity.
* **Link:** Each summary in the vector store is linked to its original document in the document store using a unique ID using an embedding model (such as `text-embedding-3-small`).

When a query is made, the system searches the vector store for relevant summaries and retrieves the corresponding original documents from the document store.

### 6a. Initializing the Vector and Document Stores
First, install dependency to integrate `LangChain` and `ChromaDB`.

In [None]:
%pip install langchain-chroma

In [None]:
import uuid
from langchain_chroma import Chroma
from langchain.storage import InMemoryStore
from langchain_core.documents import Document
from langchain_openai import AzureOpenAIEmbeddings
from langchain.retrievers.multi_vector import MultiVectorRetriever

# The vectorstore to use to index the child chunks (summaries)
vectorstore = Chroma(
    collection_name="multi_modal_rag",
    embedding_function=AzureOpenAIEmbeddings(
        azure_deployment=os.getenv("AZURE_EMBEDDING_DEPLOYMENT_NAME"),
        azure_endpoint=os.getenv("AZURE_EMBEDDING_ENDPOINT"),
        api_key=os.getenv("AZURE_EMBEDDING_API_KEY"),
        api_version=os.getenv("EMBEDDING_API_VERSION"),
    )
)

# The storage layer for the parent documents (original elements)
store = InMemoryStore()
id_key = "doc_id"

# The retriever, which will be populated in the next step
retriever = MultiVectorRetriever(
    vectorstore=vectorstore,
    docstore=store,
    id_key=id_key,
)

### 6b. Populating the Stores
For each element type (text, table, image), the process is:
* Generate unique IDs for each original element.
* Create Document objects for the summaries, adding the unique ID to the metadata.
* Add the summary documents to the vectorstore.
* Add the original elements to the docstore, using the same unique IDs.

In [None]:
# Add text summaries and link to original text chunks
doc_ids = [str(uuid.uuid4()) for _ in texts]
summary_texts = [
    Document(page_content=summary, metadata={id_key: doc_ids[i]}) for i, summary in enumerate(texts_summaries)
]
retriever.vectorstore.add_documents(summary_texts)
retriever.docstore.mset(list(zip(doc_ids, texts)))

# Add table summaries and link to original table objects
table_ids = [str(uuid.uuid4()) for _ in tables]
summary_tables = [
    Document(page_content=summary, metadata={id_key: table_ids[i]}) for i, summary in enumerate(table_summaries)
]
retriever.vectorstore.add_documents(summary_tables)
retriever.docstore.mset(list(zip(table_ids, tables)))

# Add image summaries and link to original image data
img_ids = [str(uuid.uuid4()) for _ in images]
summary_img = [
    Document(page_content=summary, metadata={id_key: img_ids[i]}) for i, summary in enumerate(image_summaries)
]
retriever.vectorstore.add_documents(summary_img)
retriever.docstore.mset(list(zip(img_ids, images)))

### 6c.  Testing the Retriever
A sample query can be run to see what the retriever returns. The output should be the original, full-content documents (not the summaries).

In [None]:
# Retrieve documents based on a query
docs = retriever.invoke(
    "who are the authors of the paper?"
)

# Display the retrieved documents
for doc in docs:
    if isinstance(doc, str):
        # This is a base64 image
        display_base64_image(doc)
    else:
        # This is a text or table element
        print(str(doc) + "\n\n" + "-" * 80)

## 7. RAG Pipeline

All components are now assembled into a final, runnable chain using `LangChain Expression Language` (LCEL).
The chain will:
* Take a user's question.
* Use the `retriever` to find relevant documents (text, tables, or images).
* Use a `parse_docs` function to separate the retrieved documents by type.
* Use a `build_prompt` function to construct a multi-modal prompt with the question and retrieved context.
* Send this prompt to the `Azure` multi-modal LLM model to generate the final answer.

**Two versions of the chain are created:** one for a direct answer (`chain`) and another that also returns the retrieved source documents (`chain_with_sources`).

In [None]:
from langchain_core.runnables import RunnablePassthrough, RunnableLambda
from langchain_core.prompts import ChatPromptTemplate
from langchain_core.messages import HumanMessage
from langchain_core.output_parsers import StrOutputParser

# Define the Azure LLM from the environment variables set at the start of the notebook
llm = AzureChatOpenAI(
    openai_api_version=os.getenv("OPENAI_API_VERSION"),
    azure_deployment=os.getenv("AZURE_OPENAI_DEPLOYMENT_NAME"),
)

def parse_docs(docs):
    # Split retrieved documents into images and texts/tables
    b64_images = []
    text_docs = []
    for doc in docs:
        if isinstance(doc, str):
            b64_images.append(doc)
        elif hasattr(doc, 'page_content'):
            text_docs.append(doc)
    return {"images": b64_images, "texts": text_docs}

def build_prompt(kwargs):
    # Builds the multi-modal prompt for the LLM
    docs_by_type = kwargs["context"]
    user_question = kwargs["question"]
    context_text = ""

    if len(docs_by_type["texts"]) > 0:
        for text_element in docs_by_type["texts"]:
            # Check for page_content for Documents or text for other element types
            content = getattr(text_element, 'page_content', getattr(text_element, 'text', ''))
            context_text += content + "\n\n"

    prompt_template = f"""Answer the question based only on the following context, which can include text, tables, and images.

Context:
{context_text}

Question: {user_question}
"""
    prompt_content = [{"type": "text", "text": prompt_template}]
    if len(docs_by_type["images"]) > 0:
        for image in docs_by_type["images"]:
            prompt_content.append(
                {
                    "type": "image_url",
                    "image_url": {"url": f"data:image/jpeg;base64,{image}"},
                }
            )
    return ChatPromptTemplate.from_messages([HumanMessage(content=prompt_content)])

# Define the common part of the chain to generate a response
response_generator = RunnableLambda(build_prompt) | llm | StrOutputParser()

# The RAG chain for a simple string response
chain = (
    {
        "context": retriever | RunnableLambda(parse_docs),
        "question": RunnablePassthrough(),
    }
    | response_generator
)

# The RAG chain that includes the source documents in the output
chain_with_sources = {
    "context": retriever | RunnableLambda(parse_docs),
    "question": RunnablePassthrough(),
} | RunnablePassthrough().assign(response=response_generator)

## 8. Querying the RAG Pipeline

The complete system can now be tested with questions.

**Example 1:** Factual Question about Text

In [None]:
response = chain.invoke(
    "What is a material balance?"
)

print(response)

**Example 2:** Visual Question with Source Verification
This question likely requires information from a diagram. The `chain_with_sources` is used to see both the LLM's answer and the context it used to generate it.

In [None]:
response = chain_with_sources.invoke(
    "What does a distillation column look like?"
)

print("Response:", response['response'])
print("\n" + "="*80 + "\n")
print("Context Used:\n")

# Print the text/table context
for text in response['context']['texts']:
    print("--- TEXT/TABLE CONTEXT ---")
    content = getattr(text, 'page_content', getattr(text, 'text', ''))
    print(content)
    if hasattr(text, 'metadata') and 'page_number' in text.metadata:
        print("Page number: ", text.metadata.get('page_number'))
    print("\n" + "-"*50 + "\n")

# Display the image context
for image in response['context']['images']:
    print("--- IMAGE CONTEXT ---")
    display_base64_image(image)