# Continue Pretraining on T5

In this notebook, I continue pre-training on a T5-small model.

In [None]:
from datasets import load_dataset

dataset = load_dataset("billingsmoore/Aggregated-bo-en", split='train')

## Corrupt Training Text

T5 is trained by learning to correct missing spans in text. Thus, the training data must have spans masked for training.

In [None]:
import random

def bo_corrupt_text(example):
    text = example["bo"]
    words = text.split()
    num_masks = max(1, len(words) // 6)  # Mask ~15-20% of the words
    masked_indices = sorted(random.sample(range(len(words)), num_masks))

    new_text = []
    labels = []
    current_mask = 0

    for i, word in enumerate(words):
        if i in masked_indices:
            if not new_text or new_text[-1] != f"<extra_id_{current_mask}>":
                new_text.append(f"<extra_id_{current_mask}>")
                labels.append(f"<extra_id_{current_mask}> {word}")
                current_mask += 1
            else:
                labels[-1] += f" {word}"
        else:
            new_text.append(word)

    return {"input_text": " ".join(new_text), "target_text": " ".join(labels)}

bo_train_dataset = dataset.map(bo_corrupt_text)


In [None]:
def en_corrupt_text(example):
    text = example["en"]
    words = text.split()
    num_masks = max(1, len(words) // 6)  # Mask ~15-20% of the words
    masked_indices = sorted(random.sample(range(len(words)), num_masks))

    new_text = []
    labels = []
    current_mask = 0

    for i, word in enumerate(words):
        if i in masked_indices:
            if not new_text or new_text[-1] != f"<extra_id_{current_mask}>":
                new_text.append(f"<extra_id_{current_mask}>")
                labels.append(f"<extra_id_{current_mask}> {word}")
                current_mask += 1
            else:
                labels[-1] += f" {word}"
        else:
            new_text.append(word)

    return {"input_text": " ".join(new_text), "target_text": " ".join(labels)}

en_train_dataset = dataset.map(en_corrupt_text)

In [None]:
from datasets import concatenate_datasets

ds = concatenate_datasets([en_train_dataset, bo_train_dataset])

## Tokenize the Data for Training

In [None]:
from transformers import AutoTokenizer

tokenizer = AutoTokenizer.from_pretrained('my_tokenizer')

In [None]:
def tokenize_data(example):
    inputs = tokenizer(example["input_text"], max_length=512, truncation=True, padding="max_length")
    targets = tokenizer(example["target_text"], max_length=512, truncation=True, padding="max_length")
    return {
        "input_ids": inputs.input_ids,
        "attention_mask": inputs.attention_mask,
        "labels": targets.input_ids
    }

tokenized_dataset = ds.map(tokenize_data, batched=True, remove_columns=["bo", 'en', 'topic', "input_text", "target_text"])

## Train the Model

In [None]:
%env WANDB_PROJECT=translation-v1

In [None]:
from transformers import TrainingArguments, Trainer, T5ForConditionalGeneration

model = T5ForConditionalGeneration.from_pretrained("google-t5/t5-large", device_map='cuda:0')
model.resize_token_embeddings(len(tokenizer))

training_args = TrainingArguments(
    output_dir="pretrain-model",
    save_strategy="epoch",
    auto_find_batch_size=True,
    learning_rate=3e-4,
    num_train_epochs=1,
    push_to_hub=False
)

trainer = Trainer(
    model=model,
    args=training_args,
    train_dataset=tokenized_dataset,
)

trainer.train()

In [None]:
model.save_pretrained('pretrained-model')