In [1]:
import transformers
from transformers import TFBertForTokenClassification
import tensorflow as tf
from tensorflow.keras.callbacks import ModelCheckpoint
import numpy as np
import glob

import sys
sys.path.append("..")
from data_preparation.data_preparation_pos import ABSATokenizer, convert_examples_to_tf_dataset, read_conll

In [3]:
gpu_devices = tf.config.experimental.list_physical_devices('GPU')
tf.config.experimental.set_memory_growth(gpu_devices[0], True)

In [4]:
tagset = ["O", "_", "ADJ", "ADP", "ADV", "AUX", "CCONJ", "DET", "INTJ", "NOUN", "NUM", 
          "PART", "PRON", "PROPN", "PUNCT", "SCONJ", "SYM", "VERB", "X"]
num_labels = len(tagset)
max_length = 256

model_name = "bert-base-multilingual-cased"
tokenizer = ABSATokenizer.from_pretrained(model_name)
config = transformers.BertConfig.from_pretrained(model_name, num_labels=num_labels)
model = TFBertForTokenClassification.from_pretrained(model_name,
                                                     config=config)

Some weights of the model checkpoint at bert-base-multilingual-cased were not used when initializing TFBertForTokenClassification: ['mlm___cls', 'nsp___cls']
- This IS expected if you are initializing TFBertForTokenClassification from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPretraining model).
- This IS NOT expected if you are initializing TFBertForTokenClassification from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).
Some weights of TFBertForTokenClassification were not initialized from the model checkpoint at bert-base-multilingual-cased and are newly initialized: ['dropout_37', 'classifier']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.


In [9]:
training_lang = "he"
path = "../data/ud/"

train_data = read_conll(glob.glob(path + training_lang + "/*-train.conllu")[0])
train_examples = [{"id": sent_id, "tokens": tokens, "tags": tags} for sent_id, tokens, tags in zip(train_data[0], 
                                                                                                   train_data[1],
                                                                                                   train_data[2])]
# In case some example is over max length
train_examples = [example for example in train_examples if len(tokenizer.subword_tokenize(example["tokens"], 
                                                                                          example["tags"])[0]) <= max_length]


dev_data = read_conll(glob.glob(path + training_lang + "/*-dev.conllu")[0])
dev_examples = [{"id": sent_id, "tokens": tokens, "tags": tags} for sent_id, tokens, tags in zip(dev_data[0], 
                                                                                                 dev_data[1],
                                                                                                 dev_data[2])]
# In case some example is over max length
dev_examples = [example for example in dev_examples if len(tokenizer.subword_tokenize(example["tokens"], 
                                                                                      example["tags"])[0]) <= max_length]

In [12]:
batch_size = 8
epochs = 30
train_dataset = convert_examples_to_tf_dataset(examples=train_examples, tokenizer=tokenizer, 
                                               tagset=tagset, max_length=max_length)
train_dataset = train_dataset.shuffle(100000).batch(batch_size).repeat(epochs)
dev_dataset = convert_examples_to_tf_dataset(examples=dev_examples, tokenizer=tokenizer, 
                                             tagset=tagset, max_length=max_length)
dev_dataset = dev_dataset.shuffle(100000).batch(batch_size).repeat(1)

In [13]:
example_batch = train_dataset.as_numpy_iterator().next()

for token, label in zip(example_batch[0]["input_ids"][0], example_batch[1][0]):
    if token == 0:
        break
    print("{:<25}{:<20}".format(tokenizer.decode(int(token)), tagset[label]))

ה                        _                   
# # מ ס פ ר              _                   
ה                        DET                 
מ ס פ ר                  NOUN                
,                        PUNCT               
נ                        VERB                
# # ר ד                  VERB                
# # ף                    VERB                
ע ל                      ADP                 
-                        PUNCT               
י ד י                    NOUN                
ז                        NOUN                
# # י                    NOUN                
# # כ ר                  NOUN                
# # ו נ ו ת              NOUN                
מ                        ADJ                 
# # ר י ם                ADJ                 
מ                        _                   
# # מ ל                  _                   
# # ח מ ת                _                   
מ                        ADP                 
מ ל ח מ ת                NOUN     

In [14]:
checkpoint = ModelCheckpoint("../checkpoints_" + training_lang + "/" + model_name + "_pos_checkpoint.hdf5", 
                             verbose=1, monitor='val_ignore_acc',
                             save_best_only=True, mode='max', save_weights_only=True)

In [15]:
import tensorflow.keras.backend as K
def ignore_acc(y_true_class, y_pred_class, class_to_ignore=0):
    y_pred_class = K.cast(K.argmax(y_pred_class, axis=-1), 'int32')
    y_true_class = K.cast(y_true_class, 'int32')
    ignore_mask = K.cast(K.not_equal(y_true_class, class_to_ignore), 'int32')
    matches = K.cast(K.equal(y_true_class, y_pred_class), 'int32') * ignore_mask
    accuracy = K.sum(matches) / K.maximum(K.sum(ignore_mask), 1)
    return accuracy

In [16]:
optimizer = tf.keras.optimizers.Adam(learning_rate=2e-5)
loss = tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True)
model.compile(optimizer=optimizer, loss=loss, metrics=[ignore_acc])

In [None]:
model.fit(train_dataset, epochs=epochs, steps_per_epoch=np.ceil(len(train_examples) / batch_size),
          validation_data=dev_dataset, validation_steps=np.ceil(len(dev_examples) / batch_size),
          callbacks=[checkpoint])