# Hugging Face Transformers 微调训练入门

本示例将介绍基于 Transformers 实现模型微调训练的主要流程，包括：
- 数据集下载
- 数据预处理
- 训练超参数配置
- 训练评估指标设置
- 训练器基本介绍
- 实战训练
- 模型保存

## YelpReviewFull 数据集

**Hugging Face 数据集：[ YelpReviewFull ](https://huggingface.co/datasets/yelp_review_full)**

### 数据集摘要

Yelp评论数据集包括来自Yelp的评论。它是从Yelp Dataset Challenge 2015数据中提取的。

### 支持的任务和排行榜
文本分类、情感分类：该数据集主要用于文本分类：给定文本，预测情感。

### 语言
这些评论主要以英语编写。

### 数据集结构

#### 数据实例
一个典型的数据点包括文本和相应的标签。

来自YelpReviewFull测试集的示例如下：

```json
{
    'label': 0,
    'text': 'I got \'new\' tires from them and within two weeks got a flat. I took my car to a local mechanic to see if i could get the hole patched, but they said the reason I had a flat was because the previous patch had blown - WAIT, WHAT? I just got the tire and never needed to have it patched? This was supposed to be a new tire. \\nI took the tire over to Flynn\'s and they told me that someone punctured my tire, then tried to patch it. So there are resentful tire slashers? I find that very unlikely. After arguing with the guy and telling him that his logic was far fetched he said he\'d give me a new tire \\"this time\\". \\nI will never go back to Flynn\'s b/c of the way this guy treated me and the simple fact that they gave me a used tire!'
}
```

#### 数据字段

- 'text': 评论文本使用双引号（"）转义，任何内部双引号都通过2个双引号（""）转义。换行符使用反斜杠后跟一个 "n" 字符转义，即 "\n"。
- 'label': 对应于评论的分数（介于1和5之间）。

#### 数据拆分

Yelp评论完整星级数据集是通过随机选取每个1到5星评论的130,000个训练样本和10,000个测试样本构建的。总共有650,000个训练样本和50,000个测试样本。

## 下载数据集

In [1]:
from datasets import load_dataset

dataset = load_dataset("yelp_review_full")

  from .autonotebook import tqdm as notebook_tqdm


In [2]:
dataset

DatasetDict({
    train: Dataset({
        features: ['label', 'text'],
        num_rows: 650000
    })
    test: Dataset({
        features: ['label', 'text'],
        num_rows: 50000
    })
})

In [3]:
dataset["train"][11]

{'label': 0,
 'text': "This place is absolute garbage...  Half of the tees are not available, including all the grass tees.  It is cash only, and they sell the last bucket at 8, despite having lights.  And if you finish even a minute after 8, don't plan on getting a drink.  The vending machines are sold out (of course) and they sell drinks inside, but close the drawers at 8 on the dot.  There are weeds grown all over the place.  I noticed some sort of batting cage, but it looks like those are out of order as well.  Someone should buy this place and turn it into what it should be."}

In [4]:
import random
import pandas as pd
import datasets
from IPython.display import display, HTML

In [5]:
def show_random_elements(dataset, num_examples=10):
    assert num_examples <= len(dataset), "Can't pick more elements than there are in the dataset."
    picks = []
    for _ in range(num_examples):
        pick = random.randint(0, len(dataset)-1)
        while pick in picks:
            pick = random.randint(0, len(dataset)-1)
        picks.append(pick)
    
    df = pd.DataFrame(dataset[picks])
    for column, typ in dataset.features.items():
        if isinstance(typ, datasets.ClassLabel):
            df[column] = df[column].transform(lambda i: typ.names[i])
    display(HTML(df.to_html()))

In [6]:
show_random_elements(dataset["train"])

Unnamed: 0,label,text
0,2 star,Awesome shark tank! Great wait staff! My whole family had a great time. I had the turkey platter which tasted good. Then I spent the next 24 hours with food poisoning.
1,1 star,"we started from no clean tables available, the only one clean had no chairs, the cashier was so distracted, we payed him $7 for the special (2 pizza slices & beer) with a $20 bill he gave me back $8, when i asked him what was this he didn't even know what was that change from, after reminding him he asked me how much is 20 minus 7, i said 13, then he grabs a pack of bills from his pocket and gives me the correct change. i was never explained that the special came with only cheese pizza, i asked to have pepperoni, he mentioned if i pay $1 he will change it for me, after getting my slices one is pepperoni the other one is cheese... after paying $1 I asked him what happened and he goes \""ill give you your $1 back and keep the slices\"" the worst place in downtown... the pizza was ok, i wouldn't come back."
2,2 star,"I'm always in Chinatown and last night was the first time I've checked out Little Macau. \n\nI don't know how this place would fit into my social life. I've never seen it a full house, so its not good for people-watching. Its alright for an overpriced cocktail, but its too far away from home to be my corner bar. I paid $14 for two Tanqueray and tonics. Its a nice low key place to have a drink if you're in the neighborhood. \n\nI think they had happy hour from 10p to 1a. We ordered our first round at 950p and later ordered two more rounds and an appetizer. We didn't get any discount. \n\nThey have a small menu with dumplings ($7-8) and fried rice and a lot of Portugese-influenced dishes, like Carne de Porco with Green Beans ($13), as Macau was colonized by the Portugese. \n\nI think the bartender was saying something about the joint being built right around the time the whole no smoking in establishments that serve food law was being passed, so they came up with an agreement to outsource food from the \""Korean restaurant next door\"" (I believe he was referring to Dae Jang Keum or DJK). Its convenient for them because they don't have to worry about kitchen health codes and pay for all of the equipment and worry about taking up precious bar stool space with a kitchen. But it took longer then it should've for my friend to get her steamed pork dumplings. And the bartender didn't seem that knowledgeable about the foods. \""Everything that I've tried is good.\"" But its friggin' bar food, who cares how good it is? \n\nThe decor was nice. My friend loved the gold wall with velvet elaborations. The lounge in the back is really laid back. But there was no bar wench, so you have to order your drinks at the bar every round. \n\n* off the strip, non-casino"
3,3 stars,Great happy hour and decent food. Always a fun crowd.
4,5 stars,The hours are wrong for Sunday. Beware. The food is wonderful but yelp is posting the wrong hours on Sundays. I'm here at 7pm and this amazing restaurant is closed.\nI am sad
5,4 stars,This is one of my favorite late night spots in Vegas. I love stumbling here at 4am and get my favorite ZEN noodles! It's like PHO but with an american twist. YUM. I also love the fact that this restaurant is open 24/7!! The sandwiches here are also really good.\n\nWatch out ladies for the slippery floor! \n\nDrunkeness + slippery floor = no buenos
6,3 stars,"This Mexican restaurant is just ok. I think for the price, the portions are pretty small. Got the chicken fajitas and there wasn't much chicken. Chips and salsa are pretty good, but I wouldn't recommend overall. Much better choices close by."
7,1 star,"Was very disappointed with level of quality and service. Ordered to go and gave clear instructions to have two separate meals. Each with naan and rice and utensils for work. Arrived to find one bag. One rice and one naan. I pointed this out and they fixed it but they forgot the naan. Looked at $35!!!! Bill and saw they double charged for the naan too. The vindaloo chicken was not spicy enough and very oily. The employees spoke very little English and that was a huge barrier. They also seemed arrogant and rude. I am thinking that most of these reviews are fake or people do not know what good Indian is. This was the worst Indian I've had in Charlotte. At twice the price. I now know why this place is always empty. I have also tried the lunch buffet and while the price was better. The spice was horrible. All bland, boring curries and an always empty naan bin. Had to ask to get naan. This place is PERFECT if you have money to throw away and want mediocre Indian."
8,1 star,"When we arrived on a Saturday night at about 8:30. The place was empty. Not a good sign. As we were seated, the host said, I have a party coming later, you might want to sit far away - and we did. \n\nHowever, the food was uneventful. And nothing like I've had in Persian Resturants in the Middle East. The service was average. However, the group - of about eight men - totall destroyed the evening. They had a keyboard player/singer. As soon as they started up -- well, my ears are still ringing this morning! \n\nWill I go back. No. Service/Average. Food/Average. Pricing/Normal for off strip. Ambiance - Nothing to write home about."
9,5 stars,I was a little nervous because I just moved out here from CA and didn't have a hairstylist. I found The Root on Yelp and I loved it. It lived up to its high rating. My hair looks amazing and the prices were reasonable. I am a customer for good!


## 预处理数据

下载数据集到本地后，使用 Tokenizer 来处理文本，对于长度不等的输入数据，可以使用填充（padding）和截断（truncation）策略来处理。

Datasets 的 `map` 方法，支持一次性在整个数据集上应用预处理函数。

下面使用填充到最大长度的策略，处理整个数据集：

In [7]:
from transformers import AutoTokenizer

tokenizer = AutoTokenizer.from_pretrained("bert-base-cased")


def tokenize_function(examples):
    return tokenizer(examples["text"], padding="max_length", truncation=True)


tokenized_datasets = dataset.map(tokenize_function, batched=True)



In [8]:
show_random_elements(tokenized_datasets["train"], num_examples=2)

Unnamed: 0,label,text,input_ids,token_type_ids,attention_mask
0,4 stars,"What can I say I love this place. The food is always fresh and everything is hot. I have been here a lot and love to see the chef in the dining room talking with customers. I come from an Italian family and I have to say that their eggplant parm is better than my grandmother's. I do have two favorites the lasagna and manicotti always delight. I love the views at sunset, this place has the best views of Vegas. I'm not sure how they make their sauce but I can't get enough of it. I'm not sure why people would go to an Italian restaurant and get fish so I can't really comment on the fish but I've tried most of the menu and everything has been wonderful. I don't understand how people can say the food is bland, I'm Italian from an Italian family and I've been to a lot of Italian restaurants and have eaten Italian food my whole life, this place by far has the best tasting sauce, I wish I knew thier secret. The dining room has a lot of seating and the staff is always friendly and when they mention you by name shows class. I'm only giving 4 stars because nobody beats my grandmother's cooking and I leave that star for her.","[101, 1327, 1169, 146, 1474, 146, 1567, 1142, 1282, 119, 1109, 2094, 1110, 1579, 4489, 1105, 1917, 1110, 2633, 119, 146, 1138, 1151, 1303, 170, 1974, 1105, 1567, 1106, 1267, 1103, 13628, 1107, 1103, 7659, 1395, 2520, 1114, 5793, 119, 146, 1435, 1121, 1126, 2169, 1266, 1105, 146, 1138, 1106, 1474, 1115, 1147, 9069, 1643, 9180, 14247, 1306, 1110, 1618, 1190, 1139, 6907, 112, 188, 119, 146, 1202, 1138, 1160, 25735, 1103, 17496, 8517, 1605, 1105, 1299, 10658, 6154, 1579, 13657, 119, 146, 1567, 1103, 4696, 1120, 16855, 117, 1142, 1282, 1144, 1103, 1436, 4696, 1104, 6554, 119, 146, 112, ...]","[0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, ...]","[1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, ...]"
1,4 stars,"Here's the deal. The food sounds special but it really isn't, and normally that would doom a place to me. But... I keep coming back. Why? Because I don't come here to eat amazing food. I come here to hang out in a casually stylish environment with good music and truly good coffee in a dedicated daytime establishment. VVM definitely owns its slice of the restaurant ocean by sticking to a simple and stylish formula.","[101, 3446, 112, 188, 1103, 2239, 119, 1109, 2094, 3807, 1957, 1133, 1122, 1541, 2762, 112, 189, 117, 1105, 5156, 1115, 1156, 26375, 170, 1282, 1106, 1143, 119, 1252, 119, 119, 119, 146, 1712, 1909, 1171, 119, 2009, 136, 2279, 146, 1274, 112, 189, 1435, 1303, 1106, 3940, 6929, 2094, 119, 146, 1435, 1303, 1106, 7311, 1149, 1107, 170, 13725, 188, 2340, 10550, 3750, 1114, 1363, 1390, 1105, 5098, 1363, 3538, 1107, 170, 3256, 14907, 4544, 119, 159, 2559, 2107, 5397, 8300, 1157, 16346, 1104, 1103, 4382, 5969, 1118, 14103, 1106, 170, 3014, 1105, 188, 2340, 10550, 7893, 119, 102, ...]","[0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, ...]","[1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, ...]"


### 数据抽样

使用 1000 个数据样本，在 BERT 上演示小规模训练（基于 Pytorch Trainer）

`shuffle()`函数会随机重新排列列的值。如果您希望对用于洗牌数据集的算法有更多控制，可以在此函数中指定generator参数来使用不同的numpy.random.Generator。

In [9]:
small_train_dataset = tokenized_datasets["train"].shuffle(seed=2).select(range(65000))
small_eval_dataset = tokenized_datasets["test"].shuffle(seed=2).select(range(5000))

## 微调训练配置

### 加载 BERT 模型

警告通知我们正在丢弃一些权重（`vocab_transform` 和 `vocab_layer_norm` 层），并随机初始化其他一些权重（`pre_classifier` 和 `classifier` 层）。在微调模型情况下是绝对正常的，因为我们正在删除用于预训练模型的掩码语言建模任务的头部，并用一个新的头部替换它，对于这个新头部，我们没有预训练的权重，所以库会警告我们在用它进行推理之前应该对这个模型进行微调，而这正是我们要做的事情。

In [10]:
from transformers import AutoModelForSequenceClassification

model = AutoModelForSequenceClassification.from_pretrained("bert-base-cased", num_labels=5)

Some weights of BertForSequenceClassification were not initialized from the model checkpoint at bert-base-cased and are newly initialized: ['classifier.bias', 'classifier.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.


### 训练超参数（TrainingArguments）

完整配置参数与默认值：https://huggingface.co/docs/transformers/v4.36.1/en/main_classes/trainer#transformers.TrainingArguments

源代码定义：https://github.com/huggingface/transformers/blob/v4.36.1/src/transformers/training_args.py#L161

**最重要配置：模型权重保存路径(output_dir)**

In [11]:
from transformers import TrainingArguments

model_dir = "models/bert-base-cased-finetune-yelp"

# logging_steps 默认值为500，根据我们的训练数据和步长，将其设置为100
training_args = TrainingArguments(output_dir=model_dir,
                                  per_device_train_batch_size=16,
                                  num_train_epochs=5,
                                  logging_steps=100)

In [12]:
# 完整的超参数配置
print(training_args)

TrainingArguments(
_n_gpu=1,
adafactor=False,
adam_beta1=0.9,
adam_beta2=0.999,
adam_epsilon=1e-08,
auto_find_batch_size=False,
bf16=False,
bf16_full_eval=False,
data_seed=None,
dataloader_drop_last=False,
dataloader_num_workers=0,
dataloader_persistent_workers=False,
dataloader_pin_memory=True,
ddp_backend=None,
ddp_broadcast_buffers=None,
ddp_bucket_cap_mb=None,
ddp_find_unused_parameters=None,
ddp_timeout=1800,
debug=[],
deepspeed=None,
disable_tqdm=False,
dispatch_batches=None,
do_eval=False,
do_predict=False,
do_train=False,
eval_accumulation_steps=None,
eval_delay=0,
eval_steps=None,
evaluation_strategy=no,
fp16=False,
fp16_backend=auto,
fp16_full_eval=False,
fp16_opt_level=O1,
fsdp=[],
fsdp_config={'min_num_params': 0, 'xla': False, 'xla_fsdp_grad_ckpt': False},
fsdp_min_num_params=0,
fsdp_transformer_layer_cls_to_wrap=None,
full_determinism=False,
gradient_accumulation_steps=1,
gradient_checkpointing=False,
gradient_checkpointing_kwargs=None,
greater_is_better=None,
group_by_le

### 训练过程中的指标评估（Evaluate)

**[Hugging Face Evaluate 库](https://huggingface.co/docs/evaluate/index)** 支持使用一行代码，获得数十种不同领域（自然语言处理、计算机视觉、强化学习等）的评估方法。 当前支持 **完整评估指标：https://huggingface.co/evaluate-metric**

训练器（Trainer）在训练过程中不会自动评估模型性能。因此，我们需要向训练器传递一个函数来计算和报告指标。 

Evaluate库提供了一个简单的准确率函数，您可以使用`evaluate.load`函数加载

In [13]:
import numpy as np
import evaluate

metric = evaluate.load("accuracy")


接着，调用 `compute` 函数来计算预测的准确率。

在将预测传递给 compute 函数之前，我们需要将 logits 转换为预测值（**所有Transformers 模型都返回 logits**）。

In [14]:
def compute_metrics(eval_pred):
    logits, labels = eval_pred
    predictions = np.argmax(logits, axis=-1)
    return metric.compute(predictions=predictions, references=labels)

#### 训练过程指标监控

通常，为了监控训练过程中的评估指标变化，我们可以在`TrainingArguments`指定`evaluation_strategy`参数，以便在 epoch 结束时报告评估指标。

In [15]:
from transformers import TrainingArguments, Trainer

training_args = TrainingArguments(output_dir=model_dir,
                                  evaluation_strategy="epoch", 
                                  per_device_train_batch_size=16,
                                  num_train_epochs=1,
                                  logging_steps=30)

## 开始训练

### 实例化训练器（Trainer）

`kernel version` 版本问题：暂不影响本示例代码运行

In [16]:
trainer = Trainer(
    model=model,
    args=training_args,
    train_dataset=small_train_dataset,
    eval_dataset=small_eval_dataset,
    compute_metrics=compute_metrics,
)

## 使用 nvidia-smi 查看 GPU 使用

为了实时查看GPU使用情况，可以使用 `watch` 指令实现轮询：`watch -n 1 nvidia-smi`:

```shell
Every 1.0s: nvidia-smi                                                   Wed Dec 20 14:37:41 2023

Wed Dec 20 14:37:41 2023
+---------------------------------------------------------------------------------------+
| NVIDIA-SMI 535.129.03             Driver Version: 535.129.03   CUDA Version: 12.2     |
|-----------------------------------------+----------------------+----------------------+
| GPU  Name                 Persistence-M | Bus-Id        Disp.A | Volatile Uncorr. ECC |
| Fan  Temp   Perf          Pwr:Usage/Cap |         Memory-Usage | GPU-Util  Compute M. |
|                                         |                      |               MIG M. |
|=========================================+======================+======================|
|   0  Tesla T4                       Off | 00000000:00:0D.0 Off |                    0 |
| N/A   64C    P0              69W /  70W |   6665MiB / 15360MiB |     98%      Default |
|                                         |                      |                  N/A |
+-----------------------------------------+----------------------+----------------------+

+---------------------------------------------------------------------------------------+
| Processes:                                                                            |
|  GPU   GI   CI        PID   Type   Process name                            GPU Memory |
|        ID   ID                                                             Usage      |
|=======================================================================================|
|    0   N/A  N/A     18395      C   /root/miniconda3/bin/python                6660MiB |
+---------------------------------------------------------------------------------------+
```

In [17]:
trainer.train()

Epoch,Training Loss,Validation Loss,Accuracy
1,0.7992,0.814003,0.6434


TrainOutput(global_step=4063, training_loss=0.8986900211290101, metrics={'train_runtime': 5640.9436, 'train_samples_per_second': 11.523, 'train_steps_per_second': 0.72, 'total_flos': 1.710267926016e+16, 'train_loss': 0.8986900211290101, 'epoch': 1.0})

In [18]:
small_test_dataset = tokenized_datasets["test"].shuffle(seed=64).select(range(100))

In [20]:
trainer.evaluate(small_test_dataset)

{'eval_loss': 0.9093865752220154,
 'eval_accuracy': 0.53,
 'eval_runtime': 2.9229,
 'eval_samples_per_second': 34.212,
 'eval_steps_per_second': 4.448,
 'epoch': 1.0}

### 保存模型和训练状态

- 使用 `trainer.save_model` 方法保存模型，后续可以通过 from_pretrained() 方法重新加载
- 使用 `trainer.save_state` 方法保存训练状态

In [21]:
trainer.save_model(model_dir)

In [22]:
trainer.save_state()

In [23]:
# trainer.model.save_pretrained("./")

## Homework: 使用完整的 YelpReviewFull 数据集训练，看 Acc 最高能到多少

In [1]:
from datasets import load_dataset

dataset = load_dataset("yelp_review_full")

  from .autonotebook import tqdm as notebook_tqdm


In [2]:
dataset

DatasetDict({
    train: Dataset({
        features: ['label', 'text'],
        num_rows: 650000
    })
    test: Dataset({
        features: ['label', 'text'],
        num_rows: 50000
    })
})

In [3]:
dataset["train"][111]

{'label': 2,
 'text': "As far as Starbucks go, this is a pretty nice one.  The baristas are friendly and while I was here, a lot of regulars must have come in, because they bantered away with almost everyone.  The bathroom was clean and well maintained and the trash wasn't overflowing in the canisters around the store.  The pastries looked fresh, but I didn't partake.  The noise level was also at a nice working level - not too loud, music just barely audible.\\n\\nI do wish there was more seating.  It is nice that this location has a counter at the end of the bar for sole workers, but it doesn't replace more tables.  I'm sure this isn't as much of a problem in the summer when there's the space outside.\\n\\nThere was a treat receipt promo going on, but the barista didn't tell me about it, which I found odd.  Usually when they have promos like that going on, they ask everyone if they want their receipt to come back later in the day to claim whatever the offer is.  Today it was one of th

In [4]:
import random
import pandas as pd
import datasets
from IPython.display import display, HTML

In [5]:
def show_random_elements(dataset, num_examples=10):
    assert num_examples <= len(dataset), "Can't pick more elements than there are in the dataset."
    picks = []
    for _ in range(num_examples):
        pick = random.randint(0, len(dataset)-1)
        while pick in picks:
            pick = random.randint(0, len(dataset)-1)
        picks.append(pick)
    
    df = pd.DataFrame(dataset[picks])
    for column, typ in dataset.features.items():
        if isinstance(typ, datasets.ClassLabel):
            df[column] = df[column].transform(lambda i: typ.names[i])
    display(HTML(df.to_html()))

In [6]:
show_random_elements(dataset["train"])

Unnamed: 0,label,text
0,4 stars,"This was my first time visiting this bistro and it was very good. The service was on point, my water glass was never empty! That's one of my biggest pet peeves because I drink a ton of water! \nOn to the food, everything looked so good however my whole table ended up ordering the special halibut and it was almost worth the 35$ a plate price tag. Being from California I'm used to the sea food being less expensive so that was my only qualm. It was do good I could eat it every day! I personally didn't care for the calamari appetizer, the flavor was ok but the squid was a little chewy for my taste. \nI have to say this is a great little bistro with an excellent staff and great menu. I must note that I called about bringing a small cake in and they wanted to charge me $2 dollars a slice to cut it! I thought that was absurd when restaurants on the strip only charge .50 cents a slice. Either way, I will be back once my bank account recovers!"
1,4 stars,Nice aquarium. Went for the first a couple of weekends ago. Downside is that since it's part of AZ mills mall we had to stand in life which pretty much blocked the mall entrance/ exit and shoppers had to filter through the aquarium line to get by. I would recommend buying tickets in advance online plus save a few extra bucks. Advance tickets paid for go through fastlane and dont have to wait. Offered military discount as well with valid ID.
2,1 star,Yuk. Sorry but the food was awful. Tried Nicks three times now & won't be back. The service was really bad too. I really wanted to like this as it's so close to my house.
3,4 stars,"Came here with my Chinese friends for Chinese New Year. They asked for the Chinese menu. There were 11 of us and the bill was $300.00 including the tip. Everything was fantastic. Chinese broccoli, fried rice with scallops, salty fried chicken, a wonderful beef dish with beautiful steamed broccoli and a plethera of other things. I know there are two menus, English and Chinese. I have heard that there are different prices between them both. I wish I read/spoke Chinese because I am definitely going back here. My husband and I loved it. It took about 30 minutes for food to come after ordering and everything was piping hot. I took off 1 star for decor, that could be a bit better. Other than that, I'm glad I went and look forward to going back again."
4,4 stars,"Cracker Barrel has some great food at reasonable prices. \n\nIf you are used to southern cooking like I am ( My Grandma ) then this is the place for you. I love their Chicken N Dumplins !\n\nYou can not go wrong at Cracker Barrel. Yes, I'm a fan !"
5,1 star,Unprofessional employees. Completely rude on the phone and couldn't answer a simple question. 0.5/10 would not recommend unless you don't mind losing your sanity going back and forth about a simple question about your delivery.
6,5 stars,"The tacos are the best. The guacomole is very flavorful, but the chips shouldve been crispy- they were dry.\n\nOverall, I would recomend this restaurant."
7,5 stars,"Wow! Always a sure fire hit if you're entertaining guests from out of town or just want to grap some awesome asian fusion cusine! We had rolls of all variety, spicy tuna, crunch lobster roll, crab tempura. It was great. The only small thing was the music makes conversation a little labored, but not too bad really.\n\nRock on Noh!"
8,5 stars,"Okay, I LOVE this sushi restaurant. \n\nThe number of rolls were awesome, the food was fresh and well made when it came out. They had TORO (fatty tuna) at a GREAT price. ::Sigh:: Memories.\n\nI will say we ended up going for a Saturday evening and it seems our waitress forgot about us, because another one had to come take care of us the rest of the night, even though the other waitress was still there? Normally I would have knocked it down to a four star, but the Toro was so clean and at such a good price, I had some spine tingling moments with my favorite food of all time. \n\nPlease work on the service (I know it's a busy night but you should know that going in), but Chef's keep up the amazingness. : )"
9,4 stars,"Friendly service, good tasting food. Lots of TVs and good happy hour drink prices."


In [7]:
from transformers import AutoTokenizer

tokenizer = AutoTokenizer.from_pretrained("bert-base-cased")


def tokenize_function(examples):
    return tokenizer(examples["text"], padding="max_length", truncation=True)


tokenized_datasets = dataset.map(tokenize_function, batched=True)



In [8]:
show_random_elements(tokenized_datasets["train"], num_examples=1)

Unnamed: 0,label,text,input_ids,token_type_ids,attention_mask
0,1 star,Was not welcoming...\nFifteen min for a coffee...\nNot a great value...,"[101, 3982, 1136, 20028, 119, 119, 119, 165, 183, 2271, 17368, 9561, 11241, 1111, 170, 3538, 119, 119, 119, 165, 183, 2249, 3329, 170, 1632, 2860, 119, 119, 119, 102, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, ...]","[0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, ...]","[1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, ...]"


In [12]:
small_train_dataset = tokenized_datasets["train"].shuffle(seed=42).select(range(30000))
small_eval_dataset = tokenized_datasets["test"].shuffle(seed=42).select(range(5000))

In [13]:
from transformers import AutoModelForSequenceClassification

model = AutoModelForSequenceClassification.from_pretrained("bert-base-cased", num_labels=5)

Some weights of BertForSequenceClassification were not initialized from the model checkpoint at bert-base-cased and are newly initialized: ['classifier.bias', 'classifier.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.


In [14]:
from transformers import TrainingArguments

model_dir = "models/bert-base-cased-finetune-yelp"

# logging_steps 默认值为500，根据我们的训练数据和步长，将其设置为100
training_args = TrainingArguments(output_dir=model_dir,
                                  per_device_train_batch_size=16,
                                  num_train_epochs=5,
                                  logging_steps=100)

In [15]:
# 完整的超参数配置
print(training_args)

TrainingArguments(
_n_gpu=1,
adafactor=False,
adam_beta1=0.9,
adam_beta2=0.999,
adam_epsilon=1e-08,
auto_find_batch_size=False,
bf16=False,
bf16_full_eval=False,
data_seed=None,
dataloader_drop_last=False,
dataloader_num_workers=0,
dataloader_persistent_workers=False,
dataloader_pin_memory=True,
ddp_backend=None,
ddp_broadcast_buffers=None,
ddp_bucket_cap_mb=None,
ddp_find_unused_parameters=None,
ddp_timeout=1800,
debug=[],
deepspeed=None,
disable_tqdm=False,
dispatch_batches=None,
do_eval=False,
do_predict=False,
do_train=False,
eval_accumulation_steps=None,
eval_delay=0,
eval_steps=None,
evaluation_strategy=no,
fp16=False,
fp16_backend=auto,
fp16_full_eval=False,
fp16_opt_level=O1,
fsdp=[],
fsdp_config={'min_num_params': 0, 'xla': False, 'xla_fsdp_grad_ckpt': False},
fsdp_min_num_params=0,
fsdp_transformer_layer_cls_to_wrap=None,
full_determinism=False,
gradient_accumulation_steps=1,
gradient_checkpointing=False,
gradient_checkpointing_kwargs=None,
greater_is_better=None,
group_by_le

In [16]:
import numpy as np
import evaluate

metric = evaluate.load("accuracy")

In [17]:
def compute_metrics(eval_pred):
    logits, labels = eval_pred
    predictions = np.argmax(logits, axis=-1)
    return metric.compute(predictions=predictions, references=labels)

In [18]:
from transformers import TrainingArguments, Trainer

training_args = TrainingArguments(output_dir=model_dir,
                                  evaluation_strategy="epoch", 
                                  per_device_train_batch_size=16,
                                  num_train_epochs=3,
                                  logging_steps=30)

In [19]:
trainer = Trainer(
    model=model,
    args=training_args,
    train_dataset=small_train_dataset,
    eval_dataset=small_eval_dataset,
    compute_metrics=compute_metrics,
)

In [20]:
trainer.train()

Epoch,Training Loss,Validation Loss,Accuracy
1,0.8929,0.865457,0.6116
2,0.7079,0.862276,0.6364
3,0.4168,1.0759,0.6312


Checkpoint destination directory models/bert-base-cased-finetune-yelp/checkpoint-500 already exists and is non-empty.Saving will proceed but saved results may be invalid.
Checkpoint destination directory models/bert-base-cased-finetune-yelp/checkpoint-1000 already exists and is non-empty.Saving will proceed but saved results may be invalid.
Checkpoint destination directory models/bert-base-cased-finetune-yelp/checkpoint-1500 already exists and is non-empty.Saving will proceed but saved results may be invalid.
Checkpoint destination directory models/bert-base-cased-finetune-yelp/checkpoint-2000 already exists and is non-empty.Saving will proceed but saved results may be invalid.
Checkpoint destination directory models/bert-base-cased-finetune-yelp/checkpoint-2500 already exists and is non-empty.Saving will proceed but saved results may be invalid.
Checkpoint destination directory models/bert-base-cased-finetune-yelp/checkpoint-3000 already exists and is non-empty.Saving will proceed but

TrainOutput(global_step=5625, training_loss=0.7048972791883681, metrics={'train_runtime': 8155.3039, 'train_samples_per_second': 11.036, 'train_steps_per_second': 0.69, 'total_flos': 2.368063282176e+16, 'train_loss': 0.7048972791883681, 'epoch': 3.0})

In [21]:
small_test_dataset = tokenized_datasets["test"].shuffle(seed=64).select(range(100))

In [22]:
trainer.evaluate(small_test_dataset)

{'eval_loss': 1.24530827999115,
 'eval_accuracy': 0.6,
 'eval_runtime': 3.028,
 'eval_samples_per_second': 33.025,
 'eval_steps_per_second': 4.293,
 'epoch': 3.0}

In [23]:
trainer.save_model(model_dir)

In [24]:
trainer.save_state()