In [None]:
from paddlenlp import Taskflow
from paddlenlp.trainer import TrainingArguments
import json
from pprint import pprint

# ----------------------------------
# 配置参数
# ----------------------------------
schema = ["观点词", "情感倾向[正向,负向,中性,未提及]"]  # 定义四分类schema
aspects = pd.read_excel(r'./aspect_word.txt')  # 从文件读取或自定义
aspects = [line.strip() for line in open("aspect_word.txt", encoding='utf-8')]
train_data = "./data/train.json"
dev_data = "./data/dev.json"
output_dir = "./checkpoints"

# ----------------------------------
# Step 1: 初始化模型
# ----------------------------------
senta = Taskflow(
    task="sentiment_analysis",
    model="uie-senta-base",
    schema=schema,
    aspects=aspects,
    device='gpu'
)

# ----------------------------------
# Step 2: 微调模型
# ----------------------------------
# 配置训练参数
training_args = TrainingArguments(
    output_dir=output_dir,
    warmup_ratio=0.1, 
    evaluation_strategy="steps",  # 按步数验证
    eval_steps=100,               # 每100步验证一次
    logging_steps=50,             # 每50步打印日志
    save_steps=200,               # 每200步保存模型
    per_device_train_batch_size=16,
    per_device_eval_batch_size=32,
    num_train_epochs=10,
    learning_rate=1e-5,
    load_best_model_at_end=True,  # 训练结束时加载最佳模型
    metric_for_best_model="eval_loss",  # 根据验证损失选择最佳模型
    greater_is_better=False,       # eval_loss越小越好
)

# 开始微调
senta.fine_tune(
    train_data_path=train_data,
    eval_data_path=dev_data,      # 指定验证集
    training_args=training_args
)

# ----------------------------------
# Step 3: 保存模型
# ----------------------------------
best_model_path = os.path.join(output_dir, "best_model")
senta.save_model(best_model_path)
print(f"最佳模型已保存至：{best_model_path}")

# ----------------------------------
# Step 4: 加载微调后的模型进行推理
# ----------------------------------
# 初始化自定义模型
custom_senta = Taskflow(
    task="sentiment_analysis",
    task_path=model_save_dir,  # 加载自定义模型
    aspects=aspects,
    device='gpu'
)

# 测试样例
text = "医生服务态度一般"
result = custom_senta(text)
print("\n测试结果：")
for item in result:
    print(f"{item['aspect']}: {item['sentiment']} ({item.get('opinion','')})")