Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
46 changes: 43 additions & 3 deletions docling_eval/cli/main.py
Original file line number Diff line number Diff line change
@@ -1,15 +1,23 @@
import json
import logging
import os
import sys
from enum import Enum
from pathlib import Path
from typing import Annotated, Optional, Tuple
from typing import Annotated, Dict, Optional, Tuple

import typer
from docling.datamodel.base_models import InputFormat
from docling.datamodel.pipeline_options import PdfPipelineOptions
from docling.document_converter import PdfFormatOption
from docling.datamodel.pipeline_options import (
PaginatedPipelineOptions,
PdfPipelineOptions,
VlmPipelineOptions,
smoldocling_vlm_conversion_options,
smoldocling_vlm_mlx_conversion_options,
)
from docling.document_converter import FormatOption, PdfFormatOption
from docling.models.factories import get_ocr_factory
from docling.pipeline.vlm_pipeline import VlmPipeline
from tabulate import tabulate # type: ignore

from docling_eval.datamodels.types import (
Expand Down Expand Up @@ -74,6 +82,7 @@ class PredictionProviderType(str, Enum):
DOCLING = "docling"
TABLEFORMER = "tableformer"
FILE = "file"
SMOLDOCLING = "smoldocling"


def log_and_save_stats(
Expand Down Expand Up @@ -184,6 +193,7 @@ def get_prediction_provider(
file_source_path: Optional[Path] = None,
file_prediction_format: Optional[PredictionFormats] = None,
):
pipeline_options: PaginatedPipelineOptions
"""Get the appropriate prediction provider with default settings."""
if provider_type == PredictionProviderType.DOCLING:
ocr_factory = get_ocr_factory()
Expand Down Expand Up @@ -211,6 +221,36 @@ def get_prediction_provider(
ignore_missing_predictions=True,
)

elif provider_type == PredictionProviderType.SMOLDOCLING:
pipeline_options = VlmPipelineOptions()

pipeline_options.vlm_options = smoldocling_vlm_conversion_options
if sys.platform == "darwin":
try:
import mlx_vlm # type: ignore

pipeline_options.vlm_options = smoldocling_vlm_mlx_conversion_options
except ImportError:
_log.warning(
"To run SmolDocling faster, please install mlx-vlm:\n"
"pip install mlx-vlm"
)

pdf_format_option = PdfFormatOption(
pipeline_cls=VlmPipeline, pipeline_options=pipeline_options
)

format_options: Dict[InputFormat, FormatOption] = {
InputFormat.PDF: pdf_format_option,
InputFormat.IMAGE: pdf_format_option,
}

return DoclingPredictionProvider(
format_options=format_options,
do_visualization=True,
ignore_missing_predictions=True,
)

elif provider_type == PredictionProviderType.TABLEFORMER:
return TableFormerPredictionProvider(
do_visualization=True,
Expand Down
64 changes: 11 additions & 53 deletions tests/test_dataset_builder.py
Original file line number Diff line number Diff line change
@@ -1,19 +1,14 @@
import os
from pathlib import Path
from typing import List, Optional

import pytest
from docling.datamodel.base_models import InputFormat
from docling.datamodel.pipeline_options import (
EasyOcrOptions,
OcrOptions,
PdfPipelineOptions,
TableFormerMode,
)
from docling.document_converter import PdfFormatOption
from docling.models.factories import get_ocr_factory

from docling_eval.cli.main import evaluate, visualize
from docling_eval.cli.main import (
PredictionProviderType,
evaluate,
get_prediction_provider,
visualize,
)
from docling_eval.datamodels.types import (
BenchMarkNames,
EvaluationModality,
Expand All @@ -33,55 +28,18 @@
PubTabNetDatasetBuilder,
)
from docling_eval.dataset_builders.xfund_builder import XFUNDDatasetBuilder
from docling_eval.prediction_providers.docling_provider import DoclingPredictionProvider
from docling_eval.prediction_providers.file_provider import FilePredictionProvider
from docling_eval.prediction_providers.tableformer_provider import (
TableFormerPredictionProvider,
)

ocr_factory = get_ocr_factory()

IS_CI = os.getenv("RUN_IN_CI") == "1"


def create_docling_prediction_provider(
page_image_scale: float = 2.0,
do_ocr: bool = False,
ocr_lang: Optional[List[str]] = None,
ocr_engine: str = EasyOcrOptions.kind,
artifacts_path: Optional[Path] = None,
):
ocr_options: OcrOptions = ocr_factory.create_options( # type: ignore
kind=ocr_engine,
)
if ocr_lang is not None:
ocr_options.lang = ocr_lang

pipeline_options = PdfPipelineOptions(
do_ocr=do_ocr,
ocr_options=ocr_options,
do_table_structure=True,
artifacts_path=artifacts_path,
)

pipeline_options.table_structure_options.mode = TableFormerMode.ACCURATE

pipeline_options.images_scale = page_image_scale
pipeline_options.generate_page_images = True
pipeline_options.generate_picture_images = True

return DoclingPredictionProvider(
format_options={
InputFormat.PDF: PdfFormatOption(pipeline_options=pipeline_options)
},
do_visualization=True,
)


@pytest.mark.dependency()
def test_run_dpbench_e2e():
target_path = Path(f"./scratch/{BenchMarkNames.DPBENCH.value}/")
docling_provider = create_docling_prediction_provider(page_image_scale=2.0)
docling_provider = get_prediction_provider(PredictionProviderType.DOCLING)

dataset_layout = DPBenchDatasetBuilder(
target=target_path / "gt_dataset",
Expand Down Expand Up @@ -207,7 +165,7 @@ def test_run_doclaynet_with_doctags_fileprovider():
)
def test_run_omnidocbench_e2e():
target_path = Path(f"./scratch/{BenchMarkNames.OMNIDOCBENCH.value}/")
docling_provider = create_docling_prediction_provider(page_image_scale=2.0)
docling_provider = get_prediction_provider(PredictionProviderType.DOCLING)

dataset_layout = OmniDocBenchDatasetBuilder(
target=target_path / "gt_dataset",
Expand Down Expand Up @@ -339,7 +297,7 @@ def test_run_omnidocbench_tables():
)
def test_run_doclaynet_v1_e2e():
target_path = Path(f"./scratch/{BenchMarkNames.DOCLAYNETV1.value}/")
docling_provider = create_docling_prediction_provider(page_image_scale=2.0)
docling_provider = get_prediction_provider(PredictionProviderType.DOCLING)

dataset_layout = DocLayNetV1DatasetBuilder(
# prediction_provider=docling_provider,
Expand Down Expand Up @@ -390,7 +348,7 @@ def test_run_doclaynet_v1_e2e():
@pytest.mark.skip("Test needs local data which is unavailable.")
def test_run_doclaynet_v2_e2e():
target_path = Path(f"./scratch/{BenchMarkNames.DOCLAYNETV2.value}/")
docling_provider = create_docling_prediction_provider(page_image_scale=2.0)
docling_provider = get_prediction_provider(PredictionProviderType.DOCLING)

dataset_layout = DocLayNetV2DatasetBuilder(
dataset_source=Path("/path/to/doclaynet_v2_benchmark"),
Expand Down Expand Up @@ -594,7 +552,7 @@ def test_run_docvqa_builder():
)

dataset_layout.save_to_disk() # does all the job of iterating the dataset, making GT+prediction records, and saving them in shards as parquet.
docling_provider = create_docling_prediction_provider(page_image_scale=2.0)
docling_provider = get_prediction_provider(PredictionProviderType.DOCLING)

docling_provider.create_prediction_dataset(
name=dataset_layout.name,
Expand Down