From 835e4a7fc69ecf195a14eadcdb759b6d53a0029e Mon Sep 17 00:00:00 2001 From: Saul Shanabrook Date: Wed, 19 Nov 2025 21:58:18 -0800 Subject: [PATCH 1/2] Convert to multisets, facts as actions, and update multiset example --- python/egglog/builtins.py | 12 ++++ python/egglog/egraph.py | 23 ++++++-- python/egglog/examples/multiset.py | 92 +++++++++++++++++++----------- 3 files changed, 89 insertions(+), 38 deletions(-) diff --git a/python/egglog/builtins.py b/python/egglog/builtins.py index 9b40f21e..a9c75f69 100644 --- a/python/egglog/builtins.py +++ b/python/egglog/builtins.py @@ -39,6 +39,7 @@ "Map", "MapLike", "MultiSet", + "MultiSetLike", "Primitive", "PyObject", "Rational", @@ -583,6 +584,17 @@ def __add__(self, other: MultiSet[T]) -> MultiSet[T]: ... def map(self, f: Callable[[T], T]) -> MultiSet[T]: ... +converter( + tuple, + MultiSet, + lambda t: MultiSet[get_type_args()[0]]( # type: ignore[misc,operator] + *(convert(x, get_type_args()[0]) for x in t) + ), +) + +MultiSetLike: TypeAlias = MultiSet[T] | tuple[TO, ...] + + class Rational(BuiltinExpr, egg_sort="Rational"): @method(preserve=True) @deprecated("use .value") diff --git a/python/egglog/egraph.py b/python/egglog/egraph.py index aeb43457..ed9b69ba 100644 --- a/python/egglog/egraph.py +++ b/python/egglog/egraph.py @@ -367,6 +367,8 @@ def __replace_expr__(self, new_expr: Self) -> None: Replace the current expression with the new expression in place. """ + def __hash__(self) -> int: ... # type: ignore[empty-body] + class BuiltinExpr(BaseExpr, metaclass=_ExprMetaclass): """ @@ -1933,9 +1935,6 @@ def seq(*schedules: Schedule) -> Schedule: return Schedule(Thunk.fn(Declarations.create, *schedules), SequenceDecl(tuple(s.schedule for s in schedules))) -ActionLike: TypeAlias = Action | BaseExpr - - def _action_likes(action_likes: Iterable[ActionLike]) -> tuple[Action, ...]: return tuple(map(_action_like, action_likes)) @@ -1943,13 +1942,25 @@ def _action_likes(action_likes: Iterable[ActionLike]) -> tuple[Action, ...]: def _action_like(action_like: ActionLike) -> Action: if isinstance(action_like, Action): return action_like + if isinstance(action_like, Fact): + match action_like.fact: + case EqDecl(tp, left, right): + return Action( + action_like.__egg_decls__, + UnionDecl(tp, left, right), + ) + case ExprFactDecl(expr): + return Action( + action_like.__egg_decls__, + ExprActionDecl(expr), + ) + case _: + assert_never(action_like.fact) return expr_action(action_like) Command: TypeAlias = Action | RewriteOrRule -CommandLike: TypeAlias = ActionLike | RewriteOrRule - def _command_like(command_like: CommandLike) -> Command: if isinstance(command_like, RewriteOrRule): @@ -1976,6 +1987,8 @@ def _rewrite_or_rule_generator(gen: RewriteOrRuleGenerator, frame: FrameType) -> FactLike = Fact | BaseExpr +ActionLike: TypeAlias = Action | BaseExpr | Fact +CommandLike: TypeAlias = ActionLike | RewriteOrRule def _fact_likes(fact_likes: Iterable[FactLike]) -> tuple[Fact, ...]: diff --git a/python/egglog/examples/multiset.py b/python/egglog/examples/multiset.py index 8306dadc..f9e170bc 100644 --- a/python/egglog/examples/multiset.py +++ b/python/egglog/examples/multiset.py @@ -6,55 +6,81 @@ from __future__ import annotations -from collections import Counter - from egglog import * class Math(Expr): def __init__(self, x: i64Like) -> None: ... + def __add__(self, other: MathLike) -> Math: ... + def __radd__(self, other: MathLike) -> Math: ... + def __mul__(self, other: MathLike) -> Math: ... + def __rmul__(self, other: MathLike) -> Math: ... -@function -def square(x: Math) -> Math: ... - - -@ruleset -def math_ruleset(i: i64): - yield rewrite(square(Math(i))).to(Math(i * i)) - +MathLike = Math | i64Like +converter(i64, Math, Math) -egraph = EGraph() -xs = MultiSet(Math(1), Math(2), Math(3)) -egraph.register(xs) +@function +def sum(xs: MultiSetLike[Math, MathLike]) -> Math: ... -egraph.check(xs == MultiSet(Math(1), Math(3), Math(2))) -egraph.check_fail(xs == MultiSet(Math(1), Math(1), Math(2), Math(3))) -assert Counter(egraph.extract(xs).value) == Counter({Math(1): 1, Math(2): 1, Math(3): 1}) +@function +def product(xs: MultiSetLike[Math, MathLike]) -> Math: ... -inserted = MultiSet(Math(1), Math(2), Math(3), Math(4)) -egraph.register(inserted) -egraph.check(xs.insert(Math(4)) == inserted) -egraph.check(xs.contains(Math(1))) -egraph.check(xs.not_contains(Math(4))) -assert Math(1) in xs -assert Math(4) not in xs +@function +def square(x: Math) -> Math: ... -egraph.check(xs.remove(Math(1)) == MultiSet(Math(2), Math(3))) -assert egraph.extract(xs.length()).value == 3 -assert len(xs) == 3 +x = constant("x", Math) +expr1 = 2 * (x + 3) +expr2 = 6 + 2 * x -egraph.check(MultiSet(Math(1), Math(1)).length() == i64(2)) -egraph.check(MultiSet(Math(1)).pick() == Math(1)) +@ruleset +def math_ruleset(a: Math, b: Math, c: Math, i: i64, j: i64, xs: MultiSet[Math], ys: MultiSet[Math], zs: MultiSet[Math]): + yield rewrite(a + b).to(sum(MultiSet(a, b))) + yield rewrite(a * b).to(product(MultiSet(a, b))) + # 0 or 1 elements sums/products also can be extracted back to numbers + yield rule(a == sum(xs), xs.length() == i64(1)).then(a == xs.pick()) + yield rule(a == product(xs), xs.length() == i64(1)).then(a == xs.pick()) + yield rewrite(sum(MultiSet[Math]())).to(Math(0)) + yield rewrite(product(MultiSet[Math]())).to(Math(1)) + # distributive rule (a * (b + c) = a*b + a*c) + yield rule( + b == product(ys), + a == sum(xs), + ys.contains(a), + ys.length() > 1, + zs == ys.remove(a), + ).then( + b == sum(xs.map(lambda x: product(zs.insert(x)))), + ) + # constants + yield rule( + a == sum(xs), + b == Math(i), + xs.contains(b), + ys == xs.remove(b), + c == Math(j), + ys.contains(c), + ).then( + a == sum(ys.remove(c).insert(Math(i + j))), + ) + yield rule( + a == product(xs), + b == Math(i), + xs.contains(b), + ys == xs.remove(b), + c == Math(j), + ys.contains(c), + ).then( + a == product(ys.remove(c).insert(Math(i * j))), + ) -mapped = xs.map(square) -egraph.register(mapped) -egraph.run(math_ruleset) -egraph.check(mapped == MultiSet(Math(1), Math(4), Math(9))) -egraph.check(xs + xs == MultiSet(Math(1), Math(2), Math(3), Math(1), Math(2), Math(3))) +egraph = EGraph() +egraph.register(expr1, expr2) +egraph.run(math_ruleset.saturate()) +egraph.check(expr1 == expr2) From 0a88d8535f89e889db3107b12adfa4c981116263 Mon Sep 17 00:00:00 2001 From: Saul Shanabrook Date: Wed, 19 Nov 2025 21:58:43 -0800 Subject: [PATCH 2/2] update cargo version --- Cargo.lock | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/Cargo.lock b/Cargo.lock index d4ca4592..90841891 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -468,7 +468,7 @@ dependencies = [ [[package]] name = "egglog_python" -version = "11.4.0" +version = "12.0.0" dependencies = [ "base64", "egglog",