In [None]:
%load_ext autoreload
%autoreload 2

In [2]:
import random
from torch import optim
from tqdm import tqdm

import numpy as np
import pandas as pd
import string
import torch
import re

from torch.utils.data import Dataset, DataLoader
from transformers import AutoTokenizer

from transformers import BertConfig, BertTokenizer
from nltk.tokenize import word_tokenize

In [13]:
class NerIDCardDataset(Dataset):
    # Static constant variable
    LABELS = [
        'U-FLD_PROV', 'B-VAL_PROV', 'L-VAL_PROV', 'U-FLD_KAB', 'U-VAL_KAB',
        'U-FLD_NIK', 'U-VAL_NIK', 'U-FLD_NAMA', 'B-VAL_NAMA', 'L-VAL_NAMA',
        'B-FLD_TTL', 'L-FLD_TTL', 'B-VAL_TTL', 'L-VAL_TTL', 'B-FLD_GDR',
        'L-FLD_GDR', 'U-VAL_GDR', 'B-FLD_GLD', 'L-FLD_GLD', 'U-VAL_GLD',
        'U-FLD_ADR', 'B-VAL_ADR', 'I-VAL_ADR', 'L-VAL_ADR', 'U-FLD_RTW',
        'U-VAL_RTW', 'U-FLD_KLH', 'U-VAL_KLH', 'U-FLD_KCM', 'U-VAL_KCM',
        'U-FLD_RLG', 'U-VAL_RLG', 'B-FLD_KWN', 'L-FLD_KWN', 'B-VAL_KWN',
        'L-VAL_KWN', 'U-FLD_KRJ', 'U-VAL_KRJ', 'U-FLD_WRG', 'U-VAL_WRG',
        'B-FLD_BLK', 'L-FLD_BLK', 'B-VAL_BLK', 'L-VAL_BLK', 'U-VAL_SGP',
        'U-VAL_SGD', 'B-VAL_KAB', 'L-VAL_KAB', 'U-VAL_NAMA', 'B-VAL_KLH',
        'L-VAL_KLH', 'B-VAL_KRJ', 'I-VAL_KRJ', 'L-VAL_KRJ', 'B-VAL_SGP',
        'L-VAL_SGP', 'I-VAL_TTL', 'L-VAL_KCM', 'B-VAL_KCM', 'U-VAL_KWN',
        'U-VAL_PROV', 'I-VAL_NAMA', 'I-VAL_PROV', 'I-VAL_KAB', 'I-VAL_KCM',
        'I-VAL_SGP', 'U-VAL_ADR', 'I-VAL_KLH', 'O'
    ]
    
    LABEL2INDEX = dict((label,idx) for idx, label in enumerate(LABELS))
    INDEX2LABEL = dict((idx, label) for idx, label in enumerate(LABELS))
    NUM_LABELS = len(LABELS)
    
    
    def __init__(self, dataset_path, tokenizer, *args, **kwargs):
        self.data = self.load_dataset(dataset_path)
        self.tokenizer = tokenizer
        
    def load_dataset(self, path):
        dframe = pd.read_csv(path)
        
        dataset, sentence, seq_label = [], [], []
        length_sentence = len(dframe.sentence_idx.unique())
        for idx in range(length_sentence):
            sframe = dframe[dframe.sentence_idx == idx]
            for sidx in range(len(sframe)):
                line = sframe.iloc[sidx]
                word = str(line.word)
                label = str(line.tag)
                sentence.append(word)
                seq_label.append(self.LABEL2INDEX[label])
            dataset.append({
                'sentence': sentence,
                'seq_label': seq_label
            })
            sentence, seq_label = [], []
        return dataset
        
    def __getitem__(self, index):
        data = self.data[index]
        sentence, seq_label = data['sentence'], data['seq_label']
        
        # Add CLS token
        subwords = [self.tokenizer.cls_token_id]
        subword_to_word_indices = [-1] # For CLS
        
        # Add subwords
        for word_idx, word in enumerate(sentence):
            subword_list = self.tokenizer.encode(word, add_special_tokens=False)
            subword_to_word_indices += [word_idx for i in range(len(subword_list))]
            subwords += subword_list
            
        # Add last SEP token
        subwords += [self.tokenizer.sep_token_id]
        subword_to_word_indices += [-1]
        
        return np.array(subwords), np.array(subword_to_word_indices), np.array(seq_label), data['sentence']
    
    def __len__(self):
        return len(self.data) 
        
class NerDataLoader(DataLoader):
    def __init__(self, max_seq_len=512, *args, **kwargs):
        super(NerDataLoader, self).__init__(*args, **kwargs)
        self.collate_fn = self._collate_fn
        self.max_seq_len = max_seq_len
        
    def _collate_fn(self, batch):
        batch_size = len(batch)
        max_seq_len = max(map(lambda x: len(x[0]), batch))
        max_seq_len = min(self.max_seq_len, max_seq_len)
        max_tgt_len = max(map(lambda x: len(x[2]), batch))
        
        subword_batch = np.zeros((batch_size, max_seq_len), dtype=np.int64)
        mask_batch = np.zeros((batch_size, max_seq_len), dtype=np.float32)
        subword_to_word_indices_batch = np.full((batch_size, max_seq_len), -1, dtype=np.int64)
        seq_label_batch = np.full((batch_size, max_tgt_len), -100, dtype=np.int64)
        
        seq_list = []
        for i, (subwords, subword_to_word_indices, seq_label, raw_seq) in enumerate(batch):
            subwords = subwords[:max_seq_len]
            subword_to_word_indices = subword_to_word_indices[:max_seq_len]
            
            subword_batch[i,:len(subwords)] = subwords
            mask_batch[i,:len(subwords)] = 1
            subword_to_word_indices_batch[i,:len(subwords)] = subword_to_word_indices
            seq_label_batch[i,:len(seq_label)] = seq_label

            seq_list.append(raw_seq)
            
        return subword_batch, mask_batch, subword_to_word_indices_batch, seq_label_batch, seq_list

In [14]:
dataset_path = 'data/idcard/ktp_ner_dataset.csv'
pretrained_model = 'bert-base-uncased'
tokenizer = AutoTokenizer.from_pretrained(pretrained_model)
trainset = NerIDCardDataset(dataset_path, tokenizer)

In [16]:
# NerIDCardDataset.LABEL2INDEX

In [17]:
for subwords, subword_to_word_indices, seq_label, sentence in trainset:
    print(subwords, len(subwords))
    print(subword_to_word_indices, len(subword_to_word_indices))
    print(seq_label, len(seq_label))
    print(sentence, len(sentence))
    break

[  101  4013  6371  5332  5730  2050  5199  3126 10556  8569 17585  2078
  8945 14339 20265  3217 23205 28906 17465 21057 12740  2575  2683 21057
 14142  2629 15125  2050  9152  3490  6520 22734  3736 18780  2527  8915
  8737  4017  1013  1056 23296  2474 11961  8945 14339 20265  3217  1010
  5840  1011  5757  1011  2639 15419  2483 17710 10278  2378 23976  8737
 13860  2175  2140  1012 18243  4430  1051 26234  4017  1046  2140  3103
  3334  9152  2099 21761  2004  3089 12849  8737  2140  9152  2099 21761
  3103  3334  9097  1038  2140  1037  1013 10715 19387  1013  1054  2860
  5890  2549  1013  5890  2487 17710  2140  1013  4078  2050 17710 11735
  8490  5575 17710 28727  6790  2078 11687 18222  2078 12943  8067  7025
  3570  2566  2912 10105  2319  8292 14995 13523  2072 21877  5484  3900
  2319  9004  7088  1013 21877  3489  8569  2078 17710  9028  5289 29107
  2527  2319  1059  3490  2022 12190  4817  2226  7632  3070  3654  7367
  2819  3126 11041  6279  8945 14339 20265  3217  5

In [21]:
loader = NerDataLoader(dataset=trainset, batch_size=32, num_workers=0)

In [20]:
# %%time
# for i, (subwords, mask, subword_to_word_indices, seq_label, seq_list) in enumerate(loader):
#     print(subwords, mask, subword_to_word_indices, seq_label, seq_list)
#     if i == 2:
#         break