diff --git a/currying/__init__.py b/currying/__init__.py new file mode 100644 index 0000000..8b13789 --- /dev/null +++ b/currying/__init__.py @@ -0,0 +1 @@ + diff --git a/currying/src/__init__.py b/currying/src/__init__.py new file mode 100644 index 0000000..8b13789 --- /dev/null +++ b/currying/src/__init__.py @@ -0,0 +1 @@ + diff --git a/currying/src/lib.py b/currying/src/lib.py new file mode 100644 index 0000000..b28e50a --- /dev/null +++ b/currying/src/lib.py @@ -0,0 +1,36 @@ +# Curry a given function. +# If the definition uses *args, specify the amount given in a specific case +def curry(f: callable, arity=None): + if arity is None: + arity = f.__code__.co_argcount + elif arity < 0: + raise Exception("negative arity") + elif arity < f.__code__.co_argcount: + raise Exception("specified arity is lesser than required") + + if arity == 0: + return lambda: f() + + def inner(*args): + if len(args) >= arity: + return f(*args) + + return lambda arg: inner(*args, arg) + + return lambda arg: inner(arg) + + +# Uncurry a curried function +def uncurry(f: callable): + def inner(*args): + value = f + + for arg in args: + try: + value = value(arg) + except Exception: + raise Exception("incorrect amount of arguments provided") + + return value + + return inner diff --git a/currying/tests/__init__.py b/currying/tests/__init__.py new file mode 100644 index 0000000..8b13789 --- /dev/null +++ b/currying/tests/__init__.py @@ -0,0 +1 @@ + diff --git a/currying/tests/test.py b/currying/tests/test.py new file mode 100644 index 0000000..97c1a1d --- /dev/null +++ b/currying/tests/test.py @@ -0,0 +1,45 @@ +import pytest + +from ..src.lib import curry, uncurry + + +def add_args(*args): + return sum(args) + + +def add2(a, b): + return a + b + + +def test_specified(): + assert curry(add_args, 3)(1)(2)(3) == 6 + + +def test_negative(): + with pytest.raises(Exception) as e: + curry(add2, -1) + assert e == "negative arity" + + +def test_less(): + with pytest.raises(Exception) as e: + curry(add2, 1) + assert e == "specified arity is lesser than required" + + +def test_unspecified(): + assert curry(add2)(1)(2) == 3 + + +def test_zero(): + assert curry(add_args, 0)() == 0 + + +def test_uncurry(): + assert uncurry(curry(add2))(1, 2) == 3 + + +def test_uncurry_incorrect(): + with pytest.raises(Exception) as e: + uncurry(curry(add2))(1, 2, 3) + assert e == "incorrect amount of arguments provided"