## 基于transformers实现模型微调训练的主要流程，包括：
- 数据集下载
- 数据预处理
- 训练超参数配置
- 训练评估指标设置
- 训练器基本介绍
- 实战训练
- 模型保存

In [1]:
### import pkgs
from datasets import load_dataset

import random
import pandas as pd
import datasets
from IPython.display import display, HTML

from transformers import AutoTokenizer
from transformers import AutoModelForSequenceClassification
from transformers import TrainingArguments, Trainer
import numpy as np
import evaluate

In [2]:
#base 
model_dir = "models/bert-base-cased-finetune-yelp"
model_name_or_path = "hugging face model file path" # TODO

## function
def show_random_elements(dataset, num_examples=10):
    assert num_examples <= len(dataset), "Can't pick more elements than there are in the dataset."

    picks = []
    for _ in range(num_examples):
        pick = random.randint(0, len(dataset)-1)
        while pick in picks:
            pick = random.randint(0, len(dataset)-1)
        picks.append(pick)

    df = pd.DataFrame(dataset[picks])
    for column, typ in dataset.features.items():
        if isinstance(typ, datasets.ClassLabel):
            df[column]=df[column].transform(lambda i: typ.names[i])
    display(HTML(df.to_html()))

### 下载数据集

In [3]:
# 手动下载模型，使用huggingface-cli
# huggingface-cli download --resume-download bert-base-cased --local-dir bert-base-cased
dataset = load_dataset("yelp_review_full")

In [4]:
dataset

DatasetDict({
    train: Dataset({
        features: ['label', 'text'],
        num_rows: 650000
    })
    test: Dataset({
        features: ['label', 'text'],
        num_rows: 50000
    })
})

In [5]:
dataset["train"][100]

{'label': 0,
 'text': 'My expectations for McDonalds are t rarely high. But for one to still fail so spectacularly...that takes something special!\\nThe cashier took my friends\'s order, then promptly ignored me. I had to force myself in front of a cashier who opened his register to wait on the person BEHIND me. I waited over five minutes for a gigantic order that included precisely one kid\'s meal. After watching two people who ordered after me be handed their food, I asked where mine was. The manager started yelling at the cashiers for \\"serving off their orders\\" when they didn\'t have their food. But neither cashier was anywhere near those controls, and the manager was the one serving food to customers and clearing the boards.\\nThe manager was rude when giving me my order. She didn\'t make sure that I had everything ON MY RECEIPT, and never even had the decency to apologize that I felt I was getting poor service.\\nI\'ve eaten at various McDonalds restaurants for over 30 years. 

In [6]:
show_random_elements(dataset["train"])

Unnamed: 0,label,text
0,4 stars,Good food but you'll be surprised by your bill. Go to share a small order of fries with your small burger...seriously.
1,5 stars,"Just got done eating the most delicious Veggie Benny at Egg Works. I can never eat breakfast anywhere but here. There food is always so scrumptious and there staff is super friendly. Speaking of staff, Donna is quite the wonderful waitress. She was fun and attentive to our needs. We didn't even have to wait for a refill. I can't wait to come back next weekend."
2,4 stars,"Love this place...but I love frozen yogurt in general.\n\nThey have yogurt. They have toppings.\nThat's what you want when you want frozen yogurt, right?\n\nIf you're at the mall, why not stop by?\n\nWhy then did I rate this place so high since it's your typical frozen yogurt place?\nFor the taro.\n\nNo other frozen yogurt place that I've frequented around here offers that flavor. If I could just have taro, I would! If you haven't tried it, please do."
3,1 star,"I was in Las Vegas for my daughter's wedding. I arrived with plenty of time to grab my reserved rental car. There were about 15 people in line at the Payless counter and nobody at any of the other counters. I though ok, this won't take too long because there were three agents taking care of customers. There were two employees alerting Nevada residents that they needed an alternative proof of address other then their drivers licenses. This of course, made people angry because who is going to bring their bank statements, which they suggested, to the airport with them? Each agent also spent at least 20 minutes with each customer. The whole time there was a woman at the little EZ car rental business, who like a carnival barker, was yelling that there are no lines at EZ Car, get your car fast here. You don't need two forms of ID here! This barking was pretty much continuous during my entire hour wait. The EZ car rental was 1 station carved out of the end of the Payless counter.\n\nRight before it was my turn at the counter, while the people in the line were audibly complaining, calling customer service on their phones, etc., I noticed that the woman carnival barking for EZ Car kept going over to the Payless agents with paperwork as if they all worked for the same place. I was kind of confused. It dawned on me at that point that this company manufactures customer dissatisfaction in order to encourage customers to switch over to the \""other\"" easy car rental place, but without their pre-reserved on line discounts. I have a very strong suspicion that they are one and the same company.\n\nAt this point I was really cutting it close to being late for my daughter's wedding. Finally my turn, I walked up to the agent with drivers license, credit card, and reservation in hand and immediately explained that I didn't have time to talk at the counter for 20 minutes because I would be late for the wedding. She immediately informed me that there would be many questions she would have to ask me before she could rent me a car. I told her that I just want the car with no added insurance and I was in a hurry. Then she immediately told me that she couldn't rent me a car because I was being abusive!\n\nThis made me absolutely furious. I got my credit card and license back from her and went over to Enterprise, and not the \""other\"" company at the end of their counter. It took about three minutes there. As I walked past Payless the people who were behind me in line were still chatting at the counter. Happy to say that I made it to the wedding on time. Avoid Payless at all costs. Not sure if what they're doing is legal, but it sure smells like fraud to me.\n\nI"
4,2 star,"Stayed here overnight on trip back to California.\n\nPros:\n- Complementary breakfast\n- Free Wi-Fi\n- Comfortable beds & decent size room.\n- Spacious entry & nice modern look.\n\nCons:\n- Crazy small parking lot with almost no room! Had to park next to someone already double parked.\n- Hard to access with one way streets on each side.\n- * Free shuttle to the strip was nice, but they didn't pick up at the stated time, nor in the spot they said the would. We arrived 10 minutes before the said arrival at the Tropicanna (where they said they pick up) and they never came for the last pick up. I had to call the hotel. Al they said was they had already been there and would not come back. \n- Noisy from airport\n\nThe bottom line:\nNot a good shuttle service. Decent rooms but shuttle service was aweful! Felt like they lied to me. Left 4 of us to take a cab back. Sucks."
5,4 stars,"Amazing customer service, workers here actually care about helping you out! I was looking for an uncommon tri-wing screwdriver here and Maria (I believe this is her name) hunted down the someone who might know where to find it, and got back to me promptly. A++ WOULD SHOP AGAIN!"
6,3 stars,Had to stay here for couple hours while waiting for a flight back home.. it was okay.. nice clean and had plenty of restaurants and bars to choose from..
7,3 stars,"about 3.5 stars\n\nhands down much better than tao beach. that place doesn't let you drink while you're in the water, lame! and tao beach is tiny. granted, i didn't go to wet republic here, just their regular pools outside, but i think it was nice.\n\nmy friend and i bought our own drinks at fat tuesday, i got a 1/2 yard and she got a huge mug and met up with other friends at this pool. thanks goodness for knowing someone with a room key! i believe it's the only way to get in for free. it was great b.c they already had seats, which were in a shaded area.\n\nnow for the pool: once you step in, it's pretty refreshing. it may be a little too crowded, but get that and that fact that you are stewing in everyone's mess out of your mind, and you can have a pretty good time. i guess it doesn't help to have a couple drinks in hand.\n\nbeware: don't taste or swallow the water! i got some splashed on my face, and the shit tasted like nothing but salty nastiness. good thing i had my 1/2 yard to wash that down with the awesome taste of sin city juice."
8,2 star,"Noisy.\n\nLimited menu.\n\nControlled chaotic restaurant environment.\n\nTook forever to get a table, so we left.\n\nWent to the Breakfast Club in Scottsdale. Nice location, seating outdoors, and food that rivals Matt's.\n\nNo longer have any reason or desire to go back to Matt's."
9,3 stars,"Had this for take out the first time a couple months ago when we had a pizza craving. The place seemed pretty spacious and great for a kid's birthday party or family gathering. The pizza was okay. The cinnamon pie/roll/sticks thing was AMAZING. But other than that, I thought it was pretty meh."


### 预处理数据

In [7]:
#tokenizer = AutoTokenizer.from_pretrained(model_name_or_path)
tokenizer = AutoTokenizer.from_pretrained("bert-base-cased")


def tokenize_function(examples):
    return tokenizer(examples["text"], padding = "max_length", truncation=True)


tokenized_datasets = dataset.map(tokenize_function, batched=True)



Map:   0%|          | 0/50000 [00:00<?, ? examples/s]

In [11]:
show_random_elements(tokenized_datasets["train"], num_examples=1)

Unnamed: 0,label,text,input_ids,token_type_ids,attention_mask
0,2 star,"This joint is pretty average and so is their food . \n\n\nThere is maybe around 6-7 choices of hot dogs and this definitely cannot satisfy everyone's taste --restaurant should consider doing custom orders.\n\nAnywho , my Morgan freeman tasted good but the overall experience was average and at that price (close to 10$) I was expecting a lot better . For example , the restaurant could have given us a fries or a drink through such an expensive order . Come the heck on it's a hot dog , will you seriously charge me 10$ ? Even Japadog in Vancouver makes better hot dogs in term of taste and friendly prices .\n\nService was courteous , but cooking these took a while , walking in such a joint makes you think you can get out with your food in 10 mins but it is certainly not the case as it took close to 20 mins .\n\nAnyway , if you really crave for a hot dog on mont royal you can definitely come to this joint , it's an experience nonetheless. However , note that they are not great and are certainly overpriced . You can definitely get better and more of them at Julep or Decarie hot dogs if you are willing to make the Trip","[101, 1188, 4091, 1110, 2785, 1903, 1105, 1177, 1110, 1147, 2094, 119, 165, 183, 165, 183, 165, 183, 1942, 12807, 1110, 2654, 1213, 127, 118, 128, 9940, 1104, 2633, 6363, 1105, 1142, 5397, 2834, 13692, 2490, 112, 188, 5080, 118, 118, 4382, 1431, 4615, 1833, 8156, 3791, 119, 165, 183, 165, 183, 1592, 3382, 2246, 5114, 117, 1139, 4461, 1714, 1399, 12876, 1363, 1133, 1103, 2905, 2541, 1108, 1903, 1105, 1120, 1115, 3945, 113, 1601, 1106, 1275, 109, 114, 146, 1108, 7805, 170, 1974, 1618, 119, 1370, 1859, 117, 1103, 4382, 1180, 1138, 1549, 1366, 170, 175, 3377, 1137, 170, ...]","[0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, ...]","[1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, ...]"


In [8]:
### 数据抽样

# shuffle()函数会随机重新排列列的值。如果您希望对用于洗牌数据集的算法有更多控制，可以在此函数中指定generator参数来使用不同的numpy.random.Generator。

small_train_dataset = tokenized_datasets["train"].shuffle(seed=42).select(range(1000))
small_eval_dataset = tokenized_datasets["test"].shuffle(seed=42).select(range(1000))

### 微调训练配置

In [9]:
model = AutoModelForSequenceClassification.from_pretrained("bert-base-cased", num_labels=5)
#model = AutoModelForSequenceClassification.from_pretrained(model_name_or_path, num_labels=5)

Some weights of BertForSequenceClassification were not initialized from the model checkpoint at bert-base-cased and are newly initialized: ['classifier.bias', 'classifier.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.


In [10]:
# 最重要配置：模型权重保存路径（output_dir）

training_args = TrainingArguments(output_dir=model_dir,
                                  per_device_train_batch_size=16,
                                  num_train_epochs=5,
                                  logging_steps=100)

print(training_args)

TrainingArguments(
_n_gpu=1,
adafactor=False,
adam_beta1=0.9,
adam_beta2=0.999,
adam_epsilon=1e-08,
auto_find_batch_size=False,
bf16=False,
bf16_full_eval=False,
data_seed=None,
dataloader_drop_last=False,
dataloader_num_workers=0,
dataloader_persistent_workers=False,
dataloader_pin_memory=True,
ddp_backend=None,
ddp_broadcast_buffers=None,
ddp_bucket_cap_mb=None,
ddp_find_unused_parameters=None,
ddp_timeout=1800,
debug=[],
deepspeed=None,
disable_tqdm=False,
dispatch_batches=None,
do_eval=False,
do_predict=False,
do_train=False,
eval_accumulation_steps=None,
eval_delay=0,
eval_steps=None,
evaluation_strategy=no,
fp16=False,
fp16_backend=auto,
fp16_full_eval=False,
fp16_opt_level=O1,
fsdp=[],
fsdp_config={'min_num_params': 0, 'xla': False, 'xla_fsdp_grad_ckpt': False},
fsdp_min_num_params=0,
fsdp_transformer_layer_cls_to_wrap=None,
full_determinism=False,
gradient_accumulation_steps=1,
gradient_checkpointing=False,
gradient_checkpointing_kwargs=None,
greater_is_better=None,
group_by_le

### 训练过程中的指标评估（Evaluate）

In [11]:
import numpy as np
import evaluate

metric = evaluate.load("accuracy")

In [12]:
def compute_metrics(eval_pred):
    logits, labels = eval_pred
    predictions = np.argmax(logits, axis=-1)
    return metric.compute(predictions=predictions,references=labels)

In [13]:
# 训练过程指标监控
training_args = TrainingArguments(output_dir=model_dir,
                                  evaluation_strategy="epoch", 
                                  per_device_train_batch_size=16,
                                  num_train_epochs=3,
                                  logging_steps=30)

### 开始训练

In [14]:
trainer = Trainer(model=model,
                  args=training_args,
                  train_dataset=small_train_dataset,
                  eval_dataset=small_eval_dataset,
                  compute_metrics=compute_metrics)


In [15]:
trainer.train()

Epoch,Training Loss,Validation Loss,Accuracy
1,1.4719,1.198524,0.487
2,0.994,1.026885,0.561
3,0.6945,0.947726,0.584


TrainOutput(global_step=189, training_loss=1.0942042113611938, metrics={'train_runtime': 119.7812, 'train_samples_per_second': 25.046, 'train_steps_per_second': 1.578, 'total_flos': 789354427392000.0, 'train_loss': 1.0942042113611938, 'epoch': 3.0})

In [None]:
# 使用nvidia-smi查看GPU的使用
!watch -n 1 nvidia-smi

Every 1.0s: nvidia-smi                                                                                                                                                                         ubuntu-lesleyll: Mon May 26 11:52:19 2025

Mon May 26 11:52:19 2025
+-----------------------------------------------------------------------------------------+
| NVIDIA-SMI 550.127.08             Driver Version: 550.127.08     CUDA Version: 12.4     |
|-----------------------------------------+------------------------+----------------------+
| GPU  Name                 Persistence-M | Bus-Id          Disp.A | Volatile Uncorr. ECC |
| Fan  Temp   Perf          Pwr:Usage/Cap |           Memory-Usage | GPU-Util  Compute M. |
|                                         |                        |               MIG M. |
|=========================================+========================+======================|
|   0  Tesla V100-SXM2-16GB           On  |   00000000:00:07.0 Off |                    0 |
| N/A   36C    P0             55W /  300W |   12137MiB /  16384MiB |      0%      Default |
|                                         |                        |                  N/A |
+-----------------------------------------+------------------------+----------------------+

+-----------------------------------------------------------------------------------------+
| Processes:                                                                              |
|  GPU   GI   CI        PID   Type   Process name                              GPU Memory |
|        ID   ID                                                               Usage      |
|=========================================================================================|
|    0   N/A  N/A     11119      C   ...user/miniconda3/envs/llm/bin/python      12134MiB |
+-----------------------------------------------------------------------------------------+


In [17]:
## 测试

small_test_dataset = tokenized_datasets["test"].shuffle(seed=64).select(range(100))

trainer.evaluate(small_test_dataset)

{'eval_loss': 0.9957807064056396,
 'eval_accuracy': 0.53,
 'eval_runtime': 1.0787,
 'eval_samples_per_second': 92.704,
 'eval_steps_per_second': 12.051,
 'epoch': 3.0}

### 保存模型和训练状态

In [19]:
trainer.save_model(model_dir)
trainer.save_state()

### Homework:使用完整的yelpReviewFull数据集训练，看看ACC最高到多少

In [33]:
# : Homework:使用完整的yelpReviewFull数据集训练bert-base-cased模型，看看ACC最高到多少
metric = evaluate.load("accuracy")

def compute_metrics(eval_pred):
    logits, labels = eval_pred
    predictions = np.argmax(logits, axis=-1)
    return metric.compute(predictions=predictions,references=labels)

training_args = TrainingArguments(output_dir=model_dir,
                                  evaluation_strategy="epoch",             # 每个 epoch 评估一次
                                  per_device_train_batch_size=16,
                                  num_train_epochs=3,
                                  save_total_limit=1,                      # 最多只保留最近一个 checkpoint，旧的会自动删除
                                  logging_strategy="steps",                # 使用步数方式控制日志
                                  logging_steps=1000,                      # 每 n 步记录一次日志（降低频率）
                                  load_best_model_at_end=False,            # 不保存最优模型（省空间；如需调优再打开）
                                  fp16=True,                               # 如果支持 GPU，建议开启混合精度训练，节省显存+提速
                                  )  


In [36]:
large_train_dataset = tokenized_datasets["train"].shuffle(seed=42).select(range(100000))
#large_eval_dataset = tokenized_datasets["test"].shuffle(seed=42).select(range(1000))

#full_train_dataset = tokenized_datasets["train"]
full_eval_dataset = tokenized_datasets["test"]

In [37]:
trainer = Trainer(model=model,
                  args=training_args,
                  train_dataset=large_train_dataset,
                  eval_dataset=full_eval_dataset,
                  compute_metrics=compute_metrics)

trainer.train() ## 查看 training loss 和 validation loss 和 accuracy

Epoch,Training Loss,Validation Loss,Accuracy
1,No log,0.797095,0.65362
2,0.767100,0.798259,0.66478
3,0.767100,0.935516,0.65972


Checkpoint destination directory models/bert-base-cased-finetune-yelp/checkpoint-500 already exists and is non-empty.Saving will proceed but saved results may be invalid.


TrainOutput(global_step=18750, training_loss=0.6514964583333334, metrics={'train_runtime': 4590.1639, 'train_samples_per_second': 65.357, 'train_steps_per_second': 4.085, 'total_flos': 7.89354427392e+16, 'train_loss': 0.6514964583333334, 'epoch': 3.0})

In [39]:
## 测试

#large_test_dataset = tokenized_datasets["test"].shuffle(seed=64).select(range(30000))
full_test_dataset = tokenized_datasets["test"]
trainer.evaluate(full_test_dataset) ## 查看 test loss 和 accuracy

{'eval_loss': 0.9355159997940063,
 'eval_accuracy': 0.65972,
 'eval_runtime': 242.906,
 'eval_samples_per_second': 205.841,
 'eval_steps_per_second': 25.73,
 'epoch': 3.0}