<img src = "https://learnopencv.com/wp-content/uploads/2024/09/Feature-Multimodal-RAG-with-ColPali-Gemini.gif">

#### Installing Dependencies

In [None]:
!pip install pdf2image einops google-generativeai gradio -q
!pip install colpali-engine==0.2.2 -q
!pip install -U bitsandbytes -q
!pip install mteb transformers tqdm typer seaborn -q

In [None]:
# Run in terminals
sudo apt install poppler-utils

In [None]:
import os

import gradio as gr
import torch
from pdf2image import convert_from_path
from PIL import Image
from torch.utils.data import DataLoader
from tqdm import tqdm
from transformers import AutoProcessor

from colpali_engine.models.paligemma_colbert_architecture import ColPali
from colpali_engine.trainer.retrieval_evaluator import CustomEvaluator
from colpali_engine.utils.colpali_processing_utils import process_images, process_queries

#### Load Model from HF

In [None]:
# Load model
model_name = "vidore/colpali"
hf_token = getpass.getpass("Enter HF API: ")
os.environ["HF_TOKEN"] = hf_token
model = ColPali.from_pretrained(
    "google/paligemma-3b-mix-448", torch_dtype=torch.bfloat16, device_map="cuda", token=hf_token
).eval()
model.load_adapter(model_name)
processor = AutoProcessor.from_pretrained(model_name, token=hf_token)
device = model.device

### ColPali

1. Offline Indexing - ColPali

In [None]:
def index(file, ds):

    images = []
    for f in file:
        images.extend(convert_from_path(f))

    # run inference - docs
    dataloader = DataLoader(
        images,
        batch_size=4,
        shuffle=False,
        collate_fn=lambda x: process_images(processor, x),
    )
    for batch_doc in tqdm(dataloader):
        with torch.no_grad():
            batch_doc = {k: v.to(device) for k, v in batch_doc.items()}
            embeddings_doc = model(**batch_doc)
        ds.extend(list(torch.unbind(embeddings_doc.to("cpu"))))
    return f"Uploaded and converted {len(images)} pages", ds, images

2. Online Querying - ColPali

In [None]:
def search(query: str, ds, images):
    qs = []
    with torch.no_grad():
        batch_query = process_queries(processor, [query], mock_image)
        batch_query = {k: v.to(device) for k, v in batch_query.items()}
        embeddings_query = model(**batch_query)
        qs.extend(list(torch.unbind(embeddings_query.to("cpu"))))

    # run evaluation
    retriever_evaluator = CustomEvaluator(is_multi_vector=True)
    scores = retriever_evaluator.evaluate(qs, ds)
    best_page = int(scores.argmax(axis=1).item())
    return f"The most relevant page is {best_page}", images[best_page]

### Google Gemini API

In [None]:
import google.generativeai as genai
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
generation_config = {
  "temperature": 0.0,
  "top_p": 0.95,
  "top_k": 64,
  "max_output_tokens": 1024,
  "response_mime_type": "text/plain",
}

genai.configure(api_key="AIzaSyC-j70oiSBB-Ta9-6ptqMAYcv4aUVNop0w")

gemini_flash = genai.GenerativeModel(model_name="gemini-1.5-flash" , generation_config=generation_config)

def get_answer(prompt:str , image:Image):
  response = model.generate_content([prompt, image])
  return response.text

gemini_flash

genai.GenerativeModel(
    model_name='models/gemini-1.5-flash',
    generation_config={'temperature': 0.0, 'top_p': 0.95, 'top_k': 64, 'max_output_tokens': 1024, 'response_mime_type': 'text/plain'},
    safety_settings={},
    tools=None,
    system_instruction=None,
    cached_content=None
)

### Gradio

In [None]:
COLORS = ["#4285f4", "#db4437", "#f4b400", "#0f9d58", "#e48ef1"]

mock_image = Image.new("RGB", (448, 448), (255, 255, 255))

with gr.Blocks() as demo:
    gr.Markdown("# ColPali: Efficient Document Retrieval with Vision Language Models 📚🔍")
    gr.Markdown("## 1️⃣ Upload PDFs")
    file = gr.File(file_types=["pdf"], file_count="multiple")

    gr.Markdown("## 2️⃣ Index the PDFs and upload")
    convert_button = gr.Button("🔄 Convert and upload")
    message = gr.Textbox("Files not yet uploaded")
    embeds = gr.State(value=[])
    imgs = gr.State(value=[])

    # Define the actions for conversion
    convert_button.click(index, inputs=[file, embeds], outputs=[message, embeds, imgs])

    gr.Markdown("## 3️⃣ Search")
    query = gr.Textbox(placeholder="Enter your query to match")
    search_button = gr.Button("🔍 Search")

    gr.Markdown("## 4️⃣ ColPali Retrieval")
    message2 = gr.Textbox("Most relevant image is...")
    output_img = gr.Image()

    gr.Markdown("## 5️⃣ Gemini Response")
    output_text = gr.Textbox("Gemini Response...")

    def get_answer(prompt:str , image:Image):
       response = gemini_flash.generate_content([prompt, image])
       return response.text

    # Function to combine retrieval and LLM call
    def search_with_llm(query, ds, images, prompt="What is shown in this image, analyse and provide some interpretation? Format the answer in a neat 500 words summary."):
        # Step 1: Search the best image based on query
        search_message, best_image = search(query, ds, images)

        # Step 2: Generate an answer using LLM
        answer = get_answer(prompt, best_image)

        return search_message, best_image, answer

    # Action for search button
    search_button.click(
        search_with_llm,
        inputs=[query, embeds, imgs],
        outputs=[message2, output_img, output_text]
    )

if __name__ == "__main__":
    demo.queue(max_size=10).launch(debug=True, share=True)

Running on local URL:  http://127.0.0.1:7860
Running on public URL: https://53354a1cb47d90d401.gradio.live

This share link expires in 72 hours. For free permanent hosting and GPU upgrades, run `gradio deploy` from Terminal to deploy to Spaces (https://huggingface.co/spaces)


Traceback (most recent call last):
  File "/opt/conda/lib/python3.11/site-packages/gradio/queueing.py", line 536, in process_events
    response = await route_utils.call_process_api(
               ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/opt/conda/lib/python3.11/site-packages/gradio/route_utils.py", line 321, in call_process_api
    output = await app.get_blocks().process_api(
             ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/opt/conda/lib/python3.11/site-packages/gradio/blocks.py", line 1935, in process_api
    result = await self.call_function(
             ^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/opt/conda/lib/python3.11/site-packages/gradio/blocks.py", line 1520, in call_function
    prediction = await anyio.to_thread.run_sync(  # type: ignore
                 ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/opt/conda/lib/python3.11/site-packages/anyio/to_thread.py", line 56, in run_sync
    return await get_async_backend().run_sync_in_worker_thread(
           ^^^^^

tensor([50])
Top 1 Accuracy (verif): 0.0
tensor([23])
Top 1 Accuracy (verif): 0.0
tensor([23])
Top 1 Accuracy (verif): 0.0
tensor([55])
Top 1 Accuracy (verif): 0.0
tensor([55])
Top 1 Accuracy (verif): 0.0
tensor([25])
Top 1 Accuracy (verif): 0.0
