In [1]:
import os
import logging
os.environ['TF_CPP_MIN_LOG_LEVEL'] = '3'

os.environ['TF_FORCE_GPU_ALLOW_GROWTH'] = 'true'
os.environ["CUDA_VISIBLE_DEVICES"] = '0'

logging.getLogger('tensorflow').setLevel(logging.FATAL)

from dataset.dataset_loader import NerProcessor, FewNERDProcessor, create_tf_dataset_for_client, split_to_tf_datasets, batch_features
from utils.fl_utils import *

import tensorflow as tf
from tqdm.notebook import tqdm

import numpy as np

from models.model import build_BertNer, MaskedSparseCategoricalCrossentropy
from tokenization import FullTokenizer



# Pretrained models
TINY = 'uncased_L-2_H-128_A-2'
TINY_1 = 'uncased_L-4_H-128_A-2'
TINY_8 = 'uncased_L-8_H-128_A-2'
TINY_12 = 'uncased_L-12_H-128_A-2'
MINI = 'uncased_L-4_H-256_A-4'
SMALL = 'uncased_L-4_H-512_A-8'
MEDIUM = 'uncased_L-8_H-512_A-8'
BASE = 'uncased_L-12_H-768_A-12'

TINY_8_128 = 'uncased_L-8_H-128_A-2'

MODEL = os.path.join("models", TINY)
SEQ_LEN = 128
PRETRAINED = False

processor = NerProcessor('dataset/conll')
# processor = FewNERDProcessor('dataset/few_nerd')
tokenizer = FullTokenizer(os.path.join(MODEL, "vocab.txt"), True)
train_features = processor.get_train_as_features(SEQ_LEN, tokenizer)
eval_features = processor.get_test_as_features(SEQ_LEN, tokenizer)


def eval_model(model, eval_data, do_print=True):
    return evaluate(model, eval_data, 
                    processor.get_label_map(), 
                    processor.token_ind('O'), 
                    processor.token_ind('[SEP]'),
                    processor.token_ind('[PAD]'), 
                    do_print=do_print)
    
eval_data_batched = batch_features(eval_features, processor.get_labels(), SEQ_LEN, tokenizer, batch_size=64)


# Train single

In [2]:
# @tf.function
def train_batch(model, x, y):
    with tf.GradientTape() as tape:
        logits = model(x, training=True)
        loss = model.loss(y, logits)
    grad = tape.gradient(loss, model.trainable_variables)
    model.optimizer.apply_gradients(zip(grad, model.trainable_variables))

def train_single(epochs=1, lr=5e-3, batch_size=50, pretrained=True):
    model = build_BertNer(MODEL, processor.label_len(), SEQ_LEN)
    model.compile(
        optimizer=tf.keras.optimizers.Adam(lr),
        loss=MaskedSparseCategoricalCrossentropy())
    
    if pretrained:
        restore_model_ckpt(model, MODEL)
    
    # model.layers[3].trainable = False # Bert layer frozen
    
    data = split_to_tf_datasets(train_features, 1, batch_size)[0]
    for e in range(epochs):
        for x, y in tqdm(data, position=0, leave=False, desc="Training"):
            train_batch(model, x, y)
        print("Epoch", e)
        eval_model(model, eval_data_batched)
    return model

In [3]:
model = train_single(pretrained=False, lr=5e-4, batch_size=32, epochs=3)

Training:   0%|          | 0/438 [00:00<?, ?it/s]

Epoch 0


Evaluating:   0%|          | 0/54 [00:00<?, ?it/s]

              precision    recall  f1-score   support

         LOC     0.6939    0.6930    0.6935      1668
        MISC     0.6230    0.5556    0.5873       702
         ORG     0.4287    0.4401    0.4343      1661
         PER     0.3535    0.4737    0.4049      1617

   micro avg     0.4937    0.5388    0.5152      5648
   macro avg     0.5248    0.5406    0.5300      5648
weighted avg     0.5096    0.5388    0.5214      5648



Training:   0%|          | 0/438 [00:00<?, ?it/s]

Epoch 1


Evaluating:   0%|          | 0/54 [00:00<?, ?it/s]

              precision    recall  f1-score   support

         LOC     0.7278    0.7518    0.7396      1668
        MISC     0.6102    0.5798    0.5946       702
         ORG     0.5169    0.5051    0.5110      1661
         PER     0.4188    0.4972    0.4546      1617

   micro avg     0.5569    0.5850    0.5706      5648
   macro avg     0.5684    0.5835    0.5749      5648
weighted avg     0.5627    0.5850    0.5728      5648



Training:   0%|          | 0/438 [00:00<?, ?it/s]

Epoch 2


Evaluating:   0%|          | 0/54 [00:00<?, ?it/s]

              precision    recall  f1-score   support

         LOC     0.6822    0.8070    0.7394      1668
        MISC     0.6173    0.6296    0.6234       702
         ORG     0.5406    0.5370    0.5388      1661
         PER     0.4464    0.5077    0.4751      1617

   micro avg     0.5667    0.6199    0.5921      5648
   macro avg     0.5716    0.6203    0.5942      5648
weighted avg     0.5650    0.6199    0.5903      5648



In [11]:
model = train_single(pretrained=False, lr=5e-4, batch_size=32, epochs=3)

Training:   0%|          | 0/438 [00:00<?, ?it/s]

Epoch 0


Evaluating:   0%|          | 0/54 [00:00<?, ?it/s]

              precision    recall  f1-score   support

         LOC     0.7587    0.7086    0.7328      1668
        MISC     0.6342    0.5385    0.5824       702
         ORG     0.4802    0.4973    0.4886      1661
         PER     0.4316    0.4490    0.4401      1617

   micro avg     0.5601    0.5510    0.5555      5648
   macro avg     0.5762    0.5483    0.5610      5648
weighted avg     0.5677    0.5510    0.5585      5648



Training:   0%|          | 0/438 [00:00<?, ?it/s]

Epoch 1


Evaluating:   0%|          | 0/54 [00:00<?, ?it/s]

              precision    recall  f1-score   support

         LOC     0.7234    0.7446    0.7338      1668
        MISC     0.6313    0.6197    0.6254       702
         ORG     0.5390    0.5780    0.5578      1661
         PER     0.4101    0.5232    0.4598      1617

   micro avg     0.5573    0.6167    0.5855      5648
   macro avg     0.5760    0.6164    0.5942      5648
weighted avg     0.5680    0.6167    0.5901      5648



Training:   0%|          | 0/438 [00:00<?, ?it/s]

Epoch 2


Evaluating:   0%|          | 0/54 [00:00<?, ?it/s]

              precision    recall  f1-score   support

         LOC     0.7019    0.7890    0.7429      1668
        MISC     0.6000    0.6410    0.6198       702
         ORG     0.5866    0.5605    0.5733      1661
         PER     0.4693    0.4632    0.4662      1617

   micro avg     0.5933    0.6101    0.6016      5648
   macro avg     0.5895    0.6134    0.6006      5648
weighted avg     0.5887    0.6101    0.5985      5648



In [3]:
train_single(10, lr=0.005, pretrained=False) # uncased_L-2_H-128_A-2

Training:   0%|          | 0/280 [00:00<?, ?it/s]

Evaluating:   0%|          | 0/54 [00:00<?, ?it/s]

              precision    recall  f1-score   support

         LOC     0.6464    0.6049    0.6250      1668
        MISC     0.4362    0.1168    0.1843       702
         ORG     0.3777    0.2324    0.2877      1661
         PER     0.1244    0.2239    0.1599      1617

   micro avg     0.3237    0.3256    0.3246      5648
   macro avg     0.3961    0.2945    0.3142      5648
weighted avg     0.3918    0.3256    0.3379      5648



Training:   0%|          | 0/280 [00:00<?, ?it/s]

Evaluating:   0%|          | 0/54 [00:00<?, ?it/s]

              precision    recall  f1-score   support

         LOC     0.6663    0.6715    0.6689      1668
        MISC     0.5460    0.5584    0.5521       702
         ORG     0.4141    0.3612    0.3859      1661
         PER     0.2008    0.2975    0.2397      1617

   micro avg     0.4153    0.4591    0.4361      5648
   macro avg     0.4568    0.4721    0.4616      5648
weighted avg     0.4439    0.4591    0.4483      5648



Training:   0%|          | 0/280 [00:00<?, ?it/s]

Evaluating:   0%|          | 0/54 [00:00<?, ?it/s]

              precision    recall  f1-score   support

         LOC     0.6924    0.6871    0.6897      1668
        MISC     0.5936    0.5285    0.5592       702
         ORG     0.3580    0.3811    0.3692      1661
         PER     0.2740    0.3673    0.3139      1617

   micro avg     0.4414    0.4858    0.4626      5648
   macro avg     0.4795    0.4910    0.4830      5648
weighted avg     0.4620    0.4858    0.4716      5648



Training:   0%|          | 0/280 [00:00<?, ?it/s]

Evaluating:   0%|          | 0/54 [00:00<?, ?it/s]

              precision    recall  f1-score   support

         LOC     0.6903    0.7176    0.7037      1668
        MISC     0.5704    0.5484    0.5592       702
         ORG     0.3895    0.4022    0.3957      1661
         PER     0.3233    0.3915    0.3541      1617

   micro avg     0.4740    0.5104    0.4916      5648
   macro avg     0.4934    0.5149    0.5032      5648
weighted avg     0.4819    0.5104    0.4951      5648



Training:   0%|          | 0/280 [00:00<?, ?it/s]

Evaluating:   0%|          | 0/54 [00:00<?, ?it/s]

              precision    recall  f1-score   support

         LOC     0.6428    0.6894    0.6653      1668
        MISC     0.5421    0.5228    0.5323       702
         ORG     0.3521    0.4028    0.3757      1661
         PER     0.2376    0.2925    0.2622      1617

   micro avg     0.4183    0.4708    0.4430      5648
   macro avg     0.4436    0.4769    0.4589      5648
weighted avg     0.4288    0.4708    0.4482      5648



Training:   0%|          | 0/280 [00:00<?, ?it/s]

Evaluating:   0%|          | 0/54 [00:00<?, ?it/s]

              precision    recall  f1-score   support

         LOC     0.6878    0.7146    0.7010      1668
        MISC     0.5513    0.5741    0.5625       702
         ORG     0.4471    0.3564    0.3966      1661
         PER     0.2939    0.3469    0.3182      1617

   micro avg     0.4824    0.4865    0.4844      5648
   macro avg     0.4950    0.4980    0.4946      5648
weighted avg     0.4873    0.4865    0.4847      5648



Training:   0%|          | 0/280 [00:00<?, ?it/s]

Evaluating:   0%|          | 0/54 [00:00<?, ?it/s]

              precision    recall  f1-score   support

         LOC     0.6686    0.7002    0.6840      1668
        MISC     0.5699    0.5399    0.5545       702
         ORG     0.3994    0.4028    0.4011      1661
         PER     0.3001    0.4224    0.3509      1617

   micro avg     0.4556    0.5133    0.4827      5648
   macro avg     0.4845    0.5163    0.4976      5648
weighted avg     0.4717    0.5133    0.4893      5648



Training:   0%|          | 0/280 [00:00<?, ?it/s]

Evaluating:   0%|          | 0/54 [00:00<?, ?it/s]

              precision    recall  f1-score   support

         LOC     0.6829    0.7140    0.6981      1668
        MISC     0.4969    0.5627    0.5277       702
         ORG     0.4097    0.4028    0.4062      1661
         PER     0.3093    0.3389    0.3234      1617

   micro avg     0.4716    0.4963    0.4836      5648
   macro avg     0.4747    0.5046    0.4889      5648
weighted avg     0.4725    0.4963    0.4838      5648



Training:   0%|          | 0/280 [00:00<?, ?it/s]

Evaluating:   0%|          | 0/54 [00:00<?, ?it/s]

              precision    recall  f1-score   support

         LOC     0.6273    0.7578    0.6864      1668
        MISC     0.5078    0.5556    0.5306       702
         ORG     0.3654    0.4046    0.3840      1661
         PER     0.3619    0.3111    0.3346      1617

   micro avg     0.4706    0.5009    0.4852      5648
   macro avg     0.4656    0.5072    0.4839      5648
weighted avg     0.4594    0.5009    0.4774      5648



Training:   0%|          | 0/280 [00:00<?, ?it/s]

Evaluating:   0%|          | 0/54 [00:00<?, ?it/s]

              precision    recall  f1-score   support

         LOC     0.6591    0.7524    0.7027      1668
        MISC     0.5864    0.5755    0.5809       702
         ORG     0.4786    0.3504    0.4046      1661
         PER     0.2889    0.2907    0.2898      1617

   micro avg     0.4987    0.4800    0.4892      5648
   macro avg     0.5032    0.4922    0.4945      5648
weighted avg     0.4910    0.4800    0.4817      5648



<tensorflow.python.keras.engine.functional.Functional at 0x7ff1201c2dc0>

In [3]:
train_single(10, lr=5e-5, pretrained=False) # uncased_L-4_H-256_A-4

Training:   0%|          | 0/280 [00:00<?, ?it/s]

Evaluating:   0%|          | 0/54 [00:00<?, ?it/s]

              precision    recall  f1-score   support

         LOC     0.6667    0.0012    0.0024      1668
        MISC     0.0000    0.0000    0.0000       702
         ORG     0.4096    0.0409    0.0744      1661
         PER     0.0000    0.0000    0.0000      1617

   micro avg     0.4142    0.0124    0.0241      5648
   macro avg     0.2691    0.0105    0.0192      5648
weighted avg     0.3174    0.0124    0.0226      5648



Training:   0%|          | 0/280 [00:00<?, ?it/s]

Evaluating:   0%|          | 0/54 [00:00<?, ?it/s]

              precision    recall  f1-score   support

         LOC     0.3467    0.6595    0.4545      1668
        MISC     0.2200    0.0157    0.0293       702
         ORG     0.3883    0.0680    0.1158      1661
         PER     0.0735    0.0334    0.0459      1617

   micro avg     0.3008    0.2263    0.2583      5648
   macro avg     0.2571    0.1941    0.1614      5648
weighted avg     0.2650    0.2263    0.1850      5648



Training:   0%|          | 0/280 [00:00<?, ?it/s]

Evaluating:   0%|          | 0/54 [00:00<?, ?it/s]

              precision    recall  f1-score   support

         LOC     0.4776    0.7362    0.5794      1668
        MISC     0.5314    0.2051    0.2960       702
         ORG     0.3160    0.2306    0.2666      1661
         PER     0.1640    0.2183    0.1873      1617

   micro avg     0.3397    0.3732    0.3557      5648
   macro avg     0.3723    0.3476    0.3323      5648
weighted avg     0.3470    0.3732    0.3399      5648



Training:   0%|          | 0/280 [00:00<?, ?it/s]

Evaluating:   0%|          | 0/54 [00:00<?, ?it/s]

              precision    recall  f1-score   support

         LOC     0.6682    0.7074    0.6872      1668
        MISC     0.5758    0.4601    0.5115       702
         ORG     0.3705    0.3444    0.3569      1661
         PER     0.2559    0.3006    0.2765      1617

   micro avg     0.4438    0.4534    0.4486      5648
   macro avg     0.4676    0.4531    0.4580      5648
weighted avg     0.4511    0.4534    0.4507      5648



Training:   0%|          | 0/280 [00:00<?, ?it/s]

Evaluating:   0%|          | 0/54 [00:00<?, ?it/s]

              precision    recall  f1-score   support

         LOC     0.6407    0.7506    0.6913      1668
        MISC     0.5502    0.5385    0.5443       702
         ORG     0.3679    0.3763    0.3720      1661
         PER     0.3057    0.3760    0.3372      1617

   micro avg     0.4524    0.5069    0.4781      5648
   macro avg     0.4661    0.5103    0.4862      5648
weighted avg     0.4533    0.5069    0.4778      5648



Training:   0%|          | 0/280 [00:00<?, ?it/s]

Evaluating:   0%|          | 0/54 [00:00<?, ?it/s]

              precision    recall  f1-score   support

         LOC     0.6806    0.7398    0.7090      1668
        MISC     0.5850    0.5342    0.5585       702
         ORG     0.3733    0.4612    0.4126      1661
         PER     0.3245    0.3655    0.3438      1617

   micro avg     0.4688    0.5251    0.4954      5648
   macro avg     0.4909    0.5252    0.5060      5648
weighted avg     0.4764    0.5251    0.4986      5648



Training:   0%|          | 0/280 [00:00<?, ?it/s]

Evaluating:   0%|          | 0/54 [00:00<?, ?it/s]

              precision    recall  f1-score   support

         LOC     0.7053    0.7476    0.7258      1668
        MISC     0.5720    0.5655    0.5688       702
         ORG     0.4411    0.4708    0.4554      1661
         PER     0.3655    0.4218    0.3916      1617

   micro avg     0.5094    0.5503    0.5291      5648
   macro avg     0.5210    0.5514    0.5354      5648
weighted avg     0.5137    0.5503    0.5311      5648



Training:   0%|          | 0/280 [00:00<?, ?it/s]

Evaluating:   0%|          | 0/54 [00:00<?, ?it/s]

              precision    recall  f1-score   support

         LOC     0.7136    0.7350    0.7242      1668
        MISC     0.5903    0.5869    0.5886       702
         ORG     0.4533    0.4973    0.4743      1661
         PER     0.3607    0.4341    0.3940      1617

   micro avg     0.5120    0.5606    0.5352      5648
   macro avg     0.5295    0.5633    0.5453      5648
weighted avg     0.5207    0.5606    0.5393      5648



Training:   0%|          | 0/280 [00:00<?, ?it/s]

Evaluating:   0%|          | 0/54 [00:00<?, ?it/s]

              precision    recall  f1-score   support

         LOC     0.7195    0.7428    0.7310      1668
        MISC     0.6106    0.5584    0.5833       702
         ORG     0.4673    0.5123    0.4888      1661
         PER     0.3554    0.4484    0.3965      1617

   micro avg     0.5152    0.5678    0.5402      5648
   macro avg     0.5382    0.5655    0.5499      5648
weighted avg     0.5276    0.5678    0.5456      5648



Training:   0%|          | 0/280 [00:00<?, ?it/s]

Evaluating:   0%|          | 0/54 [00:00<?, ?it/s]

              precision    recall  f1-score   support

         LOC     0.7192    0.7494    0.7340      1668
        MISC     0.6009    0.5812    0.5909       702
         ORG     0.4979    0.4931    0.4955      1661
         PER     0.3748    0.4583    0.4124      1617

   micro avg     0.5329    0.5698    0.5507      5648
   macro avg     0.5482    0.5705    0.5582      5648
weighted avg     0.5408    0.5698    0.5540      5648



<tensorflow.python.keras.engine.functional.Functional at 0x7f700efc9ac0>

In [3]:
train_single(10, lr=5e-5, pretrained=False) # uncased_L-8_H-128_A-2

Training:   0%|          | 0/280 [00:00<?, ?it/s]

Epoch 0


Evaluating:   0%|          | 0/54 [00:00<?, ?it/s]

              precision    recall  f1-score   support

         LOC     0.2762    0.1091    0.1564      1668
        MISC     0.0000    0.0000    0.0000       702
         ORG     0.3898    0.0138    0.0267      1661
         PER     0.2500    0.0012    0.0025      1617

   micro avg     0.2851    0.0367    0.0650      5648
   macro avg     0.2290    0.0310    0.0464      5648
weighted avg     0.2678    0.0367    0.0548      5648



Training:   0%|          | 0/280 [00:00<?, ?it/s]

Epoch 1


Evaluating:   0%|          | 0/54 [00:00<?, ?it/s]

              precision    recall  f1-score   support

         LOC     0.3314    0.6966    0.4492      1668
        MISC     0.5333    0.0114    0.0223       702
         ORG     0.3487    0.0957    0.1502      1661
         PER     0.0663    0.0328    0.0439      1617

   micro avg     0.2893    0.2447    0.2651      5648
   macro avg     0.3199    0.2091    0.1664      5648
weighted avg     0.2857    0.2447    0.1922      5648



Training:   0%|          | 0/280 [00:00<?, ?it/s]

Epoch 2


Evaluating:   0%|          | 0/54 [00:00<?, ?it/s]

              precision    recall  f1-score   support

         LOC     0.5232    0.7086    0.6020      1668
        MISC     0.4846    0.1795    0.2620       702
         ORG     0.3990    0.2438    0.3027      1661
         PER     0.1463    0.1793    0.1612      1617

   micro avg     0.3631    0.3546    0.3588      5648
   macro avg     0.3883    0.3278    0.3319      5648
weighted avg     0.3740    0.3546    0.3455      5648



Training:   0%|          | 0/280 [00:00<?, ?it/s]

Epoch 3


Evaluating:   0%|          | 0/54 [00:00<?, ?it/s]

              precision    recall  f1-score   support

         LOC     0.6537    0.7230    0.6866      1668
        MISC     0.5899    0.4487    0.5097       702
         ORG     0.4350    0.3528    0.3896      1661
         PER     0.2210    0.2900    0.2509      1617

   micro avg     0.4405    0.4561    0.4482      5648
   macro avg     0.4749    0.4536    0.4592      5648
weighted avg     0.4576    0.4561    0.4525      5648



Training:   0%|          | 0/280 [00:00<?, ?it/s]

Epoch 4


Evaluating:   0%|          | 0/54 [00:00<?, ?it/s]

              precision    recall  f1-score   support

         LOC     0.6683    0.7500    0.7068      1668
        MISC     0.5573    0.5199    0.5380       702
         ORG     0.4033    0.4419    0.4217      1661
         PER     0.2800    0.3779    0.3217      1617

   micro avg     0.4535    0.5243    0.4863      5648
   macro avg     0.4772    0.5224    0.4970      5648
weighted avg     0.4654    0.5243    0.4917      5648



Training:   0%|          | 0/280 [00:00<?, ?it/s]

Epoch 5


Evaluating:   0%|          | 0/54 [00:00<?, ?it/s]

              precision    recall  f1-score   support

         LOC     0.7124    0.7500    0.7307      1668
        MISC     0.5702    0.5499    0.5598       702
         ORG     0.4369    0.4545    0.4456      1661
         PER     0.3277    0.3871    0.3550      1617

   micro avg     0.4971    0.5343    0.5151      5648
   macro avg     0.5118    0.5354    0.5228      5648
weighted avg     0.5036    0.5343    0.5180      5648



Training:   0%|          | 0/280 [00:00<?, ?it/s]

Epoch 6


Evaluating:   0%|          | 0/54 [00:00<?, ?it/s]

              precision    recall  f1-score   support

         LOC     0.6892    0.7830    0.7331      1668
        MISC     0.5458    0.5855    0.5649       702
         ORG     0.4815    0.4467    0.4635      1661
         PER     0.3724    0.4403    0.4035      1617

   micro avg     0.5198    0.5614    0.5398      5648
   macro avg     0.5222    0.5639    0.5413      5648
weighted avg     0.5196    0.5614    0.5385      5648



Training:   0%|          | 0/280 [00:00<?, ?it/s]

Epoch 7


Evaluating:   0%|          | 0/54 [00:00<?, ?it/s]

              precision    recall  f1-score   support

         LOC     0.7213    0.7758    0.7475      1668
        MISC     0.5697    0.5883    0.5788       702
         ORG     0.5262    0.4648    0.4936      1661
         PER     0.4037    0.4966    0.4454      1617

   micro avg     0.5493    0.5811    0.5647      5648
   macro avg     0.5552    0.5814    0.5663      5648
weighted avg     0.5542    0.5811    0.5654      5648



Training:   0%|          | 0/280 [00:00<?, ?it/s]

Epoch 8


Evaluating:   0%|          | 0/54 [00:00<?, ?it/s]

              precision    recall  f1-score   support

         LOC     0.6811    0.7992    0.7354      1668
        MISC     0.5671    0.5897    0.5782       702
         ORG     0.5087    0.4937    0.5011      1661
         PER     0.4028    0.4613    0.4301      1617

   micro avg     0.5386    0.5866    0.5616      5648
   macro avg     0.5399    0.5860    0.5612      5648
weighted avg     0.5366    0.5866    0.5596      5648



Training:   0%|          | 0/280 [00:00<?, ?it/s]

Epoch 9


Evaluating:   0%|          | 0/54 [00:00<?, ?it/s]

              precision    recall  f1-score   support

         LOC     0.7263    0.7746    0.7496      1668
        MISC     0.5819    0.5869    0.5844       702
         ORG     0.5382    0.5129    0.5253      1661
         PER     0.4195    0.5121    0.4612      1617

   micro avg     0.5599    0.5992    0.5789      5648
   macro avg     0.5665    0.5966    0.5801      5648
weighted avg     0.5652    0.5992    0.5805      5648



<tensorflow.python.keras.engine.functional.Functional at 0x7f54c32faca0>