# 1. Imports & setup

In [2]:
import random
from dataclasses import dataclass

import torch
from torch.utils.data import Dataset
from transformers import (
    AutoTokenizer,
    BertForNextSentencePrediction,
    BertForPreTraining,
    TrainingArguments,
    Trainer
)

SEED = 42
random.seed(SEED)
torch.manual_seed(SEED)
DEVICE = torch.device("mps")

# 2. Tiny documents & simple sentence splitter

In [3]:
DOCS = [
    "Transformers are powerful models. They use self attention. BERT is bidirectional. We can pretrain with MLM and NSP.",
    "Sentence pairs can test coherence. Random pairs are negatives. True next sentences are positives. Sampling matters.",
    "Pretraining builds contextual embeddings. Fine tuning adapts to tasks. Classification uses CLS. NER uses token labels.",
    "Tokenization splits text into wordpieces. WordPiece reduces OOV. Special tokens include CLS and SEP. Padding helps batching.",
]

def naive_split_sentence(doc):
    """split based on . and use strip filter empty string"""
    return [s.strip() for s in doc.split(".") if s.strip()]

docs_sents = [naive_split_sentence(d) for d in DOCS]

In [4]:
docs_sents

[['Transformers are powerful models',
  'They use self attention',
  'BERT is bidirectional',
  'We can pretrain with MLM and NSP'],
 ['Sentence pairs can test coherence',
  'Random pairs are negatives',
  'True next sentences are positives',
  'Sampling matters'],
 ['Pretraining builds contextual embeddings',
  'Fine tuning adapts to tasks',
  'Classification uses CLS',
  'NER uses token labels'],
 ['Tokenization splits text into wordpieces',
  'WordPiece reduces OOV',
  'Special tokens include CLS and SEP',
  'Padding helps batching']]

# 3. Build NSP pairs (50/50, optional hard negatives)

In [5]:
def build_nsp_pairs(docs, num_pairs=800, pos_ratio=0.5, use_hard_neg=False):
    """
    Return a list od dicts: {"text_a": str, "text_b":str, "label": int}
    label 1=IsNext 0=NotNext
    pos example: A next(A)
    neg example: A random_other_doc_sentence
    if use_hard_neg, then part of the neg examples using same doc not neighboring sentences: A not_neighboring_same_doc_sentence
    """
    rng = random.Random(SEED)
    all_pairs = []
    
    #flatten for easy sampling
    all_sents = [(d_i, i, s) for d_i, doc in enumerate(docs) for i, s in enumerate(doc)]

    while len(all_pairs) < num_pairs:
        want_pos = rng.random() < pos_ratio

        #sample a doc with at least 2 sentences for positives
        d_i = rng.randrange(len(docs))
        doc = docs[d_i]
        if len(doc) < 2:
            continue

        if want_pos:
            #positive example pick i whre next exists
            i = rng.randrange(len(doc) - 1)
            a, b = doc[i], doc[i + 1]
            all_pairs.append({"text_a": a, "text_b": b, "label": 1})
        else:
            #negative example
            a_idx = rng.randrange(len(doc))
            a = doc[a_idx]
            if use_hard_neg and len(doc) > 2 and rng.random() < 0.5:
                #same doc, but not adgancet
                candidates = [j for j in range(len(doc)) if j not in {a_idx-1, a_idx, a_idx+1} and j >= 0 and j < len(doc)]
                if not candidates:
                    continue
                j = rng.choice(candidates)
                b = doc[j]
            else:
                #different doc
                other_doc_idx = rng.randrange(len(docs))
                while other_doc_idx == d_i:
                    other_doc_idx = rng.randrange(len(docs))
                other_doc = docs[other_doc_idx]
                b = rng.choice(other_doc)
            all_pairs.append({"text_a": a, "text_b": b, "label": 0})

    rng.shuffle(all_pairs)
    return all_pairs

In [6]:
# small splits for demo
train_pairs = build_nsp_pairs(docs_sents, num_pairs=800, pos_ratio=0.5, use_hard_neg=True)
val_pairs   = build_nsp_pairs(docs_sents, num_pairs=200, pos_ratio=0.5, use_hard_neg=True)
print(f"[Build] train_pairs={len(train_pairs)}, val_pairs={len(val_pairs)}")

[Build] train_pairs=800, val_pairs=200


# 4. Datasets yielding raw sentence pairs

In [18]:
class NSPPairDataset(Dataset):
    """
    Each item is a dict with raw strings
    Collator will tokenize and build tensors
    """
    def __init__(self, pairs):
        self.pairs = pairs
    def __len__(self):
        return len(self.pairs)
    def __getitem__(self, idx):
        return self.pairs[idx]

train_ds = NSPPairDataset(train_pairs)
val_ds   = NSPPairDataset(val_pairs)

# 5. Collator A: NSP-only (no MLM)

In [19]:
@dataclass
class NSPOnlyCollator:
    """
    Tokenize sentence pairs for NSP-only training
    Output keys:
        - input_ids       [B, L]
        - attention_mask  [B, L]
        - token_type_ids  [B, L]
        - next_sentence_label [B] (0/1)
    """
    tokenizer: any
    max_length: int = 128

    def __call__(self, batch):
        a_list = [x["text_a"] for x in batch]
        b_list = [x["text_b"] for x in batch]
        labels = torch.tensor([x["label"] for x in batch], dtype=torch.long)
        enc = self.tokenizer(
            a_list, b_list,
            padding="longest",
            truncation=True,
            max_length=self.max_length,
            return_tensors="pt"
        )
        enc["next_sentence_label"] = labels
        return enc

# 6. Train NSP-only with BertForNextSentencePredition

In [20]:
tokenizer = AutoTokenizer.from_pretrained("bert-base-uncased", use_fast=True)
collator_nsp = NSPOnlyCollator(tokenizer=tokenizer, max_length=64)

model_nsp = BertForNextSentencePrediction.from_pretrained("bert-base-uncased").to(DEVICE)

args_nsp = TrainingArguments(
    output_dir="out_nsp",
    per_device_train_batch_size=16,
    per_device_eval_batch_size=16,
    num_train_epochs=2,
    learning_rate=5e-5,
    weight_decay=0.01,
    optim="adamw_torch",
    warmup_ratio=0.1,
    eval_strategy="epoch",
    logging_strategy="steps",
    logging_steps=50,
    save_strategy="no",
    report_to=[],
    seed=SEED,
    dataloader_pin_memory=False,
    remove_unused_columns=False,
    label_names=["next_sentence_label"]
)

def nsp_accuracy(eval_pred):
    """
    predictions: [N, 2]
    labels_ids: [N]
    """
    if hasattr(eval_pred, "predictions"):
        preds = eval_pred.predictions
        labels = eval_pred.label_ids
    else:
        preds, labels = eval_pred
    if isinstance(preds, (tuple, list)):
        preds = preds[0]
    pred = torch.from_numpy(preds).argmax(-1)
    labels = torch.from_numpy(labels)
    acc = (pred == labels).float().mean().item()
    return {"nsp_acc": acc}

trainer_nsp = Trainer(
    model=model_nsp,
    args=args_nsp,
    train_dataset=train_ds,
    eval_dataset=val_ds,
    data_collator=collator_nsp,
    tokenizer=tokenizer,
    compute_metrics=nsp_accuracy,
)

print("\n[Train-NSP] Starting …")
trainer_nsp.train()
print("[Eval-NSP] Final:", trainer_nsp.evaluate())

  trainer_nsp = Trainer(



[Train-NSP] Starting …


Epoch,Training Loss,Validation Loss,Nsp Acc
1,0.931,0.019237,0.995
2,0.041,0.000181,1.0


[Eval-NSP] Final: {'eval_loss': 0.0001811510737752542, 'eval_nsp_acc': 1.0, 'eval_runtime': 0.3052, 'eval_samples_per_second': 655.267, 'eval_steps_per_second': 42.592, 'epoch': 2.0}


# 7. Collator B: MLM + NSP (dynamic MLM 15% + 80/10/10)

In [30]:
@dataclass
class PretrainCollator:
    """
    Build inputs for MLM + NSP:
        - Tokenize pair (A, B) with BERT template
        - Dynamic MLM (15% + 80/10/10) to produce "labels" (MLM) and masked input_ids
        - Provide "next_sentence_label" (0/1) for NSP
    Outputs:
        inputs_ids [B, L] (masked)
        attention_mask [B, L]
        token_type_ids [B, L]
        labels [B, L] (MLM: -100 for non-masked)
        next_sentence_label [B]
    """
    tokenizer: any
    mlm_probability: float = 0.15
    max_length: int = 128

    def __call__(self, batch):
        a_list = [x["text_a"] for x in batch]
        b_list = [x["text_b"] for x in batch]
        nsp = torch.tensor([x["label"] for x in batch], dtype=torch.long)

        enc = self.tokenizer(
            a_list,
            b_list,
            padding="longest",
            truncation=True,
            max_length=self.max_length,
            return_special_tokens_mask=True,
            return_attention_mask=True,
            return_tensors="pt",
        )
        input_ids = enc["input_ids"]                     #[B, L]
        attention_mask = enc["attention_mask"]           #[B, L]
        special_mask = enc["special_tokens_mask"].bool()  #[B, L]

        labels = torch.full_like(input_ids, -100)
        candidate = (~special_mask) & attention_mask.bool()

        #Use Bernuolli (can change to WWM/Span)
        probs = torch.full_like(input_ids, self.mlm_probability, dtype=torch.float32)
        chosen = (torch.bernoulli(probs).bool()) & candidate

        labels[chosen] = input_ids[chosen]

        r = torch.rand_like(input_ids, dtype=torch.float32)
        replace_mask = chosen & (r < 0.8)
        replace_rand = chosen & (r >= 0.8) & (r < 0.9)

        masked = input_ids.clone()
        masked[replace_mask] = self.tokenizer.mask_token_id

        if replace_rand.any():
            vocab_size = self.tokenizer.vocab_size
            rand_ids = torch.randint(low=0, high=vocab_size, size=masked.shape, dtype=torch.long)
            specials = set(self.tokenizer.all_special_ids)
            bad = torch.isin(rand_ids, torch.tensor(list(specials)))
            if bad.any():
                rand_ids[bad] = self.tokenizer.unk_token_id
            masked[replace_rand] = rand_ids[replace_rand]

        enc_out = {
            "input_ids": masked,
            "attention_mask": attention_mask,
            "labels": labels,
            "next_sentence_label": nsp
        }

        if "token_type_ids" in enc:
            enc_out["token_type_ids"] = enc["token_type_ids"]
        return enc_out

# 8. Train MLM + NSP with BertForPreTraining

In [31]:
collator_pre = PretrainCollator(tokenizer=tokenizer, mlm_probability=0.15, max_length=64)
model_pre = BertForPreTraining.from_pretrained("bert-base-uncased").to(DEVICE)

args_pre = TrainingArguments(
    output_dir="out_pretrain",
    per_device_train_batch_size=16,
    per_device_eval_batch_size=16,
    num_train_epochs=2,            # demo
    learning_rate=5e-5,
    weight_decay=0.01,
    optim="adamw_torch",
    warmup_ratio=0.1,
    eval_strategy="epoch",
    logging_strategy="steps",
    logging_steps=50,
    save_strategy="no",
    report_to=[],
    seed=SEED,
    dataloader_pin_memory=False,
    remove_unused_columns=False,
    # return both labels
    label_names=["labels", "next_sentence_label"],
)


def metrics_pretrain(eval_pred):
    """
    For BertForPreTraining:
      predictions is a tuple: (prediction_scores [N,L,V], seq_relationship_logits [N,2])
      label_ids corresponds to label_names: (mlm_labels [N,L], nsp_labels [N])
    We'll compute:
      - masked-token accuracy (on MLM)
      - NSP accuracy
    """
    preds, labels = eval_pred
    # preds
    if isinstance(preds, (tuple, list)) and len(preds) >= 2:
        mlm_logits, nsp_logits = preds[0], preds[1]
    else:
        return {"masked_acc": 0.0, "nsp_acc": 0.0}

    # labels
    if isinstance(labels, (tuple, list)) and len(labels) >= 2:
        mlm_labels, nsp_labels = labels[0], labels[1]
    else:
        mlm_labels, nsp_labels = labels, None

    # MLM masked-token accuracy
    mlm_pred = mlm_logits.argmax(-1)               # [N,L]
    mask = (mlm_labels != -100)
    mt_correct = ((mlm_pred == mlm_labels) & mask).sum()
    mt_total = mask.sum().clip(min=1)
    masked_acc = (mt_correct / mt_total).item()

    # NSP accuracy
    if nsp_labels is not None:
        nsp_pred = nsp_logits.argmax(-1)
        nsp_acc = (nsp_pred == nsp_labels).mean().item()
    else:
        nsp_acc = 0.0

    return {"masked_acc": masked_acc, "nsp_acc": nsp_acc}

trainer_pre = Trainer(
    model=model_pre,
    args=args_pre,
    train_dataset=train_ds,
    eval_dataset=val_ds,
    data_collator=collator_pre,
    tokenizer=tokenizer,
    compute_metrics=metrics_pretrain,
)

print("\n[Train-Pretraining] Starting …")
trainer_pre.train()
print("[Eval-Pretraining] Final:", trainer_pre.evaluate())

  trainer_pre = Trainer(



[Train-Pretraining] Starting …


Epoch,Training Loss,Validation Loss,Masked Acc,Nsp Acc
1,3.2938,0.710997,0.908847,0.965
2,0.5837,0.355018,0.948229,0.99


[Eval-Pretraining] Final: {'eval_loss': 0.2849549651145935, 'eval_masked_acc': 0.9342465753424658, 'eval_nsp_acc': 0.99, 'eval_runtime': 3.3367, 'eval_samples_per_second': 59.939, 'eval_steps_per_second': 3.896, 'epoch': 2.0}
