diff --git a/.github/workflows/test.yml b/.github/workflows/test.yml index 281fd71..004a3a2 100644 --- a/.github/workflows/test.yml +++ b/.github/workflows/test.yml @@ -29,3 +29,8 @@ jobs: - name: Run tests with coverage (>=50%) run: pytest --cov=pychain --cov-fail-under=50 --cov-report=term-missing -v tests/ + + - name: Type check with basedpyright + run: | + pip install basedpyright + basedpyright pychain/ tests/test_types.py diff --git a/pychain/__init__.py b/pychain/__init__.py index 055ad4b..742b4ef 100644 --- a/pychain/__init__.py +++ b/pychain/__init__.py @@ -1,7 +1,5 @@ -from .chainable import chainable -from .pipeline import pipeline +from pychain.chainable import chainable +from pychain.pipeline import pipeline -__all__ = [ - 'chainable', 'pipeline' -] \ No newline at end of file +__all__ = ["chainable", "pipeline"] \ No newline at end of file diff --git a/pychain/async_chainable.py b/pychain/async_chainable.py new file mode 100644 index 0000000..47220de --- /dev/null +++ b/pychain/async_chainable.py @@ -0,0 +1,130 @@ +from __future__ import annotations + +import inspect +from functools import wraps +from typing import Any, Callable, TypeVar +from pychain.common import CommonChain + +T = TypeVar("T") + + +class AsyncChainableResult(CommonChain[T]): + + def __init__(self, instance: Any, awaitable_or_value: Any) -> None: + super().__init__(instance, awaitable_or_value) + + def __getattribute__(self, name: str) -> Any: + if name in ("_instance", "_value"): + return object.__getattribute__(self, name) + + try: + proxy_attr = object.__getattribute__(self, name) + if callable(proxy_attr): + return proxy_attr + return proxy_attr + except AttributeError: + pass + + instance = object.__getattribute__(self, "_instance") + value = object.__getattribute__(self, "_value") + + if hasattr(instance, name): + attr = getattr(instance, name) + if callable(attr): + def method_caller(*args: Any, **kwargs: Any) -> AsyncChainableResult[Any]: + async def async_step() -> Any: + await value + inner = attr(*args, **kwargs) + inner_value = object.__getattribute__(inner, "_value") + if inspect.isawaitable(inner_value): + return await inner_value + return inner_value + return AsyncChainableResult(instance, async_step()) + return method_caller + return attr + + if hasattr(value, name): + attr = getattr(value, name) + if callable(attr): + return lambda *args, **kwargs: attr(*args, **kwargs) + return attr + + raise AttributeError( + f"'{type(self).__name__}' object has no attribute '{name}'" + ) + + def __await__(self): + return object.__getattribute__(self, "_value").__await__() + + # -- Async-aware functional composition overrides -- + # These override CommonChain methods so that _value (a coroutine) is + # awaited before the transformation is applied. + + def map(self, fn: Callable[[Any], Any]) -> AsyncChainableResult[Any]: + prev_value = object.__getattribute__(self, "_value") + instance = object.__getattribute__(self, "_instance") + + async def _async_map() -> Any: + resolved = await prev_value + return fn(resolved) + + return AsyncChainableResult(instance, _async_map()) + + def filter(self, fn: Callable[[Any], bool]) -> AsyncChainableResult[Any]: + prev_value = object.__getattribute__(self, "_value") + instance = object.__getattribute__(self, "_instance") + + async def _async_filter() -> Any: + resolved = await prev_value + if not fn(resolved): + raise ValueError( + f"Filter predicate returned False for {resolved!r}" + ) + return resolved + + return AsyncChainableResult(instance, _async_filter()) + + def flat_map(self, fn: Callable[[Any], Any]) -> AsyncChainableResult[Any]: + """Async version: awaits _value, applies fn, returns async proxy. + + Unlike the sync CommonChain.flat_map which returns the raw unwrapped + value, the async version must return an AsyncChainableResult because + the result cannot be produced synchronously. + """ + prev_value = object.__getattribute__(self, "_value") + instance = object.__getattribute__(self, "_instance") + + async def _async_flat_map() -> Any: + resolved = await prev_value + return fn(resolved) + + return AsyncChainableResult(instance, _async_flat_map()) + + def inspect(self, fn: Callable[[Any], None]) -> AsyncChainableResult[Any]: + prev_value = object.__getattribute__(self, "_value") + instance = object.__getattribute__(self, "_instance") + + async def _async_inspect() -> Any: + resolved = await prev_value + fn(resolved) + return resolved + + return AsyncChainableResult(instance, _async_inspect()) + + def tap(self, fn: Callable[[Any], None]) -> AsyncChainableResult[Any]: + return self.inspect(fn) + + def __call__(self, *args: Any, **kwargs: Any) -> Any: + if callable(object.__getattribute__(self, "_value")): + return object.__getattribute__(self, "_value")(*args, **kwargs) + raise TypeError( + f"'{type(object.__getattribute__(self, '_value')).__name__}' object is not callable" + ) + + +def async_chainable(func: Callable[..., Any]) -> Callable[..., AsyncChainableResult[Any]]: + @wraps(func) + def wrapper(self: Any, *args: Any, **kwargs: Any) -> AsyncChainableResult[Any]: + coro = func(self, *args, **kwargs) + return AsyncChainableResult(self, coro) + return wrapper diff --git a/pychain/async_pipeline.py b/pychain/async_pipeline.py new file mode 100644 index 0000000..b6b704f --- /dev/null +++ b/pychain/async_pipeline.py @@ -0,0 +1,129 @@ +from __future__ import annotations + +import inspect +from functools import wraps +from typing import Any, Callable, TypeVar +from pychain.common import CommonChain +from pychain.enum import VALUE, INSTANCE + +T = TypeVar("T") + + +class AsyncPipelineResult(CommonChain[T]): + + def __init__(self, instance: Any, awaitable_or_value: Any) -> None: + super().__init__(instance, awaitable_or_value) + + def __getattribute__(self, name: str) -> Any: + if name == "__class__": + if value := object.__getattribute__(self, VALUE): + return type(value) + elif instance := object.__getattribute__(self, INSTANCE): + return type(instance) + else: + return type(self) + + if name in (INSTANCE, VALUE): + return object.__getattribute__(self, name) + + try: + proxy_attr = object.__getattribute__(self, name) + if callable(proxy_attr): + return proxy_attr + return proxy_attr + except AttributeError: + pass + + instance = object.__getattribute__(self, INSTANCE) + value = object.__getattribute__(self, VALUE) + + if hasattr(instance, name): + attr = getattr(instance, name) + if callable(attr): + async def make_step(prev_value: Any, *args: Any, **kwargs: Any) -> Any: + resolved = await prev_value + if isinstance(resolved, tuple): + inner = attr(*resolved, *args, **kwargs) + else: + inner = attr(resolved, *args, **kwargs) + inner_value = object.__getattribute__(inner, VALUE) + if inspect.isawaitable(inner_value): + return await inner_value + return inner_value + + return lambda *args, **kwargs: AsyncPipelineResult( + instance, make_step(value, *args, **kwargs) + ) + return attr + + if hasattr(value, name): + attr = getattr(value, name) + if callable(attr): + return lambda *args, **kwargs: attr(*args, **kwargs) + return attr + + raise AttributeError(name) + + # -- Async-aware functional composition overrides -- + + def map(self, fn: Callable[[Any], Any]) -> AsyncPipelineResult[Any]: + prev_value = object.__getattribute__(self, VALUE) + instance = object.__getattribute__(self, INSTANCE) + + async def _async_map() -> Any: + resolved = await prev_value + return fn(resolved) + + return AsyncPipelineResult(instance, _async_map()) + + def filter(self, fn: Callable[[Any], bool]) -> AsyncPipelineResult[Any]: + prev_value = object.__getattribute__(self, VALUE) + instance = object.__getattribute__(self, INSTANCE) + + async def _async_filter() -> Any: + resolved = await prev_value + if not fn(resolved): + raise ValueError( + f"Filter predicate returned False for {resolved!r}" + ) + return resolved + + return AsyncPipelineResult(instance, _async_filter()) + + def flat_map(self, fn: Callable[[Any], Any]) -> AsyncPipelineResult[Any]: + prev_value = object.__getattribute__(self, VALUE) + instance = object.__getattribute__(self, INSTANCE) + + async def _async_flat_map() -> Any: + resolved = await prev_value + return fn(resolved) + + return AsyncPipelineResult(instance, _async_flat_map()) + + def inspect(self, fn: Callable[[Any], None]) -> AsyncPipelineResult[Any]: + prev_value = object.__getattribute__(self, VALUE) + instance = object.__getattribute__(self, INSTANCE) + + async def _async_inspect() -> Any: + resolved = await prev_value + fn(resolved) + return resolved + + return AsyncPipelineResult(instance, _async_inspect()) + + def tap(self, fn: Callable[[Any], None]) -> AsyncPipelineResult[Any]: + return self.inspect(fn) + + def __call__(self, *args: Any, **kwargs: Any) -> AsyncPipelineResult[T]: + return self + + def __await__(self): + return object.__getattribute__(self, VALUE).__await__() + + +def async_pipeline(func: Callable[..., Any]) -> Callable[..., AsyncPipelineResult[Any]]: + @wraps(func) + def wrapper(self: Any, *args: Any, **kwargs: Any) -> AsyncPipelineResult[Any]: + coro = func(self, *args, **kwargs) + return AsyncPipelineResult(self, coro) + return wrapper diff --git a/pychain/chainable.py b/pychain/chainable.py index 3314125..2a6153c 100644 --- a/pychain/chainable.py +++ b/pychain/chainable.py @@ -1,14 +1,28 @@ +from __future__ import annotations + +import inspect from functools import wraps -from typing import Any +from typing import Any, Awaitable, Callable, TypeVar, overload +from pychain.async_chainable import AsyncChainableResult from pychain.common import CommonChain +T = TypeVar("T") + -class ChainableResult(CommonChain): +class ChainableResult(CommonChain[T]): - def __getattribute__(self, name) -> Any: - if name in ["_instance", "_value"]: + def __getattribute__(self, name: str) -> Any: + if name in ("_instance", "_value"): return object.__getattribute__(self, name) + try: + proxy_attr = object.__getattribute__(self, name) + if callable(proxy_attr): + return proxy_attr + return proxy_attr + except AttributeError: + pass + instance = object.__getattribute__(self, "_instance") value = object.__getattribute__(self, "_value") @@ -30,16 +44,25 @@ def __getattribute__(self, name) -> Any: f"'{type(self).__name__}' object has no attribute '{name}'" ) - def __call__(self, *args, **kwargs) -> Any: - if callable(self._value): - return self._value(*args, **kwargs) - raise TypeError(f"'{type(self._value).__name__}' object is not callable") + def __call__(self, *args: Any, **kwargs: Any) -> Any: + if callable(object.__getattribute__(self, "_value")): + return object.__getattribute__(self, "_value")(*args, **kwargs) + raise TypeError( + f"'{type(object.__getattribute__(self, '_value')).__name__}' object is not callable" + ) -def chainable(func): +@overload +def chainable(func: Callable[..., Awaitable[T]]) -> Callable[..., AsyncChainableResult[T]]: ... # pyright: ignore[reportOverlappingOverload] + +@overload +def chainable(func: Callable[..., T]) -> Callable[..., ChainableResult[T]]: ... + +def chainable(func: Callable[..., Any]) -> Callable[..., ChainableResult[Any] | AsyncChainableResult[Any]]: @wraps(func) - def wrapper(self, *args, **kwargs): + def wrapper(self: Any, *args: Any, **kwargs: Any) -> ChainableResult[Any] | AsyncChainableResult[Any]: result = func(self, *args, **kwargs) + if inspect.iscoroutine(result): + return AsyncChainableResult(self, result) return ChainableResult(self, result) - return wrapper diff --git a/pychain/common.py b/pychain/common.py index a2215c2..5dd5ec9 100644 --- a/pychain/common.py +++ b/pychain/common.py @@ -1,21 +1,26 @@ +from __future__ import annotations + import operator -from typing import Any +from typing import Any, Callable, Generic, TypeVar from pychain.enum import VALUE, INSTANCE +T = TypeVar("T") +U = TypeVar("U") + -class CommonChain: +class CommonChain(Generic[T]): - def __init__(self, instance, value) -> None: + def __init__(self, instance: Any, value: T) -> None: object.__setattr__(self, INSTANCE, instance) object.__setattr__(self, VALUE, value) - def __str__(self): + def __str__(self) -> str: return str(object.__getattribute__(self, VALUE)) def __repr__(self) -> str: return repr(object.__getattribute__(self, VALUE)) - def __dir__(self) -> list: + def __dir__(self) -> list[str]: instance = object.__getattribute__(self, INSTANCE) value = object.__getattribute__(self, VALUE) return sorted(set(dir(instance) + dir(value))) @@ -32,80 +37,110 @@ def __bool__(self) -> bool: def __index__(self) -> int: return operator.index(object.__getattribute__(self, VALUE)) - def __add__(self, other: Any): + def __add__(self, other: Any) -> Any: return object.__getattribute__(self, VALUE) + other - def __sub__(self, other: Any): + def __sub__(self, other: Any) -> Any: return object.__getattribute__(self, VALUE) - other - def __mul__(self, other: Any): + def __mul__(self, other: Any) -> Any: return object.__getattribute__(self, VALUE) * other - def __truediv__(self, other: Any): + def __truediv__(self, other: Any) -> Any: return object.__getattribute__(self, VALUE) / other - def __floordiv__(self, other: Any): + def __floordiv__(self, other: Any) -> Any: return object.__getattribute__(self, VALUE) // other - def __mod__(self, other: Any): + def __mod__(self, other: Any) -> Any: return object.__getattribute__(self, VALUE) % other - def __pow__(self, other: Any): + def __pow__(self, other: Any) -> Any: return object.__getattribute__(self, VALUE) ** other - def __radd__(self, other: Any): + def __radd__(self, other: Any) -> Any: return other + object.__getattribute__(self, VALUE) - def __rsub__(self, other: Any): + def __rsub__(self, other: Any) -> Any: return other - object.__getattribute__(self, VALUE) - def __rmul__(self, other: Any): + def __rmul__(self, other: Any) -> Any: return other * object.__getattribute__(self, VALUE) - def __rtruediv__(self, other: Any): + def __rtruediv__(self, other: Any) -> Any: return other / object.__getattribute__(self, VALUE) - def __rfloordiv__(self, other: Any): + def __rfloordiv__(self, other: Any) -> Any: return other // object.__getattribute__(self, VALUE) - def __rmod__(self, other: Any): + def __rmod__(self, other: Any) -> Any: return other % object.__getattribute__(self, VALUE) - def __rpow__(self, other: Any): + def __rpow__(self, other: Any) -> Any: return other ** object.__getattribute__(self, VALUE) - def __matmul__(self, other: Any): + def __matmul__(self, other: Any) -> Any: return object.__getattribute__(self, VALUE) @ other - def __rmatmul__(self, other: Any): + def __rmatmul__(self, other: Any) -> Any: return other @ object.__getattribute__(self, VALUE) - def __neg__(self): + def __neg__(self) -> Any: return -object.__getattribute__(self, VALUE) - def __pos__(self): + def __pos__(self) -> Any: return +object.__getattribute__(self, VALUE) - def __abs__(self): + def __abs__(self) -> Any: return abs(object.__getattribute__(self, VALUE)) - def __invert__(self): + def __invert__(self) -> Any: return ~object.__getattribute__(self, VALUE) - def __ne__(self, value: object): + def __ne__(self, value: object) -> bool: return object.__getattribute__(self, VALUE) != value def __eq__(self, value: object) -> bool: return object.__getattribute__(self, VALUE) == value - def __ge__(self, value: object): + def __ge__(self, value: object) -> bool: return object.__getattribute__(self, VALUE) >= value - def __gt__(self, value: object): + def __gt__(self, value: object) -> bool: return object.__getattribute__(self, VALUE) > value - def __le__(self, value: object): + def __le__(self, value: object) -> bool: return object.__getattribute__(self, VALUE) <= value - def __lt__(self, value: object): + def __lt__(self, value: object) -> bool: return object.__getattribute__(self, VALUE) < value + + # -- Functional composition methods -- + + def map(self, fn: Callable[[T], U]) -> CommonChain[U]: + value = object.__getattribute__(self, VALUE) + result = fn(value) + instance = object.__getattribute__(self, INSTANCE) + return type(self)(instance, result) # pyright: ignore[reportReturnType,reportArgumentType] + + def filter(self, fn: Callable[[T], bool]) -> CommonChain[T]: + """Return self if fn(_value) is truthy, otherwise raise ValueError.""" + value = object.__getattribute__(self, VALUE) + if not fn(value): + raise ValueError(f"Filter predicate returned False for {value!r}") + return self + + def flat_map(self, fn: Callable[[T], Any]) -> Any: + """Transform _value with fn and unwrap the result (no wrapping in proxy).""" + value = object.__getattribute__(self, VALUE) + return fn(value) + + def inspect(self, fn: Callable[[T], None]) -> CommonChain[T]: + """Execute fn on _value for side effects (e.g., debugging), return self unchanged.""" + value = object.__getattribute__(self, VALUE) + fn(value) + return self + + def tap(self, fn: Callable[[T], None]) -> CommonChain[T]: + """Alias for inspect.""" + return self.inspect(fn) diff --git a/pychain/enum.py b/pychain/enum.py index 40d2679..f00c9e2 100644 --- a/pychain/enum.py +++ b/pychain/enum.py @@ -1,3 +1,4 @@ +from __future__ import annotations VALUE = "_value" INSTANCE = "_instance" \ No newline at end of file diff --git a/pychain/pipeline.py b/pychain/pipeline.py index e49cef9..96e0b93 100644 --- a/pychain/pipeline.py +++ b/pychain/pipeline.py @@ -1,9 +1,16 @@ +from __future__ import annotations + +import inspect from functools import wraps -from typing import Any, Callable +from typing import Any, Awaitable, Callable, TypeVar, overload +from pychain.async_pipeline import AsyncPipelineResult from pychain.common import CommonChain from pychain.enum import VALUE, INSTANCE -class PipelineResult(CommonChain): +T = TypeVar("T") + + +class PipelineResult(CommonChain[T]): def __getattribute__(self, name: str) -> Any: if name == "__class__": @@ -14,8 +21,17 @@ def __getattribute__(self, name: str) -> Any: else: return type(self) - if name in [INSTANCE, VALUE]: + if name in (INSTANCE, VALUE): return object.__getattribute__(self, name) + + try: + proxy_attr = object.__getattribute__(self, name) + if callable(proxy_attr): + return proxy_attr + return proxy_attr + except AttributeError: + pass + instance = object.__getattribute__(self, INSTANCE) value = object.__getattribute__(self, VALUE) @@ -23,23 +39,32 @@ def __getattribute__(self, name: str) -> Any: attr = getattr(instance, name) if callable(attr): if isinstance(value, tuple): - return lambda **kwargs: attr(*value, **kwargs) - return lambda **kwargs: attr(value, **kwargs) + return lambda *args, **kwargs: attr(*value, *args, **kwargs) + return lambda *args, **kwargs: attr(value, *args, **kwargs) return attr if hasattr(value, name): attr = getattr(value, name) if callable(attr): return lambda *args, **kwargs: attr(*args, **kwargs) - return value + return attr raise AttributeError(name) - def __call__(self, *args, **kwargs) -> Any: + def __call__(self, *args: Any, **kwargs: Any) -> PipelineResult[T]: return self -def pipeline(func: Callable) -> Callable: + +@overload +def pipeline(func: Callable[..., Awaitable[T]]) -> Callable[..., AsyncPipelineResult[T]]: ... # pyright: ignore[reportOverlappingOverload] + +@overload +def pipeline(func: Callable[..., T]) -> Callable[..., PipelineResult[T]]: ... + +def pipeline(func: Callable[..., Any]) -> Callable[..., PipelineResult[Any] | AsyncPipelineResult[Any]]: @wraps(func) - def wrapper(self, *args, **kwargs) -> PipelineResult: + def wrapper(self: Any, *args: Any, **kwargs: Any) -> PipelineResult[Any] | AsyncPipelineResult[Any]: result = func(self, *args, **kwargs) + if inspect.iscoroutine(result): + return AsyncPipelineResult(self, result) return PipelineResult(self, result) return wrapper diff --git a/pychain/py.typed b/pychain/py.typed new file mode 100644 index 0000000..e69de29 diff --git a/pyproject.toml b/pyproject.toml index a121bfb..da744e4 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -25,3 +25,16 @@ classifiers = [ [project.urls] Homepage = "https://github.com/mr-wuliu/PythonChainable" Issues = "https://github.com/mr-wuliu/PythonChainable/issues" + +[tool.basedpyright] +pythonVersion = "3.10" +pythonPlatform = "All" +typeCheckingMode = "strict" +# Proxy pattern legitimately uses Any +reportAny = "none" +reportExplicitAny = "none" +# Many intentional dunder overrides in proxy classes +reportImplicitOverride = "none" +# Lambda params in proxy forwarding are inherently unknown +reportUnknownLambdaType = "none" +reportUnknownVariableType = "none" diff --git a/tests/test_chainable.py b/tests/test_chainable.py index b23b6ea..d92bb8e 100644 --- a/tests/test_chainable.py +++ b/tests/test_chainable.py @@ -204,6 +204,830 @@ def scale(self, x, y, factor: float): res2 = res.scale(factor=1/2) self.assertEqual(res2, (3 , -2)) - + + +class TestCommonChain(unittest.TestCase): + def test_str(self): + from pychain.common import CommonChain + proxy = CommonChain("instance", 42) + self.assertEqual(str(proxy), "42") + + def test_repr(self): + from pychain.common import CommonChain + proxy = CommonChain("instance", "hello") + self.assertEqual(repr(proxy), "'hello'") + + def test_int(self): + from pychain.common import CommonChain + proxy = CommonChain("instance", 7) + self.assertEqual(int(proxy), 7) + + def test_float(self): + from pychain.common import CommonChain + proxy = CommonChain("instance", 3.14) + self.assertAlmostEqual(float(proxy), 3.14) + + def test_bool(self): + from pychain.common import CommonChain + self.assertTrue(bool(CommonChain("i", 1))) + self.assertTrue(bool(CommonChain("i", "nonempty"))) + self.assertFalse(bool(CommonChain("i", 0))) + self.assertFalse(bool(CommonChain("i", ""))) + + def test_index(self): + from pychain.common import CommonChain + proxy = CommonChain("instance", 5) + self.assertEqual([0, 1, 2, 3, 4, 5][proxy], 5) + + def test_arithmetic(self): + from pychain.common import CommonChain + p = CommonChain("i", 10) + self.assertEqual(p + 5, 15) + self.assertEqual(p - 3, 7) + self.assertEqual(p * 2, 20) + self.assertEqual(p / 4, 2.5) + self.assertEqual(p // 3, 3) + self.assertEqual(p % 3, 1) + self.assertEqual(p ** 2, 100) + + def test_reverse_arithmetic(self): + from pychain.common import CommonChain + p = CommonChain("i", 10) + self.assertEqual(5 + p, 15) + self.assertEqual(20 - p, 10) + self.assertEqual(3 * p, 30) + self.assertEqual(100 / p, 10.0) + self.assertEqual(23 // p, 2) + self.assertEqual(23 % p, 3) + self.assertEqual(2 ** p, 1024) + + def test_matmul(self): + import numpy as np + from pychain.common import CommonChain + a = np.array([[1, 2]]) + b = np.array([[3], [4]]) + p = CommonChain("i", a) + result = p @ b + expected = a @ b + self.assertTrue(np.array_equal(result, expected)) + + def test_unary(self): + from pychain.common import CommonChain + p = CommonChain("i", 5) + self.assertEqual(-p, -5) + self.assertEqual(+p, 5) + self.assertEqual(abs(CommonChain("i", -7)), 7) + self.assertEqual(~CommonChain("i", 0), -1) + + def test_comparison(self): + from pychain.common import CommonChain + p = CommonChain("i", 10) + self.assertTrue(p == 10) + self.assertTrue(p != 11) + self.assertTrue(p > 5) + self.assertTrue(p < 20) + self.assertTrue(p >= 10) + self.assertTrue(p <= 10) + self.assertFalse(p > 15) + + def test_dir(self): + from pychain.common import CommonChain + d = dir(CommonChain("instance_obj", 42)) + self.assertIsInstance(d, list) + self.assertIn("__class__", d) + + +class TestFunctionalComposition(unittest.TestCase): + def test_map(self): + from pychain.chainable import ChainableResult, chainable + + class Calc: + @chainable + def add(self, a, b): + return a + b + + calc = Calc() + result = calc.add(3, 4).map(lambda x: x * 2) + self.assertEqual(int(result), 14) + + def test_filter_pass(self): + from pychain.chainable import ChainableResult, chainable + + class Calc: + @chainable + def add(self, a, b): + return a + b + + calc = Calc() + result = calc.add(3, 4).filter(lambda x: x > 0) + self.assertEqual(int(result), 7) + + def test_filter_fail(self): + from pychain.chainable import chainable + + class Calc: + @chainable + def add(self, a, b): + return a + b + + calc = Calc() + with self.assertRaises(ValueError): + calc.add(3, 4).filter(lambda x: x < 0) + + def test_flat_map(self): + from pychain.chainable import chainable + + class Calc: + @chainable + def add(self, a, b): + return a + b + + calc = Calc() + result = calc.add(3, 4).flat_map(lambda x: x * 10) + self.assertEqual(result, 70) + self.assertNotIsInstance(result, type(calc.add(1, 2))) + + def test_inspect(self): + from pychain.chainable import chainable + collected = [] + + class Calc: + @chainable + def add(self, a, b): + return a + b + + calc = Calc() + result = calc.add(3, 4).inspect(lambda x: collected.append(x)) + self.assertEqual(collected, [7]) + self.assertEqual(int(result), 7) + + def test_tap(self): + from pychain.chainable import chainable + collected = [] + + class Calc: + @chainable + def add(self, a, b): + return a + b + + calc = Calc() + result = calc.add(3, 4).tap(lambda x: collected.append(x)) + self.assertEqual(collected, [7]) + self.assertEqual(int(result), 7) + + def test_chain_functional_methods(self): + from pychain.chainable import chainable + log = [] + + class Calc: + @chainable + def add(self, a, b): + return a + b + + @chainable + def multiply(self, a, b): + return a * b + + calc = Calc() + result = ( + calc.add(2, 3) + .map(lambda x: x + 1) + .inspect(lambda x: log.append(x)) + .filter(lambda x: x > 0) + .map(lambda x: x * 2) + ) + self.assertEqual(log, [6]) + self.assertEqual(int(result), 12) + + def test_pipeline_functional_methods(self): + from pychain.pipeline import pipeline + + class Calc: + @pipeline + def double(self, x): + return x * 2 + + calc = Calc() + result = calc.double(5).map(lambda x: x + 1) + self.assertEqual(int(result), 11) + + +class TestAsyncChainable(unittest.TestCase): + def test_basic_async_chain(self): + import asyncio + from pychain.async_chainable import AsyncChainableResult + + class AsyncCalc: + @chainable + async def add(self, x): + return x + 10 + + async def run(): + calc = AsyncCalc() + result = calc.add(5) + self.assertIsInstance(result, AsyncChainableResult) + val = await result + self.assertEqual(val, 15) + + asyncio.run(run()) + + def test_await_final_result(self): + import asyncio + + class AsyncCalc: + @chainable + async def multiply(self, x): + return x * 3 + + async def run(): + calc = AsyncCalc() + val = await calc.multiply(4) + self.assertEqual(val, 12) + + asyncio.run(run()) + + def test_multi_step_chain(self): + import asyncio + + class AsyncCalc: + def __init__(self): + self.value = 0 + + @chainable + async def add(self, x): + self.value += x + return self.value + + async def run(): + calc = AsyncCalc() + result = calc.add(3).add(5) + val = await result + self.assertEqual(val, 8) + + asyncio.run(run()) + + +class TestAsyncChainableFunctional(unittest.TestCase): + def test_async_chainable_map(self): + import asyncio + from pychain.async_chainable import AsyncChainableResult + + class AsyncCalc: + @chainable + async def add(self, x): + return x + 10 + + async def run(): + calc = AsyncCalc() + result = calc.add(5).map(lambda x: x * 2) + self.assertIsInstance(result, AsyncChainableResult) + val = await result + self.assertEqual(val, 30) + + asyncio.run(run()) + + def test_async_chainable_filter_pass(self): + import asyncio + + class AsyncCalc: + @chainable + async def add(self, x): + return x + 10 + + async def run(): + calc = AsyncCalc() + result = calc.add(5).filter(lambda x: x > 0) + val = await result + self.assertEqual(val, 15) + + asyncio.run(run()) + + def test_async_chainable_filter_fail(self): + import asyncio + + class AsyncCalc: + @chainable + async def add(self, x): + return x + 10 + + async def run(): + calc = AsyncCalc() + with self.assertRaises(ValueError): + await calc.add(5).filter(lambda x: x < 0) + + asyncio.run(run()) + + def test_async_chainable_flat_map(self): + import asyncio + from pychain.async_chainable import AsyncChainableResult + + class AsyncCalc: + @chainable + async def add(self, x): + return x + 10 + + async def run(): + calc = AsyncCalc() + result = calc.add(5).flat_map(lambda x: x * 10) + self.assertIsInstance(result, AsyncChainableResult) + val = await result + self.assertEqual(val, 150) + + asyncio.run(run()) + + def test_async_chainable_inspect(self): + import asyncio + + class AsyncCalc: + @chainable + async def add(self, x): + return x + 10 + + async def run(): + collected = [] + calc = AsyncCalc() + result = calc.add(5).inspect(lambda x: collected.append(x)) + val = await result + self.assertEqual(collected, [15]) + self.assertEqual(val, 15) + + asyncio.run(run()) + + def test_async_chainable_tap(self): + import asyncio + + class AsyncCalc: + @chainable + async def add(self, x): + return x + 10 + + async def run(): + collected = [] + calc = AsyncCalc() + result = calc.add(5).tap(lambda x: collected.append(x)) + val = await result + self.assertEqual(collected, [15]) + self.assertEqual(val, 15) + + asyncio.run(run()) + + def test_async_chainable_combined_functional_chain(self): + import asyncio + + class AsyncCalc: + @chainable + async def add(self, x): + return x + 10 + + async def run(): + log = [] + calc = AsyncCalc() + result = ( + calc.add(5) + .map(lambda x: x * 2) + .inspect(lambda x: log.append(x)) + .filter(lambda x: x > 0) + .map(lambda x: x + 1) + ) + val = await result + self.assertEqual(log, [30]) + self.assertEqual(val, 31) + + asyncio.run(run()) + + +class TestAsyncPipeline(unittest.TestCase): + def test_basic_async_pipeline(self): + import asyncio + from pychain.async_pipeline import AsyncPipelineResult + + class AsyncPipe: + @pipeline + async def add_one(self, x): + return x + 1 + + @pipeline + async def double(self, x): + return x * 2 + + async def run(): + pipe = AsyncPipe() + result = pipe.add_one(5) + self.assertIsInstance(result, AsyncPipelineResult) + val = await result + self.assertEqual(val, 6) + + asyncio.run(run()) + + def test_async_pipeline_chaining(self): + import asyncio + + class AsyncPipe: + @pipeline + async def increment(self, x): + return x + 1 + + @pipeline + async def negate(self, x): + return -x + + async def run(): + pipe = AsyncPipe() + result = pipe.increment(10).negate() + val = await result + self.assertEqual(val, -11) + + asyncio.run(run()) + + def test_async_pipeline_tuple_unpacking(self): + import asyncio + + class AsyncPipe: + @pipeline + async def split(self, x): + return (x, x * 2) + + @pipeline + async def sum_pair(self, a, b): + return a + b + + async def run(): + pipe = AsyncPipe() + result = pipe.split(5).sum_pair() + val = await result + self.assertEqual(val, 15) + + asyncio.run(run()) + + +class TestAsyncPipelineFunctional(unittest.TestCase): + def test_async_pipeline_map(self): + import asyncio + from pychain.async_pipeline import AsyncPipelineResult + + class AsyncPipe: + @pipeline + async def double(self, x): + return x * 2 + + async def run(): + pipe = AsyncPipe() + result = pipe.double(5).map(lambda x: x + 1) + self.assertIsInstance(result, AsyncPipelineResult) + val = await result + self.assertEqual(val, 11) + + asyncio.run(run()) + + def test_async_pipeline_filter_pass(self): + import asyncio + + class AsyncPipe: + @pipeline + async def double(self, x): + return x * 2 + + async def run(): + pipe = AsyncPipe() + result = pipe.double(5).filter(lambda x: x > 0) + val = await result + self.assertEqual(val, 10) + + asyncio.run(run()) + + def test_async_pipeline_filter_fail(self): + import asyncio + + class AsyncPipe: + @pipeline + async def double(self, x): + return x * 2 + + async def run(): + pipe = AsyncPipe() + with self.assertRaises(ValueError): + await pipe.double(5).filter(lambda x: x < 0) + + asyncio.run(run()) + + def test_async_pipeline_flat_map(self): + import asyncio + + class AsyncPipe: + @pipeline + async def double(self, x): + return x * 2 + + async def run(): + pipe = AsyncPipe() + result = pipe.double(5).flat_map(lambda x: [x, x + 1]) + val = await result + self.assertEqual(val, [10, 11]) + + asyncio.run(run()) + + def test_async_pipeline_inspect(self): + import asyncio + + class AsyncPipe: + @pipeline + async def double(self, x): + return x * 2 + + async def run(): + collected = [] + pipe = AsyncPipe() + result = pipe.double(5).inspect(lambda x: collected.append(x)) + val = await result + self.assertEqual(collected, [10]) + self.assertEqual(val, 10) + + asyncio.run(run()) + + def test_async_pipeline_tap(self): + import asyncio + + class AsyncPipe: + @pipeline + async def double(self, x): + return x * 2 + + async def run(): + collected = [] + pipe = AsyncPipe() + result = pipe.double(5).tap(lambda x: collected.append(x)) + val = await result + self.assertEqual(collected, [10]) + self.assertEqual(val, 10) + + asyncio.run(run()) + + def test_async_pipeline_combined_functional_chain(self): + import asyncio + + class AsyncPipe: + @pipeline + async def increment(self, x): + return x + 1 + + @pipeline + async def negate(self, x): + return -x + + async def run(): + log = [] + pipe = AsyncPipe() + result = ( + pipe.increment(10) + .map(lambda x: x * 2) + .inspect(lambda x: log.append(x)) + .filter(lambda x: x > 0) + .negate() + ) + val = await result + self.assertEqual(log, [22]) + self.assertEqual(val, -22) + + asyncio.run(run()) + + +class TestPositionalArgPipeline(unittest.TestCase): + """Regression tests for positional args during chained pipeline calls.""" + + def test_sync_pipeline_positional_args(self): + class Pipe: + @pipeline + def add(self, x: int, y: int) -> int: + return x + y + + @pipeline + def multiply(self, x: int, factor: int) -> int: + return x * factor + + p = Pipe() + # Previous value (3) auto-injected as first arg; 5 passed as positional extra + result = p.add(1, 2).multiply(5) + self.assertEqual(int(result), 15) + + def test_sync_pipeline_positional_args_with_kwargs(self): + class Pipe: + @pipeline + def split(self, x: int) -> tuple: + return (x, x + 1) + + @pipeline + def combine(self, a: int, b: int, extra: int = 0) -> int: + return a + b + extra + + p = Pipe() + # Tuple (5, 6) unpacked as (a, b), then extra=10 via kwarg + result = p.split(5).combine(extra=10) + self.assertEqual(int(result), 21) + + def test_sync_pipeline_positional_args_tuple_unpack(self): + class Pipe: + @pipeline + def pair(self, x: int) -> tuple: + return (x, x * 2) + + @pipeline + def sum_with_extra(self, a: int, b: int, extra: int) -> int: + return a + b + extra + + p = Pipe() + # Tuple (3, 6) unpacked as (a, b), then 100 as positional extra + result = p.pair(3).sum_with_extra(100) + self.assertEqual(int(result), 109) + + def test_async_pipeline_positional_args(self): + import asyncio + + class AsyncPipe: + @pipeline + async def add(self, x: int, y: int) -> int: + return x + y + + @pipeline + async def multiply(self, x: int, factor: int) -> int: + return x * factor + + async def run(): + p = AsyncPipe() + result = p.add(1, 2).multiply(5) + val = await result + self.assertEqual(val, 15) + + asyncio.run(run()) + + def test_async_pipeline_positional_args_tuple_unpack(self): + import asyncio + + class AsyncPipe: + @pipeline + async def pair(self, x: int): + return (x, x * 2) + + @pipeline + async def sum_with_extra(self, a, b, extra): + return a + b + extra + + async def run(): + p = AsyncPipe() + result = p.pair(3).sum_with_extra(100) + val = await result + self.assertEqual(val, 109) + + asyncio.run(run()) + + +class TestClassBehavior(unittest.TestCase): + """Regression tests for __class__ proxy behavior in pipeline results.""" + + def test_sync_pipeline_class_returns_value_type(self): + from pychain.pipeline import PipelineResult + + class Pipe: + @pipeline + def double(self, x: int) -> int: + return x * 2 + + p = Pipe() + result = p.double(5) + # __class__ should return the wrapped value's type (int) + self.assertEqual(result.__class__, int) + # type() bypasses __getattribute__ and returns PipelineResult + self.assertEqual(type(result), PipelineResult) + + def test_sync_pipeline_class_with_string_value(self): + from pychain.pipeline import PipelineResult + + class StrPipe: + @pipeline + def upper(self, s: str) -> str: + return s.upper() + + p = StrPipe() + result = p.upper("hello") + self.assertEqual(result.__class__, str) + self.assertEqual(type(result), PipelineResult) + + def test_async_pipeline_class_returns_value_type(self): + from pychain.async_pipeline import AsyncPipelineResult + from pychain.enum import VALUE + + class AsyncPipe: + @pipeline + async def double(self, x: int) -> int: + return x * 2 + + p = AsyncPipe() + result = p.double(5) + # __class__ follows same logic as sync: returns type of _value (coroutine) + # since _value is a coroutine (truthy), __class__ returns coroutine type + self.assertNotEqual(result.__class__, AsyncPipelineResult) + # type() bypasses __getattribute__ and returns AsyncPipelineResult + self.assertEqual(type(result), AsyncPipelineResult) + object.__getattribute__(result, VALUE).close() + + def test_sync_pipeline_class_falsy_value(self): + from pychain.pipeline import PipelineResult + + class Pipe: + @pipeline + def zero(self, x: int) -> int: + return 0 + + p = Pipe() + result = p.zero(5) + # value is 0 (falsy), so __class__ falls through to instance type + self.assertEqual(result.__class__, Pipe) + + +class TestMixedAsyncSyncChaining(unittest.TestCase): + """Regression tests for mixed async/sync chained method calls.""" + + def test_async_pipeline_then_sync_pipeline(self): + import asyncio + + class MixedPipe: + @pipeline + async def start(self, x): + return x + 1 + + @pipeline + def double(self, x): + return x * 2 + + async def run(): + pipe = MixedPipe() + val = await pipe.start(1).double() + self.assertEqual(val, 4) + + asyncio.run(run()) + + def test_async_pipeline_then_sync_pipeline_with_args(self): + import asyncio + + class MixedPipe: + @pipeline + async def start(self, x): + return x + 1 + + @pipeline + def multiply(self, x, factor): + return x * factor + + async def run(): + pipe = MixedPipe() + val = await pipe.start(3).multiply(factor=5) + self.assertEqual(val, 20) + + asyncio.run(run()) + + def test_async_chainable_then_sync_chainable(self): + import asyncio + + class MixedCalc: + @chainable + async def compute(self, x): + return x + 10 + + @chainable + def transform(self, x): + return x * 2 + + async def run(): + calc = MixedCalc() + partial = calc.compute(5) + transformed = partial.transform(15) + val = await transformed + self.assertEqual(val, 30) + + asyncio.run(run()) + + def test_sync_then_async_then_sync_pipeline(self): + import asyncio + + class MultiPipe: + @pipeline + def step_one(self, x): + return x + 1 + + @pipeline + async def step_two(self, x): + return x * 3 + + @pipeline + def step_three(self, x): + return x - 2 + + async def run(): + pipe = MultiPipe() + val = await pipe.step_one(4).step_two().step_three() + self.assertEqual(val, 13) + + asyncio.run(run()) + + if __name__ == "__main__": unittest.main() diff --git a/tests/test_types.py b/tests/test_types.py new file mode 100644 index 0000000..2a542ff --- /dev/null +++ b/tests/test_types.py @@ -0,0 +1,56 @@ +# pyright: reportUnusedImport=none +"""Type-checking assertions for pychain public API. + +Validated by basedpyright, not executed by pytest. +Verifies that type inference works correctly for all public APIs. +""" +from __future__ import annotations + +from typing import TYPE_CHECKING + +from pychain import chainable, pipeline +from pychain.async_chainable import AsyncChainableResult +from pychain.async_pipeline import AsyncPipelineResult +from pychain.chainable import ChainableResult +from pychain.common import CommonChain +from pychain.pipeline import PipelineResult + +if TYPE_CHECKING: + + class Calc: + @chainable + def add(self, a: int, b: int) -> int: ... + + c = Calc() + r: ChainableResult[int] = c.add(1, 2) + mapped: CommonChain[int] = c.add(1, 2).map(lambda x: x * 2) + filtered: CommonChain[int] = c.add(1, 2).filter(lambda x: x > 0) + inspected: CommonChain[int] = c.add(1, 2).inspect(lambda x: None) + tapped: CommonChain[int] = c.add(1, 2).tap(lambda x: None) + + class Pipe: + @pipeline + def double(self, x: int) -> int: ... + + p = Pipe() + pr: PipelineResult[int] = p.double(5) + pmapped: CommonChain[int] = p.double(5).map(lambda x: x + 1) + + proxy: CommonChain[int] = CommonChain("i", 42) + add_result: int = proxy + 1 + eq_result: bool = proxy == 42 + str_result: str = str(proxy) + + class AsyncCalc: + @chainable + async def compute(self, x: int) -> int: ... + + ac = AsyncCalc() + ar: AsyncChainableResult[int] = ac.compute(5) + + class AsyncPipe: + @pipeline + async def step(self, x: int) -> int: ... + + ap = AsyncPipe() + apr: AsyncPipelineResult[int] = ap.step(3)