In [364]:
import re
import torch
import numpy as np
import pandas as pd
from glob import glob
from transformers import *
from sklearn.model_selection import train_test_split
from itertools import zip_longest

In [365]:
pd.options.display.max_rows = 999

## BERT tokenizer loading

In [366]:
pretrained_weights = 'bert-base-uncased'
tokenizer = BertTokenizer.from_pretrained(pretrained_weights)

In [367]:
max_len = tokenizer.max_len
print(max_len)

512




In [369]:
model = BertModel.from_pretrained(pretrained_weights)

In [370]:
input_ids = torch.tensor([tokenizer.encode("Here is some text to encode", add_special_tokens=True)])  

In [371]:
input_ids

tensor([[  101,  2182,  2003,  2070,  3793,  2000,  4372, 16044,   102]])

In [372]:
with torch.no_grad():
    last_hidden_states = model(input_ids)[0]  #

In [373]:
last_hidden_states

tensor([[[-0.0549,  0.1053, -0.1065,  ..., -0.3550,  0.0686,  0.6506],
         [-0.5759, -0.3650, -0.1383,  ..., -0.6782,  0.2092, -0.1639],
         [-0.1641, -0.5597,  0.0150,  ..., -0.1603, -0.1346,  0.6216],
         ...,
         [ 0.2448,  0.1254,  0.1587,  ..., -0.2749, -0.1163,  0.8809],
         [ 0.0481,  0.4950, -0.2827,  ..., -0.6097, -0.1212,  0.2527],
         [ 0.9046,  0.2137, -0.5897,  ...,  0.3040, -0.6172, -0.1950]]])

In [374]:
tokenizer.tokenize('float')
tokenizer.tokenize('total')
tokenizer.tokenize('gst')
tokenizer.tokenize('round')
tokenizer.tokenize('10%') 
tokenizer.tokenize('9.24')
tokenizer.tokenize('273500') 
tokenizer.tokenize('currency')
tokenizer.tokenize('row')
tokenizer.tokenize('date')

['date']

## data loading

In [375]:
data_dict = torch.load('/home/long8v/ICDAR-2019-SROIE/task3/data/data_dict4.pth')
zipped_data = list(zip(*data_dict.values()))
texts = zipped_data[0]
labels = zipped_data[1]

In [376]:
# texts = list(map(lambda e: e.replace('\n', ' row '), texts))

In [377]:
(texts[0], labels[0])

('TAN WOON YANN\nBOOK TA .K(TAMAN DAYA) SDN BND\n789417-W\nNO.53 55,57 & 59, JALAN SAGU 18,\nTAMAN DAYA,\n81100 JOHOR BAHRU,\nJOHOR.\nDOCUMENT NO : TD01167104\nDATE:\t25/12/2018 8:13:39 PM\nCASHIER:\tMANIS\nMEMBER:\nCASH BILL\nCODE/DESC\tPRICE\tDISC\tAMOUNT\nQTY\tRM\tRM\n9556939040116\tKF MODELLING CLAY KIDDY FISH\n1 PC\t*\t9.000\t0.00\t9.00\nTOTAL:\t9.00\nROUNDING ADJUSTMENT:\t0.00\nROUNDED TOTAL (RM):\t9.00\nCASH\t10.00\nCHANGE\t1.00\nGOODS SOLD ARE NOT RETURNABLE OR\nEXCHANGEABLE\n***\n***\nTHANK YOU\nPLEASE COME AGAIN !',
 array([0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 1, 1, 1, 1, 1, 1, 1,
        1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
        0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3,
        3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3,
        3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3,
        3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 0, 0, 0, 0, 0, 0, 0, 0,
        0, 0, 0, 0, 0, 0

## replace special token 

In [378]:
re_int = re.compile('\d+')
re_float = re.compile('(\d+\.\d+)')
re_percent = re.compile('(\d+.?\d+%)')
re_date = re.compile('(\d{2}[/-]\d{2}[/-]\d{2,4})')

In [379]:
re_date.findall('20-02-2020')

['20-02-2020']

In [380]:
re_dict = {re_float:'float', re_percent:'percent', re_date:'date', re_int:'int'}

In [381]:
corpus = ' '.join(texts)
for _ in re_dict.values():
    print('{} : {}'.format(_, sum([_ in corpus])))

float : 0
percent : 0
date : 0
int : 0


In [382]:
re_int.search(texts[1]).group()

'27'

In [383]:
for text, label in zip(texts, labels):
    if len(text)!=len(label):
        print(len(text), len(label))

In [384]:
def replace_special_token(text, label):
    text_copy = text
    label = list(label)
    label_copy = label[:]
    for re_exp, special_token in re_dict.items():
        while re_exp.search(text):
            match = re_exp.search(text)
            span_start, span_end = match.span()
            word = match.group()
            len_diff = len(special_token) - len(word)
            if len_diff > 0:
                while len_diff:
                    len_diff -= 1
                    try:
                        label_copy.insert(span_start, label_copy[span_start])
                    except:
                        print(len(label), span_start)
            elif len_diff < 0: 
                del label_copy[span_start:span_start-len_diff]
            else:
                pass
            text = list(text)
            del text[span_start:span_end]
            text.insert(span_start, special_token)
            text = ''.join(text)
    assert len(text) == len(label_copy)
    return text, label_copy

In [385]:
replace_list = dict([replace_special_token(text, label) for text, label in zip(texts, labels)])

In [386]:
texts = list(replace_list.keys())
labels = list(replace_list.values())

## bio-tagging for bert tokenizer

In [387]:
def get_tokenized_word(text):
    token_word = tokenizer.tokenize(text)
    return token_word

In [388]:
def get_token_labels(token_word, text, label):
    index = 0
    token_labels = []
    label_clean = [lbl for txt, lbl in list(zip(text, label)) if txt.strip()]
    for token in token_word[:max_len]:
        token_clean = token.replace('##', '')
        token_labels.append(label_clean[index:index+len(token_clean)])
        index += len(token_clean)
    return token_labels

In [389]:
def get_bio_tag(token_labels):
    label_dict = {0: 'O', 1: 'COMPANY', 2:'DATE', 3:'ADDRESS', 4:'TOTAL'}
    token_label_bio = []
    current = 0 
    for token_label in token_labels:
        try:
            temp_label = token_label[0]
        except IndexError as e:
            pass
        if temp_label == 0:
            token_label_bio.append(label_dict[temp_label])
        elif temp_label != current:
            token_label_bio.append('B-{}'.format(label_dict[temp_label]))
        else:
            token_label_bio.append('I-{}'.format(label_dict[temp_label]))
        current = temp_label
    return token_label_bio

In [390]:
def get_paired_token(text, label):
    token_word = get_tokenized_word(text)
    token_labels = get_token_labels(token_word, text, label)
    token_label_bio = get_bio_tag(token_labels)
    return pd.DataFrame(zip(token_word, token_label_bio))

In [391]:
def get_paired_token_text_label(texts, labels):
    df_list = []
    for text, label in zip(texts, labels):
        df = pd.DataFrame()
        df = df.append({0:'-DOCSTART-', 1: 'O'}, ignore_index=True)
        df = df.append(get_paired_token(text, label))
        df = df.append({0:'', 1:'O'}, ignore_index=True)
        df[0] = df[0].apply(lambda e: replace_digits(e))
        df_list.append(df)
    return df_list

In [392]:
train_text, test_text,  train_label, test_label = train_test_split(texts, labels)

In [393]:
train_text, val_text,  train_label, val_label = train_test_split(train_text, train_label)

In [394]:
train_df = get_paired_token_text_label(train_text, train_label)
val_df = get_paired_token_text_label(val_text, val_label)
test_df = get_paired_token_text_label(test_text, test_label)

In [395]:
import pickle

In [400]:
with open('data/train_df.p', 'wb') as f:
    pickle.dump(train_df, f)

In [397]:
from functools import reduce

In [398]:
train_df_long = reduce(lambda a, b: pd.concat([a,b]), train_df)
val_df_long = reduce(lambda a, b: pd.concat([a,b]), val_df)
test_df_long = reduce(lambda a, b: pd.concat([a,b]), test_df)

In [399]:
train_df_long.to_csv('data/train.txt', sep=' ', index=False, header=False)
val_df_long.to_csv('data/valid.txt', sep=' ', index=False, header=False)
test_df_long.to_csv('data/test.txt', sep=' ', index=False, header=False)