In [1]:
from transformers import AutoTokenizer, AutoModelForMaskedLM, Trainer, TrainingArguments, EarlyStoppingCallback, DataCollatorForLanguageModeling
from foe_foundry_nl.data.training import load_mlm_dataset
from pathlib import Path
import torch
import math
import numpy as np
from nltk.corpus import wordnet

In [6]:

# model = "sentence-transformers/msmarco-MiniLM-L12-cos-v5"
model_name = "sentence-transformers/all-MiniLM-L12-v2"
tokenizer = AutoTokenizer.from_pretrained(model_name)
model = AutoModelForMaskedLM.from_pretrained(model_name)

dataset = load_mlm_dataset()
print(f"Training: {len(dataset['train'])}")
print(f"Eval: {len(dataset['eval'])}")
print(f"Test: {len(dataset['test'])}")

Some weights of BertForMaskedLM were not initialized from the model checkpoint at sentence-transformers/all-MiniLM-L12-v2 and are newly initialized: ['cls.predictions.bias', 'cls.predictions.decoder.bias', 'cls.predictions.transform.LayerNorm.bias', 'cls.predictions.transform.LayerNorm.weight', 'cls.predictions.transform.dense.bias', 'cls.predictions.transform.dense.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.


Training: 3485
Eval: 463
Test: 487


In [9]:
# tokenize data
def tokenize_function(examples):
    return tokenizer(
        examples["text"], padding="max_length", truncation=True, max_length=512
    )

tokenized_dataset = dataset.map(
    tokenize_function, batched=True, remove_columns=["text"]
)

# Set up Trainer
data_collator = DataCollatorForLanguageModeling(
    tokenizer=tokenizer, mlm=True, mlm_probability=0.15
)


n_train = len(dataset["train"])
num_train_epochs = 5
train_batch_size = 32
total_steps = (n_train // train_batch_size) * num_train_epochs
warmup_steps = total_steps // 10


output_dir = Path.cwd().parent / "models" / "minilm-finetuned"
output_dir.mkdir(exist_ok=True)
trainer = Trainer(
    model=model,
    args=TrainingArguments(
        output_dir=str(output_dir),
        eval_strategy="epoch",
        save_strategy="epoch",
        logging_dir="./logs",
        logging_strategy="steps",
        logging_steps=10,
        num_train_epochs=num_train_epochs,
        per_device_train_batch_size=train_batch_size,
        per_device_eval_batch_size=2 * train_batch_size,
        learning_rate=5e-5,
        lr_scheduler_type="cosine",
        warmup_steps=warmup_steps,
        load_best_model_at_end=True,
    ),
    train_dataset=tokenized_dataset["train"],  # type: ignore
    eval_dataset=tokenized_dataset["test"],  # type: ignore
    data_collator=data_collator,
    callbacks=[EarlyStoppingCallback(early_stopping_patience=3)],
)

trainer.train(resume_from_checkpoint=True)
model.save_pretrained(output_dir)
tokenizer.save_pretrained(output_dir)


Map:   0%|          | 0/3485 [00:00<?, ? examples/s]

Map:   0%|          | 0/463 [00:00<?, ? examples/s]

Map:   0%|          | 0/487 [00:00<?, ? examples/s]

There were missing keys in the checkpoint model loaded: ['cls.predictions.decoder.weight', 'cls.predictions.decoder.bias'].
  torch.load(os.path.join(checkpoint, OPTIMIZER_NAME), map_location=map_location)


  0%|          | 0/545 [00:00<?, ?it/s]

There were missing keys in the checkpoint model loaded: ['cls.predictions.decoder.weight', 'cls.predictions.decoder.bias'].


{'train_runtime': 0.1108, 'train_samples_per_second': 157248.385, 'train_steps_per_second': 4918.242, 'train_loss': 0.0, 'epoch': 5.0}


('c:\\code\\foe_foundry\\models\\minilm-finetuned\\tokenizer_config.json',
 'c:\\code\\foe_foundry\\models\\minilm-finetuned\\special_tokens_map.json',
 'c:\\code\\foe_foundry\\models\\minilm-finetuned\\vocab.txt',
 'c:\\code\\foe_foundry\\models\\minilm-finetuned\\added_tokens.json',
 'c:\\code\\foe_foundry\\models\\minilm-finetuned\\tokenizer.json')

In [2]:
fine_tuned_dir = Path.cwd().parent / "models" / "minilm-finetuned"

baseline_model_name = "bert-base-uncased"
tokenizer = AutoTokenizer.from_pretrained(baseline_model_name)
baseline_model = AutoModelForMaskedLM.from_pretrained(baseline_model_name)
finetuned_model = AutoModelForMaskedLM.from_pretrained(fine_tuned_dir)
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
baseline_model.to(device)
finetuned_model.to(device)

def predict_mask(sentence: str, model: AutoModelForMaskedLM) -> str:
    inputs = tokenizer(sentence, return_tensors="pt").to(device)  # type: ignore
    with torch.no_grad():
        outputs = model(**inputs)  # type: ignore
        logits = outputs.logits

    # Find the index of the [MASK] token
    mask_token_index = (inputs["input_ids"] == tokenizer.mask_token_id).nonzero(as_tuple=True)[1]  # type: ignore

    # Get the top predicted token for the [MASK]
    mask_token_logits = logits[0, mask_token_index, :]
    top_token = torch.argmax(mask_token_logits, dim=1)
    predicted_token = tokenizer.decode(top_token)  # type: ignore

    return predicted_token

def calculate_perplexity(model, sentence: str):
    # Tokenize and prepare input
    inputs = tokenizer(sentence, return_tensors="pt").to(device)
    inputs["labels"]  = inputs["input_ids"].clone()  # type: ignore
    with torch.no_grad():
        outputs = model(**inputs)
        logits = outputs.logits

    # Identify the position of the [MASK] token
    mask_token_index = (inputs["input_ids"] == tokenizer.mask_token_id).nonzero(as_tuple=True)[1]
    mask_logits = logits[0, mask_token_index, :]

    # Calculate softmax probabilities for the masked position
    log_prob = torch.log_softmax(mask_logits, dim=-1)

    # Get perplexity from the log-probability of the correct token
    correct_token_id = inputs["labels"][0, mask_token_index]
    token_log_prob = log_prob[0, correct_token_id].item()
    perplexity = math.exp(-token_log_prob)

    return perplexity



# Evaluation Sentences
eval_sentences = [
    # Monster category (50 sentences)
    (1, "Monster", "The [MASK] prowled through the forest under the cover of darkness.", "beast"),
    (2, "Monster", "With every swing, the giant's [MASK] echoed across the battlefield.", "club"),
    (3, "Monster", "The wyvern's tail lashed out, coated in deadly [MASK].", "venom"),
    (4, "Monster", "A [MASK] of goblins stormed the village, setting fire to everything in sight.", "horde"),
    (5, "Monster", "Legends tell of a creature that [MASK] fire from its maw.", "breathes"),
    (6, "Monster", "The ancient red dragon's scales shimmered with a [MASK] glow.", "fiery"),
    (7, "Monster", "Its eyes gleamed with [MASK], promising destruction.", "malice"),
    (8, "Monster", "The ogre charged forward with a deafening [MASK].", "roar"),
    (9, "Monster", "A cloud of chill [MASK] surrounded the banshee as it hovered above the ground.", "mist"),
    (10, "Monster", "The troll's regeneration ability allows it to heal unless exposed to [MASK].", "fire"),
    (11, "Monster", "An aura of [MASK] surrounds the demon, weakening all within reach.", "fear"),
    (12, "Monster", "The hydra grows two heads for every [MASK] one.", "severed"),
    (13, "Monster", "A dire wolf's [MASK] is powerful enough to knock down a grown man.", "bite"),
    (14, "Monster", "The undead knight's blade was cursed with [MASK] magic.", "necrotic"),
    (15, "Monster", "A chimera's three heads grant it a [MASK] of attacks.", "variety"),
    (16, "Monster", "The hidden [MASK] lurked just beyond the flickering torchlight.", "shadow"),
    (17, "Monster", "The behir's lightning breath deals [MASK] damage in a straight line.", "electric"),
    (18, "Monster", "The gelatinous cube absorbs anything that it [MASK].", "touches"),
    (19, "Monster", "A vampire cannot enter a home without an invitation from a [MASK].", "mortal"),
    (20, "Monster", "The harpy's song draws travelers into deadly [MASK].", "traps"),
    (21, "Monster", "When the basilisk's gaze falls upon its victim, they turn to [MASK].", "stone"),
    (22, "Monster", "A lich's phylactery contains its [MASK].", "soul"),
    (23, "Monster", "The dragon turtle can breathe underwater and expel a cone of [MASK].", "steam"),
    (24, "Monster", "The griffon can serve as a loyal [MASK] for a Paladin.", "steed"),
    (25, "Monster", "A kobold's trap is more dangerous than the [MASK] itself.", "creature"),
    (26, "Monster", "The minotaur guards the [MASK] deep within the labyrinth.", "treasure"),
    (27, "Monster", "A werewolf transforms under the [MASK] of a full moon.", "light"),
    (28, "Monster", "The cyclops wields a tree as if it were a [MASK].", "club"),
    (29, "Monster", "A manticore can fire [MASK] from its spiked tail.", "spines"),
    (30, "Monster", "The massive eagle-like [MASK] spreads its wings and takes to the sky.", "roc"),
    (31, "Monster", "The salamander thrives in the [MASK] of active volcanoes.", "heat"),
    (32, "Monster", "The banshee's wail can [MASK] the souls of the living.", "shatter"),
    (33, "Monster", "A gelatinous cube's form is nearly [MASK].", "invisible"),
    (34, "Monster", "A succubus charms its victims with [MASK].", "beauty"),
    (35, "Monster", "The golem follows the commands of its [MASK] without question.", "creator"),
    (36, "Monster", "A mimic disguises itself as mundane [MASK] to ambush prey.", "objects"),
    (37, "Monster", "The mind flayer consumes the [MASK] of its victims.", "brains"),
    (38, "Monster", "The kraken's tentacles can drag entire [MASK] into the sea.", "ships"),
    (39, "Monster", "A hellhound's breath smells of [MASK] sulfur.", "burning"),
    (40, "Monster", "The frost giant's ax is enchanted with [MASK] magic.", "ice"),
    (41, "Monster", "A night hag can invade dreams and spread [MASK].", "nightmares"),
    (42, "Monster", "The beholder's central eye projects an anti-[MASK] cone.", "magic"),
    (43, "Monster", "The basilisk's venom causes [MASK] to spread through the veins.", "paralysis"),
    (44, "Monster", "The dracolich is an [MASK] dragon and a powerful spellcaster.", "undead"),
    (45, "Monster", "The djinn can grant wishes, but always with a [MASK].", "price"),
    (46, "Monster", "The owlbear's screech is enough to scare off most [MASK].", "predators"),
    (47, "Monster", "A fire elemental engulfs everything in [MASK].", "flames"),
    (48, "Monster", "The ghost is trapped in the mortal plane by unfinished [MASK].", "business"),
    (49, "Monster", "A ghast's stench is so foul it causes those nearby to [MASK].", "retreat"),
    (50, "Monster", "The hydra retreats only when all of its heads are [MASK].", "destroyed"),

    # Spells category (10 sentences)
    (51, "Spells", "Casting [MASK] allows the wizard to hover for a short duration.", "levitate"),
    (52, "Spells", "The sorcerer can summon a [MASK] to fight by their side.", "familiar"),
    (53, "Spells", "The cleric's healing spell can restore [MASK] to the injured.", "life"),
    (54, "Spells", "The mage conjures an [MASK] of a roaring demon to deceive enemies.", "illusion"),
    (55, "Spells", "The paladin casts [MASK] of faith to protect their allies from harm.", "shield"),
    (56, "Spells", "A druid can speak with [MASK] using a simple incantation.", "animals"),
    (57, "Spells", "The bard's song inspires [MASK] in their companions.", "courage"),
    (58, "Spells", "With a flick of her wrist, she cast the [MASK] spell as a beam of red light streaks forth.", "fireball"),
    (59, "Spells", "The necromancer raises a [MASK] with a word of power.", "skeleton"),
    (60, "Spells", "The spell of invisibility renders the caster [MASK].", "undetectable"),

    # Rules category (10 sentences)
    (61, "Rules", "Casting a spell is an example of an [MASK].", "action"),
    (62, "Rules", "Rolling a [MASK] causes an automatic success.", "natural 20"),
    (63, "Rules", "Players can take an [MASK], bonus action, and a reaction on their turn", "action"),
    (64, "Rules", "Critical hits deal [MASK] the normal damage.", "double"),
    (65, "Rules", "A failed [MASK] throw can result in the paralyzed status.", "saving"),
    (66, "Rules", "An attack of opportunity occurs when an enemy [MASK].", "retreats"),
    (67, "Rules", "Spellcasters must maintain [MASK] to keep concentration spells active.", "focus"),
    (68, "Rules", "A short rest allows players to recover [MASK].", "hit points"),
    (69, "Rules", "During combat, players act in order of their [MASK].", "initiative"),
    (70, "Rules", "The DM can award [MASK] for excellent roleplaying.", "inspiration")
]

def get_synonyms(word):
    """Retrieve synonyms for a word from WordNet."""
    synonyms = {word}
    for syn in wordnet.synsets(word):
        for lemma in syn.lemmas():
            synonyms.add(lemma.name().lower())
    return synonyms

def is_correct_prediction(predicted, expected):
    """Check if the prediction matches the expected word or a synonym."""
    predicted = predicted.strip().lower()
    expected_synonyms = get_synonyms(expected)
    return predicted in expected_synonyms



def evaluate_model(model):
    correct_predictions = 0
    perplexities = []

    print(f"\nEvaluating Model: {model.config._name_or_path}\n")
    for (idx, category, sentence, expected) in eval_sentences:
        predicted = predict_mask(sentence, model)
        is_correct = is_correct_prediction(predicted, expected)
        perplexity = calculate_perplexity(model, sentence)
        perplexities.append(perplexity)
        correct_predictions += int(is_correct)

        print(f"ID: {idx} | Category: {category}")
        print(f"  Sentence: {sentence}")
        print(f"  Predicted: {predicted} | Expected: {expected} | Correct: {is_correct}\n")
        print(f"  Perplexity: {perplexity:.2f}\n")


    accuracy = correct_predictions / len(eval_sentences) * 100
    avg_perplexity = np.mean(perplexities)
    print(f"Accuracy: {accuracy:.2f}% ({correct_predictions}/{len(eval_sentences)})")
    print(f"Average Perplexity: {avg_perplexity:.2f}")


print("EVALUATION\n")
print("BASELINE:")
evaluate_model(baseline_model)
print("FINE TUNED:")
evaluate_model(finetuned_model)






BertForMaskedLM has generative capabilities, as `prepare_inputs_for_generation` is explicitly overwritten. However, it doesn't directly inherit from `GenerationMixin`. From 👉v4.50👈 onwards, `PreTrainedModel` will NOT inherit from `GenerationMixin`, and this model will lose the ability to call `generate` and other related functions.
  - If you are the owner of the model architecture code, please modify your model class such that it inherits from `GenerationMixin` (after `PreTrainedModel`, otherwise you'll get an exception).
  - If you are not the owner of the model architecture class, please contact the model code owner to update it.
Some weights of the model checkpoint at bert-base-uncased were not used when initializing BertForMaskedLM: ['bert.pooler.dense.bias', 'bert.pooler.dense.weight', 'cls.seq_relationship.bias', 'cls.seq_relationship.weight']
- This IS expected if you are initializing BertForMaskedLM from the checkpoint of a model trained on another task or with another archite

EVALUATION

BASELINE:

Evaluating Model: bert-base-uncased



  attn_output = torch.nn.functional.scaled_dot_product_attention(


ID: 1 | Category: Monster
  Sentence: The [MASK] prowled through the forest under the cover of darkness.
  Predicted: creature | Expected: beast | Correct: True

  Perplexity: 32256937.62

ID: 2 | Category: Monster
  Sentence: With every swing, the giant's [MASK] echoed across the battlefield.
  Predicted: scream | Expected: club | Correct: False

  Perplexity: 451404147.98

ID: 3 | Category: Monster
  Sentence: The wyvern's tail lashed out, coated in deadly [MASK].
  Predicted: venom | Expected: venom | Correct: True

  Perplexity: 91465673.26

ID: 4 | Category: Monster
  Sentence: A [MASK] of goblins stormed the village, setting fire to everything in sight.
  Predicted: mob | Expected: horde | Correct: False

  Perplexity: 83590448.52

ID: 5 | Category: Monster
  Sentence: Legends tell of a creature that [MASK] fire from its maw.
  Predicted: drew | Expected: breathes | Correct: False

  Perplexity: 742565048.87

ID: 6 | Category: Monster
  Sentence: The ancient red dragon's scales s