
# NER with BERT

In [1]:
!pip install transformers seqeval
# !pip install -U tensorflow

Collecting transformers
[?25l  Downloading https://files.pythonhosted.org/packages/d8/b2/57495b5309f09fa501866e225c84532d1fd89536ea62406b2181933fb418/transformers-4.5.1-py3-none-any.whl (2.1MB)
[K     |████████████████████████████████| 2.1MB 14.7MB/s 
[?25hCollecting seqeval
[?25l  Downloading https://files.pythonhosted.org/packages/9d/2d/233c79d5b4e5ab1dbf111242299153f3caddddbb691219f363ad55ce783d/seqeval-1.2.2.tar.gz (43kB)
[K     |████████████████████████████████| 51kB 8.0MB/s 
Collecting tokenizers<0.11,>=0.10.1
[?25l  Downloading https://files.pythonhosted.org/packages/ae/04/5b870f26a858552025a62f1649c20d29d2672c02ff3c3fb4c688ca46467a/tokenizers-0.10.2-cp37-cp37m-manylinux2010_x86_64.whl (3.3MB)
[K     |████████████████████████████████| 3.3MB 46.9MB/s 
Collecting sacremoses
[?25l  Downloading https://files.pythonhosted.org/packages/75/ee/67241dc87f266093c533a2d4d3d69438e57d7a90abb216fa076e7d475d4a/sacremoses-0.0.45-py3-none-any.whl (895kB)
[K     |████████████████████████

In [2]:
from google.colab import drive
drive.mount("/content/drive")

Mounted at /content/drive


In [None]:
# !git clone https://github.com/NVIDIA/apex
# !cd apex
# !pip install -v --no-cache-dir --global-option="--cpp_ext" --global-option="--cuda_ext" ./

## Import

In [3]:
import os
import re
import time

import numpy as np
import pandas as pd
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import DataLoader
from transformers import (
    BertTokenizer, BertForTokenClassification, AdamW,
    get_linear_schedule_with_warmup)
from seqeval.metrics import accuracy_score, classification_report
from tqdm import tqdm, tqdm_notebook
from jieba import cut

import seaborn as sns
import matplotlib.pyplot as plt
%matplotlib inline

# from apex import amp  # for distributed training in Pytorch
# import psutil
# from multiprocessing import Pool

# NUM_CORES = psutil.cpu_count()  # number of cores on your machine
# print("number of cores:", NUM_CORES)

# def df_parallelize_run(df, func, num_partitions=20):
#     df_split = np.array_split(df, num_partitions)
#     pool = Pool(NUM_CORES)
#     df = pd.concat(pool.map(func, df_split))
#     #df = sp.vstack(pool.map(func, df_split), format="csr") faster and mem efficient for
#     pool.close()
#     pool.join()
#     return df

In [4]:
print("Pytorch Version: {}".format(torch.__version__))

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Device: {device}")
if torch.cuda.device_count() > 0:
    print("Found GPU at: {}".format(torch.cuda.get_device_name(0)))

DATA_DIR = "drive/MyDrive/Colab Notebooks/data/ner"

Pytorch Version: 1.8.1+cu101
Device: cuda
Found GPU at: Tesla T4


## Data preprocessing

In [5]:
def whitespace_punctuation(s):
    """Add whitespace before punctuation."""
    s = re.sub(r"([.,!?()])", r" \1 ", s)
    s = re.sub(r"\s{2,}", " ", s)
    return s


def word_tagging(raw_line):
    """"Tag each word for given line."""
    raw_line = re.sub(r"\t\d+\n", "", raw_line)
    line_split0 = re.split(r'<ENAMEX TYPE="(.*?)">(.*?)</ENAMEX>', raw_line)

    raw_tags = []
    line_split1 = []
    flag = 0
    for x in line_split0:
        if x not in ["PERSON", "ORGANIZATION", "LOCATION"]:
            if flag == 0:
                raw_tags.append("O")
            line_split1.append(x)
            flag = 0
        else:
            flag = 1
            raw_tags.append(x)

    line_split = []
    tags = []
    for x, t in zip(line_split1, raw_tags):
        y = whitespace_punctuation(x).split()
        line_split.extend(y)
        tags.extend([t] * len(y))
    return line_split, tags


class InputExample(object):
    """A single training/test example for simple sequence classification."""
    def __init__(self, guid, text_a, text_b=None, label=None):
        """Constructs a InputExample.
        Args:
            guid: Unique id for the example.
            text_a: string. The untokenized text of the first sequence. For single
            sequence tasks, only this sequence must be specified.
            text_b: (Optional) string. The untokenized text of the second sequence.
            Only must be specified for sequence pair tasks.
            label: (Optional) string. The label of the example. This should be
            specified for train and dev examples, but not for test examples.
        """
        self.guid = guid
        self.text_a = text_a
        self.text_b = text_b
        self.label = label


def load_data(file_path):
    """Load data."""
    data_examples = []
    with open(file_path, "r") as f:
        for i, raw_line in enumerate(f):
            line_split, tags = word_tagging(raw_line)
            data_examples.append(InputExample(i, " ".join(line_split), label=tags))
    return data_examples

In [6]:
train_examples = load_data(DATA_DIR + "/ner_train_data.txt")
print("Training data size =", len(train_examples))

Training data size = 1700


In [7]:
val_examples = load_data(DATA_DIR + "/ner_val_data.txt")
print("Validation data size =", len(val_examples))

Validation data size = 426


In [8]:
class InputFeatures(object):
    """A single set of features of data."""
    def __init__(self, input_ids, input_mask, segment_ids, label_id):
        self.input_ids = input_ids
        self.input_mask = input_mask
        self.segment_ids = segment_ids
        self.label_id = label_id


def convert_examples_to_features(examples, label_list, max_seq_length, tokenizer,
                                 print_examples=False):
    """Loads a data file into a list of `InputBatch`s."""
    label_map = {label: i for i, label in enumerate(label_list, 1)}

    features = []
    for ex_index, example in enumerate(examples):
        textlist = example.text_a.split(" ")
        labellist = example.label
        if labellist is None:
            labellist = ["O"] * len(textlist)
        tokens = []
        labels = []
        for i, word in enumerate(textlist):
            token = tokenizer.tokenize(word)
            tokens.extend(token)
            label_1 = labellist[i]
            for m in range(len(token)):
                if m == 0:
                    labels.append(label_1)
                else:
                    labels.append("X")
        if len(tokens) >= max_seq_length - 1:
            tokens = tokens[0:(max_seq_length - 2)]
            labels = labels[0:(max_seq_length - 2)]
        ntokens = []
        segment_ids = []
        label_id = []
        ntokens.append("[CLS]")
        segment_ids.append(0)
        label_id.append(label_map["[CLS]"])
        for i, token in enumerate(tokens):
            ntokens.append(token)
            segment_ids.append(0)
            label_id.append(label_map[labels[i]])
        ntokens.append("[SEP]")
        segment_ids.append(0)
        label_id.append(label_map["[SEP]"])
        input_ids = tokenizer.convert_tokens_to_ids(ntokens)
        input_mask = [1] * len(input_ids)
        while len(input_ids) < max_seq_length:
            input_ids.append(0)
            input_mask.append(0)
            segment_ids.append(0)
            label_id.append(0)
        assert len(input_ids) == max_seq_length
        assert len(input_mask) == max_seq_length
        assert len(segment_ids) == max_seq_length
        assert len(label_id) == max_seq_length

        if print_examples and ex_index < 3:
            print("*** Example ***")
            print("tokens: %s" % " ".join([str(x) for x in tokens]))
            print("input_ids: %s" % " ".join([str(x) for x in input_ids]))
            print("input_mask: %s" % " ".join([str(x) for x in input_mask]))
            print("segment_ids: %s" % " ".join([str(x) for x in segment_ids]))

        features.append(
            InputFeatures(input_ids=input_ids,
                          input_mask=input_mask,
                          segment_ids=segment_ids,
                          label_id=label_id))
    return features


In [9]:
PRETRAINED_MODEL_NAME = "bert-base-multilingual-cased"
MAX_SEQUENCE_LENGTH = 100
LABEL_LIST = ["O", "PERSON", "ORGANIZATION", "LOCATION", "[CLS]", "[SEP]", "X"]
REV_LABEL_MAP = {i: label for i, label in enumerate(LABEL_LIST, 1)}
NUM_LABELS = len(LABEL_LIST) + 1
EPOCHS = 4
BATCH_SIZE = 16
LR = 2e-5
WARMUP = 0.1
LOGGING_STEPS = 20
ACCUMULATION_STEPS = 1
FINETUNED_MODEL_PATH = "finetuned_bert.bin"

In [10]:
bert_tokenizer = BertTokenizer.from_pretrained(PRETRAINED_MODEL_NAME, do_lower_case=False)

HBox(children=(FloatProgress(value=0.0, description='Downloading', max=995526.0, style=ProgressStyle(descripti…




HBox(children=(FloatProgress(value=0.0, description='Downloading', max=29.0, style=ProgressStyle(description_w…




HBox(children=(FloatProgress(value=0.0, description='Downloading', max=1961828.0, style=ProgressStyle(descript…




In [11]:
train_features = convert_examples_to_features(
    train_examples, LABEL_LIST, MAX_SEQUENCE_LENGTH, bert_tokenizer)

In [12]:
val_features = convert_examples_to_features(
    val_examples, LABEL_LIST, MAX_SEQUENCE_LENGTH, bert_tokenizer)

In [13]:
def get_dataloader(data_features, batch_size, shuffle=False, drop_last=False):
    """Output dataloader."""
    data_input_ids = torch.tensor(
        [f.input_ids for f in data_features], dtype=torch.long)
    data_input_mask = torch.tensor(
        [f.input_mask for f in data_features], dtype=torch.long)
    data_segment_ids = torch.tensor(
        [f.segment_ids for f in data_features], dtype=torch.long)
    data_label_id = torch.tensor(
        [f.label_id for f in data_features], dtype=torch.long)
    data_dataset = torch.utils.data.TensorDataset(
        data_input_ids, data_input_mask, data_segment_ids, data_label_id)
    data_loader = DataLoader(
        data_dataset, batch_size=batch_size, shuffle=shuffle, drop_last=drop_last)
    return data_loader
    

train_loader = get_dataloader(train_features, BATCH_SIZE, shuffle=True)
val_loader = get_dataloader(val_features, BATCH_SIZE)

## Train model

In [14]:
model = BertForTokenClassification.from_pretrained(
    PRETRAINED_MODEL_NAME, num_labels=NUM_LABELS)

model.to(device)

HBox(children=(FloatProgress(value=0.0, description='Downloading', max=625.0, style=ProgressStyle(description_…




HBox(children=(FloatProgress(value=0.0, description='Downloading', max=714314041.0, style=ProgressStyle(descri…




Some weights of the model checkpoint at bert-base-multilingual-cased were not used when initializing BertForTokenClassification: ['cls.predictions.bias', 'cls.predictions.transform.dense.weight', 'cls.predictions.transform.dense.bias', 'cls.predictions.decoder.weight', 'cls.seq_relationship.weight', 'cls.seq_relationship.bias', 'cls.predictions.transform.LayerNorm.weight', 'cls.predictions.transform.LayerNorm.bias']
- This IS expected if you are initializing BertForTokenClassification from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).
- This IS NOT expected if you are initializing BertForTokenClassification from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).
Some weights of BertForTokenClassification were not initialized from the model checkpoint at 

BertForTokenClassification(
  (bert): BertModel(
    (embeddings): BertEmbeddings(
      (word_embeddings): Embedding(119547, 768, padding_idx=0)
      (position_embeddings): Embedding(512, 768)
      (token_type_embeddings): Embedding(2, 768)
      (LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)
      (dropout): Dropout(p=0.1, inplace=False)
    )
    (encoder): BertEncoder(
      (layer): ModuleList(
        (0): BertLayer(
          (attention): BertAttention(
            (self): BertSelfAttention(
              (query): Linear(in_features=768, out_features=768, bias=True)
              (key): Linear(in_features=768, out_features=768, bias=True)
              (value): Linear(in_features=768, out_features=768, bias=True)
              (dropout): Dropout(p=0.1, inplace=False)
            )
            (output): BertSelfOutput(
              (dense): Linear(in_features=768, out_features=768, bias=True)
              (LayerNorm): LayerNorm((768,), eps=1e-12, elementwi

In [15]:
def evaluate_model(model, val_loader, device):
    """Evaluate model."""
    val_loss = 0
    nb_val_steps = 0
    y_true = []
    y_pred = []
    for batch in val_loader:
        batch = tuple(t.to(device) for t in batch)
        input_ids, input_mask, segment_ids, label_id = batch

        with torch.no_grad():
            outputs = model(
                input_ids,
                token_type_ids=segment_ids,
                attention_mask=input_mask,
                labels=label_id)
            loss, logits = outputs[:2]

        val_loss += loss.item()
        nb_val_steps += 1

        indices = torch.argmax(logits, dim=2).detach().cpu().numpy()
        input_mask = input_mask.to("cpu").numpy()
        label_id = label_id.to("cpu").numpy()

        for i, mask in enumerate(input_mask):
            tmp_true = []
            tmp_pred = []
            for j, m in enumerate(mask):
                if j == 0:
                    continue
                if m:
                    if REV_LABEL_MAP[label_id[i][j]] != "X":
                        tmp_true.append(REV_LABEL_MAP[label_id[i][j]])
                        tmp_pred.append(REV_LABEL_MAP[indices[i][j]])
                else:
                    tmp_true.pop()
                    tmp_pred.pop()
                    break
            y_true.append(tmp_true)
            y_pred.append(tmp_pred)

    val_loss /= nb_val_steps
    val_acc = accuracy_score(y_true, y_pred)
    return val_loss, val_acc, y_true, y_pred


def train_model(model, train_loader, val_loader, device):
    """Train model."""
    max_grad_norm = 1.0
    param_optimizer = list(model.named_parameters())
    no_decay = ["bias", "LayerNorm.bias", "LayerNorm.weight"]
    optimizer_grouped_parameters = [
        {"params": [p for n, p in param_optimizer if not any(nd in n for nd in no_decay)],
         "weight_decay": 0.01},
        {"params": [p for n, p in param_optimizer if any(nd in n for nd in no_decay)],
         "weight_decay": 0.0}
    ]

    num_total_steps = int(
        EPOCHS * len(train_loader) / ACCUMULATION_STEPS)
    num_warmup_steps = WARMUP * num_total_steps
    optimizer = AdamW(optimizer_grouped_parameters, lr=LR,
                      correct_bias=False)  # To reproduce BertAdam specific behavior set correct_bias=False
    scheduler = get_linear_schedule_with_warmup(
        optimizer, num_warmup_steps, num_total_steps)  # PyTorch scheduler

    best_loss = np.inf
    global_step = 0
    tr_loss = 0.0
    logging_loss = 0.0
    model.zero_grad()
    for epoch in range(EPOCHS):
        # TRAIN loop
        t0 = time.time()
        model.train()
        
        for step, batch in enumerate(train_loader):
            batch = tuple(t.to(device) for t in batch)
            input_ids, input_mask, segment_ids, label_id = batch
            # forward pass
            outputs = model(
                input_ids,
                token_type_ids=segment_ids,
                attention_mask=input_mask,
                labels=label_id)
            
            loss = outputs[0]
            if ACCUMULATION_STEPS > 1:
                loss = loss / ACCUMULATION_STEPS

            # backward pass
            loss.backward()

            # Gradient clipping is not in AdamW anymore (so you can use amp without issue)
            torch.nn.utils.clip_grad_norm_(model.parameters(), max_grad_norm)

            tr_loss += loss.item()

            # update parameters
            if (step + 1) % ACCUMULATION_STEPS == 0:
                optimizer.step()
                scheduler.step()
                model.zero_grad()  # same as optimizer.zero_grad()
                global_step += 1

            if global_step % LOGGING_STEPS == 0:
                loss_scalar = (tr_loss - logging_loss) / LOGGING_STEPS
                logging_loss = tr_loss
                print(f"Epoch {epoch + 1}: global step = {global_step}  train loss = {loss_scalar:.4f}")

        model.eval()
        val_loss, val_acc, _, _ = evaluate_model(model, val_loader, device)
        print(f"Epoch {epoch + 1}/{EPOCHS}: elapsed time = {time.time() - t0:.0f}s"
              f"  val loss = {val_loss:.4f}  val accuracy = {val_acc:.4f}")
        
        if val_loss < best_loss:
            # Save model artefact
            print(f"Epoch {epoch + 1}: val loss improved from {best_loss:.5f} to {val_loss:.5f}, "
                  f"saving model to {FINETUNED_MODEL_PATH}\n")
            best_loss = val_loss
            torch.save(model.state_dict(), FINETUNED_MODEL_PATH)
        else:
            print(f"Epoch {epoch + 1}: val loss did not improve from {best_loss:.5f}\n")


In [16]:
train_model(model, train_loader, val_loader, device)

Epoch 1: global step = 20  train loss = 1.3307
Epoch 1: global step = 40  train loss = 0.3001
Epoch 1: global step = 60  train loss = 0.1180
Epoch 1: global step = 80  train loss = 0.0831
Epoch 1: global step = 100  train loss = 0.0588
Epoch 1/4: elapsed time = 35s  val loss = 0.0771  val accuracy = 0.9612
Epoch 1: val loss improved from inf to 0.07708, saving model to finetuned_bert.bin

Epoch 2: global step = 120  train loss = 0.0589
Epoch 2: global step = 140  train loss = 0.0409
Epoch 2: global step = 160  train loss = 0.0416
Epoch 2: global step = 180  train loss = 0.0427
Epoch 2: global step = 200  train loss = 0.0372
Epoch 2/4: elapsed time = 35s  val loss = 0.0619  val accuracy = 0.9739
Epoch 2: val loss improved from 0.07708 to 0.06190, saving model to finetuned_bert.bin

Epoch 3: global step = 220  train loss = 0.0347
Epoch 3: global step = 240  train loss = 0.0187
Epoch 3: global step = 260  train loss = 0.0224
Epoch 3: global step = 280  train loss = 0.0175
Epoch 3: global 

In [23]:
# !cp finetuned_bert.bin drive/MyDrive/Colab\ Notebooks/models/finetuned_bert.bin

## Validation

In [17]:
model = BertForTokenClassification.from_pretrained(
    PRETRAINED_MODEL_NAME, num_labels=NUM_LABELS)
model.load_state_dict(torch.load(FINETUNED_MODEL_PATH))
model.to(device)
model.eval()

Some weights of the model checkpoint at bert-base-multilingual-cased were not used when initializing BertForTokenClassification: ['cls.predictions.bias', 'cls.predictions.transform.dense.weight', 'cls.predictions.transform.dense.bias', 'cls.predictions.decoder.weight', 'cls.seq_relationship.weight', 'cls.seq_relationship.bias', 'cls.predictions.transform.LayerNorm.weight', 'cls.predictions.transform.LayerNorm.bias']
- This IS expected if you are initializing BertForTokenClassification from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).
- This IS NOT expected if you are initializing BertForTokenClassification from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).
Some weights of BertForTokenClassification were not initialized from the model checkpoint at 

BertForTokenClassification(
  (bert): BertModel(
    (embeddings): BertEmbeddings(
      (word_embeddings): Embedding(119547, 768, padding_idx=0)
      (position_embeddings): Embedding(512, 768)
      (token_type_embeddings): Embedding(2, 768)
      (LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)
      (dropout): Dropout(p=0.1, inplace=False)
    )
    (encoder): BertEncoder(
      (layer): ModuleList(
        (0): BertLayer(
          (attention): BertAttention(
            (self): BertSelfAttention(
              (query): Linear(in_features=768, out_features=768, bias=True)
              (key): Linear(in_features=768, out_features=768, bias=True)
              (value): Linear(in_features=768, out_features=768, bias=True)
              (dropout): Dropout(p=0.1, inplace=False)
            )
            (output): BertSelfOutput(
              (dense): Linear(in_features=768, out_features=768, bias=True)
              (LayerNorm): LayerNorm((768,), eps=1e-12, elementwi

In [None]:
val_loader = get_dataloader(val_features, batch_size=BATCH_SIZE)

val_loss, val_acc, y_true, y_pred = evaluate_model(model, val_loader, device)
print(f"Val loss = {val_loss:.4f}  Val accuracy = {val_acc:.4f}")  # Val loss = 0.0664  Val accuracy = 0.9736

Val loss = 0.0588  Val accuracy = 0.9719


In [None]:
# with open("eval_results.txt", "w") as writer:
#     writer.write(classification_report(y_true, y_pred, digits=6))

## Batch scoring

In [18]:
def split_text(text, lang):
    """Split text."""
    if lang in ["zh-tw", "zh-cn"]:
        return [el for el in cut(text, cut_all=False) if el != " "]
    return whitespace_punctuation(text).split()


# def convert_text_to_example(text, lang=None):
#     """Convert text to input example."""
#     text_split = split_text(text, lang)
#     return InputExample(0, " ".join(text_split), label=["O"] * len(text_split))


def print_text_with_tags(text_split, tags):
    # 0 black
    # 1 red
    # 2 green
    # 3 yellow
    # 4 blue
    # 5 magenta
    # 6 cyan
    # 7 white
    # 9 default

    dict_background = {
        "PERSON": "\033[46m", # cyan
        "ORGANIZATION": "\033[43m", # yellow
        "LOCATION": "\033[45m" # magenta
    }
    for k, v in dict_background.items():
        print(v+k+"\033[49m")

    print_str = []
    for word, tag in zip(text_split, tags):
        c = dict_background.get(tag)
        if c is not None:
            print_str.append(c+word+"\033[49m")
        else:
            print_str.append(word)
        
    print(" ".join(print_str))
    return

In [None]:
def batch_score(model, test_loader, device):
    """Perform batch scoring."""
    y_pred = []
    for batch in test_loader:
        batch = tuple(t.to(device) for t in batch)
        input_ids, input_mask, segment_ids, label_id = batch

        with torch.no_grad():
            outputs = model(input_ids, segment_ids, input_mask, labels=None)
            logits = outputs[0]
        
        indices = torch.argmax(logits, dim=2).detach().to("cpu").numpy()
        input_mask = input_mask.to("cpu").numpy()
        label_id = label_id.to("cpu").numpy()

        for i, mask in enumerate(input_mask):
            tmp_pred = []
            for j, m in enumerate(mask):
                if j == 0:
                    continue
                if m:
                    if REV_LABEL_MAP[label_id[i][j]] != "X":
                        tmp_pred.append(REV_LABEL_MAP[indices[i][j]])
                else:
                    tmp_pred.pop()
                    break
            y_pred.append(tmp_pred)
    return y_pred

In [None]:
test_df = pd.read_csv(DATA_DIR + "/ner_eval_data.csv")

print("Testing data size =", test_df.shape[0])

In [None]:
test_text_splits = []
test_examples = []
for i in range(test_df.shape[0]):
    text_split = split_text(test_df["title"].iloc[i], test_df["language_s"].iloc[i])
    test_text_splits.append(text_split)
    test_examples.append(
        InputExample(0, " ".join(text_split), label=["O"] * len(text_split)))

test_features = convert_examples_to_features(
    test_examples, LABEL_LIST, MAX_SEQUENCE_LENGTH, bert_tokenizer)

In [None]:
model = BertForTokenClassification.from_pretrained(
    PRETRAINED_MODEL_NAME, num_labels=NUM_LABELS)
model.load_state_dict(torch.load(FINETUNED_MODEL_PATH))
model.to(device)

for param in model.parameters():
    param.requires_grad = False
model.eval()

In [None]:
test_loader = get_dataloader(test_features, BATCH_SIZE)

y_pred = predict(model, test_loader, device)

## Serve

In [19]:
def predict(text, lang="en"):
    """Predict."""
    text_split = split_text(text, lang)
    sam_features = convert_examples_to_features(
        [InputExample(0, " ".join(text_split), label=["O"] * len(text_split))],
        LABEL_LIST, MAX_SEQUENCE_LENGTH, bert_tokenizer)
    
    sam_input_ids = torch.tensor([f.input_ids for f in sam_features], dtype=torch.long)
    sam_input_mask = torch.tensor([f.input_mask for f in sam_features], dtype=torch.long)
    sam_segment_ids = torch.tensor([f.segment_ids for f in sam_features], dtype=torch.long)

    with torch.no_grad():
        logits = model(sam_input_ids, sam_segment_ids, sam_input_mask, labels=None)[0]
        
    indices = torch.argmax(logits, dim=2).detach().numpy()[0]
    mask = sam_features[0].input_mask
    label_id = sam_features[0].label_id

    tags = []
    for j, m in enumerate(mask):
        if j == 0:
            continue
        if m:
            if REV_LABEL_MAP[label_id[j]] != "X":
                tags.append(REV_LABEL_MAP[indices[j]])
        else:
            tags.pop()
            break
    return text_split, tags

In [20]:
model = BertForTokenClassification.from_pretrained(
    PRETRAINED_MODEL_NAME, num_labels=NUM_LABELS)
model.load_state_dict(torch.load("finetuned_bert.bin", map_location=torch.device("cpu")))
model.eval()

Some weights of the model checkpoint at bert-base-multilingual-cased were not used when initializing BertForTokenClassification: ['cls.predictions.bias', 'cls.predictions.transform.dense.weight', 'cls.predictions.transform.dense.bias', 'cls.predictions.decoder.weight', 'cls.seq_relationship.weight', 'cls.seq_relationship.bias', 'cls.predictions.transform.LayerNorm.weight', 'cls.predictions.transform.LayerNorm.bias']
- This IS expected if you are initializing BertForTokenClassification from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).
- This IS NOT expected if you are initializing BertForTokenClassification from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).
Some weights of BertForTokenClassification were not initialized from the model checkpoint at 

BertForTokenClassification(
  (bert): BertModel(
    (embeddings): BertEmbeddings(
      (word_embeddings): Embedding(119547, 768, padding_idx=0)
      (position_embeddings): Embedding(512, 768)
      (token_type_embeddings): Embedding(2, 768)
      (LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)
      (dropout): Dropout(p=0.1, inplace=False)
    )
    (encoder): BertEncoder(
      (layer): ModuleList(
        (0): BertLayer(
          (attention): BertAttention(
            (self): BertSelfAttention(
              (query): Linear(in_features=768, out_features=768, bias=True)
              (key): Linear(in_features=768, out_features=768, bias=True)
              (value): Linear(in_features=768, out_features=768, bias=True)
              (dropout): Dropout(p=0.1, inplace=False)
            )
            (output): BertSelfOutput(
              (dense): Linear(in_features=768, out_features=768, bias=True)
              (LayerNorm): LayerNorm((768,), eps=1e-12, elementwi

In [21]:
# sample_text = "Mantan Gubernur DKI Jakarta Basuki Tjahaja Purnama atau Ahok menyatakan, dia heran Gubernur Anies Baswedan menerbitkan surat izin mendirikan bangunan (IMB) untuk bangunan di Pulau D, pulau hasil reklamasi, berdasarkan Peraturan Gubernur Nomor 206 Tahun 2016 yang dulu diteken Ahok."

# sample_text = "Presiden Joko Widodo membagikan video vlognya bersama Wakil Presiden Jusuf Kalla, di instagram @jokowi. Dalam video, Presiden Jokowi memberikan beberapa pertanyaan kepada Jusuf Kalla mengenai aktivitas selama lebaran. Presiden Jokowi tampak memegang sebuah ponsel berwarna hitam. Jokowi bertanya mengenai aktivitas Jusuf Kalla ketika lebaran. Jusuf Kalla kemudian menjelaskan kalau dirinya hanya istirahat di rumah." \
#     + "Jokowi kemudian melontarkan pertanyaan mengenai makanan favorit Jusuf Kalla saat merayakan lebaran. Jusuf Kalla mengaku rindu pada ayam opor, ketupat, dan sate. Ketika ditanya soal bermain dengan cucu, Jusuf Kalla menjelaskan dirinya memiliki 15 cucu dari kelima putra dan putrinya. Sementara melalui caption-nya, Jokowi menjelaskan bahwa dirinya menerima kedatangan Jusuf Kalla di istana negara, Jakarta. Kedatangan Jusuf Kalla tersebut bermaksud untuk membahas soal pekerjaan juga bercerita kegiatan dalam merayakan lebaran bersama keluarga."

# sample_text = "Wakil Presiden Jusuf Kalla memuji langkah safari lebaran yang dilakukan keluarga Ketua Kogasma (Komandan Tugas Bersama) Partai Demokrat Agus Harimukti Yudhoyono dan Eddhie Baskoro Yudhoyono (Ibbas), saat perayaan Hari Raya Idulfitri lalu, ke sejumlah tokoh besar bangsa Indonesia.Diketahui, keluarga generasi kedua dari presiden RI ke-6 Susilo Bambang Yudhoyono (SBY) itu, bertemu dengan \nPresiden Jokowi, Presiden RI ke-5 Megawati Soekarnoputri, serta Presiden RI ke-3 BJ.Habibie.Menurut JK, pertemuan tersebut, terjalin hangat dan memungkinkan mencairkan suasana politik nasional."Itu suatu hal yang baek sebenarnya apabila dalam kondisi lebaran untuk silaturahim dengan siapa saja, justru kita saling memaafkan, ya semua tau bahwa hubungan Bu Mega dengan SBY agak renggang kan. Jadi justru anaknya generasi keduanya bagus, berselfie ria, itu berarti mencairkan suasana politik nasional," kata JK di kantor Wapres RI, Jalan Medan Merdeka Utara, Jakarta Pusat, Selasa (11/6/2019).JK menilai, semua pertemuan tak melulu terkait politik.Dari pertemuan-pertemuan tersebut diharapkan, dapat merekatkan kembali silaturahmi antar tokoh bangsa."Jadi kita sambut baik pertemuan-pertemuan itu, jangan diliat hanya dari sisi politik tapi dari sisi hubungan-hubungan secara nasional," ungkap mantan ketum partai Golkar ini.Pertemuan AHY dan Megawati terlaksana pada 5 Juni 2019 di rumah Megawati, di Teuku Umar, Jakarta Pusat.Baik AHY, Ibbas dan Puan, terlihat hangat dengan berswa foto bersama."

# sample_text = "Ir. H. Joko Widodo atau Jokowi adalah Presiden ke-7 Indonesia yang mulai menjabat sejak 20 Oktober 2014. Ia terpilih bersama Wakil Presiden Muhammad Jusuf Kalla dalam Pemilu Presiden 2014 dan kembali terpilih bersama Wakil Presiden Ma'ruf Amin dalam Pemilu Presiden 2019. Jokowi pernah menjabat Gubernur DKI Jakarta sejak 15 Oktober 2012 hingga 16 Oktober 2014 didampingi Basuki Tjahaja Purnama sebagai wakil gubernur. Sebelumnya, ia adalah Wali Kota Surakarta (Solo), sejak 28 Juli 2005 hingga 1 Oktober 2012 didampingi F.X. Hadi Rudyatmo sebagai wakil wali kota. Dua tahun menjalani periode keduanya menjadi Wali Kota Solo, Jokowi ditunjuk oleh partainya, Partai Demokrasi Indonesia Perjuangan (PDI-P), untuk bertarung dalam pemilihan Gubernur DKI Jakarta berpasangan dengan Basuki Tjahaja Purnama (Ahok)."

sample_text = "Google, headquartered in Mountain View (1600 Amphitheatre Pkwy, Mountain View, CA 940430), unveiled the new Android phone for $799 at the Consumer Electronic Show. Sundar Pichai said in his keynote that users love their new Android phones."

In [22]:
text_split, tags = predict(sample_text)
print_text_with_tags(text_split, tags)

[46mPERSON[49m
[43mORGANIZATION[49m
[45mLOCATION[49m
[43mGoogle[49m , headquartered in [45mMountain[49m [45mView[49m ( [45m1600[49m [45mAmphitheatre[49m [45mPkwy[49m , [45mMountain[49m [45mView[49m , [45mCA[49m [45m940430[49m ) , unveiled the new Android phone for $799 at the Consumer Electronic Show . [46mSundar[49m [46mPichai[49m said in his keynote that users love their new Android phones .


In [None]:
sample_text = "腾讯科技股份有限公司是中國大陸规模最大的互联网公司，1998年11月由马化腾、张志东、陈一丹、许晨晔、曾李青5位创始人共同创立，總部位於深圳南山区騰訊濱海大廈。腾讯业务拓展至社交、娱乐、金融、资讯、工具和平台等不同领域。目前，腾讯拥有中国大陸使用人数最多的社交软件腾讯QQ和微信，以及最大的网络游戏社区腾讯游戏。在電子書領域 ，旗下有閱文集團，運營有QQ讀書和微信讀書。"

# sample_text = "近日，韩国男团GOT7的成员Jackson王嘉尔（嘎嘎）参加朋友的婚礼，一组迎娶婚礼,伴郎,王嘉尔,彭于晏,Jackson,胡歌,"

In [None]:
text_split, tags = predict(sample_text, "zh-cn")
print_text_with_tags(text_split, tags)

[46mPERSON[49m
[43mORGANIZATION[49m
[45mLOCATION[49m
[43m腾讯[49m [43m科技股份[49m [43m有限公司[49m 是 [45m中國大陸[49m 规模 最大 的 互联网 公司 ， 1998 年 11 月 由 [46m马化腾[49m 、 [46m张志东[49m 、 [46m陈一丹[49m 、 [46m许晨晔[49m 、 [46m曾[49m [46m李青[49m 5 位 创始人 共同 创立 ， 總 部位 於 [45m深圳[49m [45m南山区[49m [43m騰訊濱[49m 海大 [45m廈[49m 。 [43m腾讯[49m 业务 拓展 至 社交 、 娱乐 、 金融 、 资讯 、 工具 和 平台
