diff --git a/docs/changelog.md b/docs/changelog.md index 8b938fdb..377c79cd 100644 --- a/docs/changelog.md +++ b/docs/changelog.md @@ -7,6 +7,9 @@ _This project uses semantic versioning_ - Fix pretty printing of lambda functions - Add support for subsuming rewrite generated by default function and method definitions - Add better error message when using @function in class (thanks @shinawy) +- Add error method if `@method` decorator is in wrong place +- Subsumes lambda functions after replacing +- Add working loopnest test ## 8.0.1 (2024-10-24) diff --git a/pyproject.toml b/pyproject.toml index d083677e..69c6c39e 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -241,6 +241,8 @@ filterwarnings = [ "error", "ignore::numba.core.errors.NumbaPerformanceWarning", "ignore::pytest_benchmark.logger.PytestBenchmarkWarning", + # https://github.com/manzt/anywidget/blob/d38bb3f5f9cfc7e49e2ff1aa1ba994d66327cb02/pyproject.toml#L120 + "ignore:Deprecated in traitlets 4.1, use the instance .metadata:DeprecationWarning", ] [tool.coverage.report] diff --git a/python/egglog/builtins.py b/python/egglog/builtins.py index bbb687bd..41d97eb2 100644 --- a/python/egglog/builtins.py +++ b/python/egglog/builtins.py @@ -532,7 +532,8 @@ def _convert_function(a: FunctionType) -> UnstableFn: transformed_fn = functionalize(a, value_to_annotation) assert isinstance(transformed_fn, partial) return UnstableFn( - function(ruleset=get_current_ruleset(), use_body_as_name=True)(transformed_fn.func), *transformed_fn.args + function(ruleset=get_current_ruleset(), use_body_as_name=True, subsume=True)(transformed_fn.func), + *transformed_fn.args, ) diff --git a/python/egglog/egraph.py b/python/egglog/egraph.py index c237d36e..c0532fd3 100644 --- a/python/egglog/egraph.py +++ b/python/egglog/egraph.py @@ -570,6 +570,10 @@ def _generate_class_decls( # noqa: C901,PLR0912 fn = fn.fget case _: ref = InitRef(cls_name) if is_init else MethodRef(cls_name, method_name) + if isinstance(fn, _WrappedMethod): + msg = f"{cls_name}.{method_name} Add the @method(...) decorator above @classmethod or @property" + + raise ValueError(msg) # noqa: TRY004 special_function_name: SpecialFunctions | None = ( "fn-partial" if egg_fn == "unstable-fn" else "fn-app" if egg_fn == "unstable-app" else None ) @@ -1373,10 +1377,14 @@ def saturate( """ Saturate the egraph, running the given schedule until the egraph is saturated. It serializes the egraph at each step and returns a widget to visualize the egraph. + + If an `expr` is passed, it's also extracted after each run and printed """ from .visualizer_widget import VisualizerWidget def to_json() -> str: + if expr is not None: + print(self.extract(expr), "\n") return self._serialize(**kwargs).to_json() egraphs = [to_json()] diff --git a/python/egglog/exp/array_api.py b/python/egglog/exp/array_api.py index 47b9ecc7..dec32941 100644 --- a/python/egglog/exp/array_api.py +++ b/python/egglog/exp/array_api.py @@ -278,7 +278,7 @@ def single(cls, i: Int) -> TupleInt: return TupleInt(Int(1), lambda _: i) @classmethod - def range(cls, stop: Int) -> TupleInt: + def range(cls, stop: IntLike) -> TupleInt: return TupleInt(stop, lambda i: i) @classmethod @@ -346,7 +346,6 @@ def _tuple_int( ti: TupleInt, ti2: TupleInt, ): - remaining = TupleInt(k - 1, lambda i: idx_fn(i + 1)).filter(filter_f) return [ rewrite(TupleInt(i, idx_fn).length()).to(i), rewrite(TupleInt(i, idx_fn)[i2]).to(idx_fn(i2)), @@ -367,7 +366,11 @@ def _tuple_int( # filter TODO: could be written as fold w/ generic types rewrite(TupleInt(0, idx_fn).filter(filter_f)).to(TupleInt(0, idx_fn)), rewrite(TupleInt(Int(k), idx_fn).filter(filter_f)).to( - TupleInt.if_(filter_f(value := idx_fn(Int(k))), TupleInt.single(value) + remaining, remaining), + TupleInt.if_( + filter_f(value := idx_fn(Int(k - 1))), + (remaining := TupleInt(k - 1, idx_fn).filter(filter_f)) + TupleInt.single(value), + remaining, + ), ne(k).to(i64(0)), ), # Empty @@ -386,13 +389,13 @@ def var(cls, name: StringLike) -> TupleTupleInt: ... def __init__(self, length: IntLike, idx_fn: Callable[[Int], TupleInt]) -> None: ... - @classmethod @method(subsume=True) + @classmethod def single(cls, i: TupleInt) -> TupleTupleInt: return TupleTupleInt(Int(1), lambda _: i) - @classmethod @method(subsume=True) + @classmethod def from_vec(cls, vec: Vec[Int]) -> TupleInt: return TupleInt(vec.length(), partial(index_vec_int, vec)) diff --git a/python/egglog/exp/array_api_loopnest.py b/python/egglog/exp/array_api_loopnest.py index c12c67e5..c35fda51 100644 --- a/python/egglog/exp/array_api_loopnest.py +++ b/python/egglog/exp/array_api_loopnest.py @@ -12,6 +12,8 @@ from egglog import * from egglog.exp.array_api import * +__all__ = ["LoopNestAPI", "OptionalLoopNestAPI", "ShapeAPI"] + class ShapeAPI(Expr): def __init__(self, dims: TupleIntLike) -> None: ... @@ -105,76 +107,3 @@ def _loopnest_api_ruleset( yield rewrite(lna.indices, subsume=True).to( tuple_tuple_int_product(tuple_int_map_tuple_int(lna.get_dims(), TupleInt.range)) ) - - -@function(ruleset=array_api_ruleset, subsume=True) -def linalg_norm(X: NDArray, axis: TupleIntLike) -> NDArray: - # peel off the outer shape for result array - outshape = ShapeAPI(X.shape).deselect(axis).to_tuple() - # get only the inner shape for reduction - reduce_axis = ShapeAPI(X.shape).select(axis).to_tuple() - - return NDArray( - outshape, - X.dtype, - lambda k: sqrt( - LoopNestAPI.from_tuple(reduce_axis) - .unwrap() - .fold(lambda carry, i: carry + real(conj(x := X[i + k]) * x), init=0.0) - ).to_value(), - ) - - -# %% -# egraph = EGraph(save_egglog_string=True) - -# egraph.register(val.shape) -# egraph.run(array_api_ruleset.saturate()) -# egraph.extract_multiple(val.shape, 10) - -# %% - -X = NDArray.var("X") -assume_shape(X, (3, 2, 3, 4)) -val = linalg_norm(X, (0, 1)) -egraph = EGraph() -x = egraph.let("x", val.shape[2]) -# egraph.display(n_inline_leaves=0) -# egraph.extract(x) -# egraph.saturate(array_api_ruleset, expr=x, split_functions=[Int, TRUE, FALSE], n_inline_leaves=0) -# egraph.run(array_api_ruleset.saturate()) -# egraph.extract(x) -# egraph.display() - - -# %% - -# x = xs[-2] -# # %% -# decls = x.__egg_decls__ -# # RuntimeExpr.__from_values__(x.__egg_decls__, x.__egg_typed_expr__.expr.args[1].expr.args[1]) - -# # %% -# # x.__egg_typed_expr__.expr.args[1].expr.args[1] # %% - -# # %% -# # egraph.extract(RuntimeExpr.__from_values__(x.__egg_decls__, x.__egg_typed_expr__.expr.args[1].expr.args[1])) - - -# from egglog import pretty - -# decl = ( -# x.__egg_typed_expr__.expr.args[1] -# .expr.args[2] -# .expr.args[0] -# .expr.args[1] -# .expr.call.args[0] -# .expr.call.args[0] -# .expr.call.args[0] -# ) - -# # pprint.pprint(decl) - -# print(pretty.pretty_decl(decls, decl.expr)) - -# # %% diff --git a/python/tests/test_array_api.py b/python/tests/test_array_api.py index b6daa22f..b48fb7bd 100644 --- a/python/tests/test_array_api.py +++ b/python/tests/test_array_api.py @@ -8,8 +8,10 @@ from sklearn import config_context, datasets from sklearn.discriminant_analysis import LinearDiscriminantAnalysis +from egglog.egraph import set_current_ruleset from egglog.exp.array_api import * from egglog.exp.array_api_jit import jit +from egglog.exp.array_api_loopnest import * from egglog.exp.array_api_numba import array_api_numba_schedule from egglog.exp.array_api_program_gen import * @@ -68,6 +70,41 @@ def test_reshape_vec_noop(): egraph.check(eq(res).to(x)) +def test_filter(): + with set_current_ruleset(array_api_ruleset): + x = TupleInt.range(5).filter(lambda i: i < 2).length() + check_eq(x, Int(2), array_api_schedule) + + +@function(ruleset=array_api_ruleset, subsume=True) +def linalg_norm(X: NDArray, axis: TupleIntLike) -> NDArray: + # peel off the outer shape for result array + outshape = ShapeAPI(X.shape).deselect(axis).to_tuple() + # get only the inner shape for reduction + reduce_axis = ShapeAPI(X.shape).select(axis).to_tuple() + + return NDArray( + outshape, + X.dtype, + lambda k: sqrt( + LoopNestAPI.from_tuple(reduce_axis) + .unwrap() + .fold(lambda carry, i: carry + real(conj(x := X[i + k]) * x), init=0.0) + ).to_value(), + ) + + +class TestLoopNest: + def test_shape(self): + X = NDArray.var("X") + assume_shape(X, (3, 2, 3, 4)) + val = linalg_norm(X, (0, 1)) + + check_eq(val.shape.length(), Int(2), array_api_schedule) + check_eq(val.shape[0], Int(3), array_api_schedule) + check_eq(val.shape[1], Int(4), array_api_schedule) + + # This test happens in different steps. Each will be benchmarked and saved as a snapshot. # The next step will load the old snapshot and run their test on it. @@ -80,7 +117,6 @@ def run_lda(x, y): iris = datasets.load_iris() X_np, y_np = (iris.data, iris.target) -res_np = run_lda(X_np, y_np) def _load_py_snapshot(fn: Callable, var: str | None = None) -> Any: @@ -165,7 +201,7 @@ def test_source_optimized(self, snapshot_py, benchmark): optimized_expr = simplify_lda(egraph, expr) fn_program = ndarray_function_two(optimized_expr, NDArray.var("X"), NDArray.var("y")) py_object = benchmark(load_source, fn_program, egraph) - assert np.allclose(py_object(X_np, y_np), res_np) + assert np.allclose(py_object(X_np, y_np), run_lda(X_np, y_np)) assert egraph.eval(fn_program.statements) == snapshot_py @pytest.mark.parametrize( @@ -180,7 +216,7 @@ def test_source_optimized(self, snapshot_py, benchmark): ) def test_execution(self, fn, benchmark): # warmup once for numba - assert np.allclose(res_np, fn(X_np, y_np)) + assert np.allclose(run_lda(X_np, y_np), fn(X_np, y_np)) benchmark(fn, X_np, y_np)