## 1. Train tokenizer

In [None]:
%%capture
!pip install tokenizers transformers datasets[s3]

In [None]:
from datasets import load_dataset                                          
from tokenizers import ByteLevelBPETokenizer                              
from transformers import PreTrainedTokenizerFast                           

dataset = load_dataset(
    "HuggingFaceFW/fineweb",
    "sample-10BT",
    split="train", 
    streaming=False
)

In [None]:
## Remove non text columns, https://discuss.huggingface.co/t/speed-issues-using-tokenizer-train-new-from-iterator-on-50gb-dataset/29125
dataset = dataset.remove_columns([
        col for col in dataset.column_names if col != "text"
])

In [None]:
# train a byte-level BPE tokenizer
tokenizer = ByteLevelBPETokenizer()
special_tokens = ["<s>", "<pad>", "</s>", "<unk>", "<mask>"]

def batch_iterator(dataset, batch_size=500):
    for batch in dataset.iter(batch_size=batch_size):
        yield batch["text"]  

tokenizer.train_from_iterator(
    batch_iterator(dataset),
    vocab_size=52_000,
    min_frequency=2,
    special_tokens=special_tokens,
    show_progress=True
)

In [None]:
tokenizer.save("fineweb-10bt-tokenizer-bpe.json") 

## 2. Tokenize datasets

In [None]:
import os
from transformers import PreTrainedTokenizerFast
from datasets import load_dataset

tokenizer = PreTrainedTokenizerFast(
    tokenizer_file="fineweb-10bt-tokenizer-bpe.json",
    bos_token="<s>",
    eos_token="</s>",
    pad_token="<pad>",
    unk_token="<unk>",
    mask_token="<mask>",
)

context_length = 1024

def concatenate_and_chunk(element):
    all_token_ids = []
    for text in element["text"]:
        token_ids = tokenizer.encode(text, add_special_tokens=False)
        all_token_ids.extend(token_ids)
        all_token_ids.append(tokenizer.eos_token_id)

    total_length = len(all_token_ids)

    if total_length < context_length:
        return {"input_ids": [], "labels": []}

    total_length = (total_length // context_length) * context_length

    # Split the concatenated tokens into chunks of context_length
    chunks_input_ids = []
    for i in range(0, total_length, context_length):
        chunk = all_token_ids[i : i + context_length]
        if len(chunk) == context_length:
            chunks_input_ids.append(chunk)

    output = {"input_ids": chunks_input_ids, "labels": chunks_input_ids.copy()}
    return output


print("Loading raw dataset...")
raw_dataset = load_dataset(
    "HuggingFaceFW/fineweb",
    "sample-10BT",
    split="train",
)

raw_dataset = raw_dataset.remove_columns(
    [col for col in raw_dataset.column_names if col != "text"]
)

print("Applying tokenization and chunking...")
tokenized_dataset = raw_dataset.map(
    concatenate_and_chunk,
    batched=True,
    remove_columns=raw_dataset.column_names,
    num_proc=os.cpu_count(),
)

print("Tokenization complete.")
print(tokenized_dataset[0])


tokenized_dataset.save_to_disk("./tokenized-dataset")

## 3. Model Pre-training

In [None]:
%%capture
!pip install wandb torch torchvision torchaudio transformers[torch] 'accelerate>=0.26.0' tokenizers datasets[s3]


In [None]:
!wandb login <token>

In [None]:
from datasets import load_from_disk
from transformers import (
    GPT2Config,
    GPT2LMHeadModel,
    PreTrainedTokenizerFast,
    TrainingArguments,
    Trainer,
    DataCollatorForLanguageModeling,
)

tokenized_dataset = load_from_disk("./tokenized-dataset")

tokenizer = PreTrainedTokenizerFast(
    tokenizer_file="fineweb-10bt-tokenizer-bpe.json",
    bos_token="<s>",
    eos_token="</s>",
    pad_token="<pad>",
    unk_token="<unk>",
    mask_token="<mask>",
)

config = GPT2Config(
    vocab_size=tokenizer.vocab_size,
    n_positions=1024,
    n_ctx=1024,
    n_embd=1024,
    n_layer=24,
    n_head=16,
    bos_token_id=tokenizer.bos_token_id,
    eos_token_id=tokenizer.eos_token_id,
    pad_token_id=tokenizer.pad_token_id
)

model = GPT2LMHeadModel(config)

data_collator = DataCollatorForLanguageModeling(
    tokenizer=tokenizer,
    mlm=False,
    pad_to_multiple_of=None,
)

training_args = TrainingArguments(
    output_dir="./fineweb-gpt2-356m",
    overwrite_output_dir=True,
    num_train_epochs=1,
    per_device_train_batch_size=64,
    gradient_accumulation_steps=16,  # effective batch size = 1024
    learning_rate=1e-4,
    weight_decay=0.01,
    warmup_ratio=0.03,
    logging_steps=25,
    save_steps=500,
    save_total_limit=3,
    prediction_loss_only=True,
    fp16=True,
    logging_dir='./logs-356m',
    report_to="wandb",
    run_name = "fineweb-gpt2-356m-0p2",
    torch_compile=True,
    lr_scheduler_type="cosine",
    seed=3047,
)

trainer = Trainer(
    model=model,
    args=training_args,
    data_collator=data_collator,
    train_dataset=tokenized_dataset,
)

trainer.train()

In [None]:
trainer.save_model("./fineweb-gpt2-final")