Skip to content

Commit

Permalink
Add basic end-to-end tests
Browse files Browse the repository at this point in the history
  • Loading branch information
ankrgyl committed Sep 6, 2022
1 parent 9f9fba8 commit 7a862d9
Show file tree
Hide file tree
Showing 4 changed files with 98 additions and 2 deletions.
2 changes: 1 addition & 1 deletion .flake8
Original file line number Diff line number Diff line change
@@ -1,3 +1,3 @@
[flake8]
max-line-length = 119
ignore = E402, E203
ignore = E402, E203, E501
5 changes: 4 additions & 1 deletion Makefile
Original file line number Diff line number Diff line change
Expand Up @@ -32,9 +32,12 @@ ${VENV_PRE_COMMIT}: ${VENV_PYTHON_PACKAGES}
bash -c 'source venv/bin/activate && pre-commit install'
@touch $@

.PHONY: develop fixup
.PHONY: develop fixup test
develop: ${VENV_PRE_COMMIT}
@echo 'Run "source venv/bin/activate" to enter development mode'

fixup:
pre-commit run --all-files

test:
python -m pytest -s -v ./tests/
1 change: 1 addition & 0 deletions setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,7 @@
"flake8-isort",
"isort==5.10.1",
"pre-commit",
"pytest",
"twine",
],
"donut": [
Expand Down
92 changes: 92 additions & 0 deletions tests/test_end_to_end.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,92 @@
from pathlib import Path
from typing import Any, Dict, List

import pytest
from pydantic import BaseModel
from transformers.testing_utils import nested_simplify

from docquery.document import load_document
from docquery.pipeline import get_pipeline


CHECKPOINTS = {
"LayoutLMv1": "impira/layoutlm-document-qa",
"Donut": "naver-clova-ix/donut-base-finetuned-docvqa",
}


class QAPair(BaseModel):
question: str
answers: Dict[str, Dict]


class Example(BaseModel):
name: str
path: str
qa_pairs: List[QAPair]


# Use the examples from the DocQuery space (this also solves for hosting)
EXAMPLES = [
Example(
name="contract",
path="https://huggingface.co/spaces/impira/docquery/resolve/2f6c96314dc84dfda62d40de9da55f2f5165d403/contract.jpeg",
qa_pairs=[
{
"question": "What is the purchase amount?",
"answers": {
"LayoutLMv1": {
"score": 0.9999,
"answer": "$1,000,000,000",
"start": 97,
"end": 97,
"page": 0,
},
"Donut": {"answer": "$1,0000,000,00"},
},
}
],
),
Example(
name="invoice",
path="https://huggingface.co/spaces/impira/docquery/resolve/2f6c96314dc84dfda62d40de9da55f2f5165d403/invoice.png",
qa_pairs=[
{
"question": "What is the invoice number?",
"answers": {
"LayoutLMv1": {"score": 0.9997, "answer": "us-001", "start": 15, "end": 15, "page": 0},
"Donut": {"answer": "us-001"},
},
}
],
),
Example(
name="statement",
path="https://huggingface.co/spaces/impira/docquery/resolve/2f6c96314dc84dfda62d40de9da55f2f5165d403/statement.pdf",
qa_pairs=[
{
"question": "What are net sales for 2020?",
"answers": {
"LayoutLMv1": {
"score": 0.9429,
"answer": "$ 3,750\n",
"start": 15,
"end": 16,
"page": 0,
},
"Donut": {"answer": "$ 3,750"},
},
}
],
),
]


@pytest.mark.parametrize("example", EXAMPLES)
@pytest.mark.parametrize("model", CHECKPOINTS.keys())
def test_impira_dataset(example, model):
document = load_document(example.path)
pipeline = get_pipeline(checkpoint=CHECKPOINTS[model])
for qa in example.qa_pairs:
resp = pipeline(question=qa.question, **document.context, top_k=1)
assert nested_simplify(resp, decimals=4) == qa.answers[model]

0 comments on commit 7a862d9

Please sign in to comment.