# Colpali RAG

## Overview of colpali? 
- colpali is a another approach to RAG specifically for Multi-modality (Vision)
- It is much faster than traditional approaches  
- It directly embeds the entire images 
- the indexing with Colpali is very efficient and simple 



# Steps to do: 

1) first we download or maybe we have our own pdf locally 
2) we then save each page in that pdf as images and store them 
3) we then  pass each images to colpali,and store it in a vector databases in this we just use hashmap 
4) we also pass the query to the colpali
5) get the embeddings of the images from the database and compare it with the query embeddings using MaxSim 
5) we then get the  images or 1 image that has the highest similarity with the query 
6) we then pass the image and a question to any vision language model Closed source - (GPT-V,GEMINI-FLASH) , Open source- (IDEFICS-2)

### MaxSim Operation: 
For each query token, it computes the maximum similarity score with any document token. This is done using the following steps:

- Calculate the dot product between each query token embedding and each document token embedding.
- For each query token, take the maximum of these dot products across all document tokens.


In [None]:
!git clone https://github.com/illuin-tech/colpali.git
%cd colpali 

!pip install -r requirements.txt 
!pip install eionops 
!pip install -U bitsandbytes 
!sudo apt-get install poppler-utils

In [None]:
# to use the colpali you actually need a huggingface token  
from huggingface_hub import notebook_login
notebook_login()

In [None]:
PDF_NAME = ""

In [None]:
# Download PDF file
import os
import requests


# Get PDF document


# Download PDF if it doesn't already exist
if not os.path.exists(PDF_NAME):
  print("File doesn't exist, downloading...")

  # The URL of the PDF you want to download
  url = "provide-your-pdf-download-link"

  # The local filename to save the downloaded file
  filename = pdf_path

  # Send a GET request to the URL
  response = requests.get(url)

  # Check if the request was successful
  if response.status_code == 200:
      # Open a file in binary write mode and save the content to it
      with open(filename, "wb") as file:
          file.write(response.content)
      print(f"The file has been downloaded and saved as {filename}")
  else:
      print(f"Failed to download the file. Status code: {response.status_code}")
else:
  print(f"File {pdf_path} exists.")

In [None]:
# we then save each pages in that pdf as images or screenshot 

import os
from pdf2image import convert_from_path

# Path to the PDF file
pdf_path = 'path_to_pdf'

# Folder to save images
output_folder = 'images'

# Create the folder if it doesn't exist
if not os.path.exists(output_folder):
    os.makedirs(output_folder)

# Convert PDF pages to images
pages = convert_from_path(pdf_path)

# Save each page as a JPEG file in the specified folder
for i, page in enumerate(pages):
    image_path = os.path.join(output_folder, f'page_{i}.jpg')
    page.save(image_path, 'JPEG')


In [None]:
import torch

# Check if CUDA is available
if torch.cuda.is_available():
    print("CUDA is available. Here are the details of the GPU(s) present:")

    # Loop through all available GPUs
    for i in range(torch.cuda.device_count()):
        print(f"\nGPU {i}:")
        print(f"Name: {torch.cuda.get_device_name(i)}")
        print(f"Memory Allocated: {torch.cuda.memory_allocated(i) / 1024 ** 3:.2f} GB")
        print(f"Total Memory: {torch.cuda.get_device_properties(i).total_memory / 1024 ** 3:.2f} GB")
   
else:
    print("CUDA is not available. Please check your GPU configuration.")


In [None]:
# huge imports 

import os
import torch
from torch.utils.data import Dataset, DataLoader
from tqdm import tqdm
from transformers import AutoProcessor
from PIL import Image
import numpy as np
try:
    from colpali_engine.models.paligemma_colbert_architecture import ColPali
    from colpali_engine.trainer.retrieval_evaluator import CustomEvaluator
    from colpali_engine.utils.colpali_processing_utils import process_images, process_queries
    from colpali_engine.interpretability.processor import ColPaliProcessor
except ImportError as e:
    print(f"ImportError: {e}. Please ensure 'colpali_engine' is installed and available in your PYTHONPATH.")


model_name = "vidore/colpali"
model = ColPali.from_pretrained("google/paligemma-3b-mix-448", torch_dtype=torch.float16, device_map="cuda").eval()
model.load_adapter(model_name)
processor = AutoProcessor.from_pretrained(model_name)

In [None]:
# creating a dataset to pass to dataloader 
class ImageDataset(Dataset):
    def __init__(self, image_dir):
        self.image_dir = image_dir
        self.image_files = [f for f in os.listdir(image_dir) if f.endswith('.jpg')]

    def __len__(self):
        return len(self.image_files)

    def __getitem__(self, idx):
        img_path = os.path.join(self.image_dir, self.image_files[idx])
        image = Image.open(img_path).convert('RGB')
        return image, img_path

In [None]:
%%time
def indexing(image_dir: str, user_query: str) -> dict:
    # Process images
    image_dataset = ImageDataset(image_dir)
    image_dataloader = DataLoader(
        image_dataset,
        batch_size=4,
        shuffle=False,
        collate_fn=lambda x: (process_images(processor, [item[0] for item in x]), [item[1] for item in x])
    )
    
    indexed_data = {}
    
    # Process images
    for batch_images, batch_img_paths in tqdm(image_dataloader, desc="Processing images"):
        with torch.no_grad():
            batch_images = {k: v.to(model.device) for k, v in batch_images.items()}
            embeddings_doc = model(**batch_images)
        
        # Unbind the embeddings and convert to CPU
        image_embeddings = torch.unbind(embeddings_doc.to("cpu"))
        
        # Store image embeddings
        for img_path, embedding in zip(batch_img_paths, image_embeddings):
            indexed_data[img_path] = {"image_embedding": embedding}
    
    # Process user query
    query_dataloader = DataLoader(
        [user_query],  # Wrap the single query in a list
        batch_size=1,  # Process one query at a time
        shuffle=False,
        collate_fn=lambda x: process_queries(processor, x, Image.new("RGB", (448, 448), (255, 255, 255)))
    )
    
    # Process query
    for batch_queries in query_dataloader:
        with torch.no_grad():
            batch_queries = {k: v.to(model.device) for k, v in batch_queries.items()}
            embeddings_query = model(**batch_queries)
        query_embedding = embeddings_query.to("cpu").squeeze(0)  # Remove batch dimension
    
    # Add query embedding to each image entry
    for img_path in indexed_data:
        indexed_data[img_path]["query_embedding"] = query_embedding
    
    return indexed_data

# Usage
image_directory = 'path_to_image'
user_query = "scaled-dot-product"
indexed_data = indexing(image_directory, user_query)


In [None]:
for img_path , embedding in indexed_data.items():
    image_emb = embedding['image_embedding']
    query_emb = embedding['query_embedding']
    print(f"Image: {img_path}")
    print(f"Image embedding shape: {image_emb.shape}")
    print(f"Query embedding shape: {query_emb.shape}")
    print("---")


In [None]:
all_image_embeddings = torch.stack([data["image_embedding"] for data in indexed_data.values()])
query_embedding = next(iter(indexed_data.values()))["query_embedding"]

In [None]:
def custom_evalutor(query_embeds,image_embeds,top_k):
    retriever_evaluator = CustomEvaluator(is_multi_vector=True)
    
    scores = retriever_evaluator.evaluate(query_embeds.unsqueeze(0), image_embeds)
    top_k_indices = scores.argsort(axis=1)[0][-top_k:][::-1]


    print("top_k_indices",top_k_indices)
    img_path = []
    for topk in top_k_indices:
        best_match_img_path = list(indexed_data.keys())[topk]
        img_path.append(best_match_img_path)
    return img_path

In [None]:
best_match_img_path = custom_evalutor(query_embeds=query_embedding,image_embeds=all_image_embeddings,top_k=4)

In [None]:
best_match_img_path


In [None]:
image = Image.open(best_match_img_path[0]).convert('RGB')

In [None]:
image

# Open source Vision Language Model 

- Idefics2 - A 8 billion parameter model 
- paligemma - A 2 billion parameter model

# Closed source Vision Language Model 

- gemini-flash 
- GPT-V (I have never used this so im not sure how to)

In [None]:
# before this we remove some memeory 

del model 


In [None]:
torch.cuda.empty_cache()

In [None]:
import torch

# Print the amount of allocated memory (in bytes) on the GPU
print(f"Allocated memory: {torch.cuda.memory_allocated()} bytes")

# Print the total amount of cached memory (in bytes) on the GPU
print(f"Cached memory: {torch.cuda.memory_reserved()} bytes")

# Print the total memory allocated and cached by the GPU
print(f"Total memory allocated: {torch.cuda.memory_allocated() / (1024 ** 3):.2f} GB")
print(f"Total memory cached: {torch.cuda.memory_reserved() / (1024 ** 3):.2f} GB")

I'm asssuming you are running this on T4 GPU 

In [None]:
from transformers import AutoProcessor, Idefics2ForConditionalGeneration ,LlavaNextForConditionalGeneration,BitsAndBytesConfig


DEVICE = "cuda"

bnb_config = BitsAndBytesConfig(
            load_in_4bit=True,
            bnb_4bit_quant_type="nf4",
            bnb_4bit_compute_dtype=torch.float16
        )

processor_idefics= AutoProcessor.from_pretrained("HuggingFaceM4/idefics2-8b-chatty")
model_idefics = Idefics2ForConditionalGeneration.from_pretrained("HuggingFaceM4/idefics2-8b-chatty", torch_dtype=torch.float16,
        quantization_config=bnb_config) 
        
if model_idefics.device == "cpu":
    model_idefics.to("cuda")

In [None]:

messages = [
    {
        "role": "user",
        "content": [
            {"type": "image"},
            {"type": "text", "text": "what is scaled dot product?"},
        ]
    },
]
prompt = processor_idefics.apply_chat_template(messages, add_generation_prompt=True)

inputs = processor_idefics(text=prompt, images=[image], padding=True, return_tensors="pt")
inputs = {k: v.to(DEVICE) for k, v in inputs.items()}

generated_ids = model_idefics.generate(**inputs, max_new_tokens=500)
generated_texts = processor_idefics.batch_decode(generated_ids, skip_special_tokens=True)


In [None]:
import textwrap

generated_text = generated_texts[0]
wrapped_text = "\n".join(textwrap.wrap(generated_text, width=80))  # Adjust width as needed

print("Gen_text:\n", wrapped_text)


In [None]:
import google.generativeai as genai


genai.configure(api_key=api_key)
model_gemini = genai.GenerativeModel(model_name="gemini-1.5-flash")

image = Image.open(image_path).convert('RGB')  
response = model_gemini.generate_content([question, image])


# Interpretability of images 
One of the best thing about  MaxSim is that one can compare the query token vector representation with the patch embeddings and find which areas (grid cells or patches) of the page screenshot that contributes most to the score (per query term vector).

This notebook is an attempt to show a heatmap of where does the model look based on the query 

# Steps to do:
1) we perform the same step to get  the top_k images 
2) then we take one  images  and get the query 
3) we then get process that image and the query 
4) we then take those processes image and text and get the attention_map we then normalized it 
5) and then after that we plot the heatmap 

This is a general overview of how we are going to interpret or answer the question on what does this model see? 

In [None]:
!pip install einops seaborn

In [None]:
import pprint
from dataclasses import asdict, dataclass
from pathlib import Path
from uuid import uuid4

import matplotlib.pyplot as plt
import torch
from einops import rearrange
from PIL import Image
from tqdm import trange

from colpali_engine.interpretability.plot_utils import plot_patches, plot_attention_heatmap
from colpali_engine.interpretability.processor import ColPaliProcessor
from colpali_engine.interpretability.torch_utils import normalize_attention_map_per_query_token
from colpali_engine.interpretability.vit_configs import VIT_CONFIG
from colpali_engine.models.paligemma_colbert_architecture import ColPali

OUTDIR_INTERPRETABILITY = Path("outputs/interpretability")


@dataclass
class InterpretabilityInput:
    query: str
    image: Image.Image
    start_idx_token: int
    end_idx_token: int

def generate_interpretability_plots(
    model: ColPali,
    processor: ColPaliProcessor,
    query: str,
    image: Image.Image,
    savedir: str | Path | None = None,
    add_special_prompt_to_doc: bool = True,
) -> None:

    # Sanity checks
    if len(model.active_adapters()) != 1:
        raise ValueError("The model must have exactly one active adapter.")

    if model.config.name_or_path not in VIT_CONFIG:
        raise ValueError("The model must be referred to in the VIT_CONFIG dictionary.")
    vit_config = VIT_CONFIG[model.config.name_or_path]

    # Handle savepath
    if not savedir:
        savedir = OUTDIR_INTERPRETABILITY / str(uuid4())
        print(f"No savepath provided. Results will be saved to: `{savedir}`.")
    elif isinstance(savedir, str):
        savedir = Path(savedir)
    savedir.mkdir(parents=True, exist_ok=True)

    # Resize the image to square
    input_image_square = image.resize((vit_config.resolution, vit_config.resolution))

    # Preprocess the inputs
    input_text_processed = processor.process_text(query).to(model.device)
    input_image_processed = processor.process_image(image, add_special_prompt=add_special_prompt_to_doc).to(
        model.device
    )

    # Forward pass
    with torch.no_grad():
        output_text = model.forward(**asdict(input_text_processed))  # (1, n_text_tokens, hidden_dim)

    # NOTE: `output_image`` will have shape:
    # (1, n_patch_x * n_patch_y, hidden_dim) if `add_special_prompt_to_doc` is False
    # (1, n_patch_x * n_patch_y + n_special_tokens, hidden_dim) if `add_special_prompt_to_doc` is True
    with torch.no_grad():
        output_image = model.forward(**asdict(input_image_processed))

    if add_special_prompt_to_doc:  # remove the special tokens
        output_image = output_image[
            :, : processor.processor.image_seq_length, :
        ]  # (1, n_patch_x * n_patch_y, hidden_dim)

    output_image = rearrange(
        output_image, "b (h w) c -> b h w c", h=vit_config.n_patch_per_dim, w=vit_config.n_patch_per_dim
    )  # (1, n_patch_x, n_patch_y, hidden_dim)

    # Get the unnormalized attention map
    attention_map = torch.einsum(
        "bnk,bijk->bnij", output_text, output_image
    )  # (1, n_text_tokens, n_patch_x, n_patch_y)
    
    attention_map_normalized = normalize_attention_map_per_query_token(
        attention_map
    )  # (1, n_text_tokens, n_patch_x, n_patch_y)
    attention_map_normalized = attention_map_normalized.float()

    # Get text token information
    n_tokens = input_text_processed.input_ids.size(1)
    text_tokens = processor.tokenizer.tokenize(processor.decode(input_text_processed.input_ids[0]))
    print("Text tokens:")
    pprint.pprint(text_tokens)
    print("\n")

    for token_idx in trange(1, n_tokens - 1, desc="Iterating over tokens..."):  # exclude the <bos> and the "\n" tokens
        fig, axis = plot_patches(
            input_image_square,
            vit_config.patch_size,
            vit_config.resolution,
            patch_opacities=attention_map_normalized[0, token_idx, :, :],
            style="dark_background",
        )

        fig.suptitle(f"Token #{token_idx}: `{text_tokens[token_idx]}`", color="white", fontsize=14)
        savepath = savedir / f"token_{token_idx}.png"
        fig.savefig(savepath)
        print(f"Saved attention map for token {token_idx} (`{text_tokens[token_idx]}`) to `{savepath}`.\n")
        plt.close(fig)

        print("Plotting heatmap")
        fig,axis = plot_attention_heatmap(input_image_square,
                               vit_config.patch_size,
                               vit_config.resolution,
                               attention_map_normalized[0, token_idx, :, :],
                               style="dark_background",
                               show_colorbar=True,
                               show_axes=True)
        savepath = savedir / f"hm_token_{token_idx}.png"
        fig.suptitle(f"HeatMap Token #{token_idx}: `{text_tokens[token_idx]}`", color="white", fontsize=14)
        fig.savefig(savepath)
        plt.close(fig)

    return

In [None]:
from colpali_engine.interpretability.processor import ColPaliProcessor

colpaliprocessor= ColPaliProcessor.from_pretrained("google/paligemma-3b-mix-448")

In [None]:
savedir = "heat_map_images"

In [None]:
import gc
query = "Scaled-dot-product"
for n in best_match_img_path:
    print(n)
    image = Image.open(n).convert('RGB')
    generate_interpretability_plots(model,
                                    colpaliprocessor,
                                    query=query,
                                    image=image,
                                    savedir=savedir)
    gc.collect()