In [1]:
# Logging
import logging
import datetime
log_time = datetime.datetime.now()
LOG_FORMAT = "%(asctime)s - %(levelname)s - %(message)s"
DATE_FORMAT = "%m/%d/%Y %H:%M:%S %p"
logging.basicConfig(filename=f"./logger/{log_time}.log", level=logging.INFO, format=LOG_FORMAT,datefmt=DATE_FORMAT)
logging.info("INFO")

In [1]:
# 参数设置
import os
from transformers import set_seed
from transformers import AutoModel, AutoTokenizer, AutoConfig
import torch
from huggingface_hub.inference._text_generation import TextGenerationStreamResponse, Token
from collections.abc import Iterable
from typing import Any, Protocol

import json
import re
from conversation import postprocess_text, preprocess_text, Conversation, Role
from tool_registry import dispatch_tool, get_tools

set_seed(42)

os.environ["HTTP_PROXY"]='http://10.10.20.100:1089'
os.environ["HTTPS_PROXY"]='http://10.10.20.100:1089'

TOOL_PROMPT = 'Answer the following questions as best as you can. You have access to the following tools:'

# MODEL_PATH = os.environ.get('MODEL_PATH', '/share/lilin/chatglm3-6b')
MODEL_PATH = os.environ.get('MODEL_PATH', '/share/lilin/ChatGLM3/finetune_demo/output/tool_alpaca_ft-20231110-172112-1e-4/checkpoint-200')
# MODEL_PATH = os.environ.get('MODEL_PATH', 'THUDM/chatglm3-6b')
PT_PATH = os.environ.get('PT_PATH', None)       # 不使用checkpoint
TOKENIZER_PATH = os.environ.get("TOKENIZER_PATH", '/share/lilin/chatglm3-6b')

# tokenizer = AutoTokenizer.from_pretrained("/share/lilin/chatglm3-6b", trust_remote_code=True)
# model = AutoModel.from_pretrained("/share/lilin/chatglm3-6b", trust_remote_code=True, device='cuda')
# model = model.eval()

# response, history = model.chat(tokenizer, "你好", history=[])
# print(response)

# response, history = model.chat(tokenizer, "晚上睡不着应该怎么办", history=history)
# print(response)


[registered tool] {'description': '随机生成一个数x, 使得 `range[0]` <= x < `range[1]`， 随机数生成的种子使用 `seed`',
 'name': 'random_number_generator',
 'params': [{'description': '随机数生成器使用的种子',
             'name': 'seed',
             'required': True,
             'type': 'int'},
            {'description': '生成随机数的范围',
             'name': 'range',
             'required': True,
             'type': 'tuple[int, int]'}]}
[registered tool] {'description': '获取句子 `input_text` 的长度',
 'name': 'get_sentence_length',
 'params': [{'description': '输入的句子',
             'name': 'input_text',
             'required': True,
             'type': 'str'}]}
[registered tool] {'description': '返回指数计算的结果，底数 `base` 的指数 `power` 次方',
 'name': 'exponentiation_calculation',
 'params': [{'description': '底数',
             'name': 'base',
             'required': True,
             'type': 'int'},
            {'description': '指数',
             'name': 'power',
             'required': True,
             'type': 'int'}]}
[registe

In [5]:
# 挂起llm
def stream_chat(model, tokenizer, query: str, history: list[tuple[str, str]] = None, role: str = "user",
                    past_key_values=None,max_length: int = 8192, do_sample=True, top_p=0.8, temperature=0.8,
                    logits_processor=None, return_past_key_values=False, **kwargs):
        
    from transformers.generation.logits_process import LogitsProcessor
    from transformers.generation.utils import LogitsProcessorList

    class InvalidScoreLogitsProcessor(LogitsProcessor):
        def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor) -> torch.FloatTensor:
            if torch.isnan(scores).any() or torch.isinf(scores).any():
                scores.zero_()
                scores[..., 5] = 5e4
            return scores

    if history is None:
        history = []
    if logits_processor is None:
        logits_processor = LogitsProcessorList()
    logits_processor.append(InvalidScoreLogitsProcessor())
    eos_token_id = [tokenizer.eos_token_id, tokenizer.get_command("<|user|>"),
                    tokenizer.get_command("<|observation|>")]
    gen_kwargs = {"max_length": max_length, "do_sample": do_sample, "top_p": top_p,
                    "temperature": temperature, "logits_processor": logits_processor, **kwargs}
    if past_key_values is None:
        inputs = tokenizer.build_chat_input(query, history=history, role=role)
    else:
        inputs = tokenizer.build_chat_input(query, role=role)
    inputs = inputs.to(model.device)
    if past_key_values is not None:
        past_length = past_key_values[0][0].shape[0]
        if model.transformer.pre_seq_len is not None:
            past_length -= model.transformer.pre_seq_len
        inputs.position_ids += past_length
        attention_mask = inputs.attention_mask
        attention_mask = torch.cat((attention_mask.new_ones(1, past_length), attention_mask), dim=1)
        inputs['attention_mask'] = attention_mask
    history.append({"role": role, "content": query})
    for outputs in model.stream_generate(**inputs, past_key_values=past_key_values,
                                        eos_token_id=eos_token_id, return_past_key_values=return_past_key_values,
                                        **gen_kwargs):
        if return_past_key_values:
            outputs, past_key_values = outputs
        outputs = outputs.tolist()[0][len(inputs["input_ids"][0]):]
        response = tokenizer.decode(outputs)
        if response and response[-1] != "�":
            new_history = history
            if return_past_key_values:
                yield response, new_history, past_key_values
            else:
                yield response, new_history

# class Client(Protocol):         # 协议类只用于代码静态检查
#     def generate_stream(self,
#         system: str | None,
#         tools: list[dict] | None,
#         history: list[Conversation],
#         **parameters: Any
#     ) -> Iterable[TextGenerationStreamResponse]:
#         ...

OBS_PROMPT = "You have used tools and got the related information. Using the following tool results answering the previous questions: "


class HFClient:
    def __init__(self, model_path: str, tokenizer_path: str, pt_checkpoint: str | None = None,):
        self.model_path = model_path
        self.tokenizer = AutoTokenizer.from_pretrained(tokenizer_path, trust_remote_code=True)

        if pt_checkpoint is not None:
            config = AutoConfig.from_pretrained(model_path, trust_remote_code=True, pre_seq_len=128)
            self.model = AutoModel.from_pretrained(model_path, trust_remote_code=True, config=config)
            prefix_state_dict = torch.load(os.path.join(pt_checkpoint, "pytorch_model.bin"))
            new_prefix_state_dict = {}
            for k, v in prefix_state_dict.items():
                if k.startswith("transformer.prefix_encoder."):
                    new_prefix_state_dict[k[len("transformer.prefix_encoder."):]] = v
            print("Loaded from pt checkpoints", new_prefix_state_dict.keys())
            self.model.transformer.prefix_encoder.load_state_dict(new_prefix_state_dict)
        else:
            self.model = AutoModel.from_pretrained(model_path, trust_remote_code=True)

        self.model = self.model.to(
            'cuda' if torch.cuda.is_available() else
            'mps' if torch.backends.mps.is_available() else
            'cpu'
        ).eval()


    def generate_stream(self,
        system: str | None,
        tools: list[dict] | None,
        history: list[Conversation],
        **parameters: Any
    ) -> Iterable[TextGenerationStreamResponse]:

        chat_history = []

        for conversation in history[:-1]:
            chat_history.append({
                'role': str(conversation.role).removeprefix('<|').removesuffix('|>'),
                'content': conversation.content,
            })
        
        chat_history.append({
            'role': 'system',
            'content': OBS_PROMPT if history[-1].role==Role.OBSERVATION else TOOL_PROMPT,
        })

        if tools:
            chat_history[-1]['tools'] = tools

        # chat_history.append({
        #         'role': str(history[-1].role).removeprefix('<|').removesuffix('|>'),
        #         'content': history[-1].content,
        #     })


        query = history[-1].content
        role = str(history[-1].role).removeprefix('<|').removesuffix('|>')

        text = ''
        
        for new_text, _ in stream_chat(self.model,
            self.tokenizer,
            query,
            chat_history,
            role,
            **parameters,
        ):
            word = new_text.removeprefix(text)
            word_stripped = word.strip()
            text = new_text
            yield TextGenerationStreamResponse(
                generated_text=text,
                token=Token(
                    id=0,
                    logprob=0,
                    text=word,
                    special=word_stripped.startswith('<|') and word_stripped.endswith('|>'),
                )
            )

client = HFClient(MODEL_PATH, TOKENIZER_PATH, PT_PATH)

Could not locate the configuration_chatglm.py inside /share/lilin/ChatGLM3/finetune_demo/output/tool_alpaca_ft-20231110-172112-1e-4/checkpoint-200.


OSError: /share/lilin/ChatGLM3/finetune_demo/output/tool_alpaca_ft-20231110-172112-1e-4/checkpoint-200 does not appear to have a file named configuration_chatglm.py. Checkout 'https://huggingface.co//share/lilin/ChatGLM3/finetune_demo/output/tool_alpaca_ft-20231110-172112-1e-4/checkpoint-200/None' for available files.

In [4]:
# utils
TOP_P = 0.8
TEMPERATURE = 0.01
MAX_LENGTH = 8192
TRUNCATE_LENGTH = 1024


def append_conversation(
    conversation: Conversation,
    history: list[Conversation],
    placeholder=None, 
    # placeholder: DeltaGenerator | None=None,
) -> None:
    history.append(conversation)
    # conversation.show(placeholder)


def preprocess_text(
    system: str | None,
    tools: list[dict] | None,
    history: list[Conversation],
) -> str:
    if tools:
        tools = json.dumps(tools, indent=4, ensure_ascii=False)

    prompt = f"{Role.SYSTEM}\n"
    prompt += system if not tools else TOOL_PROMPT
    if tools:
        tools = json.loads(tools)
        prompt += json.dumps(tools, ensure_ascii=False)
    for conversation in history:
        prompt += f'{conversation}'
    prompt += f'{Role.ASSISTANT}\n'
    return prompt

def extract_code(text: str) -> str:
    pattern = r'```([^\n]*)\n(.*?)```'
    matches = re.findall(pattern, text, re.DOTALL)
    return matches[-1][1]

def tool_call(*args, **kwargs) -> dict:
    print("=== Tool call:")
    print(args)
    print(kwargs)
    return kwargs


tools = get_tools()
print("===TOOLS ", tools)

markdown_placeholder = None


===TOOLS  {'random_number_generator': {'name': 'random_number_generator', 'description': '随机生成一个数x, 使得 `range[0]` <= x < `range[1]`， 随机数生成的种子使用 `seed`', 'params': [{'name': 'seed', 'description': '随机数生成器使用的种子', 'type': 'int', 'required': True}, {'name': 'range', 'description': '生成随机数的范围', 'type': 'tuple[int, int]', 'required': True}]}, 'get_sentence_length': {'name': 'get_sentence_length', 'description': '获取句子 `input_text` 的长度', 'params': [{'name': 'input_text', 'description': '输入的句子', 'type': 'str', 'required': True}]}, 'exponentiation_calculation': {'name': 'exponentiation_calculation', 'description': '返回指数计算的结果，底数 `base` 的指数 `power` 次方', 'params': [{'name': 'base', 'description': '底数', 'type': 'int', 'required': True}, {'name': 'power', 'description': '指数', 'type': 'int', 'required': True}]}, 'web_search': {'name': 'web_search', 'description': '从网络上获得 `keyword` 的习惯内容信息。\n在你要回答你现有知识无法回答的问题时，你应该使用这个工具tool（尤其是当你需要获得最新的实时信息，或者你缺少相关信息时，在这种情况下请更倾向于使用他）。', 'params': [{'name': 'keyword', 'd

In [8]:
# Experiment
import time
history: list[Conversation] = []

for i in range(10):
# def dialogue(txt):
    print('user: ', flush=True)
    input_text = input()
    print(input_text)
    if input_text == "END":
        break
    input_text = input_text.strip()
    role = Role.USER
    append_conversation(Conversation(role, input_text), history)
    # input_text = preprocess_text(
    #     None,
    #     tools,
    #     history,
    # )
    
    for _ in range(5):
        output_text = ''
        user_mark = 0
        for response in client.generate_stream(
            system=None,
            tools=tools,
            history=history,
            do_sample=True,
            max_length=MAX_LENGTH,
            temperature=TEMPERATURE,
            top_p=TOP_P,
            stop_sequences=[str(r) for r in (Role.USER, Role.OBSERVATION)],
        ):
            token = response.token
            if response.token.special:
                # print(output_text)
                logging.info(output_text)
                logging.info(token)
                print('assistant: ', output_text)


                match token.text.strip():
                    case '<|user|>':
                        # print('assistant: ', output_text)
                        # time.sleep(1)
                        append_conversation(Conversation(
                            Role.ASSISTANT,
                            postprocess_text(output_text),
                        ), history, markdown_placeholder)
                        user_mark = 1
                        break
                    # Initiate tool call
                    case '<|assistant|>':
                        append_conversation(Conversation(
                            Role.ASSISTANT,
                            postprocess_text(output_text),
                        ), history, markdown_placeholder)
                        output_text = ''
                        # message_placeholder = placeholder.chat_message(name="tool", avatar="assistant")
                        # markdown_placeholder = message_placeholder.empty()
                        continue
                    case '<|observation|>':
                        tool, *output_text = output_text.strip().split('\n')
                        output_text = '\n'.join(output_text)
                        
                        append_conversation(Conversation(
                            Role.TOOL,
                            postprocess_text(output_text),
                            tool,
                        ), history, markdown_placeholder)
                        # message_placeholder = placeholder.chat_message(name="observation", avatar="user")
                        # markdown_placeholder = message_placeholder.empty()
                        
                        try:
                            code = extract_code(output_text)
                            logging.info(f"CODE: {code}")
                            
                            args = eval(code, {'tool_call': tool_call}, {})
                        except:
                            logging.warning('Failed to parse tool call')
                            break
                        
                        output_text = ''
                        
                        # if manual_mode:
                        #     st.info('Please provide tool call results below:')
                        #     return
                        # else:
                        #     with markdown_placeholder:
                        #         with st.spinner(f'Calling tool {tool}...'):
                        #             observation = dispatch_tool(tool, args)
                        observation = dispatch_tool(tool, args)
                        
                        if len(observation) > TRUNCATE_LENGTH:
                            observation = observation[:TRUNCATE_LENGTH] + ' [TRUNCATED]'
                        append_conversation(Conversation(
                            Role.OBSERVATION, observation
                        ), history, markdown_placeholder)
                        # message_placeholder = placeholder.chat_message(name="assistant", avatar="assistant")
                        # markdown_placeholder = message_placeholder.empty()
                        # st.session_state.calling_tool = False
                        break
                    case _:
                        logging.warning(f'Unexpected special token: {token.text.strip()}')
                        break
            output_text += response.token.text
            # markdown_placeholder.markdown(postprocess_text(output_text + '▌'))
        else:
            append_conversation(Conversation(
                Role.ASSISTANT,
                postprocess_text(output_text),
            ), history, markdown_placeholder)
            user_mark = 1
            print('assistant: ', output_text)
        if user_mark:
            break

for h in history:
    logging.info(h)

user: 


你好
assistant:  
 你好，请问有什么我可以帮助你的吗？
user: 
“大庇天下寒士俱欢颜”这句话有几个字
assistant:  
 这句话有8个字。
user: 
请介绍一下华为最新款的手机
assistant:  
 这个问题需要我调用 web_search 函数来获取相关信息。
assistant:   web_search
 ```python
tool_call(keyword='华为最新款手机')
```
=== Tool call:
()
{'keyword': '华为最新款手机'}
assistant:  
 华为最新款的手机包括：HUAWEI Mate 60 Pro、HUAWEI Mate 60 Pro+、HUAWEI Mate X5、HUAWEI Mate 60 RS 和 P 系列。这些手机均搭载了先进的芯片和摄像头技术，支持5G网络和卫星通信。其中，HUAWEI Mate 60 Pro+和HUAWEI Mate 60 RS还支持双向北斗卫星消息，使您在户外环境中也能保持联系。P 系列手机则配备了超可靠昆仑玻璃和超聚光夜视长焦，适合拍摄高质量的照片。
user: 
23的12次方是多少
assistant:  
 这个问题需要使用数学计算工具。
assistant:   exponentiation_calculation
 ```python
tool_call(base=23, power=12)
```
=== Tool call:
()
{'base': 23, 'power': 12}
assistant:  
 23的12次方等于21914624432020321。
user: 
介绍一下苹果的最新手机
assistant:  
 这个问题需要我调用 web_search 函数来获取相关信息。
assistant:   
 ```python
tool_call(keyword='苹果最新手机')
```
assistant:  
 苹果最新款的手机是iPhone 14，它拥有6.1英寸的屏幕，搭载了A16芯片，支持5G网络和iOS 15操作系统。iPhone 14还配备了双摄像头和超广角镜头，支持夜间拍摄和深度感知。此外，iPhone 14还具有 Face ID 和 Touch ID 功能，可以实现快速识别和支付。


<|user|> 你好 None
<|assistant|> 你好，请问有什么我可以帮助你的吗？ None
<|user|> 请介绍一下华为的最新款手机 None
<|assistant|> 这个问题需要我调用搜索引擎来获取相关信息。 None
<|assistant|> ```python
tool_call(keyword='华为最新款手机')
``` web_search
<|observation|> HUAWEI Mate 60 Pro 了解更多 购买 最新 HUAWEI Mate 60 Pro+ 双卫星通信 超可靠玄武架构 了解更多 购买 最新 HUAWEI Mate X5 超轻薄全能折叠 玄武钢化昆仑玻璃 ￥12999 起 了解更多 购买 最新 HUAWEI Mate 60 RS 非凡大师 双卫星通信 玄武钢化昆仑玻璃 了解更多 购买 最新 HUAWEI Mate 60 双向北斗卫星消息 超可靠玄武架构 了解更多 购买 P 系列 超可靠昆仑玻璃 超聚光夜视长焦 双向北斗卫星消息，超强灵犀通信 None
<|assistant|> 华为最新款手机包括：HUAWEI Mate 60 Pro、HUAWEI Mate 60 Pro+、HUAWEI Mate X5、HUAWEI Mate 60 RS、HUAWEI Mate 60。这些手机均搭载了超可靠的昆仑玻璃，并且支持双向北斗卫星消息。其中，HUAWEI Mate 60 Pro+和HUAWEI Mate 60 RS还支持双卫星通信。此外，HUAWEI Mate X5是一款超轻薄的全能折叠手机，HUAWEI Mate 60则是一款售价更亲民的手机，具有超聚光夜视长焦功能。 None


In [21]:
from duckduckgo_search import DDGS

with DDGS() as ddgs:
    results = [r for r in ddgs.text("python programming", max_results=5)]
    print(results)

[{'title': 'Welcome to Python.org', 'href': 'https://www.python.org/', 'body': 'The core of extensible programming is defining functions. Python allows mandatory and optional arguments, keyword arguments, and even arbitrary argument lists. More about defining functions in Python 3. Python is a programming language that lets you work quickly and integrate systems more effectively. Learn More.'}, {'title': 'Python Tutorial - W3Schools', 'href': 'https://www.w3schools.com/python/default.asp', 'body': 'Python is a popular programming language. Python can be used on a server to create web applications. Start learning Python now ».'}, {'title': 'Python (programming language) - Wikipedia', 'href': 'https://en.wikipedia.org/wiki/Python_(programming_language)', 'body': 'Python is a high-level, general-purpose programming language.Its design philosophy emphasizes code readability with the use of significant indentation.. Python is dynamically typed and garbage-collected.It supports multiple prog