Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
43 changes: 43 additions & 0 deletions cpex/framework/memory.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
import copy
import logging
import weakref
from collections.abc import Mapping
from typing import Any, Iterator, Optional, TypeVar

# Third-Party
Expand Down Expand Up @@ -173,6 +174,48 @@ def __repr__(self) -> str:
"""
return f"CopyOnWriteDict({dict(self.items())})"

__hash__ = None

def __eq__(self, other: Any) -> bool:
"""
Compare equality with another mapping.

Compares the materialized logical mapping (original + modifications - deletions)
rather than the empty base dict storage.

Args:
other: The object to compare with.

Returns:
True if other is a Mapping with the same key-value pairs, False otherwise.
Returns NotImplemented for non-Mapping types to allow other.__eq__ to handle it.
"""
if not isinstance(other, Mapping):
return NotImplemented

# Fast-path: if lengths differ, mappings cannot be equal
if len(self) != len(other):
return False

# Compare materialized items
return dict(self.items()) == dict(other.items())

def __ne__(self, other: Any) -> bool:
"""
Compare inequality with another mapping.

Args:
other: The object to compare with.

Returns:
True if not equal, False if equal.
Returns NotImplemented for non-Mapping types.
"""
eq = self.__eq__(other)
if eq is NotImplemented:
return NotImplemented
return not eq

def get(self, key: Any, default: Optional[Any] = None) -> Any:
"""
Get an item with a default fallback.
Expand Down
4 changes: 4 additions & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -146,6 +146,10 @@ preview = true
fixable = ["ALL"]
unfixable = []

[tool.ruff.lint.pylint]
# Relaxed from the default of 5; existing code has wider try clauses (max observed 38).
max-statements-in-try = 50

# Ignore D1 (docstring checks) and Pylint checks in tests and other non-production code
[tool.ruff.lint.per-file-ignores]
"tests/**/*.py" = ["D1", "PL"]
Expand Down
98 changes: 98 additions & 0 deletions tests/unit/cpex/framework/test_memory.py
Original file line number Diff line number Diff line change
Expand Up @@ -819,6 +819,104 @@ def test_iter_skips_deleted_keys_in_modifications(self):
assert set(keys) == {"b", "c"}
assert "a" not in keys

def test_equality_with_empty_dict(self):
"""CopyOnWriteDict with data should not equal empty dict."""
cow = CopyOnWriteDict({"a": 1, "b": 2})
assert cow != {}
assert {} != cow
assert not (cow == {})
assert not ({} == cow)

def test_equality_with_matching_dict(self):
"""CopyOnWriteDict should equal dict with same key-value pairs."""
original = {"a": 1, "b": 2, "c": 3}
cow = CopyOnWriteDict(original)
assert cow == {"a": 1, "b": 2, "c": 3}
assert {"a": 1, "b": 2, "c": 3} == cow

def test_equality_with_different_dict(self):
"""CopyOnWriteDict should not equal dict with different content."""
cow = CopyOnWriteDict({"a": 1, "b": 2})
assert cow != {"a": 1, "b": 3}
assert cow != {"a": 1}
assert cow != {"a": 1, "b": 2, "c": 3}
# Same length, different keys
assert cow != {"a": 1, "c": 2}

def test_equality_after_modifications(self):
"""Equality should reflect modifications."""
cow = CopyOnWriteDict({"a": 1, "b": 2})
cow["c"] = 3
assert cow == {"a": 1, "b": 2, "c": 3}
assert cow != {"a": 1, "b": 2}

def test_equality_after_deletions(self):
"""Equality should reflect deletions."""
cow = CopyOnWriteDict({"a": 1, "b": 2, "c": 3})
del cow["b"]
assert cow == {"a": 1, "c": 3}
assert cow != {"a": 1, "b": 2, "c": 3}

def test_equality_after_override(self):
"""Equality should reflect overridden values."""
cow = CopyOnWriteDict({"a": 1, "b": 2})
cow["a"] = 10
assert cow == {"a": 10, "b": 2}
assert cow != {"a": 1, "b": 2}

def test_equality_with_another_copyonwritedict(self):
"""Two CopyOnWriteDict instances with same content should be equal."""
cow1 = CopyOnWriteDict({"a": 1, "b": 2})
cow2 = CopyOnWriteDict({"a": 1, "b": 2})
assert cow1 == cow2
assert cow2 == cow1

def test_equality_empty_copyonwritedict(self):
"""Empty CopyOnWriteDict should equal empty dict."""
cow = CopyOnWriteDict({})
assert cow == {}
assert {} == cow

def test_equality_with_non_mapping_returns_notimplemented(self):
"""Equality with non-Mapping types should return NotImplemented."""
cow = CopyOnWriteDict({"a": 1})
# These should not raise, Python will handle NotImplemented
assert cow != "not a dict"
assert cow != 123
assert cow != ["a", "list"]
assert cow != None

def test_inequality_operator(self):
"""Test __ne__ operator works correctly."""
cow = CopyOnWriteDict({"a": 1, "b": 2})
assert cow != {}
assert cow != {"a": 1}
assert not (cow != {"a": 1, "b": 2})

def test_copyonwritedict_is_unhashable(self):
"""CopyOnWriteDict should remain unhashable like dict."""
cow = CopyOnWriteDict({"a": 1})
with pytest.raises(TypeError):
hash(cow)

def test_equality_wxo_args_scenario(self):
"""Regression test for the WXO args bug scenario."""
# This is the exact scenario from the bug report
cow = CopyOnWriteDict({
"wxo_connection_id": "",
"wxo_auth": "fake-token",
"wxo_environment_id": "draft",
})

# These were the failing assertions in the bug
assert cow != {}
assert {} != cow
assert cow == {
"wxo_connection_id": "",
"wxo_auth": "fake-token",
"wxo_environment_id": "draft",
}


class TestCopyOnWriteFunction:
"""Test suite for copyonwrite() factory function."""
Expand Down
114 changes: 114 additions & 0 deletions tests/unit/cpex/framework/test_policies.py
Original file line number Diff line number Diff line change
Expand Up @@ -172,6 +172,65 @@ class PayloadWithModel(PluginPayload):
assert result is not None
assert result.nested.x == 99 # type: ignore[union-attr]

def test_copyonwritedict_args_empty_modification_preserved(self):
"""Regression test for bug where CopyOnWriteDict equality caused
apply_policy to drop valid empty args modification.

When a plugin receives args as CopyOnWriteDict with data and returns
an empty dict, apply_policy should treat this as a valid modification.
Previously, CopyOnWriteDict.__eq__ was not implemented, causing the
comparison to use dict's default equality which compared the empty
base storage, incorrectly returning True for CopyOnWriteDict({...}) == {}.
"""
from cpex.framework.memory import CopyOnWriteDict

policy = HookPayloadPolicy(writable_fields=frozenset({"args"}))

# Simulate plugin receiving payload with CopyOnWriteDict args
original = SamplePayload(
name="test",
args=CopyOnWriteDict({
"wxo_connection_id": "",
"wxo_auth": "fake-token",
"wxo_environment_id": "draft",
}),
secret="s",
)

# Plugin strips all args, returning empty dict
modified = SamplePayload(name="test", args={}, secret="s")

result = apply_policy(original, modified, policy)

# The modification should be preserved, not dropped
assert result is not None, "apply_policy should not return None when args changed from {...} to {}"
assert result.args == {} # type: ignore[union-attr]
assert result.name == "test" # type: ignore[union-attr]
assert result.secret == "s" # type: ignore[union-attr]

def test_copyonwritedict_args_partial_modification_preserved(self):
"""Test that partial arg removal is also preserved correctly."""
from cpex.framework.memory import CopyOnWriteDict

policy = HookPayloadPolicy(writable_fields=frozenset({"args"}))

original = SamplePayload(
name="test",
args=CopyOnWriteDict({
"wxo_auth": "token",
"real_arg": "value",
}),
secret="s",
)

# Plugin removes only wxo_auth, keeping real_arg
modified = SamplePayload(name="test", args={"real_arg": "value"}, secret="s")

result = apply_policy(original, modified, policy)

assert result is not None
assert result.args == {"real_arg": "value"} # type: ignore[union-attr]


class TestPluginPayloadFrozen:
"""Tests for frozen PluginPayload base class."""
Expand Down Expand Up @@ -752,6 +811,61 @@ async def tool_pre_invoke(self, payload, context):
assert result.modified_payload.secret == "safe" # Policy filtered this out


@pytest.mark.asyncio
async def test_tool_pre_invoke_empty_args_modification_preserved_through_executor(self):
"""Regression test for the tool_pre_invoke executor path.

A plugin receives CoW-wrapped args containing only specific fields,
strips them all, and returns a payload with args={}. The executor should
preserve that empty args modification instead of dropping it as
"unchanged".
"""
from cpex.framework.base import HookRef, Plugin, PluginRef
from cpex.framework.hooks.policies import HookPayloadPolicy
from cpex.framework.hooks.tools import ToolPreInvokePayload
from cpex.framework.manager import PluginExecutor
from cpex.framework.memory import CopyOnWriteDict
from cpex.framework.models import GlobalContext, PluginConfig, PluginResult

seen_arg_types = []

class StripWxoArgsPlugin(Plugin):
async def tool_pre_invoke(self, payload, context):
seen_arg_types.append(type(payload.args))
cleaned_args = {k: v for k, v in payload.args.items() if not k.startswith("wxo_")}
modified = payload.model_copy(update={"args": cleaned_args})
return PluginResult(continue_processing=True, modified_payload=modified)

policies = {
"tool_pre_invoke": HookPayloadPolicy(writable_fields=frozenset({"args"})),
}
executor = PluginExecutor(hook_policies=policies)

config = PluginConfig(name="stripper", kind="test.Plugin", version="1.0", hooks=["tool_pre_invoke"])
plugin = StripWxoArgsPlugin(config)
hook_ref = HookRef("tool_pre_invoke", PluginRef(plugin))

payload = ToolPreInvokePayload(
name="list_all_secrets",
args={
"wxo_connection_id": "",
"wxo_auth": "fake-token",
"wxo_environment_id": "draft",
},
)
global_ctx = GlobalContext(request_id="tool-pre-empty-args")

result, _ = await executor.execute([hook_ref], payload, global_ctx, hook_type="tool_pre_invoke")

assert seen_arg_types == [CopyOnWriteDict]
assert result.modified_payload is not None
assert result.modified_payload == ToolPreInvokePayload(name="list_all_secrets", args={})
assert payload.args == {
"wxo_connection_id": "",
"wxo_auth": "fake-token",
"wxo_environment_id": "draft",
}

class TestMultiPluginDictChain:
"""Tests for multi-plugin chains where an earlier plugin returns a dict payload."""

Expand Down
Loading