In [1]:
# =========================================
# patRoBERTa Toy MLM — Simple & Stable
# =========================================
import math
from pathlib import Path
from datasets import load_dataset, load_from_disk
from transformers import (
    RobertaConfig, RobertaForMaskedLM, RobertaTokenizerFast,
    DataCollatorForLanguageModeling, Trainer, TrainingArguments
)

# ---------- CONSTANTS (edit) ----------
TRAIN_TXT     = "../data/ep-b1-claim1-corpus/ep-b1-claim1-cpc_train.txt"
VAL_TXT       = "../data/ep-b1-claim1-corpus/ep-b1-claim1-cpc_val.txt"
ENCODINGS_DIR = Path("../data/patroberta-encoded-512-vs8000")
TOKENIZER_DIR = "../artifacts/patroberta-tokenizers/vs8000"

SEQ_LEN = 512                  # keep small on 4GB
ARCH_MAX_POSITIONS = 514       # capacity (OK for 512 later)
MLM_PROB = 0.15

# Tiny model
HIDDEN_SIZE, NUM_LAYERS, NUM_HEADS, INTER_SIZE = 128, 2, 2, 512
DROPOUT = 0.1

# Training knobs (keep tiny to avoid OOM)
PER_DEVICE_TRAIN_BS = 64       # fits 4GB with fp16 + checkpointing
PER_DEVICE_EVAL_BS  = 8
GRAD_ACCUM_STEPS    = 2        # effective batch = 32 sequences
LEARNING_RATE       = 5e-4
WEIGHT_DECAY        = 0.01
WARMUP_RATIO        = 0.06
FP16                = True
NUM_EPOCHS          = 2        # or set MAX_STEPS instead of epochs

OUT_DIR = "../artifacts/patroberta-mlm-512-simple"

# ---------- Tokenizer & data ----------
tok = RobertaTokenizerFast.from_pretrained(TOKENIZER_DIR)
tok.model_max_length = SEQ_LEN

def enc(b): return tok(b["text"], truncation=True, max_length=SEQ_LEN)

if ENCODINGS_DIR.exists():
    ds = load_from_disk(str(ENCODINGS_DIR))
else:
    raw = load_dataset("text", data_files={"train": TRAIN_TXT, "validation": VAL_TXT})
    ds = raw.map(enc, batched=True, remove_columns=["text"])
    ds.save_to_disk(str(ENCODINGS_DIR))

# ---------- Collator ----------
collator = DataCollatorForLanguageModeling(tokenizer=tok, mlm=True, mlm_probability=MLM_PROB)

# ---------- Model ----------
cfg = RobertaConfig(
    vocab_size=tok.vocab_size,
    hidden_size=HIDDEN_SIZE,
    num_hidden_layers=NUM_LAYERS,
    num_attention_heads=NUM_HEADS,
    intermediate_size=INTER_SIZE,
    hidden_dropout_prob=DROPOUT,
    attention_probs_dropout_prob=DROPOUT,
    max_position_embeddings=ARCH_MAX_POSITIONS,
    pad_token_id=tok.pad_token_id,
    bos_token_id=tok.bos_token_id,
    eos_token_id=tok.eos_token_id,
)
model = RobertaForMaskedLM(cfg)
model.gradient_checkpointing_enable()  # big saver on 4GB

# ---------- Auto-scale logging by epoch fraction ----------
steps_per_epoch = max(1, math.ceil(len(ds["train"]) / (PER_DEVICE_TRAIN_BS * GRAD_ACCUM_STEPS)))
EVAL_STEPS     = max(1, steps_per_epoch // 4)   # 4× per epoch
SAVE_STEPS     = EVAL_STEPS
LOGGING_STEPS  = max(1, steps_per_epoch // 10)  # 10× per epoch

# ---------- Training args ----------
args = TrainingArguments(
    output_dir=OUT_DIR,
    per_device_train_batch_size=PER_DEVICE_TRAIN_BS,
    per_device_eval_batch_size=PER_DEVICE_EVAL_BS,
    gradient_accumulation_steps=GRAD_ACCUM_STEPS,
    learning_rate=LEARNING_RATE,
    weight_decay=WEIGHT_DECAY,
    warmup_ratio=WARMUP_RATIO,
    num_train_epochs=NUM_EPOCHS,
    eval_strategy="steps",
    eval_steps=EVAL_STEPS,
    save_steps=SAVE_STEPS,
    logging_steps=LOGGING_STEPS,
    fp16=FP16,
    fp16_full_eval=True,
    report_to="none",
    load_best_model_at_end=True,
    metric_for_best_model="loss",
    greater_is_better=False,
    torch_empty_cache_steps=LOGGING_STEPS,
)

trainer = Trainer(
    model=model,
    args=args,
    train_dataset=ds["train"],
    eval_dataset=ds["validation"],
    processing_class=tok,
    data_collator=collator,
)

trainer.train()
out = trainer.evaluate()
import math as _m
print(out, "Perplexity:", _m.exp(out["eval_loss"]))


Map:   0%|          | 0/337648 [00:00<?, ? examples/s]

Map:   0%|          | 0/3445 [00:00<?, ? examples/s]

Saving the dataset (0/2 shards):   0%|          | 0/337648 [00:00<?, ? examples/s]

Saving the dataset (0/1 shards):   0%|          | 0/3445 [00:00<?, ? examples/s]

Step,Training Loss,Validation Loss
2638,5.4427,5.16184
5276,4.9561,4.875057
7914,4.8214,4.717948
10552,4.4193,3.814816
13190,3.7354,3.270748
15828,3.4815,3.116742
18466,3.3936,3.008044
21104,3.3362,2.979048


There were missing keys in the checkpoint model loaded: ['lm_head.decoder.weight', 'lm_head.decoder.bias'].


{'eval_loss': 2.9789295196533203, 'eval_runtime': 21.742, 'eval_samples_per_second': 158.449, 'eval_steps_per_second': 19.823, 'epoch': 2.0} Perplexity: 19.666752500362946
