In [1]:
import os

import sys

import numpy as np
import conlleval

from common import encode, label_encode, write_result
from common import load_pretrained
from common import create_ner_model, create_optimizer, argument_parser
from common import read_conll, process_sentences, get_labels
from common import save_ner_model


In [8]:
class arguments:
    datadir="./data/finer-news"
    train_data= "./data/finer-news/train.tsv"
    dev_data = "./data/finer-news/dev.tsv"
    test_data="./data/finer-news/test.tsv"
    ner_model_dir="./finer-news-model"

    modeldir="./models/bert-base-finnish-cased-v1"
    init_checkpoint="./models/bert-base-finnish-cased-v1/bert_model.ckpt"
    vocab_file="./models/bert-base-finnish-cased-v1/vocab.txt"
    bert_config_file="./models/bert-base-finnish-cased-v1/bert_config.json"

    batch_size=4
    learning_rate=5e-5
    max_seq_length=512
    num_train_epochs=1
    warmup_proportion = 0.1
    output_file = "./output.tsv"
    do_lower_case = False

args = arguments()

In [6]:
args.datadir

'./data/finer-news'

In [9]:
seq_len = args.max_seq_length    # abbreviation

pretrained_model, tokenizer = load_pretrained(args)

train_words, train_tags = read_conll(args.train_data)
test_words, test_tags = read_conll(args.test_data)
train_data = process_sentences(train_words, train_tags, tokenizer, seq_len)
test_data = process_sentences(test_words, test_tags, tokenizer, seq_len)

label_list = get_labels(train_data.labels)
tag_map = { l: i for i, l in enumerate(label_list) }
inv_tag_map = { v: k for k, v in tag_map.items() }

train_x = encode(train_data.combined_tokens, tokenizer, seq_len)
test_x = encode(test_data.combined_tokens, tokenizer, seq_len)

train_y, train_weights = label_encode(
    train_data.combined_labels, tag_map, seq_len)
test_y, test_weights = label_encode(
    test_data.combined_labels, tag_map, seq_len)

ner_model = create_ner_model(pretrained_model, len(tag_map))
optimizer = create_optimizer(len(train_x[0]), args)

ner_model.compile(
    optimizer,
    loss='sparse_categorical_crossentropy',
    sample_weight_mode='temporal',
    metrics=['sparse_categorical_accuracy']
)


W0114 12:01:00.234465 140345459726144 deprecation_wrapper.py:119] From /home/joffe/projektit/keras-bert-ner/bert/tokenization.py:125: The name tf.gfile.GFile is deprecated. Please use tf.io.gfile.GFile instead.



In [10]:
ner_model.fit(
    train_x,
    train_y,
    sample_weight=train_weights,
    epochs=args.num_train_epochs,
    batch_size=args.batch_size
)

W0114 12:01:45.013004 140345459726144 deprecation.py:323] From /home/joffe/anaconda3/envs/kerasner/lib/python3.7/site-packages/tensorflow/python/ops/math_grad.py:1250: add_dispatch_support.<locals>.wrapper (from tensorflow.python.ops.array_ops) is deprecated and will be removed in a future version.
Instructions for updating:
Use tf.where in 2.0, which has the same broadcast rule as np.where




<tensorflow.python.keras.callbacks.History at 0x7fa05068cc50>

In [11]:
if args.ner_model_dir is not None:
    label_list = [v for k, v in sorted(list(inv_tag_map.items()))]
    save_ner_model(ner_model, tokenizer, label_list, args)

In [12]:
probs = ner_model.predict(test_x, batch_size=args.batch_size)
preds = np.argmax(probs, axis=-1)

In [13]:
pred_tags = []
for i, pred in enumerate(preds):
    pred_tags.append([inv_tag_map[t] 
                      for t in pred[1:len(test_data.tokens[i])+1]])

In [14]:
lines = write_result(
    args.output_file, test_data.words, test_data.lengths,
    test_data.tokens, test_data.labels, pred_tags
)

c = conlleval.evaluate(lines)
conlleval.report(c)

processed 46363 tokens with 4124 phrases; found: 4216 phrases; correct: 3860.
accuracy:  98.76%; precision:  91.56%; recall:  93.60%; FB1:  92.57
             DATE: precision:  97.08%; recall:  97.90%; FB1:  97.49  240
            EVENT: precision:  94.12%; recall:  88.89%; FB1:  91.43  17
              LOC: precision:  95.91%; recall:  96.28%; FB1:  96.09  513
              ORG: precision:  93.15%; recall:  96.33%; FB1:  94.71  1943
              PER: precision:  91.69%; recall:  97.78%; FB1:  94.64  433
              PRO: precision:  85.23%; recall:  85.07%; FB1:  85.15  1070
