In [455]:
%load_ext cython

The cython extension is already loaded. To reload it, use:
  %reload_ext cython


In [456]:
NOT_GIVEN = object()
impl = add_lambda = lambda x, y: x + y

# Currying techniques
def curry2_with_exception(impl):
    def func(*args):
        try:
            return impl(*args)
        except TypeError as exc:
            if len(args) == 1:
                return lambda y: impl(x, y)
            else:
                raise
    return func


def curry2_with_exception_kw(impl):
    def func(*args, **kwargs):
        try:
            return impl(*args, **kwargs)
        except TypeError as exc:
            if len(args) == 1:
                return lambda y, **kw: impl(x, y, **kwargs, **kw)
            else:
                raise
    return func


def curry2_with_extra_args(impl):
    def func(x, *args):
        if args:
            y, = args
            return impl(x, y)
        return lambda y: impl(x, y)
    return func
    
    
def curry2_with_not_given(impl):
    def func(x, y=NOT_GIVEN):
        if y is NOT_GIVEN:
            return lambda y: impl(x, y)
        return impl(x, y)
    return func


def curry2_with_args(impl):
    def func(*args):
        if len(args) == 2:
            return impl(*args)
        elif len(args) == 1:
            x = args[0]
            return lambda y: impl(x, y)
        else:
            raise TypeError('invalid number of args')
    return func


def curry(n, func):
    def incomplete_factory(arity, used_args):
        return lambda *args: (
            func(*(used_args + args))
            if len(used_args) + len(args) >= arity
            else incomplete_factory(arity, used_args + args)
        )
    return incomplete_factory(n, ())


def curry_kw(n, func):
    def incomplete_factory(arity, used_args, used_kwargs):
        return lambda *args, **kwargs: (
            func(*(used_args + args), **used_kwargs, **kwargs)
            if len(used_args) + len(args) >= arity
            else incomplete_factory(arity, used_args + args, {**used_kwargs, **kwargs})
        )
    return incomplete_factory(n, (), {})

In [457]:
class wrapped:
    __slots__ = 'func',
    def __init__(self, func):
        self.func = func
    
    def __call__(self, *args, **kwargs):
        return self.func(*args, **kwargs)
    
class wrapped2:
    __slots__ = 'func',
    def __init__(self, func):
        self.func = func
    
    def __call__(self, x, y):
        return self.func(x, y)

In [458]:
%%cython
cimport cython

NOT_GIVEN = object()

            
@cython.freelist(8)
cdef class cy:
    cdef object func
    cdef int arity
    
    def __init__(self, arity, func):
        self.func = func
        self.arity = arity
    
    def __call__(self, *args):
        cdef int n = len(args)
        if n == 2:
            return self.func(*args)
        elif n == 1:
            f = self.func 
            x = args[0]
            return lambda y: f(x, y)
        else:
            raise TypeError('invalid number of arguments')
            
def cy_fn(int arity, impl):
    def foo(*args, **kwargs):
        cdef int n = len(args)
        if n == arity:
            return impl(*args, **kwargs)
        else:
            return lambda *xs: impl(*(xs + args))
    return foo

In [459]:
from operator import add
x = 1

# Baseline
%timeit -n 100000 -r 100 (1).__add__(2)
%timeit -n 100000 -r 100 add(1, 2)
%timeit -n 100000 -r 100 impl(1, 2)

149 ns ± 15.6 ns per loop (mean ± std. dev. of 100 runs, 100000 loops each)
89.2 ns ± 25.4 ns per loop (mean ± std. dev. of 100 runs, 100000 loops each)
175 ns ± 63.7 ns per loop (mean ± std. dev. of 100 runs, 100000 loops each)


In [329]:
# Function based currying techniques for 2-arg functions
f1 = curry2_with_extra_args(impl)
f2 = curry2_with_not_given(impl)
f3 = curry2_with_args(impl)
f4 = curry2_with_exception(impl)
f5 = curry2_with_exception_kw(impl)

%timeit -n 100000 f1(1, 2)
%timeit -n 100000 f2(1, 2)
%timeit -n 100000 f3(1, 2)
%timeit -n 100000 f4(1, 2)
%timeit -n 100000 f5(1, 2)

378 ns ± 47 ns per loop (mean ± std. dev. of 7 runs, 100000 loops each)
302 ns ± 3.4 ns per loop (mean ± std. dev. of 7 runs, 100000 loops each)
468 ns ± 43.6 ns per loop (mean ± std. dev. of 7 runs, 100000 loops each)
273 ns ± 4.73 ns per loop (mean ± std. dev. of 7 runs, 100000 loops each)
387 ns ± 26.6 ns per loop (mean ± std. dev. of 7 runs, 100000 loops each)


In [318]:
# Generic currying
f1 = curry(2, impl)
f2 = curry_kw(2, impl)

%timeit -n 100000 f1(1, 2)
%timeit -n 100000 f2(1, 2)

536 ns ± 37.4 ns per loop (mean ± std. dev. of 7 runs, 100000 loops each)
683 ns ± 63 ns per loop (mean ± std. dev. of 7 runs, 100000 loops each)


In [344]:
# Other libs
import toolz, cytoolz, funcy

f1 = toolz.curry(impl)
f2 = cytoolz.curry(impl)
f3 = funcy.curry(impl, 2)
f4 = funcy.autocurry(impl, 2)

%timeit -n 100000 f1(1, 2)
%timeit -n 100000 f2(1, 2)
%timeit -n 100000 f3(1)(2)
%timeit -n 100000 f4(1, 2)

574 ns ± 49.8 ns per loop (mean ± std. dev. of 7 runs, 100000 loops each)
281 ns ± 7.62 ns per loop (mean ± std. dev. of 7 runs, 100000 loops each)
611 ns ± 36.3 ns per loop (mean ± std. dev. of 7 runs, 100000 loops each)
1 µs ± 41.6 ns per loop (mean ± std. dev. of 7 runs, 100000 loops each)


In [462]:
# Cython accelerators
f1 = cy(2, impl)
f2 = cy_fn(2, impl)

%timeit -n 100000 -r 100 f1(1, 2)
%timeit -n 100000 -r 100 f2(1, 2)

209 ns ± 11.2 ns per loop (mean ± std. dev. of 100 runs, 100000 loops each)
213 ns ± 13.6 ns per loop (mean ± std. dev. of 100 runs, 100000 loops each)


In [463]:
def foo(f):
    def foo1(x):
        g = f(x)
        def foo2(y):
            return g(y)
        return foo2
    return foo1


def foo2(f):
    x = yield
    g = f(x)
    y = yield
    return g(y)
    
foo(curry(2, add))(1)(2)


it = foo2(curry(2, add))
next(it)
it.send(1)
# it.send(2)
# next(it)

In [476]:
from functools import partial
from pprint import pprint
log = lambda x, y: print(x) or pprint(y) or print()

c_add = lambda x: lambda y: lambda z: x + y + z
c_add2 = lambda x: lambda y, z: lambda w: x + y + z + w


def uncurry(spec, func):
    if len(spec) == 1:
        return curry(spec[0], func)
    
    spec = list(reversed(spec))
    
    def curried(*args, **kwargs):
        log('starting', locals())
        n_args = len(args)
        argspec = spec.copy()
        call = func
        n = argspec.pop()
            
        if n > n_args:
            argspec.append(n - n_args)
            call = partial(call, *args, **kwargs)
        else:
            while n_args >= n:
                partial_args = args[:n]
                args = args[n:]
                call = call(*partial_args, **kwargs)
                n_args -= n
                kwargs = {}
                log('reduced', locals())
                if argspec:
                    n = argspec.pop()
                else:
                    break
            else:
                argspec.append(n)

        log('post_reduction', locals())
        
        if argspec:
            return uncurry(argspec[::-1], call)
        else:
            return call

    return curried


# (
#     uncurry([1, 1, 1], c_add)(1, 2, 3),
#     uncurry([1, 1, 1], c_add)(1)(2)(3), 
#     uncurry([1, 1, 1], c_add)(1)(2, 3),
#     uncurry([1, 1, 1], c_add)(1, 2)(3),
# )


(
#     uncurry([1, 2, 1], c_add2)(1, 2, 3, 4),
#    uncurry([1, 2, 1], c_add2)(1)(2)(3)(4), 
#     uncurry([1, 2, 1], c_add2)(1)(2, 3)(4),
     uncurry([1, 2, 1], c_add2)(1, 2)(3)(4),
)

starting
{'args': (1, 2),
 'func': <function <lambda> at 0x7f14aea0df28>,
 'kwargs': {},
 'spec': [1, 2, 1]}

reduced
{'args': (2,),
 'argspec': [1, 2],
 'call': <function <lambda>.<locals>.<lambda> at 0x7f14ae5ead90>,
 'func': <function <lambda> at 0x7f14aea0df28>,
 'kwargs': {},
 'n': 1,
 'n_args': 1,
 'partial_args': (1,),
 'spec': [1, 2, 1]}

post_reduction
{'args': (2,),
 'argspec': [1, 2],
 'call': <function <lambda>.<locals>.<lambda> at 0x7f14ae5ead90>,
 'func': <function <lambda> at 0x7f14aea0df28>,
 'kwargs': {},
 'n': 2,
 'n_args': 1,
 'partial_args': (1,),
 'spec': [1, 2, 1]}

starting
{'args': (3,),
 'func': <function <lambda>.<locals>.<lambda> at 0x7f14ae5ead90>,
 'kwargs': {},
 'spec': [1, 2]}

post_reduction
{'args': (3,),
 'argspec': [1, 1],
 'call': functools.partial(<function <lambda>.<locals>.<lambda> at 0x7f14ae5ead90>, 3),
 'func': <function <lambda>.<locals>.<lambda> at 0x7f14ae5ead90>,
 'kwargs': {},
 'n': 2,
 'n_args': 1,
 'spec': [1, 2]}

starting
{'args': (4,)

(<function __main__.curry.<locals>.incomplete_factory.<locals>.<lambda>(*args)>,)

In [452]:
r =  lambda x: lambda y: x + y
cadd = uncurry([1, 1], r)
%timeit -n 10000 cadd(1, 2)
%timeit -n 10000 cadd(1)(2)
%timeit -n 10000 r(1)(2)

2.13 µs ± 366 ns per loop (mean ± std. dev. of 7 runs, 10000 loops each)
3.55 µs ± 75.4 ns per loop (mean ± std. dev. of 7 runs, 10000 loops each)
406 ns ± 10.7 ns per loop (mean ± std. dev. of 7 runs, 10000 loops each)
