Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion tests/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,12 +20,12 @@
[
"test_example_workflows.py",
"test_run_state.py",
"test_sandbox_memory.py",
"sandbox/capabilities/test_filesystem_capability.py",
"sandbox/integration_tests/test_runner_pause_resume.py",
"sandbox/test_client_options.py",
"sandbox/test_exposed_ports.py",
"sandbox/test_extract.py",
"sandbox/test_memory.py",
"sandbox/test_runtime.py",
"sandbox/test_session_manager.py",
"sandbox/test_session_sinks.py",
Expand Down
45 changes: 45 additions & 0 deletions tests/extensions/experiemental/codex/test_payloads.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,45 @@
from __future__ import annotations

import pytest

from agents.extensions.experimental.codex.items import AgentMessageItem, TodoItem, TodoListItem


def test_dict_like_supports_mapping_access_for_dataclass_fields() -> None:
item = AgentMessageItem(id="item-1", text="hello")

assert item["id"] == "item-1"
assert item["text"] == "hello"
assert item["type"] == "agent_message"
assert item.get("text") == "hello"
assert item.get("missing", "fallback") == "fallback"
assert "id" in item
assert "missing" not in item
assert object() not in item
assert list(item.keys()) == ["id", "text", "type"]


def test_dict_like_raises_key_error_for_unknown_fields() -> None:
item = AgentMessageItem(id="item-1", text="hello")

with pytest.raises(KeyError, match="missing"):
_ = item["missing"]


def test_dict_like_as_dict_recursively_converts_nested_dataclasses() -> None:
item = TodoListItem(
id="todo-list-1",
items=[
TodoItem(text="write tests", completed=True),
TodoItem(text="run tests", completed=False),
],
)

assert item.as_dict() == {
"id": "todo-list-1",
"items": [
{"text": "write tests", "completed": True},
{"text": "run tests", "completed": False},
],
"type": "todo_list",
}
Original file line number Diff line number Diff line change
Expand Up @@ -11,14 +11,14 @@

def _load_example_module() -> Any:
path = (
Path(__file__).resolve().parents[2]
Path(__file__).resolve().parents[3]
/ "examples"
/ "sandbox"
/ "extensions"
/ "runloop"
/ "capabilities.py"
)
module_name = "tests.extensions.runloop_capabilities_example"
module_name = "tests.extensions.sandbox.runloop_capabilities_example"
spec = importlib.util.spec_from_file_location(module_name, path)
assert spec is not None
assert spec.loader is not None
Expand Down
5 changes: 2 additions & 3 deletions tests/test_session.py → tests/memory/test_session.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,9 +8,8 @@
import pytest

from agents import Agent, RunConfig, Runner, SQLiteSession, TResponseInputItem

from .fake_model import FakeModel
from .test_responses import get_text_message
from tests.fake_model import FakeModel
from tests.test_responses import get_text_message


# Helper functions for parametrized testing of different Runner methods
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -8,8 +8,8 @@
from agents import Agent, RunConfig, SQLiteSession
from agents.memory import SessionSettings
from tests.fake_model import FakeModel
from tests.memory.test_session import run_agent_async
from tests.test_responses import get_text_message
from tests.test_session import run_agent_async


@pytest.mark.parametrize("runner_method", ["run", "run_sync", "run_streamed"])
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -25,8 +25,7 @@
)
from agents.run_internal.model_retry import get_response_with_retry, stream_response_with_retry
from agents.usage import Usage

from .test_responses import get_text_message
from tests.test_responses import get_text_message


def _connection_error(message: str = "connection error") -> APIConnectionError:
Expand Down
File renamed without changes.
File renamed without changes.
File renamed without changes.
95 changes: 95 additions & 0 deletions tests/sandbox/test_session_state_roundtrip.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,9 @@
from pathlib import Path
from typing import Literal

import pytest
from pydantic import ValidationError

from agents.sandbox import Manifest
from agents.sandbox.session import SandboxSessionState
from agents.sandbox.snapshot import LocalSnapshot
Expand All @@ -27,6 +30,21 @@ class _StubSessionState(SandboxSessionState):
custom_field: str


class _PlainTypeSessionState(SandboxSessionState):
__test__ = False
type: str = "plain-type"


class _EmptyDefaultSessionState(SandboxSessionState):
__test__ = False
type: Literal[""] = ""


class _SimpleSessionState(SandboxSessionState):
__test__ = False
type: Literal["simple-roundtrip"] = "simple-roundtrip"


# ---------------------------------------------------------------------------
# Helpers
# ---------------------------------------------------------------------------
Expand Down Expand Up @@ -93,3 +111,80 @@ def test_model_dump_preserves_snapshot_subclass_fields(self) -> None:
dumped = state.model_dump()

assert "base_path" in dumped["snapshot"]

def test_parse_returns_subclass_instances_as_is(self) -> None:
state = _make_session_state()

assert SandboxSessionState.parse(state) is state

def test_parse_upgrades_base_instance_through_registry(self) -> None:
state = _SimpleSessionState(
session_id=uuid.UUID("bbbbbbbb-bbbb-bbbb-bbbb-bbbbbbbbbbbb"),
snapshot=LocalSnapshot(id="snap-1", base_path=Path("/tmp/snapshots")),
manifest=Manifest(),
)
base_instance = SandboxSessionState.model_validate(state.model_dump())

reconstructed = SandboxSessionState.parse(base_instance)

assert type(reconstructed) is _SimpleSessionState
assert reconstructed.session_id == uuid.UUID("bbbbbbbb-bbbb-bbbb-bbbb-bbbbbbbbbbbb")

@pytest.mark.parametrize(
("payload", "error_type", "message"),
[
({}, ValueError, "must include a string `type`"),
({"type": "missing"}, ValueError, "unknown sandbox session state type `missing`"),
("not-a-state", TypeError, "session state payload must be"),
],
)
def test_parse_rejects_invalid_payloads(
self,
payload: object,
error_type: type[Exception],
message: str,
) -> None:
with pytest.raises(error_type, match=message):
SandboxSessionState.parse(payload)

def test_subclass_registration_skips_non_literal_or_empty_type_defaults(self) -> None:
assert "plain-type" not in SandboxSessionState._subclass_registry
assert "" not in SandboxSessionState._subclass_registry

@pytest.mark.parametrize(
("raw_ports", "expected"),
[
(None, ()),
(8080, (8080,)),
([8080, 9000, 8080], (8080, 9000)),
],
)
def test_exposed_ports_are_normalized(
self, raw_ports: object, expected: tuple[int, ...]
) -> None:
state = _StubSessionState(
snapshot=LocalSnapshot(id="snap-1", base_path=Path("/tmp/snapshots")),
manifest=Manifest(),
custom_field="my-value",
exposed_ports=raw_ports, # type: ignore[arg-type]
)

assert state.exposed_ports == expected

@pytest.mark.parametrize(
("raw_ports", "message"),
[
("8080", "exposed_ports must be an iterable"),
([8080, "9000"], "exposed_ports must contain integers"),
([0], "exposed_ports entries must be between 1 and 65535"),
([65536], "exposed_ports entries must be between 1 and 65535"),
],
)
def test_exposed_ports_reject_invalid_values(self, raw_ports: object, message: str) -> None:
with pytest.raises((TypeError, ValidationError), match=message):
_StubSessionState(
snapshot=LocalSnapshot(id="snap-1", base_path=Path("/tmp/snapshots")),
manifest=Manifest(),
custom_field="my-value",
exposed_ports=raw_ports, # type: ignore[arg-type]
)
96 changes: 96 additions & 0 deletions tests/sandbox/test_token_truncation.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,96 @@
from __future__ import annotations

from agents.sandbox.util.token_truncation import (
TruncationPolicy,
approx_bytes_for_tokens,
approx_token_count,
approx_tokens_from_byte_count,
format_truncation_marker,
formatted_truncate_text,
formatted_truncate_text_with_token_count,
removed_units_for_source,
split_budget,
split_string,
truncate_text,
truncate_with_byte_estimate,
truncate_with_token_budget,
)


def test_truncation_policy_clamps_negative_limits_and_converts_budgets() -> None:
byte_policy = TruncationPolicy.bytes(-10)
token_policy = TruncationPolicy.tokens(-2)

assert byte_policy.limit == 0
assert byte_policy.token_budget() == 0
assert byte_policy.byte_budget() == 0
assert token_policy.limit == 0
assert token_policy.token_budget() == 0
assert token_policy.byte_budget() == 0


def test_formatted_truncate_text_returns_short_content_unchanged() -> None:
assert formatted_truncate_text("short", TruncationPolicy.bytes(20)) == "short"


def test_formatted_truncate_text_adds_line_count_when_truncated() -> None:
result = formatted_truncate_text("alpha\nbeta\ngamma", TruncationPolicy.bytes(8))

assert result.startswith("Total output lines: 3\n\n")
assert "chars truncated" in result


def test_formatted_truncate_text_with_token_count_handles_none_and_short_content() -> None:
assert formatted_truncate_text_with_token_count("short", None) == ("short", None)
assert formatted_truncate_text_with_token_count("short", 10) == ("short", None)


def test_formatted_truncate_text_with_token_count_reports_original_count() -> None:
result, original_token_count = formatted_truncate_text_with_token_count("abcdefghi", 1)

assert result.startswith("Total output lines: 1\n\n")
assert "tokens truncated" in result
assert original_token_count == approx_token_count("abcdefghi")


def test_truncate_text_dispatches_byte_and_token_modes() -> None:
assert truncate_text("abcdef", TruncationPolicy.bytes(4)).startswith("a")
assert "tokens truncated" in truncate_text("abcdefghi", TruncationPolicy.tokens(1))


def test_truncate_with_token_budget_handles_empty_and_short_content() -> None:
assert truncate_with_token_budget("", TruncationPolicy.tokens(1)) == ("", None)
assert truncate_with_token_budget("abc", TruncationPolicy.tokens(1)) == ("abc", None)


def test_truncate_with_byte_estimate_handles_empty_zero_and_short_content() -> None:
assert truncate_with_byte_estimate("", TruncationPolicy.bytes(0)) == ""
assert "chars truncated" in truncate_with_byte_estimate("abc", TruncationPolicy.bytes(0))
assert truncate_with_byte_estimate("abc", TruncationPolicy.bytes(10)) == "abc"


def test_split_string_preserves_utf8_boundaries() -> None:
removed_chars, prefix, suffix = split_string("aあbいc", 2, 4)

assert prefix == "a"
assert suffix == "いc"
assert removed_chars == 2


def test_split_string_handles_empty_content() -> None:
assert split_string("", 10, 10) == (0, "", "")


def test_formatting_and_estimate_helpers() -> None:
byte_policy = TruncationPolicy.bytes(8)
token_policy = TruncationPolicy.tokens(2)

assert "chars truncated" in format_truncation_marker(byte_policy, 3)
assert "tokens truncated" in format_truncation_marker(token_policy, 2)
assert split_budget(5) == (2, 3)
assert removed_units_for_source(byte_policy, removed_bytes=10, removed_chars=4) == 4
assert removed_units_for_source(token_policy, removed_bytes=9, removed_chars=4) == 3
assert approx_token_count("abcde") == 2
assert approx_bytes_for_tokens(-1) == 0
assert approx_tokens_from_byte_count(0) == 0
assert approx_tokens_from_byte_count(5) == 2
Loading
Loading