In [1]:
import os
os.chdir('..')

In [8]:
import torch
from torch import nn
from transformers import RobertaForMaskedLM, RobertaTokenizer
from transformers import Trainer, TrainingArguments

In [4]:
ckpt_path = './roberta_baseline/'

In [5]:
model = RobertaForMaskedLM.from_pretrained(ckpt_path)
tokenizer = RobertaTokenizer.from_pretrained(ckpt_path)

In [6]:
from transformers import LineByLineTextDataset

dataset = LineByLineTextDataset(
    tokenizer=tokenizer,
    file_path="./datasets/babylm_10M_merged.train",
    block_size=128,
)



In [14]:
from transformers import DataCollatorForLanguageModeling

data_collator = DataCollatorForLanguageModeling(
    tokenizer=tokenizer, mlm=True, mlm_probability=0.15
)


In [53]:
class ReinforceMLMTrainer(Trainer):

    def __init__(self, *args, **kwargs):
        super().__init__(*args, **kwargs)
        
    def compute_loss(self, model, inputs):

        loss, outputs = super().compute_loss(model, inputs, return_outputs=True)
        
        logits = outputs.logits
        labels = inputs["labels"]
        mask = inputs["labels"].ne(-100)
        predictions = torch.argmax(logits, dim=-1)
        accuracy = torch.sum(predictions.eq(labels) * mask) / torch.sum(mask)
        reward = loss * accuracy.item()        
        
        return reward


In [54]:
training_args = TrainingArguments(
    output_dir='./dummy',
    overwrite_output_dir=True,
    learning_rate=5e-5,
    per_device_train_batch_size=16,
    num_train_epochs=1,
    save_steps=100,
    save_total_limit=100,
    seed=12,
    prediction_loss_only=True,
)


PyTorch: setting up devices
The default value for the training argument `--report_to` will change in v5 (from all installed integrations to none). In v5, you will need to use `--report_to all` to get the same behavior as now. You should start updating your code and make this info disappear :-).


In [55]:
trainer = ReinforceMLMTrainer(
    model=model,
    args=training_args,
    data_collator=data_collator,
    train_dataset=dataset,
)


In [None]:
trainer.train()