### 1. 数据处理

In [1]:
import json
from rich import print

def parse_BIO_file(file_path):
    """
    从 BIO 格式的数据中提取句子和对应的标签序列。
    """
    data = []
    with open(file_path, 'r', encoding='utf-8') as f:
        chars = []
        tags = []
        for line in f:
            line = line.strip()
            # 如果是空行，说明一句话结束了
            if not line:
                data.append((chars, tags))
                chars = []
                tags = []
                continue
            
            # 分割字符和标签
            if len(line) == 1:
                # print(line)
                char = ' '
                tag = 'O'
            else :
                char, tag = line.split()
            chars.append(char)
            tags.append(tag)
                
        # 防止文件最后没有空行导致漏掉最后一句
        if chars: 
            data.append((chars, tags))
    return data

def extract_entities(chars, tags):
    """
    提取实体，识别B-和I-，拼成完整的实体
    """
    entities = {}
    i = 0
    while i < len(tags):
        tag = tags[i]
        if tag.startswith("B-"):
            entity_type = tag[2:] # 截取类别名 (去掉 B-)
            if entity_type not in entities: entities[entity_type] = []
            entity_content = chars[i]
            j = i + 1
            # 继续向后找同类型的 I- 标签
            while j < len(tags) and tags[j] == f"I-{entity_type}":
                entity_content += chars[j]
                j += 1
            entities[entity_type].append(entity_content)
            i = j
        else : i += 1
    return entities

def convert_to_json(input_file, output_file):
    """
    处理 BIO 数据，得到模型训练所需的 JSON 对话格式
    """
    raw_data = parse_BIO_file(input_file)
    system_prompt = (
        "你是一个中医药领域的命名实体识别专家。请从文本中提取以下类别的实体并以JSON格式输出："
        "['中医治则', '中医治疗', '中医证候', '中医诊断', '中药', '临床表现', '其他治疗', '方剂', '西医治疗', '西医诊断']。"
        "如果文本中不包含任何上述实体，请输出：{\"result\": \"未找到相关实体\"}。"
        )
    with open(output_file, 'w', encoding='utf-8') as f:
        for chars, tags in raw_data:
            
            text = "".join(chars)
            # print(text)
            entity_dict = extract_entities(chars, tags)
            
            # print(entity_dict)
            if not entity_dict:
                print("wrong")
                assistant_content = json.dumps({"result": "未找到相关实体"}, ensure_ascii=False)
            else :
                assistant_content = json.dumps(entity_dict, ensure_ascii=False)
            
            message = {
                "messages": [
                    {
                        "role": "system",
                        "content": system_prompt
                    },
                    {
                        "role": "user",
                        "content": text
                    },
                    {
                        "role": "assistant",
                        "content": assistant_content
                    }
                ]
            }
            f.write(json.dumps(message, ensure_ascii=False) + "\n")
    print(f"转换完成：{output_file}")

In [2]:
convert_to_json('./data/medical.dev', './data/medical_dev.jsonl')
convert_to_json('./data/medical.test', './data/medical_test.jsonl')
convert_to_json('./data/medical.train', './data/medical_train.jsonl')

转换完成：./data/medical_dev.jsonl
转换完成：./data/medical_test.jsonl
转换完成：./data/medical_train.jsonl


### 2. 加载 Base Model

In [3]:
import torch
from transformers import AutoTokenizer, AutoModelForCausalLM, BitsAndBytesConfig

model_id = "Qwen/Qwen2.5-7B-Instruct"
print(f"加载 Tokenizer: {model_id} ...")
tokenizer = AutoTokenizer.from_pretrained(model_id, trust_remote_code=True)
print(f"加载 Tokenizer: Done! ")

print("加载模型...")
bnb_config = BitsAndBytesConfig(
    load_in_4bit=True,
    bnb_4bit_compute_dtype=torch.bfloat16, 
    bnb_4bit_use_double_quant=True,
    bnb_4bit_quant_type="nf4"
)

model = AutoModelForCausalLM.from_pretrained(
    model_id,
    device_map="auto",
    quantization_config=bnb_config,
    trust_remote_code=True
)
print("加载模型: Done! ")

print("测试模型：")
prompt = "简单介绍下你自己"
messages = [
    {"role": "system", "content": "你是一个人工智能助手。"},
    {"role": "user", "content": prompt}
]
text = tokenizer.apply_chat_template(
    messages,
    tokenize=False,
    add_generation_prompt=True
)
model_inputs = tokenizer([text], return_tensors='pt').to(model.device)
generated_ids = model.generate(
    **model_inputs,
    max_new_tokens=50
)
generated_ids = [
    output_ids[len(input_ids):] for input_ids, output_ids in zip(model_inputs.input_ids, generated_ids)
]
response = tokenizer.batch_decode(generated_ids, skip_special_tokens=True)[0]
print(f"用户：{prompt}")
print(f"模型：{response}")

加载 Tokenizer: Qwen/Qwen2.5-7B-Instruct ...
加载 Tokenizer: Done! 
加载模型...


Loading checkpoint shards:   0%|          | 0/4 [00:00<?, ?it/s]

加载模型: Done! 
测试模型：
用户：简单介绍下你自己
模型：我叫通义千问，是由阿里云开发的大型语言模型。我的主要功能是生成各种文本内容，比如撰写文章、编写代码、制定计划等，还可以提供问题解答和对话交流。我会不断学习和进步，


### 3. 配置 LoRA

In [5]:
from peft import LoraConfig, TaskType, get_peft_model

target_modules = ["q_proj", "k_proj", "v_proj", "o_proj", "gate_proj", "up_proj", "down_proj"]

# LoRA 配置
peft_config = LoraConfig(
    task_type=TaskType.CAUSAL_LM,
    inference_mode=False,
    r=16,
    lora_alpha=32,
    lora_dropout=0.1,
    target_modules=target_modules
)

model = get_peft_model(model, peft_config)

model.print_trainable_parameters()

trainable params: 40,370,176 || all params: 7,655,986,688 || trainable%: 0.5273


### 4. 配置训练参数

In [22]:
import numpy as np
import json
from transformers import Seq2SeqTrainer, Seq2SeqTrainingArguments, DataCollatorForSeq2Seq
from datasets import Dataset

def extract_entities_from_json(json_str):
    """
    从模型生成的字符串中提取实体
    """
    entities = set()
    try:
        start = json_str.find('{')
        end = json_str.rfind('}') + 1
        if start != -1 and end != 0:
            data = json.loads(json_str[start:end])
            for label, items in data.items():
                if isinstance(items, str):
                    items = [items]
                if isinstance(items, list):
                    for item in items:
                        if isinstance(item, str) and len(item.strip()) > 0:
                            entities.add((label, item.strip()))
    except:
        print(e)
        pass
    return entities

def compute_metrics(eval_preds):
    preds, labels = eval_preds
    if isinstance(preds, tuple): preds = preds[0]
    # 将 Token IDs 解码为文本
    decoded_preds = tokenizer.batch_decode(preds, skip_special_tokens=True)
    # 处理 labels
    labels = np.where(labels != -100, labels, tokenizer.pad_token_id)
    decoded_labels = tokenizer.batch_decode(labels, skip_special_tokens=True)
    # 计算 F1
    total_tp, total_fp, total_fn = 0, 0, 0
    for pred_str, label_str in zip(decoded_preds, decoded_labels):
        pred_entities = extract_entities_from_json(pred_str)
        true_entities = extract_entities_from_json(label_str)
        
        tp = len(pred_entities & true_entities)
        fp = len(pred_entities - true_entities)
        fn = len(true_entities - pred_entities)
        
        total_tp += tp
        total_fp += fp
        total_fn += fn
    
    precision = total_tp / (total_tp + total_fp + 1e-5)
    recall = total_tp / (total_tp + total_fn + 1e-5)
    f1 = 2 * precision * recall / (precision + recall + 1e-5)
    print(f"\n[Eval] Precision: {precision:.4f}, Recall: {recall:.4f}, F1: {f1:.4f}")
    return {"precision": precision, "recall": recall, "f1": f1}

In [38]:
def process_func(example):
    """
    处理输入数据，将输入转为 Input IDs
    """
    MAX_LENGTH = 512
    input_ids, attention_mask, labels = [], [], []
    # 使用 tokenizer 的 chat template
    instruction = tokenizer.apply_chat_template(
        example["messages"][:-1], # 取 System + User
        add_generation_prompt=True,
        tokenize=False
    )
    # print(instruction)
    response = example["messages"][-1]["content"] # 取 Assistant 回复
    # print(response)
    instruction_ids = tokenizer(instruction, add_special_tokens=False).input_ids
    response_ids = tokenizer(response, add_special_tokens=False).input_ids + [tokenizer.eos_token_id]
    
    input_ids = instruction_ids + response_ids
    attention_mask = [1] * len(input_ids)
    # # 构造 Labels：User 部分设为 -100 (不计算 Loss)，Assistant 部分保留
    labels = [-100] * len(instruction_ids) + response_ids
    
    if len(input_ids) > MAX_LENGTH:
        input_ids = input_ids[:MAX_LENGTH]
        attention_mask = attention_mask[:MAX_LENGTH]
        labels = labels[:MAX_LENGTH]
    
    return {
        "input_ids": input_ids,
        "attention_mask": attention_mask,
        "labels": labels
    }

def load_jsonl(path):
    with open(path, 'r', encoding='utf-8') as f:
        return [json.loads(line) for line in f]


train_data = load_jsonl("./data/medical_train.jsonl")
eval_data = load_jsonl("./data/medical_dev.jsonl")

train_ds = Dataset.from_list(train_data)
eval_ds = Dataset.from_list(eval_data)

train_dataset = train_ds.map(process_func, remove_columns=train_ds.column_names)
eval_dataset = eval_ds.map(process_func, remove_columns=eval_ds.column_names)
print(train_dataset)
print(eval_dataset)

Map:   0%|          | 0/5259 [00:00<?, ? examples/s]

Map:   0%|          | 0/657 [00:00<?, ? examples/s]

In [31]:

with open("./data/medical_dev.jsonl", 'r', encoding='utf-8') as f:
    for line in f:
        data = json.loads(line)
        break
process_func(data)