# BART Denoising Language Modeling Domain Adaptation
Taking inspiration from HuggingFace's [run_mlm.py](https://github.com/huggingface/transformers/blob/7032e0203262ebb2ebf55da8d2e01f873973e835/examples/pytorch/language-modeling/run_mlm.py) script, replacing the Masked Language Modeling task with Denoising Language Modeling using the [run_bart_dlm.py](https://github.com/huggingface/transformers/blob/ecd7de3dff7ea5713004b2f05e3869c24b8eb6e2/examples/flax/language-modeling/run_bart_dlm_flax.py) script written for Flax and adapted for PyTorch, we continue BART's pretraining on the MITOCW+OpenHPI+VT-SSum+Yale transcripts dataset.

In [1]:
#!pip install transformers datasets evaluate accelerate torch torchvision numpy nltk scikit-learn tensorboard wandb ray[tune] hyperopt ipynbname gdown matplotlib ipywidgets rouge_score

In [2]:
import os
import sys
from sys import stderr
import torch
from torch.cuda import OutOfMemoryError
import gc
import transformers
from transformers import (
    AutoTokenizer,
    BartConfig,
    BartForConditionalGeneration,
    BatchEncoding,
    PreTrainedTokenizerBase,
    Trainer,
    TrainingArguments,
    EarlyStoppingCallback,
)
import datasets
from datasets import DatasetDict, Dataset
import evaluate
from huggingface_hub import login as notebook_login
import gdown
from glob import glob
import subprocess
from subprocess import Popen, PIPE, CalledProcessError
import numpy as np
import json
from datetime import datetime
from sklearn.metrics import confusion_matrix, ConfusionMatrixDisplay
from ray.tune.schedulers import PopulationBasedTraining
from ray.tune.search.hyperopt import HyperOptSearch
import ipynbname
import wandb
from wandb import AlertLevel
import random
from PIL import ImageDraw, ImageFont, Image
import matplotlib.pyplot as plt
import pandas as pd
from IPython.display import display, HTML
from itertools import chain
import math
from dataclasses import dataclass
from typing import List, Dict
import nltk
nltk.download("punkt")

COLAB = 'google.colab' in sys.modules  # True if on Google Colab
os.environ['COLAB'] = "true" if COLAB else ""
PLATFORM = "Colab" if COLAB else "local"

DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")

# whether to train one model or run a hyperparameter search with Ray Tune
search_hyperparams = False

# https://docs.wandb.ai/guides/integrations/huggingface
# set the wandb project name
try:
    current_notebook_name = ipynbname.name()
except:
    current_notebook_name = "BART_DLM_domain_adaptation"  # TODO: set manually
if search_hyperparams:
    current_notebook_name += "_HPO"
%env WANDB_PROJECT=$current_notebook_name
# save the best model as an artifact on wandb
%env WANDB_LOG_MODEL=true
# log everything
%env WANDB_WATCH=all

def get_gpu_mem_free_percent():
    mem_free, mem_total = torch.cuda.mem_get_info() if torch.cuda.is_available() else (0, 1)
    return mem_free / mem_total * 100

def print_gpu_mem_free_percent():
    print(f"GPU memory available {get_gpu_mem_free_percent():.2f}%")

def empty_gpu_mem():
    # try to avoid RuntimeError: CUDA out of memory.
    """try:
        del model
    except NameError:
        pass
    try:
        del trainer
    except NameError:
        pass"""
    torch.cuda.empty_cache()
    gc.collect()

print(f"Platform is {PLATFORM}")
print(f"Running in directory {os.getcwd()}")
print(f"PyTorch version {torch.__version__}")
print(f"Transformers version {transformers.__version__}")
print(f"Datasets version {datasets.__version__}")
print(f"GPU {'available' if torch.cuda.is_available() else 'unavailable'}")
print_gpu_mem_free_percent()
if get_gpu_mem_free_percent() < 80:
    print("Trying to free a bit of GPU memory...")
    empty_gpu_mem()
    print_gpu_mem_free_percent()

# export the current GPU name to the environment, may be used later (e.g. for the batch size)
# should be Tesla T4 (-> 16GB), Tesla K80 (-> 12GB), GeForce GTX 1060 (-> 6GB)
!nvidia-smi --query-gpu=gpu_name --format=csv,noheader > /tmp/gpu_name
os.environ['GPU_NAME'] = open("/tmp/gpu_name", "r").read().strip()

!nvidia-smi
!df -h /

if COLAB:
    from google.colab import drive
    drive.mount('/content/drive')  # used to save model checkpoints
    from google.colab import output
    output.enable_custom_widget_manager()

env: WANDB_PROJECT=BART_DLM_domain_adaptation
env: WANDB_LOG_MODEL=true
env: WANDB_WATCH=all
Platform is local
Running in directory /home/caste/Documents/thesis/src/project/experiments/text-mlm-denoising
PyTorch version 1.13.0+cu117
Transformers version 4.25.1
Datasets version 2.7.1
GPU available
GPU memory available 90.49%


[nltk_data] Downloading package punkt to /home/caste/nltk_data...
[nltk_data]   Package punkt is already up-to-date!


Fri Jan 20 09:19:36 2023       
+-----------------------------------------------------------------------------+
| NVIDIA-SMI 525.60.11    Driver Version: 525.60.11    CUDA Version: 12.0     |
|-------------------------------+----------------------+----------------------+
| GPU  Name        Persistence-M| Bus-Id        Disp.A | Volatile Uncorr. ECC |
| Fan  Temp  Perf  Pwr:Usage/Cap|         Memory-Usage | GPU-Util  Compute M. |
|                               |                      |               MIG M. |
|   0  NVIDIA GeForce ...  Off  | 00000000:2B:00.0  On |                  N/A |
|  0%   43C    P2    46W / 350W |   2306MiB / 24576MiB |     12%      Default |
|                               |                      |                  N/A |
+-------------------------------+----------------------+----------------------+
                                                                               
+-----------------------------------------------------------------------------+
| Proces

In [3]:
use_bart_large = True

In [4]:
if not os.path.isfile(os.path.expanduser("~/.huggingface/token")):
    notebook_login()

### Convert the DataCollator for Flax to a PyTorch-compatible one
Here we can also see that while Flax works with numpy arrays internally, PyTorch works with (obviously) PyTorch tensors.  
The difference in code can be overcome by simply converting the numpy arrays to PyTorch tensors at the end of the `DataCollatorForBartDenoisingLM.__call__` method.

In [5]:
# PyTorch
def _pytorch_shift_tokens_right(input_ids: torch.Tensor, pad_token_id: int, decoder_start_token_id: int):
    """
    Shift input ids one token to the right.
    """
    shifted_input_ids = input_ids.new_zeros(input_ids.shape)
    shifted_input_ids[:, 1:] = input_ids[:, :-1].clone()
    shifted_input_ids[:, 0] = decoder_start_token_id

    if pad_token_id is None:
        raise ValueError("self.model.config.pad_token_id has to be defined.")
    # replace possible -100 values in labels by `pad_token_id`
    shifted_input_ids.masked_fill_(shifted_input_ids == -100, pad_token_id)

    return shifted_input_ids


# Flax
def _flax_shift_tokens_right(input_ids: np.array, pad_token_id: int, decoder_start_token_id: int) -> np.ndarray:
    """
    Shift input ids one token to the right.
    """
    shifted_input_ids = np.zeros_like(input_ids)
    shifted_input_ids[:, 1:] = input_ids[:, :-1]
    shifted_input_ids[:, 0] = decoder_start_token_id

    shifted_input_ids = np.where(shifted_input_ids == -100, pad_token_id, shifted_input_ids)
    return shifted_input_ids


# custom wrapper
def shift_tokens_right(input_ids, pad_token_id, decoder_start_token_id):
    return _flax_shift_tokens_right(input_ids, pad_token_id, decoder_start_token_id)

In [6]:
@dataclass
class DataCollatorForBartDenoisingLM:
    """
    From https://github.com/huggingface/transformers/blob/ecd7de3dff7ea5713004b2f05e3869c24b8eb6e2/examples/flax/language-modeling/run_bart_dlm_flax.py

    Data collator used for BART denoising language modeling. The code is largely copied from
    `<https://github.com/morganmcg1/rotobart/blob/main/data_collator.py#L223>`__.
    For more information on how BART denoising language modeling works, one can take a look
    at the `official paper <https://arxiv.org/pdf/1910.13461.pdf>`__
    or the `official code for preprocessing <https://github.com/facebookresearch/fairseq/blob/main/fairseq/data/denoising_dataset.py>`__ .
    Args:
        tokenizer (:class:`~transformers.PreTrainedTokenizer` or :class:`~transformers.PreTrainedTokenizerFast`):
            The tokenizer used for encoding the data
        mask_ratio (:obj:`float`):
            The probability with which to (randomly) mask tokens in the input
        poisson_lambda (:obj:`float`):
            Mean parameter of Poisson distribution used to generate span-lengths to be masked
        permute_sentence_ratio (:obj:`float`):
            Ratio of sentences to be permuted in each document
        decoder_start_token_id: (:obj:`int):
            The decoder start token id of the model
    """

    tokenizer: PreTrainedTokenizerBase
    decoder_start_token_id: int
    mask_ratio: float = 0.3
    poisson_lambda: float = 3.0
    permute_sentence_ratio: float = 1.0

    def __post_init__(self):
        if self.tokenizer.mask_token is None or self.tokenizer.eos_token is None:
            raise ValueError(
                "This tokenizer does not have a mask token or eos token token which is necessary for denoising"
                " language modeling. "
            )

    def __call__(self, examples: List[Dict[str, List[int]]]) -> BatchEncoding:
        # convert list to dict and tensorize input
        batch = BatchEncoding(
            {k: np.array([examples[i][k] for i in range(len(examples))]) for k, v in examples[0].items()}
        )
        batch["labels"] = batch["input_ids"].copy()
        batch["decoder_input_ids"] = shift_tokens_right(
            batch["labels"], self.tokenizer.pad_token_id, self.decoder_start_token_id
        )
        # permuting sentences
        do_permute = False
        if self.permute_sentence_ratio > 0.0:
            batch["input_ids"] = self.permute_sentences(batch["input_ids"])
            do_permute = True

        # masking span of tokens (text infilling in the paper)
        if self.mask_ratio:
            batch["input_ids"], batch["labels"] = self.span_mask_tokens(
                batch["input_ids"], batch["labels"], do_permute
            )

        # ignore pad tokens
        batch["attention_mask"] = (batch["input_ids"] != self.tokenizer.pad_token_id).astype(int)
        batch["decoder_attention_mask"] = (batch["decoder_input_ids"] != self.tokenizer.pad_token_id).astype(int)

        # NOTE: we need to convert the BatchEncoding to PyTorch Tensor to make this Flax dataloader work with PyTorch
        # all the previous computing in this method is done with numpy.ndarray lists
        return batch.convert_to_tensors(tensor_type="pt")

    def permute_sentences(self, input_ids):
        """
        Shuffle sentences in each document.
        """
        results = input_ids.copy()

        # find end locations of sentences
        end_sentence_mask = input_ids == self.tokenizer.pad_token_id
        sentence_ends = np.argwhere(end_sentence_mask)
        sentence_ends[:, 1] += 1
        example_has_multiple_sentences, num_sentences = np.unique(sentence_ends[:, 0], return_counts=True)
        num_sentences_map = {sent_idx: count for sent_idx, count in zip(example_has_multiple_sentences, num_sentences)}

        num_to_permute = np.ceil(num_sentences * self.permute_sentence_ratio).astype(int)
        num_to_permute_map = {
            sent_idx: count for sent_idx, count in zip(example_has_multiple_sentences, num_to_permute)
        }

        sentence_ends = np.split(sentence_ends[:, 1], np.unique(sentence_ends[:, 0], return_index=True)[1][1:])
        sentence_ends_map = {sent_idx: count for sent_idx, count in zip(example_has_multiple_sentences, sentence_ends)}

        for i in range(input_ids.shape[0]):
            if i not in example_has_multiple_sentences:
                continue
            substitutions = np.random.permutation(num_sentences_map[i])[: num_to_permute_map[i]]
            ordering = np.arange(0, num_sentences_map[i])
            ordering[substitutions] = substitutions[np.random.permutation(num_to_permute_map[i])]

            # write shuffled sentences into results
            index = 0
            for j in ordering:
                sentence = input_ids[i, (sentence_ends_map[i][j - 1] if j > 0 else 0):sentence_ends_map[i][j]]
                results[i, index:index + sentence.shape[0]] = sentence
                index += sentence.shape[0]
        return results

    def span_mask_tokens(self, input_ids, labels, do_permute):
        """
        Sampling text spans with span lengths drawn from a Poisson distribution and masking them.
        """
        special_tokens_mask_labels = [
            self.tokenizer.get_special_tokens_mask(val, already_has_special_tokens=True) for val in labels.tolist()
        ]
        special_tokens_mask_inputs = [
            self.tokenizer.get_special_tokens_mask(val, already_has_special_tokens=True) for val in input_ids.tolist()
        ]
        special_tokens_mask_labels = np.array(special_tokens_mask_labels, dtype=bool)
        special_tokens_mask_inputs = np.array(special_tokens_mask_inputs, dtype=bool)

        # determine how many tokens we need to mask in total
        is_token_mask = ~(input_ids == self.tokenizer.pad_token_id) & ~special_tokens_mask_inputs
        num_tokens_to_mask = int(math.ceil(is_token_mask.astype(float).sum() * self.mask_ratio))
        if num_tokens_to_mask == 0:
            return input_ids, labels

        # generate a sufficient number of span lengths
        span_lengths = np.random.poisson(lam=self.poisson_lambda, size=(num_tokens_to_mask,))
        while np.cumsum(span_lengths, 0)[-1] < num_tokens_to_mask:
            span_lengths = np.concatenate(
                [span_lengths, np.random.poisson(lam=self.poisson_lambda, size=(num_tokens_to_mask,))]
            )

        # remove all spans of length 0
        # note that BART inserts additional mask tokens where length == 0,
        # which we do not implement for now as it adds additional complexity
        span_lengths = span_lengths[span_lengths > 0]

        # trim to about num_tokens_to_mask tokens
        cutoff_idx = np.argmin(np.abs(np.cumsum(span_lengths, 0) - num_tokens_to_mask)) + 1
        span_lengths = span_lengths[:cutoff_idx]

        # randomly choose starting positions for masking
        token_indices = np.argwhere(is_token_mask == 1)
        span_starts = np.random.permutation(token_indices.shape[0])[: span_lengths.shape[0]]
        # prepare mask
        masked_indices = np.array(token_indices[span_starts])
        mask = np.full_like(input_ids, fill_value=False)

        # mask starting positions
        for mi in masked_indices:
            mask[tuple(mi)] = True
        span_lengths -= 1

        # fill up spans
        max_index = input_ids.shape[1] - 1
        remaining = (span_lengths > 0) & (masked_indices[:, 1] < max_index)
        while np.any(remaining):
            masked_indices[remaining, 1] += 1
            for mi in masked_indices:
                mask[tuple(mi)] = True
            span_lengths -= 1
            remaining = (span_lengths > 0) & (masked_indices[:, 1] < max_index)

        # place the mask tokens
        mask[np.where(special_tokens_mask_inputs)] = False
        input_ids[np.where(mask)] = self.tokenizer.mask_token_id
        if not do_permute:
            labels[np.where(mask == 0)] = -100
        else:
            labels[np.where(special_tokens_mask_labels)] = -100

        # remove mask tokens that are not starts of spans
        to_remove = (mask == 1) & np.roll((mask == 1), 1, 1)
        new_input_ids = np.full_like(input_ids, fill_value=self.tokenizer.pad_token_id)
        for i, example in enumerate(input_ids):
            new_example = example[~to_remove[i]]
            new_input_ids[i, : new_example.shape[0]] = new_example

        return new_input_ids, labels

### Preprocess the dataset

In [7]:
# https://huggingface.co/docs/transformers/main/en/tasks/language_modeling
dataset_id = "e-caste/mitocw_openhpi_vtssum_yale-lecture_transcripts"

text_column_name = "text"
raw_datasets = datasets.load_dataset(dataset_id, use_auth_token=True)
raw_datasets

Using custom data configuration e-caste--mitocw_openhpi_vtssum_yale-lecture_transcripts-09e7b44998532bac
Found cached dataset csv (/home/caste/.cache/huggingface/datasets/e-caste___csv/e-caste--mitocw_openhpi_vtssum_yale-lecture_transcripts-09e7b44998532bac/0.0.0/6b34fb8fcf56f7c8ba51dc895bfa2bfbe43546f190a60fcf74bb5e8afdcc2317)


  0%|          | 0/2 [00:00<?, ?it/s]

DatasetDict({
    train: Dataset({
        features: ['Unnamed: 0', 'text'],
        num_rows: 13209
    })
    validation: Dataset({
        features: ['Unnamed: 0', 'text'],
        num_rows: 1468
    })
})

In [8]:
model_id = f"facebook/bart-{'large' if use_bart_large else 'base'}"

config = BartConfig.from_pretrained(model_id)
tokenizer = AutoTokenizer.from_pretrained(model_id)

In [9]:
sentence_tokenizer = nltk.data.load("tokenizers/punkt/english.pickle")


def sentence_split_function(example):
    sents = sentence_tokenizer.tokenize(example["text"])
    # use pad token as end of sentence indicator
    new_text = tokenizer.bos_token + f"{tokenizer.pad_token}".join(sents) + tokenizer.eos_token
    return {"text": new_text}

In [10]:
split_datasets = raw_datasets.map(
    sentence_split_function,
    batched=False,
    num_proc=1,
    remove_columns=raw_datasets["train"].column_names,
    load_from_cache_file=True,
)
split_datasets

  0%|          | 0/13209 [00:00<?, ?ex/s]

  0%|          | 0/1468 [00:00<?, ?ex/s]

DatasetDict({
    train: Dataset({
        features: ['text'],
        num_rows: 13209
    })
    validation: Dataset({
        features: ['text'],
        num_rows: 1468
    })
})

In [11]:
# Tokenize every text, then concatenate them together before splitting them in smaller parts.
# Since we make sure that all sequences are of the same length, no attention_mask is needed.
def tokenize_function(examples):
    return tokenizer(examples[text_column_name], add_special_tokens=False, return_attention_mask=False)


tokenized_datasets = split_datasets.map(
    tokenize_function,
    batched=True,
    num_proc=1,
    remove_columns=text_column_name,
    load_from_cache_file=True,
)
tokenized_datasets

  0%|          | 0/14 [00:00<?, ?ba/s]

Token indices sequence length is longer than the specified maximum sequence length for this model (3242 > 1024). Running this sequence through the model will result in indexing errors


  0%|          | 0/2 [00:00<?, ?ba/s]

DatasetDict({
    train: Dataset({
        features: ['input_ids'],
        num_rows: 13209
    })
    validation: Dataset({
        features: ['input_ids'],
        num_rows: 1468
    })
})

In [12]:
max_seq_length = min(tokenizer.model_max_length, 1024)


# Main data processing function that will concatenate all texts from our dataset and generate chunks of
# max_seq_length.
def group_texts(examples):
    # Concatenate all texts.
    concatenated_examples = {k: list(chain(*examples[k])) for k in examples.keys()}
    total_length = len(concatenated_examples[list(examples.keys())[0]])
    # We drop the small remainder, we could add padding if the model supported it instead of this drop, you can
    # customize this part to your needs.
    if total_length >= max_seq_length:
        total_length = (total_length // max_seq_length) * max_seq_length
    # Split by chunks of max_len.
    result = {
        k: [t[i:i + max_seq_length] for i in range(0, total_length, max_seq_length)]
        for k, t in concatenated_examples.items()
    }
    return result


tokenized_datasets = tokenized_datasets.map(
    group_texts,
    batched=True,
    num_proc=1,
    load_from_cache_file=True,
)
tokenized_datasets

  0%|          | 0/14 [00:00<?, ?ba/s]

  0%|          | 0/2 [00:00<?, ?ba/s]

DatasetDict({
    train: Dataset({
        features: ['input_ids'],
        num_rows: 59694
    })
    validation: Dataset({
        features: ['input_ids'],
        num_rows: 6236
    })
})

In [13]:
# Data collator
# This one will take care of randomly masking the tokens and permuting the sentences.
data_collator = DataCollatorForBartDenoisingLM(
    tokenizer=tokenizer,
    decoder_start_token_id=config.decoder_start_token_id,
    mask_ratio=0.30,  # mlm_probability, Ratio of tokens to mask for span masked language modeling loss
    poisson_lambda=3.0,  # Mean of Poisson distribution used to generate span-lengths to be masked
    permute_sentence_ratio=1.0,  # Ratio of sentences to be permuted in each document
)

In [14]:
metric_name = "accuracy"
metric = evaluate.load(metric_name)


def preprocess_logits_for_metrics(logits, labels):
    if isinstance(logits, tuple):
        # Depending on the model and config, logits may contain extra tensors,
        # like past_key_values, but logits always come first
        logits = logits[0]
    return logits.argmax(dim=-1)


def compute_metrics(eval_preds):
    preds, labels = eval_preds
    # preds have the same shape as the labels, after the argmax(-1) has been calculated
    # by preprocess_logits_for_metrics
    labels = labels.reshape(-1)
    preds = preds.reshape(-1)
    mask = labels != -100
    labels = labels[mask]
    preds = preds[mask]
    return metric.compute(predictions=preds, references=labels)

### Configure the hyperparameters

In [15]:
run_name = "_".join([PLATFORM, os.environ['GPU_NAME'], model_id, str(datetime.now())]).replace("\n", "").replace("/", "-").replace(":", "_").replace(" ", "-")
tensorboard_logging_dir = None  # f"./tensorboard/{run_name}"
epochs = 1
eval_delay = 0
batch_size_train = 1 if use_bart_large else 4
batch_size_eval = 2 if use_bart_large else 8
gradient_accumulation_steps = 64 if use_bart_large else 1
evaluations_per_epoch = 10
total_steps = len(tokenized_datasets['train']) * epochs / (batch_size_train * gradient_accumulation_steps)
warmup_steps = total_steps // 100
eval_and_save_steps = int(total_steps / (evaluations_per_epoch * epochs))
early_stopping_patience = epochs * evaluations_per_epoch // 2

training_args = TrainingArguments(
    output_dir=current_notebook_name,
    run_name=run_name,
    save_strategy="steps",
    save_steps=eval_and_save_steps,
    save_total_limit=evaluations_per_epoch,
    evaluation_strategy="steps",
    eval_steps=eval_and_save_steps,
    eval_delay=eval_delay,
    auto_find_batch_size=False,
    per_device_train_batch_size=batch_size_train,
    per_device_eval_batch_size=batch_size_eval,
    num_train_epochs=epochs,
    adam_beta1=0.9,
    adam_beta2=0.999,
    adam_epsilon=1e-8,
    learning_rate=1e-4,
    weight_decay=0.01,
    warmup_steps=warmup_steps,
    lr_scheduler_type="linear",
    # make the model fit in the GPU
    gradient_accumulation_steps=gradient_accumulation_steps,
    gradient_checkpointing=use_bart_large,
    # fp16 vs bf16: https://www.reddit.com/r/MachineLearning/comments/vndtn8/comment/ie6dr2u/?utm_source=share&utm_medium=web2x&context=3
    # for full precision use tf32=True: https://developer.nvidia.com/blog/accelerating-ai-training-with-tf32-tensor-cores/
    fp16=False,
    bf16=False,  # needs RTX 3090
    fp16_full_eval=False,
    bf16_full_eval=False,  # Whether to use full bfloat16 evaluation instead of 32-bit. This will be faster and save memory but can harm metric values.
    half_precision_backend="cuda_amp",  # apex is deprecated: https://discuss.pytorch.org/t/torch-cuda-amp-vs-nvidia-apex/74994/2
    load_best_model_at_end=True,
    metric_for_best_model=metric_name,
    logging_dir=tensorboard_logging_dir,
    logging_strategy="steps",
    logging_steps=1,  # crucial for logging training loss -- Number of update steps between two logs if `logging_strategy="steps"`
    remove_unused_columns=True,
    report_to="wandb",
    disable_tqdm=search_hyperparams,
)

run_name

'local_NVIDIA-GeForce-RTX-3090_facebook-bart-large_2023-01-20-09_21_14.832994'

In [16]:
def get_new_trainer(model_name: str):
    return Trainer(
        model=BartForConditionalGeneration.from_pretrained(model_name, config=config),
        args=training_args,
        train_dataset=tokenized_datasets['train'],
        eval_dataset=tokenized_datasets['validation'],
        tokenizer=tokenizer,
        data_collator=data_collator,
        compute_metrics=compute_metrics,
        preprocess_logits_for_metrics=preprocess_logits_for_metrics,
    )


trainer = get_new_trainer(model_id)

In [17]:
def evaluate_perplexity_on_test_dataset(log_metrics: bool = False):
    metrics = trainer.evaluate(eval_dataset=tokenized_datasets['validation'])
    metrics["eval_samples"] = len(tokenized_datasets['validation'])

    try:
        perplexity = math.exp(metrics["eval_loss"])
    except OverflowError:
        perplexity = float("inf")
    metrics["perplexity"] = perplexity

    if log_metrics:
        trainer.log_metrics("eval", metrics)
        trainer.save_metrics("eval", metrics)

    print(f"Perplexity: {perplexity:.2f}")

### Test pretrained model performance
Uncomment the cell to evaluate the pretrained model. Running the training directly after the evaluation results in an OutOfMemoryError.

In [18]:
# evaluate_perplexity_on_test_dataset()

***** Running Evaluation *****
  Num examples = 6236
  Batch size = 2


Automatic Weights & Biases logging enabled, to disable set os.environ["WANDB_DISABLED"] = "true"
[34m[1mwandb[0m: Currently logged in as: [33me-caste[0m. Use [1m`wandb login --relogin`[0m to force relogin


Perplexity: 191.39


### Train the pretrained model

In [19]:
# we can run a grid search for the best hyperparameter combination
# https://huggingface.co/blog/ray-tune
if search_hyperparams:
    # Default objective is the sum of all metrics
    # when metrics are provided, so we have to maximize it.
    # this doesn't save a model artifact to wandb
    trainer.hyperparameter_search(
        direction="maximize",
        backend="ray",
        n_trials=3,  # number of trials
        # search_alg=HyperOptSearch(metric="objective", mode="max"),
        scheduler=PopulationBasedTraining(metric="objective", mode="max"),
    )
else:
    try:
        # this saves the best model as an artifact on wandb
        train_result = trainer.train(resume_from_checkpoint=None)
    except OutOfMemoryError as oome:
        print("Out of memory error: trying to free PyTorch cache...", file=stderr)
        empty_gpu_mem()
        wandb.alert(
            title=f"Run out of memory for {current_notebook_name}",
            text=f"Run name: {run_name}.\nSee details at https://wandb.ai/e-caste/{current_notebook_name}",
            level=AlertLevel.ERROR,
        )
        wandb.finish()
        raise oome
    except KeyboardInterrupt:
        wandb.alert(
            title=f"Run interrupted for {current_notebook_name}",
            text=f"Run name: {run_name}.\nManually interrupted.\nSee details at https://wandb.ai/e-caste/{current_notebook_name}",
            level=AlertLevel.WARN,
        )
        wandb.finish()
    except Exception as e:
        wandb.alert(
            title=f"Run errored out for {current_notebook_name}",
            text=f"Run name: {run_name}.\nException: {e}.\nSee details at https://wandb.ai/e-caste/{current_notebook_name}",
            level=AlertLevel.ERROR,
        )
        wandb.finish()
        raise e

***** Running training *****
  Num examples = 59694
  Num Epochs = 1
  Instantaneous batch size per device = 1
  Total train batch size (w. parallel, distributed & accumulation) = 64
  Gradient Accumulation steps = 64
  Total optimization steps = 932
  Number of trainable parameters = 406291456
Automatic Weights & Biases logging enabled, to disable set os.environ["WANDB_DISABLED"] = "true"
[34m[1mwandb[0m: Currently logged in as: [33me-caste[0m. Use [1m`wandb login --relogin`[0m to force relogin


Step,Training Loss,Validation Loss,Accuracy
93,0.9921,0.876833,0.818435
186,0.915,0.858755,0.820923
279,0.8946,0.845033,0.822847
372,0.9051,0.837078,0.824371
465,0.8882,0.829452,0.825726
558,0.8765,0.822582,0.826261
651,0.8942,0.819928,0.827048
744,0.8856,0.817075,0.827398
837,0.8321,0.812009,0.828041
930,0.8128,0.812546,0.828188


***** Running Evaluation *****
  Num examples = 6236
  Batch size = 2
Saving model checkpoint to BART_DLM_domain_adaptation/checkpoint-93
Configuration saved in BART_DLM_domain_adaptation/checkpoint-93/config.json
Model weights saved in BART_DLM_domain_adaptation/checkpoint-93/pytorch_model.bin
tokenizer config file saved in BART_DLM_domain_adaptation/checkpoint-93/tokenizer_config.json
Special tokens file saved in BART_DLM_domain_adaptation/checkpoint-93/special_tokens_map.json
***** Running Evaluation *****
  Num examples = 6236
  Batch size = 2
Saving model checkpoint to BART_DLM_domain_adaptation/checkpoint-186
Configuration saved in BART_DLM_domain_adaptation/checkpoint-186/config.json
Model weights saved in BART_DLM_domain_adaptation/checkpoint-186/pytorch_model.bin
tokenizer config file saved in BART_DLM_domain_adaptation/checkpoint-186/tokenizer_config.json
Special tokens file saved in BART_DLM_domain_adaptation/checkpoint-186/special_tokens_map.json
***** Running Evaluation **

In [20]:
trainer.save_model()  # Saves the tokenizer too for easy upload

metrics = train_result.metrics
metrics["train_samples"] = len(tokenized_datasets['train'])
trainer.log_metrics("train", metrics)
trainer.save_metrics("train", metrics)
trainer.save_state()

Saving model checkpoint to BART_DLM_domain_adaptation
Configuration saved in BART_DLM_domain_adaptation/config.json
Model weights saved in BART_DLM_domain_adaptation/pytorch_model.bin
tokenizer config file saved in BART_DLM_domain_adaptation/tokenizer_config.json
Special tokens file saved in BART_DLM_domain_adaptation/special_tokens_map.json


***** train metrics *****
  epoch                    =         1.0
  total_flos               = 120385974GF
  train_loss               =      0.9186
  train_runtime            =  8:03:59.22
  train_samples            =       59694
  train_samples_per_second =       2.056
  train_steps_per_second   =       0.032


### Test trained model performance

In [21]:
# load saved model from output directory
#trainer = get_new_trainer(model_name=current_notebook_name)
try:
    evaluate_perplexity_on_test_dataset(log_metrics=True)
except:
    wandb.alert(
        title=f"Evaluation error for {current_notebook_name}",
        text=f"Run name: {run_name}.\nSee details at https://wandb.ai/e-caste/{current_notebook_name}",
        level=AlertLevel.WARN,
    )
    wandb.finish()

***** Running Evaluation *****
  Num examples = 6236
  Batch size = 2


***** eval metrics *****
  epoch                   =        1.0
  eval_accuracy           =      0.828
  eval_loss               =     0.8125
  eval_runtime            = 0:10:13.27
  eval_samples            =       6236
  eval_samples_per_second =     10.168
  eval_steps_per_second   =      5.084
  perplexity              =     2.2535
Perplexity: 2.25


In [22]:
try:
    tmp = json.load(open(f"{current_notebook_name}/eval_results.json", "r"))
    eval_results = ""
    for k, v in tmp.items():
        eval_results += f"{k}: {v}\n"
except:
    eval_results = "No eval_results.json."
wandb.alert(
    title=f"Run finished for {current_notebook_name}",
    text=f"Run name: {run_name}.\n\n{eval_results}\nSee details at https://wandb.ai/e-caste/{current_notebook_name}",
    level=AlertLevel.INFO,
)
wandb.finish()

VBox(children=(Label(value='1553.453 MB of 1553.453 MB uploaded (0.000 MB deduped)\r'), FloatProgress(value=1.…

0,1
eval/accuracy,▁▃▄▅▆▇▇▇███
eval/loss,█▆▅▄▃▂▂▂▁▁▁
eval/runtime,█▆▆██▂▂▂▂▂▁
eval/samples_per_second,▁▃▃▁▁▇▇▇▇▇█
eval/steps_per_second,▁▃▃▁▁▇▇▇▇▇█
train/epoch,▁▁▁▂▂▂▂▂▂▃▃▃▃▃▄▄▄▄▄▄▅▅▅▅▅▆▆▆▆▆▆▇▇▇▇▇████
train/global_step,▁▁▁▂▂▂▂▂▂▃▃▃▃▃▄▄▄▄▄▄▅▅▅▅▅▅▆▆▆▆▆▇▇▇▇▇▇███
train/learning_rate,▇███▇▇▇▇▇▇▆▆▆▆▆▅▅▅▅▅▅▄▄▄▄▄▃▃▃▃▃▂▂▂▂▂▂▁▁▁
train/loss,█▄▄▃▃▃▃▂▃▃▂▃▂▃▂▂▂▂▂▂▂▂▂▁▂▂▁▂▁▂▂▁▂▁▁▂▂▁▂▁
train/total_flos,▁

0,1
eval/accuracy,0.82799
eval/loss,0.8125
eval/runtime,613.2715
eval/samples_per_second,10.168
eval/steps_per_second,5.084
train/epoch,1.0
train/global_step,932.0
train/learning_rate,0.0
train/loss,0.8769
train/total_flos,1.2926345530677658e+17


### Results

| Model         | Perplexity     | Accuracy     | Loss     | Notes |
|--------------|-----------|------------|------------|------------|
| facebook/bart-base | 36.81 | ? | 3.606 | without training        |
| facebook/bart-base      | 2.636 | 0.8001 | 0.9693 | epochs=1, trained on 80% split, lr=1e-4, bs=4       |
| facebook/bart-base      | 2.934 | 0.7849 | 1.0764 | epochs=1, trained on 90% split, lr=5e-4, bs=4       |
| facebook/bart-base      | 2.6317 | 0.8002 | 0.9676 | epochs=1, trained on 90% split, lr=7.5e-5, bs=4       |
| facebook/bart-base      | ? | 0.799149 | 0.976746 | epochs=1, trained on 90% split, lr=5e-5, bs=4       |
| facebook/bart-base      | 2.5718 | 0.8036 | 0.9446 | epochs=2, trained on 90% split, lr=5e-5, bs=4       |
| facebook/bart-base      | 2.6234 | 0.8008 | 0.9645 | epochs=1, trained on 90% split, lr=1e-4, bs=4       |
| facebook/bart-base      | 2.654 | 0.7993 | 0.9761 | epochs=1, trained on 90% split, lr=2e-4, bs=4       |
| facebook/bart-base      | 2.6298 | 0.8006 | 0.9669 | epochs=1, trained on 90% split, lr=1.5e-4, bs=4       |
| facebook/bart-base      | 2.6248 | 0.8006 | 0.965 | epochs=1, trained on 90% split, lr=9e-5, bs=4       |
| facebook/bart-base      | 2.5443 | 0.8056 | 0.9339 | epochs=2, trained on 90% split, lr=1e-4, bs=4       |
| facebook/bart-large     | 191.39 | ? | 5.254 | without training       |
| facebook/bart-large     | 2.2535 | 0.8282 | 0.8125 | epochs=1, trained on 90% split, lr=1e-4, bs=64       |