# Prep Data for Pretraining on T5

In [1]:
from datasets import load_from_disk

dataset = load_from_disk('../RawData/raw-ds')

dataset = dataset['train']

## Corrupt Training Text

T5 is trained by learning to correct missing spans in text. Thus, the training data must have spans masked for training.

In [2]:
import random

def bo_corrupt_text_batch(examples):
    input_texts = []
    target_texts = []

    for text in examples["bo"]:
        words = text.split('་')
        num_masks = max(1, len(words) // 6)
        masked_indices = sorted(random.sample(range(len(words)), num_masks))

        new_text = []
        labels = []
        current_mask = 0

        for i, word in enumerate(words):
            if i in masked_indices:
                if not new_text or new_text[-1] != f"<extra_id_{current_mask}>":
                    new_text.append(f"<extra_id_{current_mask}>")
                    labels.append(f"<extra_id_{current_mask}> {word}")
                    current_mask += 1
                else:
                    labels[-1] += f" {word}"
            else:
                new_text.append(word)

        input_texts.append(" ".join(new_text))
        target_texts.append(" ".join(labels))

    return {"input_text": input_texts, "target_text": target_texts}

In [3]:
# Use batched=True and batch_size=N (default is fine)
bo_train_dataset = dataset.map(bo_corrupt_text_batch, batched=True)
bo_train_dataset.save_to_disk('Data/pieces/bo_train_dataset')

Saving the dataset (0/3 shards):   0%|          | 0/861417 [00:00<?, ? examples/s]

In [5]:
bo_train_dataset[0]

{'bo': 'བླ་མ་ལ་ཞུས་པས། མལ་གྱིས་བསེ་མགོན་འདིས་རྒྱ་གར་དུའང་མཐུ་རྩལ་འགྲན་མེད་བྱེད།',
 'en': 'For this reason Lama Sachen said, “The forcefulness of Mal’s Skin Mask Guardian is unrivaled even in India.',
 'input_text': 'བླ <extra_id_0> ལ ཞུས པས། མལ གྱིས བསེ མགོན འདིས རྒྱ གར དུའང མཐུ <extra_id_1> འགྲན མེད བྱེད།',
 'target_text': '<extra_id_0> མ <extra_id_1> རྩལ'}

In [6]:
import random

def en_corrupt_text_batch(examples):
    input_texts = []
    target_texts = []

    for text in examples["en"]:
        words = text.split()
        num_masks = max(1, len(words) // 6)
        masked_indices = sorted(random.sample(range(len(words)), num_masks))

        new_text = []
        labels = []
        current_mask = 0

        for i, word in enumerate(words):
            if i in masked_indices:
                if not new_text or new_text[-1] != f"<extra_id_{current_mask}>":
                    new_text.append(f"<extra_id_{current_mask}>")
                    labels.append(f"<extra_id_{current_mask}> {word}")
                    current_mask += 1
                else:
                    labels[-1] += f" {word}"
            else:
                new_text.append(word)

        input_texts.append(" ".join(new_text))
        target_texts.append(" ".join(labels))

    return {"input_text": input_texts, "target_text": target_texts}

In [7]:
# Use batched=True and batch_size=N (default is fine)
en_train_dataset = dataset.map(en_corrupt_text_batch, batched=True)
en_train_dataset.save_to_disk('Data/pieces/en_train_dataset')

Saving the dataset (0/3 shards):   0%|          | 0/861417 [00:00<?, ? examples/s]

In [8]:
en_train_dataset[0]

{'bo': 'བླ་མ་ལ་ཞུས་པས། མལ་གྱིས་བསེ་མགོན་འདིས་རྒྱ་གར་དུའང་མཐུ་རྩལ་འགྲན་མེད་བྱེད།',
 'en': 'For this reason Lama Sachen said, “The forcefulness of Mal’s Skin Mask Guardian is unrivaled even in India.',
 'input_text': '<extra_id_0> this reason Lama Sachen <extra_id_1> “The forcefulness of Mal’s Skin Mask Guardian is unrivaled even in <extra_id_2>',
 'target_text': '<extra_id_0> For <extra_id_1> said, <extra_id_2> India.'}

## Concatenate Dataset Pieces

In [9]:
from datasets import load_from_disk

en_train_dataset = load_from_disk('Data/pieces/en_train_dataset')
bo_train_dataset = load_from_disk('Data/pieces/bo_train_dataset')

In [10]:
from datasets import concatenate_datasets

ds = concatenate_datasets([en_train_dataset, bo_train_dataset]).shuffle(seed=42)

In [11]:
ds.save_to_disk('Data/pretraining-ds')

Saving the dataset (0/5 shards):   0%|          | 0/1722834 [00:00<?, ? examples/s]