diff --git a/examples/chat-logprobs.py b/examples/chat-logprobs.py new file mode 100644 index 0000000..b5eea18 --- /dev/null +++ b/examples/chat-logprobs.py @@ -0,0 +1,31 @@ +from typing import Iterable + +import ollama + + +def print_logprobs(logprobs: Iterable[dict], label: str) -> None: + print(f'\n{label}:') + for entry in logprobs: + token = entry.get('token', '') + logprob = entry.get('logprob') + print(f' token={token!r:<12} logprob={logprob:.3f}') + for alt in entry.get('top_logprobs', []): + if alt['token'] != token: + print(f' alt -> {alt["token"]!r:<12} ({alt["logprob"]:.3f})') + + +messages = [ + { + 'role': 'user', + 'content': 'hi! be concise.', + }, +] + +response = ollama.chat( + model='gemma3', + messages=messages, + logprobs=True, + top_logprobs=3, +) +print('Chat response:', response['message']['content']) +print_logprobs(response.get('logprobs', []), 'chat logprobs') diff --git a/examples/chat-with-history.py b/examples/chat-with-history.py index 09104ae..7275471 100644 --- a/examples/chat-with-history.py +++ b/examples/chat-with-history.py @@ -15,7 +15,8 @@ }, { 'role': 'assistant', - 'content': 'The weather in Tokyo is typically warm and humid during the summer months, with temperatures often exceeding 30°C (86°F). The city experiences a rainy season from June to September, with heavy rainfall and occasional typhoons. Winter is mild, with temperatures rarely dropping below freezing. The city is known for its high-tech and vibrant culture, with many popular tourist attractions such as the Tokyo Tower, Senso-ji Temple, and the bustling Shibuya district.', + 'content': """The weather in Tokyo is typically warm and humid during the summer months, with temperatures often exceeding 30°C (86°F). The city experiences a rainy season from June to September, with heavy rainfall and occasional typhoons. Winter is mild, with temperatures + rarely dropping below freezing. The city is known for its high-tech and vibrant culture, with many popular tourist attractions such as the Tokyo Tower, Senso-ji Temple, and the bustling Shibuya district.""", }, ] diff --git a/examples/generate-logprobs.py b/examples/generate-logprobs.py new file mode 100644 index 0000000..494eb3e --- /dev/null +++ b/examples/generate-logprobs.py @@ -0,0 +1,24 @@ +from typing import Iterable + +import ollama + + +def print_logprobs(logprobs: Iterable[dict], label: str) -> None: + print(f'\n{label}:') + for entry in logprobs: + token = entry.get('token', '') + logprob = entry.get('logprob') + print(f' token={token!r:<12} logprob={logprob:.3f}') + for alt in entry.get('top_logprobs', []): + if alt['token'] != token: + print(f' alt -> {alt["token"]!r:<12} ({alt["logprob"]:.3f})') + + +response = ollama.generate( + model='gemma3', + prompt='hi! be concise.', + logprobs=True, + top_logprobs=3, +) +print('Generate response:', response['response']) +print_logprobs(response.get('logprobs', []), 'generate logprobs') diff --git a/ollama/_client.py b/ollama/_client.py index dcd8126..e9350ee 100644 --- a/ollama/_client.py +++ b/ollama/_client.py @@ -200,6 +200,8 @@ def generate( context: Optional[Sequence[int]] = None, stream: Literal[False] = False, think: Optional[bool] = None, + logprobs: Optional[bool] = None, + top_logprobs: Optional[int] = None, raw: bool = False, format: Optional[Union[Literal['', 'json'], JsonSchemaValue]] = None, images: Optional[Sequence[Union[str, bytes, Image]]] = None, @@ -219,6 +221,8 @@ def generate( context: Optional[Sequence[int]] = None, stream: Literal[True] = True, think: Optional[bool] = None, + logprobs: Optional[bool] = None, + top_logprobs: Optional[int] = None, raw: bool = False, format: Optional[Union[Literal['', 'json'], JsonSchemaValue]] = None, images: Optional[Sequence[Union[str, bytes, Image]]] = None, @@ -237,6 +241,8 @@ def generate( context: Optional[Sequence[int]] = None, stream: bool = False, think: Optional[bool] = None, + logprobs: Optional[bool] = None, + top_logprobs: Optional[int] = None, raw: Optional[bool] = None, format: Optional[Union[Literal['', 'json'], JsonSchemaValue]] = None, images: Optional[Sequence[Union[str, bytes, Image]]] = None, @@ -266,6 +272,8 @@ def generate( context=context, stream=stream, think=think, + logprobs=logprobs, + top_logprobs=top_logprobs, raw=raw, format=format, images=list(_copy_images(images)) if images else None, @@ -284,6 +292,8 @@ def chat( tools: Optional[Sequence[Union[Mapping[str, Any], Tool, Callable]]] = None, stream: Literal[False] = False, think: Optional[Union[bool, Literal['low', 'medium', 'high']]] = None, + logprobs: Optional[bool] = None, + top_logprobs: Optional[int] = None, format: Optional[Union[Literal['', 'json'], JsonSchemaValue]] = None, options: Optional[Union[Mapping[str, Any], Options]] = None, keep_alive: Optional[Union[float, str]] = None, @@ -298,6 +308,8 @@ def chat( tools: Optional[Sequence[Union[Mapping[str, Any], Tool, Callable]]] = None, stream: Literal[True] = True, think: Optional[Union[bool, Literal['low', 'medium', 'high']]] = None, + logprobs: Optional[bool] = None, + top_logprobs: Optional[int] = None, format: Optional[Union[Literal['', 'json'], JsonSchemaValue]] = None, options: Optional[Union[Mapping[str, Any], Options]] = None, keep_alive: Optional[Union[float, str]] = None, @@ -311,6 +323,8 @@ def chat( tools: Optional[Sequence[Union[Mapping[str, Any], Tool, Callable]]] = None, stream: bool = False, think: Optional[Union[bool, Literal['low', 'medium', 'high']]] = None, + logprobs: Optional[bool] = None, + top_logprobs: Optional[int] = None, format: Optional[Union[Literal['', 'json'], JsonSchemaValue]] = None, options: Optional[Union[Mapping[str, Any], Options]] = None, keep_alive: Optional[Union[float, str]] = None, @@ -358,6 +372,8 @@ def add_two_numbers(a: int, b: int) -> int: tools=list(_copy_tools(tools)), stream=stream, think=think, + logprobs=logprobs, + top_logprobs=top_logprobs, format=format, options=options, keep_alive=keep_alive, @@ -802,6 +818,8 @@ async def generate( context: Optional[Sequence[int]] = None, stream: Literal[False] = False, think: Optional[Union[bool, Literal['low', 'medium', 'high']]] = None, + logprobs: Optional[bool] = None, + top_logprobs: Optional[int] = None, raw: bool = False, format: Optional[Union[Literal['', 'json'], JsonSchemaValue]] = None, images: Optional[Sequence[Union[str, bytes, Image]]] = None, @@ -821,6 +839,8 @@ async def generate( context: Optional[Sequence[int]] = None, stream: Literal[True] = True, think: Optional[Union[bool, Literal['low', 'medium', 'high']]] = None, + logprobs: Optional[bool] = None, + top_logprobs: Optional[int] = None, raw: bool = False, format: Optional[Union[Literal['', 'json'], JsonSchemaValue]] = None, images: Optional[Sequence[Union[str, bytes, Image]]] = None, @@ -839,6 +859,8 @@ async def generate( context: Optional[Sequence[int]] = None, stream: bool = False, think: Optional[Union[bool, Literal['low', 'medium', 'high']]] = None, + logprobs: Optional[bool] = None, + top_logprobs: Optional[int] = None, raw: Optional[bool] = None, format: Optional[Union[Literal['', 'json'], JsonSchemaValue]] = None, images: Optional[Sequence[Union[str, bytes, Image]]] = None, @@ -867,6 +889,8 @@ async def generate( context=context, stream=stream, think=think, + logprobs=logprobs, + top_logprobs=top_logprobs, raw=raw, format=format, images=list(_copy_images(images)) if images else None, @@ -885,6 +909,8 @@ async def chat( tools: Optional[Sequence[Union[Mapping[str, Any], Tool, Callable]]] = None, stream: Literal[False] = False, think: Optional[Union[bool, Literal['low', 'medium', 'high']]] = None, + logprobs: Optional[bool] = None, + top_logprobs: Optional[int] = None, format: Optional[Union[Literal['', 'json'], JsonSchemaValue]] = None, options: Optional[Union[Mapping[str, Any], Options]] = None, keep_alive: Optional[Union[float, str]] = None, @@ -899,6 +925,8 @@ async def chat( tools: Optional[Sequence[Union[Mapping[str, Any], Tool, Callable]]] = None, stream: Literal[True] = True, think: Optional[Union[bool, Literal['low', 'medium', 'high']]] = None, + logprobs: Optional[bool] = None, + top_logprobs: Optional[int] = None, format: Optional[Union[Literal['', 'json'], JsonSchemaValue]] = None, options: Optional[Union[Mapping[str, Any], Options]] = None, keep_alive: Optional[Union[float, str]] = None, @@ -912,6 +940,8 @@ async def chat( tools: Optional[Sequence[Union[Mapping[str, Any], Tool, Callable]]] = None, stream: bool = False, think: Optional[Union[bool, Literal['low', 'medium', 'high']]] = None, + logprobs: Optional[bool] = None, + top_logprobs: Optional[int] = None, format: Optional[Union[Literal['', 'json'], JsonSchemaValue]] = None, options: Optional[Union[Mapping[str, Any], Options]] = None, keep_alive: Optional[Union[float, str]] = None, @@ -960,6 +990,8 @@ def add_two_numbers(a: int, b: int) -> int: tools=list(_copy_tools(tools)), stream=stream, think=think, + logprobs=logprobs, + top_logprobs=top_logprobs, format=format, options=options, keep_alive=keep_alive, diff --git a/ollama/_types.py b/ollama/_types.py index 638fb9e..8931cea 100644 --- a/ollama/_types.py +++ b/ollama/_types.py @@ -210,6 +210,12 @@ class GenerateRequest(BaseGenerateRequest): think: Optional[Union[bool, Literal['low', 'medium', 'high']]] = None 'Enable thinking mode (for thinking models).' + logprobs: Optional[bool] = None + 'Return log probabilities for generated tokens.' + + top_logprobs: Optional[int] = None + 'Number of alternative tokens and log probabilities to include per position (0-20).' + class BaseGenerateResponse(SubscriptableBaseModel): model: Optional[str] = None @@ -243,6 +249,19 @@ class BaseGenerateResponse(SubscriptableBaseModel): 'Duration of evaluating inference in nanoseconds.' +class TokenLogprob(SubscriptableBaseModel): + token: str + 'Token text.' + + logprob: float + 'Log probability for the token.' + + +class Logprob(TokenLogprob): + top_logprobs: Optional[Sequence[TokenLogprob]] = None + 'Most likely tokens and their log probabilities.' + + class GenerateResponse(BaseGenerateResponse): """ Response returned by generate requests. @@ -257,6 +276,9 @@ class GenerateResponse(BaseGenerateResponse): context: Optional[Sequence[int]] = None 'Tokenized history up to the point of the response.' + logprobs: Optional[Sequence[Logprob]] = None + 'Log probabilities for generated tokens.' + class Message(SubscriptableBaseModel): """ @@ -360,6 +382,12 @@ def serialize_model(self, nxt): think: Optional[Union[bool, Literal['low', 'medium', 'high']]] = None 'Enable thinking mode (for thinking models).' + logprobs: Optional[bool] = None + 'Return log probabilities for generated tokens.' + + top_logprobs: Optional[int] = None + 'Number of alternative tokens and log probabilities to include per position (0-20).' + class ChatResponse(BaseGenerateResponse): """ @@ -369,6 +397,9 @@ class ChatResponse(BaseGenerateResponse): message: Message 'Response message.' + logprobs: Optional[Sequence[Logprob]] = None + 'Log probabilities for generated tokens if requested.' + class EmbedRequest(BaseRequest): input: Union[str, Sequence[str]] diff --git a/pyproject.toml b/pyproject.toml index 05eafdc..c58f323 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -37,7 +37,7 @@ dependencies = [ 'ruff>=0.9.1' ] config-path = 'none' [tool.ruff] -line-length = 999 +line-length = 320 indent-width = 2 [tool.ruff.format] diff --git a/tests/test_client.py b/tests/test_client.py index 449d6ab..4c0cec1 100644 --- a/tests/test_client.py +++ b/tests/test_client.py @@ -61,6 +61,44 @@ def test_client_chat(httpserver: HTTPServer): assert response['message']['content'] == "I don't know." +def test_client_chat_with_logprobs(httpserver: HTTPServer): + httpserver.expect_ordered_request( + '/api/chat', + method='POST', + json={ + 'model': 'dummy', + 'messages': [{'role': 'user', 'content': 'Hi'}], + 'tools': [], + 'stream': False, + 'logprobs': True, + 'top_logprobs': 3, + }, + ).respond_with_json( + { + 'model': 'dummy', + 'message': { + 'role': 'assistant', + 'content': 'Hello', + }, + 'logprobs': [ + { + 'token': 'Hello', + 'logprob': -0.1, + 'top_logprobs': [ + {'token': 'Hello', 'logprob': -0.1}, + {'token': 'Hi', 'logprob': -1.0}, + ], + } + ], + } + ) + + client = Client(httpserver.url_for('/')) + response = client.chat('dummy', messages=[{'role': 'user', 'content': 'Hi'}], logprobs=True, top_logprobs=3) + assert response['logprobs'][0]['token'] == 'Hello' + assert response['logprobs'][0]['top_logprobs'][1]['token'] == 'Hi' + + def test_client_chat_stream(httpserver: HTTPServer): def stream_handler(_: Request): def generate(): @@ -294,6 +332,40 @@ def test_client_generate(httpserver: HTTPServer): assert response['response'] == 'Because it is.' +def test_client_generate_with_logprobs(httpserver: HTTPServer): + httpserver.expect_ordered_request( + '/api/generate', + method='POST', + json={ + 'model': 'dummy', + 'prompt': 'Why', + 'stream': False, + 'logprobs': True, + 'top_logprobs': 2, + }, + ).respond_with_json( + { + 'model': 'dummy', + 'response': 'Hello', + 'logprobs': [ + { + 'token': 'Hello', + 'logprob': -0.2, + 'top_logprobs': [ + {'token': 'Hello', 'logprob': -0.2}, + {'token': 'Hi', 'logprob': -1.5}, + ], + } + ], + } + ) + + client = Client(httpserver.url_for('/')) + response = client.generate('dummy', 'Why', logprobs=True, top_logprobs=2) + assert response['logprobs'][0]['token'] == 'Hello' + assert response['logprobs'][0]['top_logprobs'][1]['token'] == 'Hi' + + def test_client_generate_with_image_type(httpserver: HTTPServer): httpserver.expect_ordered_request( '/api/generate',