In [1]:
import torch
from modeling_layoutlm import LayoutLMForTokenClassification
from transformers import (
    BertConfig,
    BertTokenizer,
)
from utils_docvqa import (
    read_docvqa_examples,
    convert_examples_to_features)
from torch.utils.data import DataLoader, SequentialSampler, TensorDataset
from transformers.data.processors.squad import SquadResult
from tqdm import tqdm
import numpy as np

  from .autonotebook import tqdm as notebook_tqdm


In [3]:
from tabulate import tabulate

In [None]:
import json
train_data = json.load(open("./train.json"))

In [None]:
ans = []
for data in train_data:
    if data["image_id"] == "ffdh0227_1":
        ans.append(data)
        break

In [None]:
ans

In [None]:
with open('ans.json', 'w') as json_file:
    json.dump(ans, json_file, indent=4)

In [29]:
MODEL_FOLDER = "./models/"
SAMPLE_DATA = "./ans.json"
LABELS = ["start","end"]
pad_token_label_id=-100
labels = ["start","end"]
max_seq_length = 512
max_query_length = 64
doc_stride = 128

In [30]:
# device = torch.device("cuda:0")
device = torch.device("cpu")
# torch.cuda.set_device(device)
model_class = LayoutLMForTokenClassification
config_class = BertConfig
tokenizer_class = BertTokenizer
config = config_class.from_pretrained(MODEL_FOLDER,num_labels=2,cache_dir=None)
model = model_class.from_pretrained(MODEL_FOLDER)
tokenizer = tokenizer_class.from_pretrained(MODEL_FOLDER,do_lower_case=True)

In [31]:
examples = read_docvqa_examples(SAMPLE_DATA, is_training=False)

In [32]:
features = convert_examples_to_features(
            examples=examples,
            label_list=labels,
            tokenizer=tokenizer,
            max_seq_length=max_seq_length,
            doc_stride=doc_stride,
            max_query_length=max_query_length,
            is_training=False,
            pad_token_label_id=pad_token_label_id)

INFO:tensorflow:*** Example ***
INFO:tensorflow:unique_id: 1000000000
INFO:tensorflow:example_index: 0
INFO:tensorflow:doc_span_index: 0
INFO:tensorflow:tokens: [CLS] what is the notification about ? [SEP] standard form 50 - rev . dec . 1961 civil ser ##vi f ##pm cha ##p . 295 ( exception to sf - 50 notification of personnel action approved by bu ##re ( employee - see general information on reverse ) 6 part december 1965 ) 50 - 126 - 21 ( for agency use ) 1 . name ( caps ) last - first - middle mr . - miss - mrs . 2 . ( for agency use ) 00 320 ##4 . birth date ( mo . . day . year ) 4 . social security no . shan ##k , robert e . dr . 80 ##19 ##5 09 - 02 - 14 49 ##9 - 34 - 05 ##9 ##7 5 . veteran preference 5 - 10 pt . other 6 . tenure group 7 . service com ##p . date | 8 . physical handicap code 1 - no 2 - 5 pt 3 - 10 pt . di ##sa ##b . - 10 pt . com ##p . . fe ##gli 10 . retirement 11 . ( for cs ##c use ) 2 1 - covered 2 - ineligible 3 - waived 1 - cs 2 2 - fi ##ca 3 - f ##s 4 - none 5 

In [33]:
all_input_ids = torch.tensor([f.input_ids for f in features], dtype=torch.long)
all_input_mask = torch.tensor([f.input_mask for f in features], dtype=torch.long)
all_segment_ids = torch.tensor([f.segment_ids for f in features], dtype=torch.long)
all_bboxes = torch.tensor([f.boxes for f in features], dtype=torch.long)
all_example_index = torch.arange(all_input_ids.size(0), dtype=torch.long)

eval_dataset = TensorDataset(
        all_input_ids, all_input_mask, all_segment_ids,all_bboxes,all_example_index)
eval_batch_size = 1
eval_sampler = (
        SequentialSampler(eval_dataset))

eval_dataloader = DataLoader(
        eval_dataset, sampler=eval_sampler, batch_size=eval_batch_size
    )

In [34]:
model.to(device)
all_results = []
table_data = []

def to_list(tensor):
    return tensor.detach().cpu().tolist()

for batch in tqdm(eval_dataloader, desc="Evaluating"):
    model.eval()
    batch = tuple(t.to(device) for t in batch)
    with torch.no_grad():
        inputs = {
            "input_ids": batch[0],
            "attention_mask": batch[1],
        }
        inputs["bbox"] = batch[3]
        inputs["token_type_ids"] = (batch[2])
        outputs = model(**inputs)
    example_indices = batch[4]

    for i, example_index in enumerate(example_indices):
        eval_feature = features[example_index.item()]
        unique_id = int(eval_feature.unique_id)

        output = [to_list(output[i]) for output in outputs]

        start_logits, end_logits = output
        result = SquadResult(unique_id, start_logits, end_logits)
        all_results.append(result)
predictions_json = {}
assert len(all_results)==len(features)
for i in range(len(all_results)):
    start_index = np.argmax(all_results[i].start_logits)
    end_index = np.argmax(all_results[i].end_logits)
    pred_answer = features[i].tokens[start_index:end_index+1]
    pred_answer = ' '.join([x for x in pred_answer])
    pred_text = pred_answer.replace(' ##', '')
    question = features[i].tokens[1:features[i].tokens.index('[SEP]')]
    question_text = ' '.join([x for x in question])
    question_text = question_text.replace(' ##', '')
    table_data.append([question_text, pred_text])
    # print(question_text)
    # print(pred_text) 


headers = ["Question", "Answer"]
print(tabulate(table_data, headers=headers, tablefmt="grid"))

Evaluating: 100%|██████████| 15/15 [00:06<00:00,  2.43it/s]

+----------------------------------------------------+----------------------------------+
| Question                                           | Answer                           |
| what is the notification about ?                   | notification of personnel action |
+----------------------------------------------------+----------------------------------+
| what is the notification about ?                   | [CLS]                            |
+----------------------------------------------------+----------------------------------+
| what is the notification about ?                   | [CLS]                            |
+----------------------------------------------------+----------------------------------+
| what is written at bottom of page ?                | [CLS]                            |
+----------------------------------------------------+----------------------------------+
| what is written at bottom of page ?                | 1 . employee copy                |
+---------


