# Transformers 微调训练



- **Transformers 微调步骤**

```mermaid
flowchart LR
    subgraph 1.数据处理
        direction LR
        下载数据集 --> 数据预处理 --> 数据抽样
    end
    subgraph 2.微调配置
        direction LR
        训练超参数设置 --> 训练评估指标设置 --> 训练器配置
    end
    subgraph 3.模型训练
        direction LR
        开始训练 
    end
    subgraph 4.保存模型
        direction LR
        保存模型 --> 保存训练状态
    end
    
    
    startNode((开始)):::startClass --> 1.数据处理 --> 2.微调配置 -->  3.模型训练 --> 4.保存模型 --> endNode((结束)):::endClass
    classDef startClass fill:#4caf50;
    classDef endClass fill:#f44336;
```

## 1. 数据处理

### 1.1 下载数据集

YelpReviewFull 数据集是一个经典的情感分析数据集，包含了大量来自 Yelp 的评论。数据集从 Yelp Dataset Challenge 2015 数据中提取，主要用于文本分类任务，目标是预测评论的情感分数。数据集的评论主要用英语编写，适合进行情感分类研究。

- 数据集首页:[https://huggingface.co/datasets/Yelp/yelp_review_full](https://huggingface.co/datasets/Yelp/yelp_review_full)



In [7]:
from datasets import load_dataset

dataset = load_dataset("Yelp/yelp_review_full")

In [8]:
dataset

DatasetDict({
    train: Dataset({
        features: ['label', 'text'],
        num_rows: 650000
    })
    test: Dataset({
        features: ['label', 'text'],
        num_rows: 50000
    })
})

In [7]:
dataset["train"][0]

{'label': 4,
 'text': "dr. goldberg offers everything i look for in a general practitioner.  he's nice and easy to talk to without being patronizing; he's always on time in seeing his patients; he's affiliated with a top-notch hospital (nyu) which my parents have explained to me is very important in case something happens and you need surgery; and you can get referrals to see specialists without having to see him first.  really, what more do you need?  i'm sitting here trying to think of any complaints i have about him, but i'm really drawing a blank."}

- 数据随机抽样

In [9]:
import pandas as pd
import random
import datasets
from IPython.display import display, HTML

def show_random_elements(dataset, num_examples = 10):
    picks = []
    for _ in range(num_examples):
        pick = random.randint(0, len(dataset) - 1)
        while pick in picks:
            pick = random.randint(0, len(dataset) - 1)
        picks.append(pick)
    df = pd.DataFrame(dataset[picks])
    for col, typ in dataset.features.items():
        if isinstance(typ, datasets.ClassLabel):
            df[col] = df[col].transform(lambda i: typ.names[i])
    display(HTML(df.to_html()))


In [10]:
show_random_elements(dataset["train"], 2)

Unnamed: 0,label,text
0,4 stars,Yummy food and nice atmosphere with wonderful service.
1,1 star,"Went there last night in hopes of finding some decent Mexican Food. Unfortunately, this was not the case. First of all, the service was lacking. It took us asking twice to just get a container of salt to the table. Also, the waitress was less than interested in helping us with anything more than taking our order. \n\nSecond, the food was atrocious. You know you have an issue if you need to add salt to Mexican food. The rice was dry (almost hard and old) and completely tasteless. The beans were bland and tasted like paste. There was barely any cheese in the enchilada I had. My husband's tacos wee haphazardly put together, falling over and impossible to pick up, and needless to say did not taste good. \n\nThe only saving grace was the shredded beef burrito and the Tomatillo-Avocado salsa. No, I will not be going back, even with the Yelp offer."


### 1.2 数据预处理
下载数据集到本地后，使用Tokenizer来处理文本，对于长度不等的输入数据，可以使用填充（padding）和截断（truncation）策略来处理。


In [11]:
from transformers import AutoTokenizer

tokenizer = AutoTokenizer.from_pretrained("google-bert/bert-base-cased");

def tokenize_function(examples):
    return tokenizer(examples["text"], padding="max_length", truncation=True)

tokenized_datasets = dataset.map(tokenize_function, batched=True)

Map:   0%|          | 0/50000 [00:00<?, ? examples/s]

In [12]:
tokenized_datasets

DatasetDict({
    train: Dataset({
        features: ['label', 'text', 'input_ids', 'token_type_ids', 'attention_mask'],
        num_rows: 650000
    })
    test: Dataset({
        features: ['label', 'text', 'input_ids', 'token_type_ids', 'attention_mask'],
        num_rows: 50000
    })
})

In [15]:
show_random_elements(tokenized_datasets["train"], 2)

Unnamed: 0,label,text,input_ids,token_type_ids,attention_mask
0,3 stars,"I wasn't too impressed...I thought I was at the wrong place after going to their website! I guess I wasn't expecting a bar, and was expecting a more classier place.\n\nThat aside, the food was okay. I had the Caprese salad which was good and the Asparagus was good. Nothing that wow'ed me. My friend had some kind of pasta...he didn't like the pasta. It was okay.\n\nCheesecake was good though!! :) Overall, eh...not my type of thing to dine on but I gave it a try","[101, 146, 1445, 112, 189, 1315, 7351, 119, 119, 119, 146, 1354, 146, 1108, 1120, 1103, 2488, 1282, 1170, 1280, 1106, 1147, 3265, 106, 146, 3319, 146, 1445, 112, 189, 7805, 170, 2927, 117, 1105, 1108, 7805, 170, 1167, 1705, 2852, 1282, 119, 165, 183, 165, 183, 1942, 11220, 4783, 117, 1103, 2094, 1108, 3008, 119, 146, 1125, 1103, 17212, 4894, 1162, 19359, 1134, 1108, 1363, 1105, 1103, 1249, 17482, 28026, 1116, 1108, 1363, 119, 4302, 1115, 192, 4064, 112, 5048, 1143, 119, 1422, 1910, 1125, 1199, 1912, 1104, 1763, 1161, 119, 119, 119, 1119, 1238, 112, 189, 1176, 1103, ...]","[0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, ...]","[1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, ...]"
1,2 star,"1 star for being accommodating and seating us during peak time without a reservation. 1 star for the decent sangaria (I asked for white, but they brought a red one...oh well, it was not bad). The staff was running around the place like crazy. We started with the goat cheese salad, but the cheese was Brie (although the owner tried to assure us it was goat...no way!). Main course was file mignon....I know my steaks and that was at best a sirloin (bad quality sirloin). Well cooked but very chewy and not edible! The other course was an overlooked sole, tasted fresh, but very overlooked. Surprised by the 4-star Yelp rating. Disappointed this was our last dinner in MTL. Had a great dinning experience otherwise at L'express and Bonaparte.","[101, 122, 2851, 1111, 1217, 170, 14566, 6262, 16848, 1916, 1105, 11051, 1366, 1219, 4709, 1159, 1443, 170, 15702, 119, 122, 2851, 1111, 1103, 11858, 6407, 11315, 113, 146, 1455, 1111, 1653, 117, 1133, 1152, 1814, 170, 1894, 1141, 119, 119, 119, 9294, 1218, 117, 1122, 1108, 1136, 2213, 114, 119, 1109, 2546, 1108, 1919, 1213, 1103, 1282, 1176, 4523, 119, 1284, 1408, 1114, 1103, 17497, 9553, 19359, 117, 1133, 1103, 9553, 1108, 139, 5997, 113, 1780, 1103, 3172, 1793, 1106, 14955, 1366, 1122, 1108, 17497, 119, 119, 119, 1185, 1236, 106, 114, 119, 4304, 1736, 1108, 4956, 1940, 25566, ...]","[0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, ...]","[1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, ...]"


### 1.3 数据抽样
在训练过程中，为了更好地控制训练过程，我们可以从数据集中抽样出一部分数据进行测试。



In [16]:
# 从数据集中抽样1000个训练数据集
small_train_dataset = tokenized_datasets["train"].shuffle(seed = 42).select(range(1000))

# 从数据集中抽样1000个测试样本
small_eval_dataset = tokenized_datasets["test"].shuffle(seed = 42).select(range(1000))

## 2. 微调配置

### 2.1 训练超参数设置

- 完整配置参数与默认值：[https://huggingface.co/docs/transformers/v4.36.1/en/main_classes/trainer#transformers.TrainingArguments](https://huggingface.co/docs/transformers/v4.36.1/en/main_classes/trainer#transformers.TrainingArguments)

- 源代码定义：[https://github.com/huggingface/transformers/blob/v4.36.1/src/transformers/training_args.py#L161](https://github.com/huggingface/transformers/blob/v4.36.1/src/transformers/training_args.py#L161)


In [17]:
from transformers import TrainingArguments

# 模型保存路径
output_dir = "models/bert-base-cased"

training_args = TrainingArguments(
    output_dir = output_dir,
    logging_steps = 100  # 每 100 步记录一次日志
)


In [18]:
print(training_args)

TrainingArguments(
_n_gpu=1,
accelerator_config={'split_batches': False, 'dispatch_batches': None, 'even_batches': True, 'use_seedable_sampler': True, 'non_blocking': False, 'gradient_accumulation_kwargs': None, 'use_configured_state': False},
adafactor=False,
adam_beta1=0.9,
adam_beta2=0.999,
adam_epsilon=1e-08,
auto_find_batch_size=False,
average_tokens_across_devices=False,
batch_eval_metrics=False,
bf16=False,
bf16_full_eval=False,
data_seed=None,
dataloader_drop_last=False,
dataloader_num_workers=0,
dataloader_persistent_workers=False,
dataloader_pin_memory=True,
dataloader_prefetch_factor=None,
ddp_backend=None,
ddp_broadcast_buffers=None,
ddp_bucket_cap_mb=None,
ddp_find_unused_parameters=None,
ddp_timeout=1800,
debug=[],
deepspeed=None,
disable_tqdm=False,
dispatch_batches=None,
do_eval=False,
do_predict=False,
do_train=False,
eval_accumulation_steps=None,
eval_delay=0,
eval_do_concat_batches=True,
eval_on_start=False,
eval_steps=None,
eval_strategy=no,
eval_use_gather_object=F

### 2.2 训练评估指标设置(Evaluate)

- 完整的评估指标：[https://huggingface.co/evaluate-metric](https://huggingface.co/evaluate-metric)



In [19]:
import numpy as np
import evaluate

metric = evaluate.load("accuracy")

def compute_metrics(eval_pred):
    logits, labels = eval_pred
    predictions = np.argmax(logits, axis=-1)
    return metric.compute(predictions=predictions, references=labels)

### 2.3 训练器配置

- 模型首页：[https://huggingface.co/google-bert/bert-base-multilingual-cased](https://huggingface.co/google-bert/bert-base-multilingual-cased)

In [24]:
from transformers import AutoModelForSequenceClassification, Trainer

# 加载模型
model = AutoModelForSequenceClassification.from_pretrained("google-bert/bert-base-cased", num_labels=5)

trainer = Trainer(
    model = model,
    args = training_args,
    train_dataset = small_train_dataset, # 训练数据集
    eval_dataset = small_eval_dataset, # 验证数据集
    compute_metrics = compute_metrics # 计算指标的函数
) 

Some weights of BertForSequenceClassification were not initialized from the model checkpoint at google-bert/bert-base-cased and are newly initialized: ['classifier.bias', 'classifier.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.


## 3. 模型训练


### 3.1 开始训练

使用 Trainer 类的 train 方法开始训练模型：

In [25]:
trainer.train()

Step,Training Loss
100,1.4366
200,1.0379
300,0.7866


TrainOutput(global_step=375, training_loss=0.9836104227701823, metrics={'train_runtime': 163.7982, 'train_samples_per_second': 18.315, 'train_steps_per_second': 2.289, 'total_flos': 789354427392000.0, 'train_loss': 0.9836104227701823, 'epoch': 3.0})

**使用 nvidia-smi 监控 GPU 使用**

```shell
watch -n 1 nvidia-smi
```


```
+-----------------------------------------------------------------------------------------+
| NVIDIA-SMI 560.35.02              Driver Version: 560.94         CUDA Version: 12.6     |
|-----------------------------------------+------------------------+----------------------+
| GPU  Name                 Persistence-M | Bus-Id          Disp.A | Volatile Uncorr. ECC |
| Fan  Temp   Perf          Pwr:Usage/Cap |           Memory-Usage | GPU-Util  Compute M. |
|                                         |                        |               MIG M. |
|=========================================+========================+======================|
|   0  NVIDIA GeForce RTX 4060 ...    On  |   00000000:01:00.0  On |                  N/A |
| N/A   69C    P0             74W /   78W |    4422MiB /   8188MiB |    100%      Default |
|                                         |                        |                  N/A |
+-----------------------------------------+------------------------+----------------------+

+-----------------------------------------------------------------------------------------+
| Processes:                                                                              |
|  GPU   GI   CI        PID   Type   Process name                              GPU Memory |
|        ID   ID                                                               Usage      |
|=========================================================================================|
|    0   N/A  N/A      3017      C   /python3.10                                 N/A      |
+-----------------------------------------------------------------------------------------+

```

## 4. 保存模型

### 4.1 保存模型



In [28]:
# 模型默认保存到 output_dir = "models/bert-base-cased"
trainer.save_model()

### 4.2 保存训练状态

In [30]:
# 保存路径为：output_dir
trainer.save_state()

## 参考
- [https://huggingface.co/docs/transformers/training#train-with-pytorch-trainer](https://huggingface.co/docs/transformers/training#train-with-pytorch-trainer)
