Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
31 changes: 31 additions & 0 deletions examples/chat-logprobs.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,31 @@
from typing import Iterable

import ollama


def print_logprobs(logprobs: Iterable[dict], label: str) -> None:
print(f'\n{label}:')
for entry in logprobs:
token = entry.get('token', '')
logprob = entry.get('logprob')
print(f' token={token!r:<12} logprob={logprob:.3f}')
for alt in entry.get('top_logprobs', []):
if alt['token'] != token:
print(f' alt -> {alt["token"]!r:<12} ({alt["logprob"]:.3f})')


messages = [
{
'role': 'user',
'content': 'hi! be concise.',
},
]

response = ollama.chat(
model='gemma3',
messages=messages,
logprobs=True,
top_logprobs=3,
)
print('Chat response:', response['message']['content'])
print_logprobs(response.get('logprobs', []), 'chat logprobs')
3 changes: 2 additions & 1 deletion examples/chat-with-history.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,8 @@
},
{
'role': 'assistant',
'content': 'The weather in Tokyo is typically warm and humid during the summer months, with temperatures often exceeding 30°C (86°F). The city experiences a rainy season from June to September, with heavy rainfall and occasional typhoons. Winter is mild, with temperatures rarely dropping below freezing. The city is known for its high-tech and vibrant culture, with many popular tourist attractions such as the Tokyo Tower, Senso-ji Temple, and the bustling Shibuya district.',
'content': """The weather in Tokyo is typically warm and humid during the summer months, with temperatures often exceeding 30°C (86°F). The city experiences a rainy season from June to September, with heavy rainfall and occasional typhoons. Winter is mild, with temperatures
rarely dropping below freezing. The city is known for its high-tech and vibrant culture, with many popular tourist attractions such as the Tokyo Tower, Senso-ji Temple, and the bustling Shibuya district.""",
},
]

Expand Down
24 changes: 24 additions & 0 deletions examples/generate-logprobs.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,24 @@
from typing import Iterable

import ollama


def print_logprobs(logprobs: Iterable[dict], label: str) -> None:
print(f'\n{label}:')
for entry in logprobs:
token = entry.get('token', '')
logprob = entry.get('logprob')
print(f' token={token!r:<12} logprob={logprob:.3f}')
for alt in entry.get('top_logprobs', []):
if alt['token'] != token:
print(f' alt -> {alt["token"]!r:<12} ({alt["logprob"]:.3f})')


response = ollama.generate(
model='gemma3',
prompt='hi! be concise.',
logprobs=True,
top_logprobs=3,
)
print('Generate response:', response['response'])
print_logprobs(response.get('logprobs', []), 'generate logprobs')
32 changes: 32 additions & 0 deletions ollama/_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -200,6 +200,8 @@ def generate(
context: Optional[Sequence[int]] = None,
stream: Literal[False] = False,
think: Optional[bool] = None,
logprobs: Optional[bool] = None,
top_logprobs: Optional[int] = None,
raw: bool = False,
format: Optional[Union[Literal['', 'json'], JsonSchemaValue]] = None,
images: Optional[Sequence[Union[str, bytes, Image]]] = None,
Expand All @@ -219,6 +221,8 @@ def generate(
context: Optional[Sequence[int]] = None,
stream: Literal[True] = True,
think: Optional[bool] = None,
logprobs: Optional[bool] = None,
top_logprobs: Optional[int] = None,
raw: bool = False,
format: Optional[Union[Literal['', 'json'], JsonSchemaValue]] = None,
images: Optional[Sequence[Union[str, bytes, Image]]] = None,
Expand All @@ -237,6 +241,8 @@ def generate(
context: Optional[Sequence[int]] = None,
stream: bool = False,
think: Optional[bool] = None,
logprobs: Optional[bool] = None,
top_logprobs: Optional[int] = None,
raw: Optional[bool] = None,
format: Optional[Union[Literal['', 'json'], JsonSchemaValue]] = None,
images: Optional[Sequence[Union[str, bytes, Image]]] = None,
Expand Down Expand Up @@ -266,6 +272,8 @@ def generate(
context=context,
stream=stream,
think=think,
logprobs=logprobs,
top_logprobs=top_logprobs,
raw=raw,
format=format,
images=list(_copy_images(images)) if images else None,
Expand All @@ -284,6 +292,8 @@ def chat(
tools: Optional[Sequence[Union[Mapping[str, Any], Tool, Callable]]] = None,
stream: Literal[False] = False,
think: Optional[Union[bool, Literal['low', 'medium', 'high']]] = None,
logprobs: Optional[bool] = None,
top_logprobs: Optional[int] = None,
format: Optional[Union[Literal['', 'json'], JsonSchemaValue]] = None,
options: Optional[Union[Mapping[str, Any], Options]] = None,
keep_alive: Optional[Union[float, str]] = None,
Expand All @@ -298,6 +308,8 @@ def chat(
tools: Optional[Sequence[Union[Mapping[str, Any], Tool, Callable]]] = None,
stream: Literal[True] = True,
think: Optional[Union[bool, Literal['low', 'medium', 'high']]] = None,
logprobs: Optional[bool] = None,
top_logprobs: Optional[int] = None,
format: Optional[Union[Literal['', 'json'], JsonSchemaValue]] = None,
options: Optional[Union[Mapping[str, Any], Options]] = None,
keep_alive: Optional[Union[float, str]] = None,
Expand All @@ -311,6 +323,8 @@ def chat(
tools: Optional[Sequence[Union[Mapping[str, Any], Tool, Callable]]] = None,
stream: bool = False,
think: Optional[Union[bool, Literal['low', 'medium', 'high']]] = None,
logprobs: Optional[bool] = None,
top_logprobs: Optional[int] = None,
format: Optional[Union[Literal['', 'json'], JsonSchemaValue]] = None,
options: Optional[Union[Mapping[str, Any], Options]] = None,
keep_alive: Optional[Union[float, str]] = None,
Expand Down Expand Up @@ -358,6 +372,8 @@ def add_two_numbers(a: int, b: int) -> int:
tools=list(_copy_tools(tools)),
stream=stream,
think=think,
logprobs=logprobs,
top_logprobs=top_logprobs,
format=format,
options=options,
keep_alive=keep_alive,
Expand Down Expand Up @@ -802,6 +818,8 @@ async def generate(
context: Optional[Sequence[int]] = None,
stream: Literal[False] = False,
think: Optional[Union[bool, Literal['low', 'medium', 'high']]] = None,
logprobs: Optional[bool] = None,
top_logprobs: Optional[int] = None,
raw: bool = False,
format: Optional[Union[Literal['', 'json'], JsonSchemaValue]] = None,
images: Optional[Sequence[Union[str, bytes, Image]]] = None,
Expand All @@ -821,6 +839,8 @@ async def generate(
context: Optional[Sequence[int]] = None,
stream: Literal[True] = True,
think: Optional[Union[bool, Literal['low', 'medium', 'high']]] = None,
logprobs: Optional[bool] = None,
top_logprobs: Optional[int] = None,
raw: bool = False,
format: Optional[Union[Literal['', 'json'], JsonSchemaValue]] = None,
images: Optional[Sequence[Union[str, bytes, Image]]] = None,
Expand All @@ -839,6 +859,8 @@ async def generate(
context: Optional[Sequence[int]] = None,
stream: bool = False,
think: Optional[Union[bool, Literal['low', 'medium', 'high']]] = None,
logprobs: Optional[bool] = None,
top_logprobs: Optional[int] = None,
raw: Optional[bool] = None,
format: Optional[Union[Literal['', 'json'], JsonSchemaValue]] = None,
images: Optional[Sequence[Union[str, bytes, Image]]] = None,
Expand Down Expand Up @@ -867,6 +889,8 @@ async def generate(
context=context,
stream=stream,
think=think,
logprobs=logprobs,
top_logprobs=top_logprobs,
raw=raw,
format=format,
images=list(_copy_images(images)) if images else None,
Expand All @@ -885,6 +909,8 @@ async def chat(
tools: Optional[Sequence[Union[Mapping[str, Any], Tool, Callable]]] = None,
stream: Literal[False] = False,
think: Optional[Union[bool, Literal['low', 'medium', 'high']]] = None,
logprobs: Optional[bool] = None,
top_logprobs: Optional[int] = None,
format: Optional[Union[Literal['', 'json'], JsonSchemaValue]] = None,
options: Optional[Union[Mapping[str, Any], Options]] = None,
keep_alive: Optional[Union[float, str]] = None,
Expand All @@ -899,6 +925,8 @@ async def chat(
tools: Optional[Sequence[Union[Mapping[str, Any], Tool, Callable]]] = None,
stream: Literal[True] = True,
think: Optional[Union[bool, Literal['low', 'medium', 'high']]] = None,
logprobs: Optional[bool] = None,
top_logprobs: Optional[int] = None,
format: Optional[Union[Literal['', 'json'], JsonSchemaValue]] = None,
options: Optional[Union[Mapping[str, Any], Options]] = None,
keep_alive: Optional[Union[float, str]] = None,
Expand All @@ -912,6 +940,8 @@ async def chat(
tools: Optional[Sequence[Union[Mapping[str, Any], Tool, Callable]]] = None,
stream: bool = False,
think: Optional[Union[bool, Literal['low', 'medium', 'high']]] = None,
logprobs: Optional[bool] = None,
top_logprobs: Optional[int] = None,
format: Optional[Union[Literal['', 'json'], JsonSchemaValue]] = None,
options: Optional[Union[Mapping[str, Any], Options]] = None,
keep_alive: Optional[Union[float, str]] = None,
Expand Down Expand Up @@ -960,6 +990,8 @@ def add_two_numbers(a: int, b: int) -> int:
tools=list(_copy_tools(tools)),
stream=stream,
think=think,
logprobs=logprobs,
top_logprobs=top_logprobs,
format=format,
options=options,
keep_alive=keep_alive,
Expand Down
31 changes: 31 additions & 0 deletions ollama/_types.py
Original file line number Diff line number Diff line change
Expand Up @@ -210,6 +210,12 @@ class GenerateRequest(BaseGenerateRequest):
think: Optional[Union[bool, Literal['low', 'medium', 'high']]] = None
'Enable thinking mode (for thinking models).'

logprobs: Optional[bool] = None
'Return log probabilities for generated tokens.'

top_logprobs: Optional[int] = None
'Number of alternative tokens and log probabilities to include per position (0-20).'


class BaseGenerateResponse(SubscriptableBaseModel):
model: Optional[str] = None
Expand Down Expand Up @@ -243,6 +249,19 @@ class BaseGenerateResponse(SubscriptableBaseModel):
'Duration of evaluating inference in nanoseconds.'


class TokenLogprob(SubscriptableBaseModel):
token: str
'Token text.'

logprob: float
'Log probability for the token.'


class Logprob(TokenLogprob):
top_logprobs: Optional[Sequence[TokenLogprob]] = None
'Most likely tokens and their log probabilities.'


class GenerateResponse(BaseGenerateResponse):
"""
Response returned by generate requests.
Expand All @@ -257,6 +276,9 @@ class GenerateResponse(BaseGenerateResponse):
context: Optional[Sequence[int]] = None
'Tokenized history up to the point of the response.'

logprobs: Optional[Sequence[Logprob]] = None
'Log probabilities for generated tokens.'


class Message(SubscriptableBaseModel):
"""
Expand Down Expand Up @@ -360,6 +382,12 @@ def serialize_model(self, nxt):
think: Optional[Union[bool, Literal['low', 'medium', 'high']]] = None
'Enable thinking mode (for thinking models).'

logprobs: Optional[bool] = None
'Return log probabilities for generated tokens.'

top_logprobs: Optional[int] = None
'Number of alternative tokens and log probabilities to include per position (0-20).'


class ChatResponse(BaseGenerateResponse):
"""
Expand All @@ -369,6 +397,9 @@ class ChatResponse(BaseGenerateResponse):
message: Message
'Response message.'

logprobs: Optional[Sequence[Logprob]] = None
'Log probabilities for generated tokens if requested.'


class EmbedRequest(BaseRequest):
input: Union[str, Sequence[str]]
Expand Down
2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,7 @@ dependencies = [ 'ruff>=0.9.1' ]
config-path = 'none'

[tool.ruff]
line-length = 999
line-length = 320
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Perhaps we make this a separate change?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Only impacted one file so decided to just leave it in this PR

indent-width = 2

[tool.ruff.format]
Expand Down
72 changes: 72 additions & 0 deletions tests/test_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -61,6 +61,44 @@ def test_client_chat(httpserver: HTTPServer):
assert response['message']['content'] == "I don't know."


def test_client_chat_with_logprobs(httpserver: HTTPServer):
httpserver.expect_ordered_request(
'/api/chat',
method='POST',
json={
'model': 'dummy',
'messages': [{'role': 'user', 'content': 'Hi'}],
'tools': [],
'stream': False,
'logprobs': True,
'top_logprobs': 3,
},
).respond_with_json(
{
'model': 'dummy',
'message': {
'role': 'assistant',
'content': 'Hello',
},
'logprobs': [
{
'token': 'Hello',
'logprob': -0.1,
'top_logprobs': [
{'token': 'Hello', 'logprob': -0.1},
{'token': 'Hi', 'logprob': -1.0},
],
}
],
}
)

client = Client(httpserver.url_for('/'))
response = client.chat('dummy', messages=[{'role': 'user', 'content': 'Hi'}], logprobs=True, top_logprobs=3)
assert response['logprobs'][0]['token'] == 'Hello'
assert response['logprobs'][0]['top_logprobs'][1]['token'] == 'Hi'


def test_client_chat_stream(httpserver: HTTPServer):
def stream_handler(_: Request):
def generate():
Expand Down Expand Up @@ -294,6 +332,40 @@ def test_client_generate(httpserver: HTTPServer):
assert response['response'] == 'Because it is.'


def test_client_generate_with_logprobs(httpserver: HTTPServer):
httpserver.expect_ordered_request(
'/api/generate',
method='POST',
json={
'model': 'dummy',
'prompt': 'Why',
'stream': False,
'logprobs': True,
'top_logprobs': 2,
},
).respond_with_json(
{
'model': 'dummy',
'response': 'Hello',
'logprobs': [
{
'token': 'Hello',
'logprob': -0.2,
'top_logprobs': [
{'token': 'Hello', 'logprob': -0.2},
{'token': 'Hi', 'logprob': -1.5},
],
}
],
}
)

client = Client(httpserver.url_for('/'))
response = client.generate('dummy', 'Why', logprobs=True, top_logprobs=2)
assert response['logprobs'][0]['token'] == 'Hello'
assert response['logprobs'][0]['top_logprobs'][1]['token'] == 'Hi'


def test_client_generate_with_image_type(httpserver: HTTPServer):
httpserver.expect_ordered_request(
'/api/generate',
Expand Down
Loading