Skip to content

Commit

Permalink
Merge 4833b98 into 34cd8c2
Browse files Browse the repository at this point in the history
  • Loading branch information
jakeret committed Dec 16, 2014
2 parents 34cd8c2 + 4833b98 commit 13aeb21
Show file tree
Hide file tree
Showing 5 changed files with 16 additions and 3 deletions.
4 changes: 2 additions & 2 deletions hope/_ast.py
Original file line number Diff line number Diff line change
Expand Up @@ -306,7 +306,7 @@ def get_dtype(self, args, keywords):
self.__check(args, keywords)
if self.name in ["empty", "zeros", "ones"]:
return getattr(np, keywords["dtype"].name) if len(keywords) == 1 else np.float64
elif self.name in "interp" or self.name in NPY_UNARY_FUNCTIONS:
elif self.name in ["interp", "sign"] or self.name in NPY_UNARY_FUNCTIONS:
return args[0].dtype
elif self.name in NPY_CAST_FUNCTIONS:
return NPY_CAST_FUNCTIONS[self.name]
Expand All @@ -315,7 +315,7 @@ def get_shape(self, args, keywords):
self.__check(args, keywords)
if self.name in ["empty", "zeros", "ones"]:
return [(None, arg) for arg in args[0]] if isinstance(args[0], list) else [(None, args[0])]
elif self.name == "interp" or self.name in NPY_UNARY_FUNCTIONS or self.name in NPY_CAST_FUNCTIONS:
elif self.name in ["interp", "sign"] or self.name in NPY_UNARY_FUNCTIONS or self.name in NPY_CAST_FUNCTIONS:
return args[0].shape


Expand Down
2 changes: 1 addition & 1 deletion hope/_const.py
Original file line number Diff line number Diff line change
Expand Up @@ -188,7 +188,7 @@
NPY_UNARY_FUNCTIONS["abs"] = "abs"
# NPY_UNARY_FUNCTIONS["absolute"] = "absolute"
NPY_UNARY_FUNCTIONS["fabs"] = "std::fabs"
# NPY_UNARY_FUNCTIONS["sign"] = "sign"
NPY_UNARY_FUNCTIONS["sign"] = "sign"

NPY_BINARY_FUNCTIONS = {}
# NPY_BINARY_FUNCTIONS["hypot"] = "hypot"
Expand Down
3 changes: 3 additions & 0 deletions hope/_generator.py
Original file line number Diff line number Diff line change
Expand Up @@ -248,6 +248,9 @@ def visit_Call(self, node):
if "right" in node.keywords:
ret = "{0} > {1} ? {2} : ({3})".format(self.visit(node.args[0]), right, self.visit(node.keywords["right"]), ret)
return "({0})".format(ret)
elif isinstance(node.name, NumpyAttr) and node.name.name == "sign":
self.library["native_sign"] = LIBRARY_NATIVE_SIGN
return "native_sign({0})".format(self.visit(node.args[0]))
elif isinstance(node.name, NumpyAttr) and node.name.name in NPY_UNARY_FUNCTIONS:
return "{0}({1})".format(NPY_UNARY_FUNCTIONS[node.name.name], self.visit(node.args[0]))
elif isinstance(node.name, NumpyAttr) and node.name.name in NPY_CAST_FUNCTIONS:
Expand Down
6 changes: 6 additions & 0 deletions hope/_library.py
Original file line number Diff line number Diff line change
Expand Up @@ -294,6 +294,12 @@
}
"""

LIBRARY_NATIVE_SIGN = """
template<typename T> inline T native_sign(T arg) {
return T(T(0) < arg) - T(arg < T(0));
}
"""

LIBRARY_NATIVE_RANGECHECK = """
#include <string>
inline int native_rangecheck(int x, int u, int l, std::string idxname, std::string varname) {
Expand Down
4 changes: 4 additions & 0 deletions test/test_functions.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,6 +53,10 @@ def test_func_cos(a, b, c): return np.cos(a)
@make_test
def test_func_tan(a, b, c): return np.tan(a)

@pytest.mark.parametrize("dtype,shape", itertools.product([np.float32, np.float64, float], shapes))
@make_test
def test_func_sign(a, b, c): return np.sign(a)

@pytest.mark.parametrize("dtype,shape", itertools.product([np.float32, np.float64, float], shapes))
def test_func_arcsin(dtype, shape):
def fkt(a):
Expand Down

0 comments on commit 13aeb21

Please sign in to comment.