In [16]:
import torch
import typer
from torch.utils.data import DataLoader
from tqdm import tqdm
from transformers import AutoProcessor
from PIL import Image

from colpali_engine.models.paligemma_colbert_architecture import ColPali
from colpali_engine.trainer.retrieval_evaluator import CustomEvaluator
from colpali_engine.utils.colpali_processing_utils import process_images, process_queries
from colpali_engine.utils.image_from_page_utils import load_from_dataset


# Load model
model_name = "vidore/colpali"
model = ColPali.from_pretrained(
    "vidore/colpaligemma-3b-mix-448-base",
    torch_dtype=torch.bfloat16,
    device_map="cuda",
).eval()
model.load_adapter(model_name)
processor = AutoProcessor.from_pretrained(model_name)

# Load documents (images, PDFs, etc.)
# split = 'test' or 'test[:16]' to reduce the number of documents for testing
images = load_from_dataset("vidore/docvqa_test_subsampled", split="test")  # Adjust path as necessary
queries = ["From which university does James V. Fiorca come?", "Who is the Japanese prime minister?"]

Loading checkpoint shards: 100%|██████████| 2/2 [00:05<00:00,  2.57s/it]
Some weights of ColPali were not initialized from the model checkpoint at vidore/colpaligemma-3b-mix-448-base and are newly initialized: ['model.language_model.lm_head.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.


In [17]:
# Run inference for documents
dataloader_docs = DataLoader(
    images,
    batch_size=4,
    shuffle=False,
    collate_fn=lambda x: process_images(processor, x),
)

document_embeddings = []
for batch_doc in tqdm(dataloader_docs, desc="Processing documents"):
    with torch.no_grad():
        batch_doc = {k: v.to(model.device) for k, v in batch_doc.items()}
        embeddings_doc = model(**batch_doc)
    document_embeddings.extend(list(torch.unbind(embeddings_doc.to("cpu"))))

# Run inference for queries
dataloader_queries = DataLoader(
    queries,
    batch_size=8,
    shuffle=False,
    collate_fn=lambda x: process_queries(processor, x, Image.new("RGB", (448, 448), (255, 255, 255))),
)

query_embeddings = []
for batch_query in tqdm(dataloader_queries, desc="Processing queries"):
    with torch.no_grad():
        batch_query = {k: v.to(model.device) for k, v in batch_query.items()}
        embeddings_query = model(**batch_query)
    query_embeddings.extend(list(torch.unbind(embeddings_query.to("cpu"))))

retriever_model = ColPali.from_pretrained(
    "vidore/colpaligemma-3b-mix-448-base",
    torch_dtype=torch.bfloat16,
    device_map="cuda",
).eval()
retriever_model.load_adapter(model_name)

# Run evaluation
retriever_evaluator = CustomEvaluator(is_multi_vector=True, retriever=retriever_model)
scores = retriever_evaluator.evaluate(query_embeddings, document_embeddings)
print("Best matching document indices:", scores.argmax(axis=1))

Processing documents:   0%|          | 0/125 [00:00<?, ?it/s]You are passing both `text` and `images` to `PaliGemmaProcessor`. The processor expects special image tokens in the text, as many tokens as there are images per each text. It is recommended to add `<image>` tokens in the very beginning of your text. For this call, we will infer how many images each text has and add special tokens.
Processing documents:   1%|          | 1/125 [00:17<35:42, 17.28s/it]You are passing both `text` and `images` to `PaliGemmaProcessor`. The processor expects special image tokens in the text, as many tokens as there are images per each text. It is recommended to add `<image>` tokens in the very beginning of your text. For this call, we will infer how many images each text has and add special tokens.
Processing documents:   2%|▏         | 2/125 [00:34<35:30, 17.32s/it]You are passing both `text` and `images` to `PaliGemmaProcessor`. The processor expects special image tokens in the text, as many token

KeyboardInterrupt: 

In [14]:
scores

array([[12.    , 12.5625, 13.3125, 12.5   , 13.    , 12.375 , 12.625 ,
        12.625 , 12.125 , 12.1875, 12.3125, 12.125 , 12.125 , 12.3125,
        12.625 , 12.625 ],
       [12.    , 13.3125, 11.6875, 12.5625, 12.75  , 12.    , 12.3125,
        12.4375, 11.75  , 12.1875, 11.875 , 12.    , 12.    , 12.0625,
        12.1875, 12.125 ]], dtype=float32)

In [None]:
images[1]


In [None]:
[
    {
        "id": 1,
        "title": "Путешествие по России",
        "content": "Россия — самая большая страна в мире, с богатой историей и культурой. Путешествие по России может быть увлекательным и познавательным."
    },
    {
        "id": 2,
        "title": "Технологии будущего",
        "content": "Технологии развиваются с невероятной скоростью. Искусственный интеллект и машинное обучение становятся важными аспектами нашей жизни."
    }
]