# 1. setup

In [1]:
import math, os, random
from dataclasses import dataclass
import torch
import torch.nn as nn
from torch.utils.data import Dataset, DataLoader


In [5]:
from transformers import (
    AutoTokenizer, BertForMaskedLM,
    get_linear_schedule_with_warmup,
    DataCollatorForLanguageModeling, Trainer, TrainingArguments
)

In [6]:
SEED = 42
random.seed(SEED)
torch.manual_seed(SEED)
DEVICE = torch.device("mps")
print("[1] Device:", DEVICE)

[1] Device: mps


# 2. Tiny toy corpus and dataset

In [12]:
CORPUS = [
    "Transformers encode rich bidirectional context.",
    "Masked language modeling enables bidirectional learning.",
    "BERT uses WordPiece tokenization and special tokens.",
    "We randomly mask about fifteen percent of tokens.",
    "Sometimes we keep original words to avoid overfitting to [MASK].",
    "Occasionally a random token is inserted to add noise.",
    "This training objective builds strong contextual embeddings.",
    "Whole Word Masking groups wordpieces of the same word together.",
    "Span masking masks consecutive tokens to model phrases.",
    "Pretraining can later be fine tuned for downstream tasks."
] * 50  # ~500 lines for a quick demo

split = int(len(CORPUS) * 0.9)
train_texts = CORPUS[:split]
val_texts = CORPUS[split:]

class TextDataset(Dataset):
    """
    Returns raw strings; custom data_collator will tokenize & create MLM labels.
    Output of __getitem__(i): str
    """
    def __init__(self, texts):
        self.texts = texts

    def __len__(self):
        return len(self.texts)

    def __getitem__(self, idx):
        return self.texts[idx] 

# 3. Manual MLM collator (15% + 80/10/10, supports WWM/Span)

In [10]:
@dataclass
class MLMCollatorManual:
    """
    Build an MLM batch with dynamic masking.
    Outputs dict of tensors:
      - input_ids      [B, L] (masked inputs)
      - attention_mask [B, L]
      - token_type_ids [B, L] (if tokenizer provides)
      - labels         [B, L] (non-masked positions = -100)
    """
    tokenizer: any
    mlm_probability: float = 0.15
    max_length: int = 128
    whole_word_mask: bool = False
    span_mask: bool = False
    mean_span_len: float = 3.0

    def __call__(self, batch_texts):
        # 1) Tokenize (dynamic padding up to batch longest, capped by max_length)
        enc = self.tokenizer(
            batch_texts,
            padding="longest",
            truncation=True,
            max_length=self.max_length,
            return_special_tokens_mask=True,
            return_attention_mask=True,
            return_tensors="pt"
        )

        input_ids = enc["input_ids"]                     # [B, L]
        attention_mask = enc["attention_mask"]           # [B, L]
        special_mask = enc["special_tokens_mask"].bool() # [B, L] such as CLS SEP

        labels = torch.full_like(input_ids, -100)        # [B, L], -100 ignored by CE
        candidate = (~special_mask) & attention_mask.bool() # not special and need attention's

        # 2) Choose masked positions
        if self.whole_word_mask:
            chosen = self._choose_wwm(enc, candidate, self.mlm_probability)  # [B, L] bool
        elif self.span_mask:
            chosen = self._choose_spans(candidate, self.mlm_probability, self.mean_span_len)
        else:
            probs = torch.full_like(input_ids, self.mlm_probability, dtype=torch.float32)
            chosen = (torch.bernoulli(probs).bool()) & candidate

        # 3) 80/10/10 replacement
        labels[chosen] = input_ids[chosen] #only masked positions contribute to loss, other places are all -100, maksed position are input token ids
        r = torch.rand_like(input_ids, dtype=torch.float32) #uniform distribution on interval [0,1) of size input_ids shape
        replace_mask = chosen & (r < 0.8)              # -> [MASK]
        replace_rand = chosen & (r >= 0.8) & (r < 0.9) # -> random token
        #remaining 10% keep original

        masked = input_ids.clone()
        masked[replace_mask] = self.tokenizer.mask_token_id

        if replace_rand.any():
            vocab_size = self.tokenizer.vocab_size
            rand_ids = torch.randint(low=0, high=vocab_size, size=masked.shape, dtype=torch.long)
            specials = set(self.tokenizer.all_special_ids)
            bad = torch.isin(rand_ids, torch.tensor(list(specials)))
            if bad.any():
                rand_ids[bad] = self.tokenizer.unk_token_id

            masked[replace_rand] = rand_ids[replace_rand]

        out = {
            "input_ids": masked,
            "attention_mask": attention_mask,
            "labels": labels
        }

        if "token_type_ids" in enc:
            out["token_type_ids"] = enc["token_type_ids"]
        return out

    def _choose_wwm(self, enc, candidate, p):
        """whole word masking via word_ids grouping"""
        B, L = candidate.shape
        chosen = torch.zeros_like(candidate, dtype=torch.bool)
        for i in range(B):
            word_ids = enc.word_ids(batch_index=i) #len L (ints or None)
            groups = {}
            for t, wid in enumerate(word_ids):
                if wid is None:
                    continue
                if not candidate[i, t]:
                    continue
                groups.setdefault(wid, []).append(t)

            group_list = list(groups.values())
            random.shuffle(group_list)
            target = int(p * candidate[i].sum().item())
            covered = 0
            for g in group_list:
                for idx in g:
                    chosen[i, idx] = True
                covered += len(g)
                if covered >= target:
                    break
        return chosen

    def _choose_spans(self, canditate, p, mean_span_len=3.0):
        """span masking using a geometric span length"""
        B, L = candidate.shape
        chosen = torch.zeros_like(candidate, dtype=torch.bool)
        q = 1.0 / (1.0 + mean_span_len) #geometric param

        for i in range(B):
            cand_idx = candidate[i].nonzero(as_tuple=False).flatten().tolist()
            random.shuffle(cand_idx)
            target = int(p * len(cand_idx))
            covered, used = 0, set()
            while covered < target and cand_idx:
                start = random.choice(cand_idx)
                if start in used:
                    cand_idx.remove(start)
                    continue
                Ls = 1
                while random.random() > q:
                    Ls += 1

                span = []
                t = start
                while len(span) < Ls and t < L and candidate[i, t] and (t not in used):
                    span.append(t)
                    t += 1

                for s in span:
                    chosen[i, s] = True
                    used.add(s)
                covered += len(span)
        return chosen
                    

# 4. Tokenizer and datasets and collator init

In [13]:
tokenizer = AutoTokenizer.from_pretrained("bert-base-uncased", use_fast=True)

train_ds = TextDataset(train_texts)
val_ds = TextDataset(val_texts)

collator = MLMCollatorManual(
    tokenizer=tokenizer,
    mlm_probability=0.15,
    max_length=64,
    whole_word_mask=True,
    span_mask=False,
    mean_span_len=3.0
)

# 5. Model (BertForMaskedLM)

In [14]:
model = BertForMaskedLM.from_pretrained("bert-base-uncased").to(DEVICE)

Some weights of the model checkpoint at bert-base-uncased were not used when initializing BertForMaskedLM: ['bert.pooler.dense.bias', 'bert.pooler.dense.weight', 'cls.seq_relationship.bias', 'cls.seq_relationship.weight']
- This IS expected if you are initializing BertForMaskedLM from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).
- This IS NOT expected if you are initializing BertForMaskedLM from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).


# 6. Training Arguments and compute_metrics

In [16]:
args = TrainingArguments(
    output_dir="out",
    per_device_train_batch_size=16,
    per_device_eval_batch_size=16,
    num_train_epochs=3,
    learning_rate=5e-5,
    weight_decay=0.01,
    optim="adamw_torch",     # requested
    warmup_ratio=0.1,        # requested
    eval_strategy="epoch",
    logging_strategy="steps",
    logging_steps=50,
    save_strategy="no",      # keep demo light; change to "epoch" to save checkpoints
    report_to=[],            # disable WandB/Comet by default
    seed=SEED,
    dataloader_pin_memory=False,
    remove_unused_columns=False,  # we pass raw strings; collator handles tokenization
)

def masked_token_accuracy(eval_pred):
    """
    Compute masked-token accuracy:
        predictions: [N, L, V]
        labels: [N, L] (non-masked=-100)
    """
    if hasattr(eval_pred, "predictions"):
        preds = eval_pred.predictions
        labels = eval_pred.label_ids
    else:
        preds, labels = eval_pred

    pred_ids = torch.from_numpy(preds).argmax(-1) if not torch.is_tensor(preds) else preds.argmax(-1)
    labels = torch.from_numpy(labels) if not torch.is_tensor(labels) else labels
    mask = labels.ne(-100)
    if mask.sum().item() == 0:
        return {"masked_acc": 0.0}
    correct = (pred_ids.eq(labels) & mask).sum().item()
    total = mask.sum().item()
    return {"masked_acc": correct / total}

# 7. Trainer/Train/Evaluate

In [17]:
trainer = Trainer(
    model=model,
    args=args,
    train_dataset=train_ds,
    eval_dataset=val_ds,
    data_collator=collator,   # manual dynamic masking (15% + 80/10/10)
    tokenizer=tokenizer,
    compute_metrics=masked_token_accuracy,
)

print("\n[Train] Starting Trainer.fit() …")
trainer.train()
print("[Eval] Final:", trainer.evaluate())

  trainer = Trainer(



[Train] Starting Trainer.fit() …


Epoch,Training Loss,Validation Loss,Masked Acc
1,No log,0.956133,0.855072
2,2.252300,0.30096,0.887324
3,2.252300,0.300889,0.940299


[Eval] Final: {'eval_loss': 0.29481780529022217, 'eval_masked_acc': 0.9661016949152542, 'eval_runtime': 0.1309, 'eval_samples_per_second': 381.916, 'eval_steps_per_second': 30.553, 'epoch': 3.0}
