# Multimodal RAG Assignment

This notebook implements a multimodal Retrieval-Augmented Generation (RAG) pipeline for processing multiple PDFs (at least 200 pages) containing text, images, and tables. The pipeline includes semantic chunking, embedding storage in Milvus with Flat, HNSW, and IVF indexes, retriever time comparison, accuracy evaluation, reranking with BM25 and MMR, and rendering output to a DOCX file.

## Setup

Install the required libraries:
```bash
pip install unstructured[pdf] langchain langchain-openai pymilvus rank_bm25 python-docx PyPDF2 sentence-transformers
```

Set up environment variables for OpenAI API key and Milvus connection.

## Prerequisites
- A directory (`data/pdfs/`) with PDFs totaling at least 200 pages.
- A running Milvus server (local or remote).
- OpenAI API key for embeddings and LLM.

In [None]:
import os
import base64
import io
import re
import time
import uuid
from pathlib import Path
from PIL import Image
from IPython.display import HTML, display
from unstructured.partition.pdf import partition_pdf
from langchain.retrievers.multi_vector import MultiVectorRetriever
from langchain.storage import InMemoryStore
from langchain_core.documents import Document
from langchain_openai import OpenAIEmbeddings, ChatOpenAI
from langchain.text_splitter import SemanticChunker
from pymilvus import connections, Collection, FieldSchema, CollectionSchema, DataType, utility
from rank_bm25 import BM25Okapi
from sentence_transformers import SentenceTransformer, util
from docx import Document as DocxDocument
from docx.shared import Inches
import PyPDF2
from langchain_core.runnables import RunnableLambda, RunnablePassthrough
from langchain_core.output_parsers import StrOutputParser
from langchain_core.messages import HumanMessage

from dotenv import load_dotenv
load_dotenv()

# Set environment variables
os.environ['OPENAI_API_KEY'] = os.getenv("OPENAI_API_KEY")
os.environ['PATH'] += os.pathsep + '/opt/homebrew/bin'  # For Poppler

# Directory containing PDFs
PDF_DIR = 'data/pdfs/'
EXTRACTED_DATA_DIR = 'extracted_data/'

# Ensure directories exist
os.makedirs(PDF_DIR, exist_ok=True)
os.makedirs(EXTRACTED_DATA_DIR, exist_ok=True)

# Milvus connection
MILVUS_HOST = 'localhost'
MILVUS_PORT = '19530'

## Step 1: Fetch Data from PDFs

Iterate through all PDFs in the specified directory and extract text, images, and tables using the `unstructured` library.

In [None]:
def count_pdf_pages(pdf_dir):
    """Count total pages across all PDFs in the directory."""
    total_pages = 0
    for pdf_file in Path(pdf_dir).glob('*.pdf'):
        with open(pdf_file, 'rb') as f:
            pdf = PyPDF2.PdfReader(f)
            total_pages += len(pdf.pages)
    return total_pages

def extract_pdf_data(pdf_dir, extracted_data_dir):
    """Extract text, images, and tables from all PDFs."""
    texts = []
    tables = []
    images = []
    
    total_pages = count_pdf_pages(pdf_dir)
    if total_pages < 200:
        raise ValueError(f"Total pages ({total_pages}) is less than 200. Please add more PDFs.")
    
    for pdf_file in Path(pdf_dir).glob('*.pdf'):
        elements = partition_pdf(
            filename=str(pdf_file),
            strategy='hi_res',
            extract_images_in_pdf=True,
            extract_image_block_types=['Image', 'Table'],
            extract_image_block_to_payload=False,
            extract_image_block_output_dir=extracted_data_dir
        )
        
        for element in elements:
            if element.__class__.__name__ == 'Text' or element.__class__.__name__ == 'NarrativeText' or element.__class__.__name__ == 'Title':
                texts.append(str(element))
            elif element.__class__.__name__ == 'Table':
                tables.append(str(element))
            elif element.__class__.__name__ == 'Image':
                # Read the saved image file
                image_path = os.path.join(extracted_data_dir, f'{element.metadata.image_path.split("/")[-1]}')
                with open(image_path, 'rb') as img_file:
                    img_data = base64.b64encode(img_file.read()).decode('utf-8')
                    images.append(img_data)
    
    return texts, tables, images

# Extract data
texts, tables, images = extract_pdf_data(PDF_DIR, EXTRACTED_DATA_DIR)

## Step 2: Semantic Chunking

Apply semantic chunking to the extracted text to create meaningful chunks for embedding.

In [None]:
def semantic_chunking(texts, embedding_function):
    """Apply semantic chunking to texts."""
    text_splitter = SemanticChunker(embedding_function, breakpoint_threshold_type='percentile')
    chunked_texts = []
    for text in texts:
        chunks = text_splitter.split_text(text)
        chunked_texts.extend(chunks)
    return chunked_texts

# Initialize embedding function
embedding_function = OpenAIEmbeddings()

# Chunk texts
chunked_texts = semantic_chunking(texts, embedding_function)

## Step 3: Generate Summaries for Tables and Images

Generate summaries for tables and images to use as indexing content.

In [None]:
def generate_summaries(content, content_type, model):
    """Generate summaries for tables or images using an LLM."""
    summaries = []
    for item in content:
        if content_type == 'table':
            prompt = f"Summarize the following table content in a concise manner:{item}

Summary:"
            response = model.invoke(prompt)
            summaries.append(response.content)
        elif content_type == 'image':
            prompt = [HumanMessage(content=[
                {"type": "image_url", "image_url": {"url": f"data:image/jpeg;base64,{item}"}},
                {"type": "text", "text": "Describe this image in a concise manner."}
            ])]
            response = model.invoke(prompt)
            summaries.append(response.content)
    return summaries

# Initialize LLM for summaries
llm = ChatOpenAI(temperature=0, model="gpt-4o", max_tokens=1024)

# Generate summaries
table_summaries = generate_summaries(tables, 'table', llm)
image_summaries = generate_summaries(images, 'image', llm)

## Step 4: Store in Milvus Vector Database

Store embeddings in Milvus with three collections, each using a different index type (Flat, HNSW, IVF).

In [None]:
def create_milvus_collection(collection_name, index_type):
    """Create a Milvus collection with the specified index type."""
    fields = [
        FieldSchema(name="id", dtype=DataType.VARCHAR, is_primary=True, max_length=36),
        FieldSchema(name="embedding", dtype=DataType.FLOAT_VECTOR, dim=1536),  # OpenAI embedding dimension
        FieldSchema(name="content", dtype=DataType.VARCHAR, max_length=65535),
        FieldSchema(name="content_type", dtype=DataType.VARCHAR, max_length=50)
    ]
    schema = CollectionSchema(fields=fields, description=f"{collection_name} collection")
    collection = Collection(name=collection_name, schema=schema)
    
    if index_type == 'FLAT':
        index_params = {"metric_type": "L2", "index_type": "FLAT", "params": {}}
    elif index_type == 'HNSW':
        index_params = {"metric_type": "L2", "index_type": "HNSW", "params": {"M": 16, "efConstruction": 200}}
    elif index_type == 'IVF_FLAT':
        index_params = {"metric_type": "L2", "index_type": "IVF_FLAT", "params": {"nlist": 1024}}
    
    collection.create_index(field_name="embedding", index_params=index_params)
    collection.load()
    return collection

def store_in_milvus(texts, tables, images, text_summaries, table_summaries, image_summaries, index_type):
    """Store embeddings in a Milvus collection."""
    collection_name = f"mmrag_{index_type.lower()}"
    if utility.has_collection(collection_name):
        utility.drop_collection(collection_name)
    
    collection = create_milvus_collection(collection_name, index_type)
    
    # Combine all content
    all_contents = texts + table_summaries + image_summaries
    all_originals = texts + tables + images
    content_types = ['text'] * len(texts) + ['table'] * len(tables) + ['image'] * len(images)
    
    # Generate embeddings
    embeddings = embedding_function.embed_documents(all_contents)
    
    # Insert data
    entities = [
        [str(uuid.uuid4()) for _ in range(len(all_contents))],
        embeddings,
        all_originals,
        content_types
    ]
    collection.insert(entities)
    
    return collection

# Connect to Milvus
connections.connect(host=MILVUS_HOST, port=MILVUS_PORT)

# Store data in three collections
flat_collection = store_in_milvus(chunked_texts, tables, images, chunked_texts, table_summaries, image_summaries, 'FLAT')
hnsw_collection = store_in_milvus(chunked_texts, tables, images, chunked_texts, table_summaries, image_summaries, 'HNSW')
ivf_collection = store_in_milvus(chunked_texts, tables, images, chunked_texts, table_summaries, image_summaries, 'IVF_FLAT')

## Step 5: Create Retriever Pipeline

Create a custom retriever for Milvus that retrieves text, tables, and images based on query embeddings.

In [None]:
class MilvusRetriever:
    def __init__(self, collection, embedding_function):
        self.collection = collection
        self.embedding_function = embedding_function
    
    def invoke(self, query, top_k=5):
        query_embedding = self.embedding_function.embed_query(query)
        search_params = {"metric_type": "L2", "params": {"nprobe": 10}} if 'IVF' in self.collection.name else {"metric_type": "L2"}
        results = self.collection.search(
            data=[query_embedding],
            anns_field="embedding",
            param=search_params,
            limit=top_k,
            output_fields=["content", "content_type"]
        )
        docs = []
        for hits in results:
            for hit in hits:
                docs.append(Document(page_content=hit.entity.get('content'), metadata={'content_type': hit.entity.get('content_type')}))
        return docs

# Create retrievers for each index type
flat_retriever = MilvusRetriever(flat_collection, embedding_function)
hnsw_retriever = MilvusRetriever(hnsw_collection, embedding_function)
ivf_retriever = MilvusRetriever(ivf_collection, embedding_function)

## Step 6: Measure Retriever Time

Compare the retrieval time for Flat, HNSW, and IVF indexes.

In [None]:
def measure_retrieval_time(retriever, query, runs=10):
    """Measure average retrieval time for a query."""
    times = []
    for _ in range(runs):
        start_time = time.time()
        retriever.invoke(query)
        times.append(time.time() - start_time)
    return sum(times) / len(times)

# Test query
test_query = "What is hypertension?"

# Measure times
flat_time = measure_retrieval_time(flat_retriever, test_query)
hnsw_time = measure_retrieval_time(hnsw_retriever, test_query)
ivf_time = measure_retrieval_time(ivf_retriever, test_query)

print(f"Flat Index Retrieval Time: {flat_time:.4f} seconds")
print(f"HNSW Index Retrieval Time: {hnsw_time:.4f} seconds")
print(f"IVF Index Retrieval Time: {ivf_time:.4f} seconds")

## Step 7: Evaluate Accuracy of Similarity Search

Calculate precision@k for each index type using predefined query-result pairs.

In [None]:
def evaluate_accuracy(retriever, test_cases, k=5):
    """Calculate precision@k for a retriever."""
    total_precision = 0
    for query, relevant_docs in test_cases.items():
        retrieved_docs = retriever.invoke(query, top_k=k)
        retrieved_contents = [doc.page_content for doc in retrieved_docs]
        relevant_count = sum(1 for doc in retrieved_contents if doc in relevant_docs)
        precision = relevant_count / k if retrieved_docs else 0
        total_precision += precision
    return total_precision / len(test_cases)

# Example test cases (replace with actual relevant documents from your PDFs)
test_cases = {
    "What is hypertension?": [chunked_texts[0]],  # Assume first chunk is relevant
    "Explain the first table": [tables[0]],
}

# Evaluate accuracy
flat_accuracy = evaluate_accuracy(flat_retriever, test_cases)
hnsw_accuracy = evaluate_accuracy(hnsw_retriever, test_cases)
ivf_accuracy = evaluate_accuracy(ivf_retriever, test_cases)

print(f"Flat Index Accuracy: {flat_accuracy:.4f}")
print(f"HNSW Index Accuracy: {hnsw_accuracy:.4f}")
print(f"IVF Index Accuracy: {ivf_accuracy:.4f}")

## Step 8: Reranking with BM25 and MMR

Apply BM25 and MMR reranking to the retrieved results.

In [None]:
def bm25_rerank(query, docs):
    """Rerank documents using BM25."""
    tokenized_docs = [doc.page_content.split() for doc in docs]
    bm25 = BM25Okapi(tokenized_docs)
    tokenized_query = query.split()
    scores = bm25.get_scores(tokenized_query)
    sorted_docs = [doc for _, doc in sorted(zip(scores, docs), key=lambda x: x[0], reverse=True)]
    return sorted_docs

def mmr_rerank(query, docs, embedding_model, lambda_param=0.6):
    """Rerank documents using MMR."""
    query_embedding = embedding_model.encode(query)
    doc_embeddings = [embedding_model.encode(doc.page_content) for doc in docs]
    selected = []
    remaining = list(range(len(docs)))
    
    # Select the document with the highest similarity to the query
    scores = [util.cos_sim(query_embedding, doc_emb)[0][0].item() for doc_emb in doc_embeddings]
    max_idx = scores.index(max(scores))
    selected.append(max_idx)
    remaining.remove(max_idx)
    
    # Iteratively select documents
    while remaining:
        mmr_scores = []
        for i in remaining:
            sim_to_query = util.cos_sim(query_embedding, doc_embeddings[i])[0][0].item()
            sim_to_selected = max([util.cos_sim(doc_embeddings[i], doc_embeddings[j])[0][0].item() for j in selected], default=0)
            mmr_score = lambda_param * sim_to_query - (1 - lambda_param) * sim_to_selected
            mmr_scores.append(mmr_score)
        max_idx = remaining[mmr_scores.index(max(mmr_scores))]
        selected.append(max_idx)
        remaining.remove(max_idx)
    
    return [docs[i] for i in selected]

# Initialize sentence transformer for MMR
embedding_model = SentenceTransformer('all-MiniLM-L6-v2')

## Step 9: Prompt Template

Define a prompt template for the LLM to generate responses based on retrieved data.

In [None]:
def img_prompt_func(data_dict):
    """Create a prompt for the LLM with text, tables, and images."""
    formatted_texts = "\n".join([doc.page_content for doc in data_dict['context'] if doc.metadata['content_type'] == 'text'])
    formatted_tables = "\n".join([doc.page_content for doc in data_dict['context'] if doc.metadata['content_type'] == 'table'])
    images = [doc.page_content for doc in data_dict['context'] if doc.metadata['content_type'] == 'image']
    
    messages = []
    
    # Add images
    for image in images:
        messages.append({"type": "image_url", "image_url": {"url": f"data:image/jpeg;base64,{image}"}})
    
    # Add text and tables
    text_message = {
        "type": "text",
        "text": (
            "You are a helpful assistant. Use the provided information to answer the user's question accurately. The information includes text, tables, and images.\n"
            f"User question: {data_dict['question']}\n\n"
            "Text content:\n"
            f"{formatted_texts}\n\n"
            "Table content:\n"
            f"{formatted_tables}\n\n"
            "Images are provided separately. Use them if relevant to the question."
        )
    }
    messages.append(text_message)
    
    return [HumanMessage(content=messages)]

## Step 10: Multimodal RAG Chain

Create a multimodal RAG chain with reranking.

In [None]:
def multi_modal_rag_chain(retriever, rerank_method='bm25'):
    """Create a multimodal RAG chain with reranking."""
    model = ChatOpenAI(temperature=0, model="gpt-4o", max_tokens=1024)
    
    def rerank_docs(inputs):
        docs = inputs['context']
        query = inputs['question']
        if rerank_method == 'bm25':
            return {'context': bm25_rerank(query, docs), 'question': query}
        else:  # MMR
            return {'context': mmr_rerank(query, docs, embedding_model), 'question': query}
    
    chain = (
        {
            "context": retriever.invoke,
            "question": RunnablePassthrough()
        }
        | RunnableLambda(rerank_docs)
        | RunnableLambda(img_prompt_func)
        | model
        | StrOutputParser()
    )
    return chain

# Create RAG chains for each index type with BM25 and MMR
flat_bm25_chain = multi_modal_rag_chain(flat_retriever, 'bm25')
flat_mmr_chain = multi_modal_rag_chain(flat_retriever, 'mmr')
hnsw_bm25_chain = multi_modal_rag_chain(hnsw_retriever, 'bm25')
hnsw_mmr_chain = multi_modal_rag_chain(hnsw_retriever, 'mmr')
ivf_bm25_chain = multi_modal_rag_chain(ivf_retriever, 'bm25')
ivf_mmr_chain = multi_modal_rag_chain(ivf_retriever, 'mmr')

## Step 11: Generate and Render Output to DOCX

Run the RAG chain and save the output to a DOCX file.

In [None]:
def save_to_docx(query, response, docs, output_file='output.docx'):
    """Save query, response, and retrieved documents to a DOCX file."""
    doc = DocxDocument()
    doc.add_heading('Multimodal RAG Output', 0)
    
    doc.add_heading('Query', level=1)
    doc.add_paragraph(query)
    
    doc.add_heading('Response', level=1)
    doc.add_paragraph(response)
    
    doc.add_heading('Retrieved Documents', level=1)
    for i, d in enumerate(docs, 1):
        doc.add_heading(f'Document {i}', level=2)
        if d.metadata['content_type'] == 'image':
            img_data = base64.b64decode(d.page_content)
            img = Image.open(io.BytesIO(img_data))
            img.save('temp_image.jpg')
            doc.add_picture('temp_image.jpg', width=Inches(4))
        else:
            doc.add_paragraph(d.page_content)
        doc.add_paragraph(f"Type: {d.metadata['content_type']}")
    
    doc.save(output_file)

# Test the RAG chain
query = "What is hypertension? Explain the first table in the PDF."
response = ivf_bm25_chain.invoke(query)
docs = ivf_retriever.invoke(query)
save_to_docx(query, response, docs, 'output_hypertension.docx')