In [1]:
import numpy as np
import pandas as pd
import torch
from transformers import *
from sklearn.model_selection import train_test_split
from itertools import zip_longest

## BERT tokenizer loading

In [2]:
pretrained_weights = 'bert-base-uncased'
tokenizer = BertTokenizer.from_pretrained(pretrained_weights)

In [3]:
max_len = tokenizer.max_len
print(max_len)

512




In [4]:
tokenizer.tokenize('※')

['[UNK]']

In [5]:
model = BertModel.from_pretrained(pretrained_weights)

In [6]:
input_ids = torch.tensor([tokenizer.encode("Here is some text to encode", add_special_tokens=True)])  

In [7]:
input_ids

tensor([[  101,  2182,  2003,  2070,  3793,  2000,  4372, 16044,   102]])

In [8]:
with torch.no_grad():
    last_hidden_states = model(input_ids)[0]  #

In [9]:
last_hidden_states

tensor([[[-0.0549,  0.1053, -0.1065,  ..., -0.3550,  0.0686,  0.6506],
         [-0.5759, -0.3650, -0.1383,  ..., -0.6782,  0.2092, -0.1639],
         [-0.1641, -0.5597,  0.0150,  ..., -0.1603, -0.1346,  0.6216],
         ...,
         [ 0.2448,  0.1254,  0.1587,  ..., -0.2749, -0.1163,  0.8809],
         [ 0.0481,  0.4950, -0.2827,  ..., -0.6097, -0.1212,  0.2527],
         [ 0.9046,  0.2137, -0.5897,  ...,  0.3040, -0.6172, -0.1950]]])

## bio-tagging for bert tokenizer

In [10]:
def get_tokenized_word(text):
    token_word = tokenizer.tokenize(text)
    return token_word

In [11]:
def get_token_labels(token_word, text, label):
    index = 0
    token_labels = []
    label_clean = [lbl for txt, lbl in list(zip_longest(text, label)) if txt.strip()]
    for token in token_word[:max_len]:
        token_clean = token.replace('##', '')
        token_labels.append(label_clean[index:index+len(token_clean)])
        index += len(token_clean)
    return token_labels

In [14]:
def get_bio_tag(token_labels):
    label_dict = {0: 'O', 1: 'COMPANY', 2:'DATE', 3:'ADDRESS', 4:'TOTAL'}
    token_label_bio = []
    current = 0 
    for token_label in token_labels:
        try:
            temp_label = token_label[0]
        except IndexError as e:
            pass
        if temp_label == 0:
            token_label_bio.append(label_dict[temp_label])
        elif temp_label != current:
            token_label_bio.append('B-{}'.format(label_dict[temp_label]))
        else:
            token_label_bio.append('I-{}'.format(label_dict[temp_label]))
        current = temp_label
    return token_label_bio

In [15]:
def get_paired_token(text, label):
    token_word = get_tokenized_word(text)
    token_labels = get_token_labels(token_word, text, label)
    token_label_bio = get_bio_tag(token_labels)
    return pd.DataFrame(zip_longest(token_word, token_label_bio))

In [16]:
def get_paired_token_text_label(texts, labels):
    df_list = []
    for text, label in zip_longest(texts, labels):
        df = pd.DataFrame()
        df = df.append({0:'-DOCSTART-', 1: 'O'}, ignore_index=True)
        df = df.append(get_paired_token(text, label))
        df = df.append({0:'', 1:'O'}, ignore_index=True)
        df_list.append(df)
    return df_list

In [17]:
data_dict = torch.load('/home/long8v/ICDAR-2019-SROIE/task3/data/data_dict4.pth')
zipped_data = list(zip_longest(*data_dict.values()))
texts = zipped_data[0]
labels = zipped_data[1]

In [19]:
train_text, test_text,  train_label, test_label = train_test_split(texts, labels)

In [20]:
train_text, val_text,  train_label, val_label = train_test_split(train_text, train_label)

In [21]:
pd.options.display.max_rows = 999

In [23]:
train_df = get_paired_token_text_label(train_text, train_label)
val_df = get_paired_token_text_label(val_text, val_label)
test_df = get_paired_token_text_label(test_text, test_label)

In [24]:
from functools import reduce

In [25]:
train_df_long = reduce(lambda a, b: pd.concat([a,b]), train_df)
val_df_long = reduce(lambda a, b: pd.concat([a,b]), val_df)
test_df_long = reduce(lambda a, b: pd.concat([a,b]), test_df)

In [32]:
train_df_long.to_csv('data/train.txt', sep=' ', index=False, header=False)
val_df_long.to_csv('data/valid.txt', sep=' ', index=False, header=False)
test_df_long.to_csv('data/test.txt', sep=' ', index=False, header=False)