In [1]:
from datasets import load_dataset
from transformers import TrainingArguments, LlamaConfig, LlamaForCausalLM, AutoTokenizer, Trainer

In [2]:
# 1. Preparing Datasets
dataset = load_dataset('wikitext', 'wikitext-2-raw-v1', split="train")

In [3]:
# 2. Preparing Tokenizer
tokenizer = AutoTokenizer.from_pretrained("meta-llama/Llama-3.1-8B")

In [4]:
# 3. Model Architecture
config = LlamaConfig(
    vocab_size=tokenizer.vocab_size,
    hidden_size = 256,
    intermediate_size=512,
    num_hidden_layers=4,
    num_attention_heads=4,
)

In [5]:
model = LlamaForCausalLM(config)

In [6]:
# 4. Tokenizing, Dataset Processing
def tokenize_function(examples):
    return tokenizer(examples["text"])

tokenized_datasets = dataset.map(tokenize_function, batched=True, num_proc=4, remove_columns=["text"])

block_size = 128 # Model Memorization Scope

def group_texts(examples):
    concatenated_examples = {k: sum(examples[k], []) for k in examples.keys()}
    total_length = len(concatenated_examples[list(examples.keys())[0]])
    total_length = (total_length // block_size) * block_size
    result = {
        k: [t[i : i + block_size] for i in range(0, total_length, block_size)]
        for k, t in concatenated_examples.items()
    }
    result["labels"] = result["input_ids"].copy()
    return result

lm_datasets = tokenized_datasets.map(
    group_texts,
    batched=True,
    batch_size=1000,
    num_proc=4,
)

Map (num_proc=4):   0%|          | 0/36718 [00:00<?, ? examples/s]

Map (num_proc=4):   0%|          | 0/36718 [00:00<?, ? examples/s]

In [7]:
# 5. Training Processs Setting and Execution
training_args = TrainingArguments(
    output_dir="./results",
    overwrite_output_dir=True,
    num_train_epochs=3,
    per_device_train_batch_size=4,
    learning_rate=2e-5,
    save_steps=10_000,
    save_total_limit=2,
    logging_strategy="steps",
    logging_steps=10,
)

trainer = Trainer(
    model=model,
    args=training_args,
    train_dataset=lm_datasets,
)

In [8]:
# Code run
trainer.train()



Step,Training Loss
10,11.5025
20,11.4846
30,11.3571
40,11.3135
50,11.2008
60,11.1429
70,11.0714
80,10.9884
90,10.9458
100,10.8124


TrainOutput(global_step=14469, training_loss=6.393296638218269, metrics={'train_runtime': 921.6925, 'train_samples_per_second': 62.787, 'train_steps_per_second': 15.698, 'total_flos': 1572956333015040.0, 'train_loss': 6.393296638218269, 'epoch': 3.0})