From e9cccef725fbd97d738f83220a40a3c90b2f310d Mon Sep 17 00:00:00 2001 From: jakelorocco Date: Tue, 30 Sep 2025 11:45:47 -0400 Subject: [PATCH 01/14] feat: add pytest-asyncio --- pyproject.toml | 3 +- test/backends/test_huggingface.py | 88 ++++++++++++-------------- test/backends/test_litellm_ollama.py | 64 +++++++++---------- test/backends/test_ollama.py | 74 ++++++++++------------ test/backends/test_openai_ollama.py | 76 ++++++++++------------ test/backends/test_watsonx.py | 64 +++++++++---------- test/stdlib_basics/test_requirement.py | 7 +- uv.lock | 25 ++++++++ 8 files changed, 199 insertions(+), 202 deletions(-) diff --git a/pyproject.toml b/pyproject.toml index 4a02f614..bd1652bc 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -89,6 +89,7 @@ dev = [ "ruff>=0.11.6", "pdm>=2.24.0", "pytest", + "pytest-asyncio", "mypy>=1.17.0", "python-semantic-release~=7.32", ] @@ -176,7 +177,7 @@ python_version = "3.10" markers = [ "qualitative: Marks the test as needing an exact output from an LLM; set by an ENV variable for CICD. All tests marked with this will xfail in CI/CD" ] - +asyncio_mode = "auto" # Don't require explicitly marking async tests. [tool.semantic_release] # for default values check: diff --git a/test/backends/test_huggingface.py b/test/backends/test_huggingface.py index 4e3ab441..3f40d4cd 100644 --- a/test/backends/test_huggingface.py +++ b/test/backends/test_huggingface.py @@ -47,7 +47,7 @@ def test_system_prompt(session): print(result) @pytest.mark.qualitative -def test_constraint_alora(session, backend): +async def test_constraint_alora(session, backend): answer = session.instruct( "Corporate wants you to find the difference between these two strings: aaaaaaaaaa aaaaabaaaa. Be concise and don't write code to answer the question.", model_options={ @@ -55,18 +55,16 @@ def test_constraint_alora(session, backend): }, # Until aloras get a bit better, try not to abruptly end generation. ) - async def alora_generate(): - alora_output = backend.get_aloras()[ - 0 - ].generate_using_strings( - input="Find the difference between these two strings: aaaaaaaaaa aaaaabaaaa", - response=str(answer), - constraint="The answer mention that there is a b in the middle of one of the strings but not the other.", - force_yn=False, # make sure that the alora naturally output Y and N without constrained generation - ) - await alora_output.avalue() - assert alora_output.value in ["Y", "N"], alora_output - asyncio.run(alora_generate()) + alora_output = backend.get_aloras()[ + 0 + ].generate_using_strings( + input="Find the difference between these two strings: aaaaaaaaaa aaaaabaaaa", + response=str(answer), + constraint="The answer mention that there is a b in the middle of one of the strings but not the other.", + force_yn=False, # make sure that the alora naturally output Y and N without constrained generation + ) + await alora_output.avalue() + assert alora_output.value in ["Y", "N"], alora_output @pytest.mark.qualitative def test_constraint_lora_with_requirement(session, backend): @@ -226,42 +224,38 @@ class Answer(pydantic.BaseModel): ) @pytest.mark.qualitative -def test_async_parallel_requests(session): - async def parallel_requests(): - model_opts = {ModelOption.STREAM: True} - mot1, _ = session.backend.generate_from_context(CBlock("Say Hello."), SimpleContext(), model_options=model_opts) - mot2, _ = session.backend.generate_from_context(CBlock("Say Goodbye!"), SimpleContext(), model_options=model_opts) - - m1_val = None - m2_val = None - if not mot1.is_computed(): - m1_val = await mot1.astream() - if not mot2.is_computed(): - m2_val = await mot2.astream() - - assert m1_val is not None, "should be a string val after generation" - assert m2_val is not None, "should be a string val after generation" - - m1_final_val = await mot1.avalue() - m2_final_val = await mot2.avalue() - - # Ideally, we would be able to assert that m1_final_val != m1_val, but sometimes the first streaming response - # contains the full response. - assert m1_final_val.startswith(m1_val), "final val should contain the first streamed chunk" - assert m2_final_val.startswith(m2_val), "final val should contain the first streamed chunk" - - assert m1_final_val == mot1.value - assert m2_final_val == mot2.value - asyncio.run(parallel_requests()) +async def test_async_parallel_requests(session): + model_opts = {ModelOption.STREAM: True} + mot1, _ = session.backend.generate_from_context(CBlock("Say Hello."), SimpleContext(), model_options=model_opts) + mot2, _ = session.backend.generate_from_context(CBlock("Say Goodbye!"), SimpleContext(), model_options=model_opts) + + m1_val = None + m2_val = None + if not mot1.is_computed(): + m1_val = await mot1.astream() + if not mot2.is_computed(): + m2_val = await mot2.astream() + + assert m1_val is not None, "should be a string val after generation" + assert m2_val is not None, "should be a string val after generation" + + m1_final_val = await mot1.avalue() + m2_final_val = await mot2.avalue() + + # Ideally, we would be able to assert that m1_final_val != m1_val, but sometimes the first streaming response + # contains the full response. + assert m1_final_val.startswith(m1_val), "final val should contain the first streamed chunk" + assert m2_final_val.startswith(m2_val), "final val should contain the first streamed chunk" + + assert m1_final_val == mot1.value + assert m2_final_val == mot2.value @pytest.mark.qualitative -def test_async_avalue(session): - async def avalue(): - mot1, _ = session.backend.generate_from_context(CBlock("Say Hello."), SimpleContext()) - m1_final_val = await mot1.avalue() - assert m1_final_val is not None - assert m1_final_val == mot1.value - asyncio.run(avalue()) +async def test_async_avalue(session): + mot1, _ = session.backend.generate_from_context(CBlock("Say Hello."), SimpleContext()) + m1_final_val = await mot1.avalue() + assert m1_final_val is not None + assert m1_final_val == mot1.value if __name__ == "__main__": import pytest diff --git a/test/backends/test_litellm_ollama.py b/test/backends/test_litellm_ollama.py index 7999e4cf..a7f4879d 100644 --- a/test/backends/test_litellm_ollama.py +++ b/test/backends/test_litellm_ollama.py @@ -78,42 +78,38 @@ def is_happy(text: str) -> bool: assert h is True @pytest.mark.qualitative -def test_async_parallel_requests(session): - async def parallel_requests(): - model_opts = {ModelOption.STREAM: True} - mot1, _ = session.backend.generate_from_context(CBlock("Say Hello."), SimpleContext(), model_options=model_opts) - mot2, _ = session.backend.generate_from_context(CBlock("Say Goodbye!"), SimpleContext(), model_options=model_opts) - - m1_val = None - m2_val = None - if not mot1.is_computed(): - m1_val = await mot1.astream() - if not mot2.is_computed(): - m2_val = await mot2.astream() - - assert m1_val is not None, "should be a string val after generation" - assert m2_val is not None, "should be a string val after generation" - - m1_final_val = await mot1.avalue() - m2_final_val = await mot2.avalue() - - # Ideally, we would be able to assert that m1_final_val != m1_val, but sometimes the first streaming response - # contains the full response. - assert m1_final_val.startswith(m1_val), "final val should contain the first streamed chunk" - assert m2_final_val.startswith(m2_val), "final val should contain the first streamed chunk" - - assert m1_final_val == mot1.value - assert m2_final_val == mot2.value - asyncio.run(parallel_requests()) +async def test_async_parallel_requests(session): + model_opts = {ModelOption.STREAM: True} + mot1, _ = session.backend.generate_from_context(CBlock("Say Hello."), SimpleContext(), model_options=model_opts) + mot2, _ = session.backend.generate_from_context(CBlock("Say Goodbye!"), SimpleContext(), model_options=model_opts) + + m1_val = None + m2_val = None + if not mot1.is_computed(): + m1_val = await mot1.astream() + if not mot2.is_computed(): + m2_val = await mot2.astream() + + assert m1_val is not None, "should be a string val after generation" + assert m2_val is not None, "should be a string val after generation" + + m1_final_val = await mot1.avalue() + m2_final_val = await mot2.avalue() + + # Ideally, we would be able to assert that m1_final_val != m1_val, but sometimes the first streaming response + # contains the full response. + assert m1_final_val.startswith(m1_val), "final val should contain the first streamed chunk" + assert m2_final_val.startswith(m2_val), "final val should contain the first streamed chunk" + + assert m1_final_val == mot1.value + assert m2_final_val == mot2.value @pytest.mark.qualitative -def test_async_avalue(session): - async def avalue(): - mot1, _ = session.backend.generate_from_context(CBlock("Say Hello."), SimpleContext()) - m1_final_val = await mot1.avalue() - assert m1_final_val is not None - assert m1_final_val == mot1.value - asyncio.run(avalue()) +async def test_async_avalue(session): + mot1, _ = session.backend.generate_from_context(CBlock("Say Hello."), SimpleContext()) + m1_final_val = await mot1.avalue() + assert m1_final_val is not None + assert m1_final_val == mot1.value if __name__ == "__main__": import pytest diff --git a/test/backends/test_ollama.py b/test/backends/test_ollama.py index 806747f4..1362019a 100644 --- a/test/backends/test_ollama.py +++ b/test/backends/test_ollama.py @@ -131,46 +131,40 @@ class Answer(pydantic.BaseModel): ) -def test_async_parallel_requests(session): - async def parallel_requests(): - model_opts = {ModelOption.STREAM: True} - mot1, _ = session.backend.generate_from_context(CBlock("Say Hello."), SimpleContext(), - model_options=model_opts) - mot2, _ = session.backend.generate_from_context(CBlock("Say Goodbye!"), SimpleContext(), - model_options=model_opts) - - m1_val = None - m2_val = None - if not mot1.is_computed(): - m1_val = await mot1.astream() - if not mot2.is_computed(): - m2_val = await mot2.astream() - - assert m1_val is not None, "should be a string val after generation" - assert m2_val is not None, "should be a string val after generation" - - m1_final_val = await mot1.avalue() - m2_final_val = await mot2.avalue() - - # Ideally, we would be able to assert that m1_final_val != m1_val, but sometimes the first streaming response - # contains the full response. - assert m1_final_val.startswith(m1_val), "final val should contain the first streamed chunk" - assert m2_final_val.startswith(m2_val), "final val should contain the first streamed chunk" - - assert m1_final_val == mot1.value - assert m2_final_val == mot2.value - - asyncio.run(parallel_requests()) - - -def test_async_avalue(session): - async def avalue(): - mot1, _ = session.backend.generate_from_context(CBlock("Say Hello."), SimpleContext()) - m1_final_val = await mot1.avalue() - assert m1_final_val is not None - assert m1_final_val == mot1.value - - asyncio.run(avalue()) +async def test_async_parallel_requests(session): + model_opts = {ModelOption.STREAM: True} + mot1, _ = session.backend.generate_from_context(CBlock("Say Hello."), SimpleContext(), + model_options=model_opts) + mot2, _ = session.backend.generate_from_context(CBlock("Say Goodbye!"), SimpleContext(), + model_options=model_opts) + + m1_val = None + m2_val = None + if not mot1.is_computed(): + m1_val = await mot1.astream() + if not mot2.is_computed(): + m2_val = await mot2.astream() + + assert m1_val is not None, "should be a string val after generation" + assert m2_val is not None, "should be a string val after generation" + + m1_final_val = await mot1.avalue() + m2_final_val = await mot2.avalue() + + # Ideally, we would be able to assert that m1_final_val != m1_val, but sometimes the first streaming response + # contains the full response. + assert m1_final_val.startswith(m1_val), "final val should contain the first streamed chunk" + assert m2_final_val.startswith(m2_val), "final val should contain the first streamed chunk" + + assert m1_final_val == mot1.value + assert m2_final_val == mot2.value + + +async def test_async_avalue(session): + mot1, _ = session.backend.generate_from_context(CBlock("Say Hello."), SimpleContext()) + m1_final_val = await mot1.avalue() + assert m1_final_val is not None + assert m1_final_val == mot1.value if __name__ == "__main__": diff --git a/test/backends/test_openai_ollama.py b/test/backends/test_openai_ollama.py index 77487c6c..d2e24970 100644 --- a/test/backends/test_openai_ollama.py +++ b/test/backends/test_openai_ollama.py @@ -144,54 +144,48 @@ class Email(pydantic.BaseModel): # assert False, f"formatting directive failed for {random_result.value}: {e.json()}" -def test_async_parallel_requests(m_session): - async def parallel_requests(): - model_opts = {ModelOption.STREAM: True} - mot1, _ = m_session.backend.generate_from_context( - CBlock("Say Hello."), SimpleContext(), model_options=model_opts - ) - mot2, _ = m_session.backend.generate_from_context( - CBlock("Say Goodbye!"),SimpleContext(), model_options=model_opts - ) - - m1_val = None - m2_val = None - if not mot1.is_computed(): - m1_val = await mot1.astream() - if not mot2.is_computed(): - m2_val = await mot2.astream() - - assert m1_val is not None, "should be a string val after generation" - assert m2_val is not None, "should be a string val after generation" +async def test_async_parallel_requests(m_session): + model_opts = {ModelOption.STREAM: True} + mot1, _ = m_session.backend.generate_from_context( + CBlock("Say Hello."), SimpleContext(), model_options=model_opts + ) + mot2, _ = m_session.backend.generate_from_context( + CBlock("Say Goodbye!"),SimpleContext(), model_options=model_opts + ) - m1_final_val = await mot1.avalue() - m2_final_val = await mot2.avalue() + m1_val = None + m2_val = None + if not mot1.is_computed(): + m1_val = await mot1.astream() + if not mot2.is_computed(): + m2_val = await mot2.astream() - # Ideally, we would be able to assert that m1_final_val != m1_val, but sometimes the first streaming response - # contains the full response. - assert m1_final_val.startswith(m1_val), ( - "final val should contain the first streamed chunk" - ) - assert m2_final_val.startswith(m2_val), ( - "final val should contain the first streamed chunk" - ) + assert m1_val is not None, "should be a string val after generation" + assert m2_val is not None, "should be a string val after generation" - assert m1_final_val == mot1.value - assert m2_final_val == mot2.value + m1_final_val = await mot1.avalue() + m2_final_val = await mot2.avalue() - asyncio.run(parallel_requests()) + # Ideally, we would be able to assert that m1_final_val != m1_val, but sometimes the first streaming response + # contains the full response. + assert m1_final_val.startswith(m1_val), ( + "final val should contain the first streamed chunk" + ) + assert m2_final_val.startswith(m2_val), ( + "final val should contain the first streamed chunk" + ) + assert m1_final_val == mot1.value + assert m2_final_val == mot2.value -def test_async_avalue(m_session): - async def avalue(): - mot1, _ = m_session.backend.generate_from_context( - CBlock("Say Hello."), SimpleContext() - ) - m1_final_val = await mot1.avalue() - assert m1_final_val is not None - assert m1_final_val == mot1.value - asyncio.run(avalue()) +async def test_async_avalue(m_session): + mot1, _ = m_session.backend.generate_from_context( + CBlock("Say Hello."), SimpleContext() + ) + m1_final_val = await mot1.avalue() + assert m1_final_val is not None + assert m1_final_val == mot1.value if __name__ == "__main__": diff --git a/test/backends/test_watsonx.py b/test/backends/test_watsonx.py index 12ec10d3..907d4575 100644 --- a/test/backends/test_watsonx.py +++ b/test/backends/test_watsonx.py @@ -100,42 +100,38 @@ def test_generate_from_raw(session: MelleaSession): assert len(results) == len(prompts) @pytest.mark.qualitative -def test_async_parallel_requests(session): - async def parallel_requests(): - model_opts = {ModelOption.STREAM: True} - mot1, _ = session.backend.generate_from_context(CBlock("Say Hello."), SimpleContext(), model_options=model_opts) - mot2, _ = session.backend.generate_from_context(CBlock("Say Goodbye!"), SimpleContext(), model_options=model_opts) - - m1_val = None - m2_val = None - if not mot1.is_computed(): - m1_val = await mot1.astream() - if not mot2.is_computed(): - m2_val = await mot2.astream() - - assert m1_val is not None, "should be a string val after generation" - assert m2_val is not None, "should be a string val after generation" - - m1_final_val = await mot1.avalue() - m2_final_val = await mot2.avalue() - - # Ideally, we would be able to assert that m1_final_val != m1_val, but sometimes the first streaming response - # contains the full response. - assert m1_final_val.startswith(m1_val), "final val should contain the first streamed chunk" - assert m2_final_val.startswith(m2_val), "final val should contain the first streamed chunk" - - assert m1_final_val == mot1.value - assert m2_final_val == mot2.value - asyncio.run(parallel_requests()) +async def test_async_parallel_requests(session): + model_opts = {ModelOption.STREAM: True} + mot1, _ = session.backend.generate_from_context(CBlock("Say Hello."), SimpleContext(), model_options=model_opts) + mot2, _ = session.backend.generate_from_context(CBlock("Say Goodbye!"), SimpleContext(), model_options=model_opts) + + m1_val = None + m2_val = None + if not mot1.is_computed(): + m1_val = await mot1.astream() + if not mot2.is_computed(): + m2_val = await mot2.astream() + + assert m1_val is not None, "should be a string val after generation" + assert m2_val is not None, "should be a string val after generation" + + m1_final_val = await mot1.avalue() + m2_final_val = await mot2.avalue() + + # Ideally, we would be able to assert that m1_final_val != m1_val, but sometimes the first streaming response + # contains the full response. + assert m1_final_val.startswith(m1_val), "final val should contain the first streamed chunk" + assert m2_final_val.startswith(m2_val), "final val should contain the first streamed chunk" + + assert m1_final_val == mot1.value + assert m2_final_val == mot2.value @pytest.mark.qualitative -def test_async_avalue(session): - async def avalue(): - mot1, _ = session.backend.generate_from_context(CBlock("Say Hello."), SimpleContext()) - m1_final_val = await mot1.avalue() - assert m1_final_val is not None - assert m1_final_val == mot1.value - asyncio.run(avalue()) +async def test_async_avalue(session): + mot1, _ = session.backend.generate_from_context(CBlock("Say Hello."), SimpleContext()) + m1_final_val = await mot1.avalue() + assert m1_final_val is not None + assert m1_final_val == mot1.value if __name__ == "__main__": import pytest diff --git a/test/stdlib_basics/test_requirement.py b/test/stdlib_basics/test_requirement.py index a1bef684..f569308d 100644 --- a/test/stdlib_basics/test_requirement.py +++ b/test/stdlib_basics/test_requirement.py @@ -7,15 +7,12 @@ ctx = ChatContext() ctx = ctx.add(ModelOutputThunk("test")) -def test_llmaj_validation_req_output_field(): +async def test_llmaj_validation_req_output_field(): m = start_session(ctx=ctx) req = Requirement("Must output test.") assert req._output is None - async def val(): - _ = await req.validate(m.backend,ctx=ctx) - asyncio.run(val()) - + _ = await req.validate(m.backend,ctx=ctx) assert req._output is None, "requirement's output shouldn't be updated during/after validation" def test_simple_validate_bool(): diff --git a/uv.lock b/uv.lock index dc920ad6..f4831e93 100644 --- a/uv.lock +++ b/uv.lock @@ -394,6 +394,15 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/b7/b8/3fe70c75fe32afc4bb507f75563d39bc5642255d1d94f1f23604725780bf/babel-2.17.0-py3-none-any.whl", hash = "sha256:4d0b53093fdfb4b21c92b5213dba5a1b23885afa8383709427046b21c366e5f2", size = 10182537, upload-time = "2025-02-01T15:17:37.39Z" }, ] +[[package]] +name = "backports-asyncio-runner" +version = "1.2.0" +source = { registry = "https://pypi.org/simple" } +sdist = { url = "https://files.pythonhosted.org/packages/8e/ff/70dca7d7cb1cbc0edb2c6cc0c38b65cba36cccc491eca64cabd5fe7f8670/backports_asyncio_runner-1.2.0.tar.gz", hash = "sha256:a5aa7b2b7d8f8bfcaa2b57313f70792df84e32a2a746f585213373f900b42162", size = 69893, upload-time = "2025-07-02T02:27:15.685Z" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/a0/59/76ab57e3fe74484f48a53f8e337171b4a2349e506eabe136d7e01d059086/backports_asyncio_runner-1.2.0-py3-none-any.whl", hash = "sha256:0da0a936a8aeb554eccb426dc55af3ba63bcdc69fa1a600b5bb305413a4477b5", size = 12313, upload-time = "2025-07-02T02:27:14.263Z" }, +] + [[package]] name = "backports-tarfile" version = "1.2.0" @@ -2425,6 +2434,7 @@ dev = [ { name = "pre-commit" }, { name = "pylint" }, { name = "pytest" }, + { name = "pytest-asyncio" }, { name = "python-semantic-release" }, { name = "ruff" }, ] @@ -2483,6 +2493,7 @@ dev = [ { name = "pre-commit", specifier = ">=4.2.0" }, { name = "pylint", specifier = ">=3.3.4" }, { name = "pytest" }, + { name = "pytest-asyncio" }, { name = "python-semantic-release", specifier = "~=7.32" }, { name = "ruff", specifier = ">=0.11.6" }, ] @@ -4098,6 +4109,20 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/a8/a4/20da314d277121d6534b3a980b29035dcd51e6744bd79075a6ce8fa4eb8d/pytest-8.4.2-py3-none-any.whl", hash = "sha256:872f880de3fc3a5bdc88a11b39c9710c3497a547cfa9320bc3c5e62fbf272e79", size = 365750, upload-time = "2025-09-04T14:34:20.226Z" }, ] +[[package]] +name = "pytest-asyncio" +version = "1.2.0" +source = { registry = "https://pypi.org/simple" } +dependencies = [ + { name = "backports-asyncio-runner", marker = "python_full_version < '3.11'" }, + { name = "pytest" }, + { name = "typing-extensions", marker = "python_full_version < '3.13'" }, +] +sdist = { url = "https://files.pythonhosted.org/packages/42/86/9e3c5f48f7b7b638b216e4b9e645f54d199d7abbbab7a64a13b4e12ba10f/pytest_asyncio-1.2.0.tar.gz", hash = "sha256:c609a64a2a8768462d0c99811ddb8bd2583c33fd33cf7f21af1c142e824ffb57", size = 50119, upload-time = "2025-09-12T07:33:53.816Z" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/04/93/2fa34714b7a4ae72f2f8dad66ba17dd9a2c793220719e736dda28b7aec27/pytest_asyncio-1.2.0-py3-none-any.whl", hash = "sha256:8e17ae5e46d8e7efe51ab6494dd2010f4ca8dae51652aa3c8d55acf50bfb2e99", size = 15095, upload-time = "2025-09-12T07:33:52.639Z" }, +] + [[package]] name = "python-bidi" version = "0.6.6" From ae3f7b516f9ff513c3873d24498e8b7755a98c4e Mon Sep 17 00:00:00 2001 From: jakelorocco Date: Tue, 30 Sep 2025 13:00:15 -0400 Subject: [PATCH 02/14] feat: add async funcs and session funcs --- mellea/stdlib/funcs.py | 411 ++++++++++++++++++++++++---- mellea/stdlib/sampling/base.py | 2 +- mellea/stdlib/sampling/best_of_n.py | 2 +- mellea/stdlib/session.py | 260 +++++++++++++++++- test/stdlib_basics/test_funcs.py | 47 +++- 5 files changed, 662 insertions(+), 60 deletions(-) diff --git a/mellea/stdlib/funcs.py b/mellea/stdlib/funcs.py index 47607b53..f3739dc2 100644 --- a/mellea/stdlib/funcs.py +++ b/mellea/stdlib/funcs.py @@ -40,6 +40,7 @@ def act( context: Context, backend: Backend, *, + requirements: list[Requirement] | None = None, strategy: SamplingStrategy | None = RejectionSamplingStrategy(loop_budget=2), return_sampling_results: Literal[False] = False, format: type[BaseModelSubclass] | None = None, @@ -54,6 +55,7 @@ def act( context: Context, backend: Backend, *, + requirements: list[Requirement] | None = None, strategy: SamplingStrategy | None = RejectionSamplingStrategy(loop_budget=2), return_sampling_results: Literal[True], format: type[BaseModelSubclass] | None = None, @@ -88,10 +90,10 @@ def act( tool_calls: if true, tool calling is enabled. Returns: - A ModelOutputThunk if `return_sampling_results` is `False`, else returns a `SamplingResult`. + A (ModelOutputThunk, Context) if `return_sampling_results` is `False`, else returns a `SamplingResult`. """ out = _run_async_in_thread( - _act( + aact( action, context, backend, @@ -101,13 +103,347 @@ def act( format=format, model_options=model_options, tool_calls=tool_calls, + ) # type: ignore[call-overload] + # Mypy doesn't like the bool for return_sampling_results. + ) + + return out + + +@overload +def instruct( + description: str, + context: Context, + backend: Backend, + *, + images: list[ImageBlock] | list[PILImage.Image] | None = None, + requirements: list[Requirement | str] | None = None, + icl_examples: list[str | CBlock] | None = None, + grounding_context: dict[str, str | CBlock | Component] | None = None, + user_variables: dict[str, str] | None = None, + prefix: str | CBlock | None = None, + output_prefix: str | CBlock | None = None, + strategy: SamplingStrategy | None = RejectionSamplingStrategy(loop_budget=2), + return_sampling_results: Literal[False] = False, + format: type[BaseModelSubclass] | None = None, + model_options: dict | None = None, + tool_calls: bool = False, +) -> tuple[ModelOutputThunk, Context]: ... + + +@overload +def instruct( + description: str, + context: Context, + backend: Backend, + *, + images: list[ImageBlock] | list[PILImage.Image] | None = None, + requirements: list[Requirement | str] | None = None, + icl_examples: list[str | CBlock] | None = None, + grounding_context: dict[str, str | CBlock | Component] | None = None, + user_variables: dict[str, str] | None = None, + prefix: str | CBlock | None = None, + output_prefix: str | CBlock | None = None, + strategy: SamplingStrategy | None = RejectionSamplingStrategy(loop_budget=2), + return_sampling_results: Literal[True], + format: type[BaseModelSubclass] | None = None, + model_options: dict | None = None, + tool_calls: bool = False, +) -> SamplingResult: ... + + +def instruct( + description: str, + context: Context, + backend: Backend, + *, + images: list[ImageBlock] | list[PILImage.Image] | None = None, + requirements: list[Requirement | str] | None = None, + icl_examples: list[str | CBlock] | None = None, + grounding_context: dict[str, str | CBlock | Component] | None = None, + user_variables: dict[str, str] | None = None, + prefix: str | CBlock | None = None, + output_prefix: str | CBlock | None = None, + strategy: SamplingStrategy | None = RejectionSamplingStrategy(loop_budget=2), + return_sampling_results: bool = False, + format: type[BaseModelSubclass] | None = None, + model_options: dict | None = None, + tool_calls: bool = False, +) -> tuple[ModelOutputThunk, Context] | SamplingResult: + """Generates from an instruction. + + Args: + description: The description of the instruction. + context: the context being used as a history from which to generate the response. + backend: the backend used to generate the response. + requirements: A list of requirements that the instruction can be validated against. + icl_examples: A list of in-context-learning examples that the instruction can be validated against. + grounding_context: A list of grounding contexts that the instruction can use. They can bind as variables using a (key: str, value: str | ContentBlock) tuple. + user_variables: A dict of user-defined variables used to fill in Jinja placeholders in other parameters. This requires that all other provided parameters are provided as strings. + prefix: A prefix string or ContentBlock to use when generating the instruction. + output_prefix: A string or ContentBlock that defines a prefix for the output generation. Usually you do not need this. + strategy: A SamplingStrategy that describes the strategy for validating and repairing/retrying for the instruct-validate-repair pattern. None means that no particular sampling strategy is used. + return_sampling_results: attach the (successful and failed) sampling attempts to the results. + format: If set, the BaseModel to use for constrained decoding. + model_options: Additional model options, which will upsert into the model/backend's defaults. + tool_calls: If true, tool calling is enabled. + images: A list of images to be used in the instruction or None if none. + + Returns: + A (ModelOutputThunk, Context) if `return_sampling_results` is `False`, else returns a `SamplingResult`. + """ + requirements = [] if requirements is None else requirements + icl_examples = [] if icl_examples is None else icl_examples + grounding_context = dict() if grounding_context is None else grounding_context + + images = _parse_and_clean_image_args(images) + + # All instruction options are forwarded to create a new Instruction object. + i = Instruction( + description=description, + requirements=requirements, + icl_examples=icl_examples, + grounding_context=grounding_context, + user_variables=user_variables, + prefix=prefix, + output_prefix=output_prefix, + images=images, + ) + + return act( + i, + context=context, + backend=backend, + requirements=i.requirements, + strategy=strategy, + return_sampling_results=return_sampling_results, + format=format, + model_options=model_options, + tool_calls=tool_calls, + ) # type: ignore[call-overload] + + +def chat( + content: str, + context: Context, + backend: Backend, + *, + role: Message.Role = "user", + images: list[ImageBlock] | list[PILImage.Image] | None = None, + user_variables: dict[str, str] | None = None, + format: type[BaseModelSubclass] | None = None, + model_options: dict | None = None, + tool_calls: bool = False, +) -> tuple[Message, Context]: + """Sends a simple chat message and returns the response. Adds both messages to the Context.""" + if user_variables is not None: + content_resolved = Instruction.apply_user_dict_from_jinja( + user_variables, content + ) + else: + content_resolved = content + images = _parse_and_clean_image_args(images) + user_message = Message(role=role, content=content_resolved, images=images) + + result, new_ctx = act( + user_message, + context=context, + backend=backend, + strategy=None, # Explicitly pass `None` since this can't pass requirements. + format=format, + model_options=model_options, + tool_calls=tool_calls, + ) + parsed_assistant_message = result.parsed_repr + assert isinstance(parsed_assistant_message, Message) + + return parsed_assistant_message, new_ctx + + +def validate( + reqs: Requirement | list[Requirement], + context: Context, + backend: Backend, + *, + output: CBlock | None = None, + format: type[BaseModelSubclass] | None = None, + model_options: dict | None = None, + generate_logs: list[GenerateLog] + | None = None, # TODO: Can we get rid of gen logs here and in act? + input: CBlock | None = None, +) -> list[ValidationResult]: + """Validates a set of requirements over the output (if provided) or the current context (if the output is not provided).""" + # Run everything in the specific event loop for this session. + + out = _run_async_in_thread( + avalidate( + reqs=reqs, + context=context, + backend=backend, + output=output, + format=format, + model_options=model_options, + generate_logs=generate_logs, + input=input, ) ) + # Wait for and return the result. return out -async def _act( +def query( + obj: Any, + query: str, + context: Context, + backend: Backend, + *, + format: type[BaseModelSubclass] | None = None, + model_options: dict | None = None, + tool_calls: bool = False, +) -> tuple[ModelOutputThunk, Context]: + """Query method for retrieving information from an object. + + Args: + obj : The object to be queried. It should be an instance of MObject or can be converted to one if necessary. + query: The string representing the query to be executed against the object. + context: the context being used as a history from which to generate the response. + backend: the backend used to generate the response. + format: format for output parsing. + model_options: Model options to pass to the backend. + tool_calls: If true, the model may make tool calls. Defaults to False. + + Returns: + ModelOutputThunk: The result of the query as processed by the backend. + """ + if not isinstance(obj, MObjectProtocol): + obj = mify(obj) + + assert isinstance(obj, MObjectProtocol) + q = obj.get_query_object(query) + + answer = act( + q, + context=context, + backend=backend, + strategy=None, # Explicitly pass `None` since this can't pass requirements. + format=format, + model_options=model_options, + tool_calls=tool_calls, + ) + return answer + + +def transform( + obj: Any, + transformation: str, + context: Context, + backend: Backend, + *, + format: type[BaseModelSubclass] | None = None, + model_options: dict | None = None, +) -> tuple[ModelOutputThunk | Any, Context]: + """Transform method for creating a new object with the transformation applied. + + Args: + obj : The object to be queried. It should be an instance of MObject or can be converted to one if necessary. + transformation: The string representing the query to be executed against the object. + context: the context being used as a history from which to generate the response. + backend: the backend used to generate the response. + + Returns: + (ModelOutputThunk | Any, Context): The result of the transformation as processed by the backend. If no tools were called, + the return type will be always be (ModelOutputThunk, Context). If a tool was called, the return type will be the return type + of the function called, usually the type of the object passed in. + """ + if not isinstance(obj, MObjectProtocol): + obj = mify(obj) + + assert isinstance(obj, MObjectProtocol) + t = obj.get_transform_object(transformation) + + # Check that your model / backend supports tool calling. + # This might throw an error when tools are provided but can't be handled by one or the other. + transformed, new_ctx = act( + t, + context=context, + backend=backend, + strategy=None, # Explicitly pass `None` since this can't pass requirements. + format=format, + model_options=model_options, + tool_calls=True, + ) + + tools = _call_tools(transformed, backend) + + # Transform only supports calling one tool call since it cannot currently synthesize multiple outputs. + # Attempt to choose the best one to call. + chosen_tool: ToolMessage | None = None + if len(tools) == 1: + # Only one function was called. Choose that one. + chosen_tool = tools[0] + + elif len(tools) > 1: + for output in tools: + if type(output._tool_output) is type(obj): + chosen_tool = output + break + + if chosen_tool is None: + chosen_tool = tools[0] + + FancyLogger.get_logger().warning( + f"multiple tool calls returned in transform of {obj} with description '{transformation}'; picked `{chosen_tool.name}`" + # type: ignore + ) + + if chosen_tool: + # Tell the user the function they should've called if no generated values were added. + if len(chosen_tool._tool.args.keys()) == 0: + FancyLogger.get_logger().warning( + f"the transform of {obj} with transformation description '{transformation}' resulted in a tool call with no generated arguments; consider calling the function `{chosen_tool._tool.name}` directly" + ) + + new_ctx.add(chosen_tool) + FancyLogger.get_logger().info( + "added a tool message from transform to the context" + ) + return chosen_tool._tool_output, new_ctx + + return transformed, new_ctx + + +@overload +async def aact( + action: Component, + context: Context, + backend: Backend, + *, + requirements: list[Requirement] | None = None, + strategy: SamplingStrategy | None = RejectionSamplingStrategy(loop_budget=2), + return_sampling_results: Literal[False] = False, + format: type[BaseModelSubclass] | None = None, + model_options: dict | None = None, + tool_calls: bool = False, +) -> tuple[ModelOutputThunk, Context]: ... + + +@overload +async def aact( + action: Component, + context: Context, + backend: Backend, + *, + requirements: list[Requirement] | None = None, + strategy: SamplingStrategy | None = RejectionSamplingStrategy(loop_budget=2), + return_sampling_results: Literal[True], + format: type[BaseModelSubclass] | None = None, + model_options: dict | None = None, + tool_calls: bool = False, +) -> SamplingResult: ... + + +async def aact( action: Component, context: Context, backend: Backend, @@ -133,7 +469,7 @@ async def _act( tool_calls: if true, tool calling is enabled. Returns: - A ModelOutputThunk if `return_sampling_results` is `False`, else returns a `SamplingResult`. + A (ModelOutputThunk, Context) if `return_sampling_results` is `False`, else returns a `SamplingResult`. """ sampling_result: SamplingResult | None = None generate_logs: list[GenerateLog] = [] @@ -165,7 +501,8 @@ async def _act( generate_logs.append(result._generate_log) else: - # if there is a reason to sample, use the sampling strategy. + # Always sample if a strategy is provided, even if no requirements were provided. + # Some sampling strategies don't use requirements or set them when instantiated. sampling_result = await strategy.sample( action, @@ -200,7 +537,7 @@ async def _act( @overload -def instruct( +async def ainstruct( description: str, context: Context, backend: Backend, @@ -221,7 +558,7 @@ def instruct( @overload -def instruct( +async def ainstruct( description: str, context: Context, backend: Backend, @@ -241,7 +578,7 @@ def instruct( ) -> SamplingResult: ... -def instruct( +async def ainstruct( description: str, context: Context, backend: Backend, @@ -277,7 +614,11 @@ def instruct( model_options: Additional model options, which will upsert into the model/backend's defaults. tool_calls: If true, tool calling is enabled. images: A list of images to be used in the instruction or None if none. + + Returns: + A (ModelOutputThunk, Context) if `return_sampling_results` is `False`, else returns a `SamplingResult`. """ + requirements = [] if requirements is None else requirements icl_examples = [] if icl_examples is None else icl_examples grounding_context = dict() if grounding_context is None else grounding_context @@ -296,7 +637,7 @@ def instruct( images=images, ) - return act( + return await aact( i, context=context, backend=backend, @@ -309,7 +650,7 @@ def instruct( ) # type: ignore[call-overload] -def chat( +async def achat( content: str, context: Context, backend: Backend, @@ -331,11 +672,11 @@ def chat( images = _parse_and_clean_image_args(images) user_message = Message(role=role, content=content_resolved, images=images) - result, new_ctx = act( + result, new_ctx = await aact( user_message, context=context, backend=backend, - strategy=None, + strategy=None, # Explicitly pass `None` since this can't pass requirements. format=format, model_options=model_options, tool_calls=tool_calls, @@ -346,39 +687,7 @@ def chat( return parsed_assistant_message, new_ctx -def validate( - reqs: Requirement | list[Requirement], - context: Context, - backend: Backend, - *, - output: CBlock | None = None, - format: type[BaseModelSubclass] | None = None, - model_options: dict | None = None, - generate_logs: list[GenerateLog] - | None = None, # TODO: Can we get rid of gen logs here and in act? - input: CBlock | None = None, -) -> list[ValidationResult]: - """Validates a set of requirements over the output (if provided) or the current context (if the output is not provided).""" - # Run everything in the specific event loop for this session. - - out = _run_async_in_thread( - _validate( - reqs=reqs, - context=context, - backend=backend, - output=output, - format=format, - model_options=model_options, - generate_logs=generate_logs, - input=input, - ) - ) - - # Wait for and return the result. - return out - - -async def _validate( +async def avalidate( reqs: Requirement | list[Requirement], context: Context, backend: Backend, @@ -435,7 +744,7 @@ async def _validate( return rvs -def query( +async def aquery( obj: Any, query: str, context: Context, @@ -465,11 +774,11 @@ def query( assert isinstance(obj, MObjectProtocol) q = obj.get_query_object(query) - answer = act( + answer = await aact( q, context=context, backend=backend, - strategy=None, + strategy=None, # Explicitly pass `None` since this can't pass requirements. format=format, model_options=model_options, tool_calls=tool_calls, @@ -477,7 +786,7 @@ def query( return answer -def transform( +async def atransform( obj: Any, transformation: str, context: Context, @@ -509,11 +818,11 @@ def transform( # Check that your model / backend supports tool calling. # This might throw an error when tools are provided but can't be handled by one or the other. - transformed, new_ctx = act( + transformed, new_ctx = await aact( t, context=context, backend=backend, - strategy=None, + strategy=None, # Explicitly pass `None` since this can't pass requirements. format=format, model_options=model_options, tool_calls=True, diff --git a/mellea/stdlib/sampling/base.py b/mellea/stdlib/sampling/base.py index 9f374fa9..6520a7da 100644 --- a/mellea/stdlib/sampling/base.py +++ b/mellea/stdlib/sampling/base.py @@ -161,7 +161,7 @@ async def sample( await result.avalue() # validation pass - val_scores_co = mfuncs._validate( + val_scores_co = mfuncs.avalidate( reqs=reqs, context=result_ctx, backend=backend, diff --git a/mellea/stdlib/sampling/best_of_n.py b/mellea/stdlib/sampling/best_of_n.py index d0bdf341..59c402b2 100644 --- a/mellea/stdlib/sampling/best_of_n.py +++ b/mellea/stdlib/sampling/best_of_n.py @@ -127,7 +127,7 @@ async def sample( result = sampled_results[i] next_action = sampled_actions[i] - val_scores_co = mfuncs._validate( + val_scores_co = mfuncs.avalidate( reqs=reqs, context=result_ctx, backend=backend, diff --git a/mellea/stdlib/session.py b/mellea/stdlib/session.py index 6aa3204b..af12a0d4 100644 --- a/mellea/stdlib/session.py +++ b/mellea/stdlib/session.py @@ -236,6 +236,7 @@ def act( self, action: Component, *, + requirements: list[Requirement] | None = None, strategy: SamplingStrategy | None = None, return_sampling_results: Literal[False] = False, format: type[BaseModelSubclass] | None = None, @@ -248,6 +249,7 @@ def act( self, action: Component, *, + requirements: list[Requirement] | None = None, strategy: SamplingStrategy | None = None, return_sampling_results: Literal[True], format: type[BaseModelSubclass] | None = None, @@ -270,11 +272,17 @@ def act(self, action: Component, **kwargs) -> ModelOutputThunk | SamplingResult: Returns: A ModelOutputThunk if `return_sampling_results` is `False`, else returns a `SamplingResult`. """ - result, context = mfuncs.act( - action, context=self.ctx, backend=self.backend, **kwargs - ) - self.ctx = context - return result + + r = mfuncs.act(action, context=self.ctx, backend=self.backend, **kwargs) + + if isinstance(r, SamplingResult): + self.ctx = r.result_ctx + return r + else: + # It's a tuple[ModelOutputThunk, Context]. + result, context = r + self.ctx = context + return result @overload def instruct( @@ -340,6 +348,7 @@ def instruct(self, description: str, **kwargs) -> ModelOutputThunk | SamplingRes self.ctx = r.result_ctx return r else: + # It's a tuple[ModelOutputThunk, Context]. result, context = r self.ctx = context return result @@ -458,6 +467,247 @@ def transform( self.ctx = context return result + @overload + async def aact( + self, + action: Component, + *, + requirements: list[Requirement] | None = None, + strategy: SamplingStrategy | None = None, + return_sampling_results: Literal[False] = False, + format: type[BaseModelSubclass] | None = None, + model_options: dict | None = None, + tool_calls: bool = False, + ) -> ModelOutputThunk: ... + + @overload + async def aact( + self, + action: Component, + *, + requirements: list[Requirement] | None = None, + strategy: SamplingStrategy | None = None, + return_sampling_results: Literal[True], + format: type[BaseModelSubclass] | None = None, + model_options: dict | None = None, + tool_calls: bool = False, + ) -> SamplingResult: ... + + async def aact( + self, action: Component, **kwargs + ) -> ModelOutputThunk | SamplingResult: + """Runs a generic action, and adds both the action and the result to the context. + + Args: + action: the Component from which to generate. + requirements: used as additional requirements when a sampling strategy is provided + strategy: a SamplingStrategy that describes the strategy for validating and repairing/retrying for the instruct-validate-repair pattern. None means that no particular sampling strategy is used. + return_sampling_results: attach the (successful and failed) sampling attempts to the results. + format: if set, the BaseModel to use for constrained decoding. + model_options: additional model options, which will upsert into the model/backend's defaults. + tool_calls: if true, tool calling is enabled. + + Returns: + A ModelOutputThunk if `return_sampling_results` is `False`, else returns a `SamplingResult`. + """ + + r = await mfuncs.aact(action, context=self.ctx, backend=self.backend, **kwargs) + + if isinstance(r, SamplingResult): + self.ctx = r.result_ctx + return r + else: + # It's a tuple[ModelOutputThunk, Context]. + result, context = r + self.ctx = context + return result + + @overload + async def ainstruct( + self, + description: str, + *, + images: list[ImageBlock] | list[PILImage.Image] | None = None, + requirements: list[Requirement | str] | None = None, + icl_examples: list[str | CBlock] | None = None, + grounding_context: dict[str, str | CBlock | Component] | None = None, + user_variables: dict[str, str] | None = None, + prefix: str | CBlock | None = None, + output_prefix: str | CBlock | None = None, + strategy: SamplingStrategy | None = None, + return_sampling_results: Literal[False] = False, + format: type[BaseModelSubclass] | None = None, + model_options: dict | None = None, + tool_calls: bool = False, + ) -> ModelOutputThunk: ... + + @overload + async def ainstruct( + self, + description: str, + *, + images: list[ImageBlock] | list[PILImage.Image] | None = None, + requirements: list[Requirement | str] | None = None, + icl_examples: list[str | CBlock] | None = None, + grounding_context: dict[str, str | CBlock | Component] | None = None, + user_variables: dict[str, str] | None = None, + prefix: str | CBlock | None = None, + output_prefix: str | CBlock | None = None, + strategy: SamplingStrategy | None = None, + return_sampling_results: Literal[True], + format: type[BaseModelSubclass] | None = None, + model_options: dict | None = None, + tool_calls: bool = False, + ) -> SamplingResult: ... + + async def ainstruct( + self, description: str, **kwargs + ) -> ModelOutputThunk | SamplingResult: + """Generates from an instruction. + + Args: + description: The description of the instruction. + requirements: A list of requirements that the instruction can be validated against. + icl_examples: A list of in-context-learning examples that the instruction can be validated against. + grounding_context: A list of grounding contexts that the instruction can use. They can bind as variables using a (key: str, value: str | ContentBlock) tuple. + user_variables: A dict of user-defined variables used to fill in Jinja placeholders in other parameters. This requires that all other provided parameters are provided as strings. + prefix: A prefix string or ContentBlock to use when generating the instruction. + output_prefix: A string or ContentBlock that defines a prefix for the output generation. Usually you do not need this. + strategy: A SamplingStrategy that describes the strategy for validating and repairing/retrying for the instruct-validate-repair pattern. None means that no particular sampling strategy is used. + return_sampling_results: attach the (successful and failed) sampling attempts to the results. + format: If set, the BaseModel to use for constrained decoding. + model_options: Additional model options, which will upsert into the model/backend's defaults. + tool_calls: If true, tool calling is enabled. + images: A list of images to be used in the instruction or None if none. + """ + + r = await mfuncs.ainstruct( + description, context=self.ctx, backend=self.backend, **kwargs + ) + + if isinstance(r, SamplingResult): + self.ctx = r.result_ctx + return r + else: + # It's a tuple[ModelOutputThunk, Context]. + result, context = r + self.ctx = context + return result + + async def achat( + self, + content: str, + role: Message.Role = "user", + *, + images: list[ImageBlock] | list[PILImage.Image] | None = None, + user_variables: dict[str, str] | None = None, + format: type[BaseModelSubclass] | None = None, + model_options: dict | None = None, + tool_calls: bool = False, + ) -> Message: + """Sends a simple chat message and returns the response. Adds both messages to the Context.""" + + result, context = await mfuncs.achat( + content=content, + context=self.ctx, + backend=self.backend, + role=role, + images=images, + user_variables=user_variables, + format=format, + model_options=model_options, + tool_calls=tool_calls, + ) + + self.ctx = context + return result + + async def avalidate( + self, + reqs: Requirement | list[Requirement], + *, + output: CBlock | None = None, + format: type[BaseModelSubclass] | None = None, + model_options: dict | None = None, + generate_logs: list[GenerateLog] | None = None, + input: CBlock | None = None, + ) -> list[ValidationResult]: + """Validates a set of requirements over the output (if provided) or the current context (if the output is not provided).""" + + return await mfuncs.avalidate( + reqs=reqs, + context=self.ctx, + backend=self.backend, + output=output, + format=format, + model_options=model_options, + generate_logs=generate_logs, + input=input, + ) + + async def aquery( + self, + obj: Any, + query: str, + *, + format: type[BaseModelSubclass] | None = None, + model_options: dict | None = None, + tool_calls: bool = False, + ) -> ModelOutputThunk: + """Query method for retrieving information from an object. + + Args: + obj : The object to be queried. It should be an instance of MObject or can be converted to one if necessary. + query: The string representing the query to be executed against the object. + format: format for output parsing. + model_options: Model options to pass to the backend. + tool_calls: If true, the model may make tool calls. Defaults to False. + + Returns: + ModelOutputThunk: The result of the query as processed by the backend. + """ + result, context = await mfuncs.aquery( + obj=obj, + query=query, + context=self.ctx, + backend=self.backend, + format=format, + model_options=model_options, + tool_calls=tool_calls, + ) + self.ctx = context + return result + + async def atransform( + self, + obj: Any, + transformation: str, + *, + format: type[BaseModelSubclass] | None = None, + model_options: dict | None = None, + ) -> ModelOutputThunk | Any: + """Transform method for creating a new object with the transformation applied. + + Args: + obj : The object to be queried. It should be an instance of MObject or can be converted to one if necessary. + transformation: The string representing the query to be executed against the object. + + Returns: + ModelOutputThunk|Any: The result of the transformation as processed by the backend. If no tools were called, + the return type will be always be ModelOutputThunk. If a tool was called, the return type will be the return type + of the function called, usually the type of the object passed in. + """ + result, context = await mfuncs.atransform( + obj=obj, + transformation=transformation, + context=self.ctx, + backend=self.backend, + format=format, + model_options=model_options, + ) + self.ctx = context + return result + # ############################### # Convenience functions # ############################### diff --git a/test/stdlib_basics/test_funcs.py b/test/stdlib_basics/test_funcs.py index f652eb98..189fb6eb 100644 --- a/test/stdlib_basics/test_funcs.py +++ b/test/stdlib_basics/test_funcs.py @@ -3,8 +3,10 @@ import pytest from mellea.backends.types import ModelOption -from mellea.stdlib.base import CBlock -from mellea.stdlib.funcs import instruct +from mellea.stdlib.base import CBlock, ModelOutputThunk +from mellea.stdlib.chat import Message +from mellea.stdlib.funcs import instruct, aact, avalidate, ainstruct +from mellea.stdlib.requirement import req from mellea.stdlib.session import start_session @@ -33,5 +35,46 @@ def test_func_context(m_session): assert initial_ctx is not ctx assert ctx._data is out +async def test_aact(m_session): + initial_ctx = m_session.ctx + backend = m_session.backend + + out, ctx = await aact( + Message(role="user", content="hello"), + initial_ctx, + backend + ) + + assert initial_ctx is not ctx + assert ctx._data is out + +async def test_ainstruct(m_session): + initial_ctx = m_session.ctx + backend = m_session.backend + + out, ctx = await ainstruct( + "Write a sentence", + initial_ctx, + backend + ) + + assert initial_ctx is not ctx + assert ctx._data is out + +async def test_avalidate(m_session): + initial_ctx = m_session.ctx + backend = m_session.backend + + val_result = await avalidate( + reqs=[req("Be formal."), req("Avoid telling jokes.")], + context=initial_ctx, + backend=backend, + output=ModelOutputThunk("Here is an output.") + ) + + assert len(val_result) == 2 + assert val_result[0] is not None + + if __name__ == "__main__": pytest.main([__file__]) \ No newline at end of file From ee766fe000dbc561936a9b496bb9fee10eaf0ec1 Mon Sep 17 00:00:00 2001 From: jakelorocco Date: Tue, 30 Sep 2025 13:02:23 -0400 Subject: [PATCH 03/14] fix: add default RejectionSamplingStrategy to session funcs to make it explicit --- mellea/stdlib/session.py | 17 +++++++++-------- 1 file changed, 9 insertions(+), 8 deletions(-) diff --git a/mellea/stdlib/session.py b/mellea/stdlib/session.py index af12a0d4..a47cf6e5 100644 --- a/mellea/stdlib/session.py +++ b/mellea/stdlib/session.py @@ -29,6 +29,7 @@ from mellea.stdlib.chat import Message from mellea.stdlib.requirement import Requirement, ValidationResult from mellea.stdlib.sampling import SamplingResult, SamplingStrategy +from mellea.stdlib.sampling.base import RejectionSamplingStrategy # Global context variable for the context session _context_session: contextvars.ContextVar[MelleaSession | None] = contextvars.ContextVar( @@ -237,7 +238,7 @@ def act( action: Component, *, requirements: list[Requirement] | None = None, - strategy: SamplingStrategy | None = None, + strategy: SamplingStrategy | None = RejectionSamplingStrategy(loop_budget=2), return_sampling_results: Literal[False] = False, format: type[BaseModelSubclass] | None = None, model_options: dict | None = None, @@ -250,7 +251,7 @@ def act( action: Component, *, requirements: list[Requirement] | None = None, - strategy: SamplingStrategy | None = None, + strategy: SamplingStrategy | None = RejectionSamplingStrategy(loop_budget=2), return_sampling_results: Literal[True], format: type[BaseModelSubclass] | None = None, model_options: dict | None = None, @@ -296,7 +297,7 @@ def instruct( user_variables: dict[str, str] | None = None, prefix: str | CBlock | None = None, output_prefix: str | CBlock | None = None, - strategy: SamplingStrategy | None = None, + strategy: SamplingStrategy | None = RejectionSamplingStrategy(loop_budget=2), return_sampling_results: Literal[False] = False, format: type[BaseModelSubclass] | None = None, model_options: dict | None = None, @@ -315,7 +316,7 @@ def instruct( user_variables: dict[str, str] | None = None, prefix: str | CBlock | None = None, output_prefix: str | CBlock | None = None, - strategy: SamplingStrategy | None = None, + strategy: SamplingStrategy | None = RejectionSamplingStrategy(loop_budget=2), return_sampling_results: Literal[True], format: type[BaseModelSubclass] | None = None, model_options: dict | None = None, @@ -473,7 +474,7 @@ async def aact( action: Component, *, requirements: list[Requirement] | None = None, - strategy: SamplingStrategy | None = None, + strategy: SamplingStrategy | None = RejectionSamplingStrategy(loop_budget=2), return_sampling_results: Literal[False] = False, format: type[BaseModelSubclass] | None = None, model_options: dict | None = None, @@ -486,7 +487,7 @@ async def aact( action: Component, *, requirements: list[Requirement] | None = None, - strategy: SamplingStrategy | None = None, + strategy: SamplingStrategy | None = RejectionSamplingStrategy(loop_budget=2), return_sampling_results: Literal[True], format: type[BaseModelSubclass] | None = None, model_options: dict | None = None, @@ -534,7 +535,7 @@ async def ainstruct( user_variables: dict[str, str] | None = None, prefix: str | CBlock | None = None, output_prefix: str | CBlock | None = None, - strategy: SamplingStrategy | None = None, + strategy: SamplingStrategy | None = RejectionSamplingStrategy(loop_budget=2), return_sampling_results: Literal[False] = False, format: type[BaseModelSubclass] | None = None, model_options: dict | None = None, @@ -553,7 +554,7 @@ async def ainstruct( user_variables: dict[str, str] | None = None, prefix: str | CBlock | None = None, output_prefix: str | CBlock | None = None, - strategy: SamplingStrategy | None = None, + strategy: SamplingStrategy | None = RejectionSamplingStrategy(loop_budget=2), return_sampling_results: Literal[True], format: type[BaseModelSubclass] | None = None, model_options: dict | None = None, From d2033c2912f1254abaf9943e87bfcb253b82080c Mon Sep 17 00:00:00 2001 From: jakelorocco Date: Tue, 30 Sep 2025 13:28:22 -0400 Subject: [PATCH 04/14] fix: default sampling strat for session funcs --- mellea/stdlib/session.py | 114 +++++++++++++++++++++++++++++++++++---- 1 file changed, 104 insertions(+), 10 deletions(-) diff --git a/mellea/stdlib/session.py b/mellea/stdlib/session.py index a47cf6e5..0f400b1e 100644 --- a/mellea/stdlib/session.py +++ b/mellea/stdlib/session.py @@ -258,7 +258,17 @@ def act( tool_calls: bool = False, ) -> SamplingResult: ... - def act(self, action: Component, **kwargs) -> ModelOutputThunk | SamplingResult: # noqa: D417 + def act( + self, + action: Component, + *, + requirements: list[Requirement] | None = None, + strategy: SamplingStrategy | None = RejectionSamplingStrategy(loop_budget=2), + return_sampling_results: bool = False, + format: type[BaseModelSubclass] | None = None, + model_options: dict | None = None, + tool_calls: bool = False, + ) -> ModelOutputThunk | SamplingResult: """Runs a generic action, and adds both the action and the result to the context. Args: @@ -274,13 +284,22 @@ def act(self, action: Component, **kwargs) -> ModelOutputThunk | SamplingResult: A ModelOutputThunk if `return_sampling_results` is `False`, else returns a `SamplingResult`. """ - r = mfuncs.act(action, context=self.ctx, backend=self.backend, **kwargs) + r = mfuncs.act( + action, + context=self.ctx, + backend=self.backend, + requirements=requirements, + strategy=strategy, + return_sampling_results=return_sampling_results, + format=format, + model_options=model_options, + tool_calls=tool_calls, + ) # type: ignore if isinstance(r, SamplingResult): self.ctx = r.result_ctx return r else: - # It's a tuple[ModelOutputThunk, Context]. result, context = r self.ctx = context return result @@ -323,7 +342,23 @@ def instruct( tool_calls: bool = False, ) -> SamplingResult: ... - def instruct(self, description: str, **kwargs) -> ModelOutputThunk | SamplingResult: # noqa: D417 + def instruct( + self, + description: str, + *, + images: list[ImageBlock] | list[PILImage.Image] | None = None, + requirements: list[Requirement | str] | None = None, + icl_examples: list[str | CBlock] | None = None, + grounding_context: dict[str, str | CBlock | Component] | None = None, + user_variables: dict[str, str] | None = None, + prefix: str | CBlock | None = None, + output_prefix: str | CBlock | None = None, + strategy: SamplingStrategy | None = RejectionSamplingStrategy(loop_budget=2), + return_sampling_results: bool = False, + format: type[BaseModelSubclass] | None = None, + model_options: dict | None = None, + tool_calls: bool = False, + ) -> ModelOutputThunk | SamplingResult: """Generates from an instruction. Args: @@ -342,7 +377,21 @@ def instruct(self, description: str, **kwargs) -> ModelOutputThunk | SamplingRes images: A list of images to be used in the instruction or None if none. """ r = mfuncs.instruct( - description, context=self.ctx, backend=self.backend, **kwargs + description, + context=self.ctx, + backend=self.backend, + images=images, + requirements=requirements, + icl_examples=icl_examples, + grounding_context=grounding_context, + user_variables=user_variables, + prefix=prefix, + output_prefix=output_prefix, + strategy=strategy, + return_sampling_results=return_sampling_results, # type: ignore + format=format, + model_options=model_options, + tool_calls=tool_calls, ) if isinstance(r, SamplingResult): @@ -495,7 +544,15 @@ async def aact( ) -> SamplingResult: ... async def aact( - self, action: Component, **kwargs + self, + action: Component, + *, + requirements: list[Requirement] | None = None, + strategy: SamplingStrategy | None = RejectionSamplingStrategy(loop_budget=2), + return_sampling_results: bool = False, + format: type[BaseModelSubclass] | None = None, + model_options: dict | None = None, + tool_calls: bool = False, ) -> ModelOutputThunk | SamplingResult: """Runs a generic action, and adds both the action and the result to the context. @@ -512,13 +569,22 @@ async def aact( A ModelOutputThunk if `return_sampling_results` is `False`, else returns a `SamplingResult`. """ - r = await mfuncs.aact(action, context=self.ctx, backend=self.backend, **kwargs) + r = await mfuncs.aact( + action, + context=self.ctx, + backend=self.backend, + requirements=requirements, + strategy=strategy, + return_sampling_results=return_sampling_results, + format=format, + model_options=model_options, + tool_calls=tool_calls, + ) # type: ignore if isinstance(r, SamplingResult): self.ctx = r.result_ctx return r else: - # It's a tuple[ModelOutputThunk, Context]. result, context = r self.ctx = context return result @@ -562,7 +628,21 @@ async def ainstruct( ) -> SamplingResult: ... async def ainstruct( - self, description: str, **kwargs + self, + description: str, + *, + images: list[ImageBlock] | list[PILImage.Image] | None = None, + requirements: list[Requirement | str] | None = None, + icl_examples: list[str | CBlock] | None = None, + grounding_context: dict[str, str | CBlock | Component] | None = None, + user_variables: dict[str, str] | None = None, + prefix: str | CBlock | None = None, + output_prefix: str | CBlock | None = None, + strategy: SamplingStrategy | None = RejectionSamplingStrategy(loop_budget=2), + return_sampling_results: bool = False, + format: type[BaseModelSubclass] | None = None, + model_options: dict | None = None, + tool_calls: bool = False, ) -> ModelOutputThunk | SamplingResult: """Generates from an instruction. @@ -583,7 +663,21 @@ async def ainstruct( """ r = await mfuncs.ainstruct( - description, context=self.ctx, backend=self.backend, **kwargs + description, + context=self.ctx, + backend=self.backend, + images=images, + requirements=requirements, + icl_examples=icl_examples, + grounding_context=grounding_context, + user_variables=user_variables, + prefix=prefix, + output_prefix=output_prefix, + strategy=strategy, + return_sampling_results=return_sampling_results, # type: ignore + format=format, + model_options=model_options, + tool_calls=tool_calls, ) if isinstance(r, SamplingResult): From ae578312f53749fe24207317fd3d52a0ac539f8c Mon Sep 17 00:00:00 2001 From: jakelorocco Date: Tue, 30 Sep 2025 13:34:15 -0400 Subject: [PATCH 05/14] test: add minimum test examples for async session funcs --- test/stdlib_basics/test_session.py | 30 +++++++++++++++++++++++++++++- 1 file changed, 29 insertions(+), 1 deletion(-) diff --git a/test/stdlib_basics/test_session.py b/test/stdlib_basics/test_session.py index 67168a38..59286899 100644 --- a/test/stdlib_basics/test_session.py +++ b/test/stdlib_basics/test_session.py @@ -2,9 +2,27 @@ import pytest +from mellea.backends.types import ModelOption from mellea.stdlib.base import ModelOutputThunk +from mellea.stdlib.chat import Message from mellea.stdlib.session import start_session +@pytest.fixture(scope="module") +def m_session(gh_run): + if gh_run == 1: + m = start_session( + "ollama", + model_id="llama3.2:1b", + model_options={ModelOption.MAX_NEW_TOKENS: 5}, + ) + else: + m = start_session( + "ollama", + model_id="granite3.3:8b", + model_options={ModelOption.MAX_NEW_TOKENS: 5}, + ) + yield m + del m def test_start_session_watsonx(gh_run): if gh_run == 1: @@ -15,7 +33,6 @@ def test_start_session_watsonx(gh_run): assert isinstance(response, ModelOutputThunk) assert response.value is not None - def test_start_session_openai_with_kwargs(gh_run): if gh_run == 1: m = start_session( @@ -37,6 +54,17 @@ def test_start_session_openai_with_kwargs(gh_run): assert response.value is not None assert initial_ctx is not m.ctx +async def test_aact(m_session): + initial_ctx = m_session.ctx + out = await m_session.aact(Message(role="user", content="Hello!")) + assert m_session.ctx is not initial_ctx + assert out.value is not None + +async def test_ainstruct(m_session): + initial_ctx = m_session.ctx + out = await m_session.ainstruct("Write a sentence.") + assert m_session.ctx is not initial_ctx + assert out.value is not None if __name__ == "__main__": pytest.main([__file__]) From 9338d425ee67f5b897b795a411d4c63cba49bc16 Mon Sep 17 00:00:00 2001 From: jakelorocco Date: Tue, 30 Sep 2025 15:38:57 -0400 Subject: [PATCH 06/14] docs: add async to tutorial --- docs/tutorial.md | 50 ++++++++++++++++++++++++++++++++++++++++++++++++ 1 file changed, 50 insertions(+) diff --git a/docs/tutorial.md b/docs/tutorial.md index 3640300f..e96f19a2 100644 --- a/docs/tutorial.md +++ b/docs/tutorial.md @@ -21,6 +21,7 @@ - [Chapter 10: Prompt Engineering for Mellea](#chapter-10-prompt-engineering-for-m) - [Custom Templates](#custom-templates) - [Chapter 11: Tool Calling](#chapter-11-tool-calling) +- [Chapter 12: Asynchronicity](#chapter-12-asynchronicity) - [Appendix: Contributing to Melles](#appendix-contributing-to-mellea) ## Chapter 1: What Is Generative Programming @@ -1317,6 +1318,55 @@ assert "web_search" in output.tool_calls result = output.tool_calls["web_search"].call_func() ``` +## Chapter 12: Asynchronicity +Mellea supports asynchronous behavior in several ways: asynchronous functions and asynchronous event loops in synchronous functions. + +### Asynchronous Functions: +`MelleaSession`s have asynchronous functions that work just like regular async functions in python. These async session functions mirror their synchronous counterparts: +``` +m = start_session() +result = await m.ainstruct("Write your instruction here!") +``` + +However, if you want to run multiple async functions at the same time, you need to be careful with your context. By default, `MelleaSession`s use a `SimpleContext` that has no history. This will work just fine when running multiple async requests at once: +``` +m = start_session() +coroutines = [] + +for i in range(5): + coroutines.append(m.ainstruct(f"Write a math problem using {i}")) + +results = await asyncio.gather(*coroutines) +``` + +If you try to use a `ChatContext`, you will need to await between each request so that the context can be properly modified: +``` +m = start_session(ctx=ChatContext()) + +result = await m.ainstruct("Write a short fairy tale.") +print(result) + +main_character = await m.ainstruct("Who is the main character of the previous fairy tail?") +print(main_character) +``` + +Otherwise, you're requests will use outdated contexts that don't have the messages you expect. For example, +``` +m = start_session(ctx=ChatContext()) + +co1 = m.ainstruct("Write a very long math problem.") # Start first request. +co2 = m.ainstruct("Solve the math problem.") # Start second request with an empty context. + +results = await asyncio.gather(co1, co2) +for result in results: + print(result) # Neither request had anything in its context. +``` + +### Asynchronicity in Synchronous Functions +Mellea utilizes asynchronicity internally. When you call `m.instruct`, you are using synchronous code that executes an asynchronous request to an LLM to generate the result. For a single request, this won't cause any differences in execution speed. + +When using `SamplingStrategy`s or during validation, Mellea can speed up the execution time of your program by generating multiple results and validating those results against multiple requirements simultaneously. Whether you use `m.instruct` or the asynchronous `m.ainstruct`, Mellea will attempt to speed up your requests by dispatching those requests as quickly as possible and asynchronously awaiting the results. + ## Appendix: Contributing to Mellea ### Contributor Guide: Requirements and Verifiers From 24dfd033c2c9c29799df74b7422bdbd3e31ff805 Mon Sep 17 00:00:00 2001 From: jakelorocco Date: Tue, 30 Sep 2025 15:49:22 -0400 Subject: [PATCH 07/14] feat: add warning for async with non-Simple contexts --- mellea/stdlib/funcs.py | 12 ++++++++++++ 1 file changed, 12 insertions(+) diff --git a/mellea/stdlib/funcs.py b/mellea/stdlib/funcs.py index f3739dc2..d8410b69 100644 --- a/mellea/stdlib/funcs.py +++ b/mellea/stdlib/funcs.py @@ -103,6 +103,7 @@ def act( format=format, model_options=model_options, tool_calls=tool_calls, + silence_context_type_warning=True, # We can safely silence this here since it's in a sync function. ) # type: ignore[call-overload] # Mypy doesn't like the bool for return_sampling_results. ) @@ -425,6 +426,7 @@ async def aact( format: type[BaseModelSubclass] | None = None, model_options: dict | None = None, tool_calls: bool = False, + silence_context_type_warning: bool = False, ) -> tuple[ModelOutputThunk, Context]: ... @@ -440,6 +442,7 @@ async def aact( format: type[BaseModelSubclass] | None = None, model_options: dict | None = None, tool_calls: bool = False, + silence_context_type_warning: bool = False, ) -> SamplingResult: ... @@ -454,6 +457,7 @@ async def aact( format: type[BaseModelSubclass] | None = None, model_options: dict | None = None, tool_calls: bool = False, + silence_context_type_warning: bool = False, ) -> tuple[ModelOutputThunk, Context] | SamplingResult: """Asynchronous version of .act; runs a generic action, and adds both the action and the result to the context. @@ -467,10 +471,18 @@ async def aact( format: if set, the BaseModel to use for constrained decoding. model_options: additional model options, which will upsert into the model/backend's defaults. tool_calls: if true, tool calling is enabled. + silence_context_type_warning: if called directly from an asynchronous function, will log a warning if not using a SimpleContext Returns: A (ModelOutputThunk, Context) if `return_sampling_results` is `False`, else returns a `SamplingResult`. """ + + if not silence_context_type_warning and not isinstance(context, SimpleContext): + FancyLogger().get_logger().warning( + "Not using a SimpleContext with asynchronous requests could cause unexpected results due to stale contexts. Ensure you await between requests." + "\nSee the async section of the tutorial: https://github.com/generative-computing/mellea/blob/main/docs/tutorial.md#chapter-12-asynchronicity" + ) + sampling_result: SamplingResult | None = None generate_logs: list[GenerateLog] = [] From 03b9af87c349e6f227b823f091814dba30408356 Mon Sep 17 00:00:00 2001 From: jakelorocco Date: Tue, 30 Sep 2025 16:34:38 -0400 Subject: [PATCH 08/14] fix: add async session tests with chat context --- test/stdlib_basics/test_session.py | 38 ++++++++++++++++++++++++++++-- 1 file changed, 36 insertions(+), 2 deletions(-) diff --git a/test/stdlib_basics/test_session.py b/test/stdlib_basics/test_session.py index 59286899..8f6f635c 100644 --- a/test/stdlib_basics/test_session.py +++ b/test/stdlib_basics/test_session.py @@ -1,13 +1,15 @@ +import asyncio import os import pytest from mellea.backends.types import ModelOption -from mellea.stdlib.base import ModelOutputThunk +from mellea.stdlib.base import ChatContext, ModelOutputThunk from mellea.stdlib.chat import Message from mellea.stdlib.session import start_session -@pytest.fixture(scope="module") +# We edit the context type in the async tests below. Don't change the scope here. +@pytest.fixture(scope="function") def m_session(gh_run): if gh_run == 1: m = start_session( @@ -66,5 +68,37 @@ async def test_ainstruct(m_session): assert m_session.ctx is not initial_ctx assert out.value is not None +async def test_async_await_with_chat_context(m_session): + m_session.ctx = ChatContext() + + m1 = Message(role="user", content="1") + m2 = Message(role="user", content="2") + r1 = await m_session.aact(m1) + r2 = await m_session.aact(m2) + + # This should be the order of these items in the session's context. + history = [r2, m2, r1, m1] + + ctx = m_session.ctx + for i in range(len(history)): + assert ctx.node_data is history[i] + ctx = ctx.previous_node + + # Ensure we made it back to the root. + assert ctx.is_root_node == True + +async def test_async_without_waiting_with_chat_context(m_session): + m_session.ctx = ChatContext() + + m1 = Message(role="user", content="1") + m2 = Message(role="user", content="2") + co1 = m_session.aact(m1) + co2 = m_session.aact(m2) + _, _ = await asyncio.gather(co2, co1) + + ctx = m_session.ctx + assert len(ctx.view_for_generation()) == 2 + + if __name__ == "__main__": pytest.main([__file__]) From 0922ee2c1aeed746c653a177ec0c94b077bd552c Mon Sep 17 00:00:00 2001 From: jakelorocco Date: Wed, 1 Oct 2025 09:43:00 -0400 Subject: [PATCH 09/14] feat: add clone method to session; remove refs to model_opts in session --- docs/tutorial.md | 29 ++++++++++++++--- mellea/stdlib/session.py | 48 +++++++++++++++++++++------- test/stdlib_basics/test_session.py | 50 ++++++++++++++++++++++++++++++ 3 files changed, 112 insertions(+), 15 deletions(-) diff --git a/docs/tutorial.md b/docs/tutorial.md index e96f19a2..933b7b76 100644 --- a/docs/tutorial.md +++ b/docs/tutorial.md @@ -944,6 +944,23 @@ or the entire last turn (user query + assistant response): print(m.ctx.last_turn()) ``` +You can also use `session.clone()` to create a copy of a given session with its context at given point in time. This allows you to make multiple generation requests with the same objects in your context: +```python +m = start_session(ctx=ChatContext()) +m.instruct("Multiply 2x2.") + +m1 = m.clone() +m2 = m.clone() + +# Need to run this code in an async event loop. +co1 = m1.ainstruct("Multiply that by 3") +co2 = m2.ainstruct("Multiply that by 5") + +print(await co1) # 12 +print(await co2) # 20 +``` +In the above example, both requests have `Multiply 2x2` and the LLM's response to that (presumably `4`) in their context. By cloning the session, the new requests both operate independently on that context to get the correct answers to 4 x 3 and 4 x 5. + ## Chapter 8: Implementing Agents > **Definition:** An *agent* is a generative program in which an LLM determines the control flow of the program. @@ -1323,13 +1340,13 @@ Mellea supports asynchronous behavior in several ways: asynchronous functions an ### Asynchronous Functions: `MelleaSession`s have asynchronous functions that work just like regular async functions in python. These async session functions mirror their synchronous counterparts: -``` +```python m = start_session() result = await m.ainstruct("Write your instruction here!") ``` However, if you want to run multiple async functions at the same time, you need to be careful with your context. By default, `MelleaSession`s use a `SimpleContext` that has no history. This will work just fine when running multiple async requests at once: -``` +```python m = start_session() coroutines = [] @@ -1340,7 +1357,7 @@ results = await asyncio.gather(*coroutines) ``` If you try to use a `ChatContext`, you will need to await between each request so that the context can be properly modified: -``` +```python m = start_session(ctx=ChatContext()) result = await m.ainstruct("Write a short fairy tale.") @@ -1351,7 +1368,7 @@ print(main_character) ``` Otherwise, you're requests will use outdated contexts that don't have the messages you expect. For example, -``` +```python m = start_session(ctx=ChatContext()) co1 = m.ainstruct("Write a very long math problem.") # Start first request. @@ -1360,8 +1377,12 @@ co2 = m.ainstruct("Solve the math problem.") # Start second request with an emp results = await asyncio.gather(co1, co2) for result in results: print(result) # Neither request had anything in its context. + +print(m.ctx) # Only shows the operations from the second request. ``` +Additionally, see [Chapter 7: Context Management](#chapter-7-on-context-management) for an example of how to use `session.clone()` to avoid these context issues. + ### Asynchronicity in Synchronous Functions Mellea utilizes asynchronicity internally. When you call `m.instruct`, you are using synchronous code that executes an asynchronous request to an LLM to generate the result. For a single request, this won't cause any differences in execution speed. diff --git a/mellea/stdlib/session.py b/mellea/stdlib/session.py index 0f400b1e..d1167adc 100644 --- a/mellea/stdlib/session.py +++ b/mellea/stdlib/session.py @@ -3,6 +3,7 @@ from __future__ import annotations import contextvars +from copy import copy from typing import Any, Literal, overload from PIL import Image as PILImage @@ -176,11 +177,10 @@ def __init__(self, backend: Backend, ctx: Context | None = None): Args: backend (Backend): This is always required. ctx (Context): The way in which the model's context will be managed. By default, each interaction with the model is a stand-alone interaction, so we use SimpleContext as the default. - model_options (Optional[dict]): model options, which will upsert into the model/backend's defaults. """ self.backend = backend self.ctx: Context = ctx if ctx is not None else SimpleContext() - self._backend_stack: list[tuple[Backend, dict | None]] = [] + self._backend_stack: list[Backend] = [] self._session_logger = FancyLogger.get_logger() self._context_token = None @@ -196,14 +196,10 @@ def __exit__(self, exc_type, exc_val, exc_tb): _context_session.reset(self._context_token) self._context_token = None - def _push_model_state(self, new_backend: Backend, new_model_opts: dict): - """The backend and model options used within a `Context` can be temporarily changed. This method changes the model's backend and model_opts, while saving the current settings in the `self._backend_stack`. - - Question: should this logic be moved into context? I really want to keep `Session` as simple as possible... see true motivation in the docstring for the class. - """ - self._backend_stack.append((self.backend, self.model_options)) + def _push_model_state(self, new_backend: Backend): + """The backend used within a `Context` can be temporarily changed. This method changes the model's backend, while saving the current settings in the `self._backend_stack`.""" + self._backend_stack.append(self.backend) self.backend = new_backend - self.opts = new_model_opts def _pop_model_state(self) -> bool: """Pops the model state. @@ -214,13 +210,43 @@ def _pop_model_state(self) -> bool: Question: should this logic be moved into context? I really want to keep `Session` as simple as possible... see true motivation in the docstring for the class. """ try: - b, b_model_opts = self._backend_stack.pop() + b = self._backend_stack.pop() self.backend = b - self.model_options = b_model_opts return True except Exception: return False + def __copy__(self): + new = MelleaSession(backend=self.backend, ctx=self.ctx) + new._backend_stack = self._backend_stack.copy() + new._session_logger = self._session_logger + # Explicitly don't copy over the _context_token. + + return new + + def clone(self): + """Useful for running multiple generation requests while keeping the context at a given point in time. + + Returns: + a copy of the current session. Keeps the context, backend, backend stack, and session logger. + + Examples: + >>> from mellea import start_session + >>> m = start_session() + >>> m.instruct("What is 2x2?") + >>> + >>> m1 = m.clone() + >>> out = m1.instruct("Multiply that by 2") + >>> print(out) + ... 8 + >>> + >>> m2 = m.clone() + >>> out = m2.instruct("Multiply that by 3") + >>> print(out) + ... 12 + """ + return copy(self) + def reset(self): """Reset the context state.""" self.ctx = self.ctx.reset_to_new() diff --git a/test/stdlib_basics/test_session.py b/test/stdlib_basics/test_session.py index 8f6f635c..e55359f8 100644 --- a/test/stdlib_basics/test_session.py +++ b/test/stdlib_basics/test_session.py @@ -3,6 +3,7 @@ import pytest +from mellea.backends.ollama import OllamaModelBackend from mellea.backends.types import ModelOption from mellea.stdlib.base import ChatContext, ModelOutputThunk from mellea.stdlib.chat import Message @@ -99,6 +100,55 @@ async def test_async_without_waiting_with_chat_context(m_session): ctx = m_session.ctx assert len(ctx.view_for_generation()) == 2 +def test_session_copy_with_context_ops(m_session): + out = m_session.instruct("What is 2x2?") + main_ctx = m_session.ctx + + m1 = m_session.clone() + out1 = m1.instruct("Multiply by 3.") + + m2 = m_session.clone() + out2 = m2.instruct("Multiply by 4.") + + # Assert that each context is the correct one. + assert m_session.ctx is main_ctx + assert m_session.ctx is not m1.ctx + assert m_session.ctx is not m2.ctx + assert m1.ctx is not m2.ctx + + # Assert that node data is correct. + assert m_session.ctx.node_data is out + assert m1.ctx.node_data is out1 + assert m2.ctx.node_data is out2 + + # Assert that the new sessions still branch off the original one. + assert m1.ctx.previous_node.previous_node is m_session.ctx + assert m2.ctx.previous_node.previous_node is m_session.ctx + +def test_session_copy_with_backend_stack(m_session): + # Assert expected values from cloning. + m1 = m_session.clone() + assert m1.backend is m_session.backend + assert m1._session_logger is m_session._session_logger + assert m1._backend_stack is not m_session._backend_stack + + # Assert that pushing to a backend stack doesn't change it for sessions previously cloned from it. + new_backend = OllamaModelBackend() + m_session._push_model_state(new_backend=new_backend) + assert len(m_session._backend_stack) == 1 + assert len(m1._backend_stack) == 0 + assert m1.backend is not m_session.backend + + # Assert that newly cloned sessions don't cause errors with changes to the backend stack. + m2 = m_session.clone() + assert len(m2._backend_stack) == 1 + + # They should still be different lists. + assert m2._backend_stack is not m_session._backend_stack + assert m2._pop_model_state() + assert len(m2._backend_stack) == 0 + assert len(m_session._backend_stack) == 1 + assert m2.backend is m1.backend if __name__ == "__main__": pytest.main([__file__]) From 70a78a3a7d33e0cdccf111650099f9717b417e49 Mon Sep 17 00:00:00 2001 From: jakelorocco Date: Wed, 1 Oct 2025 15:07:01 -0400 Subject: [PATCH 10/14] feat: add async generative slots --- mellea/stdlib/genslot.py | 98 ++++++++++++++++++++++++++++-- test/stdlib_basics/test_genslot.py | 18 ++++++ 2 files changed, 110 insertions(+), 6 deletions(-) diff --git a/mellea/stdlib/genslot.py b/mellea/stdlib/genslot.py index 1e822871..4f8f2368 100644 --- a/mellea/stdlib/genslot.py +++ b/mellea/stdlib/genslot.py @@ -1,8 +1,9 @@ """A method to generate outputs based on python functions and a Generative Slot function.""" +import asyncio import functools import inspect -from collections.abc import Callable +from collections.abc import Callable, Coroutine from copy import deepcopy from typing import Any, Generic, ParamSpec, TypedDict, TypeVar, get_type_hints @@ -168,14 +169,13 @@ def __call__( **kwargs: Additional Kwargs to be passed to the func. Returns: - ModelOutputThunk: Output with generated Thunk. + R: an object with the original return type of the function """ if m is None: m = get_session() slot_copy = deepcopy(self) arguments = bind_function_arguments(self._function._func, *args, **kwargs) if arguments: - # slot_copy._arguments = [] for key, val in arguments.items(): annotation = get_annotation(slot_copy._function._func, key, val) slot_copy._arguments.append(Argument(annotation, key, val)) @@ -207,6 +207,52 @@ def format_for_llm(self) -> TemplateRepresentation: ) +class AsyncGenerativeSlot(GenerativeSlot, Generic[P, R]): + """A generative slot component that generates asynchronously and returns a coroutine.""" + + def __call__( + self, + m: MelleaSession | None = None, + model_options: dict | None = None, + *args: P.args, + **kwargs: P.kwargs, + ) -> Coroutine[Any, Any, R]: + """Call the async generative slot. + + Args: + m: MelleaSession: A mellea session (optional, uses context if None) + **kwargs: Additional Kwargs to be passed to the func + + Returns: + Coroutine[Any, Any, R]: a coroutine that returns an object with the original return type of the function + """ + if m is None: + m = get_session() + slot_copy = deepcopy(self) + arguments = bind_function_arguments(self._function._func, *args, **kwargs) + if arguments: + for key, val in arguments.items(): + annotation = get_annotation(slot_copy._function._func, key, val) + slot_copy._arguments.append(Argument(annotation, key, val)) + + response_model = create_response_format(self._function._func) + + # AsyncGenerativeSlots are used with async functions. In order to support that behavior, + # they must return a coroutine object. + async def __async_call__(): + # Use the async act func so that control flow doesn't get stuck here in async event loops. + response = await m.aact( + slot_copy, format=response_model, model_options=model_options + ) + + function_response: FunctionResponse[R] = response_model.model_validate_json( + response.value # type: ignore + ) + return function_response.result + + return __async_call__() + + def generative(func: Callable[P, R]) -> GenerativeSlot[P, R]: """Convert a function into an AI-powered function. @@ -216,6 +262,8 @@ def generative(func: Callable[P, R]) -> GenerativeSlot[P, R]: that function's behavior. The output is guaranteed to match the return type annotation using structured outputs and automatic validation. + Note: Works with async functions as well. + Tip: Write the function and docstring in the most Pythonic way possible, not like a prompt. This ensures the function is well-documented, easily understood, and familiar to any Python developer. The more natural and conventional your @@ -248,7 +296,7 @@ def generative(func: Callable[P, R]) -> GenerativeSlot[P, R]: ... estimated_hours: float >>> >>> @generative - ... def create_project_tasks(project_desc: str, count: int) -> List[Task]: + ... async def create_project_tasks(project_desc: str, count: int) -> List[Task]: ... '''Generate a list of realistic tasks for a project. ... ... Args: @@ -260,7 +308,7 @@ def generative(func: Callable[P, R]) -> GenerativeSlot[P, R]: ... ''' ... ... >>> - >>> tasks = create_project_tasks(session, "Build a web app", 5) + >>> tasks = await create_project_tasks(session, "Build a web app", 5) >>> @generative ... def analyze_code_quality(code: str) -> Dict[str, Any]: @@ -304,8 +352,46 @@ def generative(func: Callable[P, R]) -> GenerativeSlot[P, R]: >>> >>> reasoning = generate_chain_of_thought(session, "How to optimize a slow database query?") """ - return GenerativeSlot(func) + if inspect.iscoroutinefunction(func): + return AsyncGenerativeSlot(func) + else: + return GenerativeSlot(func) # Export the decorator as the interface __all__ = ["generative"] + + +if __name__ == "__main__": + from mellea import start_session + + with start_session(): + + async def asyncly() -> int: ... + + out = asyncly() + + @generative + async def test_async(num: int) -> bool: ... + + @generative + def test_sync(truthy: bool) -> int: ... + + print("running sync") + print(test_sync(m=None, model_options=None, truthy=False)) + + async def runmany(): + print(await test_async(m=None, model_options=None, num=6)) + print(await test_async(m=None, model_options=None, num=4)) + print(await test_async(m=None, model_options=None, num=5)) + + coros = [ + test_async(m=None, model_options=None, num=1), + test_async(m=None, model_options=None, num=2), + test_async(m=None, model_options=None, num=3), + ] + results = await asyncio.gather(*coros) + print(results) + + print("running async") + asyncio.run(runmany()) diff --git a/test/stdlib_basics/test_genslot.py b/test/stdlib_basics/test_genslot.py index ebcace55..c9695260 100644 --- a/test/stdlib_basics/test_genslot.py +++ b/test/stdlib_basics/test_genslot.py @@ -1,6 +1,8 @@ +import asyncio import pytest from typing import Literal from mellea import generative, start_session +from mellea.stdlib.genslot import AsyncGenerativeSlot, GenerativeSlot @generative @@ -10,6 +12,8 @@ def classify_sentiment(text: str) -> Literal["positive", "negative"]: ... @generative def write_me_an_email() -> str: ... +@generative +async def async_write_short_sentence(topic: str) -> str: ... @pytest.fixture(scope="function") def session(): @@ -29,6 +33,7 @@ def test_gen_slot_output(classify_sentiment_output): def test_func(session): + assert isinstance(write_me_an_email, GenerativeSlot) and not isinstance(write_me_an_email, AsyncGenerativeSlot) write_email_component = write_me_an_email(session) assert isinstance(write_email_component, str) @@ -43,5 +48,18 @@ def test_gen_slot_logs(classify_sentiment_output, session): assert isinstance(last_prompt, dict) assert set(last_prompt.keys()) == {"role", "content", "images"} +async def test_async_gen_slot(session): + assert isinstance(async_write_short_sentence, AsyncGenerativeSlot) + + r1 = async_write_short_sentence(session, topic="cats") + r2 = async_write_short_sentence(session, topic="dogs") + + r3 = await async_write_short_sentence(session, topic="fish") + results = await asyncio.gather(r1, r2) + + assert isinstance(r3, str) + assert len(results) == 2 + + if __name__ == "__main__": pytest.main([__file__]) From 9522c88db43b4ea776c77e9dacf5fc0dbf86fe0f Mon Sep 17 00:00:00 2001 From: jakelorocco Date: Wed, 1 Oct 2025 15:10:08 -0400 Subject: [PATCH 11/14] fix: remove sessions' backend stack --- mellea/stdlib/session.py | 23 ----------------------- test/stdlib_basics/test_session.py | 25 ------------------------- 2 files changed, 48 deletions(-) diff --git a/mellea/stdlib/session.py b/mellea/stdlib/session.py index d1167adc..24098640 100644 --- a/mellea/stdlib/session.py +++ b/mellea/stdlib/session.py @@ -180,7 +180,6 @@ def __init__(self, backend: Backend, ctx: Context | None = None): """ self.backend = backend self.ctx: Context = ctx if ctx is not None else SimpleContext() - self._backend_stack: list[Backend] = [] self._session_logger = FancyLogger.get_logger() self._context_token = None @@ -196,29 +195,8 @@ def __exit__(self, exc_type, exc_val, exc_tb): _context_session.reset(self._context_token) self._context_token = None - def _push_model_state(self, new_backend: Backend): - """The backend used within a `Context` can be temporarily changed. This method changes the model's backend, while saving the current settings in the `self._backend_stack`.""" - self._backend_stack.append(self.backend) - self.backend = new_backend - - def _pop_model_state(self) -> bool: - """Pops the model state. - - The backend and model options used within a `Context` can be temporarily changed by pushing and popping from the model state. - This function restores the model's previous backend and model_opts from the `self._backend_stack`. - - Question: should this logic be moved into context? I really want to keep `Session` as simple as possible... see true motivation in the docstring for the class. - """ - try: - b = self._backend_stack.pop() - self.backend = b - return True - except Exception: - return False - def __copy__(self): new = MelleaSession(backend=self.backend, ctx=self.ctx) - new._backend_stack = self._backend_stack.copy() new._session_logger = self._session_logger # Explicitly don't copy over the _context_token. @@ -254,7 +232,6 @@ def reset(self): def cleanup(self) -> None: """Clean up session resources.""" self.reset() - self._backend_stack.clear() if hasattr(self.backend, "close"): self.backend.close() # type: ignore diff --git a/test/stdlib_basics/test_session.py b/test/stdlib_basics/test_session.py index e55359f8..1388b15b 100644 --- a/test/stdlib_basics/test_session.py +++ b/test/stdlib_basics/test_session.py @@ -125,30 +125,5 @@ def test_session_copy_with_context_ops(m_session): assert m1.ctx.previous_node.previous_node is m_session.ctx assert m2.ctx.previous_node.previous_node is m_session.ctx -def test_session_copy_with_backend_stack(m_session): - # Assert expected values from cloning. - m1 = m_session.clone() - assert m1.backend is m_session.backend - assert m1._session_logger is m_session._session_logger - assert m1._backend_stack is not m_session._backend_stack - - # Assert that pushing to a backend stack doesn't change it for sessions previously cloned from it. - new_backend = OllamaModelBackend() - m_session._push_model_state(new_backend=new_backend) - assert len(m_session._backend_stack) == 1 - assert len(m1._backend_stack) == 0 - assert m1.backend is not m_session.backend - - # Assert that newly cloned sessions don't cause errors with changes to the backend stack. - m2 = m_session.clone() - assert len(m2._backend_stack) == 1 - - # They should still be different lists. - assert m2._backend_stack is not m_session._backend_stack - assert m2._pop_model_state() - assert len(m2._backend_stack) == 0 - assert len(m_session._backend_stack) == 1 - assert m2.backend is m1.backend - if __name__ == "__main__": pytest.main([__file__]) From a02b2e8a0e5bba55ae5d07a61153728874085f5f Mon Sep 17 00:00:00 2001 From: jakelorocco Date: Mon, 6 Oct 2025 09:55:06 -0400 Subject: [PATCH 12/14] fix: remove testing code --- mellea/stdlib/genslot.py | 39 +++------------------------------------ 1 file changed, 3 insertions(+), 36 deletions(-) diff --git a/mellea/stdlib/genslot.py b/mellea/stdlib/genslot.py index 4f8f2368..e2b2d57c 100644 --- a/mellea/stdlib/genslot.py +++ b/mellea/stdlib/genslot.py @@ -221,6 +221,8 @@ def __call__( Args: m: MelleaSession: A mellea session (optional, uses context if None) + model_options: Model options to pass to the backend. + *args: Additional args to be passed to the func. **kwargs: Additional Kwargs to be passed to the func Returns: @@ -239,7 +241,7 @@ def __call__( # AsyncGenerativeSlots are used with async functions. In order to support that behavior, # they must return a coroutine object. - async def __async_call__(): + async def __async_call__() -> R: # Use the async act func so that control flow doesn't get stuck here in async event loops. response = await m.aact( slot_copy, format=response_model, model_options=model_options @@ -360,38 +362,3 @@ def generative(func: Callable[P, R]) -> GenerativeSlot[P, R]: # Export the decorator as the interface __all__ = ["generative"] - - -if __name__ == "__main__": - from mellea import start_session - - with start_session(): - - async def asyncly() -> int: ... - - out = asyncly() - - @generative - async def test_async(num: int) -> bool: ... - - @generative - def test_sync(truthy: bool) -> int: ... - - print("running sync") - print(test_sync(m=None, model_options=None, truthy=False)) - - async def runmany(): - print(await test_async(m=None, model_options=None, num=6)) - print(await test_async(m=None, model_options=None, num=4)) - print(await test_async(m=None, model_options=None, num=5)) - - coros = [ - test_async(m=None, model_options=None, num=1), - test_async(m=None, model_options=None, num=2), - test_async(m=None, model_options=None, num=3), - ] - results = await asyncio.gather(*coros) - print(results) - - print("running async") - asyncio.run(runmany()) From f8f3e03b68d7a85df919730e2508e7c54423a889 Mon Sep 17 00:00:00 2001 From: jakelorocco Date: Mon, 6 Oct 2025 10:49:06 -0400 Subject: [PATCH 13/14] fix: docstrings --- mellea/stdlib/funcs.py | 6 +++--- mellea/stdlib/session.py | 12 +++++------- 2 files changed, 8 insertions(+), 10 deletions(-) diff --git a/mellea/stdlib/funcs.py b/mellea/stdlib/funcs.py index d8410b69..79c81d06 100644 --- a/mellea/stdlib/funcs.py +++ b/mellea/stdlib/funcs.py @@ -347,10 +347,12 @@ def transform( """Transform method for creating a new object with the transformation applied. Args: - obj : The object to be queried. It should be an instance of MObject or can be converted to one if necessary. + obj: The object to be queried. It should be an instance of MObject or can be converted to one if necessary. transformation: The string representing the query to be executed against the object. context: the context being used as a history from which to generate the response. backend: the backend used to generate the response. + format: format for output parsing; usually not needed with transform. + model_options: Model options to pass to the backend. Returns: (ModelOutputThunk | Any, Context): The result of the transformation as processed by the backend. If no tools were called, @@ -476,7 +478,6 @@ async def aact( Returns: A (ModelOutputThunk, Context) if `return_sampling_results` is `False`, else returns a `SamplingResult`. """ - if not silence_context_type_warning and not isinstance(context, SimpleContext): FancyLogger().get_logger().warning( "Not using a SimpleContext with asynchronous requests could cause unexpected results due to stale contexts. Ensure you await between requests." @@ -630,7 +631,6 @@ async def ainstruct( Returns: A (ModelOutputThunk, Context) if `return_sampling_results` is `False`, else returns a `SamplingResult`. """ - requirements = [] if requirements is None else requirements icl_examples = [] if icl_examples is None else icl_examples grounding_context = dict() if grounding_context is None else grounding_context diff --git a/mellea/stdlib/session.py b/mellea/stdlib/session.py index 24098640..2a63a71a 100644 --- a/mellea/stdlib/session.py +++ b/mellea/stdlib/session.py @@ -196,6 +196,7 @@ def __exit__(self, exc_type, exc_val, exc_tb): self._context_token = None def __copy__(self): + """Use self.clone. Copies the current session but keeps references to the backend and context.""" new = MelleaSession(backend=self.backend, ctx=self.ctx) new._session_logger = self._session_logger # Explicitly don't copy over the _context_token. @@ -206,7 +207,7 @@ def clone(self): """Useful for running multiple generation requests while keeping the context at a given point in time. Returns: - a copy of the current session. Keeps the context, backend, backend stack, and session logger. + a copy of the current session. Keeps the context, backend, and session logger. Examples: >>> from mellea import start_session @@ -286,7 +287,6 @@ def act( Returns: A ModelOutputThunk if `return_sampling_results` is `False`, else returns a `SamplingResult`. """ - r = mfuncs.act( action, context=self.ctx, @@ -571,7 +571,6 @@ async def aact( Returns: A ModelOutputThunk if `return_sampling_results` is `False`, else returns a `SamplingResult`. """ - r = await mfuncs.aact( action, context=self.ctx, @@ -664,7 +663,6 @@ async def ainstruct( tool_calls: If true, tool calling is enabled. images: A list of images to be used in the instruction or None if none. """ - r = await mfuncs.ainstruct( description, context=self.ctx, @@ -704,7 +702,6 @@ async def achat( tool_calls: bool = False, ) -> Message: """Sends a simple chat message and returns the response. Adds both messages to the Context.""" - result, context = await mfuncs.achat( content=content, context=self.ctx, @@ -731,7 +728,6 @@ async def avalidate( input: CBlock | None = None, ) -> list[ValidationResult]: """Validates a set of requirements over the output (if provided) or the current context (if the output is not provided).""" - return await mfuncs.avalidate( reqs=reqs, context=self.ctx, @@ -787,8 +783,10 @@ async def atransform( """Transform method for creating a new object with the transformation applied. Args: - obj : The object to be queried. It should be an instance of MObject or can be converted to one if necessary. + obj: The object to be queried. It should be an instance of MObject or can be converted to one if necessary. transformation: The string representing the query to be executed against the object. + format: format for output parsing; usually not needed with transform. + model_options: Model options to pass to the backend. Returns: ModelOutputThunk|Any: The result of the transformation as processed by the backend. If no tools were called, From c956d8e48784dc7667401aee1d32b106a5e78c3f Mon Sep 17 00:00:00 2001 From: jakelorocco Date: Mon, 6 Oct 2025 11:11:06 -0400 Subject: [PATCH 14/14] fix: test failing due to sampling copy --- test/stdlib_basics/test_session.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/test/stdlib_basics/test_session.py b/test/stdlib_basics/test_session.py index 1388b15b..efcb51ed 100644 --- a/test/stdlib_basics/test_session.py +++ b/test/stdlib_basics/test_session.py @@ -74,8 +74,8 @@ async def test_async_await_with_chat_context(m_session): m1 = Message(role="user", content="1") m2 = Message(role="user", content="2") - r1 = await m_session.aact(m1) - r2 = await m_session.aact(m2) + r1 = await m_session.aact(m1, strategy=None) + r2 = await m_session.aact(m2, strategy=None) # This should be the order of these items in the session's context. history = [r2, m2, r1, m1]