In [1]:
from TokenClassificationTrainer import TokenClassificationTrainer

  from .autonotebook import tqdm as notebook_tqdm


In [2]:
# Set the task and name of the pretrained model and the batch size for finetuning
task = "ner"
model_name = "xlm-mlm-17-1280"
batch_size = 16

# Flag to indicate whether to label all tokens or just the first token of each word
label_all_tokens = True

# Baseline

In [3]:
# File paths to splits of the chosen dataset
file_paths = {
    "train": "data/datasets/baseline/en_ewt_nn_train.conll",
    "validation": "data/datasets/baseline/en_ewt_nn_newsgroup_test.conll",
    "test": "data/datasets/baseline/en_ewt_nn_train.conll",
}

trainer = TokenClassificationTrainer(task, model_name, batch_size, label_all_tokens, file_paths)

# load trianed model to trainer
trainer.set_trainer(use_old = True)
baseline_eval = trainer.evaluate()

100%|██████████| 18/18 [00:50<00:00,  2.82s/it]                    


In [4]:
baseline_eval

{'eval_loss': 0.29972416162490845,
 'eval_precision': 0.6620111731843575,
 'eval_recall': 0.5163398692810458,
 'eval_f1': 0.5801713586291309,
 'eval_span_f1': 0.606516290726817,
 'eval_accuracy': 0.9149110247494375,
 'eval_runtime': 54.3129,
 'eval_samples_per_second': 5.247,
 'eval_steps_per_second': 0.331}

# NoSta-D

In [5]:
# File paths to splits of the chosen dataset
file_paths = {
    "train": "data/datasets/NoSta-D/NER-de-train.tsv",
    "validation": "data/datasets/NoSta-D/NER-de-test.tsv",
    "test": "data/datasets/NoSta-D/NER-de-dev.tsv",
}

trainer = TokenClassificationTrainer(task, model_name, batch_size, label_all_tokens, file_paths)

# load trianed model to trainer
trainer.set_trainer(use_old = True)
NoStaD_eval = trainer.evaluate()

100%|██████████| 319/319 [17:54<00:00,  3.37s/it]                  


# DaNplus

In [6]:
# File paths to splits of the chosen dataset
file_paths = {
    "train": "data/datasets/DaNplus/da_news_train.tsv",
    "validation": "data/datasets/DaNplus/da_news_test.tsv",
    "test": "data/datasets/DaNplus/da_news_dev.tsv",
}

trainer = TokenClassificationTrainer(task, model_name, batch_size, label_all_tokens, file_paths)

# load trianed model to trainer
trainer.set_trainer(use_old = True)
DaNplus_eval = trainer.evaluate()

100%|██████████| 36/36 [02:57<00:00,  4.92s/it]                  


# Eval

In [7]:
print("baseline")
print(baseline_eval)
print("NoStaD")
print(NoStaD_eval)
print("DaNplus")
print(DaNplus_eval)

baseline
{'eval_loss': 0.29972416162490845, 'eval_precision': 0.6620111731843575, 'eval_recall': 0.5163398692810458, 'eval_f1': 0.5801713586291309, 'eval_span_f1': 0.606516290726817, 'eval_accuracy': 0.9149110247494375, 'eval_runtime': 54.3129, 'eval_samples_per_second': 5.247, 'eval_steps_per_second': 0.331}
NoStaD
{'eval_loss': 0.32413148880004883, 'eval_precision': 0.5523602199816682, 'eval_recall': 0.46431667148223055, 'eval_f1': 0.5045261891057506, 'eval_span_f1': 0.5506407163154791, 'eval_accuracy': 0.9186659739932622, 'eval_runtime': 1077.3587, 'eval_samples_per_second': 4.734, 'eval_steps_per_second': 0.296}
DaNplus
{'eval_loss': 0.2991974353790283, 'eval_precision': 0.3388338833883388, 'eval_recall': 0.31950207468879666, 'eval_f1': 0.32888414308595826, 'eval_span_f1': 0.3733333333333333, 'eval_accuracy': 0.9193520191131053, 'eval_runtime': 181.317, 'eval_samples_per_second': 3.122, 'eval_steps_per_second': 0.199}


In [10]:
import pandas as pd

cols = ["Dataset", "Language"] + [name for name, _ in baseline_eval.items()]

df = pd.DataFrame(columns=cols)

# Add the evals to df
df.loc[0] = ["Baseline", "English"] + [value for _, value in baseline_eval.items()]
df.loc[1] = ["NoSta-D", "German"] + [value for _, value in NoStaD_eval.items()]
df.loc[2] = ["DaNplus", "Danish"] + [value for _, value in DaNplus_eval.items()]

df

Unnamed: 0,Dataset,Language,eval_loss,eval_precision,eval_recall,eval_f1,eval_span_f1,eval_accuracy,eval_runtime,eval_samples_per_second,eval_steps_per_second
0,Baseline,English,0.299724,0.662011,0.51634,0.580171,0.606516,0.914911,54.3129,5.247,0.331
1,NoSta-D,German,0.324131,0.55236,0.464317,0.504526,0.550641,0.918666,1077.3587,4.734,0.296
2,DaNplus,Danish,0.299197,0.338834,0.319502,0.328884,0.373333,0.919352,181.317,3.122,0.199
