From 97678ae991bcd8fd4758f106b4aa2ce27edf457e Mon Sep 17 00:00:00 2001 From: mrwuliu Date: Tue, 28 Apr 2026 04:06:26 +0800 Subject: [PATCH 1/6] =?UTF-8?q?feat:=20P0+P1=20improvements=20=E2=80=94=20?= =?UTF-8?q?bug=20fix,=20type=20annotations,=20functional=20methods,=20asyn?= =?UTF-8?q?c=20support?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit P0 fixes: - Fix pipeline.py bug: return value → return attr for non-callable value attrs - Add Generic[T] type annotations across all modules - Add from __future__ import annotations for forward references - Expand test coverage to 83% with direct tests for CommonChain dunders P1 features: - Add functional composition: .map(), .filter(), .flat_map(), .inspect(), .tap() - Add async support: @async_chainable, @async_pipeline with __await__ - Add py.typed marker for PEP 561 compliance - Update __init__.py to export all 4 decorators 34 tests passing, 83% coverage. --- pychain/__init__.py | 8 +- pychain/async_chainable.py | 69 ++++++++ pychain/async_pipeline.py | 71 ++++++++ pychain/chainable.py | 35 ++-- pychain/common.py | 96 +++++++---- pychain/enum.py | 1 + pychain/pipeline.py | 29 +++- pychain/py.typed | 0 tests/test_chainable.py | 332 ++++++++++++++++++++++++++++++++++++- 9 files changed, 589 insertions(+), 52 deletions(-) create mode 100644 pychain/async_chainable.py create mode 100644 pychain/async_pipeline.py create mode 100644 pychain/py.typed diff --git a/pychain/__init__.py b/pychain/__init__.py index 055ad4b..7350bc8 100644 --- a/pychain/__init__.py +++ b/pychain/__init__.py @@ -1,7 +1,9 @@ -from .chainable import chainable -from .pipeline import pipeline +from pychain.chainable import chainable +from pychain.pipeline import pipeline +from pychain.async_chainable import async_chainable +from pychain.async_pipeline import async_pipeline __all__ = [ - 'chainable', 'pipeline' + "chainable", "pipeline", "async_chainable", "async_pipeline", ] \ No newline at end of file diff --git a/pychain/async_chainable.py b/pychain/async_chainable.py new file mode 100644 index 0000000..555e3a4 --- /dev/null +++ b/pychain/async_chainable.py @@ -0,0 +1,69 @@ +from __future__ import annotations + +from functools import wraps +from typing import Any, Callable, TypeVar +from pychain.common import CommonChain + +T = TypeVar("T") + + +class AsyncChainableResult(CommonChain[T]): + + def __init__(self, instance: Any, awaitable_or_value: Any) -> None: + super().__init__(instance, awaitable_or_value) + + def __getattribute__(self, name: str) -> Any: + if name in ("_instance", "_value"): + return object.__getattribute__(self, name) + + try: + proxy_attr = object.__getattribute__(self, name) + if callable(proxy_attr): + return proxy_attr + return proxy_attr + except AttributeError: + pass + + instance = object.__getattribute__(self, "_instance") + value = object.__getattribute__(self, "_value") + + if hasattr(instance, name): + attr = getattr(instance, name) + if callable(attr): + def method_caller(*args: Any, **kwargs: Any) -> AsyncChainableResult[Any]: + async def async_step() -> Any: + await value + inner = attr(*args, **kwargs) + inner_value = object.__getattribute__(inner, "_value") + return await inner_value + return AsyncChainableResult(instance, async_step()) + return method_caller + return attr + + if hasattr(value, name): + attr = getattr(value, name) + if callable(attr): + return lambda *args, **kwargs: attr(*args, **kwargs) + return attr + + raise AttributeError( + f"'{type(self).__name__}' object has no attribute '{name}'" + ) + + def __await__(self): + return object.__getattribute__(self, "_value").__await__() + + def __call__(self, *args: Any, **kwargs: Any) -> Any: + if callable(object.__getattribute__(self, "_value")): + return object.__getattribute__(self, "_value")(*args, **kwargs) + raise TypeError( + f"'{type(object.__getattribute__(self, '_value')).__name__}' object is not callable" + ) + + +def async_chainable(func: Callable[..., Any]) -> Callable[..., AsyncChainableResult[Any]]: + @wraps(func) + def wrapper(self: Any, *args: Any, **kwargs: Any) -> AsyncChainableResult[Any]: + coro = func(self, *args, **kwargs) + return AsyncChainableResult(self, coro) + return wrapper diff --git a/pychain/async_pipeline.py b/pychain/async_pipeline.py new file mode 100644 index 0000000..658e124 --- /dev/null +++ b/pychain/async_pipeline.py @@ -0,0 +1,71 @@ +from __future__ import annotations + +from functools import wraps +from typing import Any, Callable, TypeVar +from pychain.common import CommonChain +from pychain.enum import VALUE, INSTANCE + +T = TypeVar("T") + + +class AsyncPipelineResult(CommonChain[T]): + + def __init__(self, instance: Any, awaitable_or_value: Any) -> None: + super().__init__(instance, awaitable_or_value) + + def __getattribute__(self, name: str) -> Any: + if name == "__class__": + return type(self) + + if name in (INSTANCE, VALUE): + return object.__getattribute__(self, name) + + try: + proxy_attr = object.__getattribute__(self, name) + if callable(proxy_attr): + return proxy_attr + return proxy_attr + except AttributeError: + pass + + instance = object.__getattribute__(self, INSTANCE) + value = object.__getattribute__(self, VALUE) + + if hasattr(instance, name): + attr = getattr(instance, name) + if callable(attr): + async def make_step(prev_value: Any, **kwargs: Any) -> Any: + resolved = await prev_value + if isinstance(resolved, tuple): + inner = attr(*resolved, **kwargs) + else: + inner = attr(resolved, **kwargs) + inner_value = object.__getattribute__(inner, VALUE) + return await inner_value + + return lambda **kwargs: AsyncPipelineResult( + instance, make_step(value, **kwargs) + ) + return attr + + if hasattr(value, name): + attr = getattr(value, name) + if callable(attr): + return lambda *args, **kwargs: attr(*args, **kwargs) + return attr + + raise AttributeError(name) + + def __call__(self, *args: Any, **kwargs: Any) -> AsyncPipelineResult[T]: + return self + + def __await__(self): + return object.__getattribute__(self, VALUE).__await__() + + +def async_pipeline(func: Callable[..., Any]) -> Callable[..., AsyncPipelineResult[Any]]: + @wraps(func) + def wrapper(self: Any, *args: Any, **kwargs: Any) -> AsyncPipelineResult[Any]: + coro = func(self, *args, **kwargs) + return AsyncPipelineResult(self, coro) + return wrapper diff --git a/pychain/chainable.py b/pychain/chainable.py index 3314125..914906c 100644 --- a/pychain/chainable.py +++ b/pychain/chainable.py @@ -1,14 +1,26 @@ +from __future__ import annotations + from functools import wraps -from typing import Any +from typing import Any, Callable, TypeVar from pychain.common import CommonChain +T = TypeVar("T") + -class ChainableResult(CommonChain): +class ChainableResult(CommonChain[T]): - def __getattribute__(self, name) -> Any: - if name in ["_instance", "_value"]: + def __getattribute__(self, name: str) -> Any: + if name in ("_instance", "_value"): return object.__getattribute__(self, name) + try: + proxy_attr = object.__getattribute__(self, name) + if callable(proxy_attr): + return proxy_attr + return proxy_attr + except AttributeError: + pass + instance = object.__getattribute__(self, "_instance") value = object.__getattribute__(self, "_value") @@ -30,16 +42,17 @@ def __getattribute__(self, name) -> Any: f"'{type(self).__name__}' object has no attribute '{name}'" ) - def __call__(self, *args, **kwargs) -> Any: - if callable(self._value): - return self._value(*args, **kwargs) - raise TypeError(f"'{type(self._value).__name__}' object is not callable") + def __call__(self, *args: Any, **kwargs: Any) -> Any: + if callable(object.__getattribute__(self, "_value")): + return object.__getattribute__(self, "_value")(*args, **kwargs) + raise TypeError( + f"'{type(object.__getattribute__(self, '_value')).__name__}' object is not callable" + ) -def chainable(func): +def chainable(func: Callable[..., T]) -> Callable[..., ChainableResult[T]]: @wraps(func) - def wrapper(self, *args, **kwargs): + def wrapper(self: Any, *args: Any, **kwargs: Any) -> ChainableResult[T]: result = func(self, *args, **kwargs) return ChainableResult(self, result) - return wrapper diff --git a/pychain/common.py b/pychain/common.py index a2215c2..882cab2 100644 --- a/pychain/common.py +++ b/pychain/common.py @@ -1,21 +1,26 @@ +from __future__ import annotations + import operator -from typing import Any +from typing import Any, Callable, Generic, TypeVar from pychain.enum import VALUE, INSTANCE +T = TypeVar("T") +U = TypeVar("U") + -class CommonChain: +class CommonChain(Generic[T]): - def __init__(self, instance, value) -> None: + def __init__(self, instance: Any, value: T) -> None: object.__setattr__(self, INSTANCE, instance) object.__setattr__(self, VALUE, value) - def __str__(self): + def __str__(self) -> str: return str(object.__getattribute__(self, VALUE)) def __repr__(self) -> str: return repr(object.__getattribute__(self, VALUE)) - def __dir__(self) -> list: + def __dir__(self) -> list[str]: instance = object.__getattribute__(self, INSTANCE) value = object.__getattribute__(self, VALUE) return sorted(set(dir(instance) + dir(value))) @@ -32,80 +37,111 @@ def __bool__(self) -> bool: def __index__(self) -> int: return operator.index(object.__getattribute__(self, VALUE)) - def __add__(self, other: Any): + def __add__(self, other: Any) -> Any: return object.__getattribute__(self, VALUE) + other - def __sub__(self, other: Any): + def __sub__(self, other: Any) -> Any: return object.__getattribute__(self, VALUE) - other - def __mul__(self, other: Any): + def __mul__(self, other: Any) -> Any: return object.__getattribute__(self, VALUE) * other - def __truediv__(self, other: Any): + def __truediv__(self, other: Any) -> Any: return object.__getattribute__(self, VALUE) / other - def __floordiv__(self, other: Any): + def __floordiv__(self, other: Any) -> Any: return object.__getattribute__(self, VALUE) // other - def __mod__(self, other: Any): + def __mod__(self, other: Any) -> Any: return object.__getattribute__(self, VALUE) % other - def __pow__(self, other: Any): + def __pow__(self, other: Any) -> Any: return object.__getattribute__(self, VALUE) ** other - def __radd__(self, other: Any): + def __radd__(self, other: Any) -> Any: return other + object.__getattribute__(self, VALUE) - def __rsub__(self, other: Any): + def __rsub__(self, other: Any) -> Any: return other - object.__getattribute__(self, VALUE) - def __rmul__(self, other: Any): + def __rmul__(self, other: Any) -> Any: return other * object.__getattribute__(self, VALUE) - def __rtruediv__(self, other: Any): + def __rtruediv__(self, other: Any) -> Any: return other / object.__getattribute__(self, VALUE) - def __rfloordiv__(self, other: Any): + def __rfloordiv__(self, other: Any) -> Any: return other // object.__getattribute__(self, VALUE) - def __rmod__(self, other: Any): + def __rmod__(self, other: Any) -> Any: return other % object.__getattribute__(self, VALUE) - def __rpow__(self, other: Any): + def __rpow__(self, other: Any) -> Any: return other ** object.__getattribute__(self, VALUE) - def __matmul__(self, other: Any): + def __matmul__(self, other: Any) -> Any: return object.__getattribute__(self, VALUE) @ other - def __rmatmul__(self, other: Any): + def __rmatmul__(self, other: Any) -> Any: return other @ object.__getattribute__(self, VALUE) - def __neg__(self): + def __neg__(self) -> Any: return -object.__getattribute__(self, VALUE) - def __pos__(self): + def __pos__(self) -> Any: return +object.__getattribute__(self, VALUE) - def __abs__(self): + def __abs__(self) -> Any: return abs(object.__getattribute__(self, VALUE)) - def __invert__(self): + def __invert__(self) -> Any: return ~object.__getattribute__(self, VALUE) - def __ne__(self, value: object): + def __ne__(self, value: object) -> bool: return object.__getattribute__(self, VALUE) != value def __eq__(self, value: object) -> bool: return object.__getattribute__(self, VALUE) == value - def __ge__(self, value: object): + def __ge__(self, value: object) -> bool: return object.__getattribute__(self, VALUE) >= value - def __gt__(self, value: object): + def __gt__(self, value: object) -> bool: return object.__getattribute__(self, VALUE) > value - def __le__(self, value: object): + def __le__(self, value: object) -> bool: return object.__getattribute__(self, VALUE) <= value - def __lt__(self, value: object): + def __lt__(self, value: object) -> bool: return object.__getattribute__(self, VALUE) < value + + # -- Functional composition methods -- + + def map(self, fn: Callable[[T], U]) -> CommonChain[U]: + """Transform _value with fn, return new proxy wrapping result.""" + value = object.__getattribute__(self, VALUE) + result = fn(value) + instance = object.__getattribute__(self, INSTANCE) + return type(self)(instance, result) + + def filter(self, fn: Callable[[T], bool]) -> CommonChain[T]: + """Return self if fn(_value) is truthy, otherwise raise ValueError.""" + value = object.__getattribute__(self, VALUE) + if not fn(value): + raise ValueError(f"Filter predicate returned False for {value!r}") + return self + + def flat_map(self, fn: Callable[[T], Any]) -> Any: + """Transform _value with fn and unwrap the result (no wrapping in proxy).""" + value = object.__getattribute__(self, VALUE) + return fn(value) + + def inspect(self, fn: Callable[[T], None]) -> CommonChain[T]: + """Execute fn on _value for side effects (e.g., debugging), return self unchanged.""" + value = object.__getattribute__(self, VALUE) + fn(value) + return self + + def tap(self, fn: Callable[[T], None]) -> CommonChain[T]: + """Alias for inspect.""" + return self.inspect(fn) diff --git a/pychain/enum.py b/pychain/enum.py index 40d2679..f00c9e2 100644 --- a/pychain/enum.py +++ b/pychain/enum.py @@ -1,3 +1,4 @@ +from __future__ import annotations VALUE = "_value" INSTANCE = "_instance" \ No newline at end of file diff --git a/pychain/pipeline.py b/pychain/pipeline.py index e49cef9..140587d 100644 --- a/pychain/pipeline.py +++ b/pychain/pipeline.py @@ -1,9 +1,14 @@ +from __future__ import annotations + from functools import wraps -from typing import Any, Callable +from typing import Any, Callable, TypeVar from pychain.common import CommonChain from pychain.enum import VALUE, INSTANCE -class PipelineResult(CommonChain): +T = TypeVar("T") + + +class PipelineResult(CommonChain[T]): def __getattribute__(self, name: str) -> Any: if name == "__class__": @@ -14,8 +19,17 @@ def __getattribute__(self, name: str) -> Any: else: return type(self) - if name in [INSTANCE, VALUE]: + if name in (INSTANCE, VALUE): return object.__getattribute__(self, name) + + try: + proxy_attr = object.__getattribute__(self, name) + if callable(proxy_attr): + return proxy_attr + return proxy_attr + except AttributeError: + pass + instance = object.__getattribute__(self, INSTANCE) value = object.__getattribute__(self, VALUE) @@ -31,15 +45,16 @@ def __getattribute__(self, name: str) -> Any: attr = getattr(value, name) if callable(attr): return lambda *args, **kwargs: attr(*args, **kwargs) - return value + return attr raise AttributeError(name) - def __call__(self, *args, **kwargs) -> Any: + def __call__(self, *args: Any, **kwargs: Any) -> PipelineResult[T]: return self -def pipeline(func: Callable) -> Callable: + +def pipeline(func: Callable[..., T]) -> Callable[..., PipelineResult[T]]: @wraps(func) - def wrapper(self, *args, **kwargs) -> PipelineResult: + def wrapper(self: Any, *args: Any, **kwargs: Any) -> PipelineResult[T]: result = func(self, *args, **kwargs) return PipelineResult(self, result) return wrapper diff --git a/pychain/py.typed b/pychain/py.typed new file mode 100644 index 0000000..e69de29 diff --git a/tests/test_chainable.py b/tests/test_chainable.py index b23b6ea..7b01893 100644 --- a/tests/test_chainable.py +++ b/tests/test_chainable.py @@ -204,6 +204,336 @@ def scale(self, x, y, factor: float): res2 = res.scale(factor=1/2) self.assertEqual(res2, (3 , -2)) - + + +class TestCommonChain(unittest.TestCase): + def test_str(self): + from pychain.common import CommonChain + proxy = CommonChain("instance", 42) + self.assertEqual(str(proxy), "42") + + def test_repr(self): + from pychain.common import CommonChain + proxy = CommonChain("instance", "hello") + self.assertEqual(repr(proxy), "'hello'") + + def test_int(self): + from pychain.common import CommonChain + proxy = CommonChain("instance", 7) + self.assertEqual(int(proxy), 7) + + def test_float(self): + from pychain.common import CommonChain + proxy = CommonChain("instance", 3.14) + self.assertAlmostEqual(float(proxy), 3.14) + + def test_bool(self): + from pychain.common import CommonChain + self.assertTrue(bool(CommonChain("i", 1))) + self.assertTrue(bool(CommonChain("i", "nonempty"))) + self.assertFalse(bool(CommonChain("i", 0))) + self.assertFalse(bool(CommonChain("i", ""))) + + def test_index(self): + from pychain.common import CommonChain + proxy = CommonChain("instance", 5) + self.assertEqual([0, 1, 2, 3, 4, 5][proxy], 5) + + def test_arithmetic(self): + from pychain.common import CommonChain + p = CommonChain("i", 10) + self.assertEqual(p + 5, 15) + self.assertEqual(p - 3, 7) + self.assertEqual(p * 2, 20) + self.assertEqual(p / 4, 2.5) + self.assertEqual(p // 3, 3) + self.assertEqual(p % 3, 1) + self.assertEqual(p ** 2, 100) + + def test_reverse_arithmetic(self): + from pychain.common import CommonChain + p = CommonChain("i", 10) + self.assertEqual(5 + p, 15) + self.assertEqual(20 - p, 10) + self.assertEqual(3 * p, 30) + self.assertEqual(100 / p, 10.0) + self.assertEqual(23 // p, 2) + self.assertEqual(23 % p, 3) + self.assertEqual(2 ** p, 1024) + + def test_matmul(self): + import numpy as np + from pychain.common import CommonChain + a = np.array([[1, 2]]) + b = np.array([[3], [4]]) + p = CommonChain("i", a) + result = p @ b + expected = a @ b + self.assertTrue(np.array_equal(result, expected)) + + def test_unary(self): + from pychain.common import CommonChain + p = CommonChain("i", 5) + self.assertEqual(-p, -5) + self.assertEqual(+p, 5) + self.assertEqual(abs(CommonChain("i", -7)), 7) + self.assertEqual(~CommonChain("i", 0), -1) + + def test_comparison(self): + from pychain.common import CommonChain + p = CommonChain("i", 10) + self.assertTrue(p == 10) + self.assertTrue(p != 11) + self.assertTrue(p > 5) + self.assertTrue(p < 20) + self.assertTrue(p >= 10) + self.assertTrue(p <= 10) + self.assertFalse(p > 15) + + def test_dir(self): + from pychain.common import CommonChain + d = dir(CommonChain("instance_obj", 42)) + self.assertIsInstance(d, list) + self.assertIn("__class__", d) + + +class TestFunctionalComposition(unittest.TestCase): + def test_map(self): + from pychain.chainable import ChainableResult, chainable + + class Calc: + @chainable + def add(self, a, b): + return a + b + + calc = Calc() + result = calc.add(3, 4).map(lambda x: x * 2) + self.assertEqual(int(result), 14) + + def test_filter_pass(self): + from pychain.chainable import ChainableResult, chainable + + class Calc: + @chainable + def add(self, a, b): + return a + b + + calc = Calc() + result = calc.add(3, 4).filter(lambda x: x > 0) + self.assertEqual(int(result), 7) + + def test_filter_fail(self): + from pychain.chainable import chainable + + class Calc: + @chainable + def add(self, a, b): + return a + b + + calc = Calc() + with self.assertRaises(ValueError): + calc.add(3, 4).filter(lambda x: x < 0) + + def test_flat_map(self): + from pychain.chainable import chainable + + class Calc: + @chainable + def add(self, a, b): + return a + b + + calc = Calc() + result = calc.add(3, 4).flat_map(lambda x: x * 10) + self.assertEqual(result, 70) + self.assertNotIsInstance(result, type(calc.add(1, 2))) + + def test_inspect(self): + from pychain.chainable import chainable + collected = [] + + class Calc: + @chainable + def add(self, a, b): + return a + b + + calc = Calc() + result = calc.add(3, 4).inspect(lambda x: collected.append(x)) + self.assertEqual(collected, [7]) + self.assertEqual(int(result), 7) + + def test_tap(self): + from pychain.chainable import chainable + collected = [] + + class Calc: + @chainable + def add(self, a, b): + return a + b + + calc = Calc() + result = calc.add(3, 4).tap(lambda x: collected.append(x)) + self.assertEqual(collected, [7]) + self.assertEqual(int(result), 7) + + def test_chain_functional_methods(self): + from pychain.chainable import chainable + log = [] + + class Calc: + @chainable + def add(self, a, b): + return a + b + + @chainable + def multiply(self, a, b): + return a * b + + calc = Calc() + result = ( + calc.add(2, 3) + .map(lambda x: x + 1) + .inspect(lambda x: log.append(x)) + .filter(lambda x: x > 0) + .map(lambda x: x * 2) + ) + self.assertEqual(log, [6]) + self.assertEqual(int(result), 12) + + def test_pipeline_functional_methods(self): + from pychain.pipeline import pipeline + + class Calc: + @pipeline + def double(self, x): + return x * 2 + + calc = Calc() + result = calc.double(5).map(lambda x: x + 1) + self.assertEqual(int(result), 11) + + +class TestAsyncChainable(unittest.TestCase): + def test_basic_async_chain(self): + import asyncio + from pychain.async_chainable import async_chainable, AsyncChainableResult + + class AsyncCalc: + @async_chainable + async def add(self, x): + return x + 10 + + async def run(): + calc = AsyncCalc() + result = calc.add(5) + self.assertIsInstance(result, AsyncChainableResult) + val = await result + self.assertEqual(val, 15) + + asyncio.run(run()) + + def test_await_final_result(self): + import asyncio + from pychain.async_chainable import async_chainable + + class AsyncCalc: + @async_chainable + async def multiply(self, x): + return x * 3 + + async def run(): + calc = AsyncCalc() + val = await calc.multiply(4) + self.assertEqual(val, 12) + + asyncio.run(run()) + + def test_multi_step_chain(self): + import asyncio + from pychain.async_chainable import async_chainable + + class AsyncCalc: + def __init__(self): + self.value = 0 + + @async_chainable + async def add(self, x): + self.value += x + return self.value + + async def run(): + calc = AsyncCalc() + result = calc.add(3).add(5) + val = await result + self.assertEqual(val, 8) + + asyncio.run(run()) + + +class TestAsyncPipeline(unittest.TestCase): + def test_basic_async_pipeline(self): + import asyncio + from pychain.async_pipeline import async_pipeline, AsyncPipelineResult + + class AsyncPipe: + @async_pipeline + async def add_one(self, x): + return x + 1 + + @async_pipeline + async def double(self, x): + return x * 2 + + async def run(): + pipe = AsyncPipe() + result = pipe.add_one(5) + self.assertIsInstance(result, AsyncPipelineResult) + val = await result + self.assertEqual(val, 6) + + asyncio.run(run()) + + def test_async_pipeline_chaining(self): + import asyncio + from pychain.async_pipeline import async_pipeline + + class AsyncPipe: + @async_pipeline + async def increment(self, x): + return x + 1 + + @async_pipeline + async def negate(self, x): + return -x + + async def run(): + pipe = AsyncPipe() + result = pipe.increment(10).negate() + val = await result + self.assertEqual(val, -11) + + asyncio.run(run()) + + def test_async_pipeline_tuple_unpacking(self): + import asyncio + from pychain.async_pipeline import async_pipeline + + class AsyncPipe: + @async_pipeline + async def split(self, x): + return (x, x * 2) + + @async_pipeline + async def sum_pair(self, a, b): + return a + b + + async def run(): + pipe = AsyncPipe() + result = pipe.split(5).sum_pair() + val = await result + self.assertEqual(val, 15) + + asyncio.run(run()) + + if __name__ == "__main__": unittest.main() From 37c87b78b3fb7a2908996c98f4ba1731cb8512fe Mon Sep 17 00:00:00 2001 From: mrwuliu Date: Tue, 28 Apr 2026 04:23:04 +0800 Subject: [PATCH 2/6] feat: add basedpyright strict type checking with 0 diagnostics - Add [tool.basedpyright] config (strict mode) with justified suppressions for proxy-pattern-inherent Any, override, and unknown lambda warnings - Fix map() return type error with targeted pyright ignore comment - Create tests/test_types.py with TYPE_CHECKING guards for public API type inference verification (chainable, pipeline, async, functional) - Add basedpyright type-check step to CI workflow basedpyright: 0 errors, 0 warnings, 0 notes pytest: 34/34 passed --- .github/workflows/test.yml | 7 ++++- pychain/common.py | 3 +- pyproject.toml | 13 +++++++++ tests/test_types.py | 56 ++++++++++++++++++++++++++++++++++++++ 4 files changed, 76 insertions(+), 3 deletions(-) create mode 100644 tests/test_types.py diff --git a/.github/workflows/test.yml b/.github/workflows/test.yml index 281fd71..96084f1 100644 --- a/.github/workflows/test.yml +++ b/.github/workflows/test.yml @@ -28,4 +28,9 @@ jobs: pip install pytest pytest-cov numpy - name: Run tests with coverage (>=50%) - run: pytest --cov=pychain --cov-fail-under=50 --cov-report=term-missing -v tests/ + run: pytest --cov=pychain --cov-fail-under=50 --cov-report=term-missing -v tests/test_chainable.py + + - name: Type check with basedpyright + run: | + pip install basedpyright + basedpyright pychain/ tests/test_types.py diff --git a/pychain/common.py b/pychain/common.py index 882cab2..5dd5ec9 100644 --- a/pychain/common.py +++ b/pychain/common.py @@ -118,11 +118,10 @@ def __lt__(self, value: object) -> bool: # -- Functional composition methods -- def map(self, fn: Callable[[T], U]) -> CommonChain[U]: - """Transform _value with fn, return new proxy wrapping result.""" value = object.__getattribute__(self, VALUE) result = fn(value) instance = object.__getattribute__(self, INSTANCE) - return type(self)(instance, result) + return type(self)(instance, result) # pyright: ignore[reportReturnType,reportArgumentType] def filter(self, fn: Callable[[T], bool]) -> CommonChain[T]: """Return self if fn(_value) is truthy, otherwise raise ValueError.""" diff --git a/pyproject.toml b/pyproject.toml index a121bfb..da744e4 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -25,3 +25,16 @@ classifiers = [ [project.urls] Homepage = "https://github.com/mr-wuliu/PythonChainable" Issues = "https://github.com/mr-wuliu/PythonChainable/issues" + +[tool.basedpyright] +pythonVersion = "3.10" +pythonPlatform = "All" +typeCheckingMode = "strict" +# Proxy pattern legitimately uses Any +reportAny = "none" +reportExplicitAny = "none" +# Many intentional dunder overrides in proxy classes +reportImplicitOverride = "none" +# Lambda params in proxy forwarding are inherently unknown +reportUnknownLambdaType = "none" +reportUnknownVariableType = "none" diff --git a/tests/test_types.py b/tests/test_types.py new file mode 100644 index 0000000..7d3f994 --- /dev/null +++ b/tests/test_types.py @@ -0,0 +1,56 @@ +# pyright: reportUnusedImport=none +"""Type-checking assertions for pychain public API. + +Validated by basedpyright, not executed by pytest. +Verifies that type inference works correctly for all public APIs. +""" +from __future__ import annotations + +from typing import TYPE_CHECKING + +from pychain import async_chainable, async_pipeline, chainable, pipeline +from pychain.async_chainable import AsyncChainableResult +from pychain.async_pipeline import AsyncPipelineResult +from pychain.chainable import ChainableResult +from pychain.common import CommonChain +from pychain.pipeline import PipelineResult + +if TYPE_CHECKING: + + class Calc: + @chainable + def add(self, a: int, b: int) -> int: ... + + c = Calc() + r: ChainableResult[int] = c.add(1, 2) + mapped: CommonChain[int] = c.add(1, 2).map(lambda x: x * 2) + filtered: CommonChain[int] = c.add(1, 2).filter(lambda x: x > 0) + inspected: CommonChain[int] = c.add(1, 2).inspect(lambda x: None) + tapped: CommonChain[int] = c.add(1, 2).tap(lambda x: None) + + class Pipe: + @pipeline + def double(self, x: int) -> int: ... + + p = Pipe() + pr: PipelineResult[int] = p.double(5) + pmapped: CommonChain[int] = p.double(5).map(lambda x: x + 1) + + proxy: CommonChain[int] = CommonChain("i", 42) + add_result: int = proxy + 1 + eq_result: bool = proxy == 42 + str_result: str = str(proxy) + + class AsyncCalc: + @async_chainable + async def compute(self, x: int) -> int: ... + + ac = AsyncCalc() + ar: AsyncChainableResult[int] = ac.compute(5) + + class AsyncPipe: + @async_pipeline + async def step(self, x: int) -> int: ... + + ap = AsyncPipe() + apr: AsyncPipelineResult[int] = ap.step(3) From 239b0f907ea2fc49c8f5429aa703ada80901ea69 Mon Sep 17 00:00:00 2001 From: mrwuliu Date: Tue, 28 Apr 2026 10:28:25 +0800 Subject: [PATCH 3/6] =?UTF-8?q?refactor:=20unify=20async=20into=20chainabl?= =?UTF-8?q?e/pipeline=20decorators=20(4=E2=86=922)?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit @chainable and @pipeline now auto-detect coroutines via inspect.iscoroutine() and return AsyncChainableResult/AsyncPipelineResult accordingly. - Add @overload signatures for proper type narrowing (sync vs async) - Remove async_chainable/async_pipeline from __all__ - Update tests to use @chainable/@pipeline on async methods - async_chainable.py/async_pipeline.py kept for internal result classes __all__ = ['chainable', 'pipeline'] basedpyright: 0 diagnostics, pytest: 34/34 passed --- pychain/__init__.py | 6 +----- pychain/chainable.py | 16 +++++++++++++--- pychain/pipeline.py | 16 +++++++++++++--- tests/test_chainable.py | 26 +++++++++++--------------- tests/test_types.py | 6 +++--- 5 files changed, 41 insertions(+), 29 deletions(-) diff --git a/pychain/__init__.py b/pychain/__init__.py index 7350bc8..742b4ef 100644 --- a/pychain/__init__.py +++ b/pychain/__init__.py @@ -1,9 +1,5 @@ from pychain.chainable import chainable from pychain.pipeline import pipeline -from pychain.async_chainable import async_chainable -from pychain.async_pipeline import async_pipeline -__all__ = [ - "chainable", "pipeline", "async_chainable", "async_pipeline", -] \ No newline at end of file +__all__ = ["chainable", "pipeline"] \ No newline at end of file diff --git a/pychain/chainable.py b/pychain/chainable.py index 914906c..2a6153c 100644 --- a/pychain/chainable.py +++ b/pychain/chainable.py @@ -1,7 +1,9 @@ from __future__ import annotations +import inspect from functools import wraps -from typing import Any, Callable, TypeVar +from typing import Any, Awaitable, Callable, TypeVar, overload +from pychain.async_chainable import AsyncChainableResult from pychain.common import CommonChain T = TypeVar("T") @@ -50,9 +52,17 @@ def __call__(self, *args: Any, **kwargs: Any) -> Any: ) -def chainable(func: Callable[..., T]) -> Callable[..., ChainableResult[T]]: +@overload +def chainable(func: Callable[..., Awaitable[T]]) -> Callable[..., AsyncChainableResult[T]]: ... # pyright: ignore[reportOverlappingOverload] + +@overload +def chainable(func: Callable[..., T]) -> Callable[..., ChainableResult[T]]: ... + +def chainable(func: Callable[..., Any]) -> Callable[..., ChainableResult[Any] | AsyncChainableResult[Any]]: @wraps(func) - def wrapper(self: Any, *args: Any, **kwargs: Any) -> ChainableResult[T]: + def wrapper(self: Any, *args: Any, **kwargs: Any) -> ChainableResult[Any] | AsyncChainableResult[Any]: result = func(self, *args, **kwargs) + if inspect.iscoroutine(result): + return AsyncChainableResult(self, result) return ChainableResult(self, result) return wrapper diff --git a/pychain/pipeline.py b/pychain/pipeline.py index 140587d..1a32f08 100644 --- a/pychain/pipeline.py +++ b/pychain/pipeline.py @@ -1,7 +1,9 @@ from __future__ import annotations +import inspect from functools import wraps -from typing import Any, Callable, TypeVar +from typing import Any, Awaitable, Callable, TypeVar, overload +from pychain.async_pipeline import AsyncPipelineResult from pychain.common import CommonChain from pychain.enum import VALUE, INSTANCE @@ -52,9 +54,17 @@ def __call__(self, *args: Any, **kwargs: Any) -> PipelineResult[T]: return self -def pipeline(func: Callable[..., T]) -> Callable[..., PipelineResult[T]]: +@overload +def pipeline(func: Callable[..., Awaitable[T]]) -> Callable[..., AsyncPipelineResult[T]]: ... # pyright: ignore[reportOverlappingOverload] + +@overload +def pipeline(func: Callable[..., T]) -> Callable[..., PipelineResult[T]]: ... + +def pipeline(func: Callable[..., Any]) -> Callable[..., PipelineResult[Any] | AsyncPipelineResult[Any]]: @wraps(func) - def wrapper(self: Any, *args: Any, **kwargs: Any) -> PipelineResult[T]: + def wrapper(self: Any, *args: Any, **kwargs: Any) -> PipelineResult[Any] | AsyncPipelineResult[Any]: result = func(self, *args, **kwargs) + if inspect.iscoroutine(result): + return AsyncPipelineResult(self, result) return PipelineResult(self, result) return wrapper diff --git a/tests/test_chainable.py b/tests/test_chainable.py index 7b01893..24a9805 100644 --- a/tests/test_chainable.py +++ b/tests/test_chainable.py @@ -415,10 +415,10 @@ def double(self, x): class TestAsyncChainable(unittest.TestCase): def test_basic_async_chain(self): import asyncio - from pychain.async_chainable import async_chainable, AsyncChainableResult + from pychain.async_chainable import AsyncChainableResult class AsyncCalc: - @async_chainable + @chainable async def add(self, x): return x + 10 @@ -433,10 +433,9 @@ async def run(): def test_await_final_result(self): import asyncio - from pychain.async_chainable import async_chainable class AsyncCalc: - @async_chainable + @chainable async def multiply(self, x): return x * 3 @@ -449,13 +448,12 @@ async def run(): def test_multi_step_chain(self): import asyncio - from pychain.async_chainable import async_chainable class AsyncCalc: def __init__(self): self.value = 0 - @async_chainable + @chainable async def add(self, x): self.value += x return self.value @@ -472,14 +470,14 @@ async def run(): class TestAsyncPipeline(unittest.TestCase): def test_basic_async_pipeline(self): import asyncio - from pychain.async_pipeline import async_pipeline, AsyncPipelineResult + from pychain.async_pipeline import AsyncPipelineResult class AsyncPipe: - @async_pipeline + @pipeline async def add_one(self, x): return x + 1 - @async_pipeline + @pipeline async def double(self, x): return x * 2 @@ -494,14 +492,13 @@ async def run(): def test_async_pipeline_chaining(self): import asyncio - from pychain.async_pipeline import async_pipeline class AsyncPipe: - @async_pipeline + @pipeline async def increment(self, x): return x + 1 - @async_pipeline + @pipeline async def negate(self, x): return -x @@ -515,14 +512,13 @@ async def run(): def test_async_pipeline_tuple_unpacking(self): import asyncio - from pychain.async_pipeline import async_pipeline class AsyncPipe: - @async_pipeline + @pipeline async def split(self, x): return (x, x * 2) - @async_pipeline + @pipeline async def sum_pair(self, a, b): return a + b diff --git a/tests/test_types.py b/tests/test_types.py index 7d3f994..2a542ff 100644 --- a/tests/test_types.py +++ b/tests/test_types.py @@ -8,7 +8,7 @@ from typing import TYPE_CHECKING -from pychain import async_chainable, async_pipeline, chainable, pipeline +from pychain import chainable, pipeline from pychain.async_chainable import AsyncChainableResult from pychain.async_pipeline import AsyncPipelineResult from pychain.chainable import ChainableResult @@ -42,14 +42,14 @@ def double(self, x: int) -> int: ... str_result: str = str(proxy) class AsyncCalc: - @async_chainable + @chainable async def compute(self, x: int) -> int: ... ac = AsyncCalc() ar: AsyncChainableResult[int] = ac.compute(5) class AsyncPipe: - @async_pipeline + @pipeline async def step(self, x: int) -> int: ... ap = AsyncPipe() From 2636bd27f071df0c05dc995d3551c758c1c78c2b Mon Sep 17 00:00:00 2001 From: mrwuliu Date: Tue, 28 Apr 2026 14:40:51 +0800 Subject: [PATCH 4/6] fix: resolve async proxy review regressions Ultraworked with [Sisyphus](https://github.com/code-yeongyu/oh-my-openagent) Co-authored-by: Sisyphus --- pychain/async_chainable.py | 58 ++++++ pychain/async_pipeline.py | 67 +++++- pychain/pipeline.py | 4 +- tests/test_chainable.py | 412 +++++++++++++++++++++++++++++++++++++ 4 files changed, 533 insertions(+), 8 deletions(-) diff --git a/pychain/async_chainable.py b/pychain/async_chainable.py index 555e3a4..d69c181 100644 --- a/pychain/async_chainable.py +++ b/pychain/async_chainable.py @@ -53,6 +53,64 @@ async def async_step() -> Any: def __await__(self): return object.__getattribute__(self, "_value").__await__() + # -- Async-aware functional composition overrides -- + # These override CommonChain methods so that _value (a coroutine) is + # awaited before the transformation is applied. + + def map(self, fn: Callable[[Any], Any]) -> AsyncChainableResult[Any]: + prev_value = object.__getattribute__(self, "_value") + instance = object.__getattribute__(self, "_instance") + + async def _async_map() -> Any: + resolved = await prev_value + return fn(resolved) + + return AsyncChainableResult(instance, _async_map()) + + def filter(self, fn: Callable[[Any], bool]) -> AsyncChainableResult[Any]: + prev_value = object.__getattribute__(self, "_value") + instance = object.__getattribute__(self, "_instance") + + async def _async_filter() -> Any: + resolved = await prev_value + if not fn(resolved): + raise ValueError( + f"Filter predicate returned False for {resolved!r}" + ) + return resolved + + return AsyncChainableResult(instance, _async_filter()) + + def flat_map(self, fn: Callable[[Any], Any]) -> AsyncChainableResult[Any]: + """Async version: awaits _value, applies fn, returns async proxy. + + Unlike the sync CommonChain.flat_map which returns the raw unwrapped + value, the async version must return an AsyncChainableResult because + the result cannot be produced synchronously. + """ + prev_value = object.__getattribute__(self, "_value") + instance = object.__getattribute__(self, "_instance") + + async def _async_flat_map() -> Any: + resolved = await prev_value + return fn(resolved) + + return AsyncChainableResult(instance, _async_flat_map()) + + def inspect(self, fn: Callable[[Any], None]) -> AsyncChainableResult[Any]: + prev_value = object.__getattribute__(self, "_value") + instance = object.__getattribute__(self, "_instance") + + async def _async_inspect() -> Any: + resolved = await prev_value + fn(resolved) + return resolved + + return AsyncChainableResult(instance, _async_inspect()) + + def tap(self, fn: Callable[[Any], None]) -> AsyncChainableResult[Any]: + return self.inspect(fn) + def __call__(self, *args: Any, **kwargs: Any) -> Any: if callable(object.__getattribute__(self, "_value")): return object.__getattribute__(self, "_value")(*args, **kwargs) diff --git a/pychain/async_pipeline.py b/pychain/async_pipeline.py index 658e124..e448383 100644 --- a/pychain/async_pipeline.py +++ b/pychain/async_pipeline.py @@ -15,7 +15,12 @@ def __init__(self, instance: Any, awaitable_or_value: Any) -> None: def __getattribute__(self, name: str) -> Any: if name == "__class__": - return type(self) + if value := object.__getattribute__(self, VALUE): + return type(value) + elif instance := object.__getattribute__(self, INSTANCE): + return type(instance) + else: + return type(self) if name in (INSTANCE, VALUE): return object.__getattribute__(self, name) @@ -34,17 +39,17 @@ def __getattribute__(self, name: str) -> Any: if hasattr(instance, name): attr = getattr(instance, name) if callable(attr): - async def make_step(prev_value: Any, **kwargs: Any) -> Any: + async def make_step(prev_value: Any, *args: Any, **kwargs: Any) -> Any: resolved = await prev_value if isinstance(resolved, tuple): - inner = attr(*resolved, **kwargs) + inner = attr(*resolved, *args, **kwargs) else: - inner = attr(resolved, **kwargs) + inner = attr(resolved, *args, **kwargs) inner_value = object.__getattribute__(inner, VALUE) return await inner_value - return lambda **kwargs: AsyncPipelineResult( - instance, make_step(value, **kwargs) + return lambda *args, **kwargs: AsyncPipelineResult( + instance, make_step(value, *args, **kwargs) ) return attr @@ -56,6 +61,56 @@ async def make_step(prev_value: Any, **kwargs: Any) -> Any: raise AttributeError(name) + # -- Async-aware functional composition overrides -- + + def map(self, fn: Callable[[Any], Any]) -> AsyncPipelineResult[Any]: + prev_value = object.__getattribute__(self, VALUE) + instance = object.__getattribute__(self, INSTANCE) + + async def _async_map() -> Any: + resolved = await prev_value + return fn(resolved) + + return AsyncPipelineResult(instance, _async_map()) + + def filter(self, fn: Callable[[Any], bool]) -> AsyncPipelineResult[Any]: + prev_value = object.__getattribute__(self, VALUE) + instance = object.__getattribute__(self, INSTANCE) + + async def _async_filter() -> Any: + resolved = await prev_value + if not fn(resolved): + raise ValueError( + f"Filter predicate returned False for {resolved!r}" + ) + return resolved + + return AsyncPipelineResult(instance, _async_filter()) + + def flat_map(self, fn: Callable[[Any], Any]) -> AsyncPipelineResult[Any]: + prev_value = object.__getattribute__(self, VALUE) + instance = object.__getattribute__(self, INSTANCE) + + async def _async_flat_map() -> Any: + resolved = await prev_value + return fn(resolved) + + return AsyncPipelineResult(instance, _async_flat_map()) + + def inspect(self, fn: Callable[[Any], None]) -> AsyncPipelineResult[Any]: + prev_value = object.__getattribute__(self, VALUE) + instance = object.__getattribute__(self, INSTANCE) + + async def _async_inspect() -> Any: + resolved = await prev_value + fn(resolved) + return resolved + + return AsyncPipelineResult(instance, _async_inspect()) + + def tap(self, fn: Callable[[Any], None]) -> AsyncPipelineResult[Any]: + return self.inspect(fn) + def __call__(self, *args: Any, **kwargs: Any) -> AsyncPipelineResult[T]: return self diff --git a/pychain/pipeline.py b/pychain/pipeline.py index 1a32f08..96e0b93 100644 --- a/pychain/pipeline.py +++ b/pychain/pipeline.py @@ -39,8 +39,8 @@ def __getattribute__(self, name: str) -> Any: attr = getattr(instance, name) if callable(attr): if isinstance(value, tuple): - return lambda **kwargs: attr(*value, **kwargs) - return lambda **kwargs: attr(value, **kwargs) + return lambda *args, **kwargs: attr(*value, *args, **kwargs) + return lambda *args, **kwargs: attr(value, *args, **kwargs) return attr if hasattr(value, name): diff --git a/tests/test_chainable.py b/tests/test_chainable.py index 24a9805..5084035 100644 --- a/tests/test_chainable.py +++ b/tests/test_chainable.py @@ -467,6 +467,135 @@ async def run(): asyncio.run(run()) +class TestAsyncChainableFunctional(unittest.TestCase): + def test_async_chainable_map(self): + import asyncio + from pychain.async_chainable import AsyncChainableResult + + class AsyncCalc: + @chainable + async def add(self, x): + return x + 10 + + async def run(): + calc = AsyncCalc() + result = calc.add(5).map(lambda x: x * 2) + self.assertIsInstance(result, AsyncChainableResult) + val = await result + self.assertEqual(val, 30) + + asyncio.run(run()) + + def test_async_chainable_filter_pass(self): + import asyncio + + class AsyncCalc: + @chainable + async def add(self, x): + return x + 10 + + async def run(): + calc = AsyncCalc() + result = calc.add(5).filter(lambda x: x > 0) + val = await result + self.assertEqual(val, 15) + + asyncio.run(run()) + + def test_async_chainable_filter_fail(self): + import asyncio + + class AsyncCalc: + @chainable + async def add(self, x): + return x + 10 + + async def run(): + calc = AsyncCalc() + with self.assertRaises(ValueError): + await calc.add(5).filter(lambda x: x < 0) + + asyncio.run(run()) + + def test_async_chainable_flat_map(self): + import asyncio + from pychain.async_chainable import AsyncChainableResult + + class AsyncCalc: + @chainable + async def add(self, x): + return x + 10 + + async def run(): + calc = AsyncCalc() + result = calc.add(5).flat_map(lambda x: x * 10) + self.assertIsInstance(result, AsyncChainableResult) + val = await result + self.assertEqual(val, 150) + + asyncio.run(run()) + + def test_async_chainable_inspect(self): + import asyncio + + class AsyncCalc: + @chainable + async def add(self, x): + return x + 10 + + async def run(): + collected = [] + calc = AsyncCalc() + result = calc.add(5).inspect(lambda x: collected.append(x)) + val = await result + self.assertEqual(collected, [15]) + self.assertEqual(val, 15) + + asyncio.run(run()) + + def test_async_chainable_tap(self): + import asyncio + + class AsyncCalc: + @chainable + async def add(self, x): + return x + 10 + + async def run(): + collected = [] + calc = AsyncCalc() + result = calc.add(5).tap(lambda x: collected.append(x)) + val = await result + self.assertEqual(collected, [15]) + self.assertEqual(val, 15) + + asyncio.run(run()) + + def test_async_chainable_combined_functional_chain(self): + import asyncio + + class AsyncCalc: + @chainable + async def add(self, x): + return x + 10 + + async def run(): + log = [] + calc = AsyncCalc() + result = ( + calc.add(5) + .map(lambda x: x * 2) + .inspect(lambda x: log.append(x)) + .filter(lambda x: x > 0) + .map(lambda x: x + 1) + ) + val = await result + self.assertEqual(log, [30]) + self.assertEqual(val, 31) + + asyncio.run(run()) + + class TestAsyncPipeline(unittest.TestCase): def test_basic_async_pipeline(self): import asyncio @@ -531,5 +660,288 @@ async def run(): asyncio.run(run()) +class TestAsyncPipelineFunctional(unittest.TestCase): + def test_async_pipeline_map(self): + import asyncio + from pychain.async_pipeline import AsyncPipelineResult + + class AsyncPipe: + @pipeline + async def double(self, x): + return x * 2 + + async def run(): + pipe = AsyncPipe() + result = pipe.double(5).map(lambda x: x + 1) + self.assertIsInstance(result, AsyncPipelineResult) + val = await result + self.assertEqual(val, 11) + + asyncio.run(run()) + + def test_async_pipeline_filter_pass(self): + import asyncio + + class AsyncPipe: + @pipeline + async def double(self, x): + return x * 2 + + async def run(): + pipe = AsyncPipe() + result = pipe.double(5).filter(lambda x: x > 0) + val = await result + self.assertEqual(val, 10) + + asyncio.run(run()) + + def test_async_pipeline_filter_fail(self): + import asyncio + + class AsyncPipe: + @pipeline + async def double(self, x): + return x * 2 + + async def run(): + pipe = AsyncPipe() + with self.assertRaises(ValueError): + await pipe.double(5).filter(lambda x: x < 0) + + asyncio.run(run()) + + def test_async_pipeline_flat_map(self): + import asyncio + + class AsyncPipe: + @pipeline + async def double(self, x): + return x * 2 + + async def run(): + pipe = AsyncPipe() + result = pipe.double(5).flat_map(lambda x: [x, x + 1]) + val = await result + self.assertEqual(val, [10, 11]) + + asyncio.run(run()) + + def test_async_pipeline_inspect(self): + import asyncio + + class AsyncPipe: + @pipeline + async def double(self, x): + return x * 2 + + async def run(): + collected = [] + pipe = AsyncPipe() + result = pipe.double(5).inspect(lambda x: collected.append(x)) + val = await result + self.assertEqual(collected, [10]) + self.assertEqual(val, 10) + + asyncio.run(run()) + + def test_async_pipeline_tap(self): + import asyncio + + class AsyncPipe: + @pipeline + async def double(self, x): + return x * 2 + + async def run(): + collected = [] + pipe = AsyncPipe() + result = pipe.double(5).tap(lambda x: collected.append(x)) + val = await result + self.assertEqual(collected, [10]) + self.assertEqual(val, 10) + + asyncio.run(run()) + + def test_async_pipeline_combined_functional_chain(self): + import asyncio + + class AsyncPipe: + @pipeline + async def increment(self, x): + return x + 1 + + @pipeline + async def negate(self, x): + return -x + + async def run(): + log = [] + pipe = AsyncPipe() + result = ( + pipe.increment(10) + .map(lambda x: x * 2) + .inspect(lambda x: log.append(x)) + .filter(lambda x: x > 0) + .negate() + ) + val = await result + self.assertEqual(log, [22]) + self.assertEqual(val, -22) + + asyncio.run(run()) + + +class TestPositionalArgPipeline(unittest.TestCase): + """Regression tests for positional args during chained pipeline calls.""" + + def test_sync_pipeline_positional_args(self): + class Pipe: + @pipeline + def add(self, x: int, y: int) -> int: + return x + y + + @pipeline + def multiply(self, x: int, factor: int) -> int: + return x * factor + + p = Pipe() + # Previous value (3) auto-injected as first arg; 5 passed as positional extra + result = p.add(1, 2).multiply(5) + self.assertEqual(int(result), 15) + + def test_sync_pipeline_positional_args_with_kwargs(self): + class Pipe: + @pipeline + def split(self, x: int) -> tuple: + return (x, x + 1) + + @pipeline + def combine(self, a: int, b: int, extra: int = 0) -> int: + return a + b + extra + + p = Pipe() + # Tuple (5, 6) unpacked as (a, b), then extra=10 via kwarg + result = p.split(5).combine(extra=10) + self.assertEqual(int(result), 21) + + def test_sync_pipeline_positional_args_tuple_unpack(self): + class Pipe: + @pipeline + def pair(self, x: int) -> tuple: + return (x, x * 2) + + @pipeline + def sum_with_extra(self, a: int, b: int, extra: int) -> int: + return a + b + extra + + p = Pipe() + # Tuple (3, 6) unpacked as (a, b), then 100 as positional extra + result = p.pair(3).sum_with_extra(100) + self.assertEqual(int(result), 109) + + def test_async_pipeline_positional_args(self): + import asyncio + + class AsyncPipe: + @pipeline + async def add(self, x: int, y: int) -> int: + return x + y + + @pipeline + async def multiply(self, x: int, factor: int) -> int: + return x * factor + + async def run(): + p = AsyncPipe() + result = p.add(1, 2).multiply(5) + val = await result + self.assertEqual(val, 15) + + asyncio.run(run()) + + def test_async_pipeline_positional_args_tuple_unpack(self): + import asyncio + + class AsyncPipe: + @pipeline + async def pair(self, x: int): + return (x, x * 2) + + @pipeline + async def sum_with_extra(self, a, b, extra): + return a + b + extra + + async def run(): + p = AsyncPipe() + result = p.pair(3).sum_with_extra(100) + val = await result + self.assertEqual(val, 109) + + asyncio.run(run()) + + +class TestClassBehavior(unittest.TestCase): + """Regression tests for __class__ proxy behavior in pipeline results.""" + + def test_sync_pipeline_class_returns_value_type(self): + from pychain.pipeline import PipelineResult + + class Pipe: + @pipeline + def double(self, x: int) -> int: + return x * 2 + + p = Pipe() + result = p.double(5) + # __class__ should return the wrapped value's type (int) + self.assertEqual(result.__class__, int) + # type() bypasses __getattribute__ and returns PipelineResult + self.assertEqual(type(result), PipelineResult) + + def test_sync_pipeline_class_with_string_value(self): + from pychain.pipeline import PipelineResult + + class StrPipe: + @pipeline + def upper(self, s: str) -> str: + return s.upper() + + p = StrPipe() + result = p.upper("hello") + self.assertEqual(result.__class__, str) + self.assertEqual(type(result), PipelineResult) + + def test_async_pipeline_class_returns_value_type(self): + from pychain.async_pipeline import AsyncPipelineResult + from pychain.enum import VALUE + + class AsyncPipe: + @pipeline + async def double(self, x: int) -> int: + return x * 2 + + p = AsyncPipe() + result = p.double(5) + # __class__ follows same logic as sync: returns type of _value (coroutine) + # since _value is a coroutine (truthy), __class__ returns coroutine type + self.assertNotEqual(result.__class__, AsyncPipelineResult) + # type() bypasses __getattribute__ and returns AsyncPipelineResult + self.assertEqual(type(result), AsyncPipelineResult) + object.__getattribute__(result, VALUE).close() + + def test_sync_pipeline_class_falsy_value(self): + from pychain.pipeline import PipelineResult + + class Pipe: + @pipeline + def zero(self, x: int) -> int: + return 0 + + p = Pipe() + result = p.zero(5) + # value is 0 (falsy), so __class__ falls through to instance type + self.assertEqual(result.__class__, Pipe) + + if __name__ == "__main__": unittest.main() From b391d1f74750ca845cc0496f43470523c9652aa7 Mon Sep 17 00:00:00 2001 From: mrwuliu Date: Tue, 28 Apr 2026 14:41:09 +0800 Subject: [PATCH 5/6] ci: run full test discovery in workflow Ultraworked with [Sisyphus](https://github.com/code-yeongyu/oh-my-openagent) Co-authored-by: Sisyphus --- .github/workflows/test.yml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/.github/workflows/test.yml b/.github/workflows/test.yml index 96084f1..004a3a2 100644 --- a/.github/workflows/test.yml +++ b/.github/workflows/test.yml @@ -28,7 +28,7 @@ jobs: pip install pytest pytest-cov numpy - name: Run tests with coverage (>=50%) - run: pytest --cov=pychain --cov-fail-under=50 --cov-report=term-missing -v tests/test_chainable.py + run: pytest --cov=pychain --cov-fail-under=50 --cov-report=term-missing -v tests/ - name: Type check with basedpyright run: | From e630a3560360df63843b074d2ac7bb7af989567d Mon Sep 17 00:00:00 2001 From: mrwuliu Date: Tue, 28 Apr 2026 16:07:16 +0800 Subject: [PATCH 6/6] fix: support mixed async/sync chaining with conditional await Ultraworked with [Sisyphus](https://github.com/code-yeongyu/oh-my-openagent) Co-authored-by: Sisyphus --- pychain/async_chainable.py | 5 ++- pychain/async_pipeline.py | 5 ++- tests/test_chainable.py | 86 ++++++++++++++++++++++++++++++++++++++ 3 files changed, 94 insertions(+), 2 deletions(-) diff --git a/pychain/async_chainable.py b/pychain/async_chainable.py index d69c181..47220de 100644 --- a/pychain/async_chainable.py +++ b/pychain/async_chainable.py @@ -1,5 +1,6 @@ from __future__ import annotations +import inspect from functools import wraps from typing import Any, Callable, TypeVar from pychain.common import CommonChain @@ -35,7 +36,9 @@ async def async_step() -> Any: await value inner = attr(*args, **kwargs) inner_value = object.__getattribute__(inner, "_value") - return await inner_value + if inspect.isawaitable(inner_value): + return await inner_value + return inner_value return AsyncChainableResult(instance, async_step()) return method_caller return attr diff --git a/pychain/async_pipeline.py b/pychain/async_pipeline.py index e448383..b6b704f 100644 --- a/pychain/async_pipeline.py +++ b/pychain/async_pipeline.py @@ -1,5 +1,6 @@ from __future__ import annotations +import inspect from functools import wraps from typing import Any, Callable, TypeVar from pychain.common import CommonChain @@ -46,7 +47,9 @@ async def make_step(prev_value: Any, *args: Any, **kwargs: Any) -> Any: else: inner = attr(resolved, *args, **kwargs) inner_value = object.__getattribute__(inner, VALUE) - return await inner_value + if inspect.isawaitable(inner_value): + return await inner_value + return inner_value return lambda *args, **kwargs: AsyncPipelineResult( instance, make_step(value, *args, **kwargs) diff --git a/tests/test_chainable.py b/tests/test_chainable.py index 5084035..d92bb8e 100644 --- a/tests/test_chainable.py +++ b/tests/test_chainable.py @@ -943,5 +943,91 @@ def zero(self, x: int) -> int: self.assertEqual(result.__class__, Pipe) +class TestMixedAsyncSyncChaining(unittest.TestCase): + """Regression tests for mixed async/sync chained method calls.""" + + def test_async_pipeline_then_sync_pipeline(self): + import asyncio + + class MixedPipe: + @pipeline + async def start(self, x): + return x + 1 + + @pipeline + def double(self, x): + return x * 2 + + async def run(): + pipe = MixedPipe() + val = await pipe.start(1).double() + self.assertEqual(val, 4) + + asyncio.run(run()) + + def test_async_pipeline_then_sync_pipeline_with_args(self): + import asyncio + + class MixedPipe: + @pipeline + async def start(self, x): + return x + 1 + + @pipeline + def multiply(self, x, factor): + return x * factor + + async def run(): + pipe = MixedPipe() + val = await pipe.start(3).multiply(factor=5) + self.assertEqual(val, 20) + + asyncio.run(run()) + + def test_async_chainable_then_sync_chainable(self): + import asyncio + + class MixedCalc: + @chainable + async def compute(self, x): + return x + 10 + + @chainable + def transform(self, x): + return x * 2 + + async def run(): + calc = MixedCalc() + partial = calc.compute(5) + transformed = partial.transform(15) + val = await transformed + self.assertEqual(val, 30) + + asyncio.run(run()) + + def test_sync_then_async_then_sync_pipeline(self): + import asyncio + + class MultiPipe: + @pipeline + def step_one(self, x): + return x + 1 + + @pipeline + async def step_two(self, x): + return x * 3 + + @pipeline + def step_three(self, x): + return x - 2 + + async def run(): + pipe = MultiPipe() + val = await pipe.step_one(4).step_two().step_three() + self.assertEqual(val, 13) + + asyncio.run(run()) + + if __name__ == "__main__": unittest.main()