In [1]:
import numpy as np
import pandas as pd
import glob
import os
import tensorflow as tf
import transformers
from transformers import TFBertForTokenClassification
from tqdm.notebook import tqdm
import IPython

import sys
sys.path.append("..")
from data_preparation.data_preparation_pos import ABSATokenizer, convert_examples_to_tf_dataset, read_conll
import utils.utils as utils
import utils.pos_utils as pos_utils

In [2]:
gpu_devices = tf.config.experimental.list_physical_devices('GPU')
tf.config.experimental.set_memory_growth(gpu_devices[0], True)

In [4]:
data_dir = "../data/ud/"

code_dicts = utils.make_lang_code_dicts()
code_to_name = code_dicts["code_to_name"]
name_to_code = code_dicts["name_to_code"]

# Model parameters
max_length = 256
batch_size = 256
model_name = "bert-base-multilingual-cased"
tagset = ["O", "_", "ADJ", "ADP", "ADV", "AUX", "CCONJ", "DET", "INTJ", "NOUN", "NUM", 
          "PART", "PRON", "PROPN", "PUNCT", "SCONJ", "SYM", "VERB", "X"]
num_labels = len(tagset)
label_map = {label: i for i, label in enumerate(tagset)}

# Model creation
tokenizer = ABSATokenizer.from_pretrained(model_name)
config = transformers.BertConfig.from_pretrained(model_name, num_labels=num_labels)
model = TFBertForTokenClassification.from_pretrained(model_name,
                                                     config=config)

Some weights of the model checkpoint at bert-base-multilingual-cased were not used when initializing TFBertForTokenClassification: ['nsp___cls', 'mlm___cls']
- This IS expected if you are initializing TFBertForTokenClassification from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPretraining model).
- This IS NOT expected if you are initializing TFBertForTokenClassification from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).
Some weights of TFBertForTokenClassification were not initialized from the model checkpoint at bert-base-multilingual-cased and are newly initialized: ['classifier', 'dropout_75']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.


In [26]:
full_preds = {}
reconstructed_preds = {}
full_labels = {}
reconstructed_labels = {}
data_dicts = [full_preds, reconstructed_preds, full_labels, reconstructed_labels]

for weights_filepath in tqdm(glob.glob("E:/TFM_CCIL/checkpoints/*/*_pos.hdf5")):
    training_lang = weights_filepath.split("\\")[1]
    train_lang_name = code_to_name[training_lang]
    for d in data_dicts:
        d[train_lang_name] = {}
    
    # Load weights
    model.load_weights(weights_filepath)
    print("\nUsing weights from", weights_filepath)
    
    
    # Evaluation
    for directory in tqdm(os.listdir(data_dir)):
        # Load and preprocess
        path = os.path.join(data_dir, directory)
        test_examples, test_dataset = pos_utils.load_data(path, batch_size, tokenizer, tagset, max_length)

        # Predict
        preds = model.predict(test_dataset, steps=np.ceil(len(test_examples) / batch_size), verbose=1)

        # Postprocessing
        tokens, labels, filtered_preds, logits = pos_utils.filter_padding_tokens(test_examples, preds, label_map, tokenizer)
        subword_locations = pos_utils.find_subword_locations(tokens)
        new_tokens, new_labels, new_preds = pos_utils.reconstruct_subwords(subword_locations, tokens, labels, 
                                                                           filtered_preds, logits)
        
        for d, l in zip(data_dicts, [filtered_preds, new_preds, labels, new_labels]):
            d[train_lang_name][code_to_name[directory]] = l

HBox(children=(FloatProgress(value=0.0, max=14.0), HTML(value='')))


Using weights from E:/TFM_CCIL/checkpoints\ar\bert-base-multilingual-cased_pos.hdf5


HBox(children=(FloatProgress(value=0.0, max=15.0), HTML(value='')))



Using weights from E:/TFM_CCIL/checkpoints\bg\bert-base-multilingual-cased_pos.hdf5


HBox(children=(FloatProgress(value=0.0, max=15.0), HTML(value='')))



Using weights from E:/TFM_CCIL/checkpoints\en\bert-base-multilingual-cased_pos.hdf5


HBox(children=(FloatProgress(value=0.0, max=15.0), HTML(value='')))



Using weights from E:/TFM_CCIL/checkpoints\eu\bert-base-multilingual-cased_pos.hdf5


HBox(children=(FloatProgress(value=0.0, max=15.0), HTML(value='')))



Using weights from E:/TFM_CCIL/checkpoints\fi\bert-base-multilingual-cased_pos.hdf5


HBox(children=(FloatProgress(value=0.0, max=15.0), HTML(value='')))



Using weights from E:/TFM_CCIL/checkpoints\he\bert-base-multilingual-cased_pos.hdf5


HBox(children=(FloatProgress(value=0.0, max=15.0), HTML(value='')))



Using weights from E:/TFM_CCIL/checkpoints\hr\bert-base-multilingual-cased_pos.hdf5


HBox(children=(FloatProgress(value=0.0, max=15.0), HTML(value='')))



Using weights from E:/TFM_CCIL/checkpoints\ja\bert-base-multilingual-cased_pos.hdf5


HBox(children=(FloatProgress(value=0.0, max=15.0), HTML(value='')))



Using weights from E:/TFM_CCIL/checkpoints\ko\bert-base-multilingual-cased_pos.hdf5


HBox(children=(FloatProgress(value=0.0, max=15.0), HTML(value='')))



Using weights from E:/TFM_CCIL/checkpoints\ru\bert-base-multilingual-cased_pos.hdf5


HBox(children=(FloatProgress(value=0.0, max=15.0), HTML(value='')))



Using weights from E:/TFM_CCIL/checkpoints\sk\bert-base-multilingual-cased_pos.hdf5


HBox(children=(FloatProgress(value=0.0, max=15.0), HTML(value='')))



Using weights from E:/TFM_CCIL/checkpoints\tr\bert-base-multilingual-cased_pos.hdf5


HBox(children=(FloatProgress(value=0.0, max=15.0), HTML(value='')))



Using weights from E:/TFM_CCIL/checkpoints\vi\bert-base-multilingual-cased_pos.hdf5


HBox(children=(FloatProgress(value=0.0, max=15.0), HTML(value='')))



Using weights from E:/TFM_CCIL/checkpoints\zh\bert-base-multilingual-cased_pos.hdf5


HBox(children=(FloatProgress(value=0.0, max=15.0), HTML(value='')))





In [61]:
lang_to_group = utils.make_lang_group_dict()

In [72]:
mw_langs = ["Arabic", "Hebrew", "Finnish", "Turkish"]

for lang_name in mw_langs:
    print("\n", lang_name, sep="")
    agglutinative = []
    others = []
    for test_lang in full_preds[lang_name].keys():
        if test_lang not in mw_langs:
            preds = full_preds[lang_name][test_lang]
            preds_agg = reconstructed_preds[lang_name][test_lang]
            print("{:<20}{:<20.2f}{:<20.2f}".format(test_lang, (np.array(preds) == 1).sum() / len(preds) * 100, 
                  (np.array(preds_agg) == 1).sum() / len(preds_agg) * 100))
            if lang_to_group[test_lang] == "Agglutinative":
                agglutinative.extend(preds)
            else:
                others.extend(preds)
    print("\n")
    overall = np.array([values for values in full_preds[lang_name].values()]).sum()
    print("Average: {:.2f}".format((np.array(overall) == 1).sum() / len(overall) * 100))
    print("Agglutinative: {:.2f}".format((np.array(agglutinative) == 1).sum() / len(agglutinative) * 100))
    print("Others: {:.2f}".format((np.array(others) == 1).sum() / len(others) * 100))
    print("\n")


Arabic
Bulgarian           0.69                0.62                
English             0.44                0.45                
Basque              2.46                1.80                
Croatian            0.25                0.22                
Japanese            11.08               11.27               
Korean              5.61                4.70                
Russian             0.40                0.32                
Slovak              0.28                0.21                
Thai                1.92                2.22                
Vietnamese          0.29                0.28                
Chinese             0.37                0.38                


Average: 3.67
Agglutinative: 6.26
Others: 0.77



Hebrew
Bulgarian           2.35                1.23                
English             0.24                0.15                
Basque              0.35                0.17                
Croatian            1.30                0.75                
Japanese          

In [49]:
non_mw_langs = [lang for lang in full_preds.keys() if lang not in ["Arabic", "Hebrew", "Finnish", "Turkish"]]

for lang_name in non_mw_langs:
    print("\n", lang_name)
    for test_lang, preds in full_preds[lang_name].items():
        if test_lang not in non_mw_langs:
            print("{:<20}{:<20}".format(test_lang, (np.array(preds) == 1).sum() / len(preds) * 100))


 Bulgarian
Arabic              0.0                 
Finnish             0.0                 
Hebrew              0.0                 
Thai                0.0                 
Turkish             0.0                 

 English
Arabic              0.0                 
Finnish             0.0                 
Hebrew              0.0                 
Thai                0.0                 
Turkish             0.0                 

 Basque
Arabic              0.0                 
Finnish             0.0                 
Hebrew              0.0                 
Thai                0.0                 
Turkish             0.0                 

 Croatian
Arabic              0.0                 
Finnish             0.0                 
Hebrew              0.0                 
Thai                0.0                 
Turkish             0.0                 

 Japanese
Arabic              0.0                 
Finnish             0.0                 
Hebrew              0.0                 
Thai

In [73]:
scores = pd.read_excel("../results/melted_results_pos.xlsx")

In [83]:
print("MW over non-MW")
print("{:.2f}".format(scores[scores["Train Language"].isin(mw_langs) & \
                             scores["Test Language"].isin(non_mw_langs)]["Transfer"].mean()))
print("\n")
print("MW over MW")
print("{:.2f}".format(scores[scores["Train Language"].isin(mw_langs) & \
                             scores["Test Language"].isin(mw_langs) & \
                             (scores["Train Language"] != scores["Test Language"])]["Transfer"].mean()))
print("\n")
print("-"*50)
print("non-MW over MW")
print("{:.2f}".format(scores[scores["Train Language"].isin(non_mw_langs) & \
                             scores["Test Language"].isin(mw_langs)]["Transfer"].mean()))
print("\n")
print("MW over non-MW")
print("{:.2f}".format(scores[scores["Train Language"].isin(non_mw_langs) & \
                             scores["Test Language"].isin(non_mw_langs) & \
                             (scores["Train Language"] != scores["Test Language"])]["Transfer"].mean()))

MW over non-MW
22.80


MW over MW
21.43


--------------------------------------------------
non-MW over MW
27.01


MW over non-MW
27.99
