From 8ffcb08465ba3926fabe1d88b2ce13d76fd1a365 Mon Sep 17 00:00:00 2001 From: ParthSareen Date: Tue, 11 Nov 2025 18:04:03 -0800 Subject: [PATCH 1/3] client/types: add logprobs support --- examples/chat-logprobs.py | 32 ++++++++++++++++ examples/generate-logprobs.py | 25 ++++++++++++ ollama/_client.py | 32 ++++++++++++++++ ollama/_types.py | 31 +++++++++++++++ tests/test_client.py | 72 +++++++++++++++++++++++++++++++++++ 5 files changed, 192 insertions(+) create mode 100644 examples/chat-logprobs.py create mode 100644 examples/generate-logprobs.py diff --git a/examples/chat-logprobs.py b/examples/chat-logprobs.py new file mode 100644 index 0000000..a3a1af7 --- /dev/null +++ b/examples/chat-logprobs.py @@ -0,0 +1,32 @@ +from typing import Iterable + +import ollama + + +def print_logprobs(logprobs: Iterable[dict], label: str) -> None: + print(f"\n{label}:") + for entry in logprobs: + token = entry.get("token", "") + logprob = entry.get("logprob") + print(f" token={token!r:<12} logprob={logprob:.3f}") + for alt in entry.get("top_logprobs", []): + if alt['token'] != token: + print(f" alt -> {alt['token']!r:<12} ({alt['logprob']:.3f})") + + +messages = [ + { + "role": "user", + "content": "hi! be concise.", + }, +] + +response = ollama.chat( + model="gemma3", + messages=messages, + logprobs=True, + top_logprobs=3, +) +print("Chat response:", response["message"]["content"]) +print_logprobs(response.get("logprobs", []), "chat logprobs") + diff --git a/examples/generate-logprobs.py b/examples/generate-logprobs.py new file mode 100644 index 0000000..9002665 --- /dev/null +++ b/examples/generate-logprobs.py @@ -0,0 +1,25 @@ +from typing import Iterable + +import ollama + + +def print_logprobs(logprobs: Iterable[dict], label: str) -> None: + print(f"\n{label}:") + for entry in logprobs: + token = entry.get("token", "") + logprob = entry.get("logprob") + print(f" token={token!r:<12} logprob={logprob:.3f}") + for alt in entry.get("top_logprobs", []): + if alt['token'] != token: + print(f" alt -> {alt['token']!r:<12} ({alt['logprob']:.3f})") + + +response = ollama.generate( + model="gemma3", + prompt="hi! be concise.", + logprobs=True, + top_logprobs=3, +) +print("Generate response:", response["response"]) +print_logprobs(response.get("logprobs", []), "generate logprobs") + diff --git a/ollama/_client.py b/ollama/_client.py index dcd8126..e9350ee 100644 --- a/ollama/_client.py +++ b/ollama/_client.py @@ -200,6 +200,8 @@ def generate( context: Optional[Sequence[int]] = None, stream: Literal[False] = False, think: Optional[bool] = None, + logprobs: Optional[bool] = None, + top_logprobs: Optional[int] = None, raw: bool = False, format: Optional[Union[Literal['', 'json'], JsonSchemaValue]] = None, images: Optional[Sequence[Union[str, bytes, Image]]] = None, @@ -219,6 +221,8 @@ def generate( context: Optional[Sequence[int]] = None, stream: Literal[True] = True, think: Optional[bool] = None, + logprobs: Optional[bool] = None, + top_logprobs: Optional[int] = None, raw: bool = False, format: Optional[Union[Literal['', 'json'], JsonSchemaValue]] = None, images: Optional[Sequence[Union[str, bytes, Image]]] = None, @@ -237,6 +241,8 @@ def generate( context: Optional[Sequence[int]] = None, stream: bool = False, think: Optional[bool] = None, + logprobs: Optional[bool] = None, + top_logprobs: Optional[int] = None, raw: Optional[bool] = None, format: Optional[Union[Literal['', 'json'], JsonSchemaValue]] = None, images: Optional[Sequence[Union[str, bytes, Image]]] = None, @@ -266,6 +272,8 @@ def generate( context=context, stream=stream, think=think, + logprobs=logprobs, + top_logprobs=top_logprobs, raw=raw, format=format, images=list(_copy_images(images)) if images else None, @@ -284,6 +292,8 @@ def chat( tools: Optional[Sequence[Union[Mapping[str, Any], Tool, Callable]]] = None, stream: Literal[False] = False, think: Optional[Union[bool, Literal['low', 'medium', 'high']]] = None, + logprobs: Optional[bool] = None, + top_logprobs: Optional[int] = None, format: Optional[Union[Literal['', 'json'], JsonSchemaValue]] = None, options: Optional[Union[Mapping[str, Any], Options]] = None, keep_alive: Optional[Union[float, str]] = None, @@ -298,6 +308,8 @@ def chat( tools: Optional[Sequence[Union[Mapping[str, Any], Tool, Callable]]] = None, stream: Literal[True] = True, think: Optional[Union[bool, Literal['low', 'medium', 'high']]] = None, + logprobs: Optional[bool] = None, + top_logprobs: Optional[int] = None, format: Optional[Union[Literal['', 'json'], JsonSchemaValue]] = None, options: Optional[Union[Mapping[str, Any], Options]] = None, keep_alive: Optional[Union[float, str]] = None, @@ -311,6 +323,8 @@ def chat( tools: Optional[Sequence[Union[Mapping[str, Any], Tool, Callable]]] = None, stream: bool = False, think: Optional[Union[bool, Literal['low', 'medium', 'high']]] = None, + logprobs: Optional[bool] = None, + top_logprobs: Optional[int] = None, format: Optional[Union[Literal['', 'json'], JsonSchemaValue]] = None, options: Optional[Union[Mapping[str, Any], Options]] = None, keep_alive: Optional[Union[float, str]] = None, @@ -358,6 +372,8 @@ def add_two_numbers(a: int, b: int) -> int: tools=list(_copy_tools(tools)), stream=stream, think=think, + logprobs=logprobs, + top_logprobs=top_logprobs, format=format, options=options, keep_alive=keep_alive, @@ -802,6 +818,8 @@ async def generate( context: Optional[Sequence[int]] = None, stream: Literal[False] = False, think: Optional[Union[bool, Literal['low', 'medium', 'high']]] = None, + logprobs: Optional[bool] = None, + top_logprobs: Optional[int] = None, raw: bool = False, format: Optional[Union[Literal['', 'json'], JsonSchemaValue]] = None, images: Optional[Sequence[Union[str, bytes, Image]]] = None, @@ -821,6 +839,8 @@ async def generate( context: Optional[Sequence[int]] = None, stream: Literal[True] = True, think: Optional[Union[bool, Literal['low', 'medium', 'high']]] = None, + logprobs: Optional[bool] = None, + top_logprobs: Optional[int] = None, raw: bool = False, format: Optional[Union[Literal['', 'json'], JsonSchemaValue]] = None, images: Optional[Sequence[Union[str, bytes, Image]]] = None, @@ -839,6 +859,8 @@ async def generate( context: Optional[Sequence[int]] = None, stream: bool = False, think: Optional[Union[bool, Literal['low', 'medium', 'high']]] = None, + logprobs: Optional[bool] = None, + top_logprobs: Optional[int] = None, raw: Optional[bool] = None, format: Optional[Union[Literal['', 'json'], JsonSchemaValue]] = None, images: Optional[Sequence[Union[str, bytes, Image]]] = None, @@ -867,6 +889,8 @@ async def generate( context=context, stream=stream, think=think, + logprobs=logprobs, + top_logprobs=top_logprobs, raw=raw, format=format, images=list(_copy_images(images)) if images else None, @@ -885,6 +909,8 @@ async def chat( tools: Optional[Sequence[Union[Mapping[str, Any], Tool, Callable]]] = None, stream: Literal[False] = False, think: Optional[Union[bool, Literal['low', 'medium', 'high']]] = None, + logprobs: Optional[bool] = None, + top_logprobs: Optional[int] = None, format: Optional[Union[Literal['', 'json'], JsonSchemaValue]] = None, options: Optional[Union[Mapping[str, Any], Options]] = None, keep_alive: Optional[Union[float, str]] = None, @@ -899,6 +925,8 @@ async def chat( tools: Optional[Sequence[Union[Mapping[str, Any], Tool, Callable]]] = None, stream: Literal[True] = True, think: Optional[Union[bool, Literal['low', 'medium', 'high']]] = None, + logprobs: Optional[bool] = None, + top_logprobs: Optional[int] = None, format: Optional[Union[Literal['', 'json'], JsonSchemaValue]] = None, options: Optional[Union[Mapping[str, Any], Options]] = None, keep_alive: Optional[Union[float, str]] = None, @@ -912,6 +940,8 @@ async def chat( tools: Optional[Sequence[Union[Mapping[str, Any], Tool, Callable]]] = None, stream: bool = False, think: Optional[Union[bool, Literal['low', 'medium', 'high']]] = None, + logprobs: Optional[bool] = None, + top_logprobs: Optional[int] = None, format: Optional[Union[Literal['', 'json'], JsonSchemaValue]] = None, options: Optional[Union[Mapping[str, Any], Options]] = None, keep_alive: Optional[Union[float, str]] = None, @@ -960,6 +990,8 @@ def add_two_numbers(a: int, b: int) -> int: tools=list(_copy_tools(tools)), stream=stream, think=think, + logprobs=logprobs, + top_logprobs=top_logprobs, format=format, options=options, keep_alive=keep_alive, diff --git a/ollama/_types.py b/ollama/_types.py index 638fb9e..8931cea 100644 --- a/ollama/_types.py +++ b/ollama/_types.py @@ -210,6 +210,12 @@ class GenerateRequest(BaseGenerateRequest): think: Optional[Union[bool, Literal['low', 'medium', 'high']]] = None 'Enable thinking mode (for thinking models).' + logprobs: Optional[bool] = None + 'Return log probabilities for generated tokens.' + + top_logprobs: Optional[int] = None + 'Number of alternative tokens and log probabilities to include per position (0-20).' + class BaseGenerateResponse(SubscriptableBaseModel): model: Optional[str] = None @@ -243,6 +249,19 @@ class BaseGenerateResponse(SubscriptableBaseModel): 'Duration of evaluating inference in nanoseconds.' +class TokenLogprob(SubscriptableBaseModel): + token: str + 'Token text.' + + logprob: float + 'Log probability for the token.' + + +class Logprob(TokenLogprob): + top_logprobs: Optional[Sequence[TokenLogprob]] = None + 'Most likely tokens and their log probabilities.' + + class GenerateResponse(BaseGenerateResponse): """ Response returned by generate requests. @@ -257,6 +276,9 @@ class GenerateResponse(BaseGenerateResponse): context: Optional[Sequence[int]] = None 'Tokenized history up to the point of the response.' + logprobs: Optional[Sequence[Logprob]] = None + 'Log probabilities for generated tokens.' + class Message(SubscriptableBaseModel): """ @@ -360,6 +382,12 @@ def serialize_model(self, nxt): think: Optional[Union[bool, Literal['low', 'medium', 'high']]] = None 'Enable thinking mode (for thinking models).' + logprobs: Optional[bool] = None + 'Return log probabilities for generated tokens.' + + top_logprobs: Optional[int] = None + 'Number of alternative tokens and log probabilities to include per position (0-20).' + class ChatResponse(BaseGenerateResponse): """ @@ -369,6 +397,9 @@ class ChatResponse(BaseGenerateResponse): message: Message 'Response message.' + logprobs: Optional[Sequence[Logprob]] = None + 'Log probabilities for generated tokens if requested.' + class EmbedRequest(BaseRequest): input: Union[str, Sequence[str]] diff --git a/tests/test_client.py b/tests/test_client.py index 449d6ab..4c0cec1 100644 --- a/tests/test_client.py +++ b/tests/test_client.py @@ -61,6 +61,44 @@ def test_client_chat(httpserver: HTTPServer): assert response['message']['content'] == "I don't know." +def test_client_chat_with_logprobs(httpserver: HTTPServer): + httpserver.expect_ordered_request( + '/api/chat', + method='POST', + json={ + 'model': 'dummy', + 'messages': [{'role': 'user', 'content': 'Hi'}], + 'tools': [], + 'stream': False, + 'logprobs': True, + 'top_logprobs': 3, + }, + ).respond_with_json( + { + 'model': 'dummy', + 'message': { + 'role': 'assistant', + 'content': 'Hello', + }, + 'logprobs': [ + { + 'token': 'Hello', + 'logprob': -0.1, + 'top_logprobs': [ + {'token': 'Hello', 'logprob': -0.1}, + {'token': 'Hi', 'logprob': -1.0}, + ], + } + ], + } + ) + + client = Client(httpserver.url_for('/')) + response = client.chat('dummy', messages=[{'role': 'user', 'content': 'Hi'}], logprobs=True, top_logprobs=3) + assert response['logprobs'][0]['token'] == 'Hello' + assert response['logprobs'][0]['top_logprobs'][1]['token'] == 'Hi' + + def test_client_chat_stream(httpserver: HTTPServer): def stream_handler(_: Request): def generate(): @@ -294,6 +332,40 @@ def test_client_generate(httpserver: HTTPServer): assert response['response'] == 'Because it is.' +def test_client_generate_with_logprobs(httpserver: HTTPServer): + httpserver.expect_ordered_request( + '/api/generate', + method='POST', + json={ + 'model': 'dummy', + 'prompt': 'Why', + 'stream': False, + 'logprobs': True, + 'top_logprobs': 2, + }, + ).respond_with_json( + { + 'model': 'dummy', + 'response': 'Hello', + 'logprobs': [ + { + 'token': 'Hello', + 'logprob': -0.2, + 'top_logprobs': [ + {'token': 'Hello', 'logprob': -0.2}, + {'token': 'Hi', 'logprob': -1.5}, + ], + } + ], + } + ) + + client = Client(httpserver.url_for('/')) + response = client.generate('dummy', 'Why', logprobs=True, top_logprobs=2) + assert response['logprobs'][0]['token'] == 'Hello' + assert response['logprobs'][0]['top_logprobs'][1]['token'] == 'Hi' + + def test_client_generate_with_image_type(httpserver: HTTPServer): httpserver.expect_ordered_request( '/api/generate', From 522804fb0c28b09f10753518e734a989f3c040e1 Mon Sep 17 00:00:00 2001 From: ParthSareen Date: Tue, 11 Nov 2025 18:06:40 -0800 Subject: [PATCH 2/3] fix lint --- examples/chat-logprobs.py | 23 +++++++++++------------ examples/generate-logprobs.py | 21 ++++++++++----------- 2 files changed, 21 insertions(+), 23 deletions(-) diff --git a/examples/chat-logprobs.py b/examples/chat-logprobs.py index a3a1af7..b5eea18 100644 --- a/examples/chat-logprobs.py +++ b/examples/chat-logprobs.py @@ -4,29 +4,28 @@ def print_logprobs(logprobs: Iterable[dict], label: str) -> None: - print(f"\n{label}:") + print(f'\n{label}:') for entry in logprobs: - token = entry.get("token", "") - logprob = entry.get("logprob") - print(f" token={token!r:<12} logprob={logprob:.3f}") - for alt in entry.get("top_logprobs", []): + token = entry.get('token', '') + logprob = entry.get('logprob') + print(f' token={token!r:<12} logprob={logprob:.3f}') + for alt in entry.get('top_logprobs', []): if alt['token'] != token: - print(f" alt -> {alt['token']!r:<12} ({alt['logprob']:.3f})") + print(f' alt -> {alt["token"]!r:<12} ({alt["logprob"]:.3f})') messages = [ { - "role": "user", - "content": "hi! be concise.", + 'role': 'user', + 'content': 'hi! be concise.', }, ] response = ollama.chat( - model="gemma3", + model='gemma3', messages=messages, logprobs=True, top_logprobs=3, ) -print("Chat response:", response["message"]["content"]) -print_logprobs(response.get("logprobs", []), "chat logprobs") - +print('Chat response:', response['message']['content']) +print_logprobs(response.get('logprobs', []), 'chat logprobs') diff --git a/examples/generate-logprobs.py b/examples/generate-logprobs.py index 9002665..494eb3e 100644 --- a/examples/generate-logprobs.py +++ b/examples/generate-logprobs.py @@ -4,22 +4,21 @@ def print_logprobs(logprobs: Iterable[dict], label: str) -> None: - print(f"\n{label}:") + print(f'\n{label}:') for entry in logprobs: - token = entry.get("token", "") - logprob = entry.get("logprob") - print(f" token={token!r:<12} logprob={logprob:.3f}") - for alt in entry.get("top_logprobs", []): + token = entry.get('token', '') + logprob = entry.get('logprob') + print(f' token={token!r:<12} logprob={logprob:.3f}') + for alt in entry.get('top_logprobs', []): if alt['token'] != token: - print(f" alt -> {alt['token']!r:<12} ({alt['logprob']:.3f})") + print(f' alt -> {alt["token"]!r:<12} ({alt["logprob"]:.3f})') response = ollama.generate( - model="gemma3", - prompt="hi! be concise.", + model='gemma3', + prompt='hi! be concise.', logprobs=True, top_logprobs=3, ) -print("Generate response:", response["response"]) -print_logprobs(response.get("logprobs", []), "generate logprobs") - +print('Generate response:', response['response']) +print_logprobs(response.get('logprobs', []), 'generate logprobs') From 0d84681f37b42e33168d4e6104d3db7fcf5fc7af Mon Sep 17 00:00:00 2001 From: ParthSareen Date: Tue, 11 Nov 2025 18:10:44 -0800 Subject: [PATCH 3/3] lint --- examples/chat-with-history.py | 3 ++- pyproject.toml | 2 +- 2 files changed, 3 insertions(+), 2 deletions(-) diff --git a/examples/chat-with-history.py b/examples/chat-with-history.py index 09104ae..7275471 100644 --- a/examples/chat-with-history.py +++ b/examples/chat-with-history.py @@ -15,7 +15,8 @@ }, { 'role': 'assistant', - 'content': 'The weather in Tokyo is typically warm and humid during the summer months, with temperatures often exceeding 30°C (86°F). The city experiences a rainy season from June to September, with heavy rainfall and occasional typhoons. Winter is mild, with temperatures rarely dropping below freezing. The city is known for its high-tech and vibrant culture, with many popular tourist attractions such as the Tokyo Tower, Senso-ji Temple, and the bustling Shibuya district.', + 'content': """The weather in Tokyo is typically warm and humid during the summer months, with temperatures often exceeding 30°C (86°F). The city experiences a rainy season from June to September, with heavy rainfall and occasional typhoons. Winter is mild, with temperatures + rarely dropping below freezing. The city is known for its high-tech and vibrant culture, with many popular tourist attractions such as the Tokyo Tower, Senso-ji Temple, and the bustling Shibuya district.""", }, ] diff --git a/pyproject.toml b/pyproject.toml index 05eafdc..c58f323 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -37,7 +37,7 @@ dependencies = [ 'ruff>=0.9.1' ] config-path = 'none' [tool.ruff] -line-length = 999 +line-length = 320 indent-width = 2 [tool.ruff.format]