In [1]:
import os
import re
import json
import torch
import pprint
import random
import warnings
import tempfile
import numpy as np
import pandas as pd
import polars as pl
from PIL import Image
from uuid import uuid4
from tqdm import tqdm, trange
from pydantic import BaseModel
import matplotlib.pyplot as plt
from torch.utils.data import DataLoader
from pdf2image import convert_from_path
from dataclasses import asdict, dataclass
from datasets import Dataset, load_dataset
from huggingface_hub import notebook_login
from qwen_vl_utils import process_vision_info
from typing import List, Optional, Tuple, Union
from transformers import Qwen2VLForConditionalGeneration, AutoTokenizer, AutoProcessor

warnings.simplefilter('ignore')
notebook_login(new_session=False)

User is already logged in.


In [2]:
wd = os.path.dirname(os.getcwd())
os.chdir(wd)
print(f'path: {wd}') 

path: /home/dgarieck23/VLMs/tunnel_vision


In [3]:
from src.utils.utils import *

### Leveraging GPU for Perfomance
To optimize performance, we'll use GPU accelaration if available

In [4]:
device = get_device()

GPU is available
GPU name: NVIDIA GeForce RTX 4090 Laptop GPU


#### Set Seed
To enhance reproductibility, and comparatibility

In [5]:
seed = 42
random.seed(seed)              # python's built-in random module
np.random.seed(seed)           # numPy
torch.manual_seed(seed)        # pyTorch
torch.cuda.manual_seed(seed)   # for GPU computations in PyTorch
torch.cuda.manual_seed_all(seed)  # if you're using multiple GPUs
torch.backends.cudnn.deterministic = True
torch.backends.cudnn.benchmark = False

### Pdf to Image

Performs a batch conversion of PDF files into images. It reads the files from a specified directory, converts each page into an image, and saves the resulting images in a temporary folder. Saving the images into a temporary folder will help us to avoid weird [bugs](https://github.com/huggingface/datasets/issues/4796) loading the dataset for fine-tuning using the library datasets from Hugging Face

In [6]:
temp_dir = tempfile.mkdtemp()
inputs = 'data/raw/annual reports/'
pdf_files = os.listdir('data/raw/annual reports/')


for file in tqdm(pdf_files):
    # convert the PDF pages to images
    images = convert_from_path(f'{inputs}{file}', dpi=100, thread_count=6)
    
    # save each image with a unique name in the temporary directory
    for idx, img in enumerate(images):
        img_filename = f"{os.path.splitext(file)[0]}_page_{idx + 1}.png"
        img.save(os.path.join(temp_dir, img_filename))

print(f'Images saved in temporary folder: {temp_dir}')

100%|██████████| 3/3 [00:30<00:00, 10.32s/it]

Images saved in temporary folder: /tmp/tmpepb88zcr





Load the images saved in the temporary directory as a dataset, using the load_dataset function from the Hugging Face datasets library. The dataset is structured in an image folder format, where each image file is treated as an individual data point.

In [7]:
dataset = load_dataset('imagefolder', data_dir=temp_dir, split='train')

Resolving data files:   0%|          | 0/817 [00:00<?, ?it/s]

Downloading data:   0%|          | 0/817 [00:00<?, ?files/s]

Generating train split: 0 examples [00:00, ? examples/s]

### Using Qwen2-VL to generate queries

[Qwen-VL](https://qwen2vl.com/), developed by Alibaba Cloud, is a visual multimodal model from the Qwen series designed to handle inputs like images, text, and bounding boxes, producing text and bounding box outputs.

Key Features:
- **Superior Performance**: Outperforms other similar models on benchmarks like Zero-shot Captioning, VQA, DocVQA, and Grounding.
- **Multilingual Text Recognition**: Especially strong in recognizing bilingual text (Chinese and English) in images.
- **Multi-Image Conversations**: Enables comparison and storytelling across multiple images.
- **High-Resolution Understanding**: Operates at a higher resolution (448 vs. 224), enhancing tasks like fine-grained recognition and document QA.

In [8]:
vl_model = 'Qwen/Qwen2-VL-2B-Instruct'

In [9]:
model = Qwen2VLForConditionalGeneration.from_pretrained(
    vl_model,
    torch_dtype=torch.bfloat16,
    attn_implementation='flash_attention_2',
    device_map=device,
)

You are attempting to use Flash Attention 2.0 without specifying a torch dtype. This might lead to unexpected behaviour
`Qwen2VLRotaryEmbedding` can now be fully parameterized by passing the model config through the `config` argument. All other arguments will be removed in v4.46


Loading checkpoint shards:   0%|          | 0/2 [00:00<?, ?it/s]

In [10]:
# default processer
processor = AutoProcessor.from_pretrained(vl_model)

#### Building ColPali Queries

##### Pydantic Models
The use of Pydantic ensures that the data in the queries are consistently structured and validated, improving reliability when the system processes complex and varied inputs. Each model corresponds to a distinct query type and includes fields for both the query and its explanation.

In [11]:
class GeneralRetrievalQuery(BaseModel):
    broad_topical_query: str
    broad_topical_explanation: str
    specific_detail_query: str
    specific_detail_explanation: str
    visual_element_query: str
    visual_element_explanation: str

In [12]:
class MultiDocumentComparisonQuery(BaseModel):
    comparison_query: str
    comparison_explanation: str
    corroboration_contradiction_query: str
    corroboration_contradiction_explanation: str

In [13]:
class DomainSpecificQuery(BaseModel):
    identified_domain: str
    domain_specific_query: str
    domain_specific_explanation: str
    data_findings_query: str
    data_findings_explanation: str
    applications_implications_query: str
    applications_implications_explanation: str

In [14]:
class VisualElementFocusQuery(BaseModel):
    similar_visual_element_query: str
    similar_visual_element_explanation: str
    text_visual_combination_query: str
    text_visual_combination_explanation: str
    visual_content_understanding_query: str
    visual_content_understanding_explanation: str

In [15]:
class TemporalMetadataQuery(BaseModel):
    temporal_query: str
    temporal_explanation: str
    topic_metadata_combination_query: str
    topic_metadata_combination_explanation: str
    update_related_document_query: str
    update_related_document_explanation: str

In [16]:
class DifficultyAmbiguityQuery(BaseModel):
    simple_query: str
    simple_explanation: str
    complex_query: str
    complex_explanation: str
    ambiguous_query: str
    ambiguous_explanation: str

In [17]:
class MultilingualMultimodalQuery(BaseModel):
    multilingual_query: str
    multilingual_explanation: str
    multimodal_combination_query: str
    multimodal_combination_explanation: str
    text_visual_understanding_query: str
    text_visual_understanding_explanation: str

##### Prompting

Different prompts are created based on the Pydantic models to generate multiple query sets, each based on distinct ideas or aspects. The dataset will be built around seven different templates, each focusing on a different dimension or type of query. These templates allow the system to cover a broad range of scenarios and use cases

In [18]:
gral_template = '''

You are an AI assistant specialized in document retrieval tasks within the financial domain, specifically for annual reports. Given an image of a document page, your task is to generate retrieval queries that someone might use to find this financial document in a large corpus of reports.

Please generate 3 different types of retrieval queries:

1. A broad topical query: This should cover the main subject of the document, such as financial performance, key company metrics, or strategic initiatives.
2. A specific detail query: This should focus on a particular fact, financial figure (e.g., revenue, net profit), or specific point made in the document.
3. A visual element query: This should reference a chart, financial graph, or other visual components such as balance sheets or income statements, if present.

Important guidelines:
- Ensure the queries are relevant for retrieval tasks, particularly focusing on financial data, and not just describing the page content.
- Frame the queries as if someone is searching for this financial document in a corpus of reports, not asking questions about its content.
- Make the queries diverse and representative of different search strategies, including financial terms and specific company performance indicators.

For each query, also provide a brief explanation of why this query would be effective in retrieving this financial document.

Format your response as a JSON object with the following structure:

{
  "broad_topical_query": "Your query here",
  "broad_topical_explanation": "Brief explanation",
  "specific_detail_query": "Your query here",
  "specific_detail_explanation": "Brief explanation",
  "visual_element_query": "Your query here",
  "visual_element_explanation": "Brief explanation"
}

If there are no relevant visual elements, replace the third query with another specific detail query that references financial data.

Here is the document image to analyze:
<image>

Generate the queries based on this image and provide the response in the specified JSON format.

'''

In [19]:
comp_template = '''

Imagine this financial document page is part of a larger corpus of annual reports. Your task is to generate retrieval queries that would require comparing this document with others in the corpus, particularly focusing on financial data, company performance, and market trends.

Please generate 2 retrieval queries:

1. A query comparing this document’s financial performance, trends, or metrics with a related subject, such as performance from a different year, a competitor's report, or industry benchmarks.
2. A query seeking documents that either contradict or support the financial figures, strategies, or statements made in this document (e.g., conflicting market trends, opposing financial analyses, or differing growth projections).

For each query, provide a brief explanation of how it encourages document comparison and why it would be effective for retrieval within a financial corpus.

Format your response as a JSON object with the following structure:

{
  "comparison_query": "Your query here",
  "comparison_explanation": "Brief explanation",
  "corroboration_contradiction_query": "Your query here",
  "corroboration_contradiction_explanation": "Brief explanation"
}

Here is the document image to analyze:
<image>

Generate the queries based on this image and provide the response in the specified JSON format.
'''

In [20]:
dom_template = '''
Your task is to create retrieval queries that a financial professional or analyst might use to find this document in a large corpus of financial documents, specifically annual reports.

First, identify the domain of the document as "financial."

Then, generate 3 retrieval queries:

1. A query using domain-specific terminology, such as key financial metrics, accounting terms, or industry-specific jargon.
2. A query seeking specific financial data or findings presented in the document, such as revenue, net income, cash flow, or key performance indicators (KPIs).
3. A query related to the document’s potential applications or implications, such as its relevance to investment decisions, market positioning, or future growth strategies.

For each query, provide a brief explanation of its relevance to the financial domain and why it would be effective for retrieval in a corpus of annual reports.

Format your response as a JSON object with the following structure:

{
  "identified_domain": "financial",
  "domain_specific_query": "Your query here",
  "domain_specific_explanation": "Brief explanation",
  "data_findings_query": "Your query here",
  "data_findings_explanation": "Brief explanation",
  "applications_implications_query": "Your query here",
  "applications_implications_explanation": "Brief explanation"
}

Here is the document image to analyze:
<image>

Generate the queries based on this image and provide the response in the specified JSON format.
'''

In [21]:
vis_template = '''
Your task is to generate retrieval queries focusing on the visual elements present in this financial document page (such as financial charts, tables, graphs, or diagrams).

Please generate 3 retrieval queries:

1. A query specifically asking for documents with similar financial visual elements, such as bar charts of revenue trends, pie charts of market share, or financial tables (e.g., balance sheets, income statements).
2. A query combining textual and visual financial information, such as connecting financial figures in the text (e.g., revenue, net income) with their representation in graphs or tables.
3. A query that would require understanding the content of the financial visual element, such as interpreting the performance trend in a line chart or analyzing the relationship between metrics in a financial table, to retrieve this document.

For each query, provide a brief explanation of how it incorporates financial visual elements and why it would be effective for retrieval in a financial corpus.

Format your response as a JSON object with the following structure:

{
  "similar_visual_element_query": "Your query here",
  "similar_visual_element_explanation": "Brief explanation",
  "text_visual_combination_query": "Your query here",
  "text_visual_combination_explanation": "Brief explanation",
  "visual_content_understanding_query": "Your query here",
  "visual_content_understanding_explanation": "Brief explanation"
}

If the document lacks significant visual elements, explain this and generate alternative queries focusing on the financial document's structure or layout (e.g., section headings, data tables).

Here is the document image to analyze:
<image>

Generate the queries based on this image and provide the response in the specified JSON format.
'''

In [22]:
temp_template = '''
Assuming this financial document is part of a large, diverse corpus of annual reports, your task is to generate retrieval queries that incorporate metadata or temporal aspects relevant to financial reporting.

Please generate 3 retrieval queries:

1. A query specifying a likely time frame for this document, such as the fiscal year or publication date (e.g., "2023 annual report" or "Q4 financial report").
2. A query combining financial topical information (e.g., revenue, net income) with a metadata element, such as the company name, report type (e.g., balance sheet, income statement), or auditor name.
3. A query seeking updated or related financial reports on the same topic, such as subsequent reports from the same company or financial updates for the same fiscal year.

For each query, provide a brief explanation of how it uses temporal or metadata information and why it would be effective for retrieving financial documents.

Format your response as a JSON object with the following structure:

{
  "temporal_query": "Your query here",
  "temporal_explanation": "Brief explanation",
  "topic_metadata_combination_query": "Your query here",
  "topic_metadata_combination_explanation": "Brief explanation",
  "update_related_document_query": "Your query here",
  "update_related_document_explanation": "Brief explanation"
}

Here is the document image to analyze:
<image>

Generate the queries based on this image and provide the response in the specified JSON format.
'''

In [23]:
diff_template = '''
Your task is to create retrieval queries for this financial document, considering different levels of complexity and ambiguity, which reflect common information retrieval tasks in a corpus of financial reports.

Please generate 3 retrieval queries:

1. A simple, straightforward query focused on a single aspect of the document, such as a key financial figure (e.g., revenue, net income).
2. A complex query that requires understanding multiple financial metrics, trends, or sections of the document (e.g., linking financial performance with strategic initiatives or multiple sections like balance sheets and income statements).
3. An ambiguous query that could retrieve this document among others, possibly due to a more general term (e.g., "annual financial performance" or "company revenue trends"), which could apply to many documents in the corpus.

For each query, provide a brief explanation of its complexity level or ambiguity and why it would be effective or challenging for retrieval in the context of financial documents.

Format your response as a JSON object with the following structure:

{
  "simple_query": "Your query here",
  "simple_explanation": "Brief explanation",
  "complex_query": "Your query here",
  "complex_explanation": "Brief explanation",
  "ambiguous_query": "Your query here",
  "ambiguous_explanation": "Brief explanation"
}

Here is the document image to analyze:
<image>

Generate the queries based on this image and provide the response in the specified JSON format.

'''

In [24]:
mll_template = '''
Your task is to generate retrieval queries considering potential multilingual and multi-modal aspects of this financial document.

Please generate 3 retrieval queries:

1. A query in a different language (if applicable) that would retrieve this financial document (e.g., a query in Spanish, French, or another relevant language).
2. A query combining textual financial data (e.g., revenue, net income) with non-textual elements like charts, graphs, or tables representing this data visually.
3. A query that requires understanding both the financial text and visual elements (e.g., interpreting financial performance from text descriptions and visualizing trends in a graph or table) to retrieve this document accurately.

For each query, provide a brief explanation of its multilingual or multi-modal nature and why it would be effective for retrieving financial documents.

Format your response as a JSON object with the following structure:

{
  "multilingual_query": "Your query here",
  "multilingual_explanation": "Brief explanation",
  "multimodal_combination_query": "Your query here",
  "multimodal_combination_explanation": "Brief explanation",
  "text_visual_understanding_query": "Your query here",
  "text_visual_understanding_explanation": "Brief explanation"
}

If the document is not suitable for multilingual queries, explain why and provide an alternative query that focuses on the financial structure or visual layout.

Here is the document image to analyze:
<image>

Generate the queries based on this image and provide the response in the specified JSON format.
'''

In [25]:
def get_retrieval_prompt(prompt_name: str,) -> Tuple[str, Union[
        GeneralRetrievalQuery,
        MultiDocumentComparisonQuery,
        DomainSpecificQuery,
        VisualElementFocusQuery,
        TemporalMetadataQuery,
        DifficultyAmbiguityQuery,
        MultilingualMultimodalQuery,
    ],
]:
    prompts = {
        "general": (gral_template,GeneralRetrievalQuery),
        "comparison": (comp_template,MultiDocumentComparisonQuery),
        "domain": (dom_template,DomainSpecificQuery),
        "visual": (vis_template,VisualElementFocusQuery),
        "temporal": (temp_template,TemporalMetadataQuery),
        "difficulty": (diff_template,DifficultyAmbiguityQuery),
        "multilingual": (mll_template, MultilingualMultimodalQuery),
    }

    if prompt_name not in prompts:
        raise ValueError(
            f"Invalid prompt name. Choose from: {', '.join(prompts.keys())}"
        )

    return prompts[prompt_name]

#### Generating ColPali Queries

The following function generates a multimodal response by combining text (prompt) and image inputs. It uses a pretrained visual-language model (Qwen) and processor to interpret the input and generate a response, which will be useful for our visual question answering task.

In [26]:
def generate_response(prompt, image):
    messages = [
        {
            "role": "user",
            "content": [
                {
                    "type": "image",
                    "image": image,
                },
                {"type": "text", "text": prompt},
            ],
        }
    ]

    text = processor.apply_chat_template(
        messages, tokenize=False, add_generation_prompt=True
    )

    image_inputs, video_inputs = process_vision_info(messages)

    inputs = processor(
        text=[text],
        images=image_inputs,
        videos=video_inputs,
        padding=True,
        return_tensors="pt",
    )
    inputs = inputs.to(device)

    generated_ids = model.generate(**inputs, max_new_tokens=400)
    generated_ids_trimmed = [
        out_ids[len(in_ids) :]
        for in_ids, out_ids in zip(inputs.input_ids, generated_ids)
    ]

    output_text = processor.batch_decode(
        generated_ids_trimmed,
        skip_special_tokens=True,
        clean_up_tokenization_spaces=False,
    )

    return output_text

##### Creating Queries

Queries are generated for a set of images by using predefined prompt templates and corresponding Pydantic models. Each query is tailored to a specific query type, and the system generates responses based on both the image and the prompt.

For the purposes of this notebook, only **general queries** will be generated and run through inference.

In [27]:
prompts = {} 
images = dataset['image']

for prompt_name in ['general']:
    prompt, pydantic_model = get_retrieval_prompt(prompt_name)
    responses = []
    for image in tqdm(images, desc=f'Generating queries with {prompt_name} template: '):
        try:
            resp = generate_response(prompt, image)
            responses.append(resp)
        except Exception as e:
            responses.append(None)
    prompts[prompt_name] = responses    

Generating queries with general template: 100%|██████████| 817/817 [1:03:45<00:00,  4.68s/it]


#### Parsing Queries Into a Dataset

A custom function is used to parse the responses (which are in a JSON-like format) into a Python dictionary. The function ensures that all expected keys from the JSON-like string are included in the final output. If a key contains an invalid or incomplete value, the function assigns None to that key, ensuring consistent structure and completeness in the dataset.

In [28]:
def parse_string_to_dict(s):
    """
    Parses a JSON-like string into a dictionary, ensuring all keys are included.
    Assigns None to keys with invalid or incomplete values.

    Args:
        s (str): The input string containing key-value pairs.

    Returns:
        dict: A dictionary with all keys from the input string. Valid values are assigned,
              and None is assigned to keys with invalid or incomplete values.
    """
    try:
        # regular expression patterns
        key_pattern = r'"([^"\\]*(?:\\.[^"\\]*)*)"\s*:'  # matches keys
        kv_pattern = r'"([^"\\]*(?:\\.[^"\\]*)*)"\s*:\s*"([^"\\]*(?:\\.[^"\\]*)*)"'  # matches valid "key": "value" pairs

        # extract all keys
        keys = re.findall(key_pattern, s)
        # Unescape any escaped characters in keys
        keys = [bytes(key, "utf-8").decode("unicode_escape") for key in keys]

        # extract valid key-value pairs
        valid_kv_matches = re.findall(kv_pattern, s)
        valid_kv = {}
        for key, value in valid_kv_matches:
            try:
                # unescape any escaped characters
                unescaped_key = bytes(key, "utf-8").decode("unicode_escape")
                unescaped_value = bytes(value, "utf-8").decode("unicode_escape")
                valid_kv[unescaped_key] = unescaped_value
            except UnicodeDecodeError:
                # if decoding fails, skip this key-value pair
                continue

        # construct the final dictionary
        final_dict = {}
        for key in keys:
            if key in valid_kv:
                final_dict[key] = valid_kv[key]
            else:
                final_dict[key] = None

        return final_dict
    except:
        return None

In [29]:
for prompt_name in ['general']:
    print(f'number of None responses in {prompt_name}: {len([r for r in prompts[prompt_name] if r is None])}')    

number of None responses in general: 0


In [30]:
general_qa = prompts['general']
general_qa = [qa[0].replace('```','').replace('json','').replace('\n','') for qa in general_qa]
general_qa = [parse_string_to_dict(qa) for qa in general_qa]

In [31]:
dataset = dataset.add_column(name='queries',column=general_qa)

In [32]:
def explode_queries(batch):
    images = []
    queries = []
    answers = []

    # loop over each row in the batch
    for i in range(len(batch['queries'])):
        # extract the current dictionary of queries for each row
        for query, answer in batch['queries'][i].items():
            if answer is not None:  # only add entries where answer is not None
                images.append(batch['image'][i])  # append the corresponding image
                queries.append(query)
                answers.append(answer)

    return {'image': images, 'query': queries, 'answer': answers}

In [33]:
exploded_dataset = dataset.map(explode_queries, batched=True, remove_columns=['queries'])
exploded_dataset = exploded_dataset.train_test_split(test_size=0.3)

Map:   0%|          | 0/817 [00:00<?, ? examples/s]

In [34]:
# inspect output
exploded_dataset

DatasetDict({
    train: Dataset({
        features: ['image', 'query', 'answer'],
        num_rows: 3428
    })
    test: Dataset({
        features: ['image', 'query', 'answer'],
        num_rows: 1470
    })
})

In [35]:
# store output
file_name = 'data/processed/annual reports'
exploded_dataset.save_to_disk(file_name)

Saving the dataset (0/2 shards):   0%|          | 0/3428 [00:00<?, ? examples/s]

Saving the dataset (0/1 shards):   0%|          | 0/1470 [00:00<?, ? examples/s]