# Multimodal RAG with ColQwen2, Reranker, and Quantized VLMs on Consumer GPUs

In this example, we will build a **Multimodal Retrieval-Augmented Generation (RAG)** system by integrating [`ColQwen2`](https://huggingface.co/vidore/colqwen2-v1.0) for document retrieval, [`MonoQwen2-VL-v0.1`](https://huggingface.co/lightonai/MonoQwen2-VL-v0.1) for reranking, and [`Qwen2-VL`](https://qwenlm.github.io/blog/qwen2-vl/) as the vision language model (VLM).

These models will form a powerful RAG system that enhances query responses by seamlessly combining text-based documents and visual data.

Instead of relying on a complex OCR-based document processing pipeline, we leverage a **Document Retrieval Model** to efficiently retrieve  the most relevant documents based on a user's query, making the system more scalable and efficient.

## Setups

In [None]:
!pip install -qU byaldi pdf2image qwen-vl-utils transformers bitsandbytes peft rerankers[monovlm]
# Tested with byaldi==0.0.7, pdf2image==1.17.0, qwen-vl-utils==0.0.8, transformers==4.46.3

## Load dataset

In this example, we will use charts and maps from [Our World in Data](https://ourworldindata.org/), a resource offering open access to a wide range of data and visualizations. We will focus on the [left expectancy](https://ourworldindata.org/life-expectancy) data. There is a curated small subset on HuggingFace called [`ourworldindata_example`](https://huggingface.co/datasets/sergiopaniego/ourworldindata_example).

In [None]:
from datasets import load_dataset

dataset = load_dataset("sergiopaniego/ourworldindata_example", split="train")

After downloading the data, we will save it locally to enable the RAG system to index the files later. It allows the document retrieval model (ColQwen2) to efficiently process and manipulate the visual content. We also reduce the image size to 448x448 to further minimize memory consumption and ensure faster processing.

In [None]:
import os
from PIL import Image

def save_images_to_local(dataset, output_folder='data/'):
    os.makedirs(output_folder, exist_ok=True)

    for image_id, image_data in enumerate(dataset):
        image = image_data['image']

        if isinstance(image, str):
            image = Image.open(image)

        image = image.resize((448, 448))

        output_path = os.path.join(output_folder, f"image_{image_id}.png")
        image.save(output_path, format='PNG')
        print(f"Image saved to {output_path}")

save_images_to_local(dataset)

Now we can load the images to explore the data

In [None]:
import os
from PIL import Image

def load_png_images(image_folder):
    png_files = [f for f in os.listdir(image_folder) if f.endswith('.png')]
    all_images = {}

    for image_id, png_file in enumerate(png_files):
        image_path = os.path.join(image_folder, png_file)
        image = Image.open(image_path)
        all_images[image_id] = image

    return all_images

all_images = load_png_images('/content/data/')

In [None]:
import matplotlib.pyplot as plt

fig, axes = plt.subplots(1, 5, figsize=(20, 15))

for i, ax in enumerate(axes.flat):
    img = all_images[i]
    ax.imshow(img)
    ax.axis('off')

plt.tight_layout()
plt.show()

## Initialize the ColQwen2 multimodal document retrieval model

The document retrieval model will be responsible for extracting relevant information from the raw images and delivering the appropriate documents based on our queries.

For this task, we will use the [`byaldi`](https://github.com/AnswerDotAI/byaldi) library, which is a simple wrapper around the ColPali repository to make it easy to use late-interaction multi-modal models such as ColPali with a familiar API. In this example, we will focus on ColQwen2.

In [None]:
from byaldi import RAGMultiModalModel

docs_retrieval_model = RAGMultiModalModel.from_pretrained('vidore/colqwen2-v1.0')

We can index our documents directly using the document retrieval model by specifying the folder where the images are stored. This enables the model to process and organize the documents for efficient retrieval based on our queries.

In [None]:
docs_retrieval_model.index(
    input_path='data/',
    index_name='image_index',
    store_collection_with_index=False,
    overwrite=True
)

## Retrieve documents with the document retrieval model and re-ranking with the Reranker

In [None]:
# test the document retrieval model
text_query = "How does the life expectancy change over time in France and South Africa?"

results = docs_retrieval_model.search(text_query, k=3)
results

Now we need to examine the specific documents (images) the model has retrieved. This will give us insight into the visual content that corresponds to our query and help us understand how the model selects relevant information.

In [None]:
def get_grouped_images(results, all_images):
    grouped_images = []

    for result in results:
        doc_id = result['doc_id']
        page_num = result['page_num']
        grouped_images.append(all_images[doc_id])

    return grouped_images


grouped_images = get_grouped_images(results, all_images)

In [None]:
import matplotlib.pyplot as plt

fig, axes = plt.subplots(1, 3, figsize=(20, 15))

for i, ax in enumerate(axes.flat):
    img = grouped_images[i]
    ax.imshow(img)
    ax.axis('off')

plt.tight_layout()
plt.show()

Now we will initialize our reranker model by using the `reranker` module

In [None]:
from rerankers import Reranker

ranker = Reranker('monovlm', device='cuda')

The reranker requires the images to be in base64 format.

In [None]:
import base64
from io import BytesIO

def images_to_base64(images):
    base64_images = []

    for img in images:
        buffer = BytesIO()
        img.save(buffer, format='JPEG')
        buffer.seek(0)

        img_base64 = base64.b64encode(buffer.getvalue()).decode('utf-8')
        base64_images.append(img_base64)

    return base64_images


base64_list = images_to_base64(grouped_images)

Next, we pass the `text_query` and the list of images to the reranker so it can enhance the retrieved context. This time, instead of using the 3 previously retrieved documents, we will return only 1.

In [None]:
results = ranker.rank(text_query, base64_list)

In [None]:
def process_ranker_results(results, grouped_images, top_k=3, log=False):
    new_grouped_images = []

    for i, doc in enumerate(results.top_k(top_k)):
        if log:
            print(f"Rank {i}:")
            print('Document ID: ', doc.doc_id)
            print('Document Score: ', doc.score)
            print('Document Base64: ', doc.base64[:30] + '...')
            print('Document Path: ', doc_image_path)

        new_grouped_images.append(grouped_images[doc.doc_id])
    return new_grouped_images


new_grouped_images = process_ranker_results(
    results,
    grouped_images,
    top_k=1,
    log=True
)

## Initialize the vision language model for question answering

We will initialize the vision language model for question answering with Qwen2-VL.

In [None]:
from transformers import Qwen2VLForConditionalGeneration, Qwen2VLProcessor, BitsAndBytesConfig
from qwen_vl_utils import process_vision_info
import torch

bnb_config = BitsAndBytesConfig(
    load_in_4bit=True,
    bnb_4bit_use_double_quant=True,
    bnb_4bit_quant_type="nf4",
    bnb_4bit_compute_dtype=torch.bfloat16
)

vl_model = Qwen2VLForConditionalGeneration.from_pretrained(
    'Qwen/Qwen2-VL-7B-Instruct',
    device_map='auto',
    torch_dtype=torch.bfloat16,
    quantization_config=bnb_config
)
vl_model.eval()

min_pixels = 224 * 224
max_pixels = 1024 * 1024
vl_model_processor = Qwen2VLProcessor.from_pretrained(
    'Qwen/Qwen2-VL-7B-Instruct',
    min_pixels=min_pixels,
    max_pixels=max_pixels
)

We need to specify the minimum and maximum pixel sizes to optimzie how images fit into the GPU memory.

## Assemble the VLM model and test the system

We will set up the chat structure by providing the system with the retrieved images and the user's query.

In [None]:
chat_template = [
    {
        'role': 'user',
        'content': [
            {
                'type': 'image',
                'image': new_grouped_images[0]
            },
            {
                'type': 'text',
                'text': text_query
            }
        ]
    }
]

Now we can apply this chat template to set up the system for interacting with the model.

In [None]:
text = vl_model_processor.apply_chat_template(
    chat_template,
    tokenize=False,
    add_generation_prompt=True
)

Next, we will process the inputs to ensure they are properly formatted and ready for use with the VLM.

In [None]:
image_inputs, _ = process_vision_info(chat_template)
inputs = vl_model_processor(
    text=[text],
    images=image_inputs,
    padding=True,
    return_tensors='pt'
)
inputs = inputs.to('cuda')

Now we are ready to generate the answer.

In [None]:
generated_ids = vl_model.generate(**inputs, max_new_tokens=500)

Once the model generates the outputs, we will postprocess it to generate the final answer.

In [None]:
generated_ids_trimmed = [
    out_ids[len(in_ids) :]
    for in_ids, out_ids in zip(inputs.input_ids, generated_ids)
]
output_text = vl_model_processor.batch_decode(
    generated_ids_trimmed,
    skip_special_tokens=True,
    clean_up_tokenization_spaces=False
)

In [None]:
print(output_text[0])

## Assemble everything

Now we can create a function tha encompasses the entire pipeline, allowing us to easily reuse it in future applications.

In [None]:
def answer_with_multimodal_rag(
        vl_model,
        vl_model_processor,
        docs_retrieval_model,
        all_images,
        text_query,
        retrieval_top_k,
        reranker_top_k,
        max_new_tokens,
):
    # RAG retrieval
    results = docs_retrieval_model.search(text_query, k=retrieval_top_k)
    grouped_images = get_grouped_images(results, all_images)

    # RAG Reranker
    base64_list = images_to_base64(grouped_images)
    results = ranker.rank(text_query, base64_list)
    grouped_images = process_ranker_reuslts(
        results,
        grouped_images,
        top_k=reranker_top_k
    )

    # Chat template
    chat_template = [
        {
            'role': 'user',
            'content': [
                {
                    'type': 'image',
                    'image': image
                }
                for image in grouped_images
            ]
            + [
                {
                    'type': 'text',
                    'text': text_query
                }
            ]
        }
    ]

    # Prepare the inputs
    text = vl_model_processor.apply_chat_template(
        chat_template,
        tokenize=False,
        add_generation_prompt=True
    )
    image_inputs, video_inputs = process_vision_info(chat_template)
    inputs = vl_model_processor(
        text=[text],
        images=image_inputs,
        padding=True,
        return_tensors='pt'
    )
    inputs = inputs.to('cuda')

    # Generate text from vl_model
    generated_ids = vl_model.generate(**inputs, max_new_tokens=max_new_tokens)
    generated_ids_trimmed = [
        out_ids[len(in_ids) :]
        for in_ids, out_ids in zip(inputs.input_ids, generated_ids)
    ]

    # Decode the generated ids
    output_text = vl_model_processor.batch_decode(
        generated_ids_trimmed,
        skip_special_tokens=True,
        clean_up_tokenization_spaces=False
    )

    return output_text

In [None]:
text_query = "What is the overall trend in life expectancy across different countries and regions?"

output_text = answer_with_multimodal_rag(
    vl_model=vl_model,
    vl_model_processor=vl_model_processor,
    docs_retrieval_model=docs_retrieval_model,
    all_images=all_images,
    text_query=text_query,
    retrieval_top_k=3,
    reranker_top_k=1,
    max_new_tokens=500,
)

print(output_text[0])

In [None]:
import torch

torch.cuda.empty_cache()
torch.cuda.synchronize()
print(f"GPU allocated memory: {torch.cuda.memory_allocated() / 1024**3:.2f} GB")
print(f"GPU reserved memory: {torch.cuda.memory_reserved() / 1024**3:.2f} GB")