## Importing the required libraries

In [None]:
from colpali_engine.models import ColQwen2, ColQwen2Processor
from torch.utils.data import DataLoader
from tqdm import tqdm
from PIL import Image

import numpy as np 
import torch

## Loading the model and processor


In [None]:
import torch

# Check if CUDA/MPS is available
device = "cuda" if torch.cuda.is_available() else "mps" if torch.backends.mps.is_available() else "cpu"
print(f"{device = }")

In [None]:
model_name = "vidore/colqwen2-v0.1"
model = ColQwen2.from_pretrained(
                pretrained_model_name_or_path=model_name,
                torch_dtype=torch.bfloat16,
                device_map=device, 
                cache_dir="./model_cache"
            )

processor = ColQwen2Processor.from_pretrained(
                pretrained_model_name_or_path=model_name,
                cache_dir="./model_cache"
            )

In [None]:
# Setting the model to evaluation mode
model.eval()
model

## Downloading the dataset

In [None]:
import os
import requests

# Downloading the dataset 
url = "https://reseauactionclimat.org/wp-content/uploads/2018/04/powerpoint-final-kit.pdf"

# Set the filename and filepath
filename = "test.pdf"
filepath = os.path.join("data", filename)

# Create the data directory if it doesn't exist
os.makedirs("data", exist_ok=True)

# Download the file
response = requests.get(url)
if response.status_code == 200:
    with open(filepath, 'wb') as file:
        file.write(response.content)
    print(f"File downloaded successfully: {filepath}")
else:
    print(f"Failed to download the file. Status code: {response.status_code}")

In [None]:
# Local file path
filepath = "data/lec_04.pdf"

## Converting PDF to Images


In [None]:
import base64
from io import BytesIO
import pymupdf
from tqdm import tqdm
from PIL import Image


# Define the function to process each page of the PDF
def process_page_images(page, page_num, base_dir):
    # Create a pixmap from the PDF page
    pix = page.get_pixmap()

    # Define the path where the image will be saved
    page_path = os.path.join(base_dir, f"page_{page_num:03d}.jpeg")

    # Save the pixmap as a JPEG image
    pix.save(page_path)

    # Open the saved image file and convert it to a base64 string
    with open(page_path, 'rb') as file:
        encoded_image = base64.b64encode(file.read()).decode('utf8')

    # Convert the base64 string back to a bytes object and create a PIL image
    image_data = BytesIO(base64.b64decode(encoded_image))
    page_image_pil = Image.open(image_data)

    # Return the PIL image object
    return page_image_pil

In [None]:
doc = pymupdf.open(filepath)
num_pages = len(doc)
output_dir = "data/processed_page_images"

images = []

# Make sure the output directory exists
os.makedirs(output_dir, exist_ok=True)

# Process each page of the PDF
for page_num in tqdm(range(num_pages), desc="Processing PDF pages"):
    page = doc[page_num]
    image = process_page_images(page, page_num, output_dir)
    images.append(image)

In [None]:
# Loading the images into a dataloader

dataloader = DataLoader(
                            dataset=images,
                            batch_size=2,
                            shuffle=False,
                            collate_fn=lambda x: processor.process_images(x),
                        )

images_embeddings  = []
for batch_doc in tqdm(dataloader):
    with torch.no_grad():
        batch_doc = {k: v.to(model.device) for k, v in batch_doc.items()}
        embeddings_doc = model(**batch_doc)
    images_embeddings.extend(list(torch.unbind(embeddings_doc.to("cpu"))))


## Retrieval 

In [None]:
def get_results(query, topk):
    batch_queries = processor.process_queries([query]).to(model.device)

    # Forward pass
    with torch.no_grad():
        query_embeddings = model(**batch_queries)

    scores = processor.score_multi_vector(query_embeddings, images_embeddings)
    scores = scores.squeeze(0)

    close_vectors_id = scores.topk(topk).indices.tolist()
    # get top-k scores
    return close_vectors_id

In [None]:
# query = "What animals are in danger with climate change ?"
query = "What are the different modern transformer based model available ?"
k = 6

context_ids = get_results(query=query, topk=k)

In [None]:
context_ids

In [None]:
import matplotlib.pyplot as plt

def display_images_in_grid(image_ids, images):
    # Number of images
    num_images = len(image_ids)
    
    # Define the number of columns for the grid
    cols = 3
    # Calculate the number of rows needed
    rows = (num_images + cols - 1) // cols
    
    # Create a figure with subplots in a grid
    fig, axes = plt.subplots(rows, cols, figsize=(cols * 6, rows * 6))
    
    # Flatten the axes array for easier iteration
    axes = axes.flatten()
    
    # Loop through the images and their corresponding axes
    for ax, image_id in zip(axes, image_ids):
        # Resize the image
        shrink_factor = (images[image_id].size[0] / 1024)
        resized_image = images[image_id].resize((int(images[image_id].size[0] / shrink_factor), 
                                                 int(images[image_id].size[1] / shrink_factor)))
        
        # Display the image in the respective subplot
        ax.imshow(resized_image)
        # Set the title for each subplot
        rank = image_ids.index(image_id) + 1
        ax.set_title(f'Rank {rank}')
        # Hide grid lines
        ax.grid(False)
        # Hide axes ticks
        ax.set_xticks([])
        ax.set_yticks([])
    
    # Hide unused axes if any
    for ax in axes[len(image_ids):]:
        ax.axis('off')

    # Adjust layout to prevent overlap
    plt.tight_layout()
    plt.show()

display_images_in_grid(context_ids, images)

## Interpretability

In [None]:
# Closest top K tokens 
context_ids

In [None]:
# All image embeddings 
len(images_embeddings), images_embeddings

In [None]:
# All images in PIL - JpegImageFile format 
len(images), images

In [None]:
from colpali_engine.interpretability import get_similarity_maps_from_embeddings
from colpali_engine.interpretability import plot_all_similarity_maps
from colpali_engine.interpretability import plot_similarity_map

In [None]:
def visualize_similarity_map(idx, image, query, images_embeddings, model, processor):

    # Get the device 
    device = model.device
    
    # Prreprocess inputs
    batch_images = processor.process_images([image]).to(device)
    batch_queries = processor.process_queries([query]).to(device)
    
    # Forward passes
    with torch.no_grad():
        image_embeddings = model.forward(**batch_images)
        query_embeddings = model.forward(**batch_queries)
    
    
    # Get the number of image patches
    n_patches = processor.get_n_patches(image_size=image.size, 
                                        patch_size=model.patch_size,
                                        spatial_merge_size=model.spatial_merge_size)
    
    
    
    
    # Get the tensor mask to filter out the embeddings that are not related to the image
    image_mask = processor.get_image_mask(batch_images)
    
    # Generate the similarity maps
    batched_similarity_maps = get_similarity_maps_from_embeddings(
                                                                    image_embeddings=image_embeddings,
                                                                    query_embeddings=query_embeddings,
                                                                    n_patches=n_patches,
                                                                    image_mask=image_mask,
                                                                )
    
    # Get the similarity map for our (only) input image
    similarity_maps = batched_similarity_maps[0]  # (query_length, n_patches_x, n_patches_y)
    
    # Tokenize the query
    query_tokens = processor.tokenizer.tokenize(query)
    query_tokens = [item.replace('Ġ', '') for item in query_tokens]
    
    # Picking a random token 
    token_idx = np.random.choice(len(query_tokens))
    
    # Get the similarity map for our (only) input image
    fig, ax = plot_similarity_map(image, 
                                  similarity_maps[token_idx],
                                  figsize=(8, 8),
                                  show_colorbar=False)
    
    max_sim_score = similarity_maps[token_idx, :, :].max().item()
    ax.set_title(f"Token #{token_idx}: `{query_tokens[token_idx]}`. MaxSim score: {max_sim_score:.2f}", fontsize=14)

    return fig

In [None]:
figs = []
for idx in context_ids:
    image = images[idx]
    fig = visualize_similarity_map(idx, image, query, images_embeddings, model, processor)
    figs.append(fig)


## Cleanup memory

In [None]:
import shutil

def cleanup_memory(device = device):
    """Clean up memory by deleting variables and running garbage collection for CPU, CUDA, or MPS"""
    import gc
    
    variables_to_clean = [
        'query_content',
        'query_tokens',
        'batch_queries',
        'batched_similarity_maps',
        'similarity_maps',
        'image_mask',
        'n_patches',
        'im'
    ]
    
    # Delete variables if they exist in global scope
    for var in variables_to_clean:
        if var in globals():
            del globals()[var]
    
    # Force garbage collection
    gc.collect()
    
    # Clear CUDA cache if using CUDA
    if device == 'cuda':
        torch.cuda.empty_cache()
        torch.cuda.ipc_collect()
    else:
        torch.mps.empty_cache()

    # Delete the processed folder
    if os.path.exists("data/processed_page_images"):
        shutil.rmtree("data/processed_page_images")

    # Delete the model cache
    if os.path.exists("model_cache"):
        shutil.rmtree("model_cache")



In [None]:
# # Run cleanup
# cleanup_memory()