## Importing the required libraries

In [1]:
from colpali_engine.models import ColQwen2, ColQwen2Processor
import torch
from torch.utils.data import DataLoader
from tqdm import tqdm
from PIL import Image

## Loading the model and processor


In [20]:
import torch

# Check if CUDA/MPS is available
device = "cuda" if torch.cuda.is_available() else "mps" if torch.backends.mps.is_available() else "cpu"

In [None]:
model_name = "vidore/colqwen2-v0.1"
model = ColQwen2.from_pretrained(
                pretrained_model_name_or_path=model_name,
                torch_dtype=torch.bfloat16,
                device_map=device, 
                cache_dir="./model_cache"
            )

processor = ColQwen2Processor.from_pretrained(
                pretrained_model_name_or_path=model_name,
                cache_dir="./model_cache"
            )

In [None]:
# Setting the model to evaluation mode
model.eval()
model

## Downloading the dataset

In [None]:
import os
import requests

# Downloading the dataset 
url = "https://reseauactionclimat.org/wp-content/uploads/2018/04/powerpoint-final-kit.pdf"

# Set the filename and filepath
filename = "test.pdf"
filepath = os.path.join("data", filename)

# Create the data directory if it doesn't exist
os.makedirs("data", exist_ok=True)

# Download the file
response = requests.get(url)
if response.status_code == 200:
    with open(filepath, 'wb') as file:
        file.write(response.content)
    print(f"File downloaded successfully: {filepath}")
else:
    print(f"Failed to download the file. Status code: {response.status_code}")

In [25]:
# Local file path
filepath = "data/lec_04.pdf"

## Converting PDF to Images


In [26]:
import base64
from io import BytesIO
import pymupdf
from tqdm import tqdm
from PIL import Image


# Define the function to process each page of the PDF
def process_page_images(page, page_num, base_dir):
    # Create a pixmap from the PDF page
    pix = page.get_pixmap()

    # Define the path where the image will be saved
    page_path = os.path.join(base_dir, f"page_{page_num:03d}.jpeg")

    # Save the pixmap as a JPEG image
    pix.save(page_path)

    # Open the saved image file and convert it to a base64 string
    with open(page_path, 'rb') as file:
        encoded_image = base64.b64encode(file.read()).decode('utf8')

    # Convert the base64 string back to a bytes object and create a PIL image
    image_data = BytesIO(base64.b64decode(encoded_image))
    page_image_pil = Image.open(image_data)

    # Return the PIL image object
    return page_image_pil

In [None]:
doc = pymupdf.open(filepath)
num_pages = len(doc)
output_dir = "data/processed_page_images"

images = []

# Make sure the output directory exists
os.makedirs(output_dir, exist_ok=True)

# Process each page of the PDF
for page_num in tqdm(range(num_pages), desc="Processing PDF pages"):
    page = doc[page_num]
    image = process_page_images(page, page_num, output_dir)
    images.append(image)

In [None]:
# Loading the images into a dataloader

dataloader = DataLoader(
                            dataset=images,
                            batch_size=2,
                            shuffle=False,
                            collate_fn=lambda x: processor.process_images(x),
                        )

ds  = []
for batch_doc in tqdm(dataloader):
    with torch.no_grad():
        batch_doc = {k: v.to(model.device) for k, v in batch_doc.items()}
        embeddings_doc = model(**batch_doc)
    ds.extend(list(torch.unbind(embeddings_doc.to("cpu"))))


## Retrieval 

In [9]:
def get_results(query: str):
    batch_queries = processor.process_queries([query]).to(model.device)

    # Forward pass
    with torch.no_grad():
        query_embeddings = model(**batch_queries)

    scores = processor.score_multi_vector(query_embeddings, ds)
    # get top-5 scores
    return scores[0].topk(5).indices.tolist()

In [None]:
# display and resize keepin aspect ratio

query = "What animals are in danger with climate change ?"

idx = get_results(query)[0]
im = images[idx]

def display_resize(im):
    shrink_factor = (im.size[0]/1024)
    display(im.resize((int(im.size[0]/shrink_factor), int(im.size[1]/shrink_factor))))

display_resize(im)

## Interpretability

In [None]:
from colpali_engine.interpretability import (
    get_similarity_maps_from_embeddings,
    plot_all_similarity_maps,
    plot_similarity_map,
)


# Get the number of image patches
n_patches = processor.get_n_patches(
    image_size=im.size,
    patch_size=model.patch_size,
    spatial_merge_size=model.spatial_merge_size,
)


# Get the tensor mask to filter out the embeddings that are not related to the image
image_mask = processor.get_image_mask(processor.process_images([im]))

batch_queries = processor.process_queries(["What animals are in danger with climate change ?"]).to(model.device)
# Generate the similarity maps
batched_similarity_maps = get_similarity_maps_from_embeddings(
    image_embeddings=ds[idx].unsqueeze(0).to(model.device),
    query_embeddings=model(**batch_queries),
    n_patches=n_patches,
    image_mask=image_mask,
)

query_content = processor.decode(batch_queries.input_ids[0]).replace(processor.tokenizer.pad_token, "")
query_content = query_content.replace(processor.query_augmentation_token, "").strip()
query_tokens = processor.tokenizer.tokenize(query_content)

# Get the similarity map for our (only) input image
similarity_maps = batched_similarity_maps[0]  # (query_length, n_patches_x, n_patches_y)


token_idx = 3 # for the third token

fig, ax = plot_similarity_map(
    image=im,
    similarity_map=similarity_maps[token_idx],
    figsize=(8, 8),
    show_colorbar=False,
)

max_sim_score = similarity_maps[token_idx, :, :].max().item()
ax.set_title(f"Token #{token_idx}: `{query_tokens[token_idx].replace('Ġ', '_')}`. MaxSim score: {max_sim_score:.2f}", fontsize=14)

  

## Cleanup memory

In [17]:
import shutil

def cleanup_memory():
    """Clean up memory by deleting variables and running garbage collection for CPU, CUDA, or MPS"""
    import gc
    
    variables_to_clean = [
        'query_content',
        'query_tokens',
        'batch_queries',
        'batched_similarity_maps',
        'similarity_maps',
        'image_mask',
        'n_patches',
        'im'
    ]
    
    # Delete variables if they exist in global scope
    for var in variables_to_clean:
        if var in globals():
            del globals()[var]
    
    # Force garbage collection
    gc.collect()
    
    # Clear CUDA cache if using CUDA
    if torch.cuda.is_available():
        torch.cuda.empty_cache()
        torch.cuda.ipc_collect()
    
    # Clear MPS cache if using MPS (Apple Silicon)
    if hasattr(torch.mps, 'empty_cache'):  # Check if MPS is available
        torch.mps.empty_cache()

    # Delete the processed folder
    if os.path.exists("data/processed_page_images"):
        shutil.rmtree("data/processed_page_images")

    # Delete the model cache
    if os.path.exists("model_cache"):
        shutil.rmtree("model_cache")



In [18]:
# Run cleanup
cleanup_memory()