# Prefix_Tuning微调

In [1]:
from datasets import load_dataset
from transformers import AutoTokenizer,AutoModelForCausalLM,Trainer,TrainingArguments,DataCollatorForSeq2Seq
from peft import PrefixTuningConfig,get_peft_model,TaskType
import warnings
warnings.filterwarnings("ignore")

  from .autonotebook import tqdm as notebook_tqdm


## 1、加载模型和分词器

In [2]:
model_name="Qwen/Qwen3-0.6B"
model=AutoModelForCausalLM.from_pretrained(model_name,device_map='auto',torch_dtype="auto",)
tokenizer=AutoTokenizer.from_pretrained(model_name)

Loading weights: 100%|██████████| 311/311 [00:02<00:00, 145.87it/s, Materializing param=model.norm.weight]                              


## 2、PrefixTuningConfig配置

In [3]:
config = PrefixTuningConfig(
    task_type=TaskType.CAUSAL_LM, 
    num_virtual_tokens=10,
    prefix_projection=False
)
model=get_peft_model(model=model,peft_config=config)
model.print_trainable_parameters()

trainable params: 573,440 || all params: 752,205,824 || trainable%: 0.0762


## 3、加载数据集

In [4]:
data=load_dataset('json',data_files='../../dataset/chinese_law_ft_dataset.json',split="train[:1000]")

In [5]:
dataset=data.train_test_split(
    train_size=0.7,
    shuffle=True,
    seed=7
)

In [6]:
dataset

DatasetDict({
    train: Dataset({
        features: ['instruction', 'input', 'output', 'id'],
        num_rows: 700
    })
    test: Dataset({
        features: ['instruction', 'input', 'output', 'id'],
        num_rows: 300
    })
})

In [7]:
dataset['train']

Dataset({
    features: ['instruction', 'input', 'output', 'id'],
    num_rows: 700
})

## 4、数据预处理：分词、编码

In [8]:
def process_fun(example):
    content=[]
    for instruction,input,output in zip(example['instruction'],example['input'],example['output']):
        if input.strip():
            text=f'Human:{instruction}\n{input}\nAI:{output}'
            content.append(text)
        else:
            text=f'Human:{instruction}\nAI:{output}'
            content.append(text)
    
    encoded = tokenizer(
        content,
        max_length=512,
        truncation=True,
        padding="max_length",
        return_tensors="pt"
    )
    labels = encoded["input_ids"].clone()
    for index,text in enumerate(content):
        answer_start=text.find('AI:')+len('AI:')
        question=text[:answer_start]
        question_ids=tokenizer.encode(question, add_special_tokens=False)
        question_length=len(question_ids)
        labels[index,:question_length]=-100
    return {
        "input_ids": encoded["input_ids"],
        "attention_mask": encoded["attention_mask"],
        "labels": labels
    }


In [9]:
train_process_data=dataset['train'].map(process_fun,batched=True,remove_columns=dataset['train'].column_names)
test_process_data=dataset['test'].map(process_fun,batched=True,remove_columns=dataset['test'].column_names)

In [10]:
tokenizer.decode(train_process_data[1],skip_special_tokens=True)

'Human:某公司计划与一家同行公司合并，合并后的权利和义务如何分配？\nAI:根据民法典总则第六十七条规定，法人合并的，其权利和义务由合并后的法人享有和承担。因此，在合并后，原公司和同行公司的权利和义务将合并在一起，并由合并后的新公司享有和承担。'

## 5、模型训练配置

In [11]:
# 定义训练参数
training_args = TrainingArguments(
    output_dir="../../models/prefinx_tuning",
    logging_steps=10,
    logging_dir='./runs',
    eval_strategy='epoch',
    num_train_epochs=3,
    per_device_train_batch_size=4,
    per_device_eval_batch_size=4,
    learning_rate=2e-5,
    weight_decay=0.01,
    load_best_model_at_end=True,
    metric_for_best_model='eval_loss',
    # gradient_accumulation_steps=4,  # 如果GPU内存有限
)

`logging_dir` is deprecated and will be removed in v5.2. Please set `TENSORBOARD_LOGGING_DIR` instead.


In [12]:
trainer=Trainer(
    model=model,
    args=training_args,
    eval_dataset=test_process_data,
    train_dataset=train_process_data,
    data_collator=DataCollatorForSeq2Seq(tokenizer=tokenizer,padding=True),
)

## 6、模型训练

In [13]:
trainer.train()

Epoch,Training Loss,Validation Loss
1,10.07543,9.681171
2,8.105386,8.280047
3,7.667658,7.69573


TrainOutput(global_step=525, training_loss=9.406994425455729, metrics={'train_runtime': 6371.5276, 'train_samples_per_second': 0.33, 'train_steps_per_second': 0.082, 'total_flos': 3845237243904000.0, 'train_loss': 9.406994425455729, 'epoch': 3.0})

## 7、保存模型

In [None]:
trainer.save_model('../../models/prefinx_tuning')