In [None]:
%%writefile util.py
from pathlib import Path
import os
def get_epoch_checkpoints(model_dir):
    checkpoint_ids = sorted([int(str(x).split("-")[-1]) for x in Path(model_dir).glob("checkpoint-*")])

    return [os.path.join(model_dir,"checkpoint-{}".format(c)) for c in checkpoint_ids]
    # epoch_checkpoints = c[5::6]
    # if c[-1] not in epoch_checkpoints:
    #     epoch_checkpoints.pop()
    #     epoch_checkpoints.append(c[-1])
    return epoch_checkpoints

# def get_all_chunks(checkpoint_path,gradient_input_dir, gradients_per_file):
#     return [ os.path.join(gradient_input_dir, checkpoint_path.split("-")[-1] + "_" + str(i) + "_" + str(i + gradients_per_file)) for i in range(0, len(dataset["train"]), args.gradients_per_file)]
def get_epoch(checkpoint_path):
    checkpoint_ids = sorted([int(str(x).split("-")[-1]) for x in Path(os.path.dirname(checkpoint_path)).glob("checkpoint-*")])
    return checkpoint_ids.index(int(str(checkpoint_path).split("-")[-1]))


import xxhash

h = xxhash.xxh64()
def get_seed_for_document(document, epoch):
    """Returns the seed to be used to set torch.manual_seed when doing dynamic masking

    Args:
        document: A string to get the seed for
        epoch: The epoch to get the seed for

    Returns:
        An integer
    """
    h.update(document.cpu().numpy())
    h.update(bytes(epoch))
    seed = h.intdigest()
    h.reset()
    return seed

from transformers import DataCollatorForLanguageModeling
import torch
class DeterministicDataCollatorForLanguageModeling (DataCollatorForLanguageModeling): 
    def torch_mask_tokens(self, inputs, special_tokens_mask = None):
        """
        Adapted to make dynamic masking determinsitic based on (text, epoch). 
        Just wrapped the original implementation in a for loop where a seed based on (labels, epoch) is set for each individual example before masking.
        """
        labels = inputs.clone()

        if special_tokens_mask is None:
            special_tokens_mask = [
                self.tokenizer.get_special_tokens_mask(val, already_has_special_tokens=True) for val in labels.tolist()
            ]
            special_tokens_mask = torch.tensor(special_tokens_mask, dtype=torch.bool)
        else:
            special_tokens_mask = special_tokens_mask.bool()

        for i in range(0, labels.shape[0]):
            torch.manual_seed(get_seed_for_document(labels[i], self.epoch))

            # We sample a few tokens in each sequence for MLM training (with probability `self.mlm_probability`)

            probability_matrix = torch.full(labels[i:i+1].shape, self.mlm_probability)


           
            probability_matrix.masked_fill_(special_tokens_mask[i:i+1], value=0.0)
            masked_indices = torch.bernoulli(probability_matrix).bool()
            labels[i:i+1][~masked_indices] = -100  # We only compute loss on masked tokens

            # 80% of the time, we replace masked input tokens with tokenizer.mask_token ([MASK])
            indices_replaced = torch.bernoulli(torch.full(labels[i:i+1].shape, 0.8)).bool() & masked_indices
            inputs[i:i+1][indices_replaced] = self.tokenizer.convert_tokens_to_ids(self.tokenizer.mask_token)

            # 10% of the time, we replace masked input tokens with random word
            indices_random = torch.bernoulli(torch.full(labels[i:i+1].shape, 0.5)).bool() & masked_indices & ~indices_replaced
            random_words = torch.randint(len(self.tokenizer), labels[i:i+1].shape, dtype=torch.long)
            inputs[i:i+1][indices_random] = random_words[indices_random]

        ######################
        
        # The rest of the time (10% of the time) we keep the masked input tokens unchanged
        return inputs, labels
    def set_epoch(self, epoch):
        self.epoch = epoch

Overwriting util.py


In [None]:
NUM_GPUs = 8
per_device_batch_size = 16
update_freq = 16
NUM_GPUs*per_device_batch_size*update_freq

2048

In [None]:
NUM_GPUs = 2
per_device_batch_size = 64
update_freq = 16
NUM_GPUs*per_device_batch_size*update_freq

2048

In [6]:
%%writefile pretrain.py

import argparse
import os
from pathlib import Path
from tokenizers import ByteLevelBPETokenizer
from torch.utils.data import Dataset
from tokenizers.implementations import ByteLevelBPETokenizer
from tokenizers.processors import BertProcessing
import torch
from random import randrange
import cloudpickle
from datasets import load_dataset
import random
from transformers import RobertaForMaskedLM
import evaluate
from datasets import load_metric
from collections import defaultdict
import math
from datasets import load_metric
import util
from torch.utils.data import DataLoader
from transformers.trainer_utils import (seed_worker)
from tokenizers import ByteLevelBPETokenizer
import datasets
from transformers import RobertaTokenizerFast
from transformers import Trainer, TrainingArguments
from torch.utils.data import SequentialSampler
import json
from transformers import RobertaConfig


parser = argparse.ArgumentParser("pretraining")
parser.add_argument("config", help="Path to a config.json file")
parser.add_argument("--per_device_train_batch_size", help="per_device_train_batch_size", type=int, nargs="?", const=1, default=64) # TODO
parser.add_argument("--cuda_visible_devices", help="Comma seperated GPU ids to use", nargs="?", const=1, default="")

args = parser.parse_args()


config = None
with open(args.config) as f:
    config = json.load(f)
    print(config)
config["model_path"] = os.path.join("./models/",os.path.basename(config["curriculum_path"]))   

os.environ["CUDA_VISIBLE_DEVICES"] = args.cuda_visible_devices
os.environ["WANDB_PROJECT"]="babylm_pretraining"

if not os.path.exists(config["model_path"]):
    os.makedirs(config["model_path"])




# Load or create tokenizer
tokenizer = None
try:
    tokenizer = RobertaTokenizerFast.from_pretrained(config["model_path"], max_len=512)
except:

    dataset_tokenizer = datasets.load_from_disk(config["dataset_folder"]) # without set_transform
    # https://github.com/huggingface/transformers/tree/main/examples/flax/language-modeling#train-tokenizer
    tokenizer = ByteLevelBPETokenizer()

    def batch_iterator(batch_size=1000):
        for i in range(0, len(dataset_tokenizer), batch_size):
            yield dataset_tokenizer[i: i + batch_size]["text"]

    # Customized training
    tokenizer.train_from_iterator(batch_iterator(), vocab_size=52_000, min_frequency=2, special_tokens=[
        "<s>",
        "<pad>",
        "</s>",
        "<unk>",
        "<mask>",
    ])

    # Save files to disk
    tokenizer.save_model(config["model_path"])
    tokenizer = RobertaTokenizerFast.from_pretrained(config["model_path"], max_len=512)

# Setup custom data_collator:
#   we still use dynamic masking (mask differently at each epoch) as in the original RoBERTa paper, but do so deterministically (by setting the torch seed based on a hash of the document and epoch).
#       this is done to make the influence estimation more realistic
#   Aside: we do not use sentence packing as that would defeat the purpouse of applying an influence estimation method on a per-document basis

data_collator = util.DeterministicDataCollatorForLanguageModeling(
    tokenizer=tokenizer, mlm=True, mlm_probability=0.15
)
data_collator.set_epoch(0)

#   we need to change some classes so that they pass down the current epoch to this datacollator, as well as to the dataloader:


class EpochVariableDataLoader(DataLoader):
    """A version of DataLoder that passes trough the epoch to a specified function when it's set_epoch is called.
       To enable deterministic dynamic masking, the current epoch must be passed down to the data_collator but the Trainer only calls DataLoader.set_epoch().
    """
    def __init__(self, train_dataset, passtrough_function, **dataloader_params):
        self.passtrough_function = passtrough_function
        super().__init__(train_dataset, **dataloader_params)
    def set_epoch(self, epoch):
        self.sampler.epoch = epoch    
        self.passtrough_function(epoch)    
 

class OrderedSampler(SequentialSampler):
    """Loads the curriculum from config["curriculum_path"]:
       This file is either a tensor of shape (num_epochs, n), where each row is treated as an epoch, or a list of tensors where each element is treated as an epoch. 
       The curriculum (and the dataset) may vary in lenght by epoch. 
       *The huggingface Trainer is oblivious of this so keep that in mind when looking at tqdm runtime estimates!*
    """
    def __init__(self, data_source, epoch):
        self.data_source = data_source
        self.epoch = epoch
        self.curriculum = torch.load(config["curriculum_path"], weights_only=True)
       
    def __iter__(self):
        return iter(self.curriculum[self.epoch].tolist())

# load and pre-tokenize the dataset (TODO unclear if that actually increases peformance with our custom dataloader and datacollator)

t = lambda x : tokenizer(x["text"], return_special_tokens_mask=True, truncation=True, max_length=512)

dataset = datasets.load_from_disk(config["dataset_folder"])
dataset = dataset.map(t)
dataset = dataset.remove_columns(["text"]) 
dataset.set_format("torch")

dataset_eval = datasets.load_from_disk(config["eval_dataset_folder"])
dataset_eval = dataset_eval.map(t)
dataset_eval = dataset_eval.remove_columns(["text"]) 
dataset_eval.set_format("torch")





class CurriculumTrainer(Trainer):
    def get_train_dataloader(self) -> DataLoader:
        """
        Adapted to use EpochVariableDataLoader defined below
        Facilitates passing current epoch down to the data_collator (for deterministic dynamic masking) and dataloader (for loading the correct stage in the curriculum)
        Skips accelerator (as that just re-instantiates with the default DataLoader class)!
        """
        train_dataset = self.train_dataset
        data_collator = self.data_collator

        train_dataset = self._remove_unused_columns(train_dataset, description="training")
        dataloader_params = {
            "batch_size": self._train_batch_size,
            "collate_fn": data_collator,
            "num_workers": self.args.dataloader_num_workers,
            "pin_memory": self.args.dataloader_pin_memory,
            "persistent_workers": self.args.dataloader_persistent_workers,
        }

        if not isinstance(train_dataset, torch.utils.data.IterableDataset):
            dataloader_params["sampler"] = OrderedSampler(self.train_dataset, self.state.epoch if self.state.epoch is not None else 0)
            dataloader_params["drop_last"] = self.args.dataloader_drop_last
            dataloader_params["worker_init_fn"] = seed_worker
            dataloader_params["prefetch_factor"] = self.args.dataloader_prefetch_factor

        return EpochVariableDataLoader(train_dataset, data_collator.set_epoch, **dataloader_params) # the Trainer class calls set_epoch on the dataloader, but we also need it in the data_collator


# set up eval 
from collections import defaultdict
batch_metrics = defaultdict(lambda:0) 
def compute_metrics(eval_pred, compute_result=True):
    """Computes accuracy and MLM loss (ignore masked tokens) as in https://github.com/huggingface/transformers/blob/main/examples/flax/language-modeling/run_mlm_flax.py 

    Args:
        eval_pred: Tuple of logits and labels
        compute_result: Trainer will set this to true once all batches are complete. Defaults to True.

    Returns:
        Metrics that are logged to W&B
    """
    global batch_metrics 
    logits, labels = eval_pred
    if not torch.is_tensor(logits):
        logits = torch.tensor(logits)
    if not torch.is_tensor(labels):
        labels = torch.tensor(labels)

    predictions = torch.argmax(logits, axis=-1)
    label_mask = torch.where(labels > 0, 1.0, 0.0)

    batch_metrics["accuracy"] += ((torch.equal(predictions, labels))* label_mask).sum()
    batch_metrics["mlm_loss"] += (torch.nn.functional.cross_entropy(logits, torch.nn.functional.one_hot((labels).to(torch.int64), logits.shape[-1]).to(torch.float64))* label_mask).sum()
    batch_metrics["normalizer"] += label_mask.sum() # number of non-masked labels, divide this when compute_result to get mean 

    if compute_result:
        result = {
            "accuracy": batch_metrics["accuracy"] / batch_metrics["normalizer"],
            "mlm_perplexity": math.exp(batch_metrics["mlm_loss"] / batch_metrics["normalizer"]),
            "mlm_loss": batch_metrics["mlm_loss"] / batch_metrics["normalizer"],
            "normalizer" : batch_metrics["normalizer"]
            }
        batch_metrics = defaultdict(lambda:0) 
        return result
    else:
        return {}

# configs
roberta_config = RobertaConfig(
    vocab_size=tokenizer.vocab_size,
    max_position_embeddings=514,
    num_attention_heads=12,
    num_hidden_layers=12,
    type_vocab_size=1,
    layer_norm_eps=1e-05,
    attention_probs_dropout_prob = 0.1,
    hidden_act = "gelu",
    hidden_dropout_prob=0.1,
    hidden_size =768,
    initializer_range=0.02,
    intermediate_size=3072,
)

EPOCHS = len(torch.load(config["curriculum_path"], weights_only=True))
print("Detected {} epochs in the curriculum provided".format(EPOCHS)) 

training_args = TrainingArguments(
    seed=42,
    output_dir=config["model_path"],
    save_strategy="epoch",
    overwrite_output_dir=True,

    num_train_epochs=EPOCHS, # do not change this manually: see the custom OrderedSampler  
    dataloader_num_workers=10,
    fp16=False, # TODO was True in https://github.com/facebookresearch/fairseq/blob/main/examples/roberta/config/pretraining/base.yaml but loss extreme at start
    prediction_loss_only=False,
    remove_unused_columns=True,

    # https://github.com/facebookresearch/fairseq/blob/main/examples/roberta/README.pretraining.md
    # for an effective batch size of  2048=16*64* 2 GPUS:
        per_device_train_batch_size=64,
        gradient_accumulation_steps=16,
        learning_rate=5e-4, 

        adam_beta1=0.9,
        adam_beta2=0.98,
        adam_epsilon=1e-06,
        weight_decay=0.01,
        lr_scheduler_type="polynomial",
        warmup_steps=10000, 
    # eval
        eval_strategy="epoch",
        label_names=["labels"], # of eval_dataset
        batch_eval_metrics=True,
        per_device_eval_batch_size=64,
        eval_on_start = True,

    # logging
        report_to="wandb", 
        logging_steps=50,   

    # debug
        no_cuda=True,
)

model = RobertaForMaskedLM(config=roberta_config)
trainer = CurriculumTrainer(
    model=model,
    args=training_args,
    data_collator=data_collator,
    train_dataset=dataset,
    eval_dataset=dataset_eval,
    compute_metrics=compute_metrics,
    )
trainer.train()  
trainer.save_model(config["model_path"])

Overwriting pretrain.py


In [None]:
# set up eval 
# from collections import defaultdict
# batch_metrics = defaultdict(lambda:0) 
# def compute_metrics(eval_pred, compute_result=True):
#     """Computes accuracy and MLM loss (ignore masked tokens) as in https://github.com/huggingface/transformers/blob/main/examples/flax/language-modeling/run_mlm_flax.py 

#     Args:
#         eval_pred: Tuple of logits and labels
#         compute_result: Trainer will set this to true once all batches are complete. Defaults to True.

#     Returns:
#         Metrics that are logged to W&B
#     """
#     global batch_metrics 
#     logits, labels = eval_pred
#     logits = torch.tensor(logits)
#     labels = torch.tensor(labels)

#     predictions = torch.argmax(logits, axis=-1)
#     label_mask = torch.where(labels > 0, 1.0, 0.0)

#     batch_metrics["accuracy"] += ((torch.equal(predictions, labels))* label_mask).sum()
#     batch_metrics["mlm_loss"] += (torch.nn.functional.cross_entropy(logits, torch.nn.functional.one_hot((labels).to(torch.int64), logits.shape[-1]).to(torch.float64))* label_mask).sum()
#     batch_metrics["normalizer"] += label_mask.sum() # number of non-masked labels, divide this when compute_result to get mean 

#     if compute_result:
#         result = {
#             "accuracy": batch_metrics["accuracy"] / batch_metrics["normalizer"],
#             "mlm_perplexity": math.exp(batch_metrics["mlm_loss"] / batch_metrics["normalizer"]),
#             "mlm_loss": batch_metrics["mlm_loss"] / batch_metrics["normalizer"],
#             "normalizer" : batch_metrics["normalizer"]
#             }
#         batch_metrics = defaultdict(lambda:0) 
#         return result
#     else:
#         return {}
