## 模型加载
### QLoRA

In [None]:
from modelscope import AutoTokenizer, AutoModelForCausalLM
from transformers import BitsAndBytesConfig
import torch
bnb_config = BitsAndBytesConfig(
    load_in_4bit=True,
    bnb_4bit_quant_type="nf4",
    bnb_4bit_compute_dtype=torch.bfloat16,
    bnb_4bit_use_double_quant=True,
)

model_name_or_path = "Qwen/Qwen2.5-7B-Instruct"
tokenizer = AutoTokenizer.from_pretrained(model_name_or_path, use_fast=True, force_download=True, trust_remote_code=True)
model = AutoModelForCausalLM.from_pretrained(
    model_name_or_path,
    trust_remote_code=True,
    device_map="auto",
    quantization_config=bnb_config,
)
# model.enable_input_require_grads() ## set it if use gradient checkpointing to save memory

Downloading Model from https://www.modelscope.cn to directory: /root/.cache/modelscope/hub/models/Qwen/Qwen2.5-7B-Instruct


2025-10-08 17:17:30,997 - modelscope - INFO - Target directory already exists, skipping creation.


Downloading Model from https://www.modelscope.cn to directory: /root/.cache/modelscope/hub/models/Qwen/Qwen2.5-7B-Instruct


2025-10-08 17:17:44,790 - modelscope - INFO - Target directory already exists, skipping creation.


Loading checkpoint shards:   0%|          | 0/4 [00:00<?, ?it/s]

## 数据集转换及加载

In [2]:
import json
import os

def transfer_dataset(origin_path, new_path): 
    with open(origin_path, "r") as f:
        messages = []
        contents = [s.strip() for s in f.read().split('\n\n')]
        for content in contents:
            parts = content.split('\n')
            input_text = ""
            labels_list = []
            for part in parts:
                input_text += part[0]
                if part[2] == 'B':
                    labels_list.append({
                        "entity_name": part[0],
                        "entity_label": part[4:],
                    })
                elif part[2] == "I":
                    labels_list[-1]["entity_name"] += part[0]
            message = {
                "text": input_text,
                "labels": json.dumps(labels_list, ensure_ascii=False)  # 用空格连接标签
            }
            messages.append(message)
    with open(new_path, "w") as f:
        for message in messages:
            f.write(json.dumps(message, ensure_ascii=False) + '\n')

origin_paths = ["medical.dev", "medical.test", "medical.train"]
new_paths = ["eval.jsonl", "test.jsonl", "train.jsonl"]
data_dir = "data"
for origin_path, new_path in zip(origin_paths, new_paths):
    transfer_dataset(os.path.join(data_dir, origin_path), os.path.join(data_dir, new_path))

In [3]:
from datasets import load_dataset
train_file = os.path.join(data_dir, "train.jsonl")
eval_file = os.path.join(data_dir, "eval.jsonl")
test_file = os.path.join(data_dir, "test.jsonl")
data_files = {
    "train": train_file,
    "eval": eval_file,
}
dataset_dict = load_dataset("json", data_files=data_files)
dataset_dict["eval"] = dataset_dict["eval"].shuffle(seed=42).select(range(100))
print(dataset_dict)

Generating train split: 0 examples [00:00, ? examples/s]

Generating eval split: 0 examples [00:00, ? examples/s]

DatasetDict({
    train: Dataset({
        features: ['text', 'labels'],
        num_rows: 5259
    })
    eval: Dataset({
        features: ['text', 'labels'],
        num_rows: 100
    })
})


In [4]:
SYSTEM_PROMPT = """你是一个中医药领域的专家，你需要从给定的句子中提取实体信息。所有的实体种类: 中医治则 中医治疗 中医证候 中医诊断 中药 临床表现 其他治疗 方剂 西医治疗 西医诊断。
每一个实体对应一个json格式，共同组成一个json列表，例如"[{"entity_name": "口苦", "entity_label": "临床表现"}]". """
MAX_LENGTH = 512
LABELS = ["中医治则", "中医治疗", "中医证候", "中医诊断", "中药", "临床表现", "其他治疗", "方剂", "西医治疗", "西医诊断"]
def preprocess_function(example):
    model_inputs = tokenizer.apply_chat_template(
        [
            {"role": "system", "content": SYSTEM_PROMPT},
            {"role": "user", "content": example["text"]},
        ],
        tokenize=False,
        add_generation_prompt=True
    )
    model_inputs = tokenizer(model_inputs, add_special_tokens=False)
    labels = tokenizer(example["labels"] + tokenizer.eos_token, add_special_tokens=False)
    input_ids = model_inputs["input_ids"] + labels["input_ids"]
    attention_mask = model_inputs["attention_mask"] + labels["attention_mask"]
    labels = [-100] * len(model_inputs["input_ids"]) + labels["input_ids"]
    if len(input_ids) > MAX_LENGTH:
        input_ids = input_ids[:MAX_LENGTH]
        attention_mask = attention_mask[:MAX_LENGTH]
        labels = labels[:MAX_LENGTH]
    return {
        "input_ids": input_ids,
        "attention_mask": attention_mask,
        "labels": labels
    }
dataset_dict = dataset_dict.map(preprocess_function, remove_columns=dataset_dict["train"].column_names)

Map:   0%|          | 0/5259 [00:00<?, ? examples/s]

Map:   0%|          | 0/100 [00:00<?, ? examples/s]

## 准备LoRA模型

In [5]:
from peft import LoraConfig, get_peft_model, TaskType
config = LoraConfig(
    r=8,
    lora_alpha=32,
    lora_dropout=0.1,
    target_modules=["q_proj", "k_proj", "v_proj", "o_proj", "gate_proj", "up_proj", "down_proj"],
    task_type=TaskType.CAUSAL_LM
)

model = get_peft_model(model, config)
model.print_trainable_parameters()

trainable params: 20,185,088 || all params: 7,635,801,600 || trainable%: 0.2643


## Train & Eval

In [6]:
from transformers import DataCollatorForSeq2Seq, TrainingArguments, Trainer
from collections import defaultdict
def compute_metrics(eval_pred):
    logits, labels = eval_pred
    logits = logits.argmax(axis=-1)
    true_labels = [[label for label in seq if label != -100] for seq in labels]
    true_predictions = [[p for (p, l) in zip(pred[:-1], lab[1:]) if l != -100] for pred, lab in zip(logits, labels)]
    true_labels = tokenizer.batch_decode(true_labels, skip_special_tokens=True)
    true_predictions = tokenizer.batch_decode(true_predictions, skip_special_tokens=True)
    def parse_entities(text):
        try:
            entities = json.loads(text)
            return set((entity["entity_name"], entity["entity_label"]) for entity in entities)
        except:
            return set()
    true_labels = [parse_entities(text) for text in true_labels]
    true_predictions = [parse_entities(text) for text in true_predictions]
    category_metrics = defaultdict(lambda: {"tp": 0, "fp": 0, "fn": 0})
    for labels, predictions in zip(true_labels, true_predictions):
        pred_by_category = defaultdict(set)
        label_by_category = defaultdict(set)
        for name, label in labels:
            if label in LABELS:
                label_by_category[label].add(name)
        for name, label in predictions:
            if label in LABELS:
                pred_by_category[label].add(name)
        for category in LABELS:
            label_set = label_by_category[category]
            pred_set = pred_by_category[category]
            category_metrics[category]["tp"] += len(label_set & pred_set)
            category_metrics[category]["fp"] += len(pred_set - label_set)
            category_metrics[category]["fn"] += len(label_set - pred_set)
    overall_tp = overall_tp = overall_fn = 0
    results = {}
    for category, metrics in category_metrics.items():
        tp, fp, fn = metrics["tp"], metrics["fp"], metrics["fn"]
        overall_tp += tp
        overall_tp += fp
        overall_fn += fn
        precision = tp / (tp + fp) if (tp + fp) > 0 else 0.0
        recall = tp / (tp + fn) if (tp + fn) > 0 else 0.0
        f1 = 2 * precision * recall / (precision + recall) if (precision + recall) > 0 else 0.0
        results[category] = {
            "precision": f"{precision:.4f}",
            "recall": f"{recall:.4f}",
            "f1": f"{f1:.4f}"
        }
    overall_precision = overall_tp / (overall_tp + overall_tp) if (overall_tp + overall_tp) > 0 else 0.0
    overall_recall = overall_tp / (overall_tp + overall_fn) if (overall_tp + overall_fn) > 0 else 0.0
    overall_f1 = 2 * overall_precision * overall_recall / (overall_precision + overall_recall) if (overall_precision + overall_recall) > 0 else 0.
    results["overall_precision"] = overall_precision
    results["overall_recall"] = overall_recall
    results["overall_f1"] = overall_f1
    return results

args = TrainingArguments(
    output_dir="output",
    learning_rate=2e-5,
    weight_decay=0.01,
    per_device_train_batch_size=2,
    per_device_eval_batch_size=4,
    gradient_accumulation_steps=4,
    eval_accumulation_steps=4,
    num_train_epochs=3,
    save_strategy="best",
    eval_strategy="steps",
    eval_steps=200,
    logging_steps=200,
    save_total_limit=3,
    metric_for_best_model="overall_f1",
    load_best_model_at_end=True,
    report_to="none",
)

trainer = Trainer(
    model=model,
    args=args,
    train_dataset=dataset_dict["train"],
    eval_dataset=dataset_dict["eval"],
    data_collator=DataCollatorForSeq2Seq(tokenizer, padding=True),
    compute_metrics=compute_metrics,
)

trainer.train()

Step,Training Loss,Validation Loss,中医治则,中医治疗,中医证候,中医诊断,中药,临床表现,其他治疗,方剂,西医治疗,西医诊断,Overall Precision,Overall Recall,Overall F1
200,0.1041,0.115781,"{'precision': '0.6667', 'recall': '0.3333', 'f1': '0.4444'}","{'precision': '0.5556', 'recall': '0.5000', 'f1': '0.5263'}","{'precision': '0.8333', 'recall': '0.5172', 'f1': '0.6383'}","{'precision': '1.0000', 'recall': '0.5000', 'f1': '0.6667'}","{'precision': '0.8269', 'recall': '0.7288', 'f1': '0.7748'}","{'precision': '0.7368', 'recall': '0.2456', 'f1': '0.3684'}","{'precision': '0.0000', 'recall': '0.0000', 'f1': '0.0000'}","{'precision': '0.6364', 'recall': '0.3500', 'f1': '0.4516'}","{'precision': '0.2500', 'recall': '0.1667', 'f1': '0.2000'}","{'precision': '0.7500', 'recall': '0.6316', 'f1': '0.6857'}",0.5,0.559701,0.528169
400,0.0934,0.097643,"{'precision': '1.0000', 'recall': '0.5000', 'f1': '0.6667'}","{'precision': '0.5556', 'recall': '0.5000', 'f1': '0.5263'}","{'precision': '0.7500', 'recall': '0.5172', 'f1': '0.6122'}","{'precision': '1.0000', 'recall': '0.5000', 'f1': '0.6667'}","{'precision': '0.8409', 'recall': '0.6271', 'f1': '0.7184'}","{'precision': '0.6000', 'recall': '0.2105', 'f1': '0.3117'}","{'precision': '0.0000', 'recall': '0.0000', 'f1': '0.0000'}","{'precision': '0.7500', 'recall': '0.4500', 'f1': '0.5625'}","{'precision': '0.0000', 'recall': '0.0000', 'f1': '0.0000'}","{'precision': '0.8485', 'recall': '0.7368', 'f1': '0.7887'}",0.5,0.550562,0.524064
600,0.0851,0.079456,"{'precision': '0.6667', 'recall': '0.3333', 'f1': '0.4444'}","{'precision': '0.8571', 'recall': '0.6000', 'f1': '0.7059'}","{'precision': '0.6471', 'recall': '0.3793', 'f1': '0.4783'}","{'precision': '1.0000', 'recall': '0.5000', 'f1': '0.6667'}","{'precision': '0.8261', 'recall': '0.6441', 'f1': '0.7238'}","{'precision': '0.7273', 'recall': '0.2807', 'f1': '0.4051'}","{'precision': '1.0000', 'recall': '0.5000', 'f1': '0.6667'}","{'precision': '0.6364', 'recall': '0.3500', 'f1': '0.4516'}","{'precision': '1.0000', 'recall': '0.3333', 'f1': '0.5000'}","{'precision': '0.7742', 'recall': '0.6316', 'f1': '0.6957'}",0.5,0.537879,0.518248
800,0.0695,0.069358,"{'precision': '0.0000', 'recall': '0.0000', 'f1': '0.0000'}","{'precision': '0.7143', 'recall': '0.5000', 'f1': '0.5882'}","{'precision': '0.7619', 'recall': '0.5517', 'f1': '0.6400'}","{'precision': '1.0000', 'recall': '0.5000', 'f1': '0.6667'}","{'precision': '0.8462', 'recall': '0.7458', 'f1': '0.7928'}","{'precision': '0.7576', 'recall': '0.4386', 'f1': '0.5556'}","{'precision': '0.0000', 'recall': '0.0000', 'f1': '0.0000'}","{'precision': '0.6364', 'recall': '0.3500', 'f1': '0.4516'}","{'precision': '0.6667', 'recall': '0.3333', 'f1': '0.4444'}","{'precision': '0.8438', 'recall': '0.7105', 'f1': '0.7714'}",0.5,0.6171,0.552413
1000,0.0619,0.074125,"{'precision': '0.5000', 'recall': '0.3333', 'f1': '0.4000'}","{'precision': '0.6250', 'recall': '0.5000', 'f1': '0.5556'}","{'precision': '0.8095', 'recall': '0.5862', 'f1': '0.6800'}","{'precision': '1.0000', 'recall': '0.5000', 'f1': '0.6667'}","{'precision': '0.8269', 'recall': '0.7288', 'f1': '0.7748'}","{'precision': '0.7273', 'recall': '0.4211', 'f1': '0.5333'}","{'precision': '1.0000', 'recall': '1.0000', 'f1': '1.0000'}","{'precision': '0.7273', 'recall': '0.4000', 'f1': '0.5161'}","{'precision': '0.6667', 'recall': '0.3333', 'f1': '0.4444'}","{'precision': '0.7576', 'recall': '0.6579', 'f1': '0.7042'}",0.5,0.625926,0.555921
1200,0.0572,0.072394,"{'precision': '0.5000', 'recall': '0.3333', 'f1': '0.4000'}","{'precision': '0.7143', 'recall': '0.5000', 'f1': '0.5882'}","{'precision': '0.8636', 'recall': '0.6552', 'f1': '0.7451'}","{'precision': '1.0000', 'recall': '0.5000', 'f1': '0.6667'}","{'precision': '0.8600', 'recall': '0.7288', 'f1': '0.7890'}","{'precision': '0.6800', 'recall': '0.2982', 'f1': '0.4146'}","{'precision': '1.0000', 'recall': '0.5000', 'f1': '0.6667'}","{'precision': '0.7692', 'recall': '0.5000', 'f1': '0.6061'}","{'precision': '0.4000', 'recall': '0.3333', 'f1': '0.3636'}","{'precision': '0.8788', 'recall': '0.7632', 'f1': '0.8169'}",0.5,0.61597,0.551959
1400,0.0541,0.070908,"{'precision': '0.2500', 'recall': '0.1667', 'f1': '0.2000'}","{'precision': '0.7143', 'recall': '0.5000', 'f1': '0.5882'}","{'precision': '0.7391', 'recall': '0.5862', 'f1': '0.6538'}","{'precision': '1.0000', 'recall': '0.5000', 'f1': '0.6667'}","{'precision': '0.8600', 'recall': '0.7288', 'f1': '0.7890'}","{'precision': '0.6400', 'recall': '0.2807', 'f1': '0.3902'}","{'precision': '1.0000', 'recall': '0.5000', 'f1': '0.6667'}","{'precision': '0.8750', 'recall': '0.7000', 'f1': '0.7778'}","{'precision': '0.3333', 'recall': '0.3333', 'f1': '0.3333'}","{'precision': '0.9167', 'recall': '0.8684', 'f1': '0.8919'}",0.5,0.636704,0.560132
1600,0.0473,0.072242,"{'precision': '0.5000', 'recall': '0.3333', 'f1': '0.4000'}","{'precision': '0.7500', 'recall': '0.6000', 'f1': '0.6667'}","{'precision': '0.7917', 'recall': '0.6552', 'f1': '0.7170'}","{'precision': '1.0000', 'recall': '0.5000', 'f1': '0.6667'}","{'precision': '0.8269', 'recall': '0.7288', 'f1': '0.7748'}","{'precision': '0.6923', 'recall': '0.3158', 'f1': '0.4337'}","{'precision': '1.0000', 'recall': '0.5000', 'f1': '0.6667'}","{'precision': '0.9286', 'recall': '0.6500', 'f1': '0.7647'}","{'precision': '0.5000', 'recall': '0.3333', 'f1': '0.4000'}","{'precision': '0.8571', 'recall': '0.7895', 'f1': '0.8219'}",0.5,0.641509,0.561983
1800,0.045,0.068128,"{'precision': '0.5000', 'recall': '0.3333', 'f1': '0.4000'}","{'precision': '0.7500', 'recall': '0.6000', 'f1': '0.6667'}","{'precision': '0.8750', 'recall': '0.7241', 'f1': '0.7925'}","{'precision': '1.0000', 'recall': '0.5000', 'f1': '0.6667'}","{'precision': '0.8600', 'recall': '0.7288', 'f1': '0.7890'}","{'precision': '0.6923', 'recall': '0.3158', 'f1': '0.4337'}","{'precision': '0.5000', 'recall': '0.5000', 'f1': '0.5000'}","{'precision': '0.8571', 'recall': '0.6000', 'f1': '0.7059'}","{'precision': '0.6667', 'recall': '0.3333', 'f1': '0.4444'}","{'precision': '0.9167', 'recall': '0.8684', 'f1': '0.8919'}",0.5,0.65,0.565217


TrainOutput(global_step=1974, training_loss=0.06665211873697051, metrics={'train_runtime': 2721.6049, 'train_samples_per_second': 5.797, 'train_steps_per_second': 0.725, 'total_flos': 1.342817220885719e+17, 'train_loss': 0.06665211873697051, 'epoch': 3.0})

## Test

In [7]:
model.eval()
test_dataset = load_dataset("json", data_files={"test": test_file})["test"]
with torch.no_grad():
    example = test_dataset[0]
    model_inputs = tokenizer.apply_chat_template(
        [
            {"role": "system", "content": SYSTEM_PROMPT},
            {"role": "user", "content": example["text"]},
        ],
        tokenize=False,
        add_generation_prompt=True
    )
    model_inputs = tokenizer(model_inputs, return_tensors="pt").to(model.device)
    generated_ids = model.generate(**model_inputs, max_new_tokens=128)
    generated_text = tokenizer.decode(generated_ids[0], skip_special_tokens=True)
    print(f"Input: {example['text']}")
    print(f"label: {example["labels"]}")
    print(f"Prediction: {generated_text}")
    

Generating test split: 0 examples [00:00, ? examples/s]

Input: 药进１０帖，黄疸稍退，饮食稍增，精神稍振
label: [{"entity_name": "黄疸", "entity_label": "中医诊断"}]
Prediction: system
你是一个中医药领域的专家，你需要从给定的句子中提取实体信息。所有的实体种类: 中医治则 中医治疗 中医证候 中医诊断 中药 临床表现 其他治疗 方剂 西医治疗 西医诊断。
每一个实体对应一个json格式，共同组成一个json列表，例如"[{"entity_name": "口苦", "entity_label": "临床表现"}]". 
user
药进１０帖，黄疸稍退，饮食稍增，精神稍振
assistant
[{"entity_name": "黄疸", "entity_label": "中医诊断"}, {"entity_name": "饮食稍增", "entity_label": "临床表现"}, {"entity_name": "精神稍振", "entity_label": "临床表现"}]
