Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
77 changes: 77 additions & 0 deletions test/stdlib/requirements/test_reqlib_markdown.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
is_markdown_list,
is_markdown_table,
)
from mellea.stdlib.requirements.md import _md_list, _md_table


def from_model(s: str) -> Context:
Expand Down Expand Up @@ -79,5 +80,81 @@ def test_default_output_to_bool_word_with_yes_in_it():
)


# --- as_markdown_list edge cases ---


def test_as_markdown_list_paragraph():
"""Plain paragraph is not a list — should return None."""
ctx = from_model("This is just a paragraph of text.")
assert as_markdown_list(ctx) is None


def test_as_markdown_list_mixed_content():
"""List followed by a paragraph should return None (not all children are lists)."""
ctx = from_model(
"""- item one
- item two

This is a paragraph after the list."""
)
assert as_markdown_list(ctx) is None


def test_as_markdown_list_empty():
"""Empty string should return None."""
ctx = from_model("")
assert as_markdown_list(ctx) is None


def test_as_markdown_list_single_item():
"""Single-item list should work."""
ctx = from_model("- only item")
result = as_markdown_list(ctx)
assert result is not None
assert len(result) == 1


# --- _md_list validation wrapper ---


def test_md_list_valid():
result = _md_list(MARKDOWN_LIST_CTX)
assert result.as_bool() is True


def test_md_list_invalid():
ctx = from_model("Just a paragraph.")
result = _md_list(ctx)
assert result.as_bool() is False


# --- _md_table edge cases ---


def test_md_table_not_a_table():
ctx = from_model("This is just text, not a table.")
result = _md_table(ctx)
assert result.as_bool() is False


def test_md_table_multiple_children():
"""A heading followed by a table = 2 children, should return False."""
ctx = from_model(
"""# Title

| Col A | Col B |
|-------|-------|
| 1 | 2 |"""
)
result = _md_table(ctx)
assert result.as_bool() is False


def test_md_table_empty():
ctx = from_model("")
result = _md_table(ctx)
assert result.as_bool() is False


if __name__ == "__main__":
pytest.main([__file__])
176 changes: 175 additions & 1 deletion test/stdlib/requirements/test_reqlib_tools.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,28 @@
from unittest.mock import Mock

import pytest

from mellea.stdlib.requirements.tool_reqs import _name2str
from mellea.core import ModelOutputThunk, ModelToolCall
from mellea.stdlib.context import ChatContext
from mellea.stdlib.requirements.tool_reqs import (
_name2str,
tool_arg_validator,
uses_tool,
)


def _ctx_with_tool_calls(tool_calls: dict[str, ModelToolCall] | None) -> ChatContext:
"""Helper: build a ChatContext whose last output has the given tool_calls."""
ctx = ChatContext()
return ctx.add(ModelOutputThunk(value="", tool_calls=tool_calls))


def _make_tool_call(name: str, args: dict) -> ModelToolCall:
"""Helper: build a ModelToolCall with a mock func."""
return ModelToolCall(name=name, func=Mock(), args=args)


# --- _name2str ---


def test_name2str():
Expand All @@ -11,3 +33,155 @@ def test123():

assert _name2str(test123) == "test123"
assert _name2str("test1234") == "test1234"


def test_name2str_type_error():
with pytest.raises(TypeError, match="Expected Callable or str"):
_name2str(123) # type: ignore[arg-type]


# --- uses_tool ---


def test_uses_tool_present():
ctx = _ctx_with_tool_calls({"get_weather": _make_tool_call("get_weather", {})})
req = uses_tool("get_weather")
result = req.validation_fn(ctx)
assert result.as_bool() is True


def test_uses_tool_absent():
ctx = _ctx_with_tool_calls({"get_weather": _make_tool_call("get_weather", {})})
req = uses_tool("send_email")
result = req.validation_fn(ctx)
assert result.as_bool() is False


def test_uses_tool_no_tool_calls():
ctx = _ctx_with_tool_calls(None)
req = uses_tool("get_weather")
result = req.validation_fn(ctx)
assert result.as_bool() is False
assert "no tool calls" in result.reason.lower()


def test_uses_tool_callable_input():
def my_tool():
pass

ctx = _ctx_with_tool_calls({"my_tool": _make_tool_call("my_tool", {})})
req = uses_tool(my_tool)
result = req.validation_fn(ctx)
assert result.as_bool() is True


def test_uses_tool_check_only():
req = uses_tool("get_weather", check_only=True)
assert req.check_only is True


# --- tool_arg_validator ---


def test_tool_arg_validator_valid():
ctx = _ctx_with_tool_calls(
{"search": _make_tool_call("search", {"query": "hello", "limit": 10})}
)
req = tool_arg_validator(
description="limit must be positive",
tool_name="search",
arg_name="limit",
validation_fn=lambda v: v > 0,
)
result = req.validation_fn(ctx)
assert result.as_bool() is True


def test_tool_arg_validator_failed_validation():
ctx = _ctx_with_tool_calls(
{"search": _make_tool_call("search", {"query": "hello", "limit": -1})}
)
req = tool_arg_validator(
description="limit must be positive",
tool_name="search",
arg_name="limit",
validation_fn=lambda v: v > 0,
)
result = req.validation_fn(ctx)
assert result.as_bool() is False


def test_tool_arg_validator_missing_tool():
ctx = _ctx_with_tool_calls(
{"search": _make_tool_call("search", {"query": "hello"})}
)
req = tool_arg_validator(
description="check email tool",
tool_name="send_email",
arg_name="to",
validation_fn=lambda v: True,
)
result = req.validation_fn(ctx)
assert result.as_bool() is False
assert "send_email" in result.reason


def test_tool_arg_validator_missing_arg():
ctx = _ctx_with_tool_calls(
{"search": _make_tool_call("search", {"query": "hello"})}
)
req = tool_arg_validator(
description="limit must exist",
tool_name="search",
arg_name="limit",
validation_fn=lambda v: True,
)
result = req.validation_fn(ctx)
assert result.as_bool() is False
assert "limit" in result.reason


def test_tool_arg_validator_no_tool_calls():
ctx = _ctx_with_tool_calls(None)
req = tool_arg_validator(
description="check tool",
tool_name="search",
arg_name="query",
validation_fn=lambda v: True,
)
result = req.validation_fn(ctx)
assert result.as_bool() is False


def test_tool_arg_validator_no_tool_name_all_pass():
ctx = _ctx_with_tool_calls(
{
"tool_a": _make_tool_call("tool_a", {"x": 5}),
"tool_b": _make_tool_call("tool_b", {"x": 10}),
}
)
req = tool_arg_validator(
description="x must be positive",
tool_name=None,
arg_name="x",
validation_fn=lambda v: v > 0,
)
result = req.validation_fn(ctx)
assert result.as_bool() is True


def test_tool_arg_validator_no_tool_name_one_fails():
ctx = _ctx_with_tool_calls(
{
"tool_a": _make_tool_call("tool_a", {"x": 5}),
"tool_b": _make_tool_call("tool_b", {"x": -1}),
}
)
req = tool_arg_validator(
description="x must be positive",
tool_name=None,
arg_name="x",
validation_fn=lambda v: v > 0,
)
result = req.validation_fn(ctx)
assert result.as_bool() is False
Loading
Loading