diff --git a/Cargo.lock b/Cargo.lock index d00f21b5..75a0acb0 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -357,7 +357,7 @@ dependencies = [ [[package]] name = "egglog_python" -version = "11.1.0" +version = "11.2.0" dependencies = [ "core-relations 1.0.0 (git+https://github.com/egraphs-good/egglog.git?branch=main)", "egglog", diff --git a/docs/changelog.md b/docs/changelog.md index bcce6577..d1de268b 100644 --- a/docs/changelog.md +++ b/docs/changelog.md @@ -4,6 +4,7 @@ _This project uses semantic versioning_ ## UNRELEASED +- Add `back_off` scheduler [#350](https://github.com/egraphs-good/egglog-python/pull/350) ## 11.2.0 (2025-09-03) - Add support for `set_cost` action to have row level costs for extraction [#343](https://github.com/egraphs-good/egglog-python/pull/343) diff --git a/docs/reference/egglog-translation.md b/docs/reference/egglog-translation.md index 5162233a..c169aabb 100644 --- a/docs/reference/egglog-translation.md +++ b/docs/reference/egglog-translation.md @@ -459,6 +459,62 @@ step_egraph.check(left(i64(10)), right(i64(9))) step_egraph.check_fail(left(i64(11)), right(i64(10))) ``` +#### Custom Schedulers + +Custom backoff scheduler from egglog-experimental is supported. Create a custom backoff scheduler with `bo = BackOff(match_limit: None | int=None, ban_length: None | int=None)`, then run using `run(ruleset, *facts, scheduler=bo)`: + +- `match_limit`: per-rule threshold of matches allowed in a single scheduler iteration. If a rule produces more matches than the threshold, that rule is temporarily banned. +- `ban_length`: initial ban duration (in scheduler iterations). While banned, that rule is skipped. +- Exponential backoff: each time a rule is banned, both the threshold and ban length double for that rule (threshold = match_limit << times_banned; ban = ban_length << times_banned). +- Fast-forwarding: when any rule is banned, the scheduler fast-forwards by the minimum remaining ban to unban at least one rule before checking for termination again. +- Defaults: match_limit defaults to 1000; ban_length defaults to 5. + +For example, this egglog code: + +``` +(run-schedule + (let-scheduler bo (back-off :match-limit 10)) + (repeat 10 (run-with bo step_right))) +``` + +Is translated as: + +```{code-cell} python +step_egraph.run( + run(step_right, scheduler=back_off(match_limit=10)) * 10 +) +``` + +By default the scheduler will be created before any other schedules are run. +To control where is instantiated explicitly, use `bo.scope()`, where it will be created before everything in ``. + +So the previous is equivalent to: + +```{code-cell} python +bo = back_off(match_limit=10) +step_egraph.run( + bo.scope(run(step_right, scheduler=bo) * 10) +) +``` + +If you wanted to create the scheduler inside the repeated schedule, you can do: + +```{code-cell} python +bo = back_off(match_limit=10) +step_egraph.run( + bo.scope(run(step_right, scheduler=bo)) * 10 +) +``` + +This would be equivalent to this egglog: + +``` +(run-schedule + (repeat 10 + (let-scheduler bo (back-off :match-limit 10)) + (run-with bo step_right))) +``` + ## Check The `(check ...)` command to verify that some facts are true, can be translated to Python with the `egraph.check` function: diff --git a/python/egglog/declarations.py b/python/egglog/declarations.py index 1e2f54cd..c657930f 100644 --- a/python/egglog/declarations.py +++ b/python/egglog/declarations.py @@ -9,6 +9,7 @@ from dataclasses import dataclass, field from functools import cached_property from typing import TYPE_CHECKING, ClassVar, Literal, Protocol, TypeAlias, TypeVar, Union, cast, runtime_checkable +from uuid import UUID from weakref import WeakValueDictionary from typing_extensions import Self, assert_never @@ -20,6 +21,7 @@ __all__ = [ "ActionCommandDecl", "ActionDecl", + "BackOffDecl", "BiRewriteDecl", "CallDecl", "CallableDecl", @@ -52,6 +54,7 @@ "JustTypeRef", "LetDecl", "LetRefDecl", + "LetSchedulerDecl", "LitDecl", "LitType", "MethodRef", @@ -790,9 +793,24 @@ class SequenceDecl: class RunDecl: ruleset: str until: tuple[FactDecl, ...] | None + scheduler: BackOffDecl | None = None -ScheduleDecl: TypeAlias = SaturateDecl | RepeatDecl | SequenceDecl | RunDecl +@dataclass(frozen=True) +class LetSchedulerDecl: + scheduler: BackOffDecl + inner: ScheduleDecl + + +ScheduleDecl: TypeAlias = SaturateDecl | RepeatDecl | SequenceDecl | RunDecl | LetSchedulerDecl + + +@dataclass(frozen=True) +class BackOffDecl: + id: UUID + match_limit: int | None + ban_length: int | None + ## # Facts diff --git a/python/egglog/egraph.py b/python/egglog/egraph.py index 00f6eb59..d0f451cf 100644 --- a/python/egglog/egraph.py +++ b/python/egglog/egraph.py @@ -23,6 +23,7 @@ get_type_hints, overload, ) +from uuid import uuid4 from warnings import warn import graphviz @@ -45,6 +46,7 @@ __all__ = [ "Action", + "BackOff", "BaseExpr", "BuiltinExpr", "Command", @@ -63,6 +65,7 @@ "_RewriteBuilder", "_SetBuilder", "_UnionBuilder", + "back_off", "birewrite", "check", "check_eq", @@ -905,8 +908,8 @@ def run( def _run_schedule(self, schedule: Schedule) -> bindings.RunReport: self._add_decls(schedule) - egg_schedule = self._state.schedule_to_egg(schedule.schedule) - (command_output,) = self._egraph.run_program(bindings.RunSchedule(egg_schedule)) + cmd = self._state.run_schedule_to_egg(schedule.schedule) + (command_output,) = self._egraph.run_program(cmd) assert isinstance(command_output, bindings.RunScheduleOutput) return command_output.report @@ -1786,17 +1789,51 @@ def to_runtime_expr(expr: BaseExpr) -> RuntimeExpr: return expr -def run(ruleset: Ruleset | None = None, *until: FactLike) -> Schedule: +def run(ruleset: Ruleset | None = None, *until: FactLike, scheduler: BackOff | None = None) -> Schedule: """ Create a run configuration. """ facts = _fact_likes(until) return Schedule( Thunk.fn(Declarations.create, ruleset, *facts), - RunDecl(ruleset.__egg_name__ if ruleset else "", tuple(f.fact for f in facts) or None), + RunDecl( + ruleset.__egg_name__ if ruleset else "", + tuple(f.fact for f in facts) or None, + scheduler.scheduler if scheduler else None, + ), ) +def back_off(match_limit: None | int = None, ban_length: None | int = None) -> BackOff: + """ + Create a backoff scheduler configuration. + + ```python + schedule = run(analysis_ruleset).saturate() + run(ruleset, scheduler=back_off(match_limit=1000, ban_length=5)) * 10 + ``` + This will run the `analysis_ruleset` until saturation, then run `ruleset` 10 times, using a backoff scheduler. + """ + return BackOff(BackOffDecl(id=uuid4(), match_limit=match_limit, ban_length=ban_length)) + + +@dataclass(frozen=True) +class BackOff: + scheduler: BackOffDecl + + def scope(self, schedule: Schedule) -> Schedule: + """ + Defines the scheduler to be created directly before the inner schedule, instead of the default which is at the + most outer scope. + """ + return Schedule(schedule.__egg_decls_thunk__, LetSchedulerDecl(self.scheduler, schedule.schedule)) + + def __str__(self) -> str: + return pretty_decl(Declarations(), self.scheduler) + + def __repr__(self) -> str: + return str(self) + + def seq(*schedules: Schedule) -> Schedule: """ Run a sequence of schedules. diff --git a/python/egglog/egraph_state.py b/python/egglog/egraph_state.py index 52b0e291..acca0778 100644 --- a/python/egglog/egraph_state.py +++ b/python/egglog/egraph_state.py @@ -8,6 +8,7 @@ from collections import defaultdict from dataclasses import dataclass, field, replace from typing import TYPE_CHECKING, Literal, overload +from uuid import UUID from typing_extensions import assert_never @@ -89,18 +90,140 @@ def copy(self) -> EGraphState: cost_callables=self.cost_callables.copy(), ) - def schedule_to_egg(self, schedule: ScheduleDecl) -> bindings._Schedule: + def run_schedule_to_egg(self, schedule: ScheduleDecl) -> bindings._Command: + """ + Turn a run schedule into an egg command. + + If there exists any custom schedulers in the schedule, it will be turned into a custom extract command otherwise + will be a normal run command. + """ + processed_schedule = self._process_schedule(schedule) + if processed_schedule is None: + return bindings.RunSchedule(self._schedule_to_egg(schedule)) + top_level_schedules = self._schedule_with_scheduler_to_egg(processed_schedule, []) + if len(top_level_schedules) == 1: + schedule_expr = top_level_schedules[0] + else: + schedule_expr = bindings.Call(span(), "seq", top_level_schedules) + return bindings.UserDefined(span(), "run-schedule", [schedule_expr]) + + def _process_schedule(self, schedule: ScheduleDecl) -> ScheduleDecl | None: + """ + Processes a schedule to determine if it contains any custom schedulers. + + If it does, it returns a new schedule with all the required let bindings added to the other scope. + If not, returns none. + + Also processes all rulesets in the schedule to make sure they are registered. + """ + bound_schedulers: list[UUID] = [] + unbound_schedulers: list[BackOffDecl] = [] + + def helper(s: ScheduleDecl) -> None: + match s: + case LetSchedulerDecl(scheduler, inner): + bound_schedulers.append(scheduler.id) + return helper(inner) + case RunDecl(ruleset_name, _, scheduler): + self.ruleset_to_egg(ruleset_name) + if scheduler and scheduler.id not in bound_schedulers: + unbound_schedulers.append(scheduler) + case SaturateDecl(inner) | RepeatDecl(inner, _): + return helper(inner) + case SequenceDecl(schedules): + for sc in schedules: + helper(sc) + case _: + assert_never(s) + return None + + helper(schedule) + if not bound_schedulers and not unbound_schedulers: + return None + for scheduler in unbound_schedulers: + schedule = LetSchedulerDecl(scheduler, schedule) + return schedule + + def _schedule_to_egg(self, schedule: ScheduleDecl) -> bindings._Schedule: + msg = "Should never reach this, let schedulers should be handled by custom scheduler" match schedule: case SaturateDecl(schedule): - return bindings.Saturate(span(), self.schedule_to_egg(schedule)) + return bindings.Saturate(span(), self._schedule_to_egg(schedule)) case RepeatDecl(schedule, times): - return bindings.Repeat(span(), times, self.schedule_to_egg(schedule)) + return bindings.Repeat(span(), times, self._schedule_to_egg(schedule)) case SequenceDecl(schedules): - return bindings.Sequence(span(), [self.schedule_to_egg(s) for s in schedules]) - case RunDecl(ruleset_name, until): - self.ruleset_to_egg(ruleset_name) + return bindings.Sequence(span(), [self._schedule_to_egg(s) for s in schedules]) + case RunDecl(ruleset_name, until, scheduler): + if scheduler is not None: + raise ValueError(msg) config = bindings.RunConfig(ruleset_name, None if not until else list(map(self.fact_to_egg, until))) return bindings.Run(span(), config) + case LetSchedulerDecl(): + raise ValueError(msg) + case _: + assert_never(schedule) + + def _schedule_with_scheduler_to_egg( # noqa: C901, PLR0912 + self, schedule: ScheduleDecl, bound_schedulers: list[UUID] + ) -> list[bindings._Expr]: + """ + Turns a scheduler into an egg expression, to be used with a custom extract command. + + The bound_schedulers is a list of all the schedulers that have been bound. We can lookup their name as `_scheduler_{index}`. + """ + match schedule: + case LetSchedulerDecl(BackOffDecl(id, match_limit, ban_length), inner): + name = f"_scheduler_{len(bound_schedulers)}" + bound_schedulers.append(id) + args: list[bindings._Expr] = [] + if match_limit is not None: + args.append(bindings.Var(span(), ":match-limit")) + args.append(bindings.Lit(span(), bindings.Int(match_limit))) + if ban_length is not None: + args.append(bindings.Var(span(), ":ban-length")) + args.append(bindings.Lit(span(), bindings.Int(ban_length))) + back_off_decl = bindings.Call(span(), "back-off", args) + let_decl = bindings.Call(span(), "let-scheduler", [bindings.Var(span(), name), back_off_decl]) + return [let_decl, *self._schedule_with_scheduler_to_egg(inner, bound_schedulers)] + case RunDecl(ruleset_name, until, scheduler): + args = [bindings.Var(span(), ruleset_name)] + if scheduler: + name = "run-with" + scheduler_name = f"_scheduler_{bound_schedulers.index(scheduler.id)}" + args.insert(0, bindings.Var(span(), scheduler_name)) + else: + name = "run" + if until: + if len(until) > 1: + msg = "Can only have one until fact with custom scheduler" + raise ValueError(msg) + args.append(bindings.Var(span(), ":until")) + fact_egg = self.fact_to_egg(until[0]) + if isinstance(fact_egg, bindings.Eq): + msg = "Cannot use equality fact with custom scheduler" + raise ValueError(msg) + args.append(fact_egg.expr) + return [bindings.Call(span(), name, args)] + case SaturateDecl(inner): + return [ + bindings.Call(span(), "saturate", self._schedule_with_scheduler_to_egg(inner, bound_schedulers)) + ] + case RepeatDecl(inner, times): + return [ + bindings.Call( + span(), + "repeat", + [ + bindings.Lit(span(), bindings.Int(times)), + *self._schedule_with_scheduler_to_egg(inner, bound_schedulers), + ], + ) + ] + case SequenceDecl(schedules): + res = [] + for s in schedules: + res.extend(self._schedule_with_scheduler_to_egg(s, bound_schedulers)) + return res case _: assert_never(schedule) diff --git a/python/egglog/examples/jointree.py b/python/egglog/examples/jointree.py index e596f95e..fe4683f2 100644 --- a/python/egglog/examples/jointree.py +++ b/python/egglog/examples/jointree.py @@ -62,6 +62,3 @@ def _rules(s: String, a: JoinTree, b: JoinTree, c: JoinTree, asize: i64, bsize: egraph.run(1000) print(egraph.extract(query)) print(egraph.extract(query.size)) - - -egraph diff --git a/python/egglog/pretty.py b/python/egglog/pretty.py index 8c8c7130..62531acf 100644 --- a/python/egglog/pretty.py +++ b/python/egglog/pretty.py @@ -67,7 +67,9 @@ "__invert__": "~", } -AllDecls: TypeAlias = RulesetDecl | CombinedRulesetDecl | CommandDecl | ActionDecl | FactDecl | ExprDecl | ScheduleDecl +AllDecls: TypeAlias = ( + RulesetDecl | CombinedRulesetDecl | CommandDecl | ActionDecl | FactDecl | ExprDecl | ScheduleDecl | BackOffDecl +) def pretty_decl( @@ -188,10 +190,12 @@ def __call__(self, decl: AllDecls, toplevel: bool = False) -> None: # noqa: C90 case _: for e in exprs: self(e.expr) - case RunDecl(_, until): + case RunDecl(_, until, scheduler): if until: for f in until: self(f) + if scheduler: + self(scheduler) case PartialCallDecl(c): self(c) case CombinedRulesetDecl(_): @@ -201,6 +205,12 @@ def __call__(self, decl: AllDecls, toplevel: bool = False) -> None: # noqa: C90 case SetCostDecl(_, e, c): self(e) self(c) + case BackOffDecl(): + pass + case LetSchedulerDecl(scheduler, schedule): + self(scheduler) + self(schedule) + case _: assert_never(decl) @@ -238,7 +248,11 @@ def __call__( # it would take up is > than some constant (~ line length). line_diff: int = len(expr) - LINE_DIFFERENCE n_parents = self.parents[decl] - if n_parents > 1 and n_parents * line_diff > MAX_LINE_LENGTH: + if n_parents > 1 and ( + n_parents * line_diff > MAX_LINE_LENGTH + # Schedulers with multiple parents need to be the same object, b/c are created with hidden UUIDs + or tp_name == "scheduler" + ): self.names[decl] = expr_name = self._name_expr(tp_name, expr, copy_identifier=False) return expr_name return expr @@ -318,16 +332,27 @@ def uncached(self, decl: AllDecls, *, unwrap_lit: bool, parens: bool, ruleset_na return f"{self(schedules[0], parens=True)} + {self(schedules[1], parens=True)}", "schedule" args = ", ".join(map(self, schedules)) return f"seq({args})", "schedule" - case RunDecl(ruleset_name, until): + case LetSchedulerDecl(scheduler, schedule): + return f"{self(scheduler, parens=True)}.scope({self(schedule, parens=True)})", "schedule" + case RunDecl(ruleset_name, until, scheduler): ruleset = self.decls._rulesets[ruleset_name] ruleset_str = self(ruleset, ruleset_name=ruleset_name) - if not until: + if not until and not scheduler: return ruleset_str, "schedule" - args = ", ".join(map(self, until)) - return f"run({ruleset_str}, {args})", "schedule" + arg_lst = list(map(self, until or [])) + if scheduler: + arg_lst.append(f"scheduler={self(scheduler)}") + return f"run({ruleset_str}, {', '.join(arg_lst)})", "schedule" case DefaultRewriteDecl(): msg = "default rewrites should not be pretty printed" raise TypeError(msg) + case BackOffDecl(_, match_limit, ban_length): + list_args: list[str] = [] + if match_limit is not None: + list_args.append(f"match_limit={match_limit}") + if ban_length is not None: + list_args.append(f"ban_length={ban_length}") + return f"back_off({', '.join(list_args)})", "scheduler" assert_never(decl) def _call( diff --git a/python/tests/test_high_level.py b/python/tests/test_high_level.py index 28cf61c9..c6e8731b 100644 --- a/python/tests/test_high_level.py +++ b/python/tests/test_high_level.py @@ -1093,3 +1093,94 @@ def __sub__(self, other: E) -> E: ... assert egraph.extract(E(2), include_cost=True) == (E(1) + E(1), 203) egraph.register(set_cost(E(5) - E(3), 198)) assert egraph.extract(E(2), include_cost=True) == (E(5) - E(3), 202) + + +class TestScheduler: + def test_sequence_repeat_saturate(self): + """ + Mirrors the scheduling example: alternate step-right and step-left, + saturating each, repeated 10 times. Verifies final facts. + """ + egraph = EGraph() + + left = relation("left", i64) + right = relation("right", i64) + + x, y = vars_("x y", i64) + + # Name rulesets to make schedule translation stable and explicit + step_left = ruleset( + rule( + left(x), + right(x), + ).then(left(x + 1)), + name="step-left", + ) + step_right = ruleset( + rule( + left(x), + right(y), + eq(x).to(y + 1), + ).then(right(x)), + name="step-right", + ) + + # Initial facts + egraph.register(left(i64(0)), right(i64(0))) + + # (repeat 10 (seq (run step-right) (saturate step-left))) + egraph.run(seq(step_right, step_left) * 10) + + # We took 10 left steps, but only 9 right steps (first can't move) + egraph.check(left(i64(10)), right(i64(9))) + egraph.check_fail(left(i64(11)), right(i64(10))) + + def test_backoff_scheduler(self): + """ + Passing `scheduler=...` to run(...) hoists the scheduler to the + outer scope. This is equivalent to an explicit outer `bo.scope(...)`. + + https://egraphs.zulipchat.com/#narrow/channel/375765-egg.2Fegglog/topic/.E2.9C.94.20Backoff.20Scheduler.20Example/with/538745863 + """ + includes = relation("includes", i64) + x = var("x", i64) + grow = ruleset(rule(includes(x)).then(includes(x + 1))) + shrink = ruleset(rule(includes(x)).then(includes(x - 1))) + + e1 = EGraph() + e1.register(includes(i64(0))) + # default scheduler + with e1: + e1.run((grow + shrink) * 3) + e1.check(includes(i64(3)), includes(i64(-3))) + # back-off implicit outer hoisting + bo = back_off(match_limit=1) + with e1: + e1.run((run(grow, scheduler=bo) + shrink) * 3) + e1.check(includes(i64(2)), includes(i64(-3))) + e1.check_fail(includes(i64(3))) + # back off inner hoisting + with e1: + e1.run(bo.scope(run(grow, scheduler=bo) + shrink) * 3) + e1.check(includes(i64(1)), includes(i64(-3))) + e1.check_fail(includes(i64(2))) + + def test_custom_scheduler_invalid_until(self): + """ + Custom schedulers do not support equality facts in :until, + and only allow a single non-equality fact. + """ + egraph = EGraph() + + rel = relation("rel", i64) + x = var("x", i64) + r = ruleset(name="r") + bo = back_off(match_limit=1) + + # Equality in until should error via high-level run + with pytest.raises(ValueError, match="Cannot use equality fact with custom scheduler"): + egraph.run(run(r, eq(x).to(i64(1)), scheduler=bo)) + + # Multiple until facts should error via high-level run + with pytest.raises(ValueError, match="Can only have one until fact with custom scheduler"): + egraph.run(run(r, rel(i64(0)), rel(i64(1)), scheduler=bo)) diff --git a/python/tests/test_pretty.py b/python/tests/test_pretty.py index ad2db1bd..a6f682d8 100644 --- a/python/tests/test_pretty.py +++ b/python/tests/test_pretty.py @@ -82,6 +82,8 @@ def my_very_long_function_name() -> A: ... r = ruleset(name="r") +bo = back_off(ban_length=5) + PARAMS = [ # expression function calls pytest.param(A(), "A()", id="init"), @@ -148,6 +150,16 @@ def my_very_long_function_name() -> A: ... pytest.param(r + r, 'ruleset(name="r") + ruleset(name="r")', id="sequence"), pytest.param(seq(r, r, r), 'seq(ruleset(name="r"), ruleset(name="r"), ruleset(name="r"))', id="seq"), pytest.param(run(r, h()), 'run(ruleset(name="r"), h())', id="run"), + pytest.param( + run(r, h(), scheduler=bo), + 'run(ruleset(name="r"), h(), scheduler=back_off(ban_length=5))', + id="run with scheduler", + ), + pytest.param( + bo.scope(run(r, scheduler=bo)), + '_scheduler_1 = back_off(ban_length=5)\n_scheduler_1.scope(run(ruleset(name="r"), scheduler=_scheduler_1))', + id="scoped scheduler", + ), # Functions pytest.param(f, "f", id="function"), pytest.param(A().method, "A().method", id="method"),