Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Doc classifier #37

Merged
merged 5 commits into from
Sep 26, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 3 additions & 3 deletions src/docquery/cmd/scan.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,7 @@ def build_parser(subparsers, parent_parser):
parser.add_argument(
"--classify-checkpoint",
default=None,
help=f"A custom model checkpoint to use (other than {PIPELINE_DEFAULTS['image-classification']})",
help=f"A custom model checkpoint to use (other than {PIPELINE_DEFAULTS['document-classification']})",
)

parser.set_defaults(func=main)
Expand Down Expand Up @@ -68,7 +68,7 @@ def main(args):
log.info("Ready to start evaluating!")

if args.classify:
classify = pipeline("image-classification", model=args.classify_checkpoint)
classify = pipeline("document-classification", model=args.classify_checkpoint)

max_fname_len = max(len(str(p)) for p in paths)
max_question_len = max(len(q) for q in args.questions) if len(args.questions) > 0 else 0
Expand All @@ -82,7 +82,7 @@ def main(args):
if i > 0 and len(args.questions) > 1:
print("")
if args.classify:
cls = classify(images=d.preview[0])[0]
cls = classify(image=d.preview[0])[0]
print(f"{str(p):<{max_fname_len}} Document Type: {cls['label']}")

for q in args.questions:
Expand Down
329 changes: 329 additions & 0 deletions src/docquery/ext/pipeline_document_classification.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,329 @@
# This file is copied from transformers:
# https://github.com/huggingface/transformers/blob/bb6f6d53386bf2340eead6a8f9320ce61add3e96/src/transformers/pipelines/image_classification.py
# And has been modified to support Donut
import re
from typing import List, Optional, Tuple, Union

import torch
from transformers.pipelines.base import PIPELINE_INIT_ARGS, ChunkPipeline
from transformers.pipelines.text_classification import ClassificationFunction, sigmoid, softmax
from transformers.utils import ExplicitEnum, add_end_docstrings, logging

from .pipeline_document_question_answering import ImageOrName, apply_tesseract
from .qa_helpers import TESSERACT_LOADED, VISION_LOADED, load_image


logger = logging.get_logger(__name__)


class ModelType(ExplicitEnum):
Standard = "standard"
VisionEncoderDecoder = "vision_encoder_decoder"


def donut_token2json(tokenizer, tokens, is_inner_value=False):
"""
Convert a (generated) token sequence into an ordered JSON format.
"""
output = dict()

while tokens:
start_token = re.search(r"<s_(.*?)>", tokens, re.IGNORECASE)
if start_token is None:
break
key = start_token.group(1)
end_token = re.search(rf"</s_{key}>", tokens, re.IGNORECASE)
start_token = start_token.group()
if end_token is None:
tokens = tokens.replace(start_token, "")
else:
end_token = end_token.group()
start_token_escaped = re.escape(start_token)
end_token_escaped = re.escape(end_token)
content = re.search(f"{start_token_escaped}(.*?){end_token_escaped}", tokens, re.IGNORECASE)
if content is not None:
content = content.group(1).strip()
if r"<s_" in content and r"</s_" in content: # non-leaf node
value = donut_token2json(tokenizer, content, is_inner_value=True)
if value:
if len(value) == 1:
value = value[0]
output[key] = value
else: # leaf nodes
output[key] = []
for leaf in content.split(r"<sep/>"):
leaf = leaf.strip()
if leaf in tokenizer.get_added_vocab() and leaf[0] == "<" and leaf[-2:] == "/>":
leaf = leaf[1:-2] # for categorical special tokens
output[key].append(leaf)
if len(output[key]) == 1:
output[key] = output[key][0]

tokens = tokens[tokens.find(end_token) + len(end_token) :].strip()
if tokens[:6] == r"<sep/>": # non-leaf nodes
return [output] + donut_token2json(tokenizer, tokens[6:], is_inner_value=True)

if len(output):
return [output] if is_inner_value else output
else:
return [] if is_inner_value else {"text_sequence": tokens}


DEFAULT_CLS_BBOX = [1000, 1000, 1000, 1000]
DEFAULT_SEP_BBOX = [0, 0, 0, 0]
DEFAULT_PAD_BBOX = [0, 0, 0, 0]


@add_end_docstrings(PIPELINE_INIT_ARGS)
class DocumentClassificationPipeline(ChunkPipeline):
"""
Document classification pipeline using any `AutoModelForDocumentClassification`. This pipeline predicts the class of a
document.

This document classification pipeline can currently be loaded from [`pipeline`] using the following task identifier:
`"document-classification"`.

See the list of available models on
[huggingface.co/models](https://huggingface.co/models?filter=document-classification).
"""

def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)

if self.model.config.__class__.__name__ == "VisionEncoderDecoderConfig":
self.model_type = ModelType.VisionEncoderDecoder
else:
self.model_type = ModelType.Standard

def _sanitize_parameters(
self,
doc_stride=None,
lang: Optional[str] = None,
tesseract_config: Optional[str] = None,
max_seq_len=None,
function_to_apply=None,
top_k=None,

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I took a quick look at huggingface's TextClassificationPipeline and it looks like they set the default value for top_k to an empty string, as None is reserved for something else. Not sure if that would make sense for our use case, but if you wanted to more closely follow huggingface's patterns: https://github.com/huggingface/transformers/blob/main/src/transformers/pipelines/text_classification.py#L77

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Interesting, thanks for pointing that out. I think that is some legacy code that is probably still around for back compat.

I mostly followed the newer image pipeline which I think we should do here too.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I set the default to 1 to match roughly this behavior.

):
preprocess_params, postprocess_params = {}, {}
if doc_stride is not None:
preprocess_params["doc_stride"] = doc_stride
if max_seq_len is not None:
preprocess_params["max_seq_len"] = max_seq_len
if lang is not None:
preprocess_params["lang"] = lang
if tesseract_config is not None:
preprocess_params["tesseract_config"] = tesseract_config

if isinstance(function_to_apply, str):
function_to_apply = ClassificationFunction[function_to_apply.upper()]

if function_to_apply is not None:
postprocess_params["function_to_apply"] = function_to_apply

if top_k is not None:
if top_k < 1:
raise ValueError(f"top_k parameter should be >= 1 (got {top_k})")
postprocess_params["top_k"] = top_k

return preprocess_params, {}, postprocess_params

def __call__(self, image: Union[ImageOrName, List[ImageOrName], List[Tuple]], **kwargs):

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Super nit: Huggingface uses the identifier images for the first input param in the call method of similar pipelines, so if we want to be consistent with them we should change the variable name here.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm going to keep it as is to be consistent with the document question answering pipeline.

"""
Assign labels to the document(s) passed as inputs.

# TODO
"""
if isinstance(image, list):
normalized_images = (i if isinstance(i, (tuple, list)) else (i, None) for i in image)
else:
normalized_images = [(image, None)]

return super().__call__({"pages": normalized_images}, **kwargs)

def preprocess(
self,
input,
doc_stride=None,
max_seq_len=None,
word_boxes: Tuple[str, List[float]] = None,
lang=None,
tesseract_config="",
):
# NOTE: This code mirrors the code in question answering and will be implemented in a follow up PR
# to support documents with enough tokens that overflow the model's window
if max_seq_len is None:
# TODO: LayoutLM's stride is 512 by default. Is it ok to use that as the min
# instead of 384 (which the QA model uses)?
max_seq_len = min(self.tokenizer.model_max_length, 512)

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Replace magic numbers like 512 with global variables defined at the top of the file

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm going to keep this as is for now to stay consistent with the questoin answering pipeline (it's basically the same code). I need to refactor that code to incorporate changes from transformers soon, so I'll try to fix both simultaneously.

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

👍 That works too


if doc_stride is None:
doc_stride = min(max_seq_len // 2, 256)

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Just curius, what's the reason for using min(...) here rather than just saying doc_stride = max_seq_len // 2? The latter seems simpler is all, as well as removing the use of hardcoded numeric values in the code.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

this code is borrowed from the question_answering pipeline in huggingface.


for page_idx, (image, word_boxes) in enumerate(input["pages"]):
image_features = {}
if image is not None:
if not VISION_LOADED:
raise ValueError(
"If you provide an image, then the pipeline will run process it with PIL (Pillow), but"
" PIL is not available. Install it with pip install Pillow."
)
image = load_image(image)
if self.feature_extractor is not None:
image_features.update(self.feature_extractor(images=image, return_tensors=self.framework))

words, boxes = None, None
if self.model_type != ModelType.VisionEncoderDecoder:
if word_boxes is not None:
words = [x[0] for x in word_boxes]
boxes = [x[1] for x in word_boxes]
elif "words" in image_features and "boxes" in image_features:
words = image_features.pop("words")[0]
boxes = image_features.pop("boxes")[0]
elif image is not None:
if not TESSERACT_LOADED:

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

imo this if statement and the next one should both be combined/replaced by if TESSERACT_LOADED: ... else: .... This check could also be put inside the apply_tesseract(), which would reduce the amount of if TESSERACT_LOADED checks that we do throughout the repo (especially as more pipelines are added in the future).

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

going to keep this as is for consistency with document question answering.

raise ValueError(
"If you provide an image without word_boxes, then the pipeline will run OCR using"
" Tesseract, but pytesseract is not available. Install it with pip install pytesseract."
)
if TESSERACT_LOADED:
words, boxes = apply_tesseract(image, lang=lang, tesseract_config=tesseract_config)
else:
raise ValueError(
"You must provide an image or word_boxes. If you provide an image, the pipeline will"
" automatically run OCR to derive words and boxes"
)

if self.tokenizer.padding_side != "right":
raise ValueError(
"Document classification only supports tokenizers whose padding side is 'right', not"
f" {self.tokenizer.padding_side}"
)

if self.model_type == ModelType.VisionEncoderDecoder:
encoding = {
"inputs": image_features["pixel_values"],
"max_length": self.model.decoder.config.max_position_embeddings,
"decoder_input_ids": self.tokenizer(
"<s_rvlcdip>",
add_special_tokens=False,
return_tensors=self.framework,
).input_ids,
"return_dict_in_generate": True,
}
yield {
**encoding,
"page": None,
}
else:
encoding = self.tokenizer(
ankrgyl marked this conversation as resolved.
Show resolved Hide resolved
text=words,
max_length=max_seq_len,
stride=doc_stride,
return_token_type_ids=True,
ankrgyl marked this conversation as resolved.
Show resolved Hide resolved
is_split_into_words=True,
truncation=True,
return_overflowing_tokens=True,
)

num_spans = len(encoding["input_ids"])

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: Sometimes we access members of encoding by dict keys (i.e. encoding["input_ids"]), and sometimes by attribute (i.e. encoding.input_ids). I think we should pick one or the other, for code consistency.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm going to leave this as is for consistency with the document question answering pipeline


for span_idx in range(num_spans):
if self.framework == "pt":
span_encoding = {k: torch.tensor(v[span_idx : span_idx + 1]) for (k, v) in encoding.items()}
span_encoding.update(
{k: v for (k, v) in image_features.items()}
) # TODO: Verify cardinality is correct
else:
raise ValueError("Unsupported: Tensorflow preprocessing for DocumentClassification")

# For each span, place a bounding box [0,0,0,0] for question and CLS tokens, [1000,1000,1000,1000]
# for SEP tokens, and the word's bounding box for words in the original document.
bbox = []
for i, s, w in zip(
encoding.input_ids[span_idx],
encoding.sequence_ids(span_idx),
encoding.word_ids(span_idx),
):
if i == self.tokenizer.cls_token_id:
bbox.append(DEFAULT_CLS_BBOX)
elif i == self.tokenizer.sep_token_id:
bbox.append(DEFAULT_SEP_BBOX)
elif i == self.tokenizer.pad_token_id:
bbox.append(DEFAULT_PAD_BBOX)
else:
bbox.append(boxes[w])

span_encoding["bbox"] = torch.tensor(bbox).unsqueeze(0)

yield {
**span_encoding,
"page": page_idx,
}

def _forward(self, model_inputs):
page = model_inputs.pop("page", None)

if "overflow_to_sample_mapping" in model_inputs:
model_inputs.pop("overflow_to_sample_mapping")

if self.model_type == ModelType.VisionEncoderDecoder:
model_outputs = self.model.generate(**model_inputs)
else:
model_outputs = self.model(**model_inputs)

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: It's a bit confusing that model_outputs in this method is a ModelOutput object, but model_outputs in the postprocess() method is a list of ModelOutputs. If I were writing this I would tweak the variable names to try and reflect this, but I won't push too hard on this 😄

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'll change it -- I think that is valid feedback.


model_outputs["page"] = page
model_outputs["attention_mask"] = model_inputs.get("attention_mask", None)
return model_outputs

def postprocess(self, model_outputs, function_to_apply=None, top_k=1, **kwargs):
if function_to_apply is None:
if self.model.config.num_labels == 1:
function_to_apply = ClassificationFunction.SIGMOID
elif self.model.config.num_labels > 1:
function_to_apply = ClassificationFunction.SOFTMAX
elif hasattr(self.model.config, "function_to_apply") and function_to_apply is None:
function_to_apply = self.model.config.function_to_apply
else:
function_to_apply = ClassificationFunction.NONE

if self.model_type == ModelType.VisionEncoderDecoder:
answers = self.postprocess_encoder_decoder(model_outputs, top_k=top_k, **kwargs)
else:
answers = self.postprocess_standard(
model_outputs, function_to_apply=function_to_apply, top_k=top_k, **kwargs
)

answers = sorted(answers, key=lambda x: x.get("score", 0), reverse=True)[:top_k]

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Will this break if top_k > len(answers)?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't think so?

> a = [1,2,3]
> a[:10]
# [1, 2, 3]

return answers

def postprocess_encoder_decoder(self, model_outputs, **kwargs):
classes = set()
for model_output in model_outputs:
for sequence in self.tokenizer.batch_decode(model_output.sequences):
sequence = sequence.replace(self.tokenizer.eos_token, "").replace(self.tokenizer.pad_token, "")
sequence = re.sub(r"<.*?>", "", sequence, count=1).strip() # remove first task start token
classes.add(donut_token2json(self.tokenizer, sequence)["class"])

# Return the first top_k unique classes we see
return [{"label": v} for v in classes]

def postprocess_standard(self, model_outputs, function_to_apply, **kwargs):
# Average the score across pages
sum_scores = {k: 0 for k in self.model.config.id2label.values()}
for model_output in model_outputs:
outputs = model_output["logits"][0]
outputs = outputs.numpy()

ankrgyl marked this conversation as resolved.
Show resolved Hide resolved
if function_to_apply == ClassificationFunction.SIGMOID:
scores = sigmoid(outputs)
elif function_to_apply == ClassificationFunction.SOFTMAX:
scores = softmax(outputs)
elif function_to_apply == ClassificationFunction.NONE:
scores = outputs
else:
raise ValueError(f"Unrecognized `function_to_apply` argument: {function_to_apply}")

for i, score in enumerate(scores):
sum_scores[self.model.config.id2label[i]] += score.item()

return [{"label": label, "score": score / len(model_outputs)} for (label, score) in sum_scores.items()]
3 changes: 1 addition & 2 deletions src/docquery/ext/pipeline_document_question_answering.py
Original file line number Diff line number Diff line change
Expand Up @@ -197,7 +197,7 @@ def __call__(
image = image["image"]

if isinstance(image, list):
normalized_images = (i if isinstance(i, tuple) or isinstance(i, list) else (i, None) for i in image)
normalized_images = (i if isinstance(i, (tuple, list)) else (i, None) for i in image)
else:
normalized_images = [(image, None)]

Expand Down Expand Up @@ -297,7 +297,6 @@ def preprocess(
padding=padding,
max_length=max_seq_len,
stride=doc_stride,
return_token_type_ids=True,
is_split_into_words=True,
truncation="only_second",
return_overflowing_tokens=True,
Expand Down
Loading