In [None]:
! pip install transformers datasets sentence-transformers

# Evaluate different masking rates

In [None]:
from transformers import Trainer, TrainingArguments, DataCollatorForLanguageModeling
from datasets import load_dataset
from transformers import AutoTokenizer, AutoModelForMaskedLM, Trainer, TrainingArguments, DataCollatorForLanguageModeling
import torch
from torch.utils.data import DataLoader
import math

# Load Model & Tokenizer
model_name = "bert-base-uncased"
tokenizer = AutoTokenizer.from_pretrained(model_name)
model = AutoModelForMaskedLM.from_pretrained(model_name)

# Load Dataset
dataset = load_dataset("wikitext", "wikitext-2-raw-v1")

# Tokenize Function
def tokenize_function(examples):
    return tokenizer(
        examples['text'],
        truncation=True,
        padding='max_length',
        max_length=128,
        return_special_tokens_mask=True
    )

# Tokenize Dataset
tokenized_datasets = dataset.map(
    tokenize_function,
    batched=True,
    remove_columns=["text"]
)

# Select 100 samples from train and validation
train_subset = tokenized_datasets["train"]
eval_subset = tokenized_datasets["validation"]

# Training Arguments
training_args = TrainingArguments(
    output_dir="./bert-mlm-custom-masking",
    evaluation_strategy="steps",
    eval_steps=10,
    logging_steps=10,
    per_device_train_batch_size=8,
    per_device_eval_batch_size=8,
    num_train_epochs=1,
    save_steps=50,
    save_total_limit=2
)

# Load Wikipedia-style question dataset (e.g., SQuAD or TriviaQA) for evaluation
wiki_dataset = load_dataset("squad", split="validation[:100]")  # Sample 100 questions

In [3]:
def calculate_perplexity(text):
    """
    Calculates the perplexity of a given text using a pre-trained language model.

    Perplexity is computed as the exponential of the loss from the model.
    Lower perplexity indicates that the model is more confident in predicting the text.

    Args:
        text (str): The input text for which to compute perplexity.

    Returns:
        float: The perplexity score of the given text.
    """
    # Tokenize the input text and convert it into tensors
    inputs = tokenizer(text, return_tensors="pt", max_length=512, truncation=True)
    input_ids = inputs["input_ids"]

    # Perform inference without tracking gradients (to save memory and speed up computation)
    with torch.no_grad():
        outputs = model(input_ids, labels=input_ids)  # Compute the model loss
        loss = outputs.loss.item()  # Extract the loss value

    # Compute perplexity using the exponent of the loss
    perplexity = math.exp(loss)
    return perplexity


def evaluate_masking_rates(model, training_args, train_subset, eval_subset, masking_rate=0.3, dataset=wiki_dataset):
    """
    Evaluates the effect of different masking rates on model performance.

    This function:
    - Configures a data collator for masked language modeling.
    - Trains the model on a given train subset.
    - Evaluates the model on an evaluation subset.
    - Computes and prints the loss on a single batch.
    - Computes perplexity for Wikipedia-style questions.

    Args:
        model (transformers.PreTrainedModel): The pre-trained language model to be trained and evaluated.
        training_args (transformers.TrainingArguments): Training configuration parameters.
        train_subset (datasets.Dataset): The subset of the training dataset.
        eval_subset (datasets.Dataset): The subset of the evaluation dataset.
        masking_rate (float, optional): Probability of masking tokens during training. Defaults to 0.3.
        dataset (datasets.Dataset, optional): Dataset containing Wikipedia questions for perplexity evaluation.

    Returns:
        None: Prints evaluation results including loss and perplexity scores.
    """

    # Initialize the data collator for masked language modeling with the specified masking rate
    data_collator = DataCollatorForLanguageModeling(
        tokenizer=tokenizer,
        mlm=True,
        mlm_probability=masking_rate
    )

    # Set up the Trainer for training the model
    trainer = Trainer(
        model=model,
        args=training_args,
        data_collator=data_collator,
        train_dataset=train_subset,
        eval_dataset=eval_subset
    )

    # Train the model using the provided training data
    trainer.train()

    # Create a DataLoader to handle masked input processing
    eval_dataloader = DataLoader(
        eval_subset,
        batch_size=8,  # Process data in mini-batches
        collate_fn=data_collator  # Use the collator to apply masking
    )

    # Fetch one batch of evaluation data
    batch = next(iter(eval_dataloader))

    # Move the batch data to the model's device (e.g., GPU if available)
    batch = {k: v.to(model.device) for k, v in batch.items()}

    # Set the model to evaluation mode
    model.eval()
    with torch.no_grad():
        # Compute the loss on the evaluation batch
        outputs = model(**batch)
        loss = outputs.loss

    # Compute perplexity for Wikipedia-style questions from the dataset
    perplexities = [calculate_perplexity(q) for q in dataset["question"]]

    # Print the masking rate and the loss value
    print(f"Evaluating masking rate of: {masking_rate:.2%} \n")
    print(f"Loss on a single batch: {loss.item()}")
    print("Average Perplexity:", sum(perplexities) / len(perplexities))


In [8]:
def calculate_perplexity(text):
    """
    Calculates the perplexity of a given text using a pre-trained language model.

    Perplexity is computed as the exponential of the loss from the model.
    Lower perplexity indicates that the model is more confident in predicting the text.

    Args:
        text (str): The input text for which to compute perplexity.

    Returns:
        float: The perplexity score of the given text.
    """
    # Tokenize the input text and convert it into tensors
    inputs = tokenizer(text, return_tensors="pt", max_length=512, truncation=True)
    input_ids = inputs["input_ids"]

    # Perform inference without tracking gradients (to save memory and speed up computation)
    with torch.no_grad():
        outputs = model(input_ids, labels=input_ids)  # Compute the model loss
        loss = outputs.loss.item()  # Extract the loss value

    # Compute perplexity using the exponent of the loss
    perplexity = math.exp(loss)
    return perplexity


def evaluate_masking_rates(model, training_args, train_subset, eval_subset, masking_rate=0.3, dataset=wiki_dataset):
    """
    Evaluates the effect of different masking rates on model performance.

    This function:
    - Configures a data collator for masked language modeling.
    - Trains the model on a given train subset.
    - Evaluates the model on an evaluation subset.
    - Computes and prints the loss on a single batch.
    - Computes perplexity for Wikipedia-style questions.

    Args:
        model (transformers.PreTrainedModel): The pre-trained language model to be trained and evaluated.
        training_args (transformers.TrainingArguments): Training configuration parameters.
        train_subset (datasets.Dataset): The subset of the training dataset.
        eval_subset (datasets.Dataset): The subset of the evaluation dataset.
        masking_rate (float, optional): Probability of masking tokens during training. Defaults to 0.3.
        dataset (datasets.Dataset, optional): Dataset containing Wikipedia questions for perplexity evaluation.

    Returns:
        None: Prints evaluation results including loss and perplexity scores.
    """

    # Initialize the data collator for masked language modeling with the specified masking rate
    data_collator = DataCollatorForLanguageModeling(
        tokenizer=tokenizer,
        mlm=True,
        mlm_probability=masking_rate
    )

    # Set up the Trainer for training the model
    trainer = Trainer(
        model=model,
        args=training_args,
        data_collator=data_collator,
        train_dataset=train_subset,
        eval_dataset=eval_subset
    )

    # Train the model using the provided training data
    trainer.train()

    # Create a DataLoader to handle masked input processing
    eval_dataloader = DataLoader(
        eval_subset,
        batch_size=8,  # Process data in mini-batches
        collate_fn=data_collator  # Use the collator to apply masking
    )

    # Move model to evaluation mode
    model.eval()

    total_loss = 0.0
    num_batches = 0

    # Iterate through the full evaluation dataset
    with torch.no_grad():
        for batch in eval_dataloader:
            batch = {k: v.to(model.device) for k, v in batch.items()}
            outputs = model(**batch)
            total_loss += outputs.loss.item()
            num_batches += 1

    # Compute the average loss across all batches
    avg_loss = total_loss / num_batches
    perplexity = math.exp(avg_loss)

    # Print loss
    print(f"Evaluating masking rate of: {masking_rate:.2%} \n")
    print(f"Average Loss across Eval Set: {avg_loss:.4f}")
    print(f"Perplexity across Eval Set: {perplexity:.4f}")

    # Compute perplexity for Wikipedia-style questions from the dataset
    perplexities = [calculate_perplexity(q) for q in dataset["question"]]

    # Print the perplexity
    print("Average Perplexity:", sum(perplexities) / len(perplexities))

Note to reader: this is where you can try different masking rates and evaluate based on loss and perplexity yourself.

In [9]:
evaluate_masking_rates(model, training_args, train_subset, eval_subset, 0.3)

Step,Training Loss,Validation Loss
10,1.3742,2.980839


Evaluating masking rate of: 30.00%
Average Loss across Eval Set: 3.0823
Perplexity across Eval Set: 21.8085
Evaluating masking rate of: 30.00% 

Average Perplexity: 43.06413626229333


In [None]:
evaluate_masking_rates(model, training_args, train_subset, eval_subset, 0.5)

# Masking based on heuristics:

In this example, we will mask based on frequency of a word's occurence within the corpus. If its frequency is below a certain threshold, it will be more likely to be masked.

In [10]:
from transformers import Trainer, TrainingArguments
from transformers import AutoTokenizer, AutoModelForMaskedLM
from datasets import load_dataset
import torch


# Load model & tokenizer
model_name = "bert-base-uncased"
tokenizer = AutoTokenizer.from_pretrained(model_name)
model = AutoModelForMaskedLM.from_pretrained(model_name)

# Load dataset
dataset = load_dataset("wikitext", "wikitext-2-raw-v1")
train_dataset = dataset['train']
eval_dataset = dataset['validation']

# Build Frequency Dictionary
all_tokens = [word for sentence in train_dataset['text'] for word in sentence.split(' ')]
token_ids = tokenizer(all_tokens, add_special_tokens=False)['input_ids']
flat_token_ids = [token_id for sublist in token_ids for token_id in sublist]
freq_dict = {token_id: flat_token_ids.count(token_id) for token_id in set(flat_token_ids)}


# Custom Collator For Frequency Based Masking
class LowFrequencyMaskingCollator:
    def __init__(self, tokenizer, freq_dict, mask_prob=0.3, rare_threshold=5):
        self.tokenizer = tokenizer
        self.freq_dict = freq_dict
        self.mask_prob = mask_prob
        self.rare_threshold = rare_threshold

    def __call__(self, examples):
        # Extract the text column
        texts = [example['text'] for example in examples]

        # Tokenize
        batch = self.tokenizer(texts, padding=True, truncation=True, return_tensors='pt')

        input_ids = batch['input_ids']
        labels = input_ids.clone()

        for i in range(input_ids.shape[0]):
            for j in range(input_ids.shape[1]):
                token_id = input_ids[i, j].item()
                freq = self.freq_dict.get(token_id, 0)

                if freq < self.rare_threshold:
                    mask_prob = self.mask_prob * 2
                else:
                    mask_prob = self.mask_prob

                if torch.rand(1).item() < mask_prob:
                    input_ids[i, j] = self.tokenizer.mask_token_id

        batch['input_ids'] = input_ids
        batch['labels'] = labels

        return batch


collator = LowFrequencyMaskingCollator(tokenizer, freq_dict)


# Training arguments
training_args = TrainingArguments(
    output_dir="./bert-mlm-custom-masking",
    evaluation_strategy="steps",
    eval_steps=10,
    logging_steps=10,
    per_device_train_batch_size=4,
    per_device_eval_batch_size=4,
    num_train_epochs=1,
    save_steps=50,
    save_total_limit=1,
    remove_unused_columns=False
)

# Trainer
trainer = Trainer(
    model=model,
    args=training_args,
    data_collator=collator,
    train_dataset=train_dataset,
    eval_dataset=eval_dataset,
)

# Train
trainer.train()


Some weights of the model checkpoint at bert-base-uncased were not used when initializing BertForMaskedLM: ['bert.pooler.dense.bias', 'bert.pooler.dense.weight', 'cls.seq_relationship.bias', 'cls.seq_relationship.weight']
- This IS expected if you are initializing BertForMaskedLM from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).
- This IS NOT expected if you are initializing BertForMaskedLM from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).


Step,Training Loss,Validation Loss
10,8.0692,5.784812
20,5.0741,4.49767


Step,Training Loss,Validation Loss
10,8.0692,5.784812
20,5.0741,4.49767


TrainOutput(global_step=25, training_loss=6.260419082641602, metrics={'train_runtime': 452.5074, 'train_samples_per_second': 0.221, 'train_steps_per_second': 0.055, 'total_flos': 9080566041600.0, 'train_loss': 6.260419082641602, 'epoch': 1.0})

Note to reader: experiment with different masking and frequency threshold rates by inputting different values into the data collator in the cell above. How does it affect loss and perplexity below?

In [18]:
def evaluate_frequency_based_masking( model, data_collator,eval_subset=eval_dataset, dataset=wiki_dataset):
    """
    Evaluates the effect of different masking rates on model performance.

    This function:
    - Evaluates the model on an evaluation subset.
    - Computes and prints the loss on a single batch.
    - Computes perplexity for Wikipedia-style questions.

    Args:
        model (transformers.PreTrainedModel): The pre-trained language model to be trained and evaluated.
        eval_subset (datasets.Dataset, optional): The subset of the evaluation dataset. Defaults to eval subset from previous cell.
        dataset (datasets.Dataset, optional): Dataset containing Wikipedia questions for perplexity evaluation.

    Returns:
        None: Prints evaluation results including loss and perplexity scores.
    """
    # Create a DataLoader to handle masked input processing
    eval_dataloader = DataLoader(
        eval_subset,
        batch_size=8,  # Process data in mini-batches
        collate_fn=data_collator  # Use the collator to apply masking
    )

    # Move model to evaluation mode
    model.eval()

    total_loss = 0.0
    num_batches = 0

    # Iterate through the full evaluation dataset
    with torch.no_grad():
        for batch in eval_dataloader:
            batch = {k: v.to(model.device) for k, v in batch.items()}
            outputs = model(**batch)
            total_loss += outputs.loss.item()
            num_batches += 1

    # Compute the average loss across all batches
    avg_loss = total_loss / num_batches
    perplexity = math.exp(avg_loss)

    # Print loss
    print(f"Average Loss across Eval Set: {avg_loss:.4f}")
    print(f"Perplexity across Eval Set: {perplexity:.4f}")

    # Compute perplexity for Wikipedia-style questions from the dataset
    perplexities = [calculate_perplexity(q) for q in dataset["question"]]

    # Print the loss and average perplexity
    print("Average Perplexity:", sum(perplexities) / len(perplexities))


In [19]:
evaluate_frequency_based_masking(model, collator)

Average Loss across Eval Set: 4.1630
Perplexity across Eval Set: 64.2658
Average Perplexity: 12.48361186605892
Average Perplexity: 12.48361186605892


# Note to reader: can we also mask based on part of speech tags?