In [None]:
# Step1. 导入相关包
from transformers import AutoTokenizer, AutoModelForSequenceClassification,Trainer, TrainingArguments
from datasets import load_dataset

# Step2. 加载数据集
dataset = load_dataset("csv", data_files="./ChnSentiCorp_htl_all.csv", split="train")
dataset = dataset.filter(lambda x: x["review"] is not None)
dataset

In [None]:
# Step3. 划分数据集
datasets = dataset.train_test_split(test_size=0.1)
datasets

In [None]:
# Step4. 数据预处理

import torch
tokenizer = AutoTokenizer.from_pretrained("hfl/rbt3")

def process_function(examples):
    tokenzed_examples = tokenizer(examples["review"], max_length=128, truncation=True)
    tokenzed_examples["labels"] = examples["labels"]
    return tokenzed_examples

tokenized_datasets = datasets.map(process_function, batched=True, remove_columns=datasets["train"].column_names)
tokenized_datasets

In [None]:
# Step5. 创建模型
model = AutoModelForSequenceClassification.from_pretrained("hfl/rbt3")

# Step6. 创建评估函数
import evaluate
acc_metrics = evaluate.load("accuracy")
f1_metrics = evaluate.load("f1")

def eval_metric(eval_predict):
    predictions, labels = eval_predict
    predictions = predictions.argmax(axis=-1)
    acc = acc_metrics.compute(predictions=predictions, references=labels)
    f1 = f1_metrics.compute(predictions=predictions, references=labels)
    acc.update(f1)
    return  acc

# Step7. 创建TrainingArguments
train_args = TrainingArguments(output_dir="./checkpoints", per_device_train_batch_size=64, per_device_eval_batch_size=128,logging_steps=10, evaluation_strategy="steps", eval_steps=110,save_strategy="epoch", save_total_limit=3,learning_rate=2e-5, weight_decay=0.01)
train_args

# Step8. 创建Trainer
from transformers import DataCollatorWithPadding
trainer = Trainer(model=model, args=train_args,train_dataset=tokenized_datasets["train"], eval_dataset=tokenized_datasets["test"], data_collator=DataCollatorWithPadding(tokenizer=tokenizer),compute_metrics=eval_metric)

# Step9. 模型训练
trainer.train()

# Step10. 模型评估
trainer.evaluate()

# Step11. 模型预测
trainer.predict(tokenized_datasets["test"])
