66import dataclasses
77import random
88import string
9- from typing import Any , Dict , Iterator , List , Literal , Optional , Tuple , Union , Protocol
9+ from typing import Any , Dict , Iterator , List , Literal , Optional , Tuple , Union , Protocol , cast
1010
1111import jinja2
1212
@@ -338,6 +338,7 @@ def _convert_completion_to_chat_function(
338338 }
339339 ],
340340 },
341+ "logprobs" : None ,
341342 "finish_reason" : "tool_calls" ,
342343 }
343344 ],
@@ -1191,7 +1192,6 @@ def format_mistral_instruct(
11911192 elif (
11921193 message ["role" ] == "assistant"
11931194 and message ["content" ] is not None
1194- and isinstance (message ["content" ], str )
11951195 ):
11961196 prompt += " [/INST]" + message ["content" ] + eos
11971197 prompt += " [/INST]"
@@ -1263,7 +1263,7 @@ def format_gemma(
12631263 ** kwargs : Any ,
12641264) -> ChatFormatterResponse :
12651265 system_message = _get_system_message (messages )
1266- if system_message is not None and system_message != "" :
1266+ if system_message != "" :
12671267 logger .debug (
12681268 "`role='system'` messages are not allowed on Google's Gemma models."
12691269 )
@@ -1628,6 +1628,7 @@ def message_to_str(msg: llama_types.ChatCompletionRequestMessage):
16281628 }
16291629 ],
16301630 },
1631+ "logprobs" : None ,
16311632 "finish_reason" : "tool_calls" ,
16321633 }
16331634 ],
@@ -1909,14 +1910,14 @@ def get_grammar(function_call):
19091910 return grammar
19101911
19111912 def create_completion (stop ):
1912- completion : llama_types .Completion = llama .create_completion (
1913+ completion = cast ( llama_types .Completion , llama .create_completion (
19131914 prompt = prompt ,
19141915 temperature = temperature ,
19151916 top_p = top_p ,
19161917 top_k = top_k ,
19171918 min_p = min_p ,
19181919 typical_p = typical_p ,
1919- stream = stream ,
1920+ stream = False ,
19201921 stop = stop ,
19211922 max_tokens = max_tokens ,
19221923 presence_penalty = presence_penalty ,
@@ -1929,7 +1930,7 @@ def create_completion(stop):
19291930 model = model ,
19301931 logits_processor = logits_processor ,
19311932 grammar = grammar ,
1932- )
1933+ ))
19331934
19341935 return completion
19351936
@@ -2050,7 +2051,7 @@ def create_completion(stop):
20502051 assert "usage" in completion
20512052 assert len (function_calls ) == len (function_bodies )
20522053
2053- tool_calls = []
2054+ tool_calls : List [ llama_types . ChatCompletionMessageToolCall ] = []
20542055 for function_call , function_body in zip (function_calls , function_bodies ):
20552056 tool_calls .append (
20562057 {
@@ -2070,6 +2071,12 @@ def create_completion(stop):
20702071 )
20712072
20722073 # TODO: support stream mode
2074+ function_call_dict : Union [Dict [str , str ], Dict [Literal ["function_call" ], llama_types .ChatCompletionRequestAssistantMessageFunctionCall ]] = {
2075+ "function_call" : {
2076+ "name" : tool_calls [0 ]["function" ]["name" ],
2077+ "arguments" : tool_calls [0 ]["function" ]["arguments" ],
2078+ }
2079+ } if len (tool_calls ) == 1 else {}
20732080 return llama_types .CreateChatCompletionResponse (
20742081 id = "chat" + completion ["id" ],
20752082 object = "chat.completion" ,
@@ -2078,14 +2085,12 @@ def create_completion(stop):
20782085 choices = [
20792086 {
20802087 "index" : 0 ,
2088+ "logprobs" : None ,
20812089 "message" : {
20822090 "role" : "assistant" ,
20832091 "content" : None if content == "" else content ,
2084- "function_call" : {
2085- "name" : tool_calls [0 ]["function" ]["name" ],
2086- "arguments" : tool_calls [0 ]["function" ]["arguments" ],
2087- } if len (tool_calls ) > 0 else None ,
2088- "tool_calls" : tool_calls if len (tool_calls ) > 0 else None ,
2092+ "tool_calls" : tool_calls ,
2093+ ** function_call_dict ,
20892094 },
20902095 "finish_reason" : "tool_calls" if len (tool_calls ) > 0 else "stop" ,
20912096 }
@@ -2565,8 +2570,8 @@ def chatml_function_calling(
25652570 tool_name = text [len ("functions." ) :]
25662571 tool = next ((tool for tool in tools if tool ["function" ]["name" ] == tool_name ), None )
25672572 if not stream :
2568- completions = []
2569- completions_tool_name = []
2573+ completions : List [ llama_types . CreateCompletionResponse ] = []
2574+ completions_tool_name : List [ str ] = []
25702575 while tool is not None :
25712576 prompt += f"functions.{ tool_name } :\n "
25722577 try :
@@ -2603,6 +2608,7 @@ def chatml_function_calling(
26032608 logits_processor = logits_processor ,
26042609 grammar = grammar ,
26052610 )
2611+ completion_or_chunks = cast (llama_types .CreateCompletionResponse , completion_or_chunks )
26062612 completions .append (completion_or_chunks )
26072613 completions_tool_name .append (tool_name )
26082614 prompt += completion_or_chunks ["choices" ][0 ]["text" ]
@@ -2631,14 +2637,15 @@ def chatml_function_calling(
26312637 follow_up_gbnf_tool_grammar , verbose = llama .verbose
26322638 ),
26332639 )
2640+ response = cast (llama_types .CreateCompletionResponse , response )
26342641
26352642 tool_name = response ["choices" ][0 ]["text" ][len ("functions." ) :]
26362643 tool = next (
26372644 (tool for tool in tools if tool ["function" ]["name" ] == tool_name ), None
26382645 )
26392646
26402647 # Merge completions
2641- function_call = {
2648+ function_call_dict : Union [ Dict [ str , str ], Dict [ Literal [ " function_call" ], llama_types . ChatCompletionRequestAssistantMessageFunctionCall ]] = {
26422649 "function_call" : {
26432650 "name" : tool_name ,
26442651 "arguments" : completions [0 ]["choices" ][0 ]["text" ],
@@ -2653,6 +2660,7 @@ def chatml_function_calling(
26532660 {
26542661 "finish_reason" : "tool_calls" ,
26552662 "index" : 0 ,
2663+ "logprobs" : None ,
26562664 "message" : {
26572665 "role" : "assistant" ,
26582666 "content" : None ,
@@ -2673,20 +2681,22 @@ def chatml_function_calling(
26732681 zip (completions_tool_name , completions )
26742682 )
26752683 ],
2676- ** function_call
2684+ ** function_call_dict
26772685 },
26782686 }
26792687 ],
26802688 "usage" : {
26812689 "completion_tokens" : sum (
2682- completion ["usage" ]["completion_tokens" ]
2690+ completion ["usage" ]["completion_tokens" ] if "usage" in completion else 0
26832691 for completion in completions
26842692 ),
26852693 "prompt_tokens" : sum (
2686- completion ["usage" ]["prompt_tokens" ] for completion in completions
2694+ completion ["usage" ]["prompt_tokens" ] if "usage" in completion else 0
2695+ for completion in completions
26872696 ),
26882697 "total_tokens" : sum (
2689- completion ["usage" ]["total_tokens" ] for completion in completions
2698+ completion ["usage" ]["total_tokens" ] if "usage" in completion else 0
2699+ for completion in completions
26902700 ),
26912701 },
26922702 }
0 commit comments