Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
11 changes: 11 additions & 0 deletions mellea/stdlib/session.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
import contextvars
from collections.abc import Generator
from contextlib import contextmanager
from copy import deepcopy
from typing import Any, Literal, Optional

from mellea.backends import Backend, BaseModelSubclass
Expand Down Expand Up @@ -453,13 +454,23 @@ def genslot(
Returns:
ModelOutputThunk: Output thunk
"""
generate_logs: list[GenerateLog] = []
result: ModelOutputThunk = self.backend.generate_from_context(
action=gen_slot,
ctx=self.ctx,
model_options=model_options,
format=format,
generate_logs=generate_logs,
tool_calls=tool_calls,
)
# make sure that the last and only Log is marked as the one related to result
assert len(generate_logs) == 1, "Simple call can only add one generate_log"
generate_logs[0].is_final_result = True

self.ctx.insert_turn(
Copy link
Contributor

@nrfulton nrfulton Aug 19, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I have to admit that I don't quite know if I understand how @HendrikStrobelt intends turns to be used. I am going to approve and merge this PR because we need some visibility for genslots.

@HendrikStrobelt : I'll leave it to you to review this use of insert_turn post-facto when he returns and open an issue (assigned to me and @avinash2692 ) if there are any issues. If there are issues, then we need to both fix them and also document the meaning of last_turn vs. last_prompt vs. last_output.

ContextTurn(deepcopy(gen_slot), result), generate_logs=generate_logs
)

return result

def query(
Expand Down
9 changes: 8 additions & 1 deletion test/stdlib_basics/test_genslot.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
import pytest
from typing import Literal
from mellea import generative, start_session
from mellea.stdlib.base import LinearContext


@generative
Expand All @@ -13,7 +14,7 @@ def write_me_an_email() -> str: ...

@pytest.fixture
def session():
return start_session()
return start_session(ctx=LinearContext())


@pytest.fixture
Expand All @@ -34,5 +35,11 @@ def test_sentiment_output(classify_sentiment_output):
assert classify_sentiment_output in ["positive", "negative"]


def test_gen_slot_logs(classify_sentiment_output, session):
sent = classify_sentiment_output
last_prompt = session.last_prompt()[-1]
assert isinstance(last_prompt, dict)
assert set(last_prompt.keys()) == {"role", "content"}

if __name__ == "__main__":
pytest.main([__file__])