Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 5 additions & 0 deletions .github/workflows/test.yml
Original file line number Diff line number Diff line change
Expand Up @@ -29,3 +29,8 @@ jobs:

- name: Run tests with coverage (>=50%)
run: pytest --cov=pychain --cov-fail-under=50 --cov-report=term-missing -v tests/

- name: Type check with basedpyright
run: |
pip install basedpyright
basedpyright pychain/ tests/test_types.py
8 changes: 3 additions & 5 deletions pychain/__init__.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,5 @@
from .chainable import chainable
from .pipeline import pipeline
from pychain.chainable import chainable
from pychain.pipeline import pipeline


__all__ = [
'chainable', 'pipeline'
]
__all__ = ["chainable", "pipeline"]
130 changes: 130 additions & 0 deletions pychain/async_chainable.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,130 @@
from __future__ import annotations

import inspect
from functools import wraps
from typing import Any, Callable, TypeVar
from pychain.common import CommonChain

T = TypeVar("T")


class AsyncChainableResult(CommonChain[T]):

def __init__(self, instance: Any, awaitable_or_value: Any) -> None:
super().__init__(instance, awaitable_or_value)

def __getattribute__(self, name: str) -> Any:
if name in ("_instance", "_value"):
return object.__getattribute__(self, name)

try:
proxy_attr = object.__getattribute__(self, name)
if callable(proxy_attr):
return proxy_attr
return proxy_attr
except AttributeError:
pass

instance = object.__getattribute__(self, "_instance")
value = object.__getattribute__(self, "_value")

if hasattr(instance, name):
attr = getattr(instance, name)
if callable(attr):
def method_caller(*args: Any, **kwargs: Any) -> AsyncChainableResult[Any]:
async def async_step() -> Any:
await value
inner = attr(*args, **kwargs)
inner_value = object.__getattribute__(inner, "_value")
if inspect.isawaitable(inner_value):
return await inner_value
return inner_value
return AsyncChainableResult(instance, async_step())
return method_caller
return attr

if hasattr(value, name):
attr = getattr(value, name)
if callable(attr):
return lambda *args, **kwargs: attr(*args, **kwargs)
return attr

raise AttributeError(
f"'{type(self).__name__}' object has no attribute '{name}'"
)

def __await__(self):
return object.__getattribute__(self, "_value").__await__()

# -- Async-aware functional composition overrides --
# These override CommonChain methods so that _value (a coroutine) is
# awaited before the transformation is applied.

def map(self, fn: Callable[[Any], Any]) -> AsyncChainableResult[Any]:
prev_value = object.__getattribute__(self, "_value")
instance = object.__getattribute__(self, "_instance")

async def _async_map() -> Any:
resolved = await prev_value
return fn(resolved)

return AsyncChainableResult(instance, _async_map())

def filter(self, fn: Callable[[Any], bool]) -> AsyncChainableResult[Any]:
prev_value = object.__getattribute__(self, "_value")
instance = object.__getattribute__(self, "_instance")

async def _async_filter() -> Any:
resolved = await prev_value
if not fn(resolved):
raise ValueError(
f"Filter predicate returned False for {resolved!r}"
)
return resolved

return AsyncChainableResult(instance, _async_filter())

def flat_map(self, fn: Callable[[Any], Any]) -> AsyncChainableResult[Any]:
"""Async version: awaits _value, applies fn, returns async proxy.

Unlike the sync CommonChain.flat_map which returns the raw unwrapped
value, the async version must return an AsyncChainableResult because
the result cannot be produced synchronously.
"""
prev_value = object.__getattribute__(self, "_value")
instance = object.__getattribute__(self, "_instance")

async def _async_flat_map() -> Any:
resolved = await prev_value
return fn(resolved)

return AsyncChainableResult(instance, _async_flat_map())

def inspect(self, fn: Callable[[Any], None]) -> AsyncChainableResult[Any]:
prev_value = object.__getattribute__(self, "_value")
instance = object.__getattribute__(self, "_instance")

async def _async_inspect() -> Any:
resolved = await prev_value
fn(resolved)
return resolved

return AsyncChainableResult(instance, _async_inspect())

def tap(self, fn: Callable[[Any], None]) -> AsyncChainableResult[Any]:
return self.inspect(fn)

def __call__(self, *args: Any, **kwargs: Any) -> Any:
if callable(object.__getattribute__(self, "_value")):
return object.__getattribute__(self, "_value")(*args, **kwargs)
raise TypeError(
f"'{type(object.__getattribute__(self, '_value')).__name__}' object is not callable"
)


def async_chainable(func: Callable[..., Any]) -> Callable[..., AsyncChainableResult[Any]]:
@wraps(func)
def wrapper(self: Any, *args: Any, **kwargs: Any) -> AsyncChainableResult[Any]:
coro = func(self, *args, **kwargs)
return AsyncChainableResult(self, coro)
return wrapper
129 changes: 129 additions & 0 deletions pychain/async_pipeline.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,129 @@
from __future__ import annotations

import inspect
from functools import wraps
from typing import Any, Callable, TypeVar
from pychain.common import CommonChain
from pychain.enum import VALUE, INSTANCE

T = TypeVar("T")


class AsyncPipelineResult(CommonChain[T]):

def __init__(self, instance: Any, awaitable_or_value: Any) -> None:
super().__init__(instance, awaitable_or_value)

def __getattribute__(self, name: str) -> Any:
if name == "__class__":
if value := object.__getattribute__(self, VALUE):
return type(value)
elif instance := object.__getattribute__(self, INSTANCE):
return type(instance)
else:
return type(self)

if name in (INSTANCE, VALUE):
return object.__getattribute__(self, name)

try:
proxy_attr = object.__getattribute__(self, name)
if callable(proxy_attr):
return proxy_attr
return proxy_attr
except AttributeError:
pass

instance = object.__getattribute__(self, INSTANCE)
value = object.__getattribute__(self, VALUE)

if hasattr(instance, name):
attr = getattr(instance, name)
if callable(attr):
async def make_step(prev_value: Any, *args: Any, **kwargs: Any) -> Any:
resolved = await prev_value
if isinstance(resolved, tuple):
inner = attr(*resolved, *args, **kwargs)
else:
inner = attr(resolved, *args, **kwargs)
inner_value = object.__getattribute__(inner, VALUE)
if inspect.isawaitable(inner_value):
return await inner_value
return inner_value

return lambda *args, **kwargs: AsyncPipelineResult(
instance, make_step(value, *args, **kwargs)
)
return attr

if hasattr(value, name):
attr = getattr(value, name)
if callable(attr):
return lambda *args, **kwargs: attr(*args, **kwargs)
return attr

raise AttributeError(name)

# -- Async-aware functional composition overrides --

def map(self, fn: Callable[[Any], Any]) -> AsyncPipelineResult[Any]:
prev_value = object.__getattribute__(self, VALUE)
instance = object.__getattribute__(self, INSTANCE)

async def _async_map() -> Any:
resolved = await prev_value
return fn(resolved)

return AsyncPipelineResult(instance, _async_map())

def filter(self, fn: Callable[[Any], bool]) -> AsyncPipelineResult[Any]:
prev_value = object.__getattribute__(self, VALUE)
instance = object.__getattribute__(self, INSTANCE)

async def _async_filter() -> Any:
resolved = await prev_value
if not fn(resolved):
raise ValueError(
f"Filter predicate returned False for {resolved!r}"
)
return resolved

return AsyncPipelineResult(instance, _async_filter())

def flat_map(self, fn: Callable[[Any], Any]) -> AsyncPipelineResult[Any]:
prev_value = object.__getattribute__(self, VALUE)
instance = object.__getattribute__(self, INSTANCE)

async def _async_flat_map() -> Any:
resolved = await prev_value
return fn(resolved)

return AsyncPipelineResult(instance, _async_flat_map())

def inspect(self, fn: Callable[[Any], None]) -> AsyncPipelineResult[Any]:
prev_value = object.__getattribute__(self, VALUE)
instance = object.__getattribute__(self, INSTANCE)

async def _async_inspect() -> Any:
resolved = await prev_value
fn(resolved)
return resolved

return AsyncPipelineResult(instance, _async_inspect())

def tap(self, fn: Callable[[Any], None]) -> AsyncPipelineResult[Any]:
return self.inspect(fn)

def __call__(self, *args: Any, **kwargs: Any) -> AsyncPipelineResult[T]:
return self

def __await__(self):
return object.__getattribute__(self, VALUE).__await__()


def async_pipeline(func: Callable[..., Any]) -> Callable[..., AsyncPipelineResult[Any]]:
@wraps(func)
def wrapper(self: Any, *args: Any, **kwargs: Any) -> AsyncPipelineResult[Any]:
coro = func(self, *args, **kwargs)
return AsyncPipelineResult(self, coro)
return wrapper
45 changes: 34 additions & 11 deletions pychain/chainable.py
Original file line number Diff line number Diff line change
@@ -1,14 +1,28 @@
from __future__ import annotations

import inspect
from functools import wraps
from typing import Any
from typing import Any, Awaitable, Callable, TypeVar, overload
from pychain.async_chainable import AsyncChainableResult
from pychain.common import CommonChain

T = TypeVar("T")


class ChainableResult(CommonChain):
class ChainableResult(CommonChain[T]):

def __getattribute__(self, name) -> Any:
if name in ["_instance", "_value"]:
def __getattribute__(self, name: str) -> Any:
if name in ("_instance", "_value"):
return object.__getattribute__(self, name)

try:
proxy_attr = object.__getattribute__(self, name)
if callable(proxy_attr):
return proxy_attr
return proxy_attr
except AttributeError:
pass

instance = object.__getattribute__(self, "_instance")
value = object.__getattribute__(self, "_value")

Expand All @@ -30,16 +44,25 @@ def __getattribute__(self, name) -> Any:
f"'{type(self).__name__}' object has no attribute '{name}'"
)

def __call__(self, *args, **kwargs) -> Any:
if callable(self._value):
return self._value(*args, **kwargs)
raise TypeError(f"'{type(self._value).__name__}' object is not callable")
def __call__(self, *args: Any, **kwargs: Any) -> Any:
if callable(object.__getattribute__(self, "_value")):
return object.__getattribute__(self, "_value")(*args, **kwargs)
raise TypeError(
f"'{type(object.__getattribute__(self, '_value')).__name__}' object is not callable"
)


def chainable(func):
@overload
def chainable(func: Callable[..., Awaitable[T]]) -> Callable[..., AsyncChainableResult[T]]: ... # pyright: ignore[reportOverlappingOverload]

@overload
def chainable(func: Callable[..., T]) -> Callable[..., ChainableResult[T]]: ...

def chainable(func: Callable[..., Any]) -> Callable[..., ChainableResult[Any] | AsyncChainableResult[Any]]:
@wraps(func)
def wrapper(self, *args, **kwargs):
def wrapper(self: Any, *args: Any, **kwargs: Any) -> ChainableResult[Any] | AsyncChainableResult[Any]:
result = func(self, *args, **kwargs)
if inspect.iscoroutine(result):
return AsyncChainableResult(self, result)
return ChainableResult(self, result)

return wrapper
Loading
Loading