In [2]:
import evaluate
import numpy as np
import sys
import os
from datasets import load_from_disk, disable_caching
from sklearn.metrics import f1_score
from transformers import (
    AutoTokenizer,
    DataCollatorWithPadding,
    Trainer,
    TrainingArguments,
    set_seed,
)
sys.path.append(os.path.abspath('../../modules'))
from experiment_1.BertEntity import BertEntity

In [4]:
# import random
# seeds = [random.randint(0, 1e9) for _ in range(5)]
# seeds

In [3]:
disable_caching()

In [4]:
num_labels = 5
id2label = {
    0: "reject",
    1: "B_supplies_A",
    2: "A_supplies_B",
    3: "ambiguous",
    4: "ownership",
}
label2id = {
    "reject": 0,
    "B_supplies_A": 1,
    "A_supplies_B": 2,
    "ambiguous": 3,
    "ownership": 4,
}
metric = evaluate.load("f1")

In [5]:
model_name = "google-bert/bert-base-uncased"
tokenizer = AutoTokenizer.from_pretrained(model_name)
tokenizer.add_special_tokens(
    {"additional_special_tokens": ["__NE_FROM__", "__NE_TO__", "__NE_OTHER__"]}
)
data_collator = DataCollatorWithPadding(tokenizer=tokenizer)

In [6]:
def preprocess_function(examples):
    return tokenizer(examples["text"], truncation=True)

def compute_metrics(eval_pred):
    logits, labels = eval_pred
    predictions = np.argmax(logits, axis=-1)
    f1_micro = f1_score(labels, predictions, average='micro')
    f1_macro = f1_score(labels, predictions, average='macro')
    f1_classwise = f1_score(labels, predictions, average=None)

    return {
        "f1_micro": f1_micro,
        "f1_macro": f1_macro,
        **{f"f1_class_{i}": score for i, score in enumerate(f1_classwise)}
    }


def model_init():
    model = BertEntity.from_pretrained(
        model_name,
        num_labels=5,
        id2label=id2label,
        label2id=label2id,
    )
    model.resize_token_embeddings(len(tokenizer))
    return model

In [7]:
ds = load_from_disk("../../datasets/ManualDataset")
ds = ds.select_columns(["masked_text", "label"])
ds = ds.rename_column("masked_text", "text")
ds = ds.map(preprocess_function, batched=True)

Map:   0%|          | 0/2559 [00:00<?, ? examples/s]

Map:   0%|          | 0/418 [00:00<?, ? examples/s]

Map:   0%|          | 0/745 [00:00<?, ? examples/s]

In [8]:
def run_experiment(seed):
    set_seed(seed)
    training_args = TrainingArguments(
        seed=seed,
        data_seed=seed,
        tf32=True,
        output_dir="logs/experiment_1_bert_entity",
        learning_rate=2e-5,
        per_device_train_batch_size=16,
        per_device_eval_batch_size=16,
        num_train_epochs=10,
        weight_decay=0.01,
        eval_strategy="epoch",
        save_strategy="epoch",
        warmup_ratio=0.1,
        load_best_model_at_end=True,
        save_total_limit=1,
        report_to=[],
        save_only_model=True,
    )
    trainer = Trainer(
        model_init=model_init,
        args=training_args,
        train_dataset=ds["train"],
        eval_dataset=ds["valid"],
        tokenizer=tokenizer,
        data_collator=data_collator,
        compute_metrics=compute_metrics,
    )
    trainer.train()
    test_results = trainer.predict(ds["test"])
    return test_results

In [9]:
seeds = [656566143, 497239539, 527721645, 74564875, 180967469]
all_results = []

for seed in seeds:
    results = run_experiment(seed)
    all_results.append(results)

Some weights of BertEntity were not initialized from the model checkpoint at google-bert/bert-base-uncased and are newly initialized: ['classifier.bias', 'classifier.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.
Some weights of BertEntity were not initialized from the model checkpoint at google-bert/bert-base-uncased and are newly initialized: ['classifier.bias', 'classifier.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.


  0%|          | 0/1600 [00:00<?, ?it/s]

  0%|          | 0/27 [00:00<?, ?it/s]

{'eval_loss': 0.811123788356781, 'eval_f1_micro': 0.6961722488038278, 'eval_f1_macro': 0.6456507285136183, 'eval_f1_class_0': 0.7066246056782335, 'eval_f1_class_1': 0.4528301886792453, 'eval_f1_class_2': 0.7412587412587412, 'eval_f1_class_3': 0.6911764705882353, 'eval_f1_class_4': 0.6363636363636364, 'eval_runtime': 0.3852, 'eval_samples_per_second': 1085.234, 'eval_steps_per_second': 70.099, 'epoch': 1.0}


  0%|          | 0/27 [00:00<?, ?it/s]

{'eval_loss': 0.6324194073677063, 'eval_f1_micro': 0.7990430622009569, 'eval_f1_macro': 0.7713093757722664, 'eval_f1_class_0': 0.8255813953488372, 'eval_f1_class_1': 0.6349206349206349, 'eval_f1_class_2': 0.8215767634854771, 'eval_f1_class_3': 0.7659574468085106, 'eval_f1_class_4': 0.8085106382978723, 'eval_runtime': 0.3886, 'eval_samples_per_second': 1075.674, 'eval_steps_per_second': 69.481, 'epoch': 2.0}


  0%|          | 0/27 [00:00<?, ?it/s]

{'eval_loss': 0.5924275517463684, 'eval_f1_micro': 0.8181818181818182, 'eval_f1_macro': 0.7942578503787702, 'eval_f1_class_0': 0.8546511627906976, 'eval_f1_class_1': 0.6896551724137931, 'eval_f1_class_2': 0.8278688524590164, 'eval_f1_class_3': 0.7619047619047619, 'eval_f1_class_4': 0.8372093023255814, 'eval_runtime': 0.3674, 'eval_samples_per_second': 1137.773, 'eval_steps_per_second': 73.493, 'epoch': 3.0}
{'loss': 0.7616, 'grad_norm': 12.672073364257812, 'learning_rate': 1.5277777777777777e-05, 'epoch': 3.12}


  0%|          | 0/27 [00:00<?, ?it/s]

{'eval_loss': 0.579056978225708, 'eval_f1_micro': 0.8349282296650717, 'eval_f1_macro': 0.8141931750495456, 'eval_f1_class_0': 0.8604651162790697, 'eval_f1_class_1': 0.676923076923077, 'eval_f1_class_2': 0.8326848249027238, 'eval_f1_class_3': 0.84375, 'eval_f1_class_4': 0.8571428571428571, 'eval_runtime': 0.388, 'eval_samples_per_second': 1077.3, 'eval_steps_per_second': 69.586, 'epoch': 4.0}


  0%|          | 0/27 [00:00<?, ?it/s]

{'eval_loss': 0.6744696497917175, 'eval_f1_micro': 0.8301435406698564, 'eval_f1_macro': 0.8084029035007563, 'eval_f1_class_0': 0.8545994065281899, 'eval_f1_class_1': 0.6984126984126984, 'eval_f1_class_2': 0.8413284132841329, 'eval_f1_class_3': 0.7966101694915254, 'eval_f1_class_4': 0.851063829787234, 'eval_runtime': 0.3373, 'eval_samples_per_second': 1239.239, 'eval_steps_per_second': 80.047, 'epoch': 5.0}


  0%|          | 0/27 [00:00<?, ?it/s]

{'eval_loss': 0.6501762866973877, 'eval_f1_micro': 0.8421052631578947, 'eval_f1_macro': 0.8294743539069298, 'eval_f1_class_0': 0.8562874251497006, 'eval_f1_class_1': 0.7419354838709677, 'eval_f1_class_2': 0.8475836431226765, 'eval_f1_class_3': 0.832, 'eval_f1_class_4': 0.8695652173913043, 'eval_runtime': 0.3568, 'eval_samples_per_second': 1171.512, 'eval_steps_per_second': 75.672, 'epoch': 6.0}
{'loss': 0.1539, 'grad_norm': 11.089712142944336, 'learning_rate': 8.333333333333334e-06, 'epoch': 6.25}


  0%|          | 0/27 [00:00<?, ?it/s]

{'eval_loss': 0.7359728813171387, 'eval_f1_micro': 0.8373205741626795, 'eval_f1_macro': 0.8169533406822904, 'eval_f1_class_0': 0.8645533141210374, 'eval_f1_class_1': 0.7096774193548387, 'eval_f1_class_2': 0.84375, 'eval_f1_class_3': 0.8031496062992126, 'eval_f1_class_4': 0.8636363636363636, 'eval_runtime': 0.324, 'eval_samples_per_second': 1290.247, 'eval_steps_per_second': 83.341, 'epoch': 7.0}


  0%|          | 0/27 [00:00<?, ?it/s]

{'eval_loss': 0.7991121411323547, 'eval_f1_micro': 0.8421052631578947, 'eval_f1_macro': 0.8191743550625631, 'eval_f1_class_0': 0.8728323699421965, 'eval_f1_class_1': 0.7096774193548387, 'eval_f1_class_2': 0.8470588235294118, 'eval_f1_class_3': 0.8091603053435115, 'eval_f1_class_4': 0.8571428571428571, 'eval_runtime': 0.3622, 'eval_samples_per_second': 1153.934, 'eval_steps_per_second': 74.536, 'epoch': 8.0}


  0%|          | 0/27 [00:00<?, ?it/s]

{'eval_loss': 0.8177751898765564, 'eval_f1_micro': 0.84688995215311, 'eval_f1_macro': 0.8307908993058541, 'eval_f1_class_0': 0.8680351906158358, 'eval_f1_class_1': 0.7419354838709677, 'eval_f1_class_2': 0.8527131782945736, 'eval_f1_class_3': 0.8217054263565892, 'eval_f1_class_4': 0.8695652173913043, 'eval_runtime': 0.3603, 'eval_samples_per_second': 1160.024, 'eval_steps_per_second': 74.93, 'epoch': 9.0}
{'loss': 0.0401, 'grad_norm': 1.8022270202636719, 'learning_rate': 1.3888888888888892e-06, 'epoch': 9.38}


  0%|          | 0/27 [00:00<?, ?it/s]

{'eval_loss': 0.827440619468689, 'eval_f1_micro': 0.8397129186602871, 'eval_f1_macro': 0.8175387596899226, 'eval_f1_class_0': 0.8662790697674418, 'eval_f1_class_1': 0.71875, 'eval_f1_class_2': 0.84375, 'eval_f1_class_3': 0.8217054263565892, 'eval_f1_class_4': 0.8372093023255814, 'eval_runtime': 0.3288, 'eval_samples_per_second': 1271.387, 'eval_steps_per_second': 82.123, 'epoch': 10.0}
{'train_runtime': 89.6188, 'train_samples_per_second': 285.543, 'train_steps_per_second': 17.853, 'train_loss': 0.3000320391356945, 'epoch': 10.0}


  0%|          | 0/47 [00:00<?, ?it/s]

Some weights of BertEntity were not initialized from the model checkpoint at google-bert/bert-base-uncased and are newly initialized: ['classifier.bias', 'classifier.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.
Some weights of BertEntity were not initialized from the model checkpoint at google-bert/bert-base-uncased and are newly initialized: ['classifier.bias', 'classifier.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.


  0%|          | 0/1600 [00:00<?, ?it/s]

  0%|          | 0/27 [00:00<?, ?it/s]

{'eval_loss': 0.7663498520851135, 'eval_f1_micro': 0.7248803827751196, 'eval_f1_macro': 0.6351080309250128, 'eval_f1_class_0': 0.7828418230563002, 'eval_f1_class_1': 0.5813953488372093, 'eval_f1_class_2': 0.7673469387755102, 'eval_f1_class_3': 0.6153846153846154, 'eval_f1_class_4': 0.42857142857142855, 'eval_runtime': 0.3726, 'eval_samples_per_second': 1121.848, 'eval_steps_per_second': 72.464, 'epoch': 1.0}


  0%|          | 0/27 [00:00<?, ?it/s]

{'eval_loss': 0.5353245139122009, 'eval_f1_micro': 0.8014354066985646, 'eval_f1_macro': 0.7686074632586045, 'eval_f1_class_0': 0.8173913043478261, 'eval_f1_class_1': 0.696969696969697, 'eval_f1_class_2': 0.8377358490566038, 'eval_f1_class_3': 0.7540983606557377, 'eval_f1_class_4': 0.7368421052631579, 'eval_runtime': 0.363, 'eval_samples_per_second': 1151.526, 'eval_steps_per_second': 74.381, 'epoch': 2.0}


  0%|          | 0/27 [00:00<?, ?it/s]

{'eval_loss': 0.5335806608200073, 'eval_f1_micro': 0.8301435406698564, 'eval_f1_macro': 0.808478162161545, 'eval_f1_class_0': 0.8531073446327684, 'eval_f1_class_1': 0.7333333333333333, 'eval_f1_class_2': 0.8482490272373541, 'eval_f1_class_3': 0.7704918032786885, 'eval_f1_class_4': 0.8372093023255814, 'eval_runtime': 0.3724, 'eval_samples_per_second': 1122.478, 'eval_steps_per_second': 72.505, 'epoch': 3.0}
{'loss': 0.7515, 'grad_norm': 4.689485549926758, 'learning_rate': 1.5277777777777777e-05, 'epoch': 3.12}


  0%|          | 0/27 [00:00<?, ?it/s]

{'eval_loss': 0.5541859865188599, 'eval_f1_micro': 0.8301435406698564, 'eval_f1_macro': 0.8071702970767373, 'eval_f1_class_0': 0.8539325842696629, 'eval_f1_class_1': 0.7540983606557377, 'eval_f1_class_2': 0.8432835820895522, 'eval_f1_class_3': 0.7663551401869159, 'eval_f1_class_4': 0.8181818181818182, 'eval_runtime': 0.3469, 'eval_samples_per_second': 1204.817, 'eval_steps_per_second': 77.823, 'epoch': 4.0}


  0%|          | 0/27 [00:00<?, ?it/s]

{'eval_loss': 0.6108567714691162, 'eval_f1_micro': 0.8349282296650717, 'eval_f1_macro': 0.8259468034571562, 'eval_f1_class_0': 0.8490028490028491, 'eval_f1_class_1': 0.7540983606557377, 'eval_f1_class_2': 0.842911877394636, 'eval_f1_class_3': 0.8, 'eval_f1_class_4': 0.8837209302325582, 'eval_runtime': 0.3385, 'eval_samples_per_second': 1234.98, 'eval_steps_per_second': 79.771, 'epoch': 5.0}


  0%|          | 0/27 [00:00<?, ?it/s]

{'eval_loss': 0.6522165536880493, 'eval_f1_micro': 0.8397129186602871, 'eval_f1_macro': 0.8281739852040184, 'eval_f1_class_0': 0.8563218390804598, 'eval_f1_class_1': 0.7936507936507936, 'eval_f1_class_2': 0.8505747126436781, 'eval_f1_class_3': 0.7903225806451613, 'eval_f1_class_4': 0.85, 'eval_runtime': 0.3701, 'eval_samples_per_second': 1129.384, 'eval_steps_per_second': 72.951, 'epoch': 6.0}
{'loss': 0.1396, 'grad_norm': 5.446402072906494, 'learning_rate': 8.333333333333334e-06, 'epoch': 6.25}


  0%|          | 0/27 [00:00<?, ?it/s]

{'eval_loss': 0.7122995257377625, 'eval_f1_micro': 0.8373205741626795, 'eval_f1_macro': 0.8270422749951993, 'eval_f1_class_0': 0.8488372093023255, 'eval_f1_class_1': 0.7878787878787878, 'eval_f1_class_2': 0.8473282442748091, 'eval_f1_class_3': 0.8067226890756303, 'eval_f1_class_4': 0.8444444444444444, 'eval_runtime': 0.3536, 'eval_samples_per_second': 1181.97, 'eval_steps_per_second': 76.347, 'epoch': 7.0}


  0%|          | 0/27 [00:00<?, ?it/s]

{'eval_loss': 0.7961209416389465, 'eval_f1_micro': 0.8444976076555024, 'eval_f1_macro': 0.8436749930157182, 'eval_f1_class_0': 0.8537313432835821, 'eval_f1_class_1': 0.8307692307692308, 'eval_f1_class_2': 0.849624060150376, 'eval_f1_class_3': 0.8062015503875969, 'eval_f1_class_4': 0.8780487804878049, 'eval_runtime': 0.3954, 'eval_samples_per_second': 1057.29, 'eval_steps_per_second': 68.294, 'epoch': 8.0}


  0%|          | 0/27 [00:00<?, ?it/s]

{'eval_loss': 0.7953670024871826, 'eval_f1_micro': 0.8349282296650717, 'eval_f1_macro': 0.8267778134345682, 'eval_f1_class_0': 0.8461538461538461, 'eval_f1_class_1': 0.8125, 'eval_f1_class_2': 0.8473282442748091, 'eval_f1_class_3': 0.7906976744186046, 'eval_f1_class_4': 0.8372093023255814, 'eval_runtime': 0.3764, 'eval_samples_per_second': 1110.555, 'eval_steps_per_second': 71.734, 'epoch': 9.0}
{'loss': 0.0318, 'grad_norm': 1.2701342105865479, 'learning_rate': 1.3888888888888892e-06, 'epoch': 9.38}


  0%|          | 0/27 [00:00<?, ?it/s]

{'eval_loss': 0.7995858788490295, 'eval_f1_micro': 0.8421052631578947, 'eval_f1_macro': 0.8420799367880999, 'eval_f1_class_0': 0.8513119533527697, 'eval_f1_class_1': 0.8125, 'eval_f1_class_2': 0.8416988416988417, 'eval_f1_class_3': 0.816, 'eval_f1_class_4': 0.8888888888888888, 'eval_runtime': 0.3909, 'eval_samples_per_second': 1069.453, 'eval_steps_per_second': 69.08, 'epoch': 10.0}
{'train_runtime': 90.6949, 'train_samples_per_second': 282.155, 'train_steps_per_second': 17.642, 'train_loss': 0.28960464663803576, 'epoch': 10.0}


  0%|          | 0/47 [00:00<?, ?it/s]

Some weights of BertEntity were not initialized from the model checkpoint at google-bert/bert-base-uncased and are newly initialized: ['classifier.bias', 'classifier.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.
Some weights of BertEntity were not initialized from the model checkpoint at google-bert/bert-base-uncased and are newly initialized: ['classifier.bias', 'classifier.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.


  0%|          | 0/1600 [00:00<?, ?it/s]

  0%|          | 0/27 [00:00<?, ?it/s]

{'eval_loss': 0.8248170018196106, 'eval_f1_micro': 0.6961722488038278, 'eval_f1_macro': 0.5960513060827772, 'eval_f1_class_0': 0.7539267015706806, 'eval_f1_class_1': 0.41509433962264153, 'eval_f1_class_2': 0.725, 'eval_f1_class_3': 0.6417910447761194, 'eval_f1_class_4': 0.4444444444444444, 'eval_runtime': 0.3547, 'eval_samples_per_second': 1178.556, 'eval_steps_per_second': 76.127, 'epoch': 1.0}


  0%|          | 0/27 [00:00<?, ?it/s]

{'eval_loss': 0.6172729730606079, 'eval_f1_micro': 0.7799043062200957, 'eval_f1_macro': 0.7385629006687281, 'eval_f1_class_0': 0.7987616099071208, 'eval_f1_class_1': 0.5915492957746479, 'eval_f1_class_2': 0.8164794007490637, 'eval_f1_class_3': 0.7801418439716312, 'eval_f1_class_4': 0.7058823529411765, 'eval_runtime': 0.3535, 'eval_samples_per_second': 1182.295, 'eval_steps_per_second': 76.368, 'epoch': 2.0}


  0%|          | 0/27 [00:00<?, ?it/s]

{'eval_loss': 0.6593921184539795, 'eval_f1_micro': 0.7799043062200957, 'eval_f1_macro': 0.751415546577667, 'eval_f1_class_0': 0.7948717948717948, 'eval_f1_class_1': 0.6268656716417911, 'eval_f1_class_2': 0.8148148148148148, 'eval_f1_class_3': 0.7586206896551724, 'eval_f1_class_4': 0.7619047619047619, 'eval_runtime': 0.3447, 'eval_samples_per_second': 1212.48, 'eval_steps_per_second': 78.318, 'epoch': 3.0}
{'loss': 0.7578, 'grad_norm': 2.4464385509490967, 'learning_rate': 1.5277777777777777e-05, 'epoch': 3.12}


  0%|          | 0/27 [00:00<?, ?it/s]

{'eval_loss': 0.708695650100708, 'eval_f1_micro': 0.8014354066985646, 'eval_f1_macro': 0.7792907305078854, 'eval_f1_class_0': 0.8408408408408409, 'eval_f1_class_1': 0.6666666666666666, 'eval_f1_class_2': 0.808, 'eval_f1_class_3': 0.7516778523489933, 'eval_f1_class_4': 0.8292682926829268, 'eval_runtime': 0.3599, 'eval_samples_per_second': 1161.316, 'eval_steps_per_second': 75.013, 'epoch': 4.0}


  0%|          | 0/27 [00:00<?, ?it/s]

{'eval_loss': 0.7824329137802124, 'eval_f1_micro': 0.8086124401913876, 'eval_f1_macro': 0.7846558745912231, 'eval_f1_class_0': 0.8438356164383561, 'eval_f1_class_1': 0.6779661016949152, 'eval_f1_class_2': 0.8065843621399177, 'eval_f1_class_3': 0.765625, 'eval_f1_class_4': 0.8292682926829268, 'eval_runtime': 0.3733, 'eval_samples_per_second': 1119.601, 'eval_steps_per_second': 72.319, 'epoch': 5.0}


  0%|          | 0/27 [00:00<?, ?it/s]

{'eval_loss': 0.8282323479652405, 'eval_f1_micro': 0.8301435406698564, 'eval_f1_macro': 0.8090645810884564, 'eval_f1_class_0': 0.844311377245509, 'eval_f1_class_1': 0.6984126984126984, 'eval_f1_class_2': 0.8505747126436781, 'eval_f1_class_3': 0.8148148148148148, 'eval_f1_class_4': 0.8372093023255814, 'eval_runtime': 0.3377, 'eval_samples_per_second': 1237.86, 'eval_steps_per_second': 79.957, 'epoch': 6.0}
{'loss': 0.1414, 'grad_norm': 0.34553807973861694, 'learning_rate': 8.333333333333334e-06, 'epoch': 6.25}


  0%|          | 0/27 [00:00<?, ?it/s]

{'eval_loss': 0.8446513414382935, 'eval_f1_micro': 0.8373205741626795, 'eval_f1_macro': 0.820062249621704, 'eval_f1_class_0': 0.8613569321533924, 'eval_f1_class_1': 0.6875, 'eval_f1_class_2': 0.8549618320610687, 'eval_f1_class_3': 0.7874015748031497, 'eval_f1_class_4': 0.9090909090909091, 'eval_runtime': 0.3255, 'eval_samples_per_second': 1284.081, 'eval_steps_per_second': 82.943, 'epoch': 7.0}


  0%|          | 0/27 [00:00<?, ?it/s]

{'eval_loss': 0.8809350728988647, 'eval_f1_micro': 0.8421052631578947, 'eval_f1_macro': 0.8205471354660304, 'eval_f1_class_0': 0.8621700879765396, 'eval_f1_class_1': 0.71875, 'eval_f1_class_2': 0.8593155893536122, 'eval_f1_class_3': 0.8125, 'eval_f1_class_4': 0.85, 'eval_runtime': 0.3539, 'eval_samples_per_second': 1181.14, 'eval_steps_per_second': 76.294, 'epoch': 8.0}


  0%|          | 0/27 [00:00<?, ?it/s]

{'eval_loss': 0.9254699349403381, 'eval_f1_micro': 0.8492822966507177, 'eval_f1_macro': 0.8254070424772925, 'eval_f1_class_0': 0.8716417910447761, 'eval_f1_class_1': 0.6885245901639344, 'eval_f1_class_2': 0.8679245283018868, 'eval_f1_class_3': 0.8208955223880597, 'eval_f1_class_4': 0.8780487804878049, 'eval_runtime': 0.3692, 'eval_samples_per_second': 1132.156, 'eval_steps_per_second': 73.13, 'epoch': 9.0}
{'loss': 0.03, 'grad_norm': 0.7800927758216858, 'learning_rate': 1.3888888888888892e-06, 'epoch': 9.38}


  0%|          | 0/27 [00:00<?, ?it/s]

{'eval_loss': 0.9382026195526123, 'eval_f1_micro': 0.8421052631578947, 'eval_f1_macro': 0.8218425327297695, 'eval_f1_class_0': 0.8604651162790697, 'eval_f1_class_1': 0.7, 'eval_f1_class_2': 0.8615384615384616, 'eval_f1_class_3': 0.8091603053435115, 'eval_f1_class_4': 0.8780487804878049, 'eval_runtime': 0.3403, 'eval_samples_per_second': 1228.174, 'eval_steps_per_second': 79.332, 'epoch': 10.0}
{'train_runtime': 91.7529, 'train_samples_per_second': 278.901, 'train_steps_per_second': 17.438, 'train_loss': 0.2916619960963726, 'epoch': 10.0}


  0%|          | 0/47 [00:00<?, ?it/s]

Some weights of BertEntity were not initialized from the model checkpoint at google-bert/bert-base-uncased and are newly initialized: ['classifier.bias', 'classifier.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.
Some weights of BertEntity were not initialized from the model checkpoint at google-bert/bert-base-uncased and are newly initialized: ['classifier.bias', 'classifier.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.


  0%|          | 0/1600 [00:00<?, ?it/s]

  0%|          | 0/27 [00:00<?, ?it/s]

{'eval_loss': 0.7930118441581726, 'eval_f1_micro': 0.6985645933014354, 'eval_f1_macro': 0.6417689697107645, 'eval_f1_class_0': 0.7202380952380952, 'eval_f1_class_1': 0.6206896551724138, 'eval_f1_class_2': 0.7453874538745388, 'eval_f1_class_3': 0.6376811594202898, 'eval_f1_class_4': 0.48484848484848486, 'eval_runtime': 0.3406, 'eval_samples_per_second': 1227.255, 'eval_steps_per_second': 79.272, 'epoch': 1.0}


  0%|          | 0/27 [00:00<?, ?it/s]

{'eval_loss': 0.5819430351257324, 'eval_f1_micro': 0.7822966507177034, 'eval_f1_macro': 0.7568244998336827, 'eval_f1_class_0': 0.7746031746031746, 'eval_f1_class_1': 0.6349206349206349, 'eval_f1_class_2': 0.8243727598566308, 'eval_f1_class_3': 0.7883211678832117, 'eval_f1_class_4': 0.7619047619047619, 'eval_runtime': 0.3555, 'eval_samples_per_second': 1175.929, 'eval_steps_per_second': 75.957, 'epoch': 2.0}


  0%|          | 0/27 [00:00<?, ?it/s]

{'eval_loss': 0.5689365863800049, 'eval_f1_micro': 0.8086124401913876, 'eval_f1_macro': 0.7841603910276893, 'eval_f1_class_0': 0.8414985590778098, 'eval_f1_class_1': 0.65625, 'eval_f1_class_2': 0.8211382113821138, 'eval_f1_class_3': 0.7647058823529411, 'eval_f1_class_4': 0.8372093023255814, 'eval_runtime': 0.3547, 'eval_samples_per_second': 1178.548, 'eval_steps_per_second': 76.126, 'epoch': 3.0}
{'loss': 0.7429, 'grad_norm': 10.931737899780273, 'learning_rate': 1.5277777777777777e-05, 'epoch': 3.12}


  0%|          | 0/27 [00:00<?, ?it/s]

{'eval_loss': 0.568697988986969, 'eval_f1_micro': 0.8253588516746412, 'eval_f1_macro': 0.8150847961945062, 'eval_f1_class_0': 0.8404907975460123, 'eval_f1_class_1': 0.71875, 'eval_f1_class_2': 0.8357142857142857, 'eval_f1_class_3': 0.7967479674796748, 'eval_f1_class_4': 0.8837209302325582, 'eval_runtime': 0.3552, 'eval_samples_per_second': 1176.797, 'eval_steps_per_second': 76.013, 'epoch': 4.0}


  0%|          | 0/27 [00:00<?, ?it/s]

{'eval_loss': 0.6025853753089905, 'eval_f1_micro': 0.8397129186602871, 'eval_f1_macro': 0.8157214123601649, 'eval_f1_class_0': 0.8538011695906432, 'eval_f1_class_1': 0.6896551724137931, 'eval_f1_class_2': 0.8560606060606061, 'eval_f1_class_3': 0.8346456692913385, 'eval_f1_class_4': 0.8444444444444444, 'eval_runtime': 0.3468, 'eval_samples_per_second': 1205.389, 'eval_steps_per_second': 77.86, 'epoch': 5.0}


  0%|          | 0/27 [00:00<?, ?it/s]

{'eval_loss': 0.7130756378173828, 'eval_f1_micro': 0.8397129186602871, 'eval_f1_macro': 0.8219685336836113, 'eval_f1_class_0': 0.8628571428571429, 'eval_f1_class_1': 0.6875, 'eval_f1_class_2': 0.8384615384615385, 'eval_f1_class_3': 0.8429752066115702, 'eval_f1_class_4': 0.8780487804878049, 'eval_runtime': 0.3591, 'eval_samples_per_second': 1164.036, 'eval_steps_per_second': 75.189, 'epoch': 6.0}
{'loss': 0.1477, 'grad_norm': 0.7047950029373169, 'learning_rate': 8.333333333333334e-06, 'epoch': 6.25}


  0%|          | 0/27 [00:00<?, ?it/s]

{'eval_loss': 0.7300108671188354, 'eval_f1_micro': 0.8492822966507177, 'eval_f1_macro': 0.8349985378712056, 'eval_f1_class_0': 0.863768115942029, 'eval_f1_class_1': 0.7586206896551724, 'eval_f1_class_2': 0.8517110266159695, 'eval_f1_class_3': 0.84375, 'eval_f1_class_4': 0.8571428571428571, 'eval_runtime': 0.357, 'eval_samples_per_second': 1170.705, 'eval_steps_per_second': 75.62, 'epoch': 7.0}


  0%|          | 0/27 [00:00<?, ?it/s]

{'eval_loss': 0.7230962514877319, 'eval_f1_micro': 0.8732057416267942, 'eval_f1_macro': 0.8728490242720135, 'eval_f1_class_0': 0.8816568047337278, 'eval_f1_class_1': 0.8181818181818182, 'eval_f1_class_2': 0.8593155893536122, 'eval_f1_class_3': 0.896, 'eval_f1_class_4': 0.9090909090909091, 'eval_runtime': 0.3501, 'eval_samples_per_second': 1194.018, 'eval_steps_per_second': 77.126, 'epoch': 8.0}


  0%|          | 0/27 [00:00<?, ?it/s]

{'eval_loss': 0.7484925985336304, 'eval_f1_micro': 0.8732057416267942, 'eval_f1_macro': 0.8724738520668925, 'eval_f1_class_0': 0.8823529411764706, 'eval_f1_class_1': 0.8307692307692308, 'eval_f1_class_2': 0.8625954198473282, 'eval_f1_class_3': 0.8818897637795275, 'eval_f1_class_4': 0.9047619047619048, 'eval_runtime': 0.337, 'eval_samples_per_second': 1240.534, 'eval_steps_per_second': 80.13, 'epoch': 9.0}
{'loss': 0.039, 'grad_norm': 10.46073055267334, 'learning_rate': 1.3888888888888892e-06, 'epoch': 9.38}


  0%|          | 0/27 [00:00<?, ?it/s]

{'eval_loss': 0.7715902328491211, 'eval_f1_micro': 0.868421052631579, 'eval_f1_macro': 0.8635251109179585, 'eval_f1_class_0': 0.8823529411764706, 'eval_f1_class_1': 0.8307692307692308, 'eval_f1_class_2': 0.8582375478927203, 'eval_f1_class_3': 0.8682170542635659, 'eval_f1_class_4': 0.8780487804878049, 'eval_runtime': 0.372, 'eval_samples_per_second': 1123.756, 'eval_steps_per_second': 72.587, 'epoch': 10.0}
{'train_runtime': 88.9845, 'train_samples_per_second': 287.578, 'train_steps_per_second': 17.981, 'train_loss': 0.29178381726145747, 'epoch': 10.0}


  0%|          | 0/47 [00:00<?, ?it/s]

Some weights of BertEntity were not initialized from the model checkpoint at google-bert/bert-base-uncased and are newly initialized: ['classifier.bias', 'classifier.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.
Some weights of BertEntity were not initialized from the model checkpoint at google-bert/bert-base-uncased and are newly initialized: ['classifier.bias', 'classifier.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.


  0%|          | 0/1600 [00:00<?, ?it/s]

  0%|          | 0/27 [00:00<?, ?it/s]

{'eval_loss': 0.7698206305503845, 'eval_f1_micro': 0.7296650717703349, 'eval_f1_macro': 0.5996415393141383, 'eval_f1_class_0': 0.7751479289940828, 'eval_f1_class_1': 0.24390243902439024, 'eval_f1_class_2': 0.7835051546391752, 'eval_f1_class_3': 0.6956521739130435, 'eval_f1_class_4': 0.5, 'eval_runtime': 0.3583, 'eval_samples_per_second': 1166.651, 'eval_steps_per_second': 75.358, 'epoch': 1.0}


  0%|          | 0/27 [00:00<?, ?it/s]

{'eval_loss': 0.5775346159934998, 'eval_f1_micro': 0.7990430622009569, 'eval_f1_macro': 0.7686405055362281, 'eval_f1_class_0': 0.8467966573816156, 'eval_f1_class_1': 0.6909090909090909, 'eval_f1_class_2': 0.7950819672131147, 'eval_f1_class_3': 0.7299270072992701, 'eval_f1_class_4': 0.7804878048780488, 'eval_runtime': 0.3667, 'eval_samples_per_second': 1140.006, 'eval_steps_per_second': 73.637, 'epoch': 2.0}


  0%|          | 0/27 [00:00<?, ?it/s]

{'eval_loss': 0.5302379727363586, 'eval_f1_micro': 0.8181818181818182, 'eval_f1_macro': 0.7927423458910849, 'eval_f1_class_0': 0.8417910447761194, 'eval_f1_class_1': 0.6896551724137931, 'eval_f1_class_2': 0.8290909090909091, 'eval_f1_class_3': 0.7936507936507936, 'eval_f1_class_4': 0.8095238095238095, 'eval_runtime': 0.3449, 'eval_samples_per_second': 1211.831, 'eval_steps_per_second': 78.276, 'epoch': 3.0}
{'loss': 0.7529, 'grad_norm': 9.289810180664062, 'learning_rate': 1.5277777777777777e-05, 'epoch': 3.12}


  0%|          | 0/27 [00:00<?, ?it/s]

{'eval_loss': 0.6226517558097839, 'eval_f1_micro': 0.80622009569378, 'eval_f1_macro': 0.7770291227689217, 'eval_f1_class_0': 0.8439306358381503, 'eval_f1_class_1': 0.6984126984126984, 'eval_f1_class_2': 0.8130081300813008, 'eval_f1_class_3': 0.7605633802816901, 'eval_f1_class_4': 0.7692307692307693, 'eval_runtime': 0.3567, 'eval_samples_per_second': 1171.984, 'eval_steps_per_second': 75.702, 'epoch': 4.0}


  0%|          | 0/27 [00:00<?, ?it/s]

{'eval_loss': 0.6948862671852112, 'eval_f1_micro': 0.8277511961722488, 'eval_f1_macro': 0.809747822334702, 'eval_f1_class_0': 0.8282208588957055, 'eval_f1_class_1': 0.7, 'eval_f1_class_2': 0.8602941176470589, 'eval_f1_class_3': 0.8091603053435115, 'eval_f1_class_4': 0.851063829787234, 'eval_runtime': 0.3598, 'eval_samples_per_second': 1161.885, 'eval_steps_per_second': 75.05, 'epoch': 5.0}


  0%|          | 0/27 [00:00<?, ?it/s]

{'eval_loss': 0.7316552996635437, 'eval_f1_micro': 0.8349282296650717, 'eval_f1_macro': 0.8109704354395448, 'eval_f1_class_0': 0.8563049853372434, 'eval_f1_class_1': 0.6764705882352942, 'eval_f1_class_2': 0.8625954198473282, 'eval_f1_class_3': 0.7899159663865546, 'eval_f1_class_4': 0.8695652173913043, 'eval_runtime': 0.347, 'eval_samples_per_second': 1204.668, 'eval_steps_per_second': 77.813, 'epoch': 6.0}
{'loss': 0.1375, 'grad_norm': 1.5498089790344238, 'learning_rate': 8.333333333333334e-06, 'epoch': 6.25}


  0%|          | 0/27 [00:00<?, ?it/s]

{'eval_loss': 0.8203293085098267, 'eval_f1_micro': 0.8205741626794258, 'eval_f1_macro': 0.7982319788155743, 'eval_f1_class_0': 0.8333333333333334, 'eval_f1_class_1': 0.696969696969697, 'eval_f1_class_2': 0.862453531598513, 'eval_f1_class_3': 0.7611940298507462, 'eval_f1_class_4': 0.8372093023255814, 'eval_runtime': 0.3441, 'eval_samples_per_second': 1214.843, 'eval_steps_per_second': 78.471, 'epoch': 7.0}


  0%|          | 0/27 [00:00<?, ?it/s]

{'eval_loss': 0.8511439561843872, 'eval_f1_micro': 0.8349282296650717, 'eval_f1_macro': 0.8075598603103561, 'eval_f1_class_0': 0.8588235294117647, 'eval_f1_class_1': 0.676923076923077, 'eval_f1_class_2': 0.8679245283018868, 'eval_f1_class_3': 0.7704918032786885, 'eval_f1_class_4': 0.8636363636363636, 'eval_runtime': 0.3517, 'eval_samples_per_second': 1188.414, 'eval_steps_per_second': 76.764, 'epoch': 8.0}


  0%|          | 0/27 [00:00<?, ?it/s]

{'eval_loss': 0.8591747879981995, 'eval_f1_micro': 0.8373205741626795, 'eval_f1_macro': 0.8086969489445904, 'eval_f1_class_0': 0.8648648648648649, 'eval_f1_class_1': 0.676923076923077, 'eval_f1_class_2': 0.8603773584905661, 'eval_f1_class_3': 0.796875, 'eval_f1_class_4': 0.8444444444444444, 'eval_runtime': 0.3662, 'eval_samples_per_second': 1141.309, 'eval_steps_per_second': 73.721, 'epoch': 9.0}
{'loss': 0.0353, 'grad_norm': 12.705297470092773, 'learning_rate': 1.3888888888888892e-06, 'epoch': 9.38}


  0%|          | 0/27 [00:00<?, ?it/s]

{'eval_loss': 0.8726463317871094, 'eval_f1_micro': 0.8301435406698564, 'eval_f1_macro': 0.804288968878722, 'eval_f1_class_0': 0.8502994011976048, 'eval_f1_class_1': 0.6875, 'eval_f1_class_2': 0.8614232209737828, 'eval_f1_class_3': 0.7777777777777778, 'eval_f1_class_4': 0.8444444444444444, 'eval_runtime': 0.3519, 'eval_samples_per_second': 1187.765, 'eval_steps_per_second': 76.722, 'epoch': 10.0}
{'train_runtime': 88.1156, 'train_samples_per_second': 290.414, 'train_steps_per_second': 18.158, 'train_loss': 0.2905592140555382, 'epoch': 10.0}


  0%|          | 0/47 [00:00<?, ?it/s]

In [10]:
# Calculate mean and std of F1 scores
metrics = ["test_f1_micro", "test_f1_macro"] + [f"test_f1_class_{i}" for i in range(5)]  # Adjust range based on num_classes

avg_results = {}
for metric in metrics:
    scores = [r[metric] for r in [x.metrics for x in all_results]]
    avg_results[metric] = {
        'mean': np.mean(scores),
        'std': np.std(scores)
    }

# Print results
for metric in metrics:
    print(f"Average {metric}: {avg_results[metric]['mean']:.4f} ± {avg_results[metric]['std']:.4f}")

Average test_f1_micro: 0.7895 ± 0.0252
Average test_f1_macro: 0.7756 ± 0.0305
Average test_f1_class_0: 0.8215 ± 0.0211
Average test_f1_class_1: 0.7727 ± 0.0612
Average test_f1_class_2: 0.8094 ± 0.0076
Average test_f1_class_3: 0.6664 ± 0.0324
Average test_f1_class_4: 0.8079 ± 0.0405
