In [2]:
%load_ext autoreload
%autoreload 2

In [6]:
import random
from torch import optim
from tqdm import tqdm

import numpy as np
import pandas as pd
import string
import torch
import re

from torch.utils.data import Dataset, DataLoader
from transformers import AutoTokenizer

from transformers import BertConfig, BertTokenizer
from nltk.tokenize import word_tokenize

In [10]:
class NerGritDataset(Dataset):
    # Static constant variable
    LABEL2INDEX = {'I-PERSON': 0, 'B-ORGANISATION': 1, 'I-ORGANISATION': 2, 'B-PLACE': 3, 'I-PLACE': 4, 'O': 5, 'B-PERSON': 6}
    INDEX2LABEL = {0: 'I-PERSON', 1: 'B-ORGANISATION', 2: 'I-ORGANISATION', 3: 'B-PLACE', 4: 'I-PLACE', 5: 'O', 6: 'B-PERSON'}
    NUM_LABELS = 7
    
    def load_dataset(self, path):
        # Read file
        data = open(path,'r').readlines()

        # Prepare buffer
        dataset = []
        sentence = []
        seq_label = []
        for line in data:
            if len(line.strip()) > 0:
                token, label = line[:-1].split('\t')
                sentence.append(token)
                seq_label.append(self.LABEL2INDEX[label])
            else:
                dataset.append({
                    'sentence': sentence,
                    'seq_label': seq_label
                })
                sentence = []
                seq_label = []
        return dataset
    
    def __init__(self, dataset_path, tokenizer, *args, **kwargs):
        self.data = self.load_dataset(dataset_path)
        self.tokenizer = tokenizer
        
    def __getitem__(self, index):
        data = self.data[index]
        sentence, seq_label = data['sentence'], data['seq_label']
        
        # Add CLS token
        subwords = [self.tokenizer.cls_token_id]
        subword_to_word_indices = [-1] # For CLS
        
        # Add subwords
        for word_idx, word in enumerate(sentence):
            subword_list = self.tokenizer.encode(word, add_special_tokens=False)
            subword_to_word_indices += [word_idx for i in range(len(subword_list))]
            subwords += subword_list
            
        # Add last SEP token
        subwords += [self.tokenizer.sep_token_id]
        subword_to_word_indices += [-1]
        
        return np.array(subwords), np.array(subword_to_word_indices), np.array(seq_label), data['sentence']
    
    def __len__(self):
        return len(self.data) 
        
class NerDataLoader(DataLoader):
    def __init__(self, max_seq_len=512, *args, **kwargs):
        super(NerDataLoader, self).__init__(*args, **kwargs)
        self.collate_fn = self._collate_fn
        self.max_seq_len = max_seq_len
        
    def _collate_fn(self, batch):
        batch_size = len(batch)
        max_seq_len = max(map(lambda x: len(x[0]), batch))
        max_seq_len = min(self.max_seq_len, max_seq_len)
        max_tgt_len = max(map(lambda x: len(x[2]), batch))
        
        subword_batch = np.zeros((batch_size, max_seq_len), dtype=np.int64)
        mask_batch = np.zeros((batch_size, max_seq_len), dtype=np.float32)
        subword_to_word_indices_batch = np.full((batch_size, max_seq_len), -1, dtype=np.int64)
        seq_label_batch = np.full((batch_size, max_tgt_len), -100, dtype=np.int64)
        
        seq_list = []
        for i, (subwords, subword_to_word_indices, seq_label, raw_seq) in enumerate(batch):
            subwords = subwords[:max_seq_len]
            subword_to_word_indices = subword_to_word_indices[:max_seq_len]

            subword_batch[i,:len(subwords)] = subwords
            mask_batch[i,:len(subwords)] = 1
            subword_to_word_indices_batch[i,:len(subwords)] = subword_to_word_indices
            seq_label_batch[i,:len(seq_label)] = seq_label

            seq_list.append(raw_seq)
            
        return subword_batch, mask_batch, subword_to_word_indices_batch, seq_label_batch, seq_list

In [22]:
dataset_path = 'data/nergrit_ner-grit/train_preprocess.txt'
pretrained_model = 'bert-base-uncased'
tokenizer = AutoTokenizer.from_pretrained(pretrained_model)
trainset = NerGritDataset(dataset_path, tokenizer)


In [23]:
for subwords, subword_to_word_indices, seq_label, sentence in trainset:
    print(subwords)
    print(subword_to_word_indices)
    print(seq_label)
    print(sentence)
    break

[  101 12849  3372  3089  8286 24300  2050 28774 16102  9331 27746 18886
 28076 10093  4430 27955 24237  5313  9126  7221  3148  2243  3653  9153
  5332  2744  3022  6968 12967  8922  2982  1010 14262  2696  4241  2050
 20252  2015  2053 22311  5332  1025  4241  2050 17323  2088  2636  1025
  4907  7279  9103  7911 10695  3148 16510  2121 23630 11905  2078  7367
 23615  2906  4185 18414  2696  2128 27052  2319  1012   102]
[-1  0  0  0  0  0  0  1  1  1  2  2  3  4  4  5  5  5  5  6  6  6  7  7
  7  8  8  8  9 10 11 12 13 13 14 14 15 15 16 16 16 17 18 18 19 20 21 22
 23 24 24 24 24 24 25 25 25 25 25 26 26 26 27 28 28 29 29 29 30 -1]
[5 5 5 5 5 5 5 5 5 5 5 5 5 5 5 5 5 5 5 5 5 5 5 5 5 5 5 5 5 5 5]
['Kontribusinya', 'terhadap', 'industri', 'musik', 'telah', 'mengumpulkan', 'banyak', 'prestasi', 'termasuk', 'lima', 'Grammy', 'Awards', ',', 'serta', 'dua', 'belas', 'nominasi', ';', 'dua', 'Guinness', 'World', 'Records', ';', 'dan', 'penjualannya', 'diperkirakan', 'sekitar', '64', 'juta', 'r

In [24]:
loader = NerDataLoader(dataset=trainset, batch_size=32, num_workers=32)

In [29]:
%%time
for i, (subwords, mask, subword_to_word_indices, seq_label, seq_list) in enumerate(loader):
    print(subwords, mask, subword_to_word_indices, seq_label, seq_list)
    if i == 2:
        break

[[  101 12849  3372 ...     0     0     0]
 [  101  7367  3672 ...     0     0     0]
 [  101 10556 17311 ...     0     0     0]
 ...
 [  101 11687  2050 ...     0     0     0]
 [  101 27829 14289 ...     0     0     0]
 [  101  3520  2050 ...     0     0     0]] [[1. 1. 1. ... 0. 0. 0.]
 [1. 1. 1. ... 0. 0. 0.]
 [1. 1. 1. ... 0. 0. 0.]
 ...
 [1. 1. 1. ... 0. 0. 0.]
 [1. 1. 1. ... 0. 0. 0.]
 [1. 1. 1. ... 0. 0. 0.]] [[-1  0  0 ... -1 -1 -1]
 [-1  0  0 ... -1 -1 -1]
 [-1  0  0 ... -1 -1 -1]
 ...
 [-1  0  0 ... -1 -1 -1]
 [-1  0  0 ... -1 -1 -1]
 [-1  0  0 ... -1 -1 -1]] [[   5    5    5 ... -100 -100 -100]
 [   5    5    5 ... -100 -100 -100]
 [   5    3    5 ... -100 -100 -100]
 ...
 [   5    5    5 ... -100 -100 -100]
 [   3    4    5 ... -100 -100 -100]
 [   5    5    5 ... -100 -100 -100]] [['Kontribusinya', 'terhadap', 'industri', 'musik', 'telah', 'mengumpulkan', 'banyak', 'prestasi', 'termasuk', 'lima', 'Grammy', 'Awards', ',', 'serta', 'dua', 'belas', 'nominasi', ';', 'dua', 'Gu