In [4]:
from transformers import (
    BertConfig,
    BertForMaskedLM,
    PreTrainedTokenizerFast,
    DataCollatorForLanguageModeling,
    Trainer,
    TrainingArguments,
)
from datasets import load_dataset, Dataset, DatasetDict
import torch

# === 加载模型和数据 ===
tokenizer = PreTrainedTokenizerFast.from_pretrained(
    "dataset/tokenizer/custom_tokenizer"
)
tokenized_dataset = load_from_disk("dataset/huggingface/custom_tokenized")
config = BertConfig(
    vocab_size=tokenizer.vocab_size,
    hidden_size=384,
    num_hidden_layers=12,
    num_attention_heads=8,
    max_position_embeddings=64,
    type_vocab_size=1,
)
model = BertForMaskedLM(config)

In [5]:
# 拆分验证集和测试集
split_dataset = tokenized_dataset.train_test_split(test_size=0.2, seed=42)
valid_test_split = split_dataset["test"].train_test_split(test_size=0.5, seed=42)
tokenized_dataset = DatasetDict(
    {
        "train": split_dataset["train"],
        "valid": valid_test_split["train"],
        "test": valid_test_split["test"],
    }
)

# === 构造小 batch（例如 4 个样本）===
batch = tokenized_dataset["train"].select(range(4))
input_ids = torch.tensor(batch["input_ids"])
attention_mask = torch.tensor(batch["attention_mask"])

In [6]:
# === 构造 MLM 标签 ===
labels = input_ids.clone()
rand = torch.rand(input_ids.shape)
mask_arr = (
    (rand < 0.15)
    & (input_ids != tokenizer.pad_token_id)
    & (input_ids != tokenizer.cls_token_id)
    & (input_ids != tokenizer.sep_token_id)
)
labels[~mask_arr] = -100



In [7]:
import os 
os.environ["CUDA_VISIBLE_DEVICES"] = "3"
# === 送入模型设备（GPU）===
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model.to(device)
input_ids = input_ids.to(device)
attention_mask = attention_mask.to(device)
labels = labels.to(device)

# === 前向传播 + 反向传播 ===
model.train()
outputs = model(input_ids=input_ids, attention_mask=attention_mask, labels=labels)
loss = outputs.loss
loss.backward()

# === 计算梯度范数（用于判断是否爆炸）===
total_norm = 0.0
for p in model.parameters():
    if p.grad is not None:
        param_norm = p.grad.data.norm(2)
        total_norm += param_norm.item() ** 2
total_norm = total_norm**0.5

print(f"Loss: {loss.item():.4f}")
print(f"Grad Norm: {total_norm:.2f}")

Loss: 12.1316
Grad Norm: 11.24
