# 安装依赖

In [None]:
!git clone https://github.com/THUDM/ChatGLM-6B.git

In [None]:
!pip install -r ChatGLM-6B/requirements.txt

In [None]:
!pip install rouge_chinese nltk jieba datasets 

# 加载模型

In [None]:
!git clone https://huggingface.co/THUDM/chatglm-6b-int4

In [None]:
# 加载模型
from transformers import AutoTokenizer, AutoModel

model_path = "chatglm-6b-int4"
tokenizer = AutoTokenizer.from_pretrained(model_path, trust_remote_code=True)
model = AutoModel.from_pretrained(model_path, trust_remote_code=True).half().cuda()
# model = model.eval()

In [None]:
from IPython.display import display, Markdown, clear_output

# 准备提示语
prompt = "如何制作宫保鸡丁"

# 使用 IPython.display 流式打印模型输出
for response, history in model.stream_chat(
        tokenizer, prompt, history=[]):
    clear_output(wait=True)
    display(Markdown(response))

# 模型微调

In [None]:
# 下载 ADGEN 数据集
!wget -O AdvertiseGen.tar.gz https://cloud.tsinghua.edu.cn/f/b3f119a008264b1cabd1/?dl=1

In [None]:
# 解压数据集
!tar -xzvf AdvertiseGen.tar.gz

In [None]:
import os

os.environ["WANDB_DISABLED"] = "true"

In [None]:
# P-tuning v2
!PRE_SEQ_LEN=128 && LR=2e-2 && CUDA_VISIBLE_DEVICES=0 python3 ChatGLM-6B/ptuning/main.py \
    --do_train \
    --train_file AdvertiseGen/train.json \
    --validation_file AdvertiseGen/dev.json \
    --prompt_column content \
    --response_column summary \
    --overwrite_cache \
    --model_name_or_path chatglm-6b-int4 \
    --output_dir output/adgen-chatglm-6b-int4-pt-$PRE_SEQ_LEN-$LR \
    --overwrite_output_dir \
    --max_source_length 64 \
    --max_target_length 64 \
    --per_device_train_batch_size 4 \
    --per_device_eval_batch_size 1 \
    --gradient_accumulation_steps 4 \
    --predict_with_generate \
    --max_steps 100 \
    --logging_steps 10 \
    --save_steps 100 \
    --learning_rate $LR \
    --pre_seq_len $PRE_SEQ_LEN \
    --quantization_bit 4

# 模型推理

In [None]:
# !PRE_SEQ_LEN=128 && CHECKPOINT_PATH=adgen-chatglm-6b-int4-pt-128-2e-2 && STEP=100 && CUDA_VISIBLE_DEVICES=0 python3 ChatGLM-6B/ptuning/main.py \
#     --do_predict \
#     --validation_file AdvertiseGen/dev.json \
#     --test_file AdvertiseGen/dev.json \
#     --overwrite_cache \
#     --prompt_column content \
#     --response_column summary \
#     --model_name_or_path chatglm-6b-int4 \
#     --ptuning_checkpoint ./output/$CHECKPOINT_PATH/checkpoint-$STEP \
#     --output_dir ./output/$CHECKPOINT_PATH \
#     --overwrite_output_dir \
#     --max_source_length 64 \
#     --max_target_length 64 \
#     --per_device_eval_batch_size 2 \
#     --predict_with_generate \
#     --pre_seq_len $PRE_SEQ_LEN \
#     --quantization_bit 4

# 模型部署

In [None]:
import torch
from transformers import AutoConfig, AutoModel, AutoTokenizer

# Fine-tuning 后的表现测试，载入Tokenizer
tokenizer = AutoTokenizer.from_pretrained(model_path, trust_remote_code=True)
config = AutoConfig.from_pretrained(model_path, trust_remote_code=True, pre_seq_len=128)
model = AutoModel.from_pretrained(model_path, config=config, trust_remote_code=True)
# 此处使用你的 ptuning 工作目录
prefix_state_dict = torch.load(os.path.join("output/adgen-chatglm-6b-int4-pt-128-2e-2/checkpoint-100", "pytorch_model.bin"))
new_prefix_state_dict = {}
for k, v in prefix_state_dict.items():
    if k.startswith("transformer.prefix_encoder."):
        new_prefix_state_dict[k[len("transformer.prefix_encoder."):]] = v
model.transformer.prefix_encoder.load_state_dict(new_prefix_state_dict)

# 根据需求可以进行量化，也可以直接使用：
# model = model.quantize(4)
model = model.half().cuda()
model.transformer.prefix_encoder.float()
model = model.eval()

In [None]:
response, history = model.chat(tokenizer, "类型#上衣\*材质#牛仔布\*颜色#白色\*风格#简约\*图案#刺绣\*衣样式#外套\*衣款式#破洞", history=[])
response