In [None]:
%pip install -q transformers datasets evaluate accelerate scikit-learn torch

## Imports

In [None]:
from datasets import (
    load_dataset, 
    DatasetDict, 
)
import torch
from typing import Dict, Any
from transformers import (
    AutoModelForMaskedLM,
    AutoTokenizer,
    DataCollatorForLanguageModeling,
    TrainingArguments, 
    Trainer,
    pipeline
)

## Load Dataset

In [None]:
# Task 1 - Load the train and test splits from ag_news. Randomly select 10% of the training set as validation.

SEED = 42

dataset = load_dataset("ag_news")
dataset = dataset.shuffle(SEED)

# dataset["train"] = load_dataset("ag_news", split="train[:4000]") # Note: This is useful for sanity checking the training process. Comment out/Uncomment as necessary

train_val_dataset = dataset["train"].train_test_split(test_size=0.1, seed=SEED)  # Split training set into training and validation set

# Construct new dataset object from old test, new train and new validation sets
dataset = DatasetDict({
    'train': train_val_dataset["train"],
    'test': dataset["test"],
    'val': train_val_dataset['test']
})

## Preprocessing Function

In [None]:
# Task 3

# TODO: Check if EOS Token is correctly inserted
tokenizer = AutoTokenizer.from_pretrained("distilroberta-base")
tokenizer.pad_token = tokenizer.eos_token

def preprocess_function(sample: Dict[str, Any], seq_len: int):
    return tokenizer(sample["text"], truncation=True, padding="max_length", max_length=seq_len)

encoded_ds = dataset.map(preprocess_function, 
                         fn_kwargs={"seq_len": 256},
                         remove_columns=['label'])

## Data Collator

In [None]:
# Task 4
data_collator = DataCollatorForLanguageModeling(tokenizer=tokenizer, 
                                                mlm_probability=0.10)

## Load Model

In [None]:
# Task 5
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model = AutoModelForMaskedLM.from_pretrained("distilroberta-base").to(device)

In [None]:
# Task 6
print(model)

## Define TrainingArguments

In [None]:
# Task 7
# TODO: Learning Rate Scheduler, Weight Decay

# ['linear', 'cosine', 'cosine_with_restarts', 'polynomial', 'constant', 'constant_with_warmup', 'inverse_sqrt', 'reduce_lr_on_plateau']
lr_scheduler_type = "linear"

training_args = TrainingArguments(output_dir = './checkpoints/',
                                  do_train=True,
                                  do_eval=True,
                                  per_device_train_batch_size=32,
                                  per_device_eval_batch_size=32,
                                  learning_rate=2e-5,
                                  evaluation_strategy="epoch",
                                  num_train_epochs=5,
                                  load_best_model_at_end=True,
                                  save_strategy="epoch",
                                  lr_scheduler_type=lr_scheduler_type,
                                  weight_decay=0.1,
                                  load_best_model_at_end=True,
                                  )

## Define Trainer

In [None]:
# Task 8
trainer = Trainer(model=model,
                  args=training_args,
                  train_dataset = encoded_ds["train"],
                  eval_dataset = encoded_ds["test"],
                  data_collator=data_collator)

## Train Model

In [None]:
# Task 9

# TODO: Hyper parameter tuning
#   - batch size
#   - number of epochs
#   - weight decay
#   - learning rate
# Note: Should be executed in Google Colab
# Note: Does not yet work as intended... Training loss does not seem to go down...

trainer.train()

## Evaluation on Validation and Test Splits with Perplexity

In [None]:
# Task 10
# TODO: Calculate perplexity on validation and test splits
# Note: Check out this: https://huggingface.co/docs/transformers/perplexity

# from evaluate import load
# perplexity = load("perplexity", module_type="metric")

# predictions_train = trainer(encoded_ds["test"])
# predictions_val = trainer(encoded_ds["val"])

# results_train = perplexity.compute(predictions=predictions_train)
# results_val = perplexity.compute(predictions=predictions_val)

# print(results_train)
# print(results_val)

## Inference

In [None]:
# Task 11

text = "E-mail scam targets police chief Wiltshire Police warns about <mask> after its fraud squad chief was targeted."

mask_filler = pipeline('mask-filler', trainer)
mask_filler(text, top_k=5)