In [1]:
import torch
from modeling_layoutlm import LayoutLMForTokenClassification
from transformers import (
    BertConfig,
    BertTokenizer,
)
from utils_docvqa import (
    read_docvqa_examples,
    convert_examples_to_features)
from torch.utils.data import DataLoader, SequentialSampler, TensorDataset
from transformers.data.processors.squad import SquadResult
from tqdm import tqdm
import numpy as np
from tabulate import tabulate

  from .autonotebook import tqdm as notebook_tqdm


In [11]:
MODEL_FOLDER = "./models/"
SAMPLE_DATA = "./image.json"
LABELS = ["start","end"]
pad_token_label_id=-100
labels = ["start","end"]
max_seq_length = 512
max_query_length = 64
doc_stride = 128

In [12]:
# device = torch.device("cuda:0")
device = torch.device("cpu")
# torch.cuda.set_device(device)
model_class = LayoutLMForTokenClassification
config_class = BertConfig
tokenizer_class = BertTokenizer
config = config_class.from_pretrained(MODEL_FOLDER,num_labels=2,cache_dir=None)
model = model_class.from_pretrained(MODEL_FOLDER)
tokenizer = tokenizer_class.from_pretrained(MODEL_FOLDER,do_lower_case=True)

In [13]:
examples = read_docvqa_examples(SAMPLE_DATA, is_training=False)

In [14]:
features = convert_examples_to_features(
            examples=examples,
            label_list=labels,
            tokenizer=tokenizer,
            max_seq_length=max_seq_length,
            doc_stride=doc_stride,
            max_query_length=max_query_length,
            is_training=False,
            pad_token_label_id=pad_token_label_id)

INFO:tensorflow:*** Example ***
INFO:tensorflow:unique_id: 1000000000
INFO:tensorflow:example_index: 0
INFO:tensorflow:doc_span_index: 0
INFO:tensorflow:tokens: [CLS] hello there [SEP] ( exception to sf - 50 approved by bureau of the budget december 1965 ) gas civil service com ##iss ##ion fe ##m cha ##p . 2000 ##0 & notification of personnel action & ( employee — see general information on reverse ) ( for agency use ) 2 . ( for agency use ) 3 . birth date 4 . social security no . 1 . name ( caps ) last — first — middle mr . — miss — mrs ( mo , day , year ) 6 . tenure group 7 . service com ##p . date | 8 . physical handicap code , robert e . dr . 80 ##19 ##5 | _ 09 - 02 - 14 49 ##9 - 34 - 05 ##9 ##7 5 . veteran preference 3 - 10 pt . di ##sa ##b . 5 - 10 pt . other 4 — 10 pt . com ##p . 9 . fe ##gli 10 . retirement 2 ‘ covered 2 — ineligible 3 - waived fi ##e pa ##ela ##te - cs 3 - f ##s 5 — other | | 2 - fi ##ca 4 — none 12 . code nature of action [ 3 effective date * civil service or

In [15]:
all_input_ids = torch.tensor([f.input_ids for f in features], dtype=torch.long)
all_input_mask = torch.tensor([f.input_mask for f in features], dtype=torch.long)
all_segment_ids = torch.tensor([f.segment_ids for f in features], dtype=torch.long)
all_bboxes = torch.tensor([f.boxes for f in features], dtype=torch.long)
all_example_index = torch.arange(all_input_ids.size(0), dtype=torch.long)

eval_dataset = TensorDataset(
        all_input_ids, all_input_mask, all_segment_ids,all_bboxes,all_example_index)
eval_batch_size = 1
eval_sampler = (
        SequentialSampler(eval_dataset))

eval_dataloader = DataLoader(
        eval_dataset, sampler=eval_sampler, batch_size=eval_batch_size
    )

In [16]:
model.to(device)
all_results = []
table_data = []

def to_list(tensor):
    return tensor.detach().cpu().tolist()

for batch in tqdm(eval_dataloader, desc="Evaluating"):
    model.eval()
    batch = tuple(t.to(device) for t in batch)
    with torch.no_grad():
        inputs = {
            "input_ids": batch[0],
            "attention_mask": batch[1],
        }
        inputs["bbox"] = batch[3]
        inputs["token_type_ids"] = (batch[2])
        outputs = model(**inputs)
    example_indices = batch[4]

    for i, example_index in enumerate(example_indices):
        eval_feature = features[example_index.item()]
        unique_id = int(eval_feature.unique_id)

        output = [to_list(output[i]) for output in outputs]

        start_logits, end_logits = output
        result = SquadResult(unique_id, start_logits, end_logits)
        all_results.append(result)
predictions_json = {}
assert len(all_results)==len(features)
for i in range(len(all_results)):
    start_index = np.argmax(all_results[i].start_logits)
    end_index = np.argmax(all_results[i].end_logits)
    pred_answer = features[i].tokens[start_index:end_index+1]
    pred_answer = ' '.join([x for x in pred_answer])
    pred_text = pred_answer.replace(' ##', '')
    question = features[i].tokens[1:features[i].tokens.index('[SEP]')]
    question_text = ' '.join([x for x in question])
    question_text = question_text.replace(' ##', '')
    table_data.append([question_text, pred_text])
    # print(question_text)
    # print(pred_text) 


headers = ["Question", "Answer"]
print(tabulate(table_data, headers=headers, tablefmt="grid"))

Evaluating:   0%|          | 0/3 [00:00<?, ?it/s]

tensor([[   0,    0,    0, 1000,   58,   58,  114,  126,  126,  126,   59,  107,
          121,  160,  172,  191,   58,  110,  110,   59,   83,  102,  144,  144,
          144,   58,   58,   83,   83,   83,  141,  141,  271,  304,  394,  416,
          493,  544,  318,  322,  366,  378,  396,  428,  477,  489,  489,   72,
           72,   97,  139,  139,  419,  419,  429,  429,  454,  495,  495,  533,
          533,  544,  582,  648,  648,  658,  693,  742,  742,   56,   56,   66,
           95,   95,   95,  132,  132,  132,  132,  132,  306,  306,  306,  306,
          306,  306,  541,  541,  541,  564,  564,  584,  584,  420,  420,  430,
          470,  533,  533,  543,  583,  583,  583,  616,  640,  649,  649,  660,
          707,  755,  114,  129,  187,  187,  345,  345,  429,  429,  429,  525,
          525,  525,  525,  525,  525,  525,  661,  661,  661,  661,  661,  661,
          661,  661,   53,   53,   67,  111,  198,  198,  198,  219,  219,  234,
          234,  234,  234,  

Evaluating:  33%|███▎      | 1/3 [00:00<00:00,  2.58it/s]

tensor([[   0,    0,    0, 1000,  661,  661,   53,   53,   67,  111,  198,  198,
          198,  219,  219,  234,  234,  234,  234,  303,  303,  303,  324,  324,
          339,  198,  198,  198,  220,  220,  231,  231,  231,   56,   56,   67,
           67,  415,  415,  431,   71,  105,  105,  212,  212,  212,  328,  328,
          328,  409,  409,  429,  429,  429,  429,  429,  519,  519,  519,  519,
          585,  585,  585,   96,  449,  456,  456,  456,  456,  520,  520,  520,
           57,   57,   70,  105,  144,  159,  408,  408,  430,  482,  525,  545,
          572,  613,  629,  663,  694,   71,   71,  112,  204,  204,  204,  204,
          204,  279,  409,   72,   72,  112,  112,  144,  144,  186,  186,  220,
          220,  220,  220,  220,  296,  429,  429,  429,  429,  429,  545,  570,
          604,  637,  525,  525,  525,  545,  545,  545,  559,  599,  599,  599,
          614,  647,  647,  661,  568,  618,  562,  612,  504,  646,  684,  684,
          705,  705,   64,  




IndexError: index out of range in self