In [23]:
import torch
import time
import argparse
import numpy as np
from ipex_llm.transformers import AutoModel, AutoModelForCausalLM
from modelscope import AutoTokenizer

In [24]:
CHATGLM_V3_PROMPT_FORMAT = "<|user|>\n{prompt}\n<|assistant|>"

In [25]:
def load_model_and_tokenizer(model_path):
    # Load model in 4 bit,
    # which convert the relevant layers in the model into INT4 format
    model = AutoModel.from_pretrained(model_path,
                                      load_in_4bit=True,
                                      trust_remote_code=True,
                                      model_hub='modelscope')

    # Load tokenizer
    tokenizer = AutoTokenizer.from_pretrained(model_path,
                                              trust_remote_code=True)
    
    return model, tokenizer

In [26]:
def generate_tokens(model, tokenizer, prompt, n_predict):
    # Generate predicted tokens
    prompt = CHATGLM_V3_PROMPT_FORMAT.format(prompt=prompt)
    input_ids = tokenizer.encode(prompt, return_tensors="pt")

    with torch.no_grad():
        st = time.time()
        output = model.generate(input_ids, max_new_tokens=n_predict)
        end = time.time()

    output_str = tokenizer.decode(output[0], skip_special_tokens=True)

    print(f'Inference time: {end-st} s')
    print('-'*20, 'Prompt', '-'*20)
    print(prompt)
    print('-'*20, 'Output', '-'*20)
    print(output_str)

In [27]:
model_id = 'ZhipuAI/chatglm3-6b'
prompt = 'AI是什么？'
n_predict = 32

In [28]:
model, tokenizer = load_model_and_tokenizer(model_id)

Loading checkpoint shards: 100%|██████████| 7/7 [00:08<00:00,  1.27s/it]
2024-04-09 13:53:27,812 - INFO - Converting the current model to sym_int4 format......


In [29]:
generate_tokens(model, tokenizer, prompt, n_predict)

Inference time: 442.2059109210968 s
-------------------- Prompt --------------------
<|user|>
AI是什么？
<|assistant|>
-------------------- Output --------------------
[gMASK]sop <|user|>
AI是什么？
<|assistant|> AI是人工智能（Artificial Intelligence）的缩写，指的是通过计算机程序和算法实现智能的一种技术。AI可以帮助人类完成各种任务，例如语音
