In [2]:
import sys
import glob
import json
import torch
import pandas as pd
sys.path.append('/home/long8v/BERT-NER')
from bert import Ner

# Data Loading

In [3]:
path = '/home/long8v/ICDAR-2019-SROIE/task3/data/test_dict.pth'
data_dict = torch.load(path)

In [4]:
data_dict['X00016469670']

'TAN CHAY YEE\n*** COPY ***\nOJC MARKETING SDN BHD\nROC NO: 538358-H\nNO 2 & 4, JALAN BAYU 4,\nBANDAR SERI ALAM,\n81750 MASAI, JOHOR\nTEL:07-388 2218 FAX:07-388 8218\nEMAIL:NG@OJCGROUP.COM\nTAX INVOICE\nINVOICE NO\t: PEGIV-1030765\nDATE\t: 15/01/2019 11:05:16 AM\nCASHIER\t: NG CHUAN MIN\nSALES PERSON : FATIN\nBILL TO\t: THE PEAK QUARRY WORKS\nADDRESS\t:.\nDESCRIPTION\tQTY\tPRICE\tAMOUNT\n000000111\t1\t193.00\t193.00 SR\nKINGS SAFETY SHOES KWD B05\nQTY: 1\tTOTAL EXCLUDE GST:\t193.00\nTOTAL GST @6%:\t0.00\nTOTAL INCLUSIVE GST:\t193.00\nROUND AMT:\t0.00\nTOTAL:\t193.00\nVISA CARD\t193.00\nXXXXXXXXXXXX4318\nAPPROVAL CODE:000\nGOODS SOLD ARE NOT RETURNABLE & REFUNDABLE\n****THANK YOU. PLEASE COME AGAIN.****'

# inference

In [5]:
import re
import pickle
from transformers import *

In [6]:
model = Ner('/home/long8v/sroie_data/raw_bert_replace_special_token')

In [7]:
tokenizer = model.tokenizer
tokenizer.vocab_size

28996

In [8]:
re_int = re.compile('\d+')
re_float = re.compile('(\d+\.\d+)')
re_percent = re.compile('(\d+.?\d+%)')
re_date = re.compile('(\d{2}[/-]\d{2}[/-]\d{2,4})')

In [9]:
re_dict = {re_float:'float', re_percent:'percent', re_date:'date', re_int:'int'}

In [10]:
def re_text(text):
    for key, value in re_dict.items():
        text = key.sub(value, text)
    return text

In [11]:
def re_find_pattern(text):
    pattern = re.compile('float|percent|date|int')
    if pattern.findall(text):
        return True
    else:
        return False

In [12]:
def re_find_which_pattern(text):
    pattern = re.compile('float|percent|date|int')
    try:
        return pattern.findall(text)[0]
    except:
        return text

In [13]:
def get_re_text(text):
    patterns = {}
    for key, value in re_dict.items():
        patterns[value] = re.findall(key, text) 
        text = key.sub(value, text)
    return patterns

In [14]:
def get_tokenized_word(text):
    token_word = tokenizer.tokenize(text)
    return token_word

In [15]:
def preprocess(text):
    text = re_text(text)
    return text

In [16]:
preprocess_data = {}
re_data = {}
for key, value in data_dict.items():
    preprocess_data[key] = preprocess(value)
    re_data[key] = get_re_text(value)

In [17]:
result_data = {}
for key, value in preprocess_data.items():
    result_data[key] = model.predict(value[:512])
    result_data[key].extend(model.predict(value[512:1024]))

In [56]:
from collections import defaultdict


def get_result_json(result_list):
    result_json = defaultdict(list)
    re_dict_json = defaultdict(int)
    for sample in result_list:
        tag = sample['tag'].lower()
        word = sample['word']
        word = word.replace('##', '')
        re_dict_json[re_find_which_pattern(word)] += 1
        if tag != 'o':
            try:
                result_json[tag.split('-')[1]] += [(word, int(re_dict_json[re_find_which_pattern(word)]))]
            except Exception as e:
                print(tag, e)
    return dict(result_json)


json_data = {}
for key, value in result_data.items():
    json_data[key] = get_result_json(value)
    new_json_data = {}
    for k, v in json_data[key].items():
        words = [re_data[key][re_find_which_pattern(word)][count-1] 
                 if re_find_pattern(word) else word 
                 for word, count in v]
        if words:
            words = [re.escape(word) for word in words]
            print(words)
            pattern = '\s?'.join(words)
            try:
                v_with_space = list(filter(lambda e: e, re.findall(pattern, data_dict[key])))
                new_json_data[k] = v_with_space
            except Exception as e:
                pass
    json_data[key] = new_json_data

['TAN', 'OJC', 'MARKETING', 'SDN', 'BHD', '1']
['NO', '2', '\\&', '4', '\\,', 'JALAN', 'BAYU', '4', '\\,', 'BANDAR', 'SERI', 'ALAM', '\\,', '81750', 'MASAI', '\\,', 'JOHOR', 'RETURNABLE', '\\&', 'REFUNDABLE']
['15\\/01\\/2019']
['193\\.00', '0\\.00']
['TAN', 'CHAY', 'YEE', 'OJC', 'MARKETING', 'SDN', 'BHD', 'LE', '\\&', 'REFUNDABLE']
['NO', '2', '\\&', '4', '\\,', 'JALAN', 'BAYU', '4', '\\,', 'BANDAR', 'SERI', 'ALAM', '\\,', '81750', 'MASAI', '\\,', 'JOHOR', 'YOU', '\\,', 'PLEASE', 'COME', 'AGAIN']
['02\\/01\\/2019']
['0\\.00']
['PERNIAGAAN', 'ZHENG', 'HUI', 'GEL']
['59', 'JALAN', 'PERMAS', '9', 'BANDAR', 'BARU', 'PERMAS', 'JAYA', '5', 'JOHOR', 'BAHRU']
['09\\/02\\/2018']
['411\\.50']
['PETRON', 'BKT', 'LANJAN', 'SB', 'ALSERKAM', 'ENTERPRISE']
['458\\.4']
['BKT', 'LANJAN', 'UTARA', '\\,', 'L\\/RAYA', 'UTARA', 'SELATAN', '\\,', 'SG', 'BULOH', '001083069', 'SUNGAI']
['01\\/02\\/2018']
['GERBANG', 'ALAF', 'RESTAURANTS', 'SDN', 'BHD']
['LICENSEE', 'OF', "MCDONALD\\'S", 'LEVEL', '6', '\\,', 

['UNIHAKKA', 'INTERNATIONAL', 'SDN', 'BHD', 'GST']
['12', '\\,', 'JALAN', 'TAMPOI', '7', '\\,', 'KAWASAN', 'PERINDUSTRIAN', 'TAMPOI', '\\,', '4', 'JOHOR', 'BAHRU', '\\,', 'JOHOR', 'TAXINVOICE', 'THANK']
['8\\.90']
['AEON', 'CO\\.', '\\(', 'M', '\\)', 'BHD', 'GST']
['3', 'FLR', '\\,', 'AEON', 'TAMAN', 'MALURI', 'SC', 'JLN', 'JEJAKA', '\\,', 'TAMAN', 'MALURI', 'CHERAS', '\\,', '55100', 'KUALA', 'LUMPUR', 'SUBAIDAH', 'BINTI', 'TEBRAU']
['50\\.90']
['14\\/04\\/2018']
['TRIPLE', 'SIX', 'POINT', 'ENTERPRISE', 'OYA']
['666', 'NO', '14', '\\&', '16', 'JALAN', 'PERMAS', '4', 'BANDAR', 'BARU', 'PERMAS', 'JAY']
['23\\-04\\-2018']
['RESTORAN', 'HWA', 'MUI', 'SUTERA', 'SDN', 'BHD']
['50', '\\,', 'JALAN', 'SUTERA', 'TANJUNG', '5', 'TAMAN', 'SUTERA', 'UTAMA', '\\,', '4', 'SKUDAI', '\\,', 'JOHOR']
['29\\/04\\/2018']
['9\\.00', '3\\.30', '91\\.50']
['SECRET', 'RECIPE', 'RESTAURANT', 'YOU']
['LOT', '16', '\\,', 'PERMAS', 'JAYA', 'JUSCO']
['6']
['AEON', 'CO\\.', '\\(', 'M', '\\)', 'BHD', 'GST', 'AMOUNT']

In [57]:
# json_data

In [63]:
new_json_data = {}
for key, value in json_data.items():
    new_dict = defaultdict(str)
    for k, v in value.items():
        if v:
            if k == 'total':
                try:
                    v = max(list(map(lambda e: float(e), v)))
                except ValueError:
                    v = ''
            else:
                v = v[0]
        else:
            v = ''
        new_dict[k] = str(v).replace('\n', ' ')
    new_json_data[key] = new_dict

In [68]:
new_json_data['X00016469670']['day']

''

In [71]:
path = '/home/long8v/docrv2_sroie/submission/SROIE_example_t3'
for key, value in new_json_data.items():
    with open('{}/{}.txt'.format(path, key), 'w') as f:
        f.write('{\n')
        f.write('    "company": "{}",\n'.format(value['company']))
        f.write('    "date": "{}",\n'.format(value['date']))
        f.write('    "address": "{}",\n'.format(value['address']))
        f.write('    "total": "{}"\n'.format(value['total']))
        f.write('}')
        print('{} saved.'.format(key))

X00016469670 saved.
X00016469671 saved.
X51005200931 saved.
X51005230605 saved.
X51005230616 saved.
X51005230621 saved.
X51005230648 saved.
X51005230657 saved.
X51005230659 saved.
X51005268275 saved.
X51005268408 saved.
X51005288570 saved.
X51005301666 saved.
X51005337867 saved.
X51005337877 saved.
X51005361906 saved.
X51005361908 saved.
X51005361912 saved.
X51005361923 saved.
X51005365187 saved.
X51005433518 saved.
X51005433543 saved.
X51005433548 saved.
X51005433556 saved.
X51005442322 saved.
X51005442334 saved.
X51005442343 saved.
X51005442366 saved.
X51005442375 saved.
X51005442382 saved.
X51005442388 saved.
X51005444040 saved.
X51005444041 saved.
X51005444044 saved.
X51005444046 saved.
X51005447841 saved.
X51005447842 saved.
X51005447844 saved.
X51005447851 saved.
X51005447859 saved.
X51005568855 saved.
X51005568866 saved.
X51005568885 saved.
X51005568887 saved.
X51005568889 saved.
X51005568890 saved.
X51005568892 saved.
X51005568894 saved.
X51005568895 saved.
X51005577191 saved.
