diff --git a/test/stdlib/requirements/test_reqlib_markdown.py b/test/stdlib/requirements/test_reqlib_markdown.py index b663999f0..20ef57648 100644 --- a/test/stdlib/requirements/test_reqlib_markdown.py +++ b/test/stdlib/requirements/test_reqlib_markdown.py @@ -7,6 +7,7 @@ is_markdown_list, is_markdown_table, ) +from mellea.stdlib.requirements.md import _md_list, _md_table def from_model(s: str) -> Context: @@ -79,5 +80,81 @@ def test_default_output_to_bool_word_with_yes_in_it(): ) +# --- as_markdown_list edge cases --- + + +def test_as_markdown_list_paragraph(): + """Plain paragraph is not a list — should return None.""" + ctx = from_model("This is just a paragraph of text.") + assert as_markdown_list(ctx) is None + + +def test_as_markdown_list_mixed_content(): + """List followed by a paragraph should return None (not all children are lists).""" + ctx = from_model( + """- item one +- item two + +This is a paragraph after the list.""" + ) + assert as_markdown_list(ctx) is None + + +def test_as_markdown_list_empty(): + """Empty string should return None.""" + ctx = from_model("") + assert as_markdown_list(ctx) is None + + +def test_as_markdown_list_single_item(): + """Single-item list should work.""" + ctx = from_model("- only item") + result = as_markdown_list(ctx) + assert result is not None + assert len(result) == 1 + + +# --- _md_list validation wrapper --- + + +def test_md_list_valid(): + result = _md_list(MARKDOWN_LIST_CTX) + assert result.as_bool() is True + + +def test_md_list_invalid(): + ctx = from_model("Just a paragraph.") + result = _md_list(ctx) + assert result.as_bool() is False + + +# --- _md_table edge cases --- + + +def test_md_table_not_a_table(): + ctx = from_model("This is just text, not a table.") + result = _md_table(ctx) + assert result.as_bool() is False + + +def test_md_table_multiple_children(): + """A heading followed by a table = 2 children, should return False.""" + ctx = from_model( + """# Title + +| Col A | Col B | +|-------|-------| +| 1 | 2 |""" + ) + result = _md_table(ctx) + assert result.as_bool() is False + + +def test_md_table_empty(): + ctx = from_model("") + result = _md_table(ctx) + assert result.as_bool() is False + + if __name__ == "__main__": pytest.main([__file__]) diff --git a/test/stdlib/requirements/test_reqlib_tools.py b/test/stdlib/requirements/test_reqlib_tools.py index c20d4a783..a98cde904 100644 --- a/test/stdlib/requirements/test_reqlib_tools.py +++ b/test/stdlib/requirements/test_reqlib_tools.py @@ -1,6 +1,28 @@ +from unittest.mock import Mock + import pytest -from mellea.stdlib.requirements.tool_reqs import _name2str +from mellea.core import ModelOutputThunk, ModelToolCall +from mellea.stdlib.context import ChatContext +from mellea.stdlib.requirements.tool_reqs import ( + _name2str, + tool_arg_validator, + uses_tool, +) + + +def _ctx_with_tool_calls(tool_calls: dict[str, ModelToolCall] | None) -> ChatContext: + """Helper: build a ChatContext whose last output has the given tool_calls.""" + ctx = ChatContext() + return ctx.add(ModelOutputThunk(value="", tool_calls=tool_calls)) + + +def _make_tool_call(name: str, args: dict) -> ModelToolCall: + """Helper: build a ModelToolCall with a mock func.""" + return ModelToolCall(name=name, func=Mock(), args=args) + + +# --- _name2str --- def test_name2str(): @@ -11,3 +33,155 @@ def test123(): assert _name2str(test123) == "test123" assert _name2str("test1234") == "test1234" + + +def test_name2str_type_error(): + with pytest.raises(TypeError, match="Expected Callable or str"): + _name2str(123) # type: ignore[arg-type] + + +# --- uses_tool --- + + +def test_uses_tool_present(): + ctx = _ctx_with_tool_calls({"get_weather": _make_tool_call("get_weather", {})}) + req = uses_tool("get_weather") + result = req.validation_fn(ctx) + assert result.as_bool() is True + + +def test_uses_tool_absent(): + ctx = _ctx_with_tool_calls({"get_weather": _make_tool_call("get_weather", {})}) + req = uses_tool("send_email") + result = req.validation_fn(ctx) + assert result.as_bool() is False + + +def test_uses_tool_no_tool_calls(): + ctx = _ctx_with_tool_calls(None) + req = uses_tool("get_weather") + result = req.validation_fn(ctx) + assert result.as_bool() is False + assert "no tool calls" in result.reason.lower() + + +def test_uses_tool_callable_input(): + def my_tool(): + pass + + ctx = _ctx_with_tool_calls({"my_tool": _make_tool_call("my_tool", {})}) + req = uses_tool(my_tool) + result = req.validation_fn(ctx) + assert result.as_bool() is True + + +def test_uses_tool_check_only(): + req = uses_tool("get_weather", check_only=True) + assert req.check_only is True + + +# --- tool_arg_validator --- + + +def test_tool_arg_validator_valid(): + ctx = _ctx_with_tool_calls( + {"search": _make_tool_call("search", {"query": "hello", "limit": 10})} + ) + req = tool_arg_validator( + description="limit must be positive", + tool_name="search", + arg_name="limit", + validation_fn=lambda v: v > 0, + ) + result = req.validation_fn(ctx) + assert result.as_bool() is True + + +def test_tool_arg_validator_failed_validation(): + ctx = _ctx_with_tool_calls( + {"search": _make_tool_call("search", {"query": "hello", "limit": -1})} + ) + req = tool_arg_validator( + description="limit must be positive", + tool_name="search", + arg_name="limit", + validation_fn=lambda v: v > 0, + ) + result = req.validation_fn(ctx) + assert result.as_bool() is False + + +def test_tool_arg_validator_missing_tool(): + ctx = _ctx_with_tool_calls( + {"search": _make_tool_call("search", {"query": "hello"})} + ) + req = tool_arg_validator( + description="check email tool", + tool_name="send_email", + arg_name="to", + validation_fn=lambda v: True, + ) + result = req.validation_fn(ctx) + assert result.as_bool() is False + assert "send_email" in result.reason + + +def test_tool_arg_validator_missing_arg(): + ctx = _ctx_with_tool_calls( + {"search": _make_tool_call("search", {"query": "hello"})} + ) + req = tool_arg_validator( + description="limit must exist", + tool_name="search", + arg_name="limit", + validation_fn=lambda v: True, + ) + result = req.validation_fn(ctx) + assert result.as_bool() is False + assert "limit" in result.reason + + +def test_tool_arg_validator_no_tool_calls(): + ctx = _ctx_with_tool_calls(None) + req = tool_arg_validator( + description="check tool", + tool_name="search", + arg_name="query", + validation_fn=lambda v: True, + ) + result = req.validation_fn(ctx) + assert result.as_bool() is False + + +def test_tool_arg_validator_no_tool_name_all_pass(): + ctx = _ctx_with_tool_calls( + { + "tool_a": _make_tool_call("tool_a", {"x": 5}), + "tool_b": _make_tool_call("tool_b", {"x": 10}), + } + ) + req = tool_arg_validator( + description="x must be positive", + tool_name=None, + arg_name="x", + validation_fn=lambda v: v > 0, + ) + result = req.validation_fn(ctx) + assert result.as_bool() is True + + +def test_tool_arg_validator_no_tool_name_one_fails(): + ctx = _ctx_with_tool_calls( + { + "tool_a": _make_tool_call("tool_a", {"x": 5}), + "tool_b": _make_tool_call("tool_b", {"x": -1}), + } + ) + req = tool_arg_validator( + description="x must be positive", + tool_name=None, + arg_name="x", + validation_fn=lambda v: v > 0, + ) + result = req.validation_fn(ctx) + assert result.as_bool() is False diff --git a/test/stdlib/requirements/test_requirement.py b/test/stdlib/requirements/test_requirement.py index 575e40ec7..125022e06 100644 --- a/test/stdlib/requirements/test_requirement.py +++ b/test/stdlib/requirements/test_requirement.py @@ -1,16 +1,26 @@ +import json +from unittest.mock import patch + import pytest from mellea.core import ModelOutputThunk, Requirement from mellea.stdlib.context import ChatContext from mellea.stdlib.requirements import LLMaJRequirement, simple_validate +from mellea.stdlib.requirements.requirement import ( + ALoraRequirement, + check, + req, + reqify, + requirement_check_to_bool, +) from mellea.stdlib.session import start_session ctx = ChatContext() ctx = ctx.add(ModelOutputThunk("test")) -pytestmark = [pytest.mark.ollama, pytest.mark.e2e] - +@pytest.mark.ollama +@pytest.mark.e2e async def test_llmaj_validation_req_output_field(): m = start_session(ctx=ctx) req = Requirement("Must output test.") @@ -22,6 +32,8 @@ async def test_llmaj_validation_req_output_field(): ) +@pytest.mark.ollama +@pytest.mark.e2e async def test_llmaj_requirement_uses_requirement_template(): m = start_session(ctx=ctx) req = LLMaJRequirement("Must output test.") @@ -60,5 +72,113 @@ def test_simple_validate_invalid(): validation_func(ctx) +# --- requirement_check_to_bool --- + + +def test_requirement_check_to_bool_above_threshold(): + assert requirement_check_to_bool('{"requirement_likelihood": 0.8}') is True + + +def test_requirement_check_to_bool_below_threshold(): + assert requirement_check_to_bool('{"requirement_likelihood": 0.3}') is False + + +def test_requirement_check_to_bool_at_threshold(): + """0.5 is NOT > 0.5, so should return False.""" + assert requirement_check_to_bool('{"requirement_likelihood": 0.5}') is False + + +def test_requirement_check_to_bool_missing_key(): + assert requirement_check_to_bool('{"other_field": 1.0}') is False + + +def test_requirement_check_to_bool_invalid_json(): + with pytest.raises(json.JSONDecodeError): + requirement_check_to_bool("not json") + + +# --- reqify --- + + +def test_reqify_string(): + result = reqify("must be valid") + assert isinstance(result, Requirement) + assert result.description == "must be valid" + + +def test_reqify_requirement(): + original = Requirement("must be valid") + result = reqify(original) + assert result is original + + +def test_reqify_invalid_type(): + with pytest.raises(Exception, match="reqify takes a str or requirement"): + reqify(123) # type: ignore[arg-type] + + +# --- req / check shorthands --- + + +def test_req_shorthand(): + result = req("must be valid") + assert isinstance(result, Requirement) + assert result.description == "must be valid" + + +def test_check_shorthand(): + result = check("must be valid") + assert isinstance(result, Requirement) + assert result.check_only is True + + +# --- simple_validate edge case --- + + +def test_simple_validate_none_output(): + """Context with no output should return False without calling the fn.""" + empty_ctx = ChatContext() + validation_func = simple_validate(lambda x: True) + result = validation_func(empty_ctx) + assert result.as_bool() is False + + +# --- LLMaJRequirement --- + + +def test_llmaj_requirement_use_aloras_false(): + r = LLMaJRequirement("must be valid") + assert r.use_aloras is False + + +# --- ALoraRequirement --- + + +@patch("mellea.stdlib.requirements.requirement.Intrinsic.__init__") +def test_alora_requirement_default_intrinsic(mock_intrinsic_init): + mock_intrinsic_init.return_value = None + r = ALoraRequirement("must be valid") + assert r.use_aloras is True + assert r.description == "must be valid" + # Intrinsic.__init__ is unbound; mock receives self as first positional arg. + mock_intrinsic_init.assert_called_once_with( + r, + intrinsic_name="requirement_check", + intrinsic_kwargs={"requirement": "must be valid"}, + ) + + +@patch("mellea.stdlib.requirements.requirement.Intrinsic.__init__") +def test_alora_requirement_custom_intrinsic(mock_intrinsic_init): + mock_intrinsic_init.return_value = None + r = ALoraRequirement("must be valid", intrinsic_name="custom_check") + assert r.use_aloras is True + mock_intrinsic_init.assert_called_once_with( + r, + intrinsic_name="custom_check", + intrinsic_kwargs={"requirement": "must be valid"}, + ) + + if __name__ == "__main__": pytest.main([__file__])