In [1]:
from conversation import postprocess_text, preprocess_text, Conversation, Role
from tool_registry import dispatch_tool, get_tools
from transformers import AutoModel, AutoTokenizer, BertModel
from transformers.generation.utils import LogitsProcessorList, StoppingCriteriaList, GenerationConfig, ModelOutput
from transformers.generation.logits_process import LogitsProcessor
import torch

[registered tool] {'description': 'Generates a random number x, s.t. range[0] <= x < range[1]',
 'name': 'random_number_generator',
 'params': [{'description': 'The random seed used by the generator',
             'name': 'seed',
             'required': True,
             'type': 'int'},
            {'description': 'The range of the generated numbers',
             'name': 'range',
             'required': True,
             'type': 'tuple[int, int]'}]}
[registered tool] {'description': 'Get the current weather for `city_name`',
 'name': 'get_weather',
 'params': [{'description': 'The name of the city to be queried',
             'name': 'city_name',
             'required': True,
             'type': 'str'}]}
[registered tool] {'description': '获取经营指定产品的公司',
 'name': 'get_company_operating_the_given_product',
 'params': [{'description': '产品',
             'name': 'product',
             'required': True,
             'type': 'str'}]}


  from .autonotebook import tqdm as notebook_tqdm


In [3]:
role = Role.USER
history = []

In [4]:
tools = get_tools()

In [5]:
prompt_text = '经营电池的公司有哪些'
history.append(Conversation(role, prompt_text))

In [6]:
history

[Conversation(role=<Role.USER: 2>, content='经营电池的公司有哪些', tool=None, image=None)]

In [None]:
input_text = preprocess_text(
    None,
    tools,
    history,
)

In [7]:
TOOL_PROMPT = 'Answer the following questions as best as you can. You have access to the following tools:'

In [8]:
chat_history = [{
    'role': 'system',
    'content': TOOL_PROMPT,
}]

In [9]:
chat_history[0]['tools'] = tools

In [10]:
chat_history

[{'role': 'system',
  'content': 'Answer the following questions as best as you can. You have access to the following tools:',
  'tools': {'random_number_generator': {'name': 'random_number_generator',
    'description': 'Generates a random number x, s.t. range[0] <= x < range[1]',
    'params': [{'name': 'seed',
      'description': 'The random seed used by the generator',
      'type': 'int',
      'required': True},
     {'name': 'range',
      'description': 'The range of the generated numbers',
      'type': 'tuple[int, int]',
      'required': True}]},
   'get_weather': {'name': 'get_weather',
    'description': 'Get the current weather for `city_name`',
    'params': [{'name': 'city_name',
      'description': 'The name of the city to be queried',
      'type': 'str',
      'required': True}]},
   'get_company_operating_the_given_product': {'name': 'get_company_operating_the_given_product',
    'description': 'Get the companies that operates the given product',
    'params': [{'

In [11]:
query = history[-1].content
role = str(history[-1].role).removeprefix('<|').removesuffix('|>')

In [14]:
MODEL_PATH = '/home/lc/projects/pretrained_models/chatglm3-6b'
DEVICE = 'cuda:1'

In [15]:
tokenizer = AutoTokenizer.from_pretrained(MODEL_PATH, trust_remote_code=True)
model = AutoModel.from_pretrained(MODEL_PATH, trust_remote_code=True)
model = model.to(DEVICE)

Loading checkpoint shards: 100%|██████████| 7/7 [00:08<00:00,  1.27s/it]


In [19]:
resp, his = model.chat(tokenizer=tokenizer, query=query, history=chat_history, role=role)

In [20]:
resp

{'name': 'get_company_operating_the_given_product',
 'parameters': {'product': '电池'}}

In [21]:
observation = dispatch_tool(resp['name'], resp['parameters'])

In [23]:
from tool_registry import _TOOL_HOOKS

In [25]:
tool_call = _TOOL_HOOKS[resp['name']]

In [28]:
resp['parameters']

{'product': '电池'}

In [30]:
tool_call(**resp['parameters'])

['电池', '新能源']

In [22]:
observation

"['电池', '新能源']"

In [None]:
class InvalidScoreLogitsProcessor(LogitsProcessor):
    def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor) -> torch.FloatTensor:
        if torch.isnan(scores).any() or torch.isinf(scores).any():
            scores.zero_()
            scores[..., 5] = 5e4
        return scores

In [None]:
max_length: int = 8192
num_beams=1
do_sample=True
top_p=0.8
temperature=0.8

In [None]:
chat_history

In [None]:
logits_processor = LogitsProcessorList()
logits_processor.append(InvalidScoreLogitsProcessor())
gen_kwargs = {
    "max_length": max_length,
    "num_beams": num_beams,
    "do_sample": do_sample,
    "top_p": top_p,
    "temperature": temperature,
    "logits_processor": logits_processor
}
inputs = tokenizer.build_chat_input(query, history=chat_history, role=role)
inputs = inputs.to(model.device)

In [None]:
tokenizer.decode(inputs['input_ids'][0], skip_special_tokens=True)

In [None]:
eos_token_id = [
    tokenizer.eos_token_id,
    tokenizer.get_command("<|user|>"),
    tokenizer.get_command("<|observation|>")
]

In [None]:
eos_token_id

In [None]:
outputs = model.generate(**inputs, **gen_kwargs, eos_token_id=eos_token_id)

In [None]:
tokenizer.decode(outputs[0], skip_special_tokens=True)

In [None]:
outputs

In [None]:
outputs = outputs.tolist()[0][len(inputs["input_ids"][0]):-1]

In [None]:
response = tokenizer.decode(outputs)

In [None]:
response

In [None]:
chat_history.append({"role": role, "content": query})

In [None]:
chat_history

In [None]:
metadata, content = response.split("\n", maxsplit=1)

In [None]:
chat_history.append({"role": "assistant", "metadata": metadata, "content": content})

In [None]:
chat_history

In [None]:
content.split("\n")

In [None]:
content = "\n".join(content.split("\n")[1:-1])

In [None]:
content

In [None]:
def tool_call(**kwargs):
    return kwargs

In [None]:
parameters = eval(content)

In [None]:
parameters

In [None]:
content = {"name": metadata.strip(), "parameters": parameters}

In [None]:
content

In [None]:
observation = dispatch_tool(content['name'], content['parameters'])

In [None]:
observation

In [None]:
chat_history

In [None]:
del history[-1]

In [None]:
history.append(
    Conversation(Role.TOOL, "```python\ntool_call(city_name='北京')\n```",
                 'get_weather'))

In [None]:
history

In [None]:
history.append(Conversation(Role.OBSERVATION, observation))

In [None]:
history

In [None]:
history[-1].content