In [1]:
import sys
import glob
import json
import torch
import pandas as pd
sys.path.append('/home/long8v/BERT-NER')
from bert import Ner

# Data Loading

In [2]:
path = '/home/long8v/ICDAR-2019-SROIE/task3/data/test_dict.pth'
data_dict = torch.load(path)

In [3]:
path = '/home/long8v/ICDAR-2019-SROIE/task3/data/data_dict.pth'
data_dict = torch.load(path)

In [4]:
data_dict = {key:value[0] for key, value in data_dict.items()}

In [5]:
path = '/home/long8v/ICDAR-2019-SROIE/task3/data/keys.pth'
key_dict = torch.load(path)

# inference

In [6]:
import re
import pickle
from transformers import *

In [7]:
model = Ner('/home/long8v/sroie_data/raw_bert_replace_special_token')

In [8]:
tokenizer = model.tokenizer
tokenizer.vocab_size

28996

In [9]:
re_int = re.compile('\d+')
re_float = re.compile('(\d+\.\d+)')
re_percent = re.compile('(\d+.?\d+%)')
re_date = re.compile('(\d{2}[/-]\d{2}[/-]\d{2,4})')

In [10]:
re_dict = {re_float:'float', re_percent:'percent', re_date:'date', re_int:'int'}

In [11]:
def re_text(text):
    for key, value in re_dict.items():
        text = key.sub(value, text)
    return text

In [12]:
def re_find_pattern(text):
    pattern = re.compile('float|percent|date|int')
    if pattern.findall(text):
        return True
    else:
        return False

In [13]:
def re_find_which_pattern(text):
    pattern = re.compile('float|percent|date|int')
    try:
        return pattern.findall(text)[0]
    except:
        return text

In [14]:
def get_re_text(text):
    patterns = {}
    for key, value in re_dict.items():
        patterns[value] = re.findall(key, text) 
        text = key.sub(value, text)
    return patterns

In [15]:
def get_tokenized_word(text):
    token_word = tokenizer.tokenize(text)
    return token_word

In [16]:
def preprocess(text):
    text = re_text(text)
    return text

In [17]:
preprocess_data = {}
re_data = {}
for key, value in data_dict.items():
    preprocess_data[key] = preprocess(value)
    re_data[key] = get_re_text(value)

In [18]:
result_data = {}
for key, value in preprocess_data.items():
    result_data[key] = model.predict(value[:512])
    result_data[key].extend(model.predict(value[512:1024]))

In [32]:
from collections import defaultdict


def get_result_json(result_list):
    result_json = defaultdict(list)
    re_dict_json = defaultdict(int)
    for sample in result_list:
        tag = sample['tag'].lower()
        word = sample['word']
        word = word.replace('##', '')
        re_dict_json[re_find_which_pattern(word)] += 1
        if tag != 'o':
            try:
                result_json[tag.split('-')[1]] += [(word, int(re_dict_json[re_find_which_pattern(word)]))]
            except Exception as e:
                print(tag, e)
    return dict(result_json)


json_data = {}
for key, value in result_data.items():
    json_data[key] = get_result_json(value)
    new_json_data = {}
    for k, v in json_data[key].items():
        words = [re_data[key][re_find_which_pattern(word)][count-1] 
                 if re_find_pattern(word) else word 
                 for word, count in v]
        if words:
            words = [re.escape(word) for word in words]
            pattern = '\s?'.join(words)
            try:
                v_with_space = list(filter(lambda e: e, re.findall(pattern, data_dict[key])))
                new_json_data[k] = v_with_space
            except Exception as e:
                pass
    json_data[key] = new_json_data

In [33]:
new_json_data = {}
for key, value in json_data.items():
    new_dict = defaultdict(str)
    for k, v in value.items():
        if v:
            if k == 'total':
                try:
                    v = max(list(map(lambda e: float(e), v)))
                except ValueError:
                    v = ''
            else:
                v = v[0]
        else:
            v = ''
        new_dict[k] = str(v).replace('\n', ' ')
    new_json_data[key] = new_dict

In [34]:
path = '/home/long8v/docrv2_sroie/submission/SROIE_example_t3'
for key, value in new_json_data.items():
    with open('{}/{}.txt'.format(path, key), 'w') as f:
        f.write('{\n')
        f.write('    "company": "{}",\n'.format(value['company']))
        f.write('    "date": "{}",\n'.format(value['date']))
        f.write('    "address": "{}",\n'.format(value['address']))
        f.write('    "total": "{}"\n'.format(value['total']))
        f.write('}')

In [35]:
in_path = '/home/long8v/docrv2_sroie/submission/SROIE_example_t3'
out_path = '/home/long8v/docrv2_sroie/submission/SROIE_example_t3'
out_zip_file = '/home/long8v/docrv2_sroie/evaluation/task3/submit.zip'

In [36]:
import os
import zipfile
 
submission_zip = zipfile.ZipFile(out_zip_file, 'w')
for folder, subfolders, files in os.walk(in_path): 
    for file in files:
        submission_zip.write(os.path.join(folder, file), 
                             os.path.relpath(os.path.join(folder,file), out_path), 
                             compress_type = zipfile.ZIP_DEFLATED)
submission_zip.close()