There are two versions of XLM-R - base and large. The architectural differences can be seen in Table 7. of the original paper (https://arxiv.org/abs/1911.02116). In our experiments it took about 3x more time to train the large model (on a single V100 GPU it took ~1 hour and ~3 hours to train on English for the UD POS task). Therefore training on the base model would significantly reduce the computation resources used for the experiment. In this notebook we check the difference in performance between base and large for every task when trained on English and evaluated on all languages. For all tasks the differences were between 1 and 3%, so we decided to use the base model for further experiments.

In [4]:
import pandas as pd

In [5]:
def get_accuracy(filename):()
    results = {"lang":[],"acc":[],"f1":[]}
    with open(filename) as f:
        lines = f.read().split("\n")
        f.close()

    for line in lines:
        if "language" in line:
            language = line.split("=")[1]
            results["lang"].append(language)
        elif "accuracy" in line:
            accuracy = line.split()[2]
            results["acc"].append(float(accuracy))
        elif "f1" in line:
            f1 = line.split()[2]
            results["f1"].append(float(f1))
        else:
            continue
    results_df = pd.DataFrame.from_dict(results)
    return results_df

In [6]:
def get_merged_results(task):
    if not task in ["udpos","panx"]:
        raise ValueError("Wrong task name. Possible task names are \"udpos\" or \"panx\"")
    base = get_accuracy("dev_results_base_{}.txt".format(task))
    large = get_accuracy("dev_results_large_{}.txt".format(task))
    merged = base.merge(large,on="lang",suffixes=("_base","_large"))
    merged["acc_diff"] = merged["acc_large"]-merged["acc_base"]
    merged["f1_diff"] = merged["f1_large"]-merged["f1_base"]
    return merged

## UD POS

In [12]:
udpos = get_merged_results("udpos")

In [9]:
udpos.sort_values(by="acc_diff",ascending=False).head()

Unnamed: 0,lang,acc_base,f1_base,acc_large,f1_large,acc_diff,f1_diff
28,zh,0.415462,0.32144,0.599937,0.518073,0.184474,0.196633
17,ja,0.334931,0.283111,0.476954,0.39178,0.142023,0.108669
8,eu,0.749243,0.678295,0.801121,0.745933,0.051878,0.067638
26,ur,0.719224,0.622772,0.749194,0.672522,0.029971,0.04975
19,mr,0.8325,0.766082,0.86,0.815029,0.0275,0.048947


In [10]:
udpos.describe()

Unnamed: 0,acc_base,f1_base,acc_large,f1_large,acc_diff,f1_diff
count,29.0,29.0,29.0,29.0,29.0,29.0
mean,0.779038,0.734121,0.798541,0.757257,0.019503,0.023136
std,0.155571,0.176603,0.126849,0.1491,0.042378,0.042229
min,0.334931,0.283111,0.476954,0.39178,-0.014233,-0.014566
25%,0.733418,0.649947,0.7543,0.690602,0.002616,0.003844
50%,0.855693,0.815821,0.857593,0.815029,0.005819,0.007879
75%,0.886631,0.87247,0.887233,0.875398,0.020882,0.025952
max,0.966487,0.960648,0.969856,0.964492,0.184474,0.196633


### Languages used

In [18]:
" ".join(sorted(udpos.lang.unique()))

'af ar bg de el en es et eu fa fi fr he hi hu id it ja ko mr nl pt ru ta te tr ur vi zh'

In [21]:
print("Nr of langauges:",udpos.lang.nunique())

Nr of langauges: 29


For udpos we are missing kk, th, tl, yo (kazakh, thai, togolug, yoruba) because they don't have dev sets

## Panx (NER)

In [13]:
panx = get_merged_results("panx")

In [15]:
panx.sort_values(by="acc_diff",ascending=False).head()

Unnamed: 0,lang,acc_base,f1_base,acc_large,f1_large,acc_diff,f1_diff
3,id,0.705829,0.483915,0.785134,0.537441,0.079305,0.053526
34,zh,0.672967,0.203989,0.74719,0.286597,0.074222,0.082608
2,vi,0.818271,0.644834,0.888013,0.773377,0.069742,0.128543
16,bn,0.780784,0.684811,0.850034,0.769163,0.06925,0.084352
27,ja,0.659142,0.151699,0.715916,0.191296,0.056774,0.039597


In [16]:
panx.describe()

Unnamed: 0,acc_base,f1_base,acc_large,f1_large,acc_diff,f1_diff
count,40.0,40.0,40.0,40.0,40.0,40.0
mean,0.807968,0.594411,0.830343,0.633658,0.022375,0.039246
std,0.125255,0.18227,0.119325,0.181245,0.02589,0.036152
min,0.215551,0.04069,0.242977,0.013995,-0.061531,-0.04678
25%,0.789286,0.488263,0.808318,0.535642,0.010069,0.019739
50%,0.822197,0.6395,0.845238,0.684223,0.019869,0.041113
75%,0.888497,0.736594,0.913811,0.775874,0.030133,0.066622
max,0.92838,0.832816,0.934936,0.845551,0.079305,0.128543


### Languages used

In [19]:
" ".join(sorted(panx.lang.unique()))

'af ar bg bn de el en es et eu fa fi fr he hi hu id it ja jv ka kk ko ml mr ms my nl pt ru sw ta te th tl tr ur vi yo zh'

In [20]:
print("Nr of langauges:",panx.lang.nunique())

Nr of langauges: 40
