# PEFT 库 QLoRA 实战 - ChatGLM3-6B

In [1]:
# 定义全局变量和参数
peft_model_path = "models/chatglm3-6b-nf4" # 训练好的模型存储目录

In [2]:
import torch
from transformers import AutoModel, AutoTokenizer, BitsAndBytesConfig
from peft import PeftModel, PeftConfig

config = PeftConfig.from_pretrained(peft_model_path)
q_config = BitsAndBytesConfig(load_in_4bit=True,
                              bnb_4bit_quant_type='nf4',
                              bnb_4bit_use_double_quant=True,
                              bnb_4bit_compute_dtype=torch.float32)

base_model = AutoModel.from_pretrained(config.base_model_name_or_path,
                                       quantization_config=q_config,
                                       trust_remote_code=True,
                                       device_map='auto')
base_model.requires_grad_(False)
base_model.eval()

  from .autonotebook import tqdm as notebook_tqdm
Loading checkpoint shards: 100%|██████████| 7/7 [00:44<00:00,  6.35s/it]


ChatGLMForConditionalGeneration(
  (transformer): ChatGLMModel(
    (embedding): Embedding(
      (word_embeddings): Embedding(65024, 4096)
    )
    (rotary_pos_emb): RotaryEmbedding()
    (encoder): GLMTransformer(
      (layers): ModuleList(
        (0-27): 28 x GLMBlock(
          (input_layernorm): RMSNorm()
          (self_attention): SelfAttention(
            (query_key_value): Linear4bit(in_features=4096, out_features=4608, bias=True)
            (core_attention): CoreAttention(
              (attention_dropout): Dropout(p=0.0, inplace=False)
            )
            (dense): Linear4bit(in_features=4096, out_features=4096, bias=False)
          )
          (post_attention_layernorm): RMSNorm()
          (mlp): MLP(
            (dense_h_to_4h): Linear4bit(in_features=4096, out_features=27392, bias=False)
            (dense_4h_to_h): Linear4bit(in_features=13696, out_features=4096, bias=False)
          )
        )
      )
      (final_layernorm): RMSNorm()
    )
    (output_la

In [7]:
peft_model = PeftModel.from_pretrained(base_model, peft_model_path)
tokenizer = AutoTokenizer.from_pretrained(config.base_model_name_or_path, trust_remote_code=True)
def compare_results(query, base_model, peft_model):
    base_response, _ = base_model.chat(tokenizer, query)
    ft_response,_ = peft_model.chat(tokenizer, query)
    print(f"问题：{query}\n\n原始输出：\n{base_response}\n\n\n微调后：\n{ft_response}")

Setting eos_token is not supported, use the default one.
Setting pad_token is not supported, use the default one.
Setting unk_token is not supported, use the default one.


In [8]:
input_text = '类型#裙*版型#显瘦*风格#文艺*风格#简约*图案#印花*图案#撞色*裙下摆#压褶*裙长#连衣裙*裙领型#圆领'
compare_results(input_text, base_model, peft_model)

问题：类型#裙*版型#显瘦*风格#文艺*风格#简约*图案#印花*图案#撞色*裙下摆#压褶*裙长#连衣裙*裙领型#圆领

原始输出：
简约的圆领设计，修饰颈部曲线，修饰脸型。裙摆采用压褶设计，穿着更飘逸，优雅大方。印花图案，充满艺术气息，甜美可人。撞色的面料，色彩鲜明，视觉冲击力强，穿着更舒适。修身版型，勾勒出纤细的腰身，展现迷人身材。


微调后：
简约又充满设计感的连衣裙，采用经典的圆领设计，修饰脸型，展现优美的天鹅颈。袖口处采用撞色设计，打破整体色调，增添几分时尚感。腰间的松紧设计，穿着舒适又显瘦，后背处的压褶设计，增加层次感，尽显文艺气质。连衣裙后背的印花设计，让整条裙子更加富有层次感。
