安装与导入必要包

In [None]:
!pip install datasets
!pip install trl
!pip install -U bitsandbytes

In [None]:
from transformers import AutoModelForCausalLM, AutoTokenizer
from datasets import load_dataset, DatasetDict
from trl import SFTConfig, SFTTrainer
import re
import random
from multiprocessing import cpu_count
from huggingface_hub import login
from google.colab import drive
import shutil
import torch
from transformers import pipeline
from peft import LoraConfig

加载预训练模型

In [None]:
login(token="here is your auth token")
tokenizer = AutoTokenizer.from_pretrained('meta-llama/Llama-3.1-8B-Instruct')
model = AutoModelForCausalLM.from_pretrained('meta-llama/Llama-3.1-8B-Instruct',
                        load_in_8bit=True)

加载与处理数据

In [None]:
data = load_dataset("Flmc/DISC-Med-SFT")
data1 = data.rename_column('conversation', 'messages')
data1 = data1['train'].filter(lambda x: x['source'] == 'meddial')
data1 = data1.train_test_split(test_size=0.3)
data2 = DatasetDict({
    "train": data1['train'],
    "test": data1['test']
})
data2

In [None]:
# 按模板调整
chat_template = "
{% for message in messages %}\n{%
if message['role'] == 'user' %}\n{{ '<|user|>\n' + message['content'] + eos_token }}\n
{% elif message['role'] == 'system' %}\n{{ '<|system|>\n' + message['content'] + eos_token }}\n{%
elif message['role'] == 'assistant' %}\n{{ '<|assistant|>\n'  + message['content'] + eos_token }}\n{%
endif %}\n{% if loop.last and add_generation_prompt %}\n{{ '<|assistant|>' }}\n{% endif %}\n{% endfor %}"
tokenizer.chat_template = chat_template
def apply_chat_template(example, tokenizer):
    messages = example["messages"]
    # We add an empty system message if there is none
    if messages[0]["role"] != "system":
        messages.insert(0, {"role": "system", "content": ""})
    example["text"] = tokenizer.apply_chat_template(messages, tokenize=False)
    return example

data3 = data2.map(
            apply_chat_template,
            num_proc=cpu_count(),
            fn_kwargs={"tokenizer": tokenizer},
            remove_columns=list(data2["train"].features),
            desc="Applying chat template",
      )

small_train_dataset = data3["train"].select(range(1000))
small_test_dataset = data3["test"].select(range(1000))

In [None]:
small_train_dataset[0]

{'text': '<|system|>\n<|eot_id|>\n<|user|>\n脑梗塞，我爸今年50岁，大约2007年10月左右，突然发现我爸走路时，一条腿像是在地上拖拽着走似的。他自己开始都感觉不到，我让他抬起腿走路，他那一抬腿的动作很不协调。后来到医院检查，说是脑梗塞<|eot_id|>\n<|assistant|>\n非常抱歉听到您父亲的情况。脑梗塞是由于脑血管堵塞导致的血液供应不足，进而引起脑部功能障碍。针对您父亲的症状，我建议他进行脑部CT检查，以了解脑梗塞的具体部位和范围，以及评估目前所使用的溶栓药物的疗效。一般来说，患者在溶栓后锻炼半年左右，基本功能可以恢复。祝您元旦快乐！<|eot_id|>\n<|user|>\n谢谢您大夫，别的医生也是说半年左右基本可以恢复，但是我父亲这都好几年了，下次我把病例给您发邮件过去看看吧<|eot_id|>\n<|assistant|>\n非常感谢您的信任。如果方便的话，您可以将您父亲之前的检查结果和影像片上传到我们的在线平台，或者直接通过QQ（2447848951）发送给我，我会仔细研究并给出进一步的建议。<|eot_id|>\n'}

训练模型

In [None]:
tokenizer.pad_token = tokenizer.eos_token
# 训练数据
training_args = SFTConfig(
    eval_strategy="epoch",
    output_dir="/content/model/",
    overwrite_output_dir=True,
    dataset_text_field="text",
    packing=True,
    logging_steps=1,
    per_device_train_batch_size=1,
    per_device_eval_batch_size=1,
    save_strategy="epoch",
)
peft_config = LoraConfig(
    r=16,
    lora_alpha=32,
    lora_dropout=0.05,
    bias="none",
    task_type="CAUSAL_LM",
)
trainer = SFTTrainer(
        model=model,
        args=training_args,
        train_dataset=small_train_dataset,
        eval_dataset=small_test_dataset,
        tokenizer=tokenizer,
        peft_config=peft_config,
)
trainer.train()



Generating train split: 0 examples [00:00, ? examples/s]

Generating train split: 0 examples [00:00, ? examples/s]



Epoch,Training Loss,Validation Loss
1,1.2952,1.63132
2,1.2735,1.603304
3,1.814,1.596146




TrainOutput(global_step=1524, training_loss=1.6108911678077669, metrics={'train_runtime': 3059.1125, 'train_samples_per_second': 0.498, 'train_steps_per_second': 0.498, 'total_flos': 7.03358508686377e+16, 'train_loss': 1.6108911678077669, 'epoch': 3.0})

模型推理

In [None]:
# sft后
model_id = "/content/model/checkpoint-1524/"
pipe = pipeline(
    "text-generation",
    model=model_id,
    torch_dtype=torch.bfloat16,
    device_map="auto",
)


In [None]:
messages = [
    {"role": "system", "content": "您是一个专业的医生，回答我提出的医疗问题。"},
    {"role": "user", "content": "医生，我最近咳嗽，请问该怎么办？"},
]
outputs = pipe(
    messages,
    max_new_tokens=256,
)
print(outputs[0]["generated_text"][-1])

In [None]:
# sft前
model_id = "meta-llama/Llama-3.2-1B-Instruct"
pipe_before = pipeline(
    "text-generation",
    model=model_id,
    torch_dtype=torch.bfloat16,
    device_map="auto",
)


In [None]:
messages = [
    {"role": "system", "content": "您是一个专业的医生，回答我提出的医疗问题。"},
    {"role": "user", "content": "医生，我最近咳嗽，请问该怎么办？"},
]
outputs = pipe_before(
    messages,
    max_new_tokens=256,
)
print(outputs[0]["generated_text"][-1])