Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat(siu): allow pipe func to accept args and kwargs #413

Merged
merged 12 commits into from
Aug 26, 2022
59 changes: 36 additions & 23 deletions examples/examples-siu.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -11,23 +11,25 @@
"cell_type": "code",
"execution_count": 1,
"metadata": {},
"outputs": [],
"outputs": [
{
"name": "stderr",
"output_type": "stream",
"text": [
"/Users/machow/.virtualenvs/siuba/lib/python3.8/site-packages/pandas/compat/__init__.py:124: UserWarning: Could not import the lzma module. Your installed Python is incomplete. Attempting to use lzma compression will result in a RuntimeError.\n",
" warnings.warn(msg)\n"
]
}
],
"source": [
"from siuba.siu import _, explain"
"from siuba.siu import _, explain, strip_symbolic"
]
},
{
"cell_type": "code",
"execution_count": 2,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"_.somecol.min()\n"
]
},
{
"data": {
"text/plain": [
Expand Down Expand Up @@ -301,15 +303,18 @@
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"((_.a + (_.b / 2) + _.c**_.d) >> _) & _\n"
]
"data": {
"text/plain": [
"'_.a + _.b / 2 + _.c**_.d << _ & _'"
]
},
"execution_count": 10,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"f = _.a + _.b / 2 + _.c**_.d >> _ & _\n",
"f = _.a + _.b / 2 + _.c**_.d << _ & _\n",
"\n",
"explain(f)"
]
Expand Down Expand Up @@ -364,7 +369,7 @@
{
"data": {
"text/plain": [
"{'a', 'b', 'c'}"
"{'a', 'b'}"
]
},
"execution_count": 12,
Expand All @@ -376,7 +381,7 @@
"symbol = _.a[_.b + 1] + _['c']\n",
"\n",
"# hacky way to go from symbol to call for now\n",
"call = symbol.source\n",
"call = strip_symbolic(symbol)\n",
"\n",
"call.op_vars()"
]
Expand Down Expand Up @@ -415,7 +420,15 @@
"name": "stdout",
"output_type": "stream",
"text": [
"<built-in function add>(1,_('a') + _('b'))\n"
"<built-in function add>(1,_['a'] + _['b'])\n"
]
},
{
"name": "stderr",
"output_type": "stream",
"text": [
"/Users/machow/repos/siuba/siuba/meta_hook.py:20: UserWarning: The siuba.meta_hook module is DEPRECATED and will be removed in a future release.\n",
" warnings.warn(\n"
]
},
{
Expand All @@ -435,7 +448,7 @@
"from siuba.meta_hook.pandas import DataFrame\n",
"\n",
"f = add(1, _['a'] + _['b'])\n",
"explain(f)\n",
"print(explain(f))\n",
"\n",
"f({'a': 1, 'b': 2})"
]
Expand Down Expand Up @@ -549,7 +562,7 @@
"name": "stdout",
"output_type": "stream",
"text": [
"212 µs ± 50.2 µs per loop (mean ± std. dev. of 7 runs, 1000 loops each)\n"
"154 µs ± 177 ns per loop (mean ± std. dev. of 7 runs, 10,000 loops each)\n"
]
}
],
Expand All @@ -569,7 +582,7 @@
"name": "stdout",
"output_type": "stream",
"text": [
"7.29 µs ± 199 ns per loop (mean ± std. dev. of 7 runs, 100000 loops each)\n"
"2.74 µs ± 10 ns per loop (mean ± std. dev. of 7 runs, 100,000 loops each)\n"
]
}
],
Expand Down Expand Up @@ -672,7 +685,7 @@
],
"metadata": {
"kernelspec": {
"display_name": "Python 3",
"display_name": "Python 3 (ipykernel)",
"language": "python",
"name": "python3"
},
Expand All @@ -686,7 +699,7 @@
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.6.8"
"version": "3.8.12"
},
"toc": {
"base_numbering": 1,
Expand Down
3 changes: 1 addition & 2 deletions siuba/siu/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,11 +10,10 @@
)
from .symbolic import Symbolic, strip_symbolic, create_sym_call, explain
from .visitors import CallTreeLocal, CallVisitor, FunctionLookupBound, FunctionLookupError, ExecutionValidatorVisitor
from .dispatchers import symbolic_dispatch, singledispatch2, pipe_no_args, Pipeable
from .dispatchers import symbolic_dispatch, singledispatch2, pipe_no_args, Pipeable, pipe, call

Lam = Lazy

_ = Symbolic()

pipe = Pipeable

60 changes: 60 additions & 0 deletions siuba/siu/calls.py
Original file line number Diff line number Diff line change
Expand Up @@ -117,6 +117,8 @@ class Call:


"""


def __init__(self, func, *args, **kwargs):
self.func = func
self.args = args
Expand Down Expand Up @@ -189,6 +191,27 @@ def __call__(self, x):
f_op = getattr(operator, self.func)
return f_op(inst, *rest, **kwargs)

# TODO: type checks will be very useful here. Will need to import symbolic.
# Let's do this once types are in a _typing.py submodule.
def __rshift__(self, x):
"""Create a"""
from .symbolic import strip_symbolic

stripped = strip_symbolic(x)

if isinstance(stripped, Call):
return self._construct_pipe(MetaArg("_"), self, x)

raise TypeError()

def __rrshift__(self, x):
from .symbolic import strip_symbolic
if isinstance(strip_symbolic(x), (Call)):
# only allow non-calls (i.e. data) on the left.
raise TypeError()

return self(x)

@staticmethod
def evaluate_calls(arg, x):
if isinstance(arg, Call): return arg(x)
Expand Down Expand Up @@ -284,6 +307,10 @@ def obj_name(self):

return None

@classmethod
def _construct_pipe(cls, *args):
return PipeCall(*args)


class Lazy(Call):
"""Lazily return calls rather than evaluating them."""
Expand Down Expand Up @@ -586,4 +613,37 @@ def __call__(self, x):
return self.args[0]


# Pipe ===================================================================================

class PipeCall(Call):
"""
pipe(df, a, b, c)
pipe(_, a, b, c)

should options for first arg be only MetaArg or a non-call?
"""

def __init__(self, func, *args, **kwargs):
self.func = "__siu_pipe_call__"
self.args = (func, *args)
if kwargs:
raise ValueError("Keyword arguments are not allowed.")
self.kwargs = {}

def __call__(self, x=None):
# Note that most calls map_subcalls to pass in the same data for each argument.
# In contrast, PipeCall passes data from the prev step to the next.
crnt_data, *calls = self.args

if isinstance(crnt_data, MetaArg):
crnt_data = crnt_data(x)

for call in calls:
new_data = call(crnt_data)
crnt_data = new_data

return crnt_data

def __repr__(self):
args_repr = ",".join(map(repr, self.args))
return f"{type(self).__name__}({args_repr})"