# 使用 PyTorch 微调 BERT

从 Hugging Face Hub 上加载预训练的 BERT 模型，然后使用 PyTorch 纯手工对其进行微调，设定如下：
- 预训练模型：bert-base-uncased
- 下游任务：GLUE/SST-2

使用 PyTorch 微调 BERT 需要以下步骤：
- 数据预处理：加载数据集并定义 `Dataset` 和 `DataLoader`
- 模型定义：给 BERT 基础模型添加一个全连接层作为分类头
- 模型微调：使用 AdamW 优化器对模型进行微调
- 模型验证：计算模型在验证集上的准确率

In [1]:
import torch
from torch import nn
from torch.utils.data import DataLoader
from sklearn.metrics import confusion_matrix, classification_report
from transformers import (
    BertTokenizerFast,
    BertModel,
    DataCollatorWithPadding,
    set_seed
)
from datasets import load_dataset
from torch.utils.tensorboard import SummaryWriter
from datetime import datetime
from typing import Callable
from tqdm.notebook import tqdm

一些超参数

In [2]:
batch_size = 64
epochs = 1
learning_rate = 5e-5
device = "cuda:1"
set_seed(42)

## 加载数据集

使用 `datasets.load_dataset` 从 Hugging Face Hub 上加载 GLUE/SST-2 任务的数据集

In [3]:
raw_datasets = load_dataset("glue", "sst2")
raw_datasets

Using the latest cached version of the module from /home/wh/.cache/huggingface/modules/datasets_modules/datasets/glue/dacbe3125aa31d7f70367a07a8a9e72a5a0bfeb5fc42e75c9db75b96da6053ad (last modified on Thu Mar 31 13:53:16 2022) since it couldn't be found locally at glue.
Reusing dataset glue (/home/wh/.cache/huggingface/datasets/glue/sst2/1.0.0/dacbe3125aa31d7f70367a07a8a9e72a5a0bfeb5fc42e75c9db75b96da6053ad)


  0%|          | 0/3 [00:00<?, ?it/s]

DatasetDict({
    train: Dataset({
        features: ['sentence', 'label', 'idx'],
        num_rows: 67349
    })
    validation: Dataset({
        features: ['sentence', 'label', 'idx'],
        num_rows: 872
    })
    test: Dataset({
        features: ['sentence', 'label', 'idx'],
        num_rows: 1821
    })
})

对原始数据集分词

In [4]:
tokenizer = BertTokenizerFast.from_pretrained("bert-base-uncased", use_fast=True)

def preprocessing(examples):
    """用于分词的预处理程序"""
    return tokenizer(examples["sentence"], padding="max_length", max_length=60, truncation=True)

tokenized_datasets = raw_datasets.map(preprocessing, batched=True)
tokenized_datasets

Loading cached processed dataset at /home/wh/.cache/huggingface/datasets/glue/sst2/1.0.0/dacbe3125aa31d7f70367a07a8a9e72a5a0bfeb5fc42e75c9db75b96da6053ad/cache-57abd33ab18d3d8f.arrow
Loading cached processed dataset at /home/wh/.cache/huggingface/datasets/glue/sst2/1.0.0/dacbe3125aa31d7f70367a07a8a9e72a5a0bfeb5fc42e75c9db75b96da6053ad/cache-f1701a10d10b29a1.arrow
Loading cached processed dataset at /home/wh/.cache/huggingface/datasets/glue/sst2/1.0.0/dacbe3125aa31d7f70367a07a8a9e72a5a0bfeb5fc42e75c9db75b96da6053ad/cache-db86549a2685890a.arrow


DatasetDict({
    train: Dataset({
        features: ['sentence', 'label', 'idx', 'input_ids', 'token_type_ids', 'attention_mask'],
        num_rows: 67349
    })
    validation: Dataset({
        features: ['sentence', 'label', 'idx', 'input_ids', 'token_type_ids', 'attention_mask'],
        num_rows: 872
    })
    test: Dataset({
        features: ['sentence', 'label', 'idx', 'input_ids', 'token_type_ids', 'attention_mask'],
        num_rows: 1821
    })
})

从数据集中取出训练集和验证集，并移除训练过程中不需要的 `sentence` 和 `idx` 字段

In [5]:
train_dataset = tokenized_datasets["train"].remove_columns(["sentence", "idx"])
eval_dataset = tokenized_datasets["validation"].remove_columns(["sentence", "idx"])
train_dataset

Dataset({
    features: ['label', 'input_ids', 'token_type_ids', 'attention_mask'],
    num_rows: 67349
})

定义 `DataLoader`

In [6]:
data_collator = DataCollatorWithPadding(tokenizer)
train_dataloader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True, collate_fn=data_collator)
eval_dataloader = DataLoader(eval_dataset, batch_size=batch_size, shuffle=True, collate_fn=data_collator)

In [7]:
batch = next(iter(train_dataloader))
batch.keys()

dict_keys(['input_ids', 'token_type_ids', 'attention_mask', 'labels'])

## 模型定义

In [8]:
class BertForSST2(nn.Module):
    def __init__(self, model_name: str, dropout: float=0.5, use_pooled_output: bool=True):
        super(BertForSST2, self).__init__()
        self.use_pooled_output = use_pooled_output
        self.bert = BertModel.from_pretrained(model_name)
        self.dropout = nn.Dropout(dropout)
        self.classifier = nn.Sequential(
            nn.Linear(768, 2),
            nn.Softmax(dim=1)
        )


    def forward(self, input_ids, attention_mask, token_type_ids):
        outputs = self.bert(
            input_ids,
            attention_mask=attention_mask,
            token_type_ids=token_type_ids
        )
        if self.use_pooled_output:
            cls_representation = outputs[1]
        else:
            cls_representation = outputs[0][:, 0]
        return self.classifier(self.dropout(cls_representation))

In [9]:
model = BertForSST2("bert-base-uncased", use_pooled_output=True)
no_pooled_model = BertForSST2("bert-base-uncased", use_pooled_output=False)

Some weights of the model checkpoint at bert-base-uncased were not used when initializing BertModel: ['cls.predictions.decoder.weight', 'cls.predictions.transform.LayerNorm.weight', 'cls.predictions.transform.dense.bias', 'cls.seq_relationship.weight', 'cls.seq_relationship.bias', 'cls.predictions.transform.dense.weight', 'cls.predictions.transform.LayerNorm.bias', 'cls.predictions.bias']
- This IS expected if you are initializing BertModel from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).
- This IS NOT expected if you are initializing BertModel from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).
Some weights of the model checkpoint at bert-base-uncased were not used when initializing BertModel: ['cls.predictions.decoder.weight', 'cls.predictions.tr

## 模型微调

In [10]:
def test_loop(
    model: nn.Module,
    dataloader: DataLoader,
    loss_fn: Callable,
    device: str,
):
    size = len(dataloader.dataset)
    num_batches = len(dataloader)
    test_loss, correct = 0, 0
    model.eval()
    with torch.no_grad():
        for X in dataloader:
            X.to(device)
            pred = model(X.input_ids, X.attention_mask, X.token_type_ids)
            test_loss += loss_fn(pred, X.labels).item()
            correct += (pred.argmax(1) == X.labels).type(torch.float).sum().item()

    test_loss /= num_batches
    test_acc = 100 * (correct / size)
    return test_loss, test_acc


def train_loop(
    model: nn.Module, 
    train_dataloader: DataLoader,
    eval_dataloader: DataLoader,
    loss_fn: Callable,
    optimizer,
    lr_scheduler,
    device: str,
    writer: SummaryWriter,
    epoch: int,
):
    size = len(train_dataloader.dataset)
    num_batches = len(train_dataloader)
    loop = tqdm(enumerate(train_dataloader), total =len(train_dataloader))
    loop.set_description(f'Epoch [{epoch}/{epochs}]')
    model.train()
    for batch, X in loop:
        X.to(device)
        # 前向传播并计算loss
        pred = model(X.input_ids, X.attention_mask, X.token_type_ids)
        loss = loss_fn(pred, X.labels)
        # 反向传播并优化模型参数
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
        # 让进度条显示 acc 和 loss
        acc = 100 * (pred.argmax(1) == X.labels).type(torch.float).sum().item() / X.input_ids.size()[0]
        loop.set_postfix(loss=loss.item(), acc=acc, lr=optimizer.param_groups[0]["lr"])
        # 写入 TensorBoard
        if batch % 50 == 0:
            global_step = epoch * num_batches + batch
            test_loss, test_acc = test_loop(model, eval_dataloader, loss_fn, device)
            writer.add_scalar("Loss/train", loss.item(), global_step)
            writer.add_scalar("Acc/train", acc, global_step)
            writer.add_scalar("Loss/test", test_loss, global_step)
            writer.add_scalar("Acc/test", test_acc, global_step)
            writer.add_scalar("Learning rate", optimizer.param_groups[0]["lr"], global_step)
            lr_scheduler.step()

In [11]:
now = datetime.now().strftime("%Y-%m-%d %H:%M:%S")
writer = SummaryWriter(f"./logs/base-base-uncased-sst2-{now}")
no_pooled_writer = SummaryWriter(f"./logs/base-base-uncased-sst2-no-pooled-{now}")
writer.add_graph(model, [batch["input_ids"], batch["attention_mask"], batch["token_type_ids"]])
no_pooled_writer.add_graph(no_pooled_model, [batch["input_ids"], batch["attention_mask"], batch["token_type_ids"]])

In [12]:
loss_fn = nn.CrossEntropyLoss()
optimizer = torch.optim.AdamW(model.parameters(), lr=learning_rate, weight_decay=.0)
lr_scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=1, gamma=0.9)

model.to(device)
for t in range(epochs):
    train_loop(model, train_dataloader, eval_dataloader, loss_fn, optimizer, lr_scheduler, device, writer, t)
    test_loss, test_acc = test_loop(model, eval_dataloader, loss_fn, device)
    print(f"Acc={test_acc:.4f} Loss={test_loss:.4f}")
print("Done!")

  0%|          | 0/1053 [00:00<?, ?it/s]

Acc=91.1697 Loss=0.3959
Done!


In [13]:
loss_fn = nn.CrossEntropyLoss()
optimizer = torch.optim.AdamW(no_pooled_model.parameters(), lr=learning_rate, weight_decay=.0)
lr_scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=1, gamma=0.9)

no_pooled_model.to(device)
for t in range(epochs):
    train_loop(no_pooled_model, train_dataloader, eval_dataloader, loss_fn, optimizer, lr_scheduler, device, no_pooled_writer, t)
    test_loss, test_acc = test_loop(no_pooled_model, eval_dataloader, loss_fn, device)
    print(f"Acc={test_acc:.4f} Loss={test_loss:.4f}")
print("Done!")

  0%|          | 0/1053 [00:00<?, ?it/s]

Acc=92.3165 Loss=0.3875
Done!


## 模型验证

In [17]:
def predict(
    model: nn.Module,
    dataloader: DataLoader,
    device: str,
):
    y_true = []
    y_pred = []
    model.eval()
    with torch.no_grad():
        for X in dataloader:
            X.to(device)
            pred = model(X.input_ids, X.attention_mask, X.token_type_ids)
            y_true.extend(X.labels.tolist())
            y_pred.extend(pred.argmax(1).tolist())
    return y_true, y_pred

In [18]:
y_true, y_pred = predict(model, eval_dataloader, device)
print(confusion_matrix(y_true, y_pred))
print(classification_report(y_true, y_pred, labels=[0, 1], target_names=["negative", "positive"]))

[[378  50]
 [ 27 417]]
              precision    recall  f1-score   support

    negative       0.93      0.88      0.91       428
    positive       0.89      0.94      0.92       444

    accuracy                           0.91       872
   macro avg       0.91      0.91      0.91       872
weighted avg       0.91      0.91      0.91       872



In [19]:
y_true, y_pred = predict(no_pooled_model, eval_dataloader, device)
print(confusion_matrix(y_true, y_pred))
print(classification_report(y_true, y_pred, labels=[0, 1], target_names=["negative", "positive"]))

[[398  30]
 [ 37 407]]
              precision    recall  f1-score   support

    negative       0.91      0.93      0.92       428
    positive       0.93      0.92      0.92       444

    accuracy                           0.92       872
   macro avg       0.92      0.92      0.92       872
weighted avg       0.92      0.92      0.92       872

