# Homework: 使用完整的 YelpReviewFull 数据集训练

## 公共

In [1]:
bert_base_cased="/mnt/workspace/models/google-bert/bert-base-cased"
yelp_review_full="/mnt/workspace/dataset/yelp_review_full"
output_model_dir="/mnt/workspace/models/bert-base-cased-finetune-yelp"
max_length==512

## 数据集

In [135]:
import random
import pandas as pd
import datasets
from IPython.display import display, HTML

# import os
# os.environ['HF_ENDPOINT'] = 'https://hf-mirror.com'

def show_random_elements(dataset, num_examples=10):
    assert num_examples <= len(dataset), "Can't pick more elements than there are in the dataset."
    picks = []
    for _ in range(num_examples):
        pick = random.randint(0, len(dataset)-1)
        while pick in picks:
            pick = random.randint(0, len(dataset)-1)
        picks.append(pick)
    
    df = pd.DataFrame(dataset[picks])
    for column, typ in dataset.features.items():
        if isinstance(typ, datasets.ClassLabel):
            df[column] = df[column].transform(lambda i: typ.names[i])
    display(HTML(df.to_html()))

In [75]:
from datasets import load_dataset

dataset = load_dataset(yelp_review_full)

In [None]:
dataset

In [None]:
dataset["train"][12]

In [None]:
show_random_elements(dataset["train"],3)

## 预处理数据

### 预处理

In [None]:
from transformers import AutoTokenizer

tokenizer = AutoTokenizer.from_pretrained(bert_base_cased)


def tokenize_function(examples):
    return tokenizer(examples["text"], padding="max_length", truncation=True,max_length=max_length)


# num_proc 并行处理
# tokenized_datasets = dataset.map(tokenize_function, batched=True)
tokenized_datasets = dataset.map(tokenize_function, batched=True, num_proc=8)

In [None]:
show_random_elements(tokenized_datasets["train"], num_examples=1)

### 数据重排

In [141]:
train_dataset = tokenized_datasets["train"].shuffle(seed=42)
eval_dataset = tokenized_datasets["test"].shuffle(seed=42)

## 微调训练配置

### 加载 BERT 模型

In [None]:
from transformers import AutoModelForSequenceClassification

llm_model = AutoModelForSequenceClassification.from_pretrained(bert_base_cased, num_labels=5)

### 训练超参数（TrainingArguments）

完整配置参数与默认值：https://huggingface.co/docs/transformers/v4.36.1/en/main_classes/trainer#transformers.TrainingArguments

源代码定义：https://github.com/huggingface/transformers/blob/v4.36.1/src/transformers/training_args.py#L161

**最重要配置：模型权重保存路径(output_dir)**

In [143]:
from transformers import TrainingArguments

# logging_steps 默认值为500，根据我们的训练数据和步长，将其设置为100
training_args = TrainingArguments(output_dir=output_model_dir,
                                  per_device_train_batch_size=16,
                                  num_train_epochs=5,
                                  logging_steps=100)

In [None]:
# 完整的超参数配置
print(training_args)

### 训练过程中的指标评估（Evaluate)

In [147]:
import numpy as np
import evaluate

metric = evaluate.load("/mnt/workspace/evaluate/metrics/accuracy/accuracy.py")
# metric = evaluate.load("accuracy")

In [148]:
def compute_metrics(eval_pred):
    logits, labels = eval_pred
    predictions = np.argmax(logits, axis=-1)
    return metric.compute(predictions=predictions, references=labels)

#### 训练过程指标监控

In [149]:
from transformers import TrainingArguments, Trainer

training_args = TrainingArguments(output_dir=output_model_dir,
                                  evaluation_strategy="epoch", 
                                  per_device_train_batch_size=16,
                                  num_train_epochs=3,
                                  logging_steps=500)

## 开始训练



### 实例化训练器（Trainer）

In [None]:
trainer = Trainer(
    model=llm_model,
    args=training_args,
    train_dataset=train_dataset,
    eval_dataset=eval_dataset,
    compute_metrics=compute_metrics,
)

### 使用 nvidia-smi 查看 GPU 使用

!watch -n 1 nvidia-smi

### 启动训练

In [None]:
trainer.train()

### 抽样验证

In [133]:
small_test_dataset = tokenized_datasets["test"].shuffle(seed=64).select(range(100))
trainer.evaluate(small_test_dataset)

## 保存模型和训练状态

In [20]:
trainer.save_model(output_model_dir)

In [21]:
trainer.save_state()

In [23]:
# trainer.model.save_pretrained("./")