# Hugging Face Transformers 微调训练入门

本示例将介绍基于 Transformers 实现模型微调训练的主要流程，包括：
- 数据集下载
- 数据预处理
- 训练超参数配置
- 训练评估指标设置
- 训练器基本介绍
- 实战训练
- 模型保存

## YelpReviewFull 数据集

**Hugging Face 数据集：[ YelpReviewFull ](https://huggingface.co/datasets/yelp_review_full)**

### 数据集摘要

Yelp评论数据集包括来自Yelp的评论。它是从Yelp Dataset Challenge 2015数据中提取的。

### 支持的任务和排行榜
文本分类、情感分类：该数据集主要用于文本分类：给定文本，预测情感。

### 语言
这些评论主要以英语编写。

### 数据集结构

#### 数据实例
一个典型的数据点包括文本和相应的标签。

来自YelpReviewFull测试集的示例如下：

```json
{
    'label': 0,
    'text': 'I got \'new\' tires from them and within two weeks got a flat. I took my car to a local mechanic to see if i could get the hole patched, but they said the reason I had a flat was because the previous patch had blown - WAIT, WHAT? I just got the tire and never needed to have it patched? This was supposed to be a new tire. \\nI took the tire over to Flynn\'s and they told me that someone punctured my tire, then tried to patch it. So there are resentful tire slashers? I find that very unlikely. After arguing with the guy and telling him that his logic was far fetched he said he\'d give me a new tire \\"this time\\". \\nI will never go back to Flynn\'s b/c of the way this guy treated me and the simple fact that they gave me a used tire!'
}
```

#### 数据字段

- 'text': 评论文本使用双引号（"）转义，任何内部双引号都通过2个双引号（""）转义。换行符使用反斜杠后跟一个 "n" 字符转义，即 "\n"。
- 'label': 对应于评论的分数（介于1和5之间）。

#### 数据拆分

Yelp评论完整星级数据集是通过随机选取每个1到5星评论的130,000个训练样本和10,000个测试样本构建的。总共有650,000个训练样本和50,000个测试样本。

## 下载数据集

In [1]:
from datasets import load_dataset

dataset = load_dataset("yelp_review_full")

  from .autonotebook import tqdm as notebook_tqdm
Downloading readme: 100%|██████████| 6.72k/6.72k [00:00<00:00, 6.72MB/s]
Downloading data: 100%|██████████| 299M/299M [00:40<00:00, 7.35MB/s] 
Downloading data: 100%|██████████| 23.5M/23.5M [00:03<00:00, 7.44MB/s]
Generating train split: 100%|██████████| 650000/650000 [00:00<00:00, 907195.51 examples/s]
Generating test split: 100%|██████████| 50000/50000 [00:00<00:00, 950589.26 examples/s]


In [2]:
dataset

DatasetDict({
    train: Dataset({
        features: ['label', 'text'],
        num_rows: 650000
    })
    test: Dataset({
        features: ['label', 'text'],
        num_rows: 50000
    })
})

In [3]:
dataset["train"][10]

{'label': 0,
 'text': "Owning a driving range inside the city limits is like a license to print money.  I don't think I ask much out of a driving range.  Decent mats, clean balls and accessible hours.  Hell you need even less people now with the advent of the machine that doles out the balls.  This place has none of them.  It is april and there are no grass tees yet.  BTW they opened for the season this week although it has been golfing weather for a month.  The mats look like the carpet at my 107 year old aunt Irene's house.  Worn and thread bare.  Let's talk about the hours.  This place is equipped with lights yet they only sell buckets of balls until 730.  It is still light out.  Finally lets you have the pit to hit into.  When I arrived I wasn't sure if this was a driving range or an excavation site for a mastodon or a strip mining operation.  There is no grass on the range. Just mud.  Makes it a good tool to figure out how far you actually are hitting the ball.  Oh, they are cash 

In [4]:
import random
import pandas as pd
import datasets
from IPython.display import display, HTML

In [5]:
def show_random_elements(dataset, num_examples=10):
    assert num_examples <= len(dataset), "Can't pick more elements than there are in the dataset."
    picks = []
    for _ in range(num_examples):
        pick = random.randint(0, len(dataset)-1)
        while pick in picks:
            pick = random.randint(0, len(dataset)-1)
        picks.append(pick)
    
    df = pd.DataFrame(dataset[picks])
    for column, typ in dataset.features.items():
        if isinstance(typ, datasets.ClassLabel):
            df[column] = df[column].transform(lambda i: typ.names[i])
    display(HTML(df.to_html()))

In [6]:
show_random_elements(dataset["train"])

Unnamed: 0,label,text
0,3 stars,"I'm giving a 3 star rating because it isn't good enough to deserve 4 and not so bad that I'd give it a two. \n\nMy wife and I will stop in on a Thursday night to get the \""Date Night\"" special. An appetizer, two small salads and a pizza for $20. Great deal for a thrifty couple. Any other night would be too much of a strain on the pocketbook as items tend to be a little pricey. \n\nYeah the wait can be long, service is slow and they pack too many people in a small space. The outdoor patio is nice during the right time of year and there has been decent live entertainment at the bar. Wait staff can be a bit odd but not enough to bother me. \n\nOne night they ran out of mozzarella for the appetizer. I know, a pizza place ran out of mozzarella right? They didn't tell us until way too late but made up for it with a free dessert. \n\nBottom line, I will go back for the special but not much else."
1,5 stars,"Had a great experience here AGAIN! Wonderful peanut butter bacon burger, would definitely order it again! Best burgers so far BBQ bacon and peanut butter bacon. I have tried the bleu cheese but it's a little strong for me! Definitely a must try!!"
2,5 stars,"Good food, amazing service. We even got to tour the kitchen."
3,1 star,"What a brutal dining experience. In my many, many years of dining out, I have never experienced such terrible, unfriendly and unwelcoming service. We walked in to the restaurant and it had a distinct odor, much like that of natural gas. We hoped for the best and definitely got the worst. From the moment we sat down, my friend and I were assaulted by murderous glances from the owner. Was it because we brought our children to dine in his establishment? How dare we? We had to ask for bread, drinks, napkins a couple of times before receiving them. The waitress was overwhelmed and not too kind either. The food came out in waves, not all together as you would expect when dining in a restaurant. The adult's food arrived at least 8 minutes before the children's, you'd think it would be the opposite. The food was actually quite good, yet not good enough, to overcome the terrible discomfort that we felt being in this restaurant. Truly, I would have felt more welcome in McDonald's. \nIn this economy, you should welcome each and every paying customer into your restaurant with open arms, or at least a smile. A little kindness goes a long way....."
4,3 stars,"The positive: love the Oreo shake its thick, creamy and rich definitely will order it again when in town. The frozen chocolate kids devour it in nano seconds. They loved it. Kids love the burgers juicy and moist. Th crab cakes sandwiches not bad but I had better. \n\nThe Negative: kids meal pasta with sausage not good at all. My three girls didn't touch it after a few bites. It had no flavor so bland and nothing that resembles pasta sauce at all yuk! The Caesar salad with steak, steak is bland and chewy, could use a little bit more Caesar dressing . \nOverall I will go there again just for the Oreo shakes and maybe try their other desserts. I will skip their food its $$$$ not much flavor food just isn't good."
5,2 star,"We walked in around four thirty pm. Maybe it was just a bad evening for them, but no-one greeted us at the door. When a waitress did show up, she just waved her hand and said, \"" you can sit anywhere\"" So we seated ourselves\n\nWe were taken care of promptly, no problems there. However when our sushi arrived, we had to ask for soy sauce, wasabi and ginger to be served with it. The waitress seemed surprised that we wanted it\n\nOverall, the sushi was average. The service was less than stellar. The servers really need to step it up. I won't be back."
6,3 stars,"We went again last night...we sat in the back row. The seats are really, really, uncomfortable. They seem to have updated the hosting seating process...so we arrived 6:45 for a 7:50 show to be sure we had good seats...luckily I am used to fly SWA and I was able to barge my way through the cattle call to be one of the first people into the theater so we could pick what I thought would be decent seats...but wait...we screwed up...we choose the back row, upper right where the isle was. The spot we choose was dirty and not clean...we had to call someone to clean it. After she finished cleaning it, we had to press the button again to order our bottle of wine. And wait...the spot we choose...HUGE MISTAKE...all the servers ducked when they were in front of you, but not when walking down the stairs...so obviously whoever designed that row was smoking a doobie and thinking about his food service over the entertainment/movie part of this entire experience...because clearly these seats suck because there were always people walking in front of you. Oh and did I mention the seats were so uncomfortable I thought I was going to have to go to the Jacuzzi afterwards to get all the kinks out of my back? Ok...the good...I was able to drink a bottle of wine during a movie!"
7,2 star,"I remember visiting Lotus of Siam in my pre-Yelp days. I read all the (Zagat) hype, heard all the praise, and, on more than one occasion, mentioned the reviews of it being the \""Best Thai in North America.\""\n\nAll in all, it ended up being a huge letdown and I was embarrassed at bringing a large party of 8 to this restaurant and making reservations one month in advance. \n\n1) Location - Lotus of Siam is located in the middle of a strip mall and it isn't an easy location for most cabs to find. The neighborhood is questionable and in the evenings is probably not the safest locale.\n\n2) Decor - reflective of its strip mall location, the interior is nothing special. There are, however, typical Thai touches (flowers and statues) and plenty of newspaper articles touting it's reputation as a great Thai restaurant.\n\n3) Food - While many reviews talk about the amazing flavors and the incredible tastes, I found everything to be absolutely average. Given the high expectations, I was obviously disappointed, but even without those expectations, I'd be hard pressed to even call the food good. All the dishes chosen were \""signature\"" dishes recommended by our waiter and each failed to impress.\n\nWith every bite I took, I just wondered how the heck the Gourmet magazine food review could be so off and what kool-aid all the reviewers were drinking. Even now, I can think of a dozen Thai restaurants better than Lotus of Siam and probably only one or two that I've considered worse. \n\nIn the end, as I waited 45 minutes for a cab to pick up our group from the seedy strip mall location, I figured it all must be a tourist trap. One stellar review must have spawned the rest or maybe, none of the people who were gushing over Lotus had ever been to a Thai restaurant. \n\nFor the Las Vegas that is home to the $4.99 Steak and Egg breakfast and the 24 hour buffet, Lotus of Siam may be a revolutionary restaurant. But for North America?? Lotus of Siam falls amazingly flat."
8,4 stars,"I went here for lunch on a weekday and it was extremely affordable and worth the money. They had a wide range of foods from beef, poultry, pasta, seafood, even sushi. Our server was very attentive and brought us our drinks and took away our plates without us having to ask. All of the food was very well prepared. Th cheesecake was very creamy as was the raspberry chocolate mousse. I definitely recommend this place for a lunch, but be prepared to be EXTREMELY full after eating here."
9,2 star,"You get what you pay for. It's cheap and it looks cheap. \n\nMy friends and I headed to Vegas on a whim and nothing was planned, so we needed a room and fast. We called multiple places that were either booked or too expensive, so we finally checked in here.\n\nThe beds aren't comfortable and there really isn't anything special about this place, but it's a good choice if you are in Vegas to party and aren't planning on being in the room much.\n\nIt's not directly on The Strip but it's not far from it. We walked and it took us maybe 15-20 min."


## 预处理数据

下载数据集到本地后，使用 Tokenizer 来处理文本，对于长度不等的输入数据，可以使用填充（padding）和截断（truncation）策略来处理。

Datasets 的 `map` 方法，支持一次性在整个数据集上应用预处理函数。

下面使用填充到最大长度的策略，处理整个数据集：

In [7]:
from transformers import AutoTokenizer

tokenizer = AutoTokenizer.from_pretrained("bert-base-cased")


def tokenize_function(examples):
    return tokenizer(examples["text"], padding="max_length", truncation=True)


tokenized_datasets = dataset.map(tokenize_function, batched=True)

tokenizer_config.json: 100%|██████████| 29.0/29.0 [00:00<?, ?B/s]
To support symlinks on Windows, you either need to activate Developer Mode or to run Python as an administrator. In order to see activate developer mode, see this article: https://docs.microsoft.com/en-us/windows/apps/get-started/enable-your-device-for-development
config.json: 100%|██████████| 570/570 [00:00<?, ?B/s] 
vocab.txt: 100%|██████████| 213k/213k [00:00<00:00, 494kB/s]
tokenizer.json: 100%|██████████| 436k/436k [00:00<00:00, 712kB/s]
Map: 100%|██████████| 650000/650000 [02:02<00:00, 5304.39 examples/s]
Map: 100%|██████████| 50000/50000 [00:09<00:00, 5365.02 examples/s]


In [10]:
show_random_elements(tokenized_datasets["train"], num_examples=1)

Unnamed: 0,label,text,input_ids,token_type_ids,attention_mask
0,5 stars,"Wonderful fresh Persian food! But even better, great customer service!! We arrived after the kitchen closed, but to help us eat a healthy, hearty meal, Aram took our order and gave us great take out. Generous portions and kindness to strangers are much appreciated. Don't miss Royal Persis and try the Fesenjoon.","[101, 20361, 4489, 3886, 2094, 106, 1252, 1256, 1618, 117, 1632, 8132, 1555, 106, 106, 1284, 2474, 1170, 1103, 3119, 1804, 117, 1133, 1106, 1494, 1366, 3940, 170, 8071, 117, 1762, 1183, 7696, 117, 25692, 1306, 1261, 1412, 1546, 1105, 1522, 1366, 1632, 1321, 1149, 119, 9066, 13149, 8924, 1105, 18569, 1106, 15712, 1132, 1277, 12503, 119, 1790, 112, 189, 5529, 1787, 14286, 4863, 1105, 2222, 1103, 11907, 3792, 5077, 1320, 119, 102, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, ...]","[0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, ...]","[1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, ...]"


### 数据抽样

使用 1000 个数据样本，在 BERT 上演示小规模训练（基于 Pytorch Trainer）

`shuffle()`函数会随机重新排列列的值。如果您希望对用于洗牌数据集的算法有更多控制，可以在此函数中指定generator参数来使用不同的numpy.random.Generator。

In [23]:
train_dataset = tokenized_datasets["train"].shuffle(seed=42).select(range(650000))
eval_dataset = tokenized_datasets["test"].shuffle(seed=42).select(range(50000))

## 微调训练配置

### 加载 BERT 模型

警告通知我们正在丢弃一些权重（`vocab_transform` 和 `vocab_layer_norm` 层），并随机初始化其他一些权重（`pre_classifier` 和 `classifier` 层）。在微调模型情况下是绝对正常的，因为我们正在删除用于预训练模型的掩码语言建模任务的头部，并用一个新的头部替换它，对于这个新头部，我们没有预训练的权重，所以库会警告我们在用它进行推理之前应该对这个模型进行微调，而这正是我们要做的事情。

In [24]:
from transformers import AutoModelForSequenceClassification

model = AutoModelForSequenceClassification.from_pretrained("bert-base-cased", num_labels=5)

Some weights of the model checkpoint at bert-base-cased were not used when initializing BertForSequenceClassification: ['cls.predictions.transform.dense.bias', 'cls.predictions.transform.LayerNorm.bias', 'cls.predictions.transform.dense.weight', 'cls.predictions.bias', 'cls.seq_relationship.bias', 'cls.predictions.transform.LayerNorm.weight', 'cls.seq_relationship.weight']
- This IS expected if you are initializing BertForSequenceClassification from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).
- This IS NOT expected if you are initializing BertForSequenceClassification from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).
Some weights of BertForSequenceClassification were not initialized from the model checkpoint at bert-base-cased and are newly initi

### 训练超参数（TrainingArguments）

完整配置参数与默认值：https://huggingface.co/docs/transformers/v4.36.1/en/main_classes/trainer#transformers.TrainingArguments

源代码定义：https://github.com/huggingface/transformers/blob/v4.36.1/src/transformers/training_args.py#L161

**最重要配置：模型权重保存路径(output_dir)**

In [26]:
from transformers import TrainingArguments

model_dir = "models/bert-base-cased"

# logging_steps 默认值为500，根据我们的训练数据和步长，将其设置为100
training_args = TrainingArguments(output_dir=f"{model_dir}/test_trainer",
                                  logging_dir=f"{model_dir}/test_trainer/runs",
                                  logging_steps=5000)

In [27]:
# 完整的超参数配置
print(training_args)

TrainingArguments(
_n_gpu=1,
adafactor=False,
adam_beta1=0.9,
adam_beta2=0.999,
adam_epsilon=1e-08,
auto_find_batch_size=False,
bf16=False,
bf16_full_eval=False,
data_seed=None,
dataloader_drop_last=False,
dataloader_num_workers=0,
dataloader_pin_memory=True,
ddp_backend=None,
ddp_bucket_cap_mb=None,
ddp_find_unused_parameters=None,
ddp_timeout=1800,
debug=[],
deepspeed=None,
disable_tqdm=False,
do_eval=False,
do_predict=False,
do_train=False,
eval_accumulation_steps=None,
eval_delay=0,
eval_steps=None,
evaluation_strategy=no,
fp16=False,
fp16_backend=auto,
fp16_full_eval=False,
fp16_opt_level=O1,
fsdp=[],
fsdp_config={'fsdp_min_num_params': 0, 'xla': False, 'xla_fsdp_grad_ckpt': False},
fsdp_min_num_params=0,
fsdp_transformer_layer_cls_to_wrap=None,
full_determinism=False,
gradient_accumulation_steps=1,
gradient_checkpointing=False,
greater_is_better=None,
group_by_length=False,
half_precision_backend=auto,
hub_model_id=None,
hub_private_repo=False,
hub_strategy=every_save,
hub_token=

### 训练过程中的指标评估（Evaluate)

**[Hugging Face Evaluate 库](https://huggingface.co/docs/evaluate/index)** 支持使用一行代码，获得数十种不同领域（自然语言处理、计算机视觉、强化学习等）的评估方法。 当前支持 **完整评估指标：https://huggingface.co/evaluate-metric**

训练器（Trainer）在训练过程中不会自动评估模型性能。因此，我们需要向训练器传递一个函数来计算和报告指标。 

Evaluate库提供了一个简单的准确率函数，您可以使用`evaluate.load`函数加载

In [28]:
import numpy as np
import evaluate

metric = evaluate.load("accuracy")


接着，调用 `compute` 函数来计算预测的准确率。

在将预测传递给 compute 函数之前，我们需要将 logits 转换为预测值（**所有Transformers 模型都返回 logits**）。

In [29]:
def compute_metrics(eval_pred):
    logits, labels = eval_pred
    predictions = np.argmax(logits, axis=-1)
    return metric.compute(predictions=predictions, references=labels)

#### 训练过程指标监控

通常，为了监控训练过程中的评估指标变化，我们可以在`TrainingArguments`指定`evaluation_strategy`参数，以便在 epoch 结束时报告评估指标。

In [30]:
from transformers import TrainingArguments, Trainer

training_args = TrainingArguments(output_dir=f"{model_dir}/test_trainer",
                                  evaluation_strategy="epoch", 
                                  logging_dir=f"{model_dir}/test_trainer/runs",
                                  logging_steps=5000)

## 开始训练

### 实例化训练器（Trainer）

`kernel version` 版本问题：暂不影响本示例代码运行

In [31]:
trainer = Trainer(
    model=model,
    args=training_args,
    train_dataset=train_dataset,
    eval_dataset=eval_dataset,
    compute_metrics=compute_metrics,
)

## 使用 nvidia-smi 查看 GPU 使用

为了实时查看GPU使用情况，可以使用 `watch` 指令实现轮询：`watch -n 1 nvidia-smi`:

```shell
Every 1.0s: nvidia-smi                                                   Wed Dec 20 14:37:41 2023

Wed Dec 20 14:37:41 2023
+---------------------------------------------------------------------------------------+
| NVIDIA-SMI 535.129.03             Driver Version: 535.129.03   CUDA Version: 12.2     |
|-----------------------------------------+----------------------+----------------------+
| GPU  Name                 Persistence-M | Bus-Id        Disp.A | Volatile Uncorr. ECC |
| Fan  Temp   Perf          Pwr:Usage/Cap |         Memory-Usage | GPU-Util  Compute M. |
|                                         |                      |               MIG M. |
|=========================================+======================+======================|
|   0  Tesla T4                       Off | 00000000:00:0D.0 Off |                    0 |
| N/A   64C    P0              69W /  70W |   6665MiB / 15360MiB |     98%      Default |
|                                         |                      |                  N/A |
+-----------------------------------------+----------------------+----------------------+

+---------------------------------------------------------------------------------------+
| Processes:                                                                            |
|  GPU   GI   CI        PID   Type   Process name                            GPU Memory |
|        ID   ID                                                             Usage      |
|=======================================================================================|
|    0   N/A  N/A     18395      C   /root/miniconda3/bin/python                6660MiB |
+---------------------------------------------------------------------------------------+
```

In [32]:
trainer.train()

  0%|          | 167/243750 [3:49:18<5574:26:59, 82.39s/it]
                                                        
  2%|▏         | 5000/243750 [09:42<7:37:36,  8.70it/s]    

{'loss': 1.0327, 'learning_rate': 4.8974358974358975e-05, 'epoch': 0.06}


                                                        
  4%|▍         | 10000/243750 [19:26<7:31:16,  8.63it/s]   

{'loss': 0.932, 'learning_rate': 4.7948717948717955e-05, 'epoch': 0.12}


                                                         
  6%|▌         | 15000/243750 [29:08<7:21:44,  8.63it/s]   

{'loss': 0.9157, 'learning_rate': 4.692307692307693e-05, 'epoch': 0.18}


                                                         
  8%|▊         | 20000/243750 [38:48<7:10:52,  8.66it/s]   

{'loss': 0.8932, 'learning_rate': 4.5897435897435895e-05, 'epoch': 0.25}


                                                         
 10%|█         | 25000/243750 [48:29<6:57:35,  8.73it/s]   

{'loss': 0.8943, 'learning_rate': 4.4871794871794874e-05, 'epoch': 0.31}


                                                         
 12%|█▏        | 30000/243750 [58:11<6:50:07,  8.69it/s]   

{'loss': 0.9034, 'learning_rate': 4.384615384615385e-05, 'epoch': 0.37}


                                                           
 14%|█▍        | 35000/243750 [1:07:53<6:41:30,  8.67it/s] 

{'loss': 0.8883, 'learning_rate': 4.282051282051282e-05, 'epoch': 0.43}


                                                           
 16%|█▋        | 40000/243750 [1:17:35<6:38:34,  8.52it/s] 

{'loss': 0.8796, 'learning_rate': 4.17948717948718e-05, 'epoch': 0.49}


                                                           
 18%|█▊        | 45000/243750 [1:27:17<6:20:17,  8.71it/s] 

{'loss': 0.8783, 'learning_rate': 4.0769230769230773e-05, 'epoch': 0.55}


                                                           
 21%|██        | 50000/243750 [1:36:58<6:06:45,  8.80it/s] 

{'loss': 0.8766, 'learning_rate': 3.974358974358974e-05, 'epoch': 0.62}


                                                           
 23%|██▎       | 55000/243750 [1:46:39<6:02:11,  8.69it/s] 

{'loss': 0.8781, 'learning_rate': 3.871794871794872e-05, 'epoch': 0.68}


                                                           
 25%|██▍       | 60000/243750 [1:56:20<5:51:17,  8.72it/s] 

{'loss': 0.8709, 'learning_rate': 3.769230769230769e-05, 'epoch': 0.74}


                                                           
 27%|██▋       | 65000/243750 [2:06:02<5:38:58,  8.79it/s] 

{'loss': 0.8683, 'learning_rate': 3.6666666666666666e-05, 'epoch': 0.8}


                                                           
 29%|██▊       | 70000/243750 [2:15:42<5:32:39,  8.70it/s] 

{'loss': 0.8735, 'learning_rate': 3.5641025641025646e-05, 'epoch': 0.86}


                                                           
 31%|███       | 75000/243750 [2:25:20<5:20:49,  8.77it/s] 

{'loss': 0.8772, 'learning_rate': 3.461538461538462e-05, 'epoch': 0.92}


                                                           
 33%|███▎      | 80000/243750 [2:34:59<5:11:24,  8.76it/s] 

{'loss': 0.8645, 'learning_rate': 3.358974358974359e-05, 'epoch': 0.98}


 33%|███▎      | 81250/243750 [2:37:24<5:10:44,  8.72it/s] 
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A


{'eval_loss': 0.8560718894004822, 'eval_accuracy': 0.6355, 'eval_runtime': 256.3551, 'eval_samples_per_second': 195.042, 'eval_steps_per_second': 24.38, 'epoch': 1.0}


                                                             
 35%|███▍      | 85000/243750 [2:48:56<5:03:53,  8.71it/s] 

{'loss': 0.8459, 'learning_rate': 3.2564102564102565e-05, 'epoch': 1.05}


                                                           
 37%|███▋      | 90000/243750 [2:58:35<4:52:54,  8.75it/s] 

{'loss': 0.8284, 'learning_rate': 3.153846153846154e-05, 'epoch': 1.11}


                                                           
 39%|███▉      | 95000/243750 [3:08:20<4:47:04,  8.64it/s] 

{'loss': 0.8479, 'learning_rate': 3.0512820512820518e-05, 'epoch': 1.17}


                                                           
 41%|████      | 100000/243750 [3:18:01<4:33:47,  8.75it/s]

{'loss': 0.84, 'learning_rate': 2.948717948717949e-05, 'epoch': 1.23}


                                                            
 43%|████▎     | 105000/243750 [3:27:42<4:27:06,  8.66it/s]

{'loss': 0.8453, 'learning_rate': 2.846153846153846e-05, 'epoch': 1.29}


                                                            
 45%|████▌     | 110000/243750 [3:37:22<4:13:00,  8.81it/s]

{'loss': 0.8576, 'learning_rate': 2.743589743589744e-05, 'epoch': 1.35}


                                                            
 47%|████▋     | 115000/243750 [3:47:04<4:03:59,  8.79it/s]

{'loss': 0.8448, 'learning_rate': 2.6410256410256413e-05, 'epoch': 1.42}


                                                            
 49%|████▉     | 120000/243750 [3:56:44<3:56:35,  8.72it/s]

{'loss': 0.8495, 'learning_rate': 2.5384615384615383e-05, 'epoch': 1.48}


                                                            
 51%|█████▏    | 125000/243750 [4:06:24<3:45:28,  8.78it/s]

{'loss': 0.8338, 'learning_rate': 2.435897435897436e-05, 'epoch': 1.54}


                                                            
 53%|█████▎    | 130000/243750 [4:16:03<3:34:20,  8.85it/s]

{'loss': 0.8495, 'learning_rate': 2.3333333333333336e-05, 'epoch': 1.6}


                                                            
 55%|█████▌    | 135000/243750 [4:25:42<3:27:59,  8.71it/s]

{'loss': 0.864, 'learning_rate': 2.230769230769231e-05, 'epoch': 1.66}


                                                            
 57%|█████▋    | 140000/243750 [4:35:21<3:16:38,  8.79it/s]

{'loss': 0.8477, 'learning_rate': 2.1282051282051282e-05, 'epoch': 1.72}


                                                            
 59%|█████▉    | 145000/243750 [4:45:00<3:07:38,  8.77it/s]

{'loss': 0.8332, 'learning_rate': 2.025641025641026e-05, 'epoch': 1.78}


                                                           
 62%|██████▏   | 150000/243750 [4:54:40<2:56:50,  8.84it/s]

{'loss': 0.8306, 'learning_rate': 1.923076923076923e-05, 'epoch': 1.85}


                                                           
 64%|██████▎   | 155000/243750 [5:04:19<2:48:44,  8.77it/s]

{'loss': 0.8157, 'learning_rate': 1.8205128205128204e-05, 'epoch': 1.91}


                                                           
 66%|██████▌   | 160000/243750 [5:13:59<2:41:22,  8.65it/s]

{'loss': 0.8177, 'learning_rate': 1.717948717948718e-05, 'epoch': 1.97}


 67%|██████▋   | 162500/243750 [5:18:50<2:34:12,  8.78it/s]
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A


{'eval_loss': 0.8337429761886597, 'eval_accuracy': 0.64566, 'eval_runtime': 255.6521, 'eval_samples_per_second': 195.578, 'eval_steps_per_second': 24.447, 'epoch': 2.0}


                                                              
 68%|██████▊   | 165000/243750 [5:27:56<2:30:17,  8.73it/s]

{'loss': 0.8001, 'learning_rate': 1.6153846153846154e-05, 'epoch': 2.03}


                                                           
 70%|██████▉   | 170000/243750 [5:37:36<2:22:13,  8.64it/s]

{'loss': 0.7928, 'learning_rate': 1.5128205128205129e-05, 'epoch': 2.09}


                                                           
 72%|███████▏  | 175000/243750 [5:47:16<2:09:35,  8.84it/s]

{'loss': 0.7915, 'learning_rate': 1.4102564102564104e-05, 'epoch': 2.15}


                                                           
 74%|███████▍  | 180000/243750 [5:56:56<2:02:20,  8.68it/s]

{'loss': 0.7779, 'learning_rate': 1.3076923076923078e-05, 'epoch': 2.22}


                                                           
 76%|███████▌  | 185000/243750 [6:06:35<1:50:57,  8.82it/s]

{'loss': 0.7772, 'learning_rate': 1.2051282051282051e-05, 'epoch': 2.28}


                                                           
 78%|███████▊  | 190000/243750 [6:16:13<1:41:11,  8.85it/s]

{'loss': 0.7703, 'learning_rate': 1.1025641025641026e-05, 'epoch': 2.34}


                                                           
 80%|████████  | 195000/243750 [6:25:52<1:33:41,  8.67it/s] 

{'loss': 0.7718, 'learning_rate': 1e-05, 'epoch': 2.4}


                                                           
 82%|████████▏ | 200000/243750 [6:35:32<1:22:10,  8.87it/s] 

{'loss': 0.7617, 'learning_rate': 8.974358974358976e-06, 'epoch': 2.46}


                                                           
 84%|████████▍ | 205000/243750 [6:45:11<1:13:06,  8.83it/s] 

{'loss': 0.7711, 'learning_rate': 7.948717948717949e-06, 'epoch': 2.52}


                                                           
 86%|████████▌ | 210000/243750 [6:54:53<1:04:44,  8.69it/s] 

{'loss': 0.7583, 'learning_rate': 6.923076923076923e-06, 'epoch': 2.58}


                                                           
 88%|████████▊ | 215000/243750 [7:04:38<55:37,  8.61it/s]   

{'loss': 0.7511, 'learning_rate': 5.897435897435897e-06, 'epoch': 2.65}


                                                           
 90%|█████████ | 220000/243750 [7:14:20<45:33,  8.69it/s]   

{'loss': 0.7478, 'learning_rate': 4.871794871794872e-06, 'epoch': 2.71}


                                                           
 92%|█████████▏| 225000/243750 [7:24:02<35:47,  8.73it/s]   

{'loss': 0.7496, 'learning_rate': 3.846153846153847e-06, 'epoch': 2.77}


                                                           
 94%|█████████▍| 230000/243750 [7:33:44<26:22,  8.69it/s]   

{'loss': 0.751, 'learning_rate': 2.8205128205128207e-06, 'epoch': 2.83}


                                                           
 96%|█████████▋| 235000/243750 [7:43:25<16:32,  8.82it/s]   

{'loss': 0.7449, 'learning_rate': 1.7948717948717948e-06, 'epoch': 2.89}


                                                         
 98%|█████████▊| 240000/243750 [7:52:53<07:07,  8.77it/s]   

{'loss': 0.7478, 'learning_rate': 7.692307692307694e-07, 'epoch': 2.95}


100%|██████████| 243750/243750 [8:00:13<00:00,  8.67it/s]
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[

{'eval_loss': 0.8057957291603088, 'eval_accuracy': 0.66116, 'eval_runtime': 259.6802, 'eval_samples_per_second': 192.545, 'eval_steps_per_second': 24.068, 'epoch': 3.0}
{'train_runtime': 29072.7788, 'train_samples_per_second': 67.073, 'train_steps_per_second': 8.384, 'train_loss': 0.832849788661859, 'epoch': 3.0}


TrainOutput(global_step=243750, training_loss=0.832849788661859, metrics={'train_runtime': 29072.7788, 'train_samples_per_second': 67.073, 'train_steps_per_second': 8.384, 'train_loss': 0.832849788661859, 'epoch': 3.0})

In [33]:
test_dataset = tokenized_datasets["test"].shuffle(seed=64).select(range(10000))

In [34]:
trainer.evaluate(test_dataset)

100%|██████████| 1250/1250 [00:51<00:00, 24.41it/s]


{'eval_loss': 0.8071225881576538,
 'eval_accuracy': 0.6661,
 'eval_runtime': 51.482,
 'eval_samples_per_second': 194.243,
 'eval_steps_per_second': 24.28,
 'epoch': 3.0}

### 保存模型和训练状态

- 使用 `trainer.save_model` 方法保存模型，后续可以通过 from_pretrained() 方法重新加载
- 使用 `trainer.save_state` 方法保存训练状态

In [35]:
trainer.save_model(f"{model_dir}/finetuned-trainer")

In [36]:
trainer.save_state()

## Homework: 使用完整的 YelpReviewFull 数据集训练，对比看 Acc 最高能到多少

Acc提升到了0.6661