In [1]:
from utils import TrainingConfig

training_config = TrainingConfig()

In [2]:
from datasets import load_dataset
from transformers import AutoTokenizer
tokenizer = AutoTokenizer.from_pretrained("gpt2")   # or your custom one

# 1. Load the raw text
ds = load_dataset("openwebtext", split="train", trust_remote_code=True)      # ≈38 GB of text

Loading dataset shards:   0%|          | 0/80 [00:00<?, ?it/s]

In [3]:
ds = ds.select(range(1000))
ds[0]

{'text': 'Port-au-Prince, Haiti (CNN) -- Earthquake victims, writhing in pain and grasping at life, watched doctors and nurses walk away from a field hospital Friday night after a Belgian medical team evacuated the area, saying it was concerned about security.\n\nThe decision left CNN Chief Medical Correspondent Sanjay Gupta as the only doctor at the hospital to get the patients through the night.\n\nCNN initially reported, based on conversations with some of the doctors, that the United Nations ordered the Belgian First Aid and Support Team to evacuate. However, Belgian Chief Coordinator Geert Gijs, a doctor who was at the hospital with 60 Belgian medical personnel, said it was his decision to pull the team out for the night. Gijs said he requested U.N. security personnel to staff the hospital overnight, but was told that peacekeepers would only be able to evacuate the team.\n\nHe said it was a "tough decision" but that he accepted the U.N. offer to evacuate after a Canadian medical t

In [4]:
tokenizer.pad_token = tokenizer.eos_token

def tokenize(batch):
    return tokenizer(
        batch["text"],
        # truncation=False,
        # max_length=training_config.max_len,
        # padding=False,
        # return_tensors="pt",
    )

tokenized = ds.map(tokenize, batched=True, remove_columns=["text"])
tokenized

Dataset({
    features: ['input_ids', 'attention_mask'],
    num_rows: 1000
})

In [7]:
def group(batch):
    # Flattens the input_ids and attention_mask into single lists
    flat_ids = sum(batch["input_ids"], [])
    flat_masks = sum(batch["attention_mask"], [])

    num_of_complete_blocks = len(flat_ids) // training_config.block_size
    total = num_of_complete_blocks * training_config.block_size
    flat_ids = flat_ids[:total]
    flat_masks = flat_masks[:total]

    return {
        "input_ids": [flat_ids[i:i+training_config.block_size] for i in range(0, total, training_config.block_size)],
        "attention_mask": [flat_masks[i:i+training_config.block_size] for i in range(0, total, training_config.block_size)],
        "labels": [flat_ids[i+1:i+training_config.block_size+1] for i in range(0, total, training_config.block_size)]
    }

lm_ds = tokenized.map(group, batched=True, batch_size=10000)

Map:   0%|          | 0/1000 [00:00<?, ? examples/s]