#  Hugging Face Transformers 微调训练

- 数据集下载
- 数据预处理
- 训练超参数设置
- 训练评估指标设置
- 实战训练
- 模型保存



## 数据集下载

数据集首页: [https://huggingface.co/datasets/Yelp/yelp_review_full](https://huggingface.co/datasets/Yelp/yelp_review_full)



In [3]:
from datasets import load_dataset
dataset_path = "/home/hengzq/workspace/modelscope/datasets/yelp_review_full"

dataset = load_dataset(dataset_path)

In [4]:
dataset

DatasetDict({
    train: Dataset({
        features: ['label', 'text'],
        num_rows: 650000
    })
    test: Dataset({
        features: ['label', 'text'],
        num_rows: 50000
    })
})

In [5]:
dataset["train"][10]

{'label': 0,
 'text': "Owning a driving range inside the city limits is like a license to print money.  I don't think I ask much out of a driving range.  Decent mats, clean balls and accessible hours.  Hell you need even less people now with the advent of the machine that doles out the balls.  This place has none of them.  It is april and there are no grass tees yet.  BTW they opened for the season this week although it has been golfing weather for a month.  The mats look like the carpet at my 107 year old aunt Irene's house.  Worn and thread bare.  Let's talk about the hours.  This place is equipped with lights yet they only sell buckets of balls until 730.  It is still light out.  Finally lets you have the pit to hit into.  When I arrived I wasn't sure if this was a driving range or an excavation site for a mastodon or a strip mining operation.  There is no grass on the range. Just mud.  Makes it a good tool to figure out how far you actually are hitting the ball.  Oh, they are cash 

In [4]:
dataset["test"][10]

{'label': 1,
 'text': "Think Chuck E. Cheese for adults.  Skee Ball, video games pool tables.  Clean environment.  Good fun.\\n\\nUnfortunately, I went for a bite to eat and it was impossible to find anything good and healthy on the menu.  I ended up settling for spinach dip.  Sadly, they topped the dip off with horrible orange shredded cheese that appeared to have been popped in the microwave for a few seconds.  Blahhhh.  Trying to get something healthy, I ordered the apple pecan salad.  I swear the dressing came right out of the grocery store bottle.  I could barely eat the salad.  Too sweet.\\n\\nMy mom ordered a steak roll.....holy friedness!  The steak was more like hamburger fried with cheese and then stuffed into breading that was fried AGAIN!  Yowzer!  Artery clogger for sure.  \\n\\nI like the atmosphere.  I like the bar area.  Perhaps next time we'll just stop by for drinks instead."}

In [6]:
import random
import pandas as pd
import datasets
from IPython.display import display, HTML

In [7]:
def show_random_elements(dataset, num_examples = 10):
    picks = []
    for _ in range(num_examples):
        pick = random.randint(0, len(dataset) - 1)
        while pick in picks:
            pick = random.randint(0, len(dataset) - 1)
        picks.append(pick)

    df = pd.DataFrame(dataset[picks])
    for col, typ in dataset.features.items():
        if isinstance(typ, datasets.ClassLabel):
            df[col] = df[col].transform(lambda i: typ.names[i])
    display(HTML(df.to_html()))

In [8]:
show_random_elements(dataset["train"], 8)

Unnamed: 0,label,text
0,2 star,"Listen, I don't enjoy writing bad reviews, but this place really earns it. This is not a happy place. You feel like you have really hit the dirty end of Vegas when you come here. My company made me stay here while working a trade show. While I can't speak to the price, it made me realize my company doesn't care much for me or my comfort. \n\nI stayed in the Central Tower, which I think is the reason for my negative experience. I hear the North Tower has been renovated and is decent. \n\nCons; the hotel rooms (at least the ones not renovated) are gross, the bathrooms are filthy, you only get two small towels, the sink was broken\n\nPros; it is attached to the monorail, there are several semi decent restaurants, there is a good gift shop (albeit expensive)"
1,4 stars,"If you are looking for great authentic style pizza, wings, and subs, come and enjoy a warm family oriented restaurant in the heart of Phoenix. Are you tired of the artificiality of the popular food chains, come and enjoy a unique experience at Angie and Jimmy's. The pizza and wings is the hot commodity, due to taste and daily specials, yet Angie and Jimmy's also provide a variety of authentic subs. The interior design is layered with precious Elvis artifacts, which is the reason why some people come in and dine in the first place. Once there however, it is hard to not fall in love with anything and everything you eat. Yet, if you can't make it to the restaurant, Angie and Jimmy's cater to parties and local events. So if you can't enjoy the fun from the inside, Angie and Jimmy's lets you take it outside."
2,2 star,"We've come here a couple times, and the experience has been fine. As I dined here tonight, not only was the service horrible but all the dishes came out cold. It's a Thursday night and I've never had such slow service. Our group ordered maybe 5 rolls total but only 3 came out. As we asked our waiter where they were, he said that the chef didn't get them all in. (Who's fault is that really?) I'm not sure if I'll dine here again after tonight. He won't be getting a good surprise at the end of tonight."
3,4 stars,"Always good lunch, excellent soup... Gooood sushi. The reason for the missing star is restroom. Clean but cheap stuff... Doesnt have to be that way. Soap n paper towel is important!"
4,4 stars,I noticed I had a low tire on the way in to town so I looked for a Discount Tire and found one just down the street. In and out in less than two minutes. One of the reasons Discount had been getting all my tire business for many years.
5,4 stars,"The Rivers is a great spot downtown to play some slots or table games! They have non-smoking sections, a wide variety of machines, beverage service, great buffets with a ton of freshly made cuisines and free parking. Everything is incredibly clean, too! The machines are usually pretty tight but it's still fun to go!"
6,5 stars,"If there were 10 Stars, I would give it 10 stars!!!!!! Thank you Thank you THANK YOU my dearest Junior for introducing me to MacAlpines! This is My absolute favorite Place. Walk In here, and feel like you have traveled back in time, so amazingly kept original! Wooden Booths with a hat rack at each end, sooo cool!! Eat their Mouthwatering food while you sing along to the coolest oldies overhead. I must say my favorite meal is a tie between their Egg Salad and Chicken Salad Sandwich... mmmmmmmm... oh and their split pea soup!!! Cocunut Egg Cream Soda!!!! AND THEY SERVE THRIFTY'S ICE CREAM!! THE BEST Banana Split you will find! And you get to shop while you wait for your food! YES! 2 rooms filled with the neatest trinkets,a room filled with amazingly classy clothes/scarves/hats/pins, a phone booth, furniature, You will want to buy it all. You MUST check this place out."
7,4 stars,"The casino is a good size, has beautiful architecture out front and elegant waterfalls inside. Since its so new, the carpeting and walls are still clean. The staff are all very helpful, and especially good with the old folk. \n\nI didn't try out any of the gambling, since I was just there to eat. I was amazed at how busy it was on an early Saturday. The place is full of depressing overweight smokers, and unlike Vegas where everyone is in a party mood, the customers here seemed lifeless. \n\nWe ate at the 24/7 cafe which has your regular fried up comfort food. I was kind of shocked at how unhealthily they were able to prepare every single dish. If butter, lard, or a fried covering could be added to the food, it was. But the prices are cheap and give you the energy to keep throwing your money at the slots for another few hours. \n\nI want to come back and try their nice looking steakhouse. \nThis place gets an extra star for being open day and night, which is rare in the East Valley."


## 数据预处理

     下载数据集到本地后，使用Tokenizer来处理文本，对于长度不等的输入数据，可以使用填充（padding）和截断（truncation）策略来处理。

     Datasets的map方法，支持一次性在整个数据集上应用预处理函数。
     


In [9]:
#  使用填充到最大长度的策略，处理整个数据集：

from transformers import AutoTokenizer

model_path = "/home/hengzq/workspace/modelscope/models/bert-base-cased"

tokenizer = AutoTokenizer.from_pretrained(model_path)

def tokenize_function(examples):
    return tokenizer(examples["text"], padding = "max_length", truncation = True)

tokenized_datasets = dataset.map(tokenize_function, batched = True)

In [10]:
show_random_elements(tokenized_datasets["train"], num_examples = 2)

Unnamed: 0,label,text,input_ids,token_type_ids,attention_mask
0,2 star,"This will absolutely be the last time I review this place- promise! They are changing their name to \""Spectrum Fitness\"", and this week is the week to sign your new contract. Meaning, whatever deal you had is null and void. Everyone has to pay $19.99 a month. No exceptions. Do not pass Go, do not collect $200.\n\nI hate to do it, because I really do love this facility, but it's not worth paying the extra money when I can still pay what I always have at Fitness Works, who took over our contracts. It's just a little annoying that they would rather lose two people (they lost my training partner too) who wanted to stay, than work out some sort of deal. Oh well, their loss. We'll see how long they last, after all it's the third company to be in the space in I think as many years.","[101, 1188, 1209, 7284, 1129, 1103, 1314, 1159, 146, 3189, 1142, 1282, 118, 4437, 106, 1220, 1132, 4787, 1147, 1271, 1106, 165, 107, 22046, 28074, 165, 107, 117, 1105, 1142, 1989, 1110, 1103, 1989, 1106, 2951, 1240, 1207, 2329, 119, 25030, 1158, 117, 3451, 2239, 1128, 1125, 1110, 26280, 1105, 13340, 119, 6064, 1144, 1106, 2653, 109, 1627, 119, 4850, 170, 2370, 119, 1302, 12408, 119, 2091, 1136, 2789, 3414, 117, 1202, 1136, 7822, 109, 2363, 119, 165, 183, 165, 183, 2240, 4819, 1106, 1202, 1122, 117, 1272, 146, 1541, 1202, 1567, 1142, 3695, 117, 1133, 1122, 112, 188, 1136, ...]","[0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, ...]","[1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, ...]"
1,5 stars,"Had a chance to get back to Java Cat today. Seriously. If you love frozen treats and you haven't been here yet, I am shedding a tear for you. \n\nThe best gelato I've had, hands down. Both in quality, and creativity of flavors. \n\n(as a side note, the man and I split a BLT today as well, which was also yummy in that good-old-classic-sandwich kind of way).","[101, 6467, 170, 2640, 1106, 1243, 1171, 1106, 9155, 8572, 2052, 119, 18725, 119, 1409, 1128, 1567, 7958, 20554, 1105, 1128, 3983, 112, 189, 1151, 1303, 1870, 117, 146, 1821, 8478, 3408, 170, 7591, 1111, 1128, 119, 165, 183, 165, 183, 1942, 4638, 1436, 27426, 10024, 146, 112, 1396, 1125, 117, 1493, 1205, 119, 2695, 1107, 3068, 117, 1105, 17980, 1104, 16852, 1116, 119, 165, 183, 165, 183, 113, 1112, 170, 1334, 3805, 117, 1103, 1299, 1105, 146, 3325, 170, 139, 26909, 2052, 1112, 1218, 117, 1134, 1108, 1145, 194, 1818, 4527, 1107, 1115, 1363, 118, 1385, 118, 5263, 118, ...]","[0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, ...]","[1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, ...]"


### 数据抽样

`shuffle()`函数会随机重新排列列的值。

In [11]:
small_train_dataset = tokenized_datasets["train"].shuffle(seed = 42).select(range(1000))
small_eval_dataset = tokenized_datasets["test"].shuffle(seed = 42).select(range(1000))

## 加载模型


In [12]:
from transformers import AutoModelForSequenceClassification

model_path = "/home/hengzq/workspace/modelscope/models/bert-base-cased"

model = AutoModelForSequenceClassification.from_pretrained(model_path, num_labels = 5)

Some weights of BertForSequenceClassification were not initialized from the model checkpoint at /home/hengzq/workspace/modelscope/models/bert-base-cased and are newly initialized: ['classifier.bias', 'classifier.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.


## 训练超参数设置

完整配置参数与默认值：[https://huggingface.co/docs/transformers/v4.36.1/en/main_classes/trainer#transformers.TrainingArguments](https://huggingface.co/docs/transformers/v4.36.1/en/main_classes/trainer#transformers.TrainingArguments)

源代码定义：[https://github.com/huggingface/transformers/blob/v4.36.1/src/transformers/training_args.py#L161](https://github.com/huggingface/transformers/blob/v4.36.1/src/transformers/training_args.py#L161)




In [13]:
from transformers import TrainingArguments

model_dir = "/home/hengzq/workspace/modelscope/models/bert-base-cased"

# logging_steps 默认值为500，根据我们的训练数据和步长，设置为100
training_args = TrainingArguments(output_dir = f"{model_dir}/test_trainer",
                                  logging_dir = f"{model_dir}/test_trainer/runs",
                                  logging_steps = 100)



In [13]:
print(training_args)

TrainingArguments(
_n_gpu=1,
accelerator_config={'split_batches': False, 'dispatch_batches': None, 'even_batches': True, 'use_seedable_sampler': True, 'non_blocking': False, 'gradient_accumulation_kwargs': None, 'use_configured_state': False},
adafactor=False,
adam_beta1=0.9,
adam_beta2=0.999,
adam_epsilon=1e-08,
auto_find_batch_size=False,
average_tokens_across_devices=False,
batch_eval_metrics=False,
bf16=False,
bf16_full_eval=False,
data_seed=None,
dataloader_drop_last=False,
dataloader_num_workers=0,
dataloader_persistent_workers=False,
dataloader_pin_memory=True,
dataloader_prefetch_factor=None,
ddp_backend=None,
ddp_broadcast_buffers=None,
ddp_bucket_cap_mb=None,
ddp_find_unused_parameters=None,
ddp_timeout=1800,
debug=[],
deepspeed=None,
disable_tqdm=False,
dispatch_batches=None,
do_eval=False,
do_predict=False,
do_train=False,
eval_accumulation_steps=None,
eval_delay=0,
eval_do_concat_batches=True,
eval_on_start=False,
eval_steps=None,
eval_strategy=no,
eval_use_gather_object=F

## 训练评估指标设置(Evaluate)

    Hugging Face Evaluate库 支持使用一行代码，获得数十种不同领域（自然语言处理，计算机视觉、强化学习等）的评估方法。

    完整的评估指标：[https://huggingface.co/evaluate-metric](https://huggingface.co/evaluate-metric)

    训练器（Trainer)在训练过程中不会自动评估模型性能。因此，我们需要向训练器传递一个函数来计算和报告评估。

    Evaluate库提供了一个简单的准确率函数，你可以使用`evaluate.load`函数加载



In [14]:
#  定义度量函数
def compute_accuracy(predictions, references):
    correct = sum(p == r for p, r in zip(predictions, references))
    return correct / len(predictions) if len(predictions) > 0 else 0.0

### 训练过程指标监控
    通常，为了监控训练过程中的评估指标变化，我们可以在`TrainingArguments`指定`evaluation_strategy`参数，以便在epoch结束时评估指标。

In [16]:
from transformers import TrainingArguments

training_args = TrainingArguments(output_dir = f"{model_dir}/test_trainer",
                                  logging_dir = f"{model_dir}/test_trainer/runs",
                                  logging_steps = 100)

## 实战训练




In [14]:
from transformers import Trainer

trainer = Trainer(
    model = model,
    args = training_args,
    train_dataset = small_train_dataset,
    eval_dataset = small_eval_dataset
)

In [15]:
trainer.train()

Step,Training Loss
100,1.4034
200,1.0554
300,0.7762


TrainOutput(global_step=375, training_loss=0.9796944274902344, metrics={'train_runtime': 163.9516, 'train_samples_per_second': 18.298, 'train_steps_per_second': 2.287, 'total_flos': 789354427392000.0, 'train_loss': 0.9796944274902344, 'epoch': 3.0})

### 训练状态
- 使用`trainer.save_state`方法保存训练状态

In [16]:
trainer.save_state()

### 模型保存

- 使用`trainer.save_model`方法保存模型，后续可以通过`from_pretrained()`方法重新加载


In [17]:
trainer.save_model()