Function to extract text and images from pdf file.
The extracted texts will also be already chunked.

In [40]:
from pdf2image import convert_from_path
from PIL import Image
import pytesseract
import os
from langchain.text_splitter import RecursiveCharacterTextSplitter

def extract_content(pdf_path, output_dir, resize_width=None, resize_height=None):
    # Ensure the output directory exists
    os.makedirs(output_dir, exist_ok=True)

    extracted_data = {
        "text": [],
        "images": []
    }

    # Convert PDF pages to images
    pages = convert_from_path(pdf_path, dpi=300)

    # Initialize RecursiveCharacterTextSplitter for text chunking
    text_splitter = RecursiveCharacterTextSplitter(
        chunk_size=400,
        chunk_overlap=200,
        length_function=len,
        is_separator_regex=False
    )

    # Process each page
    for i, page in enumerate(pages):
        # Save the page as an image
        image_path = os.path.join(output_dir, f"page_{i + 1}.jpg")

        # Resize the image if dimensions are provided
        if resize_width or resize_height:
            # Keep aspect ratio if only one dimension is provided
            if not resize_width:
                aspect_ratio = page.width / page.height
                resize_width = int(resize_height * aspect_ratio)
            elif not resize_height:
                aspect_ratio = page.height / page.width
                resize_height = int(resize_width / aspect_ratio)

            # Resize the image
            page = page.resize((resize_width, resize_height), Image.ANTIALIAS)

        # Save the image file
        page.save(image_path, "JPEG")
        extracted_data["images"].append(image_path)

        # Extract text using pytesseract OCR
        raw_text = pytesseract.image_to_string(Image.open(image_path))
        chunked_text = text_splitter.split_text(raw_text)
        extracted_data["text"].extend(chunked_text)

    return extracted_data

Function to generate embeddings for images and texts
This is separated to check if there is a huge difference in speed when generating the text and image embeddings.
Additionally, the generate_text_embeddings accepts list of texts, while the generate_image_embeddings accepts file path.

In [11]:
import torch
from colpali_engine.models import ColQwen2, ColQwen2Processor
from colpali_engine.compression.token_pooling import HierarchicalTokenPooler
from transformers.utils.import_utils import is_flash_attn_2_available
import numpy as np

def generate_text_embeddings(text_list, model_name="vidore/colqwen2-v1.0", pool_factor=3):
    """
    Generate pooled embeddings for text chunks using the ColQwen2 model on GPU.
    """
    device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
    model = ColQwen2.from_pretrained(
        model_name,
        torch_dtype=torch.bfloat16,
        device_map={"": device},
        attn_implementation="flash_attention_2" if is_flash_attn_2_available() else None,
    ).eval().to(device)
    processor = ColQwen2Processor.from_pretrained(model_name, use_fast=True)

    batch = processor.process_queries(text_list).to(model.device)
    
    with torch.no_grad():
        embeddings = model(**batch).cpu().float().numpy()

    pooled_embeddings = np.mean(embeddings, axis=1)
    pooled_embeddings = pooled_embeddings.tolist()
    
    return pooled_embeddings

from PIL import Image

def generate_image_embeddings(image_paths, model_name="vidore/colqwen2-v1.0"):
    """
    Generate pooled embeddings for image paths using the ColQwen2 model on GPU.
    """
    device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
    model = ColQwen2.from_pretrained(
        model_name,
        torch_dtype=torch.bfloat16,
        device_map={"": device},
        attn_implementation="flash_attention_2" if is_flash_attn_2_available() else None,
    ).eval().to(device)
    processor = ColQwen2Processor.from_pretrained(model_name, use_fast=True)

    images = [Image.open(image_path).convert("RGB") for image_path in image_paths]
    processed_images = processor.process_images(images).to(device)

    with torch.no_grad():
        embeddings = model(**processed_images).cpu().float().numpy()
        
    pooled_embeddings = np.mean(embeddings, axis=1)
    pooled_embeddings = pooled_embeddings.tolist()  

    return pooled_embeddings

In [42]:
import psycopg2

def create_table(db_config):
    """
    Create a single table to store metadata and embeddings.

    Args:
        db_config (dict): Database configuration with keys 'host', 'database', 'user', 'password'.
    """
    connection = None  # Initialize the connection
    try:
        # Connect to the PostgreSQL database
        connection = psycopg2.connect(**db_config)
        cursor = connection.cursor()

        # Enable pgvector extension
        cursor.execute("CREATE EXTENSION IF NOT EXISTS vector;")

        # Create a combined table for metadata and embeddings
        cursor.execute("""
        CREATE TABLE embeddings (
                id SERIAL PRIMARY KEY,       -- Unique ID for each entry
                content TEXT,                -- Text or image link
                is_image BOOLEAN,            -- Indicates if the content is an image (True for images, False for text)
                embedding VECTOR(128)        -- The pooled embedding (1D vector of 128 dimensions)
            );
        """)

        # Commit the changes
        connection.commit()
        print("Table successfully created!")

    except Exception as e:
        print(f"Error creating table: {e}")

    finally:
        if connection:
            cursor.close()
            connection.close()

In [43]:
def load_text_embeddings(db_config, text_chunks):
    """
    Insert text embeddings into the database with is_image set to False.

    Args:
        db_config (dict): Database configuration with keys 'host', 'database', 'user', 'password'.
        text_chunks (list): List of text chunks to embed and insert into the database.
    """
    connection = None
    try:
        import psycopg2

        # Generate text embeddings
        embeddings = generate_text_embeddings(text_chunks)

        # Connect to the database
        connection = psycopg2.connect(**db_config)
        cursor = connection.cursor()

        # Insert text embeddings into the database
        for text, embedding in zip(text_chunks, embeddings):
            cursor.execute(
                "INSERT INTO embeddings (content, is_image, embedding) VALUES (%s, %s, %s);",
                (text, False, embedding)
            )

        # Commit the transaction
        connection.commit()
        print("Inserted text embeddings successfully!")

    except Exception as e:
        print(f"Error inserting text embeddings: {e}")

    finally:
        if cursor:
            cursor.close()
        if connection:
            connection.close()

In [44]:
def load_image_embeddings(db_config, image_paths):
    """
    Insert image embeddings into the database with is_image set to True.

    Args:
        db_config (dict): Database configuration with keys 'host', 'database', 'user', 'password'.
        image_paths (list): List of image file paths to embed and insert into the database.
    """
    connection = None
    try:
        import psycopg2

        # Generate image embeddings
        embeddings = generate_image_embeddings(image_paths)

        # Connect to the database
        connection = psycopg2.connect(**db_config)
        cursor = connection.cursor()

        # Insert image embeddings into the database
        for image_path, embedding in zip(image_paths, embeddings):
            cursor.execute(
                "INSERT INTO embeddings (content, is_image, embedding) VALUES (%s, %s, %s);",
                (image_path, True, embedding)
            )

        # Commit the transaction
        connection.commit()
        print("Inserted image embeddings successfully!")

    except Exception as e:
        print(f"Error inserting image embeddings: {e}")

    finally:
        if cursor:
            cursor.close()
        if connection:
            connection.close()

In [14]:
def retrieve_similar_content(db_config, query_embedding, top_k=5, include_images="False"):
    """
    Retrieve the top-K most similar content based on query embedding,
    with three modes for filtering: text only, images only, or both.

    Args:
        db_config (dict): Database configuration with keys 'host', 'database', 'user', 'password'.
        query_embedding (list): Query embedding (1D vector of 128 dimensions).
        top_k (int): The number of most similar content items to retrieve for each mode.
        include_images (str): Mode to filter results - "False", "True", or "Both".

    Returns:
        list: A list of dictionaries containing content, distance, and is_image flag.
    """
    connection = None
    cursor = None  # Initialize cursor to ensure it exists in case of exceptions
    try:
        import psycopg2

        # Convert numpy embedding to list if needed
        if isinstance(query_embedding, np.ndarray):
            query_embedding = query_embedding.tolist()

        # Connect to the database
        connection = psycopg2.connect(**db_config)
        cursor = connection.cursor()

        # Initialize result list
        similar_content = []

        # Mode: Text Only (include_images=False)
        if include_images == "False":
            retrieve_query = """
            SELECT content, is_image, embedding <-> %s::vector AS distance
            FROM embeddings
            WHERE is_image = False
            ORDER BY distance ASC
            LIMIT %s;
            """
            cursor.execute(retrieve_query, (query_embedding, top_k))
            results = cursor.fetchall()

            for content, is_image, distance in results:
                similar_content.append({
                    "content": content,
                    "is_image": is_image,
                    "distance": round(distance, 4)
                })

        # Mode: Images Only (include_images=True)
        elif include_images == "True":
            retrieve_query = """
            SELECT content, is_image, embedding <-> %s::vector AS distance
            FROM embeddings
            WHERE is_image = True
            ORDER BY distance ASC
            LIMIT %s;
            """
            cursor.execute(retrieve_query, (query_embedding, top_k))
            results = cursor.fetchall()

            for content, is_image, distance in results:
                similar_content.append({
                    "content": content,
                    "is_image": is_image,
                    "distance": round(distance, 4)
                })

        # Mode: Both (include_images="Both")
        elif include_images == "Both":
            # Retrieve top-k text results
            retrieve_text_query = """
            SELECT content, is_image, embedding <-> %s::vector AS distance
            FROM embeddings
            WHERE is_image = False
            ORDER BY distance ASC
            LIMIT %s;
            """
            cursor.execute(retrieve_text_query, (query_embedding, top_k))
            text_results = cursor.fetchall()

            for content, is_image, distance in text_results:
                similar_content.append({
                    "content": content,
                    "is_image": is_image,
                    "distance": round(distance, 4)
                })

            # Retrieve top-k image results
            retrieve_image_query = """
            SELECT content, is_image, embedding <-> %s::vector AS distance
            FROM embeddings
            WHERE is_image = True
            ORDER BY distance ASC
            LIMIT %s;
            """
            cursor.execute(retrieve_image_query, (query_embedding, top_k))
            image_results = cursor.fetchall()

            for content, is_image, distance in image_results:
                similar_content.append({
                    "content": content,
                    "is_image": is_image,
                    "distance": round(distance, 4)
                })

        return similar_content

    except Exception as e:
        print(f"Error retrieving similar content: {e}")
        return []

    finally:
        if cursor:
            cursor.close()
        if connection:
            connection.close()

In [56]:
def create_index(db_config):
    
    try:
        conn = psycopg2.connect(**db_config)
        cursor = conn.cursor()

        create_index_query = """
        CREATE INDEX hnsw_index
        ON embeddings
        USING hnsw (embedding vector_l2_ops)
        WITH (m = 16, ef_construction = 200);
        """

        cursor.execute(create_index_query)
        conn.commit()

        print("HNSW index created successfully!")

    except Exception as e:
        print("Error:", e)

    finally:
        if cursor:
            cursor.close()
        if conn:
            conn.close()


In [9]:
def create_prompt(query, db_config, top_k=5):
    """
    Create a prompt for the LLM based on the query by generating embeddings and performing similarity searches.

    Args:
        query (str): The user's query.
        db_config (dict): Database configuration with keys 'host', 'database', 'user', 'password'.
        top_k (int): The number of most similar content items to retrieve for each type.

    Returns:
        str: The formatted prompt for the LLM.
    """
    # Step 1: Generate query embeddings
    query_embedding = generate_text_embeddings([query])[0]  # Generate embedding for the query text

    # Step 2: Retrieve similar content
    retrieved_content = retrieve_similar_content(db_config, query_embedding, top_k=top_k, include_images="Both")

    # Step 3: Format the retrieved content into a prompt
    text_matches = [item for item in retrieved_content if not item["is_image"]]
    image_matches = [item for item in retrieved_content if item["is_image"]]

    prompt = f"User query: {query}\n\n"

    prompt += "Here are the most relevant text results:\n"
    for text in text_matches:
        prompt += f"- {str(text['content'])} (Similarity Score: {text['distance']})\n"

    prompt += "\nHere are the most relevant images:\n"
    for image in image_matches:
        prompt += f"- {str(image['content'])} (Similarity Score: {image['distance']})\n"

    prompt += "\nGenerate a response based on the query and the above context."
    return prompt

In [22]:
def retrieve_image(query, db_config, top_k = 1):
    query_embedding = generate_text_embeddings([query])[0]
    retrieved_content = retrieve_similar_content(db_config, query_embedding, top_k=top_k, include_images="True")

    return retrieved_content

In [8]:
from ollama import chat
from ollama import ChatResponse

def generate_llm_response(query, db_config, model="gemma3:4b", top_k=5):
    """
    Generate a response using the LLM based on the query by calling create_prompt.

    Args:
        query (str): The user's query.
        db_config (dict): Database configuration with keys 'host', 'database', 'user', 'password'.
        model (str): The model name to use via Ollama (default: gemma3:4b).
        top_k (int): The number of most similar content items to retrieve for each type.

    Returns:
        str: The response generated by the LLM.
    """
    try:
        # Step 1: Create the prompt
        prompt = create_prompt(query, db_config, top_k=top_k)

        # Step 2: Send the prompt to Ollama and retrieve the response
        response: ChatResponse = chat(model=model, messages=[
            {
                'role': 'user',
                'content': prompt,
            }
        ])

        # Step 3: Return the generated response
        return response.message.content

    except Exception as e:
        return f"Error generating response: {str(e)}"

Don't edit the code below this part is for testing the functions

In [45]:
data = extract_content("AMD Q4'24 Earnings Slides.pdf", "C:/Users/raimo/python/multimodal_rag/output_images")

In [46]:
from dotenv import dotenv_values

db_config = dotenv_values('db_config.env')
create_table(db_config)

Table successfully created!


In [47]:
load_text_embeddings(db_config, data['text'])

Fetching 2 files:   0%|          | 0/2 [00:00<?, ?it/s]

Loading checkpoint shards:   0%|          | 0/2 [00:00<?, ?it/s]

Inserted text embeddings successfully!


In [49]:
load_image_embeddings(db_config, data['images'])

Fetching 2 files:   0%|          | 0/2 [00:00<?, ?it/s]

Loading checkpoint shards:   0%|          | 0/2 [00:00<?, ?it/s]

Inserted image embeddings successfully!


In [50]:
sample_query = ["What is the earnings per share of the company for year 2024?"]
sample_query_embeddings = generate_text_embeddings(sample_query)[0]

Fetching 2 files:   0%|          | 0/2 [00:00<?, ?it/s]

Loading checkpoint shards:   0%|          | 0/2 [00:00<?, ?it/s]

In [58]:
retrieve_similar_content(db_config, sample_query_embeddings, top_k=5, include_images="False")

[{'content': 'EARNINGS PER SHARE 4 2024\n\nGAAP\n\nNon-GAAP"\n\n$1.09\n$0.77\n$0.41\nQ4 2023 Q4 2024 Q4 2023 Q4 2024\n\n= GAAP net income of $482 million = Record non-GAAP net income of $1.8 billion\n= GAAP EPS down 29% y/y, primarily driven by higher = Non-GAAP EPS up 42% y/y, primarily driven by higher\n\nrevenue and gross margin, more than offset by higher revenue and gross margin, partially offset by higher',
  'is_image': False,
  'distance': 0.4738},
 {'content': 'EARNINGS PER SHARE FY 2024\n\nGAAP\n\n$1.00\n$0.53\nFY 2023 FY 2024\n= GAAP net income of $1.6 billion, up 92% y/y "\n= GAAP EPS of $1.00, up 89% y/y, primarily driven 8\n\nby higher revenue and gross margin, and lower\namortization of acquisition-related intangible\nassets, partially offset by higher operating\nexpenses and a one-time tax provision\n\nNon-GAAP’\n$3.31\n\n$2.65\n\nFY 2023 FY 2024',
  'is_image': False,
  'distance': 0.4782},
 {'content': 'Non GAAP net income / earnings per share $1,777 $ 1.09 $1,249 $ 0

In [57]:
create_index(db_config)

HNSW index created successfully!


In [16]:
gradio_interface('What are the growth opportunities?')

Fetching 2 files:   0%|          | 0/2 [00:00<?, ?it/s]

Loading checkpoint shards:   0%|          | 0/2 [00:00<?, ?it/s]

'Based on the provided text results, AMD has significant growth opportunities across diverse markets. The presentation specifically highlights AMD’s expected first quarter 2025 financial outlook, including revenue, non-GAAP gross margin, and operating expenses. However, it’s important to note that these forward-looking statements are based on current expectations as of February 4, 2025, and are subject to potential changes due to factors like competition with Intel. The presentation also details forward-looking non-GAAP measures concerning AMD’s financial outlook.'

In [None]:
import gradio as gr
from dotenv import dotenv_values

db_config = dotenv_values('db_config.env')

import gradio as gr

# Gradio wrapper for query handling
def gradio_interface(query):
    try:
        # Generate response using Gemma3:4b model
        response = generate_llm_response(query, db_config, model="gemma3:4b", top_k=5)
        return response
    except Exception as e:
        return f"Error: {str(e)}"

# Create the enhanced Gradio UI
with gr.Blocks() as demo:
    # Title Section
    gr.Markdown(
        """
        ## Multimodal RAG
        Experience the power of multimodal retrieval and response generation using cutting-edge AI technology. 
        Enter your query below and receive contextually relevant insights based on text and image data.
        """,
        elem_classes="title-section"
    )
    
    # Query Input Section
    with gr.Row(elem_classes="input-section"):
        query_input = gr.Textbox(
            label="Your Query", 
            placeholder="Type your question here...",
            lines=2,
            elem_classes="query-input"
        )

    # Response Output Section
    response_output = gr.Textbox(
        label="Response", 
        placeholder="Generated response will appear here.",
        lines=8,
        interactive=False,
        elem_classes="response-output"
    )

    # Submit Button Section
    submit_button = gr.Button(
        "Generate Response", 
        elem_classes="submit-button"
    )

    # Link the input and output with the function
    submit_button.click(fn=gradio_interface, inputs=query_input, outputs=response_output)

# Custom CSS for styling
demo.css = """
.title-section {
    font-size: 1.5em;
    color: #1e293b;
    text-align: center;
    margin-bottom: 20px;
}

.input-section {
    margin-bottom: 20px;
}

.query-input {
    width: 100%;
    padding: 10px;
    border-radius: 5px;
    border: 1px solid #1e293b;
}

.response-output {
    background-color: #f1f5f9;
    padding: 10px;
    border-radius: 5px;
    border: 1px solid #1e293b;
}

.submit-button {
    background-color: #2563eb;
    color: #ffffff;
    border-radius: 5px;
    padding: 10px 20px;
    font-size: 1em;
    cursor: pointer;
    text-transform: uppercase;
    box-shadow: 0px 4px 6px rgba(0, 0, 0, 0.1);
}

.submit-button:hover {
    background-color: #1d4ed8;
}
"""

# Launch the enhanced Gradio UI
demo.launch()

In [30]:
import gradio as gr

# Wrapper function for handling query
def wrapper_function(query):
    db_config = db_config
    
    try:
        # Text Response
        response_text = generate_llm_response(query, db_config, model="gemma3:4b", top_k=5)
        
        # Image Retrieval
        query_embedding = generate_text_embeddings([query])[0]
        retrieved_images = retrieve_similar_content(db_config, query_embedding, top_k=1, include_images="True")
        closest_image_path = retrieved_images[0]["content"] if retrieved_images else None

        return response_text, closest_image_path
    except Exception as e:
        return f"Error: {str(e)}", None

In [19]:
import gradio as gr
from dotenv import dotenv_values

# Load database configuration
db_config = dotenv_values('db_config.env')

# Wrapper function for handling query
def wrapper_function(query):
    try:
        # Text Response
        response_text = generate_llm_response(query, db_config, model="gemma3:4b", top_k=5)
        
        # Image Retrieval
        query_embedding = generate_text_embeddings([query])[0]
        retrieved_images = retrieve_similar_content(db_config, query_embedding, top_k=1, include_images="True")
        closest_image_path = retrieved_images[0]["content"] if retrieved_images else None

        return response_text, closest_image_path
    except Exception as e:
        return f"Error: {str(e)}", None

# Create the enhanced Gradio UI
with gr.Blocks() as demo:
    # Title Section
    gr.Markdown(
        """
        ## Multimodal RAG
        Experience the power of multimodal retrieval and response generation using cutting-edge AI technology. 
        Upload a PDF, enter your query, and receive contextually relevant insights based on text and image data.
        """,
        elem_classes="title-section"
    )

    # File Upload Section
    with gr.Tab("Upload PDF"):
        pdf_input = gr.File(label="Upload PDF", file_types=[".pdf"])
        extracted_images = gr.Gallery(label="Extracted Images").style(grid=[3], height="auto")

    # Query Input Section
    with gr.Tab("Query"):
        query_input = gr.Textbox(
            label="Your Query", 
            placeholder="Type your question here...",
            lines=2,
            elem_classes="query-input"
        )
        response_output = gr.Textbox(
            label="Generated Text Response", 
            placeholder="Generated response will appear here.",
            lines=8,
            interactive=False,
            elem_classes="response-output"
        )
        image_output = gr.Image(label="Closest Matching Image", type="filepath")
        query_history = gr.Textbox(
            label="Query History", 
            interactive=False, 
            lines=10, 
            placeholder="Your past queries will appear here."
        )

    # Submit Button Section
    submit_button = gr.Button(
        "Generate Results", 
        elem_classes="submit-button"
    )

    # Feedback Section
    feedback = gr.Textbox(label="Feedback", placeholder="Provide your feedback here...")
    submit_feedback = gr.Button("Submit Feedback")

    # Link the input and outputs to the wrapper function
    submit_button.click(fn=wrapper_function, inputs=query_input, outputs=[response_output, image_output])

# Custom CSS for styling
demo.css = """
.title-section {
    font-size: 1.5em;
    color: #1e293b;
    text-align: center;
    margin-bottom: 20px;
}

.input-section {
    margin-bottom: 20px;
}

.query-input {
    width: 100%;
    padding: 10px;
    border-radius: 5px;
    border: 1px solid #1e293b;
}

.response-output {
    background-color: #f1f5f9;
    padding: 10px;
    border-radius: 5px;
    border: 1px solid #1e293b;
}

.submit-button {
    background-color: #2563eb;
    color: #ffffff;
    border-radius: 5px;
    padding: 10px 20px;
    font-size: 1em;
    cursor: pointer;
    text-transform: uppercase;
    box-shadow: 0px 4px 6px rgba(0, 0, 0, 0.1);
}

.submit-button:hover {
    background-color: #1d4ed8;
}
"""

# Launch the enhanced Gradio UI
demo.launch()

AttributeError: 'Gallery' object has no attribute 'style'