<a href="https://colab.research.google.com/github/dvgiannis/Document-Understanding/blob/main/Thesis_Code_%5BGitHub%5D_v5_7.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

## Version History and Changelog

### **Version 1.0**: Initial Setup and Dataset Loading
- Loaded the dataset of 5,000 PDFs and XML files.
- Verified file counts and folder structure.

### **Version 1.1**: Folder Structure Analysis and File Statistics
- Analyzed nested folder structure and extracted file counts and sizes for PDFs and XMLs.

### **Version 1.2**: Outlier Detection for PDFs
- Performed statistical analysis on page counts and file sizes.
- Identified outliers such as large files with few pages and small files with high page counts.
- Calculated correlations between file size and page counts.

### **Version 1.3**: Efficient File Path Indexing
- Pre-indexed file paths for PDFs and XMLs to optimize processing time.
- Implemented file path resolution logic.

### **Version 1.4**: XML Parsing Enhancements
- Addressed parsing errors for 216 malformed XML files.
- Included additional tags relevant to the thesis such as `title`, `keywords`, and `pub-date`.

### **Version 1.5**: Final Dataset Preparation
- Created a filtered dataset excluding error-prone files.
- Added columns for `pdf_path` and `xml_path`.
- Generated a new folder structure containing only valid PDFs and XMLs.

### **Version 1.6**: Parallelized Workflows
- Implemented parallel processing for file path lookups and copying.
- Optimized metadata extraction using `concurrent.futures`.

### **Version 1.7**: Thesis Contextual Updates
- Updated dataset analysis to emphasize its relevance to document understanding with AI.
- Integrated insights from XML metadata such as title lengths, keyword distribution, and structural metadata.
- Refined the focus of the analysis to align with thesis objectives.

---

### **Version 2.0**: Initial PDF-to-Image Conversion and Image Resizing
- Implemented PDF-to-image conversion using `pdf2image` with multiprocessing for efficient processing.
- Resized images to 1024x1024 pixels for compatibility with ColPali and ColQwen models.

### **Version 2.1**: Precision Experiments and Batch Size Selection
- Experimented with FP16, FP32, and bfloat16 for model inference.
- Conducted batch size experiments (6, 8, and 10).

### **Version 2.2**: Pipeline Optimization
- Enhanced crash recovery by saving intermediate results after each batch.
- Integrated GPU memory monitoring to track allocation and prevent overflow.
- Logged invalid JSON outputs for debugging and improved error handling.

### **Version 2.3**: Dataset Generation
- Created an image-query dataset from the 7239 images of the 1000 PDFs.
- Restructured dataset to parquet format for processing performance and HuggingFace compatibility.
- Pushed dataset to HuggingFace with Viewer capabilities.

---

### **Version 3.0**: Batch Size Optimization and System Monitoring
- Conducted experiments with various batch sizes for query, passage, and scoring to optimize memory usage and processing speed.
- Established optimal batch sizes (`batch_query=32`, `batch_passage=32`, `batch_score=512`).
- Integrated system monitoring to track resource utilization, ensuring efficient hardware usage during processing.

### **Version 3.1**: Shard-Based Dataset Evaluation
- Transitioned from full dataset evaluation to shard-based processing to manage memory constraints and runtime challenges.
- Implemented logging mechanisms to record performance for each shard, enabling incremental progress tracking and crash recovery.
- Introduced timestamp-based logging for detailed shard-level insights.

### **Version 3.2**: Log Consolidation and Performance Analysis
- Consolidated shard metrics into a unified CSV file for detailed analysis.
- Performed comprehensive evaluation of shard performance.
- Extracted dataset-wide insights by comparing shard-level metrics and trends.

### **Version 3.3**: Pipeline Optimization and Scalability Enhancements
- Refined the shard processing pipeline to ensure scalability and robust error handling for large datasets.
- Improved flexibility to adjust the number of shards processed for experimentation.
- Enhanced system monitoring integration, correlating resource utilization data with shard performance for better optimization.

### **Version 3.4**: Retriever Experiments and Metrics Analysis
- Conducted experiments with ColPali and ColQwen2 retrievers using the shared dataset.
- Logged and analyzed performance metrics for both retrievers to compare their effectiveness.
- Insights derived from metrics used to validate model performance across different retrievers.

---

### **Version 4.0**: Initial PaperMage PDF Processing Implementation
- Implemented PaperMage CoreRecipe for title, author, abstract, keyword, and bibliography extraction.
- Designed basic pipeline to process PDFs and extract structured data.
- Tested PaperMage output against the ground truth XML data.

### **Version 4.1**: Parallel Processing with ThreadPoolExecutor
- Integrated ThreadPoolExecutor to process multiple PDFs in parallel.
- Adjusted batch processing strategy to process PDFs in chunks instead of all at once.
- Benchmarked execution time with 4, 8, and 12 workers.

### **Version 4.2**: Handling System Crashes & Memory Optimization
- Introduced incremental saving after each processed PDF.
- Implemented resume mechanism to skip already processed PDFs.
- Limited worker threads dynamically based on available system resources.

### **Version 4.3**: Evaluation of Papermage Model
- Loaded extracted_data_PAPERMAGE.json (model output) and extracted_data_XML.json (ground truth).
- Inspected file structures and identified challenges such as formatting differences and empty fields.
- Designed an evaluation strategy using token-level Precision, Recall, and F1-score, macro-averaged over Titles, Authors, Abstracts, Keywords, and Bibliographies.
- Ran the first complete evaluation and generated a per-document evaluation CSV file.

### **Version 4.4**: Refinements and Validation Methods
- Identified potential unfair scoring issues
- Implemented manual spot-checking methods to identify false positives and false negatives.
- Ran the second complete evaluation and generated a per-document evaluation CSV file.

### **Version 4.5**: Fair Scoring Adjustments
- Adjusted the scoring system: Assigned F1 = 1.0 for perfect empty matches (to avoid penalizing correct empty extractions).
- Excluded empty categories from macro-averaging to prevent unfairly lowering the overall score.
- Re-ran the evaluation with these fixes.

---

### **Version 5.0: Initial RAG System Integration**  
- Designed the basic Retrieval-Augmented Generation (RAG) architecture for scientific PDFs.  
- Integrated Byaldi (ColQwen) for document indexing and semantic retrieval.  
- Integrated Qwen2.5-7B-Instruct for answer generation based on retrieved context.  
- Built early prototype manually parsing PDFs without layout structure.

### **Version 5.1: Layout-Aware Parsing with Papermage**  
- Implemented Papermage CoreRecipe for structured parsing of PDF tokens, sentences, blocks, and images.  
- Replaced plain text extraction with layout-aware token extraction.  
- Aligned parsed page text as input to the Qwen language model.

### **Version 5.2: Sentence-Level Highlighting with MiniLM**  
- Integrated sentence-transformer (all-MiniLM-L6-v2) for sentence embeddings.  
- Implemented semantic alignment between generated answers and page sentences using cosine similarity.  
- Highlighted matched tokens on retrieved PDF pages.

### **Version 5.3: Answer Fallback and Robustness Improvements**  
- Added handling for fallback answers when Qwen reports "The answer is not in the provided context."  
- Introduced retrieval score filtering to avoid highlighting irrelevant pages.  
- Added exception handling for full pipeline errors with tracebacks.

### **Version 5.4: Gradio Interface Deployment**  
- Built a simple Gradio app with input textbox for questions.  
- Displayed generated answer, highlighted PDF page image, and document filename.  
- Managed outputs based on retrieval success or fallback states.

# **Phase 1 - Preparation and Exploration**

## **Import dataset**

In [None]:
# Mount Google Drive
from google.colab import drive
drive.mount('/content/drive')

In [None]:
import random
import json
from tqdm import tqdm
import os
import shutil

# Define the source and destination paths
source_path = 'PLACEHOLDER'
destination_path = '/content/dataset_initial'

# Get the file size for the progress bar
file_size = os.path.getsize(source_path)

# Define a function to copy the file with a progress bar
def copy_with_progress(src, dest):
    with open(src, 'rb') as fsrc, open(dest, 'wb') as fdest:
        with tqdm(total=file_size, unit='B', unit_scale=True, desc="Copying file") as pbar:
            while True:
                buffer = fsrc.read(1024 * 1024)  # Read in chunks of 1MB
                if not buffer:
                    break
                fdest.write(buffer)
                pbar.update(len(buffer))

# Copy the file
copy_with_progress(source_path, destination_path)

print("File copied successfully!")

In [None]:
import tarfile
import os
from tqdm import tqdm

# Define the .tar file path and the extraction directory
tar_file_path = '/content/dataset_initial'
extraction_folder = '/content/dataset_untarred'

# Create the extraction folder if it doesn't exist
os.makedirs(extraction_folder, exist_ok=True)

# Extract the .tar file with a progress bar
with tarfile.open(tar_file_path, 'r') as tar:
    members = tar.getmembers()  # List of files in the tar archive
    total_files = len(members)  # Total number of files

    # Progress bar for extraction
    with tqdm(total=total_files, unit='file', desc="Extracting files") as pbar:
        for member in members:
            tar.extract(member, path=extraction_folder)
            pbar.update(1)

print(f"Files successfully extracted to {extraction_folder}")

In [None]:
import os
import zipfile
from tqdm import tqdm

# Define the paths
subfolders_path = 'PLACEHOLDER'  # Path containing zipped folders
dataset_folder = '/content/dataset_unzipped'  # Destination folder for unzipped content

# Create the destination folder if it doesn't exist
os.makedirs(dataset_folder, exist_ok=True)

# Iterate through all zipped files in the folder
zipped_files = [f for f in os.listdir(subfolders_path) if f.endswith('.zip')]

# Unzip each file into its own folder
with tqdm(total=len(zipped_files), desc="Unzipping folders", unit="file") as pbar:
    for zipped_file in zipped_files:
        zipped_file_path = os.path.join(subfolders_path, zipped_file)
        # Create a unique subdirectory for each .zip file
        unzip_subdir = os.path.join(dataset_folder, os.path.splitext(zipped_file)[0])
        os.makedirs(unzip_subdir, exist_ok=True)

        # Extract contents to the specific subdirectory
        with zipfile.ZipFile(zipped_file_path, 'r') as zip_ref:
            zip_ref.extractall(unzip_subdir)

        pbar.update(1)

print(f"All folders successfully unzipped into {dataset_folder}")

In [None]:
# Count files and directories for verification
total_dirs = sum([len(dirs) for _, dirs, _ in os.walk(dataset_folder)])
total_files = sum([len(files) for _, _, files in os.walk(dataset_folder)])
file_types = {}

for _, _, files in os.walk(dataset_folder):
    for file in files:
        ext = os.path.splitext(file)[1]
        file_types[ext] = file_types.get(ext, 0) + 1

print("Dataset Summary:")
print(f"Total Directories: {total_dirs}")
print(f"Total Files: {total_files}")
print(f"File Types: {file_types}")

## **Explore the dataset**

### **Folders Structure**

In [None]:
import os

# Path to the dataset
dataset_path = '/content/dataset_unzipped'

# List all files and directories in the dataset
def explore_dataset(path):
    for root, dirs, files in os.walk(path):
        print(f"Directory: {root}")
        print(f"Subdirectories: {dirs}")
        print(f"Number of files: {len(files)}")
        print(f"Files: {files[:5]}")  # Display first 5 files as a sample
        print("-" * 50)

# Explore the dataset
explore_dataset(dataset_path)

# Count the number of PDFs and XML files
pdf_count = 0
xml_count = 0

for root, dirs, files in os.walk(dataset_path):
    for file in files:
        if file.endswith('.pdf'):
            pdf_count += 1
        elif file.endswith('.xml'):
            xml_count += 1

print(f"Total PDFs: {pdf_count}")
print(f"Total XMLs: {xml_count}")

### **PDF Analysis**

In [None]:
import os
from PyPDF2 import PdfReader
import pandas as pd

# Paths
dataset_path = '/content/dataset_unzipped'

# Storage for analysis
pdf_stats = []

# Counter for progress tracking
total_pdfs = 0
processed_pdfs = 0

# Count total PDFs for progress tracking
for root, dirs, files in os.walk(dataset_path):
    total_pdfs += sum(1 for file in files if file.endswith('.pdf'))

print(f"Total PDFs to process: {total_pdfs}")

# Analyze PDFs
for root, dirs, files in os.walk(dataset_path):
    for file in files:
        if file.endswith('.pdf'):
            processed_pdfs += 1
            file_path = os.path.join(root, file)
            try:
                # Read PDF to get page count
                reader = PdfReader(file_path)
                num_pages = len(reader.pages)
                # Get file size
                file_size = os.path.getsize(file_path) / (1024 * 1024)  # Size in MB
                pdf_stats.append({'file': file, 'pages': num_pages, 'size_MB': file_size})
                print(f"[{processed_pdfs}/{total_pdfs}] Processed: {file}, Pages: {num_pages}, Size: {file_size:.2f} MB")
            except Exception as e:
                print(f"[{processed_pdfs}/{total_pdfs}] Error reading PDF {file}: {e}")
                pdf_stats.append({'file': file, 'pages': None, 'size_MB': None, 'error': str(e)})

# Convert to pandas DataFrame
pdf_df = pd.DataFrame(pdf_stats)

# Summarize statistics
pdf_summary = pdf_df[['pages', 'size_MB']].describe()
print("\nPDF Analysis Summary:")
print(pdf_summary)

# Analyze file names
print("\nSample PDF File Names:")
print(pdf_df['file'].head(10))  # Display first 10 file names

# Identify problematic PDFs
problematic_pdfs = pdf_df[pdf_df['pages'].isnull()]
if not problematic_pdfs.empty:
    print("\nProblematic PDFs:")
    print(problematic_pdfs)

#### **Outliers**

In [None]:
import pandas as pd

# Define IQR-based outlier detection
def find_outliers(df, column):
    q1 = df[column].quantile(0.25)
    q3 = df[column].quantile(0.75)
    iqr = q3 - q1
    lower_bound = q1 - 1.5 * iqr
    upper_bound = q3 + 1.5 * iqr
    return df[(df[column] < lower_bound) | (df[column] > upper_bound)]

# Detect outliers in pages and file sizes
page_outliers = find_outliers(pdf_df, 'pages')
size_outliers = find_outliers(pdf_df, 'size_MB')

print(f"Outliers in page count:\n{page_outliers[['file', 'pages', 'size_MB']]}")
print(f"\nOutliers in file size:\n{size_outliers[['file', 'pages', 'size_MB']]}")

In [None]:
!apt-get install -y poppler-utils
!pip install pdf2image

#### **See a sample**

In [None]:
from pdf2image import convert_from_path
from IPython.display import display, Image

# Locate a specific PDF by name
def locate_pdf(pdf_name, root_path):
    for root, dirs, files in os.walk(root_path):
        for file in files:
            if file == pdf_name:
                return os.path.join(root, file)
    return None

# Render the first page of a PDF
def render_pdf_preview(pdf_name, root_path):
    pdf_path = locate_pdf(pdf_name, root_path)
    if not pdf_path:
        print(f"PDF not found: {pdf_name}")
        return

    try:
        # Convert the first page of the PDF to an image
        pages = convert_from_path(pdf_path, dpi=150, first_page=1, last_page=1)
        img_path = f"{pdf_name}_preview.jpg"
        pages[0].save(img_path, 'JPEG')
        print(f"Rendering preview for {pdf_name}")
        display(Image(img_path))
    except Exception as e:
        print(f"Error rendering {pdf_name}: {e}")

# Render previews for outliers in page count
for pdf_name in page_outliers['file'].head(5):  # Limit to first 5 for demo
    render_pdf_preview(pdf_name, dataset_path)

# Render previews for outliers in file size
for pdf_name in size_outliers['file'].head(5):  # Limit to first 5 for demo
    render_pdf_preview(pdf_name, dataset_path)

#### **Organize images**

In [None]:
import os
import shutil

# Create a folder to store all images
output_folder = "/content/pdf_sample_images"
os.makedirs(output_folder, exist_ok=True)

# Move all images to the folder
for file in os.listdir():
    if file.endswith(".jpg"):  # Only move image files
        shutil.move(file, os.path.join(output_folder, file))

print(f"All images moved to {output_folder}")

#### **Summary and Distribution Analysis**

In [None]:
import matplotlib.pyplot as plt

# Page count distribution
pdf_df['pages'].plot(kind='hist', bins=30, figsize=(8, 5))
plt.title('PDF Page Count Distribution')
plt.xlabel('Number of Pages')
plt.ylabel('Count')
plt.show()

# File size distribution
pdf_df['size_MB'].plot(kind='hist', bins=30, figsize=(8, 5))
plt.title('PDF File Size Distribution')
plt.xlabel('File Size (MB)')
plt.ylabel('Count')
plt.show()

# Correlation between page count and file size
correlation = pdf_df[['pages', 'size_MB']].corr()
print("Correlation between page counts and file size:")
print(correlation)

pdf_df.plot.scatter(x='pages', y='size_MB', figsize=(8, 5))
plt.title('Page Counts vs. File Size')
plt.xlabel('Page Count')
plt.ylabel('File Size (MB)')
plt.show()

#### **Content Density Analysis**

In [None]:
import fitz  # PyMuPDF
import random
import pandas as pd
from tqdm import tqdm
import time

# Function to calculate text density
def calculate_text_density(pdf_name, root_path):
    pdf_path = locate_pdf(pdf_name, root_path)
    if not pdf_path:
        return {"file": pdf_name, "average_text_length": None, "error": "File not found"}

    try:
        doc = fitz.open(pdf_path)
        total_text_length = sum(len(page.get_text().strip()) for page in doc)
        average_text_length = total_text_length / len(doc)
        return {"file": pdf_name, "average_text_length": average_text_length, "error": None}
    except Exception as e:
        return {"file": pdf_name, "average_text_length": None, "error": str(e)}

# Sample PDFs for analysis
sample_size = 50  # Adjust based on desired sample size
sample_pdfs = random.sample(pdf_df['file'].tolist(), sample_size)

# Track progress and execution time
start_time = time.time()

# Analyze PDFs with progress tracking
results = []
for pdf_name in tqdm(sample_pdfs, desc="Analyzing PDFs"):
    result = calculate_text_density(pdf_name, dataset_path)
    results.append(result)

# Calculate total execution time
end_time = time.time()
execution_time = end_time - start_time
print(f"Analysis completed in {execution_time:.2f} seconds.")

# Convert results to DataFrame
density_df = pd.DataFrame(results)

# Summary of density analysis
print("Text Density Analysis Summary (Sampled PDFs):")
print(density_df.describe())
print("\nErrors Encountered:")
print(density_df[density_df['error'].notnull()][['file', 'error']])

# Save results to CSV for later inspection
density_df.to_csv("/content/text_density_analysis.csv", index=False)

In [None]:
from concurrent.futures import ProcessPoolExecutor, as_completed
import time
from tqdm import tqdm
import pandas as pd

# Function to calculate text density for a single PDF
def calculate_text_density_parallel(args):
    pdf_name, root_path = args
    pdf_path = locate_pdf(pdf_name, root_path)
    if not pdf_path:
        return {"file": pdf_name, "average_text_length": None, "error": "File not found"}

    try:
        doc = fitz.open(pdf_path)
        total_text_length = sum(len(page.get_text().strip()) for page in doc)
        average_text_length = total_text_length / len(doc)
        return {"file": pdf_name, "average_text_length": average_text_length, "error": None}
    except Exception as e:
        return {"file": pdf_name, "average_text_length": None, "error": str(e)}

# Prepare inputs for parallel processing
pdf_list = pdf_df['file'].tolist()
args = [(pdf_name, dataset_path) for pdf_name in pdf_list]

# Run parallel processing
start_time = time.time()
results = []

print("Starting full dataset analysis with parallel processing...")
with ProcessPoolExecutor() as executor:
    futures = {executor.submit(calculate_text_density_parallel, arg): arg[0] for arg in args}
    for future in tqdm(as_completed(futures), total=len(futures), desc="Analyzing PDFs"):
        results.append(future.result())

end_time = time.time()
execution_time = end_time - start_time

# Convert results to DataFrame
density_df_full = pd.DataFrame(results)

# Save results to CSV
output_csv = "/content/text_density_full_dataset.csv"
density_df_full.to_csv(output_csv, index=False)

print(f"Full dataset analysis completed in {execution_time:.2f} seconds.")
print(f"Results saved to: {output_csv}")


### **XML Analysis**

#### **Extract DOCTYPE and Unique Tags**

In [None]:
import os
from lxml import etree
from collections import Counter
from tqdm import tqdm  # For progress bar

def extract_doctype(file_path):
    """
    Extract the DOCTYPE declaration from an XML file.
    """
    try:
        with open(file_path, 'r') as file:
            for line in file:
                line = line.strip()
                if line.startswith("<!DOCTYPE"):
                    return line
        return "No DOCTYPE found"
    except Exception:
        return "Error reading file"

def find_all_xml_files(directory):
    """
    Efficiently find all XML files in a directory, including nested folders.
    """
    for root, _, files in os.walk(directory):
        for file in files:
            if file.endswith(".xml"):
                yield os.path.join(root, file)

def process_xml_files(directory):
    """
    Process all XML files, extract their DOCTYPE declarations, and summarize results.
    """
    doctype_counter = Counter()
    xml_files = list(find_all_xml_files(directory))  # Get all XML file paths

    # Display a progress bar while processing files
    for file_path in tqdm(xml_files, desc="Processing XML files", unit="file"):
        doctype = extract_doctype(file_path)
        doctype_counter[doctype] += 1

    return doctype_counter, len(xml_files)

def main():
    # Replace 'your_directory_path' with the directory containing your XML files
    directory = '/content/dataset_unzipped'

    print("Finding and processing XML files...")
    doctype_counts, total_files = process_xml_files(directory)

    # Print summary
    print("\nSummary of DOCTYPE declarations:")
    for doctype, count in doctype_counts.items():
        print(f"{doctype}: {count} file(s)")
    print(f"\nTotal XML files processed: {total_files}")

if __name__ == "__main__":
    main()


In [None]:
import os
import re
import xml.etree.ElementTree as ET

# Define the dataset path
dataset_path = '/content/dataset_unzipped'

# Function to extract DOCTYPE from an XML file
def extract_doctype(file_path):
    try:
        with open(file_path, 'r', encoding='utf-8') as file:
            for line in file:
                if "<!DOCTYPE" in line:
                    doctype = re.search(r"<!DOCTYPE\s+([^\s>]+)", line)
                    return doctype.group(1) if doctype else "Unknown"
        return "No DOCTYPE found"
    except Exception as e:
        return f"Error reading DOCTYPE: {e}"

# Function to extract unique tags from an XML file
def extract_unique_tags(file_path):
    try:
        tree = ET.parse(file_path)
        root = tree.getroot()
        return set(elem.tag for elem in root.iter())
    except Exception as e:
        return {f"Error parsing {file_path}: {str(e)}"}

# Initialize data structures to hold results
doctypes = {}
unique_tags_set = set()
error_files = []

# Process all XML files
print("Processing XML files...\n")
for root_dir, dirs, files in os.walk(dataset_path):
    for file in files:
        if file.endswith('.xml'):
            file_path = os.path.join(root_dir, file)
            print(f"Processing file: {file_path}")  # Log file paths being processed

            # Extract DOCTYPE
            doctype = extract_doctype(file_path)
            if doctype not in doctypes:
                doctypes[doctype] = 0
            doctypes[doctype] += 1

            # Extract tags
            tags = extract_unique_tags(file_path)
            if any("Error parsing" in tag for tag in tags):
                error_files.append(file_path)
            else:
                unique_tags_set.update(tags)

# Sort and summarize results
unique_tags_set_sorted = sorted(unique_tags_set)

# Output DOCTYPE summary and unique tags
print("\nDOCTYPE Summary:")
for dt, count in doctypes.items():
    print(f"DOCTYPE: {dt}, Count: {count}")

print("\nUnique Tags Extracted:")
print(unique_tags_set_sorted)

print("\nSample Files with Parsing Errors:")
print(error_files[:5])

# Save results for further inspection
import pandas as pd

# Save DOCTYPE summary
doctype_df = pd.DataFrame(list(doctypes.items()), columns=['DOCTYPE', 'Count'])
doctype_df.to_csv('/content/xml_doctype_summary.csv', index=False)

# Save unique tags
tags_df = pd.DataFrame({'Tags': unique_tags_set_sorted})
tags_df.to_csv('/content/xml_unique_tags.csv', index=False)

# Save error files
error_files_df = pd.DataFrame({'File': error_files})
error_files_df.to_csv('/content/xml_error_files.csv', index=False)

print("\nResults saved:")
print("1. DOCTYPE Summary: /content/xml_doctype_summary.csv")
print("2. Unique Tags: /content/xml_unique_tags.csv")
print("3. Error Files: /content/xml_error_files.csv")

#### **Sample Article XML**

In [None]:
import os
from lxml import etree

def find_files_with_doctype(directory, target_doctype):
    """
    Finds XML files with the specified DOCTYPE.
    Args:
    - directory: Path to the directory containing XML files.
    - target_doctype: The DOCTYPE string to search for.

    Returns:
    - A list of file paths matching the target DOCTYPE.
    """
    matching_files = []

    # Find all XML files in the directory
    xml_files = [os.path.join(root, file)
                 for root, _, files in os.walk(directory) for file in files if file.endswith('.xml')]

    for xml_file in xml_files:
        try:
            with open(xml_file, 'r') as f:
                for line in f:
                    if line.strip().startswith("<!DOCTYPE") and target_doctype in line:
                        matching_files.append(xml_file)
                        break
        except Exception as e:
            print(f"Error processing {xml_file}: {e}")

    return matching_files

# Define the directory and target DOCTYPE
xml_directory = "/content/dataset_unzipped"  # Replace with your directory
target_doctype = '<!DOCTYPE article PUBLIC "-//NLM//DTD JATS (Z39.96) Journal Archiving and Interchange DTD with OASIS Tables with MathML3 v1.2d1 20170631//EN" "JATS-archive-oasis-article1-mathml3.dtd">'

# Find matching files
matching_files = find_files_with_doctype(xml_directory, target_doctype)

# Print the results
print(f"Found {len(matching_files)} files with the target DOCTYPE.")
if matching_files:
    print("Example file:", matching_files[0])  # Print one example file path

In [None]:
from lxml import etree

def extract_all_tags(xml_file):
    """
    Extract all unique tags from an XML file.
    Args:
    - xml_file: Path to the XML file.

    Returns:
    - A sorted list of unique tags.
    """
    try:
        # Parse the XML file
        tree = etree.parse(xml_file)
        root = tree.getroot()

        # Recursively collect all unique tags
        tags = set()
        def collect_tags(element):
            tags.add(element.tag)
            for child in element:
                collect_tags(child)

        collect_tags(root)

        return sorted(tags)

    except Exception as e:
        print(f"Error extracting tags: {e}")
        return None

# Example Usage
xml_file_path = "/content/dataset_unzipped/sigcomm-ccr_v41_i4_2043164.2018481_20231012155618/2043164/2043164.2018481/2043164.2018481.xml"  # Replace with your XML file path
tags = extract_all_tags(xml_file_path)

# Print the list of unique tags
if tags:
    print("Unique Tags in the XML:")
    for tag in tags:
        print(tag)

In [None]:
def print_tree(element, level=0):
    """
    Print the XML structure as a tree.
    Args:
    - element: An lxml element object.
    - level: Current depth in the tree (for indentation).
    """
    indent = "  " * level
    print(f"{indent}- {element.tag}")
    for child in element:
        print_tree(child, level + 1)

# Parse and print the tree
tree = etree.parse(xml_file_path)
root = tree.getroot()

print("XML Structure:")
print_tree(root)


#### **Sample Book XML**

In [None]:
import os
from lxml import etree

# Define the directory and target DOCTYPE
xml_directory = "/content/dataset_unzipped"  # Replace with your directory
target_doctype = '<!DOCTYPE book-part-wrapper PUBLIC "-//NLM//DTD BITS Book Interchange DTD with OASIS and XHTML Tables v2.0 20151225//EN" "BITS-book-oasis2.dtd">'

# Find matching files
matching_files = find_files_with_doctype(xml_directory, target_doctype)

# Print the results
print(f"Found {len(matching_files)} files with the target DOCTYPE.")
if matching_files:
    print("Example file:", matching_files[0])  # Print one example file path

In [None]:
# Example Usage
xml_file_path = "/content/dataset_unzipped/2464576.2482704_20231012112033/2464576/2464576.2482704/2464576.2482704.xml"  # Replace with your XML file path
tags = extract_all_tags(xml_file_path)

# Print the list of unique tags
if tags:
    print("Unique Tags in the XML:")
    for tag in tags:
        print(tag)

In [None]:
# Parse and print the tree
tree = etree.parse(xml_file_path)
root = tree.getroot()

print("XML Structure:")
print_tree(root)

#### **XML Metadata Parsing**

In [None]:
import os
import xml.etree.ElementTree as ET
import pandas as pd
import matplotlib.pyplot as plt
from tqdm import tqdm

# Define the dataset path
dataset_path = '/content/dataset_unzipped'

# Function to parse a single XML file and extract metadata
def parse_xml_metadata(file_path):
    try:
        tree = ET.parse(file_path)
        root = tree.getroot()

        # Extract all relevant tags including new tags
        metadata = {}
        for elem in root.iter():
            if elem.text and elem.text.strip():
                metadata[elem.tag] = elem.text.strip()
        return metadata
    except Exception as e:
        return {"error": str(e)}

# Iterate through XML files and collect metadata
metadata_list = []
for root_dir, dirs, files in tqdm(os.walk(dataset_path), desc="Processing XML files"):
    for file in files:
        if file.endswith('.xml'):
            file_path = os.path.join(root_dir, file)
            metadata = parse_xml_metadata(file_path)
            metadata['file'] = file
            metadata_list.append(metadata)

# Convert metadata to a DataFrame for analysis
metadata_df = pd.DataFrame(metadata_list)

# Save the extracted metadata for further inspection
metadata_csv = '/content/xml_metadata_analysis.csv'
metadata_df.to_csv(metadata_csv, index=False)

print(f"XML metadata analysis completed. Results saved to: {metadata_csv}")

#### **Validation of XML Parsing**

In [None]:
# Validate XML parsing
error_summary = metadata_df['error'].value_counts(dropna=False)

# Display the summary of parsing results
print("Parsing Validation:")
print(error_summary)

# Identify and list files with errors
error_files = metadata_df[metadata_df['error'].notnull()]
print(f"\nNumber of files with parsing errors: {len(error_files)}")
print("\nSample of files with errors:")
print(error_files[['file', 'error']].head())


#### **Analyze the Metadata CSV**

In [None]:
# Analyze metadata fields
print("\nMetadata Overview:")
print(metadata_df.info())
print("\nSample Metadata:")
print(metadata_df.head())

# Field occurrence analysis
field_counts = metadata_df.count().sort_values(ascending=False)
print("\nField Occurrence Count:")
print(field_counts)

# Plot the frequency of the most common metadata fields
field_counts.head(10).plot(kind='bar', figsize=(10, 6))
plt.title("Top 10 Most Common Metadata Fields")
plt.xlabel("Field Name")
plt.ylabel("Number of Non-Empty Entries")
plt.xticks(rotation=45)
plt.show()

# Analyze publication years using the 'year' tag
if 'year' in metadata_df.columns:
    # Convert the 'year' column to numeric
    metadata_df['year'] = pd.to_numeric(metadata_df['year'], errors='coerce')
    year_distribution = metadata_df['year'].dropna()

    if not year_distribution.empty:
        print("\nPublication Year Statistics:")
        print(year_distribution.describe())

        # Plot the distribution of publication years
        year_distribution.plot(kind='hist', bins=30, figsize=(10, 6), alpha=0.7)
        plt.title("Publication Year Distribution")
        plt.xlabel("Year")
        plt.ylabel("Frequency")
        plt.show()
    else:
        print("\nNo valid 'year' values were found in the metadata.")
else:
    print("\nThe 'year' column does not exist in the metadata DataFrame.")


In [None]:
# Analyze title lengths
if 'title' in metadata_df.columns:
    metadata_df['title_length'] = metadata_df['title'].dropna().apply(len)
    print("\nTitle Length Statistics:")
    print(metadata_df['title_length'].describe())

    # Plot title length distribution
    metadata_df['title_length'].plot(kind='hist', bins=30, figsize=(10, 6))
    plt.title("Title Length Distribution")
    plt.xlabel("Title Length")
    plt.ylabel("Frequency")
    plt.show()

# Keyword analysis
if 'kwd' in metadata_df.columns:
    print("\nSample Keywords:")
    print(metadata_df['kwd'].dropna().sample(10))

# Check for missing data
missing_data = metadata_df.isnull().sum().sort_values(ascending=False)
print("\nMissing Data Per Field:")
print(missing_data)

# Visualize missing data
missing_data.plot(kind='bar', figsize=(10, 6))
plt.title("Missing Data in Metadata Fields")
plt.xlabel("Field Name")
plt.ylabel("Number of Missing Entries")
plt.xticks(rotation=90)
plt.show()

# Field completeness analysis
field_completeness = metadata_df.notnull().mean().sort_values(ascending=False) * 100
print("\nField Completeness (% of non-missing values):")
print(field_completeness)

#### **Update CSV and Create New Dataset**

In [None]:
import os
import pandas as pd
import shutil
from concurrent.futures import ThreadPoolExecutor
from tqdm import tqdm

# Paths
dataset_path = '/content/dataset_unzipped'
metadata_csv = '/content/xml_metadata_analysis.csv'
final_dataset_csv = '/content/final_dataset_with_paths.csv'
output_dataset_path = '/content/final_dataset_unzipped'

# Load the metadata CSV
metadata_df = pd.read_csv(metadata_csv)

# Exclude files with parsing errors
final_dataset = metadata_df[metadata_df['error'].isnull()].copy()

# Pre-index all files in the dataset folder
file_index = {}
for root_dir, _, files in os.walk(dataset_path):
    for file in files:
        file_path = os.path.join(root_dir, file)
        file_index[file] = file_path

# Function to get file paths for PDFs and XMLs
def get_file_paths(row):
    base_name = row['file'].split('.xml')[0]  # Base name without .xml
    pdf_path = file_index.get(f"{base_name}.pdf", None)
    xml_path = file_index.get(f"{base_name}.xml", None)
    return {"pdf_path": pdf_path, "xml_path": xml_path}

# Use parallel processing to locate file paths
with ThreadPoolExecutor() as executor:
    paths_list = list(tqdm(executor.map(get_file_paths, final_dataset.to_dict('records')),
                           desc="Locating file paths", total=len(final_dataset)))

paths_df = pd.DataFrame(paths_list)
final_dataset = pd.concat([final_dataset.reset_index(drop=True), paths_df], axis=1)

# Save the updated dataset with file paths
final_dataset.to_csv(final_dataset_csv, index=False)
print(f"Updated dataset with file paths saved to: {final_dataset_csv}")

In [None]:
# Ensure the output dataset folder exists
os.makedirs(output_dataset_path, exist_ok=True)

# Function to copy files to the new dataset folder
def copy_files(row):
    if row['xml_path'] and os.path.exists(row['xml_path']):
        output_xml_dir = os.path.join(output_dataset_path, os.path.dirname(row['xml_path']).split('/')[-1])
        os.makedirs(output_xml_dir, exist_ok=True)
        shutil.copy(row['xml_path'], output_xml_dir)
    if row['pdf_path'] and os.path.exists(row['pdf_path']):
        output_pdf_dir = os.path.join(output_dataset_path, os.path.dirname(row['pdf_path']).split('/')[-1])
        os.makedirs(output_pdf_dir, exist_ok=True)
        shutil.copy(row['pdf_path'], output_pdf_dir)

# Use parallel processing to copy files
with ThreadPoolExecutor() as executor:
    list(tqdm(executor.map(copy_files, final_dataset.to_dict('records')),
              desc="Copying files", total=len(final_dataset)))

print(f"Filtered dataset files copied to: {output_dataset_path}")

# **Phase 2 - Model Implementation and Evaluation**

## **ColPali & ColQwen**

### **Install libraries**

In [None]:
# Library for PDF to image convertion
!pip install pdf2image
!apt-get install -y poppler-utils

In [None]:
# Library for Qwen and queries generation
!pip install transformers qwen-vl-utils torch torchvision
!apt-get install -y libgl1-mesa-glx

In [None]:
# Supplementary library for Qwen
!pip install flash-attn --no-build-isolation
!pip install --upgrade triton

### **Dataset Preparation**

#### **Sample 1000 PDFs and convert to Images**

In [None]:
import os
import random
from tqdm import tqdm
from pdf2image import convert_from_path
import shutil

# Define Paths
base_dir = '/content/final_dataset_unzipped'
output_dir = '/content/dataset_1000_sample'  # Output directory for the sampled dataset
os.makedirs(output_dir, exist_ok=True)  # Ensures the output directory exists

In [None]:
# Locate All Paper Folders with PDFs
print("Locating paper folders with PDFs...")
paper_folders = []
for root, dirs, files in os.walk(base_dir):
    for file in files:
        if file.endswith('.pdf'):  # Look for PDF files in nested directories
            paper_folders.append(os.path.dirname(os.path.join(root, file)))

# Deduplicate to avoid processing the same folder multiple times
paper_folders = list(set(paper_folders))
print(f"Total paper folders with PDFs found: {len(paper_folders)}")

In [None]:
# Randomly Select 1000 Papers
random.seed(41)
print("Randomly selecting 1000 papers...")
sampled_papers = random.sample(paper_folders, min(1000, len(paper_folders)))
print(f"Selected papers: {len(sampled_papers)}")

In [None]:
from multiprocessing import Pool
from pdf2image import convert_from_path
import shutil
import time
from tqdm import tqdm

def process_pdf_to_images(args):
    paper_folder, output_paper_dir = args
    os.makedirs(output_paper_dir, exist_ok=True)

    # Copy PDF and XML files
    files = [file for file in os.listdir(paper_folder) if file.endswith('.pdf') or file.endswith('.xml')]
    for file in files:
        shutil.copy(os.path.join(paper_folder, file), os.path.join(output_paper_dir, file))

    # Convert PDFs to images
    invalid_count = 0
    pdf_files = [file for file in files if file.endswith('.pdf')]
    for pdf_file in pdf_files:
        pdf_path = os.path.join(paper_folder, pdf_file)
        try:
            pages = convert_from_path(pdf_path, dpi=300)
            for i, page in enumerate(pages):
                image_name = f"{os.path.splitext(pdf_file)[0]}_page_{i + 1}.png"
                page.save(os.path.join(output_paper_dir, image_name), "PNG")
        except Exception as e:
            invalid_count += 1  # Increment invalid PDF count
    return invalid_count

def process_papers(sampled_papers, output_dir):
    args_list = [
        (paper_folder, os.path.join(output_dir, os.path.basename(paper_folder)))
        for paper_folder in sampled_papers
    ]

    total_start_time = time.perf_counter()  # Start tracking total time

    with Pool() as pool:
        with tqdm(total=len(args_list), desc="Processing PDFs", unit="paper") as pbar:
            for invalid_count in pool.imap_unordered(process_pdf_to_images, args_list):
                pbar.update(1)

    total_end_time = time.perf_counter()  # End tracking total time

    print(f"\nTotal time taken: {total_end_time - total_start_time:.2f} seconds")
    print("Processing complete.")

In [None]:
process_papers(sampled_papers, output_dir)

In [None]:
import os

# Path to dataset
dataset_dir = "/content/dataset_1000_sample"

# Initialize counters and trackers
total_images = 0
total_folders = 0
empty_folders = []

for folder_name in os.listdir(dataset_dir):
    folder_path = os.path.join(dataset_dir, folder_name)
    if os.path.isdir(folder_path):
        total_folders += 1
        # Check for PNG images
        images = [file for file in os.listdir(folder_path) if file.endswith('.png')]
        if images:
            total_images += len(images)  # Update total images count
        else:
            empty_folders.append(folder_name)  # Track folders with no images

# Calculate statistics
non_empty_folders = total_folders - len(empty_folders)
average_images_per_folder = total_images / non_empty_folders if non_empty_folders > 0 else 0

# Print summary
print("\nDataset Summary:")
print(f"Total folders processed: {total_folders}")
print(f"Total images found: {total_images}")
print(f"Average images per folder (non-empty): {average_images_per_folder:.2f}")
print(f"Empty folders: {len(empty_folders)}")

if empty_folders:
    print("\nThe following folders are empty:")
    print(empty_folders)

In [None]:
from PIL import Image
import os

# Path to dataset
dataset_dir = "/content/dataset_1000_sample"

# Collect image sizes
image_sizes = []
for folder_name in os.listdir(dataset_dir):
    folder_path = os.path.join(dataset_dir, folder_name)
    if os.path.isdir(folder_path):
        images = [os.path.join(folder_path, file) for file in os.listdir(folder_path) if file.endswith('.png')]
        for image_path in images:
            with Image.open(image_path) as img:
                image_sizes.append(img.size)  # (width, height)

# Analyze image sizes
print(f"Total images checked: {len(image_sizes)}")
print(f"Max size: {max(image_sizes)}")
print(f"Min size: {min(image_sizes)}")
print(f"Average size: {tuple(map(lambda x: sum(x) // len(image_sizes), zip(*image_sizes)))}")

#### **Check a sample**

In [None]:
import os
import random
from PIL import Image
from IPython.display import display

# Path to dataset
dataset_dir = "/content/dataset_1000_sample"

# Collect all image paths
image_paths = []
for folder_name in os.listdir(dataset_dir):
    folder_path = os.path.join(dataset_dir, folder_name)
    if os.path.isdir(folder_path):
        images = [os.path.join(folder_path, file) for file in os.listdir(folder_path) if file.endswith('.png')]
        image_paths.extend(images)

# Check if there are images in the dataset
if image_paths:
    # Select a random image to display
    sample_image_path = random.choice(image_paths)
    print(f"Displaying sample image: {sample_image_path}")

    # Open and display the image inline
    sample_image = Image.open(sample_image_path)
    display(sample_image)
else:
    print("No images found in the dataset.")

### **Generate Images - Queries Pairs**

#### **Create Prompt**

In [None]:
prompt = """
You are an assistant specialized in Multimodal RAG tasks.

The task is the following: given an image from a pdf page, you will have to generate questions that can be asked by a user to retrieve information from a large documentary corpus.

The question should be relevant to the page, and should not be too specific or too general. The question should be about the subject of the page, and the answer needs to be found in the page.

Remember that the question is asked by a user to get some information from a large documentary corpus that contains multimodal data. Generate a question that could be asked by a user without knowing the existence and the content of the corpus.

Generate at most THREE pairs of questions and answers per page in a dictionary with the following format, answer ONLY this dictionary NOTHING ELSE:

{
    "questions": [
        { "question": "XXXXXX", "answer": ["YYYYYY"] },
        { "question": "XXXXXX", "answer": ["YYYYYY"] },
        { "question": "XXXXXX", "answer": ["YYYYYY"] }
    ]
}

where XXXXXX is the question and ['YYYYYY'] is the corresponding list of answers that could be as long as needed.

Note: If there are no questions to ask about the page, return an empty list. Focus on making relevant questions concerning the page.

Here is the page:
"""

#### **Define function for memory checks and intermediate results**

In [None]:
from torch.utils.data import DataLoader
from torch import cuda
import torch
import json
import os
import time
from tqdm import tqdm


def generate_dataset_optimized(image_paths, output_file, batch_size, max_new_tokens):
    """
    Generates a dataset of image-query pairs with GPU memory monitoring and crash recovery.

    Args:
        image_paths (list): List of paths to images.
        output_file (str): Path to save the generated dataset JSON file.
        batch_size (int): Number of images to process per batch.
        max_new_tokens (int): Maximum number of tokens to generate per image.
    """
    # Load existing results if resuming
    if os.path.exists(output_file):
        with open(output_file, "r") as f:
            results = json.load(f)
    else:
        results = []

    processed_images = {result["image"] for result in results}  # Track already processed images
    remaining_images = [img for img in image_paths if img not in processed_images]

    # Metrics counters
    total_image_query_pairs = sum(len(r["questions"]) for r in results)
    successful_image_queries = len([r for r in results if r["questions"]])
    invalid_count = 0

    dataloader = DataLoader(remaining_images, batch_size=batch_size, shuffle=False)
    device = 'cuda' if cuda.is_available() else 'cpu'

    # Start tracking time
    total_start_time = time.perf_counter()

    print("Generating queries and answers...")
    with tqdm(total=len(remaining_images), desc="Processing images", unit="image") as pbar:
        for batch_images in dataloader:
            batch_start_time = time.perf_counter()  # Start batch timing

            # GPU memory usage before processing the batch
            print(f"Before processing batch: Allocated: {torch.cuda.memory_allocated() / 1024**3:.2f} GB | Reserved: {torch.cuda.memory_reserved() / 1024**3:.2f} GB")

            try:
                # Prepare inputs
                messages = [
                    {"role": "user", "content": [
                        {"type": "image", "image": image},
                        {"type": "text", "text": prompt}
                    ]}
                    for image in batch_images
                ]
                texts = [
                    processor.apply_chat_template([message], tokenize=False, add_generation_prompt=True)
                    for message in messages
                ]
                image_inputs, video_inputs = process_vision_info(messages)
                inputs = processor(
                    text=texts,
                    images=image_inputs,
                    videos=video_inputs,
                    padding=True,
                    return_tensors="pt"
                ).to(device)

                # Generate outputs
                with torch.no_grad():
                    generated_ids = model.generate(**inputs, max_new_tokens=max_new_tokens)
                generated_ids_trimmed = [
                    out_ids[len(in_ids):] for in_ids, out_ids in zip(inputs.input_ids, generated_ids)
                ]
                output_texts = processor.batch_decode(
                    generated_ids_trimmed, skip_special_tokens=True, clean_up_tokenization_spaces=False
                )

                # Parse results
                for img_path, output_text in zip(batch_images, output_texts):
                    try:
                        output_json = json.loads(output_text)
                        questions = output_json.get("questions", [])
                        total_image_query_pairs += len(questions)
                        if questions:
                            successful_image_queries += 1
                        results.append({
                            "image": img_path,
                            "questions": questions
                        })
                    except json.JSONDecodeError:
                        invalid_count += 1

                # Save intermediate results
                with open(output_file, "w") as f:
                    json.dump(results, f, indent=4)

            except Exception as e:
                print(f"Error during generation for batch: {e}")
                torch.cuda.empty_cache()  # Free GPU memory on error

            # GPU memory usage after processing the batch
            print(f"After processing batch: Allocated: {torch.cuda.memory_allocated() / 1024**3:.2f} GB | Reserved: {torch.cuda.memory_reserved() / 1024**3:.2f} GB")

            # Track batch time
            batch_end_time = time.perf_counter()
            pbar.set_postfix(batch_time=f"{batch_end_time - batch_start_time:.2f}s")
            pbar.update(len(batch_images))  # Update progress bar

    # End tracking time
    total_end_time = time.perf_counter()

    # Print summary
    print("\nDataset creation summary:")
    print(f"Total images processed: {len(image_paths)}")
    print(f"Total image-query pairs created: {total_image_query_pairs}")
    print(f"Successfully processed images with queries: {successful_image_queries}")
    print(f"Invalid responses: {invalid_count}")
    print(f"Total time taken: {total_end_time - total_start_time:.2f} seconds")
    print(f"Dataset saved to {output_file}.")

#### **Generate dataset**

##### **Take a sample of 10 images**

In [None]:
# Collect image paths
def collect_image_paths(dataset_dir, extension=".png"):
    print("Collecting image paths...")
    image_paths = [
        os.path.join(root, file)
        for root, _, files in os.walk(dataset_dir)
        for file in files
        if file.endswith(extension)
    ]
    print(f"Total images collected: {len(image_paths)}")
    return image_paths

In [None]:
# Collect image paths
dataset_dir = '/content/dataset_1000_sample'
image_paths = collect_image_paths(dataset_dir)

In [None]:
sample_images = image_paths[:10]  # Select the first 10 images

##### **Generate sample dataset with bp16 and Default Size**

###### **Instatianate Model with bfloat16**

In [None]:
import os
import json
from tqdm import tqdm
from transformers import Qwen2VLForConditionalGeneration, AutoTokenizer, AutoProcessor
from qwen_vl_utils import process_vision_info
import torch

# Load the Model and Processor
model = Qwen2VLForConditionalGeneration.from_pretrained(
    "Qwen/Qwen2-VL-7B-Instruct",
    torch_dtype=torch.bfloat16,
    attn_implementation="flash_attention_2",
    device_map="auto",
)
processor = AutoProcessor.from_pretrained("Qwen/Qwen2-VL-7B-Instruct")

###### **Generate sample dataset**

In [None]:
generate_dataset_optimized(
    image_paths=sample_images,
    output_file="experiment_without_fp16.json",
    batch_size=4,
    max_new_tokens=200)

##### **Generate sample dataset with Mixed Precision and Resizing**

###### **Instatianate Model with float16 and mixed precision**

In [None]:
import os
import json
from tqdm import tqdm
from transformers import Qwen2VLForConditionalGeneration, AutoTokenizer, AutoProcessor
from qwen_vl_utils import process_vision_info
import torch

# Load the Model and Processor
model = Qwen2VLForConditionalGeneration.from_pretrained(
    "Qwen/Qwen2-VL-7B-Instruct",
    torch_dtype=torch.float16,
    attn_implementation="flash_attention_2",
    device_map="auto",
)
processor = AutoProcessor.from_pretrained("Qwen/Qwen2-VL-7B-Instruct")

###### **Resize Images**

In [None]:
from PIL import Image
resized_images = []
for img_path in sample_images:
    img = Image.open(img_path)
    img_resized = img.resize((1024, 1024))
    resized_path = img_path.replace(".png", "_resized.png")
    img_resized.save(resized_path)
    resized_images.append(resized_path)

###### **Generate sample dataset**

In [None]:
generate_dataset_optimized(
    image_paths=resized_images,
    output_file="experiment_with_fp16.json",
    batch_size=4,
    max_new_tokens=200)

##### **Generate full dataset with Mixed Precision and Resizing**

###### **Preprocess Images (Resize to 1024x1024)**

In [None]:
from PIL import Image
import os
import time
from tqdm import tqdm
from concurrent.futures import ThreadPoolExecutor

def resize_image(task):
    """
    Resizes a single image.

    Args:
        task (tuple): A tuple containing the input path, output path, and target size.
    """
    input_path, output_path, target_size = task
    try:
        with Image.open(input_path) as img:
            img.resize(target_size).save(output_path)
    except Exception as e:
        return f"Error resizing image {input_path}: {e}"

def resize_images(input_dir, output_dir, target_size=(1024, 1024), max_workers=8):
    """
    Resizes all images in the input directory to the specified target size and saves them to the output directory.
    Includes parallel processing and progress tracking.

    Args:
        input_dir (str): Path to the input directory containing images.
        output_dir (str): Path to the output directory to save resized images.
        target_size (tuple): Target size for resizing (width, height).
        max_workers (int): Number of threads to use for parallel processing.
    """
    start_time = time.perf_counter()  # Start timing

    os.makedirs(output_dir, exist_ok=True)
    tasks = []

    # Collect all image paths
    for folder_name in os.listdir(input_dir):
        folder_path = os.path.join(input_dir, folder_name)
        if os.path.isdir(folder_path):
            output_folder_path = os.path.join(output_dir, folder_name)
            os.makedirs(output_folder_path, exist_ok=True)
            for file_name in os.listdir(folder_path):
                if file_name.endswith('.png'):
                    input_path = os.path.join(folder_path, file_name)
                    output_path = os.path.join(output_folder_path, file_name)
                    tasks.append((input_path, output_path, target_size))

    print(f"Total images to process: {len(tasks)}")

    # Parallel processing with progress tracking
    with ThreadPoolExecutor(max_workers=max_workers) as executor:
        with tqdm(total=len(tasks), desc="Resizing images", unit="image") as pbar:
            for result in executor.map(resize_image, tasks):
                if result:  # Log any errors
                    print(result)
                pbar.update(1)

    end_time = time.perf_counter()  # End timing
    print(f"\nAll images resized and saved to {output_dir}")
    print(f"Total time taken: {end_time - start_time:.2f} seconds")

In [None]:
# Resize all images
input_dir = "/content/dataset_1000_sample"
output_dir = "/content/dataset_1000_sample_resized"
resize_images(input_dir, output_dir, target_size=(1024, 1024), max_workers=8)

###### **Collect All Image Paths**

In [None]:
import os

# Collect all image paths
def collect_image_paths(image_dir):
    image_paths = []
    for folder_name in os.listdir(image_dir):
        folder_path = os.path.join(image_dir, folder_name)
        if os.path.isdir(folder_path):
            image_paths.extend(
                [os.path.join(folder_path, file) for file in os.listdir(folder_path) if file.endswith('.png')]
            )
    return image_paths

resized_image_dir = "/content/dataset_1000_sample_resized"
image_paths = collect_image_paths(resized_image_dir)
print(f"Total images collected: {len(image_paths)}")

###### **Load Model and Processor**

In [None]:
import os
import json
from tqdm import tqdm
from transformers import Qwen2VLForConditionalGeneration, AutoTokenizer, AutoProcessor
from qwen_vl_utils import process_vision_info
import torch

# Load the Model and Processor
model = Qwen2VLForConditionalGeneration.from_pretrained(
    "Qwen/Qwen2-VL-7B-Instruct",
    torch_dtype=torch.bfloat16,
    attn_implementation="flash_attention_2",
    device_map="auto",
)
processor = AutoProcessor.from_pretrained("Qwen/Qwen2-VL-7B-Instruct")

###### **Define Function for Dataset Processing**

In [None]:
from torch.utils.data import DataLoader
from torch import cuda
import torch
import json
import os
import time
from tqdm import tqdm

def process_dataset(image_paths, model, processor, output_file, batch_size, max_new_tokens):
    """
    Processes the dataset of images with mixed precision, includes GPU memory checks, and saves intermediate results.

    Args:
        image_paths (list): List of image paths to process.
        model: The loaded Qwen2VL model.
        processor: The loaded processor.
        output_file (str): Path to save the results.
        batch_size (int): Number of images per batch.
        max_new_tokens (int): Maximum number of tokens to generate.
    """
    # Load existing results if any
    results = []
    if os.path.exists(output_file):
        with open(output_file, "r") as f:
            results = json.load(f)

    processed_images = {result["image"] for result in results}
    remaining_images = [img for img in image_paths if img not in processed_images]

    # Metrics counters
    total_image_query_pairs = sum(len(r["questions"]) for r in results)
    successful_image_queries = len([r for r in results if r["questions"]])
    invalid_count = 0

    dataloader = DataLoader(remaining_images, batch_size=batch_size, shuffle=False)
    device = 'cuda' if cuda.is_available() else 'cpu'

    # Start tracking time
    total_start_time = time.perf_counter()
    print(f"Processing {len(remaining_images)} images...")

    with tqdm(total=len(remaining_images), desc="Processing images", unit="image") as pbar:
        for batch_images in dataloader:
            # Memory usage before batch
            print(f"Before processing batch: Allocated: {torch.cuda.memory_allocated() / 1024**3:.2f} GB | Reserved: {torch.cuda.memory_reserved() / 1024**3:.2f} GB")

            try:
                # Prepare inputs
                messages = [
                    {"role": "user", "content": [
                        {"type": "image", "image": image},
                        {"type": "text", "text": prompt}
                    ]}
                    for image in batch_images
                ]
                texts = [
                    processor.apply_chat_template([message], tokenize=False, add_generation_prompt=True)
                    for message in messages
                ]
                image_inputs, video_inputs = process_vision_info(messages)
                inputs = processor(
                    text=texts,
                    images=image_inputs,
                    videos=video_inputs,
                    padding=True,
                    return_tensors="pt"
                ).to(device)

                # Generate outputs
                with torch.no_grad():
                    generated_ids = model.generate(**inputs, max_new_tokens=max_new_tokens)
                generated_ids_trimmed = [
                    out_ids[len(in_ids):] for in_ids, out_ids in zip(inputs.input_ids, generated_ids)
                ]
                output_texts = processor.batch_decode(
                    generated_ids_trimmed, skip_special_tokens=True, clean_up_tokenization_spaces=False
                )

                # Parse results
                for img_path, output_text in zip(batch_images, output_texts):
                    try:
                        output_json = json.loads(output_text)
                        total_image_query_pairs += len(output_json.get("questions", []))
                        if output_json.get("questions", []):
                            successful_image_queries += 1
                        results.append({
                            "image": img_path,
                            "questions": output_json.get("questions", [])
                        })
                    except json.JSONDecodeError:
                        invalid_count += 1
                        print(f"Invalid JSON output for image {img_path}")

                # Save intermediate results
                with open(output_file, "w") as f:
                    json.dump(results, f, indent=4)

            except Exception as e:
                print(f"Error processing batch: {e}")
                torch.cuda.empty_cache()  # Clear GPU memory on error

            # Memory usage after batch
            print(f"After processing batch: Allocated: {torch.cuda.memory_allocated() / 1024**3:.2f} GB | Reserved: {torch.cuda.memory_reserved() / 1024**3:.2f} GB")

            pbar.update(len(batch_images))

    total_end_time = time.perf_counter()

    # Print summary
    print("\nDataset creation summary:")
    print(f"Total images processed: {len(image_paths)}")
    print(f"Total image-query pairs created: {total_image_query_pairs}")
    print(f"Successfully processed images with queries: {successful_image_queries}")
    print(f"Invalid responses: {invalid_count}")
    print(f"Total time taken: {total_end_time - total_start_time:.2f} seconds")
    print(f"Dataset saved to {output_file}")

###### **Run the Dataset Processing**

In [None]:
batch_size = 6
max_new_tokens = 200
output_file = f"experiment_bf16_batch_size_{batch_size}.json"

process_dataset(image_paths, model, processor, output_file, batch_size, max_new_tokens)

### **Restructure Dataset and push to Hub**

#### **Convert to Parquet**

In [None]:
import json
import os
import shutil
from pathlib import Path

# Paths
colpali_json = "/content/experiment_bf16_batch_size_6.json"  # Original dataset JSON
output_dir = "/content/structured_dataset"  # Directory to store images and JSON file
os.makedirs(output_dir, exist_ok=True)

# Load the dataset
with open(colpali_json, "r") as f:
    data = json.load(f)

# Restructure dataset
structured_data = []
images_dir = os.path.join(output_dir, "images")
os.makedirs(images_dir, exist_ok=True)

# Initialize query_id counter
query_id_counter = 1

for entry in data:
    image_path = entry["image"]
    image_filename = os.path.basename(image_path)

    # Extract the source and page from the filename
    try:
        # Example format: 2464576.2480793_page_1.png
        image_source = image_filename.split("_page_")[0]  # Extract source (e.g., 2464576.2480793)
        image_page = image_filename.split("_page_")[1].split(".")[0]  # Extract page number (e.g., 1)
    except IndexError:
        image_source = "unknown"
        image_page = "unknown"

    # Iterate over all questions in the entry to create one row per image-query-answer combination
    for question_entry in entry.get("questions", []):
        structured_data.append({
            "query_id": query_id_counter,  # Assign a unique numerical query ID
            "query": question_entry["question"],
            "answer": ", ".join(question_entry["answer"]),  # Concatenate answers as a single string
            "image": os.path.join("images", image_filename),  # Relative path to image
            "image_filename": image_filename,
            "image_page": image_page,
            "image_source": image_source,
            "model": "Qwen2-VL-7B-Instruct",
        })

        # Increment the query_id_counter
        query_id_counter += 1

    # Copy the image to the new images directory
    if os.path.exists(image_path):
        shutil.copy(image_path, os.path.join(images_dir, image_filename))
    else:
        print(f"Image not found: {image_path}")

# Save the structured data as a JSON file
output_json_file = os.path.join(output_dir, "structured_dataset.json")
with open(output_json_file, "w") as f:
    json.dump(structured_data, f, indent=4)

print(f"Structured dataset saved to {output_json_file}")
print(f"Images directory: {images_dir}")

In [None]:
import json
import pandas as pd
import os

# Paths
json_file = "/content/structured_dataset/structured_dataset.json"  # The JSON file with query and image data
parquet_path = "/content/structured_dataset/test.parquet"  # Output Parquet file

# Load JSON data
with open(json_file, "r") as f:
    data = json.load(f)

# Convert JSON to DataFrame
df = pd.DataFrame(data)

# Save the DataFrame as a Parquet file
df.to_parquet(parquet_path, index=False, engine="pyarrow")

print(f"Parquet file saved at {parquet_path}")

#### **Push to HuggingFace**

In [None]:
!pip install datasets huggingface_hub

In [None]:
from huggingface_hub import login

login()

In [None]:
from datasets import Dataset, Features, Image, Value
import pandas as pd
import os

# Load the Parquet file into a Pandas DataFrame
df = pd.read_parquet(parquet_path)

# Update the `image` column to store full paths to the images
images_dir = "/content/structured_dataset/images"
df["image"] = df["image"].apply(lambda x: os.path.join(images_dir, os.path.basename(x)))

# Define the dataset schema
features = Features({
    "query_id": Value("int32"),
    "query": Value("string"),
    "answer": Value("string"),
    "image": Image(),  # Define the `image` column as an Image feature
    "image_filename": Value("string"),
    "image_page": Value("string"),
    "image_source": Value("string"),
    "model": Value("string")
})

# Convert the DataFrame into a Hugging Face Dataset
dataset = Dataset.from_pandas(df, features=features)

In [None]:
from datasets import DatasetDict

# Wrap the dataset in a DatasetDict and label it as test
dataset_dict = DatasetDict({"test": dataset})

# Push the dataset to Hugging Face
dataset_dict.push_to_hub("PLACEHOLDER", private=True)

In [None]:
from huggingface_hub import upload_file
import os

# Define repository details
repo_id = "PLACEHOLDER"  # Replace with your username and dataset name

# Upload all images to the repository under the `images/` directory
images_dir = "/content/structured_dataset/images"
for image_file in os.listdir(images_dir):
    image_path = os.path.join(images_dir, image_file)
    upload_file(
        path_or_fileobj=image_path,
        path_in_repo=f"images/{image_file}",
        repo_id=repo_id,
        repo_type="dataset"
    )

### **ViDoRe Benchmark**

#### **Enviroment setup**

##### **Clone Repo and Install package in Editable Mode**

In [None]:
!git clone https://github.com/illuin-tech/vidore-benchmark.git
%cd vidore-benchmark

In [None]:
!pip install -e .

##### **Install Retrievers and Engine**

In [None]:
!pip install colpali-engine

In [None]:
!pip install "vidore-benchmark[all-retrievers]"

In [None]:
!pip install "vidore-benchmark[colpali-engine]"

##### **Install supplementary libraries**

In [None]:
# Supplementary library for Qwen
!pip install flash-attn --no-build-isolation
!pip install --upgrade triton

In [None]:
# Upgrage transformers to support flash attention
!pip install --upgrade transformers

In [None]:
!pip install huggingface_hub

In [None]:
from huggingface_hub import login

login()

##### **Monitoring with TensorBoard**

In [None]:
import psutil
import GPUtil
import threading
import time
import pandas as pd
import os
import datetime
import tensorflow as tf

# Global variables
stop_monitoring = False
monitoring_data = []  # List to store logged data in memory
log_file_path = "resource_usage_log.csv"  # File to log data

# Initialize TensorBoard SummaryWriter
log_dir = "logs/system_metrics/" + datetime.datetime.now().strftime("%Y%m%d-%H%M%S")
writer = tf.summary.create_file_writer(log_dir)

def log_to_file(data):
    """Appends a row of data to the log file."""
    with open(log_file_path, "a") as f:
        f.write(
            f"{data['timestamp']},{data['cpu_usage']},{data['ram_used_gb']},{data['ram_total_gb']}," +
            f"{data['ram_percent']},\"{data['gpu_info']}\"\n"
        )

def log_to_tensorboard(data, step):
    """Logs system metrics to TensorBoard."""
    with writer.as_default():
        tf.summary.scalar("CPU Usage (%)", data["cpu_usage"], step=step)
        tf.summary.scalar("RAM Usage (%)", data["ram_percent"], step=step)
        tf.summary.scalar("RAM Used (GB)", data["ram_used_gb"], step=step)
        tf.summary.scalar("RAM Total (GB)", data["ram_total_gb"], step=step)

        if data["gpu_info"]:
            for idx, gpu in enumerate(data["gpu_info"]):
                try:
                    # Parse GPU info
                    gpu_name, load_part, memory_part = gpu.split(", ")
                    gpu_load = float(load_part.split(": ")[1].strip().replace("%", ""))
                    gpu_memory_used = float(memory_part.split(": ")[1].split(" ")[0])

                    # Log GPU info as separate scalars
                    tf.summary.scalar(f"GPU {idx} Load (%)", gpu_load, step=step)
                    tf.summary.scalar(f"GPU {idx} Memory Used (MB)", gpu_memory_used, step=step)
                except Exception as e:
                    print(f"Error parsing GPU info for TensorBoard: {e}")
        writer.flush()

def monitor_system():
    """Logs resource usage metrics to TensorBoard and a log file."""
    global stop_monitoring, monitoring_data
    try:
        step = 0  # TensorBoard step
        while not stop_monitoring:
            # Gather CPU usage
            cpu_usage = psutil.cpu_percent(interval=1)

            # Gather RAM usage
            memory_info = psutil.virtual_memory()
            ram_total = memory_info.total / (1024 ** 3)  # Convert to GB
            ram_used = memory_info.used / (1024 ** 3)   # Convert to GB
            ram_percent = memory_info.percent

            # Gather GPU usage
            gpus = GPUtil.getGPUs()
            gpu_info = [
                f"{gpu.name}, Load: {gpu.load * 100:.2f}%, Memory Used: {gpu.memoryUsed:.2f} MB / {gpu.memoryTotal:.2f} MB"
                for gpu in gpus
            ]

            # Prepare data entry with real-time timestamp
            data = {
                "timestamp": datetime.now().strftime('%Y-%m-%d %H:%M:%S'),
                "cpu_usage": round(cpu_usage, 2),
                "ram_used_gb": round(ram_used, 2),
                "ram_total_gb": round(ram_total, 2),
                "ram_percent": round(ram_percent, 2),
                "gpu_info": gpu_info,
            }

            # Log data to memory, file, and TensorBoard
            monitoring_data.append(data)
            log_to_file(data)
            log_to_tensorboard(data, step)
            step += 1  # Increment TensorBoard step
    except Exception as e:
        print(f"Monitoring stopped due to error: {e}")

def start_monitoring():
    """Starts the resource monitoring in a background thread."""
    global stop_monitoring
    stop_monitoring = False
    monitor_thread = threading.Thread(target=monitor_system, daemon=True)
    monitor_thread.start()
    print(f"Monitoring started. View metrics in TensorBoard")

def stop_monitoring_system():
    """Stops the resource monitoring."""
    global stop_monitoring
    stop_monitoring = True
    print("Monitoring stopped. Use the recorded data for analysis.")

def get_monitoring_data():
    """Converts the logged data into a pandas DataFrame."""
    df = pd.DataFrame(monitoring_data)
    if "gpu_info" in df:
        df["gpu_info"] = df["gpu_info"].apply(str)  # Convert GPU info to string for display
    return df

def read_log_file():
    """Reads the log file into a pandas DataFrame."""
    if os.path.exists(log_file_path):
        return pd.read_csv(log_file_path)
    else:
        print(f"No log file found at {log_file_path}.")
        return pd.DataFrame()

print("Run `start_monitoring()` to begin monitoring.")
print("Run `stop_monitoring_system()` to stop monitoring.")
print("Run `get_monitoring_data()` to view the logged data in memory.")
print("Run `read_log_file()` to view the logged data from the file.")

##### **Load dataset**

In [None]:
# Load and sample the dataset
from datasets import load_dataset
dataset = load_dataset("PLACEHOLDER", split="test")

In [None]:
# Get the cache directory
cache_dir = dataset.cache_files[0]["filename"]
print(f"Dataset cache location: {cache_dir}")

In [None]:
# Check all cached files
print(dataset.cache_files)

In [None]:
# Check dataset structure and metadata
print(dataset)
print(dataset.column_names)
print(dataset.num_rows)

##### **Tensorboard**

In [None]:
# Load the TensorBoard extension
%load_ext tensorboard

In [None]:
# Launch TensorBoard in the notebook
%tensorboard --logdir logs/system_metrics/

#### **ColPali Retriever**

In [None]:
# To also look for modules in this directory
import sys
sys.path.append('/content/vidore-benchmark/src')

##### **Sample dataset**

In [None]:
sampled_dataset = dataset.shuffle(seed=42).select(range(100))

In [None]:
sampled_dataset = dataset.shuffle(seed=43).select(range(500))

In [None]:
sampled_dataset = dataset.shuffle(seed=44).select(range(1000))

In [None]:
batch_query=32
batch_passage=32
batch_score=512

In [None]:
# import os
from vidore_benchmark.evaluation.evaluate import evaluate_dataset
from vidore_benchmark.retrievers.colpali_retriever import ColPaliRetriever
import torch
import time

def main():
    # Start system monitoring
    start_monitoring()

    try:
        # Initialize the retriever
        retriever = ColPaliRetriever("vidore/colpali-v1.3", device="cuda" if torch.cuda.is_available() else "cpu")

        # Timing
        start_time = time.time()

        # Evaluate the retriever
        metrics = evaluate_dataset(
            vision_retriever=retriever,
            ds=sampled_dataset,
            batch_query=batch_query,
            batch_passage=batch_passage,
            batch_score=batch_score
        )

        # Stop timing
        end_time = time.time()
        elapsed_time = end_time - start_time

        # Print timing and evaluation metrics
        print(f"Dataset evaluation completed in {elapsed_time:.2f} seconds.")
        print("Evaluation Metrics:", metrics)

    except Exception as e:
        print(f"An error occurred: {e}")

    finally:
        # Stop system monitoring
        stop_monitoring_system()

if __name__ == "__main__":
    main()

In [None]:
import matplotlib.pyplot as plt

# Convert 'timestamp' to a datetime object
df['timestamp'] = pd.to_datetime(df['timestamp'], format='%Y-%m-%d %H:%M:%S')

# Normalize time to start at 0 seconds
df['time'] = (df['timestamp'] - df['timestamp'].iloc[0]).dt.total_seconds()

# Plot the data
plt.figure(figsize=(10, 5))
plt.plot(df['time'], df['cpu_usage'], label='CPU Usage (%)')
plt.plot(df['time'], df['ram_percent'], label='RAM Usage (%)')
plt.xlabel('Time (s)')
plt.ylabel('Usage (%)')
plt.title('System Resource Usage Over Time')
plt.legend()
plt.show()


##### **Whole dataset**

In [None]:
# import os
from vidore_benchmark.evaluation.evaluate import evaluate_dataset
from vidore_benchmark.retrievers.colpali_retriever import ColPaliRetriever
import torch
import time

batch_query=32
batch_passage=32
batch_score=512

def main():
    # Start system monitoring
    start_monitoring()

    try:
        # Initialize the retriever
        retriever = ColPaliRetriever("vidore/colpali-v1.3", device="cuda" if torch.cuda.is_available() else "cpu")

        # Timing
        start_time = time.time()

        # Evaluate the retriever
        metrics = evaluate_dataset(
            vision_retriever=retriever,
            ds=dataset,
            batch_query=batch_query,
            batch_passage=batch_passage,
            batch_score=batch_score
        )

        # Stop timing
        end_time = time.time()
        elapsed_time = end_time - start_time

        # Print timing and evaluation metrics
        print(f"Dataset evaluation completed in {elapsed_time:.2f} seconds.")
        print("Evaluation Metrics:", metrics)

    except Exception as e:
        print(f"An error occurred: {e}")

    finally:
        # Stop system monitoring
        stop_monitoring_system()

if __name__ == "__main__":
    main()

In [None]:
df = read_log_file()
df.head()

In [None]:
import matplotlib.pyplot as plt

# Convert 'timestamp' to a datetime object
df['timestamp'] = pd.to_datetime(df['timestamp'], format='%Y-%m-%d %H:%M:%S')

# Normalize time to start at 0 seconds
df['time'] = (df['timestamp'] - df['timestamp'].iloc[0]).dt.total_seconds()

# Plot the data
plt.figure(figsize=(10, 5))
plt.plot(df['time'], df['cpu_usage'], label='CPU Usage (%)')
plt.plot(df['time'], df['ram_percent'], label='RAM Usage (%)')
plt.xlabel('Time (s)')
plt.ylabel('Usage (%)')
plt.title('System Resource Usage Over Time')
plt.legend()
plt.show()

##### **Batched dataset**

In [None]:
import os
import time
import pandas as pd
from datetime import datetime
from datasets import Dataset
from vidore_benchmark.retrievers.colpali_retriever import ColPaliRetriever
import torch

from colpali_engine.models import ColPali, ColPaliProcessor
from vidore_benchmark.evaluation.vidore_evaluators import ViDoReEvaluatorQA
from vidore_benchmark.retrievers import VisionRetriever
from vidore_benchmark.utils.data_utils import get_datasets_from_collection

# Configure batch sizes
batch_query = 16
batch_passage = 16
batch_score = 256

dataset_cache_path = "PLACEHOLDER"
num_shards_to_process = None

metrics_file = "shard_metrics_colpali.csv"

def process_shards():
    retriever = ColPaliRetriever("vidore/colpali-v1.3", device="cuda" if torch.cuda.is_available() else "cpu")

    model_name = "vidore/colpali-v1.3"
    processor = ColPaliProcessor.from_pretrained(model_name)
    model = ColPali.from_pretrained(
        model_name,
        torch_dtype=torch.bfloat16,
        device_map="cuda",
    ).eval()

    vision_retriever = VisionRetriever(model=model, processor=processor)
    vidore_evaluator = ViDoReEvaluatorQA(vision_retriever)

    shard_files = sorted([f for f in os.listdir(dataset_cache_path) if f.endswith(".arrow")])
    if num_shards_to_process is not None:
        shard_files = shard_files[:num_shards_to_process]

    # Load existing metrics file if it exists
    if os.path.exists(metrics_file):
        try:
            existing_metrics = pd.read_csv(metrics_file)
            processed_shards = set(existing_metrics["shard_idx"].tolist())
            shard_metrics = existing_metrics.to_dict(orient="records")  # Convert existing data to list of dicts
        except Exception as e:
            print(f"Warning: Could not read {metrics_file}. Starting from scratch. Error: {e}")
            processed_shards = set()
            shard_metrics = []
    else:
        processed_shards = set()
        shard_metrics = []

    print(f"Found {len(shard_files)} shard(s) to process. Resuming from last successful shard...")

    total_rows_processed = sum(d["rows_processed"] for d in shard_metrics)
    total_time = sum(d["elapsed_time"] for d in shard_metrics)

    # Start system monitoring
    start_monitoring()

    for shard_idx, shard_file in enumerate(shard_files, 1):
        if shard_idx in processed_shards:
            print(f"Skipping previously processed shard {shard_idx}: {shard_file}")
            continue

        shard_path = os.path.join(dataset_cache_path, shard_file)
        print(f"Processing shard {shard_idx}/{len(shard_files)}: {shard_file}")

        print("Starting processing of dataset shards...")

        try:
            shard_dataset = Dataset.from_file(shard_path)
            num_rows = len(shard_dataset)
            print(f"Number of rows in shard: {num_rows}")

            start_time = time.time()
            metrics = vidore_evaluator.evaluate_dataset(
                vision_retriever=retriever,
                ds=shard_dataset,
                batch_query=batch_query,
                batch_passage=batch_passage,
                batch_score=batch_score
            )
            elapsed_time = time.time() - start_time

            metrics.update({
                "shard_idx": shard_idx,
                "rows_processed": num_rows,
                "elapsed_time": elapsed_time,
                "completion_timestamp": datetime.now().strftime("%Y-%m-%d %H:%M:%S")
            })
            shard_metrics.append(metrics)

            print(f"Shard {shard_idx} processed in {elapsed_time:.2f} seconds.")
            print(f"Metrics for shard {shard_idx}: {metrics}")
            print(f"Total rows processed so far: {total_rows_processed + num_rows}")
            print(f"Total time taken so far: {total_time + elapsed_time:.2f} seconds")
            if total_rows_processed + num_rows > 0:
                print(f"Average time per row so far: {(total_time + elapsed_time) / (total_rows_processed + num_rows):.4f} seconds")

            # Convert list of dictionaries to DataFrame and save
            metrics_df = pd.DataFrame(shard_metrics)
            metrics_df.to_csv(metrics_file, index=False)
            print(f"Metrics for shard {shard_idx} saved to '{metrics_file}'.")

            total_rows_processed += num_rows
            total_time += elapsed_time

        except Exception as e:
            print(f"Error processing shard {shard_idx} ({shard_file}): {e}")
            print("Skipping this shard and continuing with the next one.")

        finally:
            del shard_dataset
            torch.cuda.empty_cache()

    # Stop system monitoring
    stop_monitoring_system()

    print("\nDataset processing completed.")
    print(f"Total rows processed: {total_rows_processed}")
    print(f"Total time taken: {total_time:.2f} seconds")
    if total_rows_processed > 0:
        print(f"Average time per row: {total_time / total_rows_processed:.4f} seconds")
    else:
        print("No rows were processed.")

if __name__ == "__main__":
    process_shards()

##### **Metrics**

In [None]:
import pandas as pd

# Load the metrics CSV
metrics_file = "shard_metrics.csv"
metrics_df = pd.read_csv(metrics_file)

# Display the first few rows
metrics_df.head()

In [None]:
# Reorder columns: move shard-related details to the front
columns_order = ['shard_idx', 'rows_processed', 'elapsed_time', 'completion_timestamp'] + \
                [col for col in metrics_df.columns if col not in ['shard_idx', 'rows_processed', 'elapsed_time', 'completion_timestamp']]
metrics_df = metrics_df[columns_order]

# Summarize shard-level statistics
summary = {
    "Total Shards Processed": len(metrics_df),
    "Total Rows Processed": metrics_df['rows_processed'].sum(),
    "Total Time (seconds)": metrics_df['elapsed_time'].sum(),
    "Average Time per Row (seconds)": metrics_df['elapsed_time'].sum() / metrics_df['rows_processed'].sum(),
    "Average NDCG@5": metrics_df['ndcg_at_5'].mean(),
    "Max NDCG@5": metrics_df['ndcg_at_5'].max(),
    "Min NDCG@5": metrics_df['ndcg_at_5'].min(),
}

# Display reordered data and high-level summary
# import ace_tools as tools; tools.display_dataframe_to_user(name="Shard Metrics DataFrame", dataframe=metrics_df)

summary


In [None]:
import pandas as pd

# Load the metrics CSV
metrics_file = "shard_metrics_colpali.csv"
metrics_df = pd.read_csv(metrics_file)

# Display the first few rows
metrics_df

In [None]:
# Reorder columns: move shard-related details to the front
columns_order = ['shard_idx', 'rows_processed', 'elapsed_time', 'completion_timestamp'] + \
                [col for col in metrics_df.columns if col not in ['shard_idx', 'rows_processed', 'elapsed_time', 'completion_timestamp']]
metrics_df = metrics_df[columns_order]

# Summarize shard-level statistics
summary = {
    "Total Shards Processed": len(metrics_df),
    "Total Rows Processed": metrics_df['rows_processed'].sum(),
    "Total Time (seconds)": metrics_df['elapsed_time'].sum(),
    "Average Time per Row (seconds)": metrics_df['elapsed_time'].sum() / metrics_df['rows_processed'].sum(),
    "Average NDCG@5": metrics_df['ndcg_at_5'].mean(),
    "Max NDCG@5": metrics_df['ndcg_at_5'].max(),
    "Min NDCG@5": metrics_df['ndcg_at_5'].min(),
}

# Display reordered data and high-level summary
# import ace_tools as tools; tools.display_dataframe_to_user(name="Shard Metrics DataFrame", dataframe=metrics_df)

summary


#### **ColQwen2 Retriever**

##### **Batched dataset**

In [None]:
# Supplementary library for Qwen
!pip install flash-attn --no-build-isolation
!pip install --upgrade triton

In [None]:
# Upgrage transformers to support flash attention
!pip install --upgrade transformers

In [None]:
import os
import time
from tqdm import tqdm
import pandas as pd
from datetime import datetime
from datasets import Dataset
import torch
import os

from colpali_engine.models import ColQwen2, ColQwen2Processor

from vidore_benchmark.evaluation.vidore_evaluators import ViDoReEvaluatorQA
from vidore_benchmark.retrievers import VisionRetriever
from vidore_benchmark.retrievers.colqwen2_retriever import ColQwen2Retriever
from vidore_benchmark.utils.data_utils import get_datasets_from_collection
from transformers.utils.import_utils import is_flash_attn_2_available

# Helps PyTorch manage GPU memory more efficiently
os.environ["PYTORCH_CUDA_ALLOC_CONF"] = "expandable_segments:True"

# Configure batch sizes
batch_query = 32
batch_passage = 32
batch_score = 512

dataset_cache_path = "PLACEHOLDER"
num_shards_to_process = None

metrics_file = "shard_metrics_colqwen.csv"

def process_shards():
    retriever = ColQwen2Retriever("vidore/colqwen2-v1.0", device="cuda" if torch.cuda.is_available() else "cpu")

    model_name = "vidore/colqwen2-v1.0"
    processor = ColQwen2Processor.from_pretrained(model_name)
    model = ColQwen2.from_pretrained(
        model_name,
        torch_dtype=torch.bfloat16,
        device_map="cuda",
        attn_implementation="flash_attention_2" if is_flash_attn_2_available() else None
    ).eval()

    vision_retriever = VisionRetriever(model=model, processor=processor)
    vidore_evaluator = ViDoReEvaluatorQA(vision_retriever)

    shard_files = sorted([f for f in os.listdir(dataset_cache_path) if f.endswith(".arrow")])
    if num_shards_to_process is not None:
        shard_files = shard_files[:num_shards_to_process]

    # Load existing metrics file if it exists
    if os.path.exists(metrics_file):
        try:
            existing_metrics = pd.read_csv(metrics_file)
            processed_shards = set(existing_metrics["shard_idx"].tolist())
            shard_metrics = existing_metrics.to_dict(orient="records")  # Convert existing data to list of dicts
        except Exception as e:
            print(f"Warning: Could not read {metrics_file}. Starting from scratch. Error: {e}")
            processed_shards = set()
            shard_metrics = []
    else:
        processed_shards = set()
        shard_metrics = []

    print(f"Found {len(shard_files)} shard(s) to process. Resuming from last successful shard...")

    total_rows_processed = sum(d["rows_processed"] for d in shard_metrics)
    total_time = sum(d["elapsed_time"] for d in shard_metrics)

    # Start system monitoring
    start_monitoring()

    print("Starting processing of dataset shards...")

    for shard_idx, shard_file in enumerate(shard_files, 1):
        if shard_idx in processed_shards:
            print(f"Skipping previously processed shard {shard_idx}: {shard_file}")
            continue

        shard_path = os.path.join(dataset_cache_path, shard_file)
        print(f"Processing shard {shard_idx}/{len(shard_files)}: {shard_file}")

        try:
            shard_dataset = Dataset.from_file(shard_path)
            num_rows = len(shard_dataset)
            print(f"Number of rows in shard: {num_rows}")

            start_time = time.time()
            metrics = vidore_evaluator.evaluate_dataset(
                vision_retriever=retriever,
                ds=shard_dataset,
                batch_query=batch_query,
                batch_passage=batch_passage,
                batch_score=batch_score
            )
            elapsed_time = time.time() - start_time

            metrics.update({
                "shard_idx": shard_idx,
                "rows_processed": num_rows,
                "elapsed_time": elapsed_time,
                "completion_timestamp": datetime.now().strftime("%Y-%m-%d %H:%M:%S")
            })
            shard_metrics.append(metrics)

            print(f"Shard {shard_idx} processed in {elapsed_time:.2f} seconds.")
            print(f"Metrics for shard {shard_idx}: {metrics}")
            print(f"Total rows processed so far: {total_rows_processed + num_rows}")
            print(f"Total time taken so far: {total_time + elapsed_time:.2f} seconds")
            if total_rows_processed + num_rows > 0:
                print(f"Average time per row so far: {(total_time + elapsed_time) / (total_rows_processed + num_rows):.4f} seconds")

            # Convert list of dictionaries to DataFrame and save
            metrics_df = pd.DataFrame(shard_metrics)
            metrics_df.to_csv(metrics_file, index=False)
            print(f"Metrics for shard {shard_idx} saved to '{metrics_file}'.")

            total_rows_processed += num_rows
            total_time += elapsed_time

        except Exception as e:
            print(f"Error processing shard {shard_idx} ({shard_file}): {e}")
            print("Skipping this shard and continuing with the next one.")

        finally:
            del shard_dataset
            torch.cuda.empty_cache()

    # Stop system monitoring
    stop_monitoring_system()

    print("\nDataset processing completed.")
    print(f"Total rows processed: {total_rows_processed}")
    print(f"Total time taken: {total_time:.2f} seconds")
    if total_rows_processed > 0:
        print(f"Average time per row: {total_time / total_rows_processed:.4f} seconds")
    else:
        print("No rows were processed.")

if __name__ == "__main__":
    process_shards()

##### **Metrics**

In [None]:
import pandas as pd

# Load the metrics CSV
metrics_file = "shard_metrics_colqwen.csv"
metrics_df = pd.read_csv(metrics_file)

# Display the first few rows
metrics_df.head()

In [None]:
# Reorder columns: move shard-related details to the front
columns_order = ['shard_idx', 'rows_processed', 'elapsed_time', 'completion_timestamp'] + \
                [col for col in metrics_df.columns if col not in ['shard_idx', 'rows_processed', 'elapsed_time', 'completion_timestamp']]
metrics_df = metrics_df[columns_order]

# Summarize shard-level statistics
summary = {
    "Total Shards Processed": len(metrics_df),
    "Total Rows Processed": metrics_df['rows_processed'].sum(),
    "Total Time (seconds)": metrics_df['elapsed_time'].sum(),
    "Average Time per Row (seconds)": metrics_df['elapsed_time'].sum() / metrics_df['rows_processed'].sum(),
    "Average NDCG@5": metrics_df['ndcg_at_5'].mean(),
    "Max NDCG@5": metrics_df['ndcg_at_5'].max(),
    "Min NDCG@5": metrics_df['ndcg_at_5'].min(),
}

# Display reordered data and high-level summary
# import ace_tools as tools; tools.display_dataframe_to_user(name="Shard Metrics DataFrame", dataframe=metrics_df)

summary


In [None]:
import pandas as pd

# Load the metrics CSV
metrics_file = "shard_metrics_colqwen.csv"
metrics_df = pd.read_csv(metrics_file)

# Display the first few rows
metrics_df

In [None]:
# Reorder columns: move shard-related details to the front
columns_order = ['shard_idx', 'rows_processed', 'elapsed_time', 'completion_timestamp'] + \
                [col for col in metrics_df.columns if col not in ['shard_idx', 'rows_processed', 'elapsed_time', 'completion_timestamp']]
metrics_df = metrics_df[columns_order]

# Summarize shard-level statistics
summary = {
    "Total Shards Processed": len(metrics_df),
    "Total Rows Processed": metrics_df['rows_processed'].sum(),
    "Total Time (seconds)": metrics_df['elapsed_time'].sum(),
    "Average Time per Row (seconds)": metrics_df['elapsed_time'].sum() / metrics_df['rows_processed'].sum(),
    "Average NDCG@5": metrics_df['ndcg_at_5'].mean(),
    "Max NDCG@5": metrics_df['ndcg_at_5'].max(),
    "Min NDCG@5": metrics_df['ndcg_at_5'].min(),
}

# Display reordered data and high-level summary
# import ace_tools as tools; tools.display_dataframe_to_user(name="Shard Metrics DataFrame", dataframe=metrics_df)

summary


## **Papermage**

##### **Dataset**

In [None]:
import shutil
import os
from tqdm import tqdm

In [None]:
# Specify the source folder in Google Drive and the destination folder in Colab
source_folder = 'PLACEHOLDER'
destination_folder = '/content/dataset_PDF_sample'

# Copy the folder to Colab with progress bar
if os.path.exists(source_folder):
    os.makedirs(destination_folder, exist_ok=True)
    files = os.listdir(source_folder)

    for file_name in tqdm(files, desc="Copying files", unit="file"):
        src_file = os.path.join(source_folder, file_name)
        dst_file = os.path.join(destination_folder, file_name)

        if os.path.isdir(src_file):
            shutil.copytree(src_file, dst_file)
        else:
            shutil.copy2(src_file, dst_file)

    print(f"Folder '{source_folder}' has been copied to '{destination_folder}'.")
else:
    print(f"Source folder '{source_folder}' does not exist.")


##### **Enviroment**

In [None]:
!apt-get install poppler-utils
!pip install 'papermage[dev,predictors,visualizers]'

In [None]:
# Run for one PDF to load models and test the library
from papermage.recipes.core_recipe import CoreRecipe
dummy_pdf_path = "PLACEHOLDER"
recipe = CoreRecipe()
doc = recipe.run(dummy_pdf_path)

##### **Sample PDF**

In [None]:
pdf_path = 'PLACEHOLDER'

In [None]:
from papermage.recipes import CoreRecipe

recipe = CoreRecipe()
doc = recipe.run(pdf_path)

In [None]:
doc.layers

In [None]:
doc.images[0]

In [None]:
from papermage.visualizers import plot_entities_on_page

page = doc.pages[0]
highlighted = plot_entities_on_page(page.images[0], page.tokens, box_width=0, box_alpha=0.3, box_color="yellow")
highlighted = plot_entities_on_page(highlighted, page.abstracts, box_width=2, box_alpha=0.1, box_color="red")
display(highlighted)

In [None]:
highlighted = plot_entities_on_page(doc.pages[0].images[0], doc.pages[0].titles, box_color="blue", box_alpha=0.2)
highlighted = plot_entities_on_page(highlighted, doc.pages[0].authors, box_color="green", box_alpha=0.2)
display(highlighted)

print('TITLE:')
print(doc.pages[0].titles[0].text)
print('\n\nAUTHORS:')
print(doc.pages[0].authors[0].text)

In [None]:
from papermage.visualizers.visualizer import plot_entities_on_page
from IPython.display import display

# Iterate through all pages in the document
for page_index, page in enumerate(doc.pages):
    print(f"\n=== Page {page_index + 1} ===")

    # Extract and display metadata
    if hasattr(page, 'metadata'):
        print("\nMetadata:")
        for key, value in page.metadata.items():
            print(f"{key}: {value}")

    # Prepare highlights only if there are entities to highlight
    page_image = page.images[0] if hasattr(page, 'images') and page.images else None
    highlighted = None

    # Highlight titles and authors
    if page_image and (hasattr(page, 'titles') and page.titles or hasattr(page, 'authors') and page.authors):
        print("\nHighlighting Titles and Authors...")
        highlighted = plot_entities_on_page(
            page_image,
            page.titles if hasattr(page, 'titles') else [],
            box_color="blue",
            box_alpha=0.2
        )
        highlighted = plot_entities_on_page(
            highlighted,
            page.authors if hasattr(page, 'authors') else [],
            box_color="green",
            box_alpha=0.2
        )

    # Print titles and authors
    if hasattr(page, 'titles') and page.titles:
        print("\nTitles:")
        for title in page.titles:
            print(f"- {title.text}")

    if hasattr(page, 'authors') and page.authors:
        print("\nAuthors:")
        for author in page.authors:
            print(f"- {author.text}")

    # Highlight figures and captions
    if page_image and (hasattr(page, 'figures') and page.figures or hasattr(page, 'captions') and page.captions):
        print("\nHighlighting Figures and Captions...")
        if highlighted is None:
            highlighted = page_image  # Initialize with the page image
        highlighted = plot_entities_on_page(
            highlighted,
            page.figures if hasattr(page, 'figures') else [],
            box_color="red",
            box_alpha=0.2
        )
        highlighted = plot_entities_on_page(
            highlighted,
            page.captions if hasattr(page, 'captions') else [],
            box_color="yellow",
            box_alpha=0.2
        )

    # Print figures and captions
    if hasattr(page, 'figures') and page.figures:
        print("\nFigures:")
        for figure in page.figures:
            print(f"Bounding Box: {figure.boxes if hasattr(figure, 'boxes') else 'N/A'}")

    if hasattr(page, 'captions') and page.captions:
        print("\nCaptions:")
        for caption in page.captions:
            print(f"- {caption.text}")

    # Display the highlighted image only if something was highlighted
    if highlighted is not None and highlighted != page_image:
        display(highlighted)
    else:
        print("No entities to highlight on this page.")

    # Extract and display paragraphs and sentences
    if hasattr(page, 'paragraphs') and page.paragraphs:
        print("\nParagraphs:")
        for i, paragraph in enumerate(page.paragraphs[:3]):  # Display up to 3 paragraphs per page
            print(f"Paragraph {i + 1}: {paragraph.text}")

    if hasattr(page, 'sentences') and page.sentences:
        print("\nSentences:")
        for i, sentence in enumerate(page.sentences[:5]):  # Display up to 5 sentences per page
            print(f"Sentence {i + 1}: {sentence.text}")

    # Separate output between pages
    print("\n" + "=" * 50 + "\n")

##### **XML**

In [None]:
import os
import json
from lxml import etree

def extract_relevant_data(xml_file):
    """
    Extracts titles, authors, abstracts, keywords, and bibliographies from an XML file.
    Args:
    - xml_file: Path to the XML file.

    Returns:
    - A dictionary containing the extracted data.
    """
    try:
        tree = etree.parse(xml_file)
        root = tree.getroot()

        # Extract titles
        titles = []

        # Extract `article-title` (relevant for articles)
        for elem in root.findall('.//article-title'):
            if elem.text:
                titles.append(elem.text.strip())

        # Extract `<title>` inside `<book-part>` (relevant for book chapters)
        for book_part in root.findall('.//book-part'):
            title_group = book_part.find('.//title-group')
            if title_group is not None:
                for title in title_group.findall('title'):
                    if title.text:
                        titles.append(title.text.strip())

        # Extract authors
        authors = []
        for contrib in root.findall('.//contrib'):
            name = contrib.find('.//name')
            surname = name.find('surname').text if name is not None and name.find('surname') is not None else None
            given_names = name.find('given-names').text if name is not None and name.find('given-names') is not None else None
            if surname or given_names:
                authors.append(f"{given_names} {surname}".strip())

        # Extract abstracts
        abstracts = []
        for abstract in root.findall('.//abstract'):
            text = ' '.join(p.text.strip() for p in abstract.findall('.//p') if p.text)
            if text:
                abstracts.append(text)

        # Extract keywords
        keywords = []
        for kwd in root.findall('.//kwd'):
            if kwd.text:
                keywords.append(kwd.text.strip())

        # Extract bibliographies
        bibliographies = []
        for ref in root.findall('.//ref'):
            mixed_citation = ref.find('.//mixed-citation')
            if mixed_citation is not None and mixed_citation.text:
                bibliographies.append(mixed_citation.text.strip())

        return {
            "titles": titles,
            "authors": authors,
            "abstracts": abstracts,
            "keywords": keywords,
            "bibliographies": bibliographies,
        }

    except Exception as e:
        print(f"Error processing {xml_file}: {e}")
        return None

def process_directory(xml_directory, output_file):
    """
    Processes all XML files in a directory, extracts relevant data, and saves it to a JSON file.
    Args:
    - xml_directory: Path to the directory containing XML files.
    - output_file: Path to the output JSON file.
    """
    extracted_data = {}

    # Find all XML files in the directory
    xml_files = [os.path.join(root, file)
                 for root, _, files in os.walk(xml_directory) for file in files if file.endswith('.xml')]

    print(f"Found {len(xml_files)} XML files. Processing...")

    for xml_file in xml_files:
        data = extract_relevant_data(xml_file)
        if data:
            extracted_data[os.path.basename(xml_file)] = data  # Use the file name as the key

    # Save extracted data to JSON
    with open(output_file, "w") as f:
        json.dump(extracted_data, f, indent=4)

    print(f"Extracted data from {len(extracted_data)} files and saved to {output_file}.")

# Example usage
xml_directory = "/content/dataset_PDF_sample"  # Replace with the path to your XML files
output_json_file = "extracted_data_XML.json"  # Replace with the desired output JSON file name
process_directory(xml_directory, output_json_file)

##### **Multiprocessing**

In [None]:
import multiprocessing

cpu_count = multiprocessing.cpu_count()
print(f"Available CPU cores: {cpu_count}")

In [None]:
import json
import os
from tqdm import tqdm
from concurrent.futures import ThreadPoolExecutor
from papermage.recipes.core_recipe import CoreRecipe
import logging
import warnings

# Suppress warnings
logging.getLogger("papermage.magelib.entity").setLevel(logging.ERROR)
warnings.filterwarnings("ignore")


# Function: Load Existing Data (if any)
def load_existing_data(output_json_path):
    if os.path.exists(output_json_path):
        with open(output_json_path, "r", encoding="utf-8") as file:
            return json.load(file)
    return {}


# Function: Extract Data from a Single PDF
def process_single_pdf(pdf_path):
    from papermage.recipes.core_recipe import CoreRecipe
    recipe = CoreRecipe()

    try:
        doc = recipe.run(pdf_path)
        data = {
            "titles": [title.text.strip() for page in doc.pages for title in page.titles] if hasattr(doc, "titles") else [],
            "authors": [author.text.strip() for page in doc.pages for author in page.authors] if hasattr(doc, "authors") else [],
            "abstracts": [abstract.text.strip() for page in doc.pages for abstract in page.abstracts] if hasattr(doc, "abstracts") else [],
            "keywords": [keyword.text.strip() for page in doc.pages for keyword in page.keywords] if hasattr(doc, "keywords") else [],
            "bibliographies": [bib.text.strip() for page in doc.pages for bib in page.bibliographies] if hasattr(doc, "bibliographies") else []
        }
        return os.path.basename(pdf_path).replace(".pdf", ".xml"), data
    except Exception as e:
        tqdm.write(f"Error processing {pdf_path}: {e}")
        return os.path.basename(pdf_path), {"error": str(e)}


# Function: Save Data Incrementally
def save_incremental_data(output_json_path, document_name, data):
    """
    Save the result of a single file to the JSON file incrementally.
    """
    existing_data = load_existing_data(output_json_path)
    existing_data[document_name] = data

    with open(output_json_path, "w", encoding="utf-8") as json_file:
        json.dump(existing_data, json_file, ensure_ascii=False, indent=4)


# Function: Process PDFs in Parallel with Incremental Saving
def process_pdfs_in_parallel(input_dir, output_json_path, num_workers, limit_files=None):
    # Load existing processed data to resume from last run
    existing_data = load_existing_data(output_json_path)
    processed_files = set(existing_data.keys())

    # Get list of PDF files and exclude already processed ones
    pdf_files = [os.path.join(input_dir, f) for f in os.listdir(input_dir) if f.endswith(".pdf")]
    pdf_files_to_process = [f for f in pdf_files if os.path.basename(f).replace(".pdf", ".xml") not in processed_files]

    if limit_files:
        pdf_files_to_process = pdf_files_to_process[:limit_files]

    if not pdf_files_to_process:
        print("All files have already been processed.")
        return

    with ThreadPoolExecutor(max_workers=num_workers) as executor:
        futures = {executor.submit(process_single_pdf, pdf_path): pdf_path for pdf_path in pdf_files_to_process}

        for future in tqdm(futures, desc="Processing PDFs", unit="file"):
            result = future.result()
            if result:
                document_name, data = result
                save_incremental_data(output_json_path, document_name, data)  # Save after each file

    print(f"All data extracted and saved to {output_json_path}")


# Input and Output Paths
input_directory = "/content/dataset_PDF_sample"  # Directory containing all PDFs
output_json_path = "/content/extracted_data_PAPERMAGE.json"  # Combined JSON output

# Run the code with incremental saving and crash recovery
process_pdfs_in_parallel(input_directory, output_json_path, num_workers=10, limit_files=None)

##### **Evaluation**

###### **Initial Evaluation**

In [None]:
import json
from sklearn.metrics import precision_score, recall_score, f1_score
from difflib import SequenceMatcher
from sklearn.feature_extraction.text import TfidfVectorizer
from scipy.spatial.distance import cosine
import numpy as np
import re
import pandas as pd

def load_json(file_path):
    with open(file_path, 'r', encoding='utf-8') as f:
        return json.load(f)

def normalize_text(text):
    return text.strip().lower()

def normalize_list(data_list):
    return [normalize_text(item) for item in data_list]

def levenshtein_ratio(s1, s2):
    return SequenceMatcher(None, normalize_text(s1), normalize_text(s2)).ratio()

def compute_cosine_similarity(text1, text2):
    if not text1.strip() and not text2.strip():
        return 1.0  # Both abstracts are empty, considered identical
    if not text1.strip() or not text2.strip():
        return 0.0  # One is empty, considered completely dissimilar
    vectorizer = TfidfVectorizer().fit_transform([normalize_text(text1), normalize_text(text2)])
    vectors = vectorizer.toarray()
    return 1 - cosine(vectors[0], vectors[1])

def evaluate_field(gt_field, pred_field, metric):
    if metric == 'levenshtein':
        return levenshtein_ratio(gt_field, pred_field)
    elif metric == 'cosine':
        return compute_cosine_similarity(gt_field, pred_field)
    else:
        raise ValueError(f"Unknown metric: {metric}")

def evaluate_lists(gt_list, pred_list, metric):
    gt_list = normalize_list(gt_list)
    pred_list = normalize_list(pred_list)

    if metric == 'f1':
        gt_set, pred_set = set(gt_list), set(pred_list)
        tp = len(gt_set & pred_set)
        precision = tp / len(pred_set) if pred_set else 0
        recall = tp / len(gt_set) if gt_set else 0
        f1 = (2 * precision * recall) / (precision + recall) if (precision + recall) > 0 else 0
        return precision, recall, f1
    elif metric == 'jaccard':
        gt_set, pred_set = set(gt_list), set(pred_list)
        return len(gt_set & pred_set) / len(gt_set | pred_set)
    else:
        raise ValueError(f"Unknown metric: {metric}")

def compare_authors(gt_list, pred_list):
    def extract_names(authors):
        names = []
        for author in authors:
            clean_author = re.sub(r'[^a-zA-Z\s]', '', author)
            names.extend(clean_author.split())
        return set(normalize_list(names))

    gt_tokens = extract_names(gt_list)
    pred_tokens = extract_names(pred_list)

    tp = len(gt_tokens & pred_tokens)
    precision = tp / len(pred_tokens) if pred_tokens else 0
    recall = tp / len(gt_tokens) if gt_tokens else 0
    f1 = (2 * precision * recall) / (precision + recall) if (precision + recall) > 0 else 0
    return precision, recall, f1

def compare_raw_strings(gt_string, pred_string):
    gt_tokens = set(re.findall(r'\b\w+\b', normalize_text(gt_string)))
    pred_tokens = set(re.findall(r'\b\w+\b', normalize_text(pred_string)))
    tp = len(gt_tokens & pred_tokens)
    precision = tp / len(pred_tokens) if pred_tokens else 0
    recall = tp / len(gt_tokens) if gt_tokens else 0
    f1 = (2 * precision * recall) / (precision + recall) if (precision + recall) > 0 else 0
    return precision, recall, f1

def main():
    gt_data = load_json('extracted_data_XML.json')
    pred_data = load_json('extracted_data_PAPERMAGE.json')

    matching_docs = set(gt_data.keys()) & set(pred_data.keys())

    results = {
        'titles': [],
        'authors': [],
        'abstracts': [],
        'keywords': [],
        'bibliographies': []
    }

    per_document_results = []

    for doc_id in matching_docs:
        gt_doc = gt_data[doc_id]
        pred_doc = pred_data[doc_id]

        # Titles
        gt_title = gt_doc.get('titles', [''])
        gt_title = gt_title[0] if gt_title else ''
        pred_title = pred_doc.get('titles', [''])
        pred_title = pred_title[0] if pred_title else ''
        title_score = evaluate_field(gt_title, pred_title, 'levenshtein')
        results['titles'].append(title_score)

        # Authors
        author_precision, author_recall, author_f1 = compare_authors(gt_doc.get('authors', []), pred_doc.get('authors', []))
        results['authors'].append({'precision': author_precision, 'recall': author_recall, 'f1': author_f1})

        # Abstracts
        gt_abstract = gt_doc.get('abstracts', [''])
        gt_abstract = gt_abstract[0] if gt_abstract else ''
        pred_abstract = pred_doc.get('abstracts', [''])
        pred_abstract = pred_abstract[0] if pred_abstract else ''
        abstract_score = evaluate_field(gt_abstract, pred_abstract, 'cosine')
        results['abstracts'].append(abstract_score)

        # Keywords
        gt_keywords = " ".join(gt_doc.get('keywords', []))
        pred_keywords = " ".join(pred_doc.get('keywords', []))
        keyword_precision, keyword_recall, keyword_f1 = compare_raw_strings(gt_keywords, pred_keywords)
        results['keywords'].append({'precision': keyword_precision, 'recall': keyword_recall, 'f1': keyword_f1})

        # Bibliographies
        gt_bib = " ".join(gt_doc.get('bibliographies', []))
        pred_bib = " ".join(pred_doc.get('bibliographies', []))
        bib_precision, bib_recall, bib_f1 = compare_raw_strings(gt_bib, pred_bib)
        results['bibliographies'].append({'precision': bib_precision, 'recall': bib_recall, 'f1': bib_f1})

        # Collect per-document results
        per_document_results.append({
            'Document': doc_id,
            'Title Score': title_score,
            'Author Precision': author_precision,
            'Author Recall': author_recall,
            'Author F1': author_f1,
            'Abstract Score': abstract_score,
            'Keyword Precision': keyword_precision,
            'Keyword Recall': keyword_recall,
            'Keyword F1': keyword_f1,
            'Bibliography Precision': bib_precision,
            'Bibliography Recall': bib_recall,
            'Bibliography F1': bib_f1
        })

    # Save per-document results to CSV
    df = pd.DataFrame(per_document_results)
    df.to_csv('per_document_evaluation.csv', index=False)

    # Compute overall averages
    print("Evaluation Results:")
    for field, scores in results.items():
        if field == 'titles' or field == 'abstracts':
            avg_score = np.nanmean(scores)
            print(f"{field.capitalize()} Average Score: {avg_score:.4f}")
        else:
            avg_precision = np.nanmean([x['precision'] for x in scores])
            avg_recall = np.nanmean([x['recall'] for x in scores])
            avg_f1 = np.nanmean([x['f1'] for x in scores])
            print(f"{field.capitalize()} - Precision: {avg_precision:.4f}, Recall: {avg_recall:.4f}, F1 Score: {avg_f1:.4f}")

if __name__ == '__main__':
    main()

###### **Updated Evaluation with Sentence Transformers and NER**

In [None]:
!pip install --upgrade transformers

In [None]:
import json
from sklearn.metrics import precision_score, recall_score, f1_score
from difflib import SequenceMatcher
from sklearn.feature_extraction.text import TfidfVectorizer
from scipy.spatial.distance import cosine
import numpy as np
import re
import pandas as pd
from sentence_transformers import SentenceTransformer
from tqdm import tqdm
import spacy

In [None]:
# Load sentence transformer model before running evaluation
print("Loading Sentence Transformer model...")
model = SentenceTransformer('all-MiniLM-L6-v2')
print("Model loaded successfully.")

In [None]:
# Load necessary models before running evaluation
print("Loading Sentence Transformer model and Named Entity Recognition (NER) model...")
model = SentenceTransformer('all-MiniLM-L6-v2')
nlp = spacy.load("en_core_web_sm")
print("Models loaded successfully.")

def load_json(file_path):
    with open(file_path, 'r', encoding='utf-8') as f:
        return json.load(f)

def normalize_text(text):
    return text.strip().lower()

def levenshtein_ratio(s1, s2):
    s1_tokens = normalize_text(s1).split()
    s2_tokens = normalize_text(s2).split()
    matches = sum(1 for token in s1_tokens if token in s2_tokens)
    return 2 * matches / (len(s1_tokens) + len(s2_tokens)) if (len(s1_tokens) + len(s2_tokens)) > 0 else 0

def compute_cosine_similarity(text1, text2):
    if not text1.strip() and not text2.strip():
        return 1.0
    if not text1.strip() or not text2.strip():
        return 0.0
    text1_embedding = model.encode(normalize_text(text1.replace("ABSTRACT", "")), convert_to_tensor=True)
    text2_embedding = model.encode(normalize_text(text2.replace("ABSTRACT", "")), convert_to_tensor=True)
    return float(cosine(text1_embedding.cpu().numpy(), text2_embedding.cpu().numpy()))

def evaluate_field(gt_field, pred_field, metric):
    if metric == 'levenshtein':
        return levenshtein_ratio(gt_field, pred_field)
    elif metric == 'cosine':
        return compute_cosine_similarity(gt_field, pred_field)
    else:
        raise ValueError(f"Unknown metric: {metric}")

def evaluate_lists(gt_list, pred_list):
    gt_tokens = set(" ".join(gt_list).split())
    pred_tokens = set(" ".join(pred_list).split())
    tp = len(gt_tokens & pred_tokens)
    precision = tp / len(pred_tokens) if pred_tokens else 0
    recall = tp / len(gt_tokens) if gt_tokens else 0
    f1 = (2 * precision * recall) / (precision + recall) if (precision + recall) > 0 else 0
    return precision, recall, f1

def main():
    print("Loading JSON files...")
    gt_data = load_json('extracted_data_XML.json')
    pred_data = load_json('extracted_data_PAPERMAGE.json')
    print("JSON files loaded successfully.")

    matching_docs = set(gt_data.keys()) & set(pred_data.keys())
    results = {'titles': [], 'authors': [], 'abstracts': [], 'keywords': [], 'bibliographies': []}
    per_document_results = []

    print("Processing documents...")
    for doc_id in tqdm(matching_docs, desc="Evaluating Documents"):
        gt_doc = gt_data[doc_id]
        pred_doc = pred_data[doc_id]

        gt_title = gt_doc.get('titles', [''])[0] if gt_doc.get('titles', ['']) else ''
        pred_title = pred_doc.get('titles', [''])[0] if pred_doc.get('titles', ['']) else ''
        title_score = evaluate_field(gt_title, pred_title, 'levenshtein')
        results['titles'].append(title_score)

        author_precision, author_recall, author_f1 = evaluate_lists(gt_doc.get('authors', []), pred_doc.get('authors', []))
        results['authors'].append({'precision': author_precision, 'recall': author_recall, 'f1': author_f1})

        gt_abstract = gt_doc.get('abstracts', [''])[0] if gt_doc.get('abstracts', ['']) else ''
        pred_abstract = pred_doc.get('abstracts', [''])[0] if pred_doc.get('abstracts', ['']) else ''
        abstract_score = evaluate_field(gt_abstract, pred_abstract, 'cosine')
        results['abstracts'].append(abstract_score)

        keyword_precision, keyword_recall, keyword_f1 = evaluate_lists(gt_doc.get('keywords', []), pred_doc.get('keywords', []))
        results['keywords'].append({'precision': keyword_precision, 'recall': keyword_recall, 'f1': keyword_f1})

        bib_precision, bib_recall, bib_f1 = evaluate_lists(gt_doc.get('bibliographies', []), pred_doc.get('bibliographies', []))
        results['bibliographies'].append({'precision': bib_precision, 'recall': bib_recall, 'f1': bib_f1})

        per_document_results.append({
            'Document': doc_id,
            'Title Score': title_score,
            'Author Precision': author_precision,
            'Author Recall': author_recall,
            'Author F1': author_f1,
            'Abstract Score': abstract_score,
            'Keyword Precision': keyword_precision,
            'Keyword Recall': keyword_recall,
            'Keyword F1': keyword_f1,
            'Bibliography Precision': bib_precision,
            'Bibliography Recall': bib_recall,
            'Bibliography F1': bib_f1
        })

    df = pd.DataFrame(per_document_results)
    df.to_csv('per_document_evaluation.csv', index=False)
    print("Evaluation complete. Results saved to per_document_evaluation.csv")

    print("Overall Averages:")
    for field, scores in results.items():
        if field == 'titles' or field == 'abstracts':
            avg_score = np.nanmean(scores)
            print(f"{field.capitalize()} Average Score: {avg_score:.4f}")
        else:
            avg_precision = np.nanmean([x['precision'] for x in scores])
            avg_recall = np.nanmean([x['recall'] for x in scores])
            avg_f1 = np.nanmean([x['f1'] for x in scores])
            print(f"{field.capitalize()} - Precision: {avg_precision:.4f}, Recall: {avg_recall:.4f}, F1 Score: {avg_f1:.4f}")

if __name__ == '__main__':
    main()

###### **Updated Evaluation with using token-level comparison**

In [None]:
# Import necessary libraries
import json
import pandas as pd
import re
from sklearn.metrics import precision_recall_fscore_support

# Load the extracted data from Papermage (Predicted Data)
with open("extracted_data_PAPERMAGE.json", "r", encoding="utf-8") as file:
    papermage_data = json.load(file)

# Load the ground truth extracted data (Actual Data)
with open("extracted_data_XML.json", "r", encoding="utf-8") as file:
    xml_data = json.load(file)

# Define categories to evaluate
categories = ["titles", "authors", "abstracts", "keywords", "bibliographies"]

# Text preprocessing function
def preprocess_text(text):
    text = text.lower().strip()  # Lowercase and strip spaces
    text = re.sub(r"\s+", " ", text)  # Normalize spaces
    text = re.sub(r"[^\w\s]", "", text)  # Remove special characters
    return text

# Tokenization function
def tokenize_text(text):
    return set(preprocess_text(text).split())

# Evaluation metrics storage
evaluation_results = {cat: {"precision": [], "recall": [], "f1": []} for cat in categories}
per_document_evaluation = []

# Get common document keys
common_keys = set(papermage_data.keys()).intersection(set(xml_data.keys()))

# Iterate over documents
for paper_id in common_keys:
    papermage_entry = papermage_data[paper_id]
    xml_entry = xml_data[paper_id]
    document_results = {"document_id": paper_id}

    for category in categories:
        # Extract text and tokenize
        papermage_text = " ".join(papermage_entry.get(category, []))
        xml_text = " ".join(xml_entry.get(category, []))

        papermage_tokens = tokenize_text(papermage_text)
        xml_tokens = tokenize_text(xml_text)

        # Compute Precision, Recall, and F1-score
        if not xml_tokens and not papermage_tokens:
            precision, recall, f1 = 1.0, 1.0, 1.0  # Perfect match for empty fields
        elif not xml_tokens:
            precision, recall, f1 = 0.0, 0.0, 0.0  # Extracted something that shouldn't be there
        elif not papermage_tokens:
            precision, recall, f1 = 0.0, 0.0, 0.0  # Missed extraction
        else:
            true_positives = len(papermage_tokens.intersection(xml_tokens))
            precision = true_positives / len(papermage_tokens) if papermage_tokens else 0
            recall = true_positives / len(xml_tokens) if xml_tokens else 0
            f1 = (2 * precision * recall / (precision + recall)) if (precision + recall) > 0 else 0

        # Store results
        evaluation_results[category]["precision"].append(precision)
        evaluation_results[category]["recall"].append(recall)
        evaluation_results[category]["f1"].append(f1)
        document_results[f"{category}_precision"] = precision
        document_results[f"{category}_recall"] = recall
        document_results[f"{category}_f1"] = f1

    per_document_evaluation.append(document_results)

# Compute Macro-Averaged Results
macro_avg_results = {}
for category in categories:
    valid_scores = [
        f1 for f1 in evaluation_results[category]["f1"] if f1 != 0.0 or f1 != 1.0
    ]
    macro_avg_results[category] = {
        "precision": sum(evaluation_results[category]["precision"]) / len(valid_scores),
        "recall": sum(evaluation_results[category]["recall"]) / len(valid_scores),
        "f1": sum(valid_scores) / len(valid_scores),
    }

# Convert results to DataFrame
macro_results_df = pd.DataFrame.from_dict(macro_avg_results, orient="index")
per_document_df = pd.DataFrame(per_document_evaluation)

# Save detailed per-document evaluation to a file
per_document_df.to_csv("per_document_evaluation.csv", index=False)
macro_results_df.to_csv("macro_evaluation_results.csv", index=False)

# Display final results
print("Macro-Averaged Evaluation Results:")
print(macro_results_df)

print("\nPer-document evaluation saved to 'per_document_evaluation.csv'")

In [None]:
import random

# Identify some examples where there are mismatches
incorrect_predictions = []

for doc in per_document_evaluation:
    for category in categories:
        if doc[f"{category}_f1"] < 0.5:  # Only inspect low F1-score cases
            incorrect_predictions.append((doc["document_id"], category))

# Randomly sample some cases for manual inspection
# random.seed(45)
sample_cases = random.sample(incorrect_predictions, min(10, len(incorrect_predictions)))

# Display mismatched extractions
for doc_id, category in sample_cases:
    print(f"\n--- Document: {doc_id} | Category: {category} ---")
    print(f"Papermage Extracted: {papermage_data[doc_id].get(category, 'N/A')}")
    print(f"Ground Truth: {xml_data[doc_id].get(category, 'N/A')}")

In [None]:
import matplotlib.pyplot as plt

# Extract values
categories = list(macro_avg_results.keys())
precision_values = [macro_avg_results[cat]["precision"] for cat in categories]
recall_values = [macro_avg_results[cat]["recall"] for cat in categories]

# Create bar chart
plt.figure(figsize=(10, 6))
plt.bar(categories, precision_values, alpha=0.7, label="Precision")
plt.bar(categories, recall_values, alpha=0.7, label="Recall")
plt.xlabel("Categories")
plt.ylabel("Score")
plt.title("Precision vs Recall for Papermage Extraction")
plt.legend()
plt.show()

# **Phase 3 - RAG Development and DEMO**

## **Setup & Install Dependencies**

In [None]:
from google.colab import drive
drive.mount('/content/drive')

In [None]:
import os
import shutil
from tqdm import tqdm

# Define the source and destination paths
source_path = 'PLACEHOLDER'
destination_path = 'PLACEHOLDER'

# Get the file size for the progress bar
file_size = os.path.getsize(source_path)

# Define a function to copy the file with a progress bar
def copy_with_progress(src, dest):
    with open(src, 'rb') as fsrc, open(dest, 'wb') as fdest:
        with tqdm(total=file_size, unit='B', unit_scale=True, desc="Copying file") as pbar:
            while True:
                buffer = fsrc.read(1024 * 1024)  # Read in chunks of 1MB
                if not buffer:
                    break
                fdest.write(buffer)
                pbar.update(len(buffer))

# Copy the file
copy_with_progress(source_path, destination_path)

print("✅ File copied successfully!")

import zipfile
import os

zip_path = "PLACEHOLDER"  # Change to your ZIP file path
extract_path = "PLACEHOLDER"

# Extract ZIP file
with zipfile.ZipFile(zip_path, 'r') as zip_ref:
    zip_ref.extractall(extract_path)

# Locate PDFs inside the extracted folder
pdf_folder = os.path.join(extract_path, "pdf_sample_10")  # Adjust if subfolder exists

print("✅ PDFs extracted to:", pdf_folder)

In [None]:
# Core Libraries
!pip install 'papermage[dev,predictors,visualizers]'
!apt-get install poppler-utils
!pip install --upgrade byaldi
!pip install flash-attn
!pip install gradio
!pip install accelerate einops

# RESTART RUNTIME AFTER IT COMPLETES EXECUTION

## **Index & Model Initialization**

In [None]:
# Import libraries
from pathlib import Path
from byaldi import RAGMultiModalModel
from papermage.recipes import CoreRecipe

# Path to folder of PDFs
pdf_folder = "PLACEHOLDER"
index_name = "document_index"

# Initialize ColQwen via Byaldi
rag_model = RAGMultiModalModel.from_pretrained("vidore/colqwen2-v1.0")
rag_model.index(input_path=pdf_folder, index_name=index_name, overwrite=True, store_collection_with_index=True)

In [None]:
# Import libraries
from transformers import AutoTokenizer, AutoModelForCausalLM
from sentence_transformers import SentenceTransformer

# Path to folder of PDFs
pdf_folder = "PLACEHOLDER"
index_name = "document_index"

# Load Qwen2.5-7B-Instruct
llm_model_name = "Qwen/Qwen2.5-7B-Instruct"
llm_model = AutoModelForCausalLM.from_pretrained(llm_model_name, device_map="auto", torch_dtype="auto")
llm_tokenizer = AutoTokenizer.from_pretrained(llm_model_name)

# Load sentence embedding model
sentence_encoder = SentenceTransformer("all-MiniLM-L6-v2")

In [None]:
# Run once for a sample PDF to trigger Papermage model instance
doc = CoreRecipe().run('/content/pdf_sample_10_documents/pdf_sample_10/1833349.1778852.pdf')

## **Helper Functions**

In [None]:
import re
from sentence_transformers import util
from papermage.visualizers import plot_entities_on_page

def ask_qwen(question, context, max_new_tokens=150):

    messages = [
        {
            "role": "system",
            "content": "A user asks a question based on the paper's context. Answer the question shortly and comprehensively. "
                       "The answer should be found in the context. If the answer is not found, say: 'The answer is not in the provided context'."
        },
        {
            "role": "user",
            "content": f"Question: {question}\n\nContext:\n{context}\n\nAnswer:"
        }
    ]
    prompt = llm_tokenizer.apply_chat_template(messages, tokenize=False, add_generation_prompt=True)
    inputs = llm_tokenizer([prompt], return_tensors="pt").to(llm_model.device)

    output = llm_model.generate(
        **inputs,
        max_new_tokens=max_new_tokens,
        do_sample=False,
        eos_token_id=llm_tokenizer.eos_token_id
    )

    trimmed_ids = output[0][inputs.input_ids.shape[-1]:]
    return llm_tokenizer.decode(trimmed_ids, skip_special_tokens=True).strip()

def split_into_sentences(text):
    return [s.strip() for s in re.split(r'(?<=[.!?])\s+', text) if s.strip()]

def get_text_from_sentence(sentence):
    return sentence.text

def highlight_with_semantic_alignment(answer, page, threshold=0.70):
    from sentence_transformers import util

    answer_sentences = split_into_sentences(answer)
    answer_embeddings = sentence_encoder.encode(answer_sentences, convert_to_tensor=True)

    page_sentences = page.sentences
    page_texts = [s.text for s in page_sentences]
    page_embeddings = sentence_encoder.encode(page_texts, convert_to_tensor=True)

    matched_tokens = []

    for a_emb in answer_embeddings:
        sims = util.cos_sim(a_emb, page_embeddings)[0]
        for idx, score in enumerate(sims):
            if score >= threshold:
                matched_tokens.extend(page_sentences[idx].tokens)

    # Fallback if nothing matched
    if not matched_tokens:
        print("⚠️ No exact matches — fallback to most informative paragraph.")
        if hasattr(page, "blocks") and page.blocks:
            longest_block = max(page.blocks, key=lambda b: len(b.tokens))
            matched_tokens = longest_block.tokens
        else:
            print("⚠️ No blocks available — fallback failed.")

    return list({id(t): t for t in matched_tokens}.values())

## **Core RAG Pipeline**

In [None]:
def rag_qa(question: str):
    # Step 1: Retrieve most relevant page
    results = rag_model.search(question, k=1)
    top_result = results[0]
    doc_id = top_result["doc_id"]
    page_num = top_result["page_num"]

    doc_id_to_filename = rag_model.get_doc_ids_to_file_names()
    pdf_path = Path(pdf_folder) / doc_id_to_filename[doc_id]

    # Step 2: Extract tokens and sentences with Papermage
    doc = CoreRecipe().run(pdf_path)
    page = doc.pages[page_num - 1]
    page_text = " ".join([t.text for t in page.tokens])

    # Step 3: Ask Qwen
    answer = ask_qwen(question, page_text)

    # Step 4: Align answer back to page
    matched_tokens = highlight_with_semantic_alignment(answer, page)
    highlighted_image = plot_entities_on_page(page.images[0], matched_tokens, box_color="yellow", box_alpha=0.4)

    # Return answer, image, and paper reference
    doc_info = f"📄 Document: {pdf_filename} (doc_id: {doc_id})"
    return answer, highlighted_image.pilimage

In [None]:
from pathlib import Path

def rag_qa(question: str):
    try:
        results = rag_model.search(question, k=1)
        top_result = results[0]
        doc_id = top_result["doc_id"]
        page_num = top_result["page_num"]

        doc_id_to_filename = rag_model.get_doc_ids_to_file_names()
        full_path = Path(pdf_folder) / doc_id_to_filename[doc_id]
        pdf_filename = full_path.name  # ONLY filename like "2330784.2330981.pdf"

        doc = CoreRecipe().run(full_path)
        page = doc.pages[page_num - 1]
        page_text = " ".join(t.text for t in page.tokens)

        answer = ask_qwen(question, page_text)
        matched_tokens = highlight_with_semantic_alignment(answer, page)
        highlighted_image = plot_entities_on_page(page.images[0], matched_tokens, box_color="yellow", box_alpha=0.4)

        return answer, highlighted_image.pilimage, pdf_filename  # returns just filename

    except Exception as e:
        import traceback
        return f"❌ Error: {str(e)}\n\n{traceback.format_exc()}", None, "❌ Error loading document"


In [None]:
def rag_qa(question: str):
    try:
        # Step 1: Search for the most relevant page
        results = rag_model.search(question, k=1)
        top_result = results[0]
        retrieval_score = top_result["score"]

        # Step 1.5: Filter based on score
        MIN_RETRIEVAL_SCORE = 10.0  # Adjust as needed
        if retrieval_score < MIN_RETRIEVAL_SCORE:
            return (
                "The answer is not in the provided context.",
                None,
                "⚠️ No relevant page found (retrieval score too low)."
            )

        # Step 2: Locate and parse the PDF
        doc_id = top_result["doc_id"]
        page_num = top_result["page_num"]
        doc_id_to_filename = rag_model.get_doc_ids_to_file_names()
        pdf_filename = Path(doc_id_to_filename[doc_id]).name
        pdf_path = Path(pdf_folder) / pdf_filename

        from papermage.recipes import CoreRecipe
        doc = CoreRecipe().run(pdf_path)
        page = doc.pages[page_num - 1]
        page_text = " ".join([t.text for t in page.tokens])

        # Step 3: Generate answer
        answer = ask_qwen(question, page_text)

        # Step 4: If Qwen says answer not found, skip highlighting
        if answer.strip() == "The answer is not in the provided context.":
            return (
                answer,
                None,
                f"📄 Document: {pdf_filename} (doc_id: {doc_id})"
            )

        # Step 5: Highlight matched tokens
        matched_tokens = highlight_with_semantic_alignment(answer, page)
        highlighted_image = plot_entities_on_page(page.images[0], matched_tokens, box_color="yellow", box_alpha=0.4)

        # Step 6: Return outputs
        return (
            answer,
            highlighted_image.pilimage,
            f"📄 Document: {pdf_filename} (doc_id: {doc_id})"
        )

    except Exception as e:
        import traceback
        return f"❌ Error: {str(e)}\n\n{traceback.format_exc()}", None, "❌ Failed to complete retrieval"

## **Gradio Interface**

In [None]:
import gradio as gr

gr.Interface(
    fn=rag_qa,
    inputs=gr.Textbox(label="Ask a question about the papers"),
    outputs=[
        gr.Textbox(label="Answer from Qwen"),
        gr.Image(type="pil", label="Highlighted PDF Page"),
        gr.Textbox(label="PDF Filename")  # Just the filename here
    ],
    title="📄 RAG PDF Q&A",
    description="Ask a question and see the page + filename that contains the answer.",
).launch(debug=True)