# Hugging Face Transformers 微调训练入门

本示例将介绍基于 Transformers 实现模型微调训练的主要流程，包括：
- 数据集下载
- 数据预处理
- 训练超参数配置
- 训练评估指标设置
- 训练器基本介绍
- 实战训练
- 模型保存

## YelpReviewFull 数据集

**Hugging Face 数据集：[ YelpReviewFull ](https://huggingface.co/datasets/yelp_review_full)**

### 数据集摘要

Yelp评论数据集包括来自Yelp的评论。它是从Yelp Dataset Challenge 2015数据中提取的。

### 支持的任务和排行榜
文本分类、情感分类：该数据集主要用于文本分类：给定文本，预测情感。

### 语言
这些评论主要以英语编写。

### 数据集结构

#### 数据实例
一个典型的数据点包括文本和相应的标签。

来自YelpReviewFull测试集的示例如下：

```json
{
    'label': 0,
    'text': 'I got \'new\' tires from them and within two weeks got a flat. I took my car to a local mechanic to see if i could get the hole patched, but they said the reason I had a flat was because the previous patch had blown - WAIT, WHAT? I just got the tire and never needed to have it patched? This was supposed to be a new tire. \\nI took the tire over to Flynn\'s and they told me that someone punctured my tire, then tried to patch it. So there are resentful tire slashers? I find that very unlikely. After arguing with the guy and telling him that his logic was far fetched he said he\'d give me a new tire \\"this time\\". \\nI will never go back to Flynn\'s b/c of the way this guy treated me and the simple fact that they gave me a used tire!'
}
```

#### 数据字段

- 'text': 评论文本使用双引号（"）转义，任何内部双引号都通过2个双引号（""）转义。换行符使用反斜杠后跟一个 "n" 字符转义，即 "\n"。
- 'label': 对应于评论的分数（介于1和5之间）。

#### 数据拆分

Yelp评论完整星级数据集是通过随机选取每个1到5星评论的130,000个训练样本和10,000个测试样本构建的。总共有650,000个训练样本和50,000个测试样本。

## 下载数据集

In [24]:
import subprocess
import os

result = subprocess.run('bash -c "source /etc/network_turbo && env | grep proxy"', shell=True, capture_output=True, text=True)
output = result.stdout
for line in output.splitlines():
    if '=' in line:
        var, value = line.split('=', 1)
        os.environ[var] = value

In [2]:
from datasets import load_dataset
#下载数据集 数据集的名称为yelp_review_full
dataset = load_dataset("yelp_review_full")

In [3]:
#数据集包含2个特征 label 和 text . train 650000条 test 50000条
dataset

DatasetDict({
    train: Dataset({
        features: ['label', 'text'],
        num_rows: 650000
    })
    test: Dataset({
        features: ['label', 'text'],
        num_rows: 50000
    })
})

In [4]:
dataset["train"][111]

{'label': 2,
 'text': "As far as Starbucks go, this is a pretty nice one.  The baristas are friendly and while I was here, a lot of regulars must have come in, because they bantered away with almost everyone.  The bathroom was clean and well maintained and the trash wasn't overflowing in the canisters around the store.  The pastries looked fresh, but I didn't partake.  The noise level was also at a nice working level - not too loud, music just barely audible.\\n\\nI do wish there was more seating.  It is nice that this location has a counter at the end of the bar for sole workers, but it doesn't replace more tables.  I'm sure this isn't as much of a problem in the summer when there's the space outside.\\n\\nThere was a treat receipt promo going on, but the barista didn't tell me about it, which I found odd.  Usually when they have promos like that going on, they ask everyone if they want their receipt to come back later in the day to claim whatever the offer is.  Today it was one of th

In [5]:
import random
import pandas as pd
import datasets
from IPython.display import display, HTML

In [6]:
def show_random_elements(dataset, num_examples=10):
    assert num_examples <= len(dataset), "Can't pick more elements than there are in the dataset."
    picks = []
    for _ in range(num_examples):
        pick = random.randint(0, len(dataset)-1)
        while pick in picks:
            pick = random.randint(0, len(dataset)-1)
        picks.append(pick)
    
    df = pd.DataFrame(dataset[picks])
    for column, typ in dataset.features.items():
        if isinstance(typ, datasets.ClassLabel):
            df[column] = df[column].transform(lambda i: typ.names[i])
    display(HTML(df.to_html()))

In [7]:
show_random_elements(dataset["train"])

Unnamed: 0,label,text
0,2 star,"Went in to see a friend to stand up comedy.\n\nBartender was friendly but they lacked certain spirits like Jack Daniels honey or Wild Turkey honey and the bartender only knew how to make drinks only containing two ingredients. Not as \""up to date\"" as I would have liked to see."
1,5 stars,"Found this place on the Triple D site and had to go. We were staying on the strip so took a cab. There is a little sign above a black door with a door bell button on the right. Don't need to buzz just walk in.\nWe weren't sure we were at the right place when we walked in but the staff was friendly and assured us we were. Pretty typical dive bar atmosphere, the owner is from Buffalo so there was a ton of Buffalo stuff on the walls which was nice as that is the area we are from. We ordered two wee's to split between four of us: \n\nThe Ginuea: This was featured on the show so we tried it. Fantastic!! The sauce was what put it over the top for me.\n\nThe Margarita: I was a little disappointed in the lack of sauce and basil(I was just expecting/wanted more) but every one thought it was good.\n\nWe finished with the fried dough: powdered sugar, chocolate sauce and fried=mmmmm.\n\nWe inquired about getting a cab to get back to the strip and the waitress called one up for us. I told her I would include her name and I forgot it (I'm very sorry!) but she was great. If I go back to Vegas I'll definitely go back. Even with the bar bill and tip it was cheaper than anywhere else we ate and as good or better."
2,3 stars,"Unsure weather to give this place 2 or 3 stars. \nJust go back after staying here for 5 nights. Had a room on the 10th floor which had an OK view of the strip, but a little low for the Belligio show, but room was comped, so who am I to argue. \nDecent sized room, a nice bathroom ( I like bathrooms with separate showers) and nice counter space. \nNow the bad. Shower had a low flow head which sucked, plus it was unadjustable so you couldn't move the spray away from you. The blow dryer had so much lint in it that it barely blew any air. The shampoo at all Harrahs properties is watery just not good. The TV is an old crt with very limited number of channels to choose from. But worst thing of all, as reviewer DL stated, both my wife and I had weird bite marks on out lower legs at the end of the trip. I pulled back the sheets and looked closely for bugs, but didn't see any.( I don't know how tiny bedbugs are) Just seem like a strange coincidence with the DL story. \nAside from the room, it has a nice selection of restaurants, a decent casino, and is centrally located on the strip so you can go over to Caesars or Belliago without too much effort."
3,5 stars,"We were going out on a very busy night, and decided to go to our local Streets location for dinner. The service was excellent. They introduced a new item on the menu - the Broasted Chicken. My husband and I are both big fans of good, fresh fried chicken, and we're really not into KFC - so we decided to give this a try - and we've been back 4 times for the chicken since! It's crispy on the outside, very lightly coated, and juicy on the inside - and not greasy! It's pressure cooked, so you don't have the grease of typical fried chicken. We've invited friends there and they raved about it as well. Give Streets of New York at Ray and Kyrene a chance - the chicken is what you'll be back for, time and again!"
4,1 star,"DO NOT PATRONIZE THIS SHOP!!\n\nI drive a Dodge Charger and I have been having some trouble with a shimmy in my steering wheel. The first time I went in, I had them balance all 4 tires. I was very wary because when I was waiting for them to write up the service order, the attendant didn't know how to ring in the service. When he asked his associate, that second guy told him to write it up at a lower price than the guy I was talking to quoted me. When I asked, he gave me some excuse that right now I don't even remember. It wasn't too big of a difference, but I still remember feeling a little ripped off because of course I was made to pay the higher price. \n\nThe real problem happened when I returned a couple weeks later to get an alignment as the balancing did not fix the problem. All seemed to be going well until I was about to be picked up by my girlfriend. The mechanic who was going to test drive the car backed out of the parking spot I was in. She was pulling into the parking lot at the same time coming down the aisle. The mechanic proceeded to hammer the gas pedal, squealing the tires and almost hitting my girlfriend head on as he maneuvered around her and out of the parking lot. Needless to say I was furious and confronted the manager. He assured me that it would be handled. \n\nWhen we were leaving, the mechanic pulled into the parking lot making an ILLEGAL left turn into the parking lot which was blocked by a median. Instead of going to the nearest light and making a safe and legal u-turn, he decided to make the left heading up the wrong side of the road. I was livid at this point. I approached the mechanic in my car and his response was that he did not squeal the tires and that he makes that turn all the time basically asking me what the big deal was. The manager came up to me giving me his best effort to rectify the situation by promising me a free oil change on my next visit. REALLY?!?!?!? I told him that I perform my own oil changes and just to fix my car.\n\nWhen I went back to pick up my car, the manager continued to assure me that everything would be taken care of and offered apologies. He gave me his best effort I guess, but his resolution was to charge me full price for the awful experience I had just gone through and to give me a free oil change on my next visit which will NEVER happen even after I told him that I change my own oil. \n\nWell, my shimmy problem is still not fixed, which there might be something wrong with the rotors, still haven't figured that out. I don't hold the continued problem against them, they performed the work I asked them to do, but of course there was no diagnosis of the actual problem or extra effort. Kinda feels like I just got a hardly satisfying haircut at Super Clips or something. Their only concern is getting cars in and out of the shop, not developing lifelong customers. I received no comfort by the actions of their mechanic while driving my car. If I saw him peel out in my car, almost hit my girlfriend, and perform an illegal left turn in the 5 seconds I saw him driving my car, what kind of driving maneuvers is he doing while he's out of sight and what other corners is he cutting while performing work on my car?\n\nI will never be going back to this shop or any other Tire Works shops for any service even if it is given to me for free. AWFUL!!"
5,5 stars,"Love this theater! It's clean, great service, popcorn hot & perfect. Perfect date night."
6,5 stars,"Still one of my favorites. Don't waste your time at any other spa on the strip. I have yet to find one that surpasses this, and I've tried most.\n\nMake sure you save time to enjoy the variety of baths. I could fall asleep in that room on the chairs for hours."
7,3 stars,"I was stressed out at the beginning of my wait at the airport and all I wanted was a Bloody Mary. Luckily I hatched that thought right outside of Home Turf Sports Bar which advertised \""mile high bloody mary's\"". As I walked in I found it to be messy and confusing. It's really unclear as to whether you should sit down and be waited on or go to the bar and order. This fact is exacerbated by the fact that they have a standing room only bar 5 feet from the bar they serve alcohol at which makes it hard to tell if there's a line or just people chilling. \n Aside from that I got a double Bloody Mary for 10 bucks, not too bad for an airport, but it definitely wasn't a mile high or anything that should be boasting awesomeness. I could have made just as good of one at home. Honestly. That being said... I was glad to have my drink and a place to rest my hat for a bit. Definitely 3 star worthy."
8,3 stars,"It wasn't busy at all when I went in. The hostess was nice and got me right to a seat. These servers here are quick. Soon as I picked up my menu the girl was asking what drink I wanted. I asked her to come back seeing that they have a lot of chocolate drinks, and I wanted some extra calories today. \n\nSo the prices turned me off about those good yummy sounding drinks, and I just got a coke, which was pricey $2.25 for a diet coke sucks. \nI didn't want to spend too much, and It looked like I had no choice. I had a lot of good choices to eat. The menu items sounded so good, and I went with the Really Crunchy Mac n Cheese. \n\nI drank my coke quickly, and they asked me several times if I needed a refill but I really actually was worried that they would charge me an extra $2.25 for it, so I said no. \n\nThe meal was ok.. I mean it was good don't get me wrong, but not the best I tried. Soon half way in the meal I bite on some seasoning, they had in the mac and cheese, and I hated it. I have tasted this before in another meal once, and I don't know what it is! I changed my mind as I was eating it and instend of ok..it was now bad because of that one little seasoning. I just stopped eating and hurried to check the menu for a drink, they could get the taste out of my mouth because it makes me sick to my stomach. \n\nI got a cookie shake to go that cost $7.75.\nThat was the best part of this whole meal was that shake! It was soo good! It had white chocolate cream and Oreo cookies! I would come back just for that shake! \n\nI think this place was decent. I just picked the wrong meal for me. If I come back, I will try something with less ingredients. \n\nMy whole meal was about $25.00, and that's not including the tip, and that was just for one individual. If you arrive with friends and family, you know where the price will be."
9,1 star,"It's always fun to see new restaurants open in the Burgh. Not so much this one. I was there with a friend, both of us really hungry. We ordered soup that took an hour to be served. When we \""reminded\"" our server we were told they were swamped. I counted 12 other people. I asked for no onions on my chili, but it was covered in onions. Maybe it was to hide the taste. I ordered catfish which I couldn't eat. I saw two other people send their food back. My soft drink glass was never refilled. The bill for 2 people was over $50. That was for soup and an entree'. When we left the hostess was busy chatting on her cell phone. This is a miss. Much better places to eat in the Waterfront."


## 预处理数据

下载数据集到本地后，使用 Tokenizer 来处理文本，对于长度不等的输入数据，可以使用填充（padding）和截断（truncation）策略来处理。

Datasets 的 `map` 方法，支持一次性在整个数据集上应用预处理函数。

下面使用填充到最大长度的策略，处理整个数据集：

In [8]:
from transformers import AutoTokenizer
#使用的是bert-base-cased模型
tokenizer = AutoTokenizer.from_pretrained("bert-base-cased")

# 数据处理方法 使用tokenizer处理 
def tokenize_function(examples):
    # 使用tokenizer处理examples字典中的"text"键对应的文本数据
    # padding="max_length"表示将所有序列填充到最大长度 bert-base-cased 最大为512
    # truncation=True表示如果序列长度超过最大长度，则进行截断
    return tokenizer(examples["text"], padding="max_length", truncation=True)


tokenized_datasets = dataset.map(tokenize_function, batched=True)

In [9]:
show_random_elements(tokenized_datasets["train"], num_examples=1)

Unnamed: 0,label,text,input_ids,token_type_ids,attention_mask
0,4 stars,"Great pizza! I normally travel all the way out to Coolidge for authentic New York style pizza. NYPD will not replace it, but it certainly will be a regular stop for me.","[101, 2038, 13473, 106, 146, 5156, 3201, 1155, 1103, 1236, 1149, 1106, 13297, 15091, 1111, 16047, 1203, 1365, 1947, 13473, 119, 5883, 15481, 1209, 1136, 4971, 1122, 117, 1133, 1122, 4664, 1209, 1129, 170, 2366, 1831, 1111, 1143, 119, 102, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, ...]","[0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, ...]","[1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, ...]"


### 数据抽样

使用 1000 个数据样本，在 BERT 上演示小规模训练（基于 Pytorch Trainer）

`shuffle()`函数会随机重新排列列的值。如果您希望对用于洗牌数据集的算法有更多控制，可以在此函数中指定generator参数来使用不同的numpy.random.Generator。

In [10]:
#抽取全量数据集
small_train_dataset = tokenized_datasets["train"].shuffle(seed=42).select(range(650000))
small_eval_dataset = tokenized_datasets["test"].shuffle(seed=42).select(range(50000))

## 微调训练配置

### 加载 BERT 模型

警告通知我们正在丢弃一些权重（`vocab_transform` 和 `vocab_layer_norm` 层），并随机初始化其他一些权重（`pre_classifier` 和 `classifier` 层）。在微调模型情况下是绝对正常的，因为我们正在删除用于预训练模型的掩码语言建模任务的头部，并用一个新的头部替换它，对于这个新头部，我们没有预训练的权重，所以库会警告我们在用它进行推理之前应该对这个模型进行微调，而这正是我们要做的事情。

In [11]:
from transformers import AutoModelForSequenceClassification
#AutoModelForSequenceClassification自动加载适合于序列分类任务的预训练模型。序列分类任务通常包括情感分析、文本分类、观点挖掘等，其中输入是一个文本序列，输出是该序列的类别标签。
#num_labels 表示只加载前5层
model = AutoModelForSequenceClassification.from_pretrained("bert-base-cased", num_labels=5)

Some weights of BertForSequenceClassification were not initialized from the model checkpoint at bert-base-cased and are newly initialized: ['classifier.bias', 'classifier.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.


In [12]:
inputs = tokenizer("This is a test sentence.",return_tensors="pt")

In [13]:
inputs

{'input_ids': tensor([[ 101, 1188, 1110,  170, 2774, 5650,  119,  102]]), 'token_type_ids': tensor([[0, 0, 0, 0, 0, 0, 0, 0]]), 'attention_mask': tensor([[1, 1, 1, 1, 1, 1, 1, 1]])}

In [14]:
outputs  = model(**inputs)

In [15]:
outputs

SequenceClassifierOutput(loss=None, logits=tensor([[-0.9853,  0.3214,  0.1799,  0.2356,  0.2071]],
       grad_fn=<AddmmBackward0>), hidden_states=None, attentions=None)

In [16]:
import torch
logits = outputs.logits
predictions = torch.argmax(logits, dim=-1)
print(predictions)

tensor([1])


### 训练超参数（TrainingArguments）

完整配置参数与默认值：https://huggingface.co/docs/transformers/v4.36.1/en/main_classes/trainer#transformers.TrainingArguments

源代码定义：https://github.com/huggingface/transformers/blob/v4.36.1/src/transformers/training_args.py#L161

**最重要配置：模型权重保存路径(output_dir)**

In [17]:
from transformers import TrainingArguments

# 设置模型保存的目录
model_dir = "models/bert-base-cased-finetune-yelp-0716"

# 初始化TrainingArguments类，用于指定训练过程的参数
training_args = TrainingArguments(
    output_dir=model_dir,  # 设置模型的输出目录，用于保存训练好的模型和日志文件
    per_device_train_batch_size=32,  # 设置每个GPU/CPU在训练时使用的批量大小
    num_train_epochs=5,  # 设置训练的总轮数
    logging_steps=100  # 设置打印日志的步长，即每100步打印一次日志信息，默认值为500
)


In [18]:
# 完整的超参数配置
print(training_args)

TrainingArguments(
_n_gpu=1,
accelerator_config={'split_batches': False, 'dispatch_batches': None, 'even_batches': True, 'use_seedable_sampler': True, 'non_blocking': False, 'gradient_accumulation_kwargs': None},
adafactor=False,
adam_beta1=0.9,
adam_beta2=0.999,
adam_epsilon=1e-08,
auto_find_batch_size=False,
batch_eval_metrics=False,
bf16=False,
bf16_full_eval=False,
data_seed=None,
dataloader_drop_last=False,
dataloader_num_workers=0,
dataloader_persistent_workers=False,
dataloader_pin_memory=True,
dataloader_prefetch_factor=None,
ddp_backend=None,
ddp_broadcast_buffers=None,
ddp_bucket_cap_mb=None,
ddp_find_unused_parameters=None,
ddp_timeout=1800,
debug=[],
deepspeed=None,
disable_tqdm=False,
dispatch_batches=None,
do_eval=False,
do_predict=False,
do_train=False,
eval_accumulation_steps=None,
eval_delay=0,
eval_do_concat_batches=True,
eval_steps=None,
eval_strategy=IntervalStrategy.NO,
evaluation_strategy=None,
fp16=False,
fp16_backend=auto,
fp16_full_eval=False,
fp16_opt_level=O1

### 训练过程中的指标评估（Evaluate)

**[Hugging Face Evaluate 库](https://huggingface.co/docs/evaluate/index)** 支持使用一行代码，获得数十种不同领域（自然语言处理、计算机视觉、强化学习等）的评估方法。 当前支持 **完整评估指标：https://huggingface.co/evaluate-metric**

训练器（Trainer）在训练过程中不会自动评估模型性能。因此，我们需要向训练器传递一个函数来计算和报告指标。 

Evaluate库提供了一个简单的准确率函数，您可以使用`evaluate.load`函数加载

In [19]:
import numpy as np
import evaluate

metric = evaluate.load("accuracy")


接着，调用 `compute` 函数来计算预测的准确率。

在将预测传递给 compute 函数之前，我们需要将 logits 转换为预测值（**所有Transformers 模型都返回 logits**）。

In [20]:
def compute_metrics(eval_pred):
    logits, labels = eval_pred
    predictions = np.argmax(logits, axis=-1)
    return metric.compute(predictions=predictions, references=labels)

#### 训练过程指标监控

通常，为了监控训练过程中的评估指标变化，我们可以在`TrainingArguments`指定`evaluation_strategy`参数，以便在 epoch 结束时报告评估指标。

In [21]:
from transformers import TrainingArguments, Trainer

# 初始化TrainingArguments实例，这是用于配置训练过程的参数集合
training_args = TrainingArguments(
    output_dir=model_dir,  # 输出目录：指定训练过程中生成的模型和日志文件将被保存在哪个目录下
    evaluation_strategy="epoch",  # 评估策略：指定模型评估的时机。这里设置为"epoch"，表示每个训练轮（epoch）结束后进行一次评估
    per_device_train_batch_size=32,  # 批量大小：每个训练设备（GPU或CPU）上一次处理的样本数量。批量大小影响模型的训练效率和内存使用
    num_train_epochs=3,  # 训练轮数：指定模型将在训练数据集上完整训练的次数。每个轮数模型将看到整个数据集一次
    logging_steps=100,  # 日志步数：指定训练过程中每多少步打印一次日志信息。这里设置为30，意味着每30个训练步骤会打印一次日志
    save_steps=5000, #设置每1000个训练步骤保存一次检查点
    save_total_limit=10,#限制保存的检查点数量
    learning_rate=1e-3,
    weight_decay=0,
    warmup_ratio=0.1,
    
)



## 开始训练

### 实例化训练器（Trainer）

`kernel version` 版本问题：暂不影响本示例代码运行

In [22]:
trainer = Trainer(
    model=model,#模型
    args=training_args,#训练参数
    train_dataset=small_train_dataset,#训练数据集
    eval_dataset=small_eval_dataset,#验证数据集
    compute_metrics=compute_metrics,#验证方法
)

## 使用 nvidia-smi 查看 GPU 使用

为了实时查看GPU使用情况，可以使用 `watch` 指令实现轮询：`watch -n 1 nvidia-smi`:

```shell
Every 1.0s: nvidia-smi                                                   Wed Dec 20 14:37:41 2023

Wed Dec 20 14:37:41 2023
+---------------------------------------------------------------------------------------+
| NVIDIA-SMI 535.129.03             Driver Version: 535.129.03   CUDA Version: 12.2     |
|-----------------------------------------+----------------------+----------------------+
| GPU  Name                 Persistence-M | Bus-Id        Disp.A | Volatile Uncorr. ECC |
| Fan  Temp   Perf          Pwr:Usage/Cap |         Memory-Usage | GPU-Util  Compute M. |
|                                         |                      |               MIG M. |
|=========================================+======================+======================|
|   0  Tesla T4                       Off | 00000000:00:0D.0 Off |                    0 |
| N/A   64C    P0              69W /  70W |   6665MiB / 15360MiB |     98%      Default |
|                                         |                      |                  N/A |
+-----------------------------------------+----------------------+----------------------+

+---------------------------------------------------------------------------------------+
| Processes:                                                                            |
|  GPU   GI   CI        PID   Type   Process name                            GPU Memory |
|        ID   ID                                                             Usage      |
|=======================================================================================|
|    0   N/A  N/A     18395      C   /root/miniconda3/bin/python                6660MiB |
+---------------------------------------------------------------------------------------+
```

In [23]:
trainer.train()
trainer.save_model(model_dir)
trainer.save_state()

Epoch,Training Loss,Validation Loss,Accuracy
1,1.6103,1.609557,0.2
2,1.6092,1.60945,0.2



KeyboardInterrupt



In [None]:
#加载测试集
# small_test_dataset = tokenized_datasets["test"].shuffle(seed=64).select(range(50000))

In [None]:
trainer.evaluate(small_test_dataset)

### 保存模型和训练状态

- 使用 `trainer.save_model` 方法保存模型，后续可以通过 from_pretrained() 方法重新加载
- 使用 `trainer.save_state` 方法保存训练状态

In [None]:
# trainer.save_model(model_dir)

In [None]:
# trainer.save_state()

In [None]:
# trainer.model.save_pretrained("./")

In [None]:
#加载训练完的模型
from transformers import AutoTokenizer, AutoModelForSequenceClassification

# 指定模型保存的路径


# 加载tokenizer和模型
#tokenizer = AutoTokenizer.from_pretrained(model_dir)
model = AutoModelForSequenceClassification.from_pretrained(model_dir)

# 将模型设置为评估模式
model.eval()

# 现在模型已经加载，可以用于推理或进一步的处理


In [None]:
# 输入文本
text = "Here is some text to classify."

# 使用tokenizer处理文本
encoded_input = tokenizer(text, return_tensors='pt')

# 使用模型进行预测
with torch.no_grad():
    outputs = model(**encoded_input)

# 获取预测结果
predictions = torch.argmax(outputs.logits, dim=-1)
print(predictions)


## Homework: 使用完整的 YelpReviewFull 数据集训练，看 Acc 最高能到多少