In [1]:
%load_ext autoreload
%autoreload 2

In [2]:
from pathlib import Path

from transformers import AutoProcessor, LayoutLMv3ForTokenClassification

from mozilla_sec_eia.ex_21.inference import (
    create_inference_dataset,
    perform_inference
)
from mozilla_sec_eia.utils.pdf import (
    draw_boxes_on_img,
    unnormalize_box
)

In [3]:
has_labels = False

In [4]:
label_list = ['O', 'B-Subsidiary', 'I-Subsidiary', 'B-Loc', 'I-Loc', 'B-Own_Per']
id2label = {k: v for k,v in enumerate(label_list)}
label2id = {v: k for k,v in enumerate(label_list)}

In [5]:
# load_model function loads a model with mlflow
model_path = Path("../models/layoutlm_v1_50_labeled_docs")

In [6]:
model = LayoutLMv3ForTokenClassification.from_pretrained(model_path,
                                                         id2label=id2label,
                                                         label2id=label2id)

In [8]:
processor = AutoProcessor.from_pretrained(
        "microsoft/layoutlmv3-base", apply_ocr=False
    )

In [9]:
pdf_dir = Path("../sec10k_filings/pdfs")

In [10]:
# only necessary if using data with labels
labeled_json_dir = Path("../sec10k_filings/labeled_jsons_v0.1")

# Create a Dataset and Perform Inference

In [11]:
dataset = create_inference_dataset(
        pdfs_dir=pdf_dir, labeled_json_dir=labeled_json_dir, has_labels=has_labels
    )

In [12]:
# only use 3 examples
dataset_index = [0, 1, 2]

In [26]:
# check but I think this is mainly slow because it's checking to make sure PDFs and JSONs are cached
# there was something about how pipelines can slow down code and you should be constantly profiling lol
logit_list, pred_list, output_dfs = perform_inference(pdf_dir, model, processor, dataset_index, labeled_json_dir, has_labels)

  df = df.merge(words_df, how="left", on=BBOX_COLS).drop_duplicates(


TypeError: sequence item 0: expected str instance, float found

# Visualize with no labels

In [None]:
assert not has_labels

In [None]:
def visual_inputs():
    for i in range(len(pred_list)):
        predictions = pred_list[i]
        example = dataset[i]
        image = example["image"]
        words = example["tokens"]
        boxes = example["bboxes"]
        encoding = processor(
            image,
            words,
            boxes=boxes,
            return_tensors="pt",
            truncation=True,
            padding="max_length",
        )
        yield predictions, encoding, image

gen = visual_inputs()

In [None]:
predictions, encoding, image = next(gen)
width, height = image.size
preds = [model.config.id2label[pred] for pred in predictions]
token_boxes = encoding.bbox.squeeze().tolist()
boxes = [unnormalize_box(box, width, height) for box in token_boxes]

image = image.convert("RGB")
draw_boxes_on_img(preds, boxes, image, width, height)
image

# Visualize with Labels

In [None]:
assert has_labels

In [None]:
def convert_ner_tags_to_id(ner_tags):
    return [int(label2id[ner_tag]) for ner_tag in ner_tags]

def visual_inputs_with_labels():
    for i in range(len(pred_list)):
        predictions = pred_list[i]
        example = dataset[i]
        image = example["image"]
        words = example["tokens"]
        boxes = example["bboxes"]
        ner_tags = convert_ner_tags_to_id(example["ner_tags"])
        encoding = processor(
            image,
            words,
            boxes=boxes,
            word_labels=ner_tags,
            return_tensors="pt",
            truncation=True,
            padding="max_length",
        )
        yield predictions, encoding, image

gen_w_labels = visual_inputs_with_labels()

In [None]:
# TODO: also add visualizaton for wrong predictions
predictions, encoding, image = next(gen_w_labels)
width, height = image.size
token_boxes = encoding.bbox.squeeze().tolist()
boxes = [unnormalize_box(box, width, height) for box in token_boxes]
labels = encoding.labels.squeeze().tolist()

true_predictions = [model.config.id2label[pred] for pred, label in zip(predictions, labels) if label != - 100]
true_labels = [model.config.id2label[label] for prediction, label in zip(predictions, labels) if label != -100]
true_boxes = [unnormalize_box(box, width, height) for box, label in zip(token_boxes, labels) if label != -100]

image = image.convert("RGB")
draw_boxes_on_img(true_predictions, true_boxes, image, width, height)
image