Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions docs/changelog.md
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,9 @@ _This project uses semantic versioning_
- Fix pretty printing of lambda functions
- Add support for subsuming rewrite generated by default function and method definitions
- Add better error message when using @function in class (thanks @shinawy)
- Add error method if `@method` decorator is in wrong place
- Subsumes lambda functions after replacing
- Add working loopnest test

## 8.0.1 (2024-10-24)

Expand Down
2 changes: 2 additions & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -241,6 +241,8 @@ filterwarnings = [
"error",
"ignore::numba.core.errors.NumbaPerformanceWarning",
"ignore::pytest_benchmark.logger.PytestBenchmarkWarning",
# https://github.com/manzt/anywidget/blob/d38bb3f5f9cfc7e49e2ff1aa1ba994d66327cb02/pyproject.toml#L120
"ignore:Deprecated in traitlets 4.1, use the instance .metadata:DeprecationWarning",
]

[tool.coverage.report]
Expand Down
3 changes: 2 additions & 1 deletion python/egglog/builtins.py
Original file line number Diff line number Diff line change
Expand Up @@ -532,7 +532,8 @@ def _convert_function(a: FunctionType) -> UnstableFn:
transformed_fn = functionalize(a, value_to_annotation)
assert isinstance(transformed_fn, partial)
return UnstableFn(
function(ruleset=get_current_ruleset(), use_body_as_name=True)(transformed_fn.func), *transformed_fn.args
function(ruleset=get_current_ruleset(), use_body_as_name=True, subsume=True)(transformed_fn.func),
*transformed_fn.args,
)


Expand Down
8 changes: 8 additions & 0 deletions python/egglog/egraph.py
Original file line number Diff line number Diff line change
Expand Up @@ -570,6 +570,10 @@ def _generate_class_decls( # noqa: C901,PLR0912
fn = fn.fget
case _:
ref = InitRef(cls_name) if is_init else MethodRef(cls_name, method_name)
if isinstance(fn, _WrappedMethod):
msg = f"{cls_name}.{method_name} Add the @method(...) decorator above @classmethod or @property"

raise ValueError(msg) # noqa: TRY004
special_function_name: SpecialFunctions | None = (
"fn-partial" if egg_fn == "unstable-fn" else "fn-app" if egg_fn == "unstable-app" else None
)
Expand Down Expand Up @@ -1373,10 +1377,14 @@ def saturate(
"""
Saturate the egraph, running the given schedule until the egraph is saturated.
It serializes the egraph at each step and returns a widget to visualize the egraph.

If an `expr` is passed, it's also extracted after each run and printed
"""
from .visualizer_widget import VisualizerWidget

def to_json() -> str:
if expr is not None:
print(self.extract(expr), "\n")
return self._serialize(**kwargs).to_json()

egraphs = [to_json()]
Expand Down
13 changes: 8 additions & 5 deletions python/egglog/exp/array_api.py
Original file line number Diff line number Diff line change
Expand Up @@ -278,7 +278,7 @@ def single(cls, i: Int) -> TupleInt:
return TupleInt(Int(1), lambda _: i)

@classmethod
def range(cls, stop: Int) -> TupleInt:
def range(cls, stop: IntLike) -> TupleInt:
return TupleInt(stop, lambda i: i)

@classmethod
Expand Down Expand Up @@ -346,7 +346,6 @@ def _tuple_int(
ti: TupleInt,
ti2: TupleInt,
):
remaining = TupleInt(k - 1, lambda i: idx_fn(i + 1)).filter(filter_f)
return [
rewrite(TupleInt(i, idx_fn).length()).to(i),
rewrite(TupleInt(i, idx_fn)[i2]).to(idx_fn(i2)),
Expand All @@ -367,7 +366,11 @@ def _tuple_int(
# filter TODO: could be written as fold w/ generic types
rewrite(TupleInt(0, idx_fn).filter(filter_f)).to(TupleInt(0, idx_fn)),
rewrite(TupleInt(Int(k), idx_fn).filter(filter_f)).to(
TupleInt.if_(filter_f(value := idx_fn(Int(k))), TupleInt.single(value) + remaining, remaining),
TupleInt.if_(
filter_f(value := idx_fn(Int(k - 1))),
(remaining := TupleInt(k - 1, idx_fn).filter(filter_f)) + TupleInt.single(value),
remaining,
),
ne(k).to(i64(0)),
),
# Empty
Expand All @@ -386,13 +389,13 @@ def var(cls, name: StringLike) -> TupleTupleInt: ...

def __init__(self, length: IntLike, idx_fn: Callable[[Int], TupleInt]) -> None: ...

@classmethod
@method(subsume=True)
@classmethod
def single(cls, i: TupleInt) -> TupleTupleInt:
return TupleTupleInt(Int(1), lambda _: i)

@classmethod
@method(subsume=True)
@classmethod
def from_vec(cls, vec: Vec[Int]) -> TupleInt:
return TupleInt(vec.length(), partial(index_vec_int, vec))

Expand Down
75 changes: 2 additions & 73 deletions python/egglog/exp/array_api_loopnest.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,8 @@
from egglog import *
from egglog.exp.array_api import *

__all__ = ["LoopNestAPI", "OptionalLoopNestAPI", "ShapeAPI"]


class ShapeAPI(Expr):
def __init__(self, dims: TupleIntLike) -> None: ...
Expand Down Expand Up @@ -105,76 +107,3 @@ def _loopnest_api_ruleset(
yield rewrite(lna.indices, subsume=True).to(
tuple_tuple_int_product(tuple_int_map_tuple_int(lna.get_dims(), TupleInt.range))
)


@function(ruleset=array_api_ruleset, subsume=True)
def linalg_norm(X: NDArray, axis: TupleIntLike) -> NDArray:
# peel off the outer shape for result array
outshape = ShapeAPI(X.shape).deselect(axis).to_tuple()
# get only the inner shape for reduction
reduce_axis = ShapeAPI(X.shape).select(axis).to_tuple()

return NDArray(
outshape,
X.dtype,
lambda k: sqrt(
LoopNestAPI.from_tuple(reduce_axis)
.unwrap()
.fold(lambda carry, i: carry + real(conj(x := X[i + k]) * x), init=0.0)
).to_value(),
)


# %%
# egraph = EGraph(save_egglog_string=True)

# egraph.register(val.shape)
# egraph.run(array_api_ruleset.saturate())
# egraph.extract_multiple(val.shape, 10)

# %%

X = NDArray.var("X")
assume_shape(X, (3, 2, 3, 4))
val = linalg_norm(X, (0, 1))
egraph = EGraph()
x = egraph.let("x", val.shape[2])
# egraph.display(n_inline_leaves=0)
# egraph.extract(x)
# egraph.saturate(array_api_ruleset, expr=x, split_functions=[Int, TRUE, FALSE], n_inline_leaves=0)
# egraph.run(array_api_ruleset.saturate())
# egraph.extract(x)
# egraph.display()


# %%

# x = xs[-2]
# # %%
# decls = x.__egg_decls__
# # RuntimeExpr.__from_values__(x.__egg_decls__, x.__egg_typed_expr__.expr.args[1].expr.args[1])

# # %%
# # x.__egg_typed_expr__.expr.args[1].expr.args[1] # %%

# # %%
# # egraph.extract(RuntimeExpr.__from_values__(x.__egg_decls__, x.__egg_typed_expr__.expr.args[1].expr.args[1]))


# from egglog import pretty

# decl = (
# x.__egg_typed_expr__.expr.args[1]
# .expr.args[2]
# .expr.args[0]
# .expr.args[1]
# .expr.call.args[0]
# .expr.call.args[0]
# .expr.call.args[0]
# )

# # pprint.pprint(decl)

# print(pretty.pretty_decl(decls, decl.expr))

# # %%
42 changes: 39 additions & 3 deletions python/tests/test_array_api.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,8 +8,10 @@
from sklearn import config_context, datasets
from sklearn.discriminant_analysis import LinearDiscriminantAnalysis

from egglog.egraph import set_current_ruleset
from egglog.exp.array_api import *
from egglog.exp.array_api_jit import jit
from egglog.exp.array_api_loopnest import *
from egglog.exp.array_api_numba import array_api_numba_schedule
from egglog.exp.array_api_program_gen import *

Expand Down Expand Up @@ -68,6 +70,41 @@ def test_reshape_vec_noop():
egraph.check(eq(res).to(x))


def test_filter():
with set_current_ruleset(array_api_ruleset):
x = TupleInt.range(5).filter(lambda i: i < 2).length()
check_eq(x, Int(2), array_api_schedule)


@function(ruleset=array_api_ruleset, subsume=True)
def linalg_norm(X: NDArray, axis: TupleIntLike) -> NDArray:
# peel off the outer shape for result array
outshape = ShapeAPI(X.shape).deselect(axis).to_tuple()
# get only the inner shape for reduction
reduce_axis = ShapeAPI(X.shape).select(axis).to_tuple()

return NDArray(
outshape,
X.dtype,
lambda k: sqrt(
LoopNestAPI.from_tuple(reduce_axis)
.unwrap()
.fold(lambda carry, i: carry + real(conj(x := X[i + k]) * x), init=0.0)
).to_value(),
)


class TestLoopNest:
def test_shape(self):
X = NDArray.var("X")
assume_shape(X, (3, 2, 3, 4))
val = linalg_norm(X, (0, 1))

check_eq(val.shape.length(), Int(2), array_api_schedule)
check_eq(val.shape[0], Int(3), array_api_schedule)
check_eq(val.shape[1], Int(4), array_api_schedule)


# This test happens in different steps. Each will be benchmarked and saved as a snapshot.
# The next step will load the old snapshot and run their test on it.

Expand All @@ -80,7 +117,6 @@ def run_lda(x, y):

iris = datasets.load_iris()
X_np, y_np = (iris.data, iris.target)
res_np = run_lda(X_np, y_np)


def _load_py_snapshot(fn: Callable, var: str | None = None) -> Any:
Expand Down Expand Up @@ -165,7 +201,7 @@ def test_source_optimized(self, snapshot_py, benchmark):
optimized_expr = simplify_lda(egraph, expr)
fn_program = ndarray_function_two(optimized_expr, NDArray.var("X"), NDArray.var("y"))
py_object = benchmark(load_source, fn_program, egraph)
assert np.allclose(py_object(X_np, y_np), res_np)
assert np.allclose(py_object(X_np, y_np), run_lda(X_np, y_np))
assert egraph.eval(fn_program.statements) == snapshot_py

@pytest.mark.parametrize(
Expand All @@ -180,7 +216,7 @@ def test_source_optimized(self, snapshot_py, benchmark):
)
def test_execution(self, fn, benchmark):
# warmup once for numba
assert np.allclose(res_np, fn(X_np, y_np))
assert np.allclose(run_lda(X_np, y_np), fn(X_np, y_np))
benchmark(fn, X_np, y_np)


Expand Down
Loading