# 模型推理 - 使用 QLoRA 微调后的 ChatGLM3-6B

In [1]:
import torch
from transformers import AutoModel, AutoTokenizer, BitsAndBytesConfig
from peft import PeftModel, PeftConfig

# 定义全局变量和参数
model_name_or_path = 'THUDM/chatglm3-6b'  # 模型ID或本地路径
peft_model_path = f"models/advertisegen/{model_name_or_path}"


  from .autonotebook import tqdm as notebook_tqdm


In [2]:
config = PeftConfig.from_pretrained(peft_model_path)

q_config = BitsAndBytesConfig(load_in_4bit=True,
                              bnb_4bit_quant_type='nf4',
                              bnb_4bit_use_double_quant=True,
                              bnb_4bit_compute_dtype=torch.float32)

base_model = AutoModel.from_pretrained(config.base_model_name_or_path,
                                       quantization_config=q_config,
                                       trust_remote_code=True,
                                       device_map='auto')
base_model.requires_grad_(False)
base_model.eval()

A new version of the following files was downloaded from https://huggingface.co/THUDM/chatglm3-6b:
- configuration_chatglm.py
. Make sure to double-check they do not contain any added malicious code. To avoid downloading new versions of the code file, you can pin a revision.
A new version of the following files was downloaded from https://huggingface.co/THUDM/chatglm3-6b:
- quantization.py
. Make sure to double-check they do not contain any added malicious code. To avoid downloading new versions of the code file, you can pin a revision.
A new version of the following files was downloaded from https://huggingface.co/THUDM/chatglm3-6b:
- modeling_chatglm.py
- quantization.py
. Make sure to double-check they do not contain any added malicious code. To avoid downloading new versions of the code file, you can pin a revision.
Downloading shards: 100%|██████████| 7/7 [00:24<00:00,  3.44s/it]
Loading checkpoint shards: 100%|██████████| 7/7 [00:04<00:00,  1.61it/s]


ChatGLMForConditionalGeneration(
  (transformer): ChatGLMModel(
    (embedding): Embedding(
      (word_embeddings): Embedding(65024, 4096)
    )
    (rotary_pos_emb): RotaryEmbedding()
    (encoder): GLMTransformer(
      (layers): ModuleList(
        (0-27): 28 x GLMBlock(
          (input_layernorm): RMSNorm()
          (self_attention): SelfAttention(
            (query_key_value): Linear4bit(in_features=4096, out_features=4608, bias=True)
            (core_attention): CoreAttention(
              (attention_dropout): Dropout(p=0.0, inplace=False)
            )
            (dense): Linear4bit(in_features=4096, out_features=4096, bias=False)
          )
          (post_attention_layernorm): RMSNorm()
          (mlp): MLP(
            (dense_h_to_4h): Linear4bit(in_features=4096, out_features=27392, bias=False)
            (dense_4h_to_h): Linear4bit(in_features=13696, out_features=4096, bias=False)
          )
        )
      )
      (final_layernorm): RMSNorm()
    )
    (output_la

In [3]:
input_text = '类型#裙*版型#显瘦*风格#文艺*风格#简约*图案#印花*图案#撞色*裙下摆#压褶*裙长#连衣裙*裙领型#圆领'
print(f'输入：\n{input_text}')
tokenizer = AutoTokenizer.from_pretrained(config.base_model_name_or_path, trust_remote_code=True)

输入：
类型#裙*版型#显瘦*风格#文艺*风格#简约*图案#印花*图案#撞色*裙下摆#压褶*裙长#连衣裙*裙领型#圆领


A new version of the following files was downloaded from https://huggingface.co/THUDM/chatglm3-6b:
- tokenization_chatglm.py
. Make sure to double-check they do not contain any added malicious code. To avoid downloading new versions of the code file, you can pin a revision.
Setting eos_token is not supported, use the default one.
Setting pad_token is not supported, use the default one.
Setting unk_token is not supported, use the default one.


In [4]:
response, history = base_model.chat(tokenizer=tokenizer, query=input_text)
print(f'ChatGLM3-6B 微调前：\n{response}')

2024-04-08 08:03:55.586284: I tensorflow/core/platform/cpu_feature_guard.cc:182] This TensorFlow binary is optimized to use available CPU instructions in performance-critical operations.
To enable the following instructions: AVX2 AVX512F FMA, in other operations, rebuild TensorFlow with the appropriate compiler flags.


ChatGLM3-6B 微调前：
根据您的描述，这款连衣裙的风格是文艺简约，图案是印花撞色，版型是显瘦的A字裙，裙下摆是压褶设计，裙长到膝盖附近，裙领型是圆领。这款连衣裙的设计既具有时尚感，又充满了文艺气息，非常适合喜欢简约风格的女性。穿上这款连衣裙，您一定会显得更加优雅和自信。


In [5]:
model = PeftModel.from_pretrained(base_model, peft_model_path)
response, history = model.chat(tokenizer=tokenizer, query=input_text)
print(f'ChatGLM3-6B 微调后: \n{response}')

ChatGLM3-6B 微调后: 
连衣裙，简约圆领设计，简洁大方，衬托脸型。腰间撞色压褶设计，修身显瘦，修饰身形。裙摆设计，优雅大方，时尚减龄。裙身采用撞色印花设计，个性时尚，尽显文艺气息。
