In [16]:
import numpy as np
import tritonclient.http as httpclient

from transformers import BertTokenizer, AutoTokenizer
from layoutlm.data.funsd import read_examples_from_file, convert_examples_to_features
import torch
import torch.nn as nn
from torch.nn import CrossEntropyLoss

# Util functions

In [18]:
def get_labels(path):
    with open(path, "r") as f:
        labels = f.read().splitlines()
    if "O" not in labels:
        labels = ["O"] + labels
    return labels

# Prepare inputs

In [2]:
client = httpclient.InferenceServerClient('localhost:8000')

In [25]:
tokenizer = AutoTokenizer.from_pretrained('model/layoutlm-base-uncased/')
labels = get_labels('data/infer_data/labels.txt')
num_labels = len(labels)
pad_token_label_id = CrossEntropyLoss().ignore_index

In [26]:
examples = read_examples_from_file('data/infer_data/', 'test')
features = convert_examples_to_features(
    examples, labels, 512,
    tokenizer, cls_token_at_end=False, # xlnet has a cls token at the end
    cls_token=tokenizer.cls_token, cls_token_segment_id=0,
    sep_token=tokenizer.sep_token, sep_token_extra=False,
    # roberta uses an extra separator b/w pairs of sentences, cf. github.com/pytorch/fairseq/commit/1684e166e3da03f5b600dbb7855cb98ddfcd0805
    pad_on_left=False,
    # pad on the left for xlnet
    pad_token=tokenizer.convert_tokens_to_ids([tokenizer.pad_token])[0],
    pad_token_segment_id=0,
    pad_token_label_id=pad_token_label_id,
)

In [27]:
all_input_ids = torch.tensor(
    [f.input_ids for f in features], dtype=torch.long
)
all_input_mask = torch.tensor(
    [f.input_mask for f in features], dtype=torch.long
)
all_segment_ids = torch.tensor(
    [f.segment_ids for f in features], dtype=torch.long
)
all_label_ids = torch.tensor(
    [f.label_ids for f in features], dtype=torch.long
)
all_bboxes = torch.tensor([f.boxes for f in features], dtype=torch.long)

In [30]:
inputs= [
    httpclient.InferInput(
        "input_ids",
        [1, 512],
        "INT64"
    ),
    httpclient.InferInput(
        "attention_mask",
        [1, 512],
        "INT64"
    ),
    httpclient.InferInput(
        "token_type_ids",
        [1, 512],
        "INT64"
    ),
    httpclient.InferInput(
        "bbox",
        [1, 512, 4],
        "INT64"
    ),
]

outputs = [
    httpclient.InferRequestedOutput('output', binary_data=True)
]

In [32]:
inputs[0].set_data_from_numpy(all_input_ids[0].unsqueeze(0).numpy())
inputs[1].set_data_from_numpy(all_input_mask[0].unsqueeze(0).numpy())
inputs[2].set_data_from_numpy(all_segment_ids[0].unsqueeze(0).numpy())
inputs[3].set_data_from_numpy(all_bboxes[0].unsqueeze(0).numpy())

# Run inference

In [33]:
results = client.infer(
    'layoutlm_onnx',inputs, outputs=outputs,    
)

In [40]:
preds = results.as_numpy('output')[0]
preds = np.argmax(preds, axis=1)

In [42]:
out_label_ids = all_label_ids[0].numpy()
label_map = {i: label for i, label in enumerate(labels)}

In [43]:
out_label_list = [[] for _ in range(out_label_ids.shape[0])]
preds_list = [[] for _ in range(out_label_ids.shape[0])]

for i in range(out_label_ids.shape[0]):
    if out_label_ids[i] != pad_token_label_id:
        out_label_list[i].append(label_map[out_label_ids[i]])
        preds_list[i].append(label_map[preds[i]])

In [44]:
out_label_list

[[],
 ['S-QUESTION'],
 ['S-QUESTION'],
 ['S-QUESTION'],
 ['S-QUESTION'],
 ['S-QUESTION'],
 ['O'],
 ['O'],
 ['O'],
 ['S-QUESTION'],
 ['S-QUESTION'],
 ['S-QUESTION'],
 [],
 [],
 ['S-QUESTION'],
 ['S-QUESTION'],
 ['S-ANSWER'],
 [],
 ['S-ANSWER'],
 [],
 ['O'],
 [],
 [],
 [],
 [],
 ['O'],
 ['O'],
 ['B-QUESTION'],
 ['E-QUESTION'],
 ['B-QUESTION'],
 ['E-QUESTION'],
 ['B-HEADER'],
 ['E-HEADER'],
 [],
 [],
 [],
 ['O'],
 ['O'],
 ['O'],
 ['O'],
 ['O'],
 ['O'],
 ['O'],
 ['O'],
 ['O'],
 ['O'],
 [],
 [],
 [],
 ['O'],
 [],
 ['O'],
 [],
 ['O'],
 [],
 ['O'],
 [],
 [],
 ['B-ANSWER'],
 ['I-ANSWER'],
 [],
 ['E-ANSWER'],
 ['B-ANSWER'],
 ['I-ANSWER'],
 ['I-ANSWER'],
 ['I-ANSWER'],
 [],
 ['I-ANSWER'],
 ['I-ANSWER'],
 [],
 [],
 ['E-ANSWER'],
 [],
 ['B-QUESTION'],
 ['E-QUESTION'],
 ['B-ANSWER'],
 ['I-ANSWER'],
 ['E-ANSWER'],
 ['B-ANSWER'],
 [],
 [],
 ['I-ANSWER'],
 ['E-ANSWER'],
 ['B-ANSWER'],
 [],
 ['I-ANSWER'],
 [],
 ['E-ANSWER'],
 [],
 [],
 ['B-ANSWER'],
 ['E-ANSWER'],
 ['B-QUESTION'],
 ['E-QUESTION'],
 ['B