In [1]:
from transformers import TFBertForSequenceClassification, BertTokenizer, AutoTokenizer, TFAutoModelForSequenceClassification
import tensorflow as tf
import pandas as pd
import numpy as np
import os
import glob
from tqdm.notebook import tqdm
from sklearn.metrics import precision_score, recall_score, f1_score, accuracy_score
import IPython

import sys
sys.path.append("..")
from data_preparation.data_preparation_sentiment import Example, convert_examples_to_tf_dataset, make_batches
import utils.utils as utils

## Zero-shot

### Training language setup

In [9]:
code_dicts = utils.make_lang_code_dicts("../utils/lang_codes.xlsx")
code_to_name = code_dicts["code_to_name"]
name_to_code = code_dicts["name_to_code"]

results_path = "../results/results_sentiment.xlsx"

# Look for languages that have sentiment weights but are not in the results file
file = open("../data_exploration/sentiment_table.txt", "r")
all_langs = [line.split("&")[1].strip() for line in file.readlines()]
all_langs = [lang for lang in all_langs if lang not in ["Turkish", "Japanese", "Russian"]]
trained_langs = [code_to_name[x.split("\\")[1]] for x in glob.glob("E:/TFM_CCIL/checkpoints/*/*sentiment.hdf5")]

if os.path.isfile(results_path):
    results = pd.read_excel(results_path, sheet_name=None)
    remaining_langs = [lang for lang in trained_langs if lang not in results["Accuracy"].columns]
else:
    remaining_langs = trained_langs
    
untrained_langs = [lang for lang in all_langs if lang not in trained_langs]
evaluated_langs = [lang for lang in trained_langs if lang not in remaining_langs]

if remaining_langs:
    training_lang = remaining_langs[0]
    print("Evaluating with:   ", training_lang, "\n")
    training_lang = name_to_code[training_lang]
    print(IPython.utils.text.columnize(["Already evaluated:"] + evaluated_langs, displaywidth=150))
    print(IPython.utils.text.columnize(["Not yet evaluated:"] + remaining_langs[1:], displaywidth=150))
    print(IPython.utils.text.columnize(["Still to train:   "] + untrained_langs, displaywidth=150))
else:
    print("No languages remaining")
    print(IPython.utils.text.columnize(["Already evaluated:"] + evaluated_langs, displaywidth=150))
    print(IPython.utils.text.columnize(["Still to train:   "] + untrained_langs, displaywidth=150))

Evaluating with:    Thai 

Already evaluated:  Bulgarian  English  Basque  Finnish  Hebrew  Croatian  Slovak  Vietnamese  Chinese

Not yet evaluated:

Still to train:     Korean  Arabic



### Model setup

In [3]:
gpu_devices = tf.config.experimental.list_physical_devices('GPU')
tf.config.experimental.set_memory_growth(gpu_devices[0], True)

In [4]:
# Model parameters
model_name = "bert-base-multilingual-cased"
max_length = 512
batch_size = 64

# Model creation and loading weights
model = TFBertForSequenceClassification.from_pretrained(model_name)
weights_path = "E:/TFM_CCIL/checkpoints/" + training_lang + "/"
weights_filename = model_name + "_sentiment.hdf5"
model.load_weights(weights_path + weights_filename)
print("Using weights from", weights_path + weights_filename)
tokenizer = BertTokenizer.from_pretrained(model_name)

Some weights of the model checkpoint at bert-base-multilingual-cased were not used when initializing TFBertForSequenceClassification: ['mlm___cls', 'nsp___cls']
- This IS expected if you are initializing TFBertForSequenceClassification from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPretraining model).
- This IS NOT expected if you are initializing TFBertForSequenceClassification from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).
Some weights of TFBertForSequenceClassification were not initialized from the model checkpoint at bert-base-multilingual-cased and are newly initialized: ['dropout_37', 'classifier']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.


Using weights from E:/TFM_CCIL/checkpoints/zh/bert-base-multilingual-cased_sentiment.hdf5


### Evaluation

In [5]:
import logging
logging.getLogger("transformers.tokenization_utils_base").setLevel(logging.ERROR)

In [6]:
path = "../data/sentiment/"
sentiment_eval = []

for lang in tqdm(os.listdir(path)):
    if lang not in ["tr", "ja", "ru"]:
        # Load and preprocess
        test = pd.read_csv(path + lang + "/test.csv", header=None)
        test.columns = ["sentiment", "review"]
        lengths = test["review"].apply(lambda x: len(tokenizer.encode(x)))
        test = test[lengths <= 512].reset_index(drop=True) # Remove long examples
        
        # Convert to TF dataset
        test_dataset = convert_examples_to_tf_dataset([(Example(text=text, category_index=label)) for label, 
                                                       text in test.values], 
                                                      tokenizer, max_length=max_length)
        test_dataset, test_batches = make_batches(test_dataset, batch_size, repetitions=1, shuffle=False)
        
        # Predict
        preds = model.predict(test_dataset, steps=np.ceil(test.shape[0] / batch_size), verbose=1)
        clean_preds = preds[0].argmax(axis=-1)
        
        # Metrics
        accuracy = accuracy_score(test["sentiment"].values, clean_preds)
        precision = precision_score(test["sentiment"].values, clean_preds, average="macro")
        recall = recall_score(test["sentiment"].values, clean_preds, average="macro")
        f1 = f1_score(test["sentiment"].values, clean_preds, average="macro")
        sentiment_eval.append((lang, accuracy, precision, recall, f1))

HBox(children=(FloatProgress(value=0.0, max=15.0), HTML(value='')))




Build the table for this training language

In [7]:
sentiment_eval = np.array(sentiment_eval, dtype=object)
table = pd.DataFrame({"Language": sentiment_eval[:,0],
                      "Accuracy": sentiment_eval[:,1],
                      "Macro_Precision": sentiment_eval[:,2],
                      "Macro_Recall": sentiment_eval[:,3],
                      "Macro_F1": sentiment_eval[:,4]})
table["Language"] = table["Language"].apply(lambda x: code_to_name[x])

Reorder so that language types are grouped

In [8]:
file = open("../data_exploration/sentiment_table.txt", "r")
lang_order = [line.split("&")[1].strip() for line in file.readlines()]
lang_order = [lang for lang in lang_order if lang not in ["Turkish", "Japanese", "Russian"]]
table["sort"] = table["Language"].apply(lambda x: lang_order.index(x))
table = table.sort_values(by=["sort"]).drop("sort", axis=1).reset_index(drop=True)

In [9]:
table

Unnamed: 0,Language,Accuracy,Macro_Precision,Macro_Recall,Macro_F1
0,Bulgarian,0.456723,0.562093,0.591042,0.443895
1,English,0.584843,0.586812,0.584716,0.582267
2,Slovak,0.760338,0.495069,0.489536,0.480364
3,Croatian,0.727689,0.564179,0.548601,0.551284
4,Chinese,0.932025,0.926653,0.932968,0.929456
5,Vietnamese,0.526316,0.524553,0.522831,0.51597
6,Thai,0.580908,0.547758,0.53807,0.529074
7,Finnish,0.747449,0.609397,0.518707,0.480503
8,Basque,0.585903,0.518059,0.533259,0.483691
9,Korean,0.329144,0.421348,0.457106,0.306814


Update results excel file

In [10]:
results_path = "../results/results_sentiment.xlsx"

if os.path.isfile(results_path):
    results = pd.read_excel(results_path, sheet_name=None)
else:
    results = dict.fromkeys(table.columns[1:].values, pd.DataFrame({"Language": table["Language"].values}))

In [11]:
with pd.ExcelWriter(results_path) as writer:
    full_training_lang = code_to_name[training_lang]
    for sheet_name, df in results.items():
        # Add each the column for each metric in the corresponding sheet
        df[full_training_lang] = table[sheet_name]
        df.to_excel(writer, index=False, sheet_name=sheet_name)

## Checking dataset balance

In [110]:
path = "../data/sentiment/"

for lang in tqdm(os.listdir(path)):
    if lang not in ["tr", "ja", "ru"]:
        test = pd.read_csv(path + lang + "/test.csv", header=None)
        test.columns = ["sentiment", "review"]
        print(lang, test["sentiment"].mean())

HBox(children=(FloatProgress(value=0.0, max=15.0), HTML(value='')))

ar 0.8325724493846764
bg 0.8111050626020686
en 0.49917627677100496
eu 0.8458149779735683
fi 0.7455919395465995
he 0.040654997176736304
hr 0.7803203661327232
ko 2.8757886435331232
sl 0.9219924812030075
th 0.40784982935153585
vi 0.5138686131386861
zh 0.6045710139669871



## Calculating excluded examples

In [153]:
path = "../data/sentiment/"
max_lengths = {}
for directory in tqdm(os.listdir(path)):
    lang_path = os.path.join(path, directory)
    test = pd.read_csv(lang_path + "/test.csv", header=None)
    test.columns = ["sentiment", "review"]
    lengths = test["review"].apply(lambda x: len(tokenizer.encode(x)))
    above_512 = lengths > 512
    above_256 = lengths > 256
    max_lengths[directory] = (lengths.max(), 
                              above_512.sum(), round(above_512.mean() * 100, 2), 
                              above_256.sum(), round(above_256.mean() * 100, 2))

HBox(children=(FloatProgress(value=0.0, max=15.0), HTML(value='')))




In [155]:
max_lengths

{'ar': (5040, 446, 4.43, 1242, 12.33),
 'bg': (65, 0, 0.0, 0, 0.0),
 'en': (77, 0, 0.0, 0, 0.0),
 'eu': (79, 0, 0.0, 0, 0.0),
 'fi': (657, 5, 1.26, 26, 6.55),
 'he': (1183, 3, 0.17, 16, 0.9),
 'hr': (316, 0, 0.0, 1, 0.23),
 'ja': (6836, 685, 34.25, 1677, 83.85),
 'ko': (212, 0, 0.0, 0, 0.0),
 'ru': (9751, 496, 57.21, 750, 86.51),
 'sk': (411, 0, 0.0, 9, 0.85),
 'th': (1336, 8, 0.34, 48, 2.05),
 'tr': (2055, 184, 100.0, 184, 100.0),
 'vi': (826, 1, 0.15, 3, 0.44),
 'zh': (847, 11, 0.2, 106, 1.92)}

In [166]:
print(("{:<20}" * 6).format("Language", "Longest", ">512 tokens", ">512 tokens (%)", 
                                        ">256 tokens", ">256 tokens (%)") + "\n")
for lang, values in max_lengths.items():
    print(("{:<20}" * 6).format(*([lang] + list(values))))

Language            Longest             >512 tokens         >512 tokens (%)     >256 tokens         >256 tokens (%)     

ar                  5040                446                 4.43                1242                12.33               
bg                  65                  0                   0.0                 0                   0.0                 
en                  77                  0                   0.0                 0                   0.0                 
eu                  79                  0                   0.0                 0                   0.0                 
fi                  657                 5                   1.26                26                  6.55                
he                  1183                3                   0.17                16                  0.9                 
hr                  316                 0                   0.0                 1                   0.23                
ja                  6836       

In [167]:
file = open("../data_exploration/pos_table.txt", "r")
output = ""
lang_codes = pd.read_excel("../data_exploration/lang_codes.xlsx", header=0)

for line in file.readlines():
    lang_name = line.split("&")[1].strip()
    lang_code = lang_codes["ISO 639-1 Code"][lang_codes["English name of Language"] == lang_name].values[0]
    
    if lang_code in max_lengths:
        values = max_lengths[lang_code]
        split_line = line.split("\\")
        start = split_line[0] + "\\" + "&".join(split_line[1].split("&")[:2])
        end = r"\\" + "".join(split_line[2:])
        new_line = start + "& " + " & ".join(np.array(values[1:]).astype(str)) + end
        
    else:
        new_line = line
        
    output += new_line

In [168]:
print(output)

    \fusional{Fusional}  & Bulgarian & 0.0 & 0.0 & 0.0 & 0.0\\ 
    \fusional{Fusional} & English & 0.0 & 0.0 & 0.0 & 0.0\\
    \fusional{Fusional}  & Russian & 496.0 & 57.21 & 750.0 & 86.51\\ 
    \fusional{Fusional} & Slovak & 0.0 & 0.0 & 9.0 & 0.85\\
    \fusional{Fusional}  & Croatian & 0.0 & 0.0 & 1.0 & 0.23\\
    \isolating{Isolating} & Chinese & 11.0 & 0.2 & 106.0 & 1.92\\ 
    \isolating{Isolating} & Vietnamese  & 1.0 & 0.15 & 3.0 & 0.44\\
    \isolating{Isolating} & Thai & 8.0 & 0.34 & 48.0 & 2.05\\
    \agglutinative{Agglutinative} & Finnish & 5.0 & 1.26 & 26.0 & 6.55\\ 
    \agglutinative{Agglutinative} & Basque & 0.0 & 0.0 & 0.0 & 0.0\\
    \agglutinative{Agglutinative} & Japanese & 685.0 & 34.25 & 1677.0 & 83.85\\ 
    \agglutinative{Agglutinative} & Korean & 0.0 & 0.0 & 0.0 & 0.0\\ 
    \agglutinative{Agglutinative} & Turkish & 184.0 & 100.0 & 184.0 & 100.0\\
    \introflexive{Introflexive} & Arabic & 446.0 & 4.43 & 1242.0 & 12.33\\
    \introflexive{Introflexive} & Hebre

In [None]:
"  "