Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion Cargo.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

1 change: 1 addition & 0 deletions docs/changelog.md
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@ _This project uses semantic versioning_

## UNRELEASED

- Add `back_off` scheduler [#350](https://github.com/egraphs-good/egglog-python/pull/350)
## 11.2.0 (2025-09-03)

- Add support for `set_cost` action to have row level costs for extraction [#343](https://github.com/egraphs-good/egglog-python/pull/343)
Expand Down
56 changes: 56 additions & 0 deletions docs/reference/egglog-translation.md
Original file line number Diff line number Diff line change
Expand Up @@ -459,6 +459,62 @@ step_egraph.check(left(i64(10)), right(i64(9)))
step_egraph.check_fail(left(i64(11)), right(i64(10)))
```

#### Custom Schedulers

Custom backoff scheduler from egglog-experimental is supported. Create a custom backoff scheduler with `bo = BackOff(match_limit: None | int=None, ban_length: None | int=None)`, then run using `run(ruleset, *facts, scheduler=bo)`:

- `match_limit`: per-rule threshold of matches allowed in a single scheduler iteration. If a rule produces more matches than the threshold, that rule is temporarily banned.
- `ban_length`: initial ban duration (in scheduler iterations). While banned, that rule is skipped.
- Exponential backoff: each time a rule is banned, both the threshold and ban length double for that rule (threshold = match_limit << times_banned; ban = ban_length << times_banned).
- Fast-forwarding: when any rule is banned, the scheduler fast-forwards by the minimum remaining ban to unban at least one rule before checking for termination again.
- Defaults: match_limit defaults to 1000; ban_length defaults to 5.

For example, this egglog code:

```
(run-schedule
(let-scheduler bo (back-off :match-limit 10))
(repeat 10 (run-with bo step_right)))
```

Is translated as:

```{code-cell} python
step_egraph.run(
run(step_right, scheduler=back_off(match_limit=10)) * 10
)
```

By default the scheduler will be created before any other schedules are run.
To control where is instantiated explicitly, use `bo.scope(<schedule>)`, where it will be created before everything in `<schedule>`.

So the previous is equivalent to:

```{code-cell} python
bo = back_off(match_limit=10)
step_egraph.run(
bo.scope(run(step_right, scheduler=bo) * 10)
)
```

If you wanted to create the scheduler inside the repeated schedule, you can do:

```{code-cell} python
bo = back_off(match_limit=10)
step_egraph.run(
bo.scope(run(step_right, scheduler=bo)) * 10
)
```

This would be equivalent to this egglog:

```
(run-schedule
(repeat 10
(let-scheduler bo (back-off :match-limit 10))
(run-with bo step_right)))
```

## Check

The `(check ...)` command to verify that some facts are true, can be translated to Python with the `egraph.check` function:
Expand Down
20 changes: 19 additions & 1 deletion python/egglog/declarations.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
from dataclasses import dataclass, field
from functools import cached_property
from typing import TYPE_CHECKING, ClassVar, Literal, Protocol, TypeAlias, TypeVar, Union, cast, runtime_checkable
from uuid import UUID
from weakref import WeakValueDictionary

from typing_extensions import Self, assert_never
Expand All @@ -20,6 +21,7 @@
__all__ = [
"ActionCommandDecl",
"ActionDecl",
"BackOffDecl",
"BiRewriteDecl",
"CallDecl",
"CallableDecl",
Expand Down Expand Up @@ -52,6 +54,7 @@
"JustTypeRef",
"LetDecl",
"LetRefDecl",
"LetSchedulerDecl",
"LitDecl",
"LitType",
"MethodRef",
Expand Down Expand Up @@ -790,9 +793,24 @@ class SequenceDecl:
class RunDecl:
ruleset: str
until: tuple[FactDecl, ...] | None
scheduler: BackOffDecl | None = None


ScheduleDecl: TypeAlias = SaturateDecl | RepeatDecl | SequenceDecl | RunDecl
@dataclass(frozen=True)
class LetSchedulerDecl:
scheduler: BackOffDecl
inner: ScheduleDecl


ScheduleDecl: TypeAlias = SaturateDecl | RepeatDecl | SequenceDecl | RunDecl | LetSchedulerDecl


@dataclass(frozen=True)
class BackOffDecl:
id: UUID
match_limit: int | None
ban_length: int | None


##
# Facts
Expand Down
45 changes: 41 additions & 4 deletions python/egglog/egraph.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@
get_type_hints,
overload,
)
from uuid import uuid4
from warnings import warn

import graphviz
Expand All @@ -45,6 +46,7 @@

__all__ = [
"Action",
"BackOff",
"BaseExpr",
"BuiltinExpr",
"Command",
Expand All @@ -63,6 +65,7 @@
"_RewriteBuilder",
"_SetBuilder",
"_UnionBuilder",
"back_off",
"birewrite",
"check",
"check_eq",
Expand Down Expand Up @@ -905,8 +908,8 @@ def run(

def _run_schedule(self, schedule: Schedule) -> bindings.RunReport:
self._add_decls(schedule)
egg_schedule = self._state.schedule_to_egg(schedule.schedule)
(command_output,) = self._egraph.run_program(bindings.RunSchedule(egg_schedule))
cmd = self._state.run_schedule_to_egg(schedule.schedule)
(command_output,) = self._egraph.run_program(cmd)
assert isinstance(command_output, bindings.RunScheduleOutput)
return command_output.report

Expand Down Expand Up @@ -1786,17 +1789,51 @@ def to_runtime_expr(expr: BaseExpr) -> RuntimeExpr:
return expr


def run(ruleset: Ruleset | None = None, *until: FactLike) -> Schedule:
def run(ruleset: Ruleset | None = None, *until: FactLike, scheduler: BackOff | None = None) -> Schedule:
"""
Create a run configuration.
"""
facts = _fact_likes(until)
return Schedule(
Thunk.fn(Declarations.create, ruleset, *facts),
RunDecl(ruleset.__egg_name__ if ruleset else "", tuple(f.fact for f in facts) or None),
RunDecl(
ruleset.__egg_name__ if ruleset else "",
tuple(f.fact for f in facts) or None,
scheduler.scheduler if scheduler else None,
),
)


def back_off(match_limit: None | int = None, ban_length: None | int = None) -> BackOff:
"""
Create a backoff scheduler configuration.

```python
schedule = run(analysis_ruleset).saturate() + run(ruleset, scheduler=back_off(match_limit=1000, ban_length=5)) * 10
```
This will run the `analysis_ruleset` until saturation, then run `ruleset` 10 times, using a backoff scheduler.
"""
return BackOff(BackOffDecl(id=uuid4(), match_limit=match_limit, ban_length=ban_length))


@dataclass(frozen=True)
class BackOff:
scheduler: BackOffDecl

def scope(self, schedule: Schedule) -> Schedule:
"""
Defines the scheduler to be created directly before the inner schedule, instead of the default which is at the
most outer scope.
"""
return Schedule(schedule.__egg_decls_thunk__, LetSchedulerDecl(self.scheduler, schedule.schedule))

def __str__(self) -> str:
return pretty_decl(Declarations(), self.scheduler)

def __repr__(self) -> str:
return str(self)


def seq(*schedules: Schedule) -> Schedule:
"""
Run a sequence of schedules.
Expand Down
135 changes: 129 additions & 6 deletions python/egglog/egraph_state.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
from collections import defaultdict
from dataclasses import dataclass, field, replace
from typing import TYPE_CHECKING, Literal, overload
from uuid import UUID

from typing_extensions import assert_never

Expand Down Expand Up @@ -89,18 +90,140 @@ def copy(self) -> EGraphState:
cost_callables=self.cost_callables.copy(),
)

def schedule_to_egg(self, schedule: ScheduleDecl) -> bindings._Schedule:
def run_schedule_to_egg(self, schedule: ScheduleDecl) -> bindings._Command:
"""
Turn a run schedule into an egg command.

If there exists any custom schedulers in the schedule, it will be turned into a custom extract command otherwise
will be a normal run command.
"""
processed_schedule = self._process_schedule(schedule)
if processed_schedule is None:
return bindings.RunSchedule(self._schedule_to_egg(schedule))
top_level_schedules = self._schedule_with_scheduler_to_egg(processed_schedule, [])
if len(top_level_schedules) == 1:
schedule_expr = top_level_schedules[0]
else:
schedule_expr = bindings.Call(span(), "seq", top_level_schedules)
return bindings.UserDefined(span(), "run-schedule", [schedule_expr])

def _process_schedule(self, schedule: ScheduleDecl) -> ScheduleDecl | None:
"""
Processes a schedule to determine if it contains any custom schedulers.

If it does, it returns a new schedule with all the required let bindings added to the other scope.
If not, returns none.

Also processes all rulesets in the schedule to make sure they are registered.
"""
bound_schedulers: list[UUID] = []
unbound_schedulers: list[BackOffDecl] = []

def helper(s: ScheduleDecl) -> None:
match s:
case LetSchedulerDecl(scheduler, inner):
bound_schedulers.append(scheduler.id)
return helper(inner)
case RunDecl(ruleset_name, _, scheduler):
self.ruleset_to_egg(ruleset_name)
if scheduler and scheduler.id not in bound_schedulers:
unbound_schedulers.append(scheduler)
case SaturateDecl(inner) | RepeatDecl(inner, _):
return helper(inner)
case SequenceDecl(schedules):
for sc in schedules:
helper(sc)
case _:
assert_never(s)
return None

helper(schedule)
if not bound_schedulers and not unbound_schedulers:
return None
for scheduler in unbound_schedulers:
schedule = LetSchedulerDecl(scheduler, schedule)
return schedule

def _schedule_to_egg(self, schedule: ScheduleDecl) -> bindings._Schedule:
msg = "Should never reach this, let schedulers should be handled by custom scheduler"
match schedule:
case SaturateDecl(schedule):
return bindings.Saturate(span(), self.schedule_to_egg(schedule))
return bindings.Saturate(span(), self._schedule_to_egg(schedule))
case RepeatDecl(schedule, times):
return bindings.Repeat(span(), times, self.schedule_to_egg(schedule))
return bindings.Repeat(span(), times, self._schedule_to_egg(schedule))
case SequenceDecl(schedules):
return bindings.Sequence(span(), [self.schedule_to_egg(s) for s in schedules])
case RunDecl(ruleset_name, until):
self.ruleset_to_egg(ruleset_name)
return bindings.Sequence(span(), [self._schedule_to_egg(s) for s in schedules])
case RunDecl(ruleset_name, until, scheduler):
if scheduler is not None:
raise ValueError(msg)
config = bindings.RunConfig(ruleset_name, None if not until else list(map(self.fact_to_egg, until)))
return bindings.Run(span(), config)
case LetSchedulerDecl():
raise ValueError(msg)
case _:
assert_never(schedule)

def _schedule_with_scheduler_to_egg( # noqa: C901, PLR0912
self, schedule: ScheduleDecl, bound_schedulers: list[UUID]
) -> list[bindings._Expr]:
"""
Turns a scheduler into an egg expression, to be used with a custom extract command.

The bound_schedulers is a list of all the schedulers that have been bound. We can lookup their name as `_scheduler_{index}`.
"""
match schedule:
case LetSchedulerDecl(BackOffDecl(id, match_limit, ban_length), inner):
name = f"_scheduler_{len(bound_schedulers)}"
bound_schedulers.append(id)
args: list[bindings._Expr] = []
if match_limit is not None:
args.append(bindings.Var(span(), ":match-limit"))
args.append(bindings.Lit(span(), bindings.Int(match_limit)))
if ban_length is not None:
args.append(bindings.Var(span(), ":ban-length"))
args.append(bindings.Lit(span(), bindings.Int(ban_length)))
back_off_decl = bindings.Call(span(), "back-off", args)
let_decl = bindings.Call(span(), "let-scheduler", [bindings.Var(span(), name), back_off_decl])
return [let_decl, *self._schedule_with_scheduler_to_egg(inner, bound_schedulers)]
case RunDecl(ruleset_name, until, scheduler):
args = [bindings.Var(span(), ruleset_name)]
if scheduler:
name = "run-with"
scheduler_name = f"_scheduler_{bound_schedulers.index(scheduler.id)}"
args.insert(0, bindings.Var(span(), scheduler_name))
else:
name = "run"
if until:
if len(until) > 1:
msg = "Can only have one until fact with custom scheduler"
raise ValueError(msg)
args.append(bindings.Var(span(), ":until"))
fact_egg = self.fact_to_egg(until[0])
if isinstance(fact_egg, bindings.Eq):
msg = "Cannot use equality fact with custom scheduler"
raise ValueError(msg)
args.append(fact_egg.expr)
return [bindings.Call(span(), name, args)]
case SaturateDecl(inner):
return [
bindings.Call(span(), "saturate", self._schedule_with_scheduler_to_egg(inner, bound_schedulers))
]
case RepeatDecl(inner, times):
return [
bindings.Call(
span(),
"repeat",
[
bindings.Lit(span(), bindings.Int(times)),
*self._schedule_with_scheduler_to_egg(inner, bound_schedulers),
],
)
]
case SequenceDecl(schedules):
res = []
for s in schedules:
res.extend(self._schedule_with_scheduler_to_egg(s, bound_schedulers))
return res
case _:
assert_never(schedule)

Expand Down
3 changes: 0 additions & 3 deletions python/egglog/examples/jointree.py
Original file line number Diff line number Diff line change
Expand Up @@ -62,6 +62,3 @@ def _rules(s: String, a: JoinTree, b: JoinTree, c: JoinTree, asize: i64, bsize:
egraph.run(1000)
print(egraph.extract(query))
print(egraph.extract(query.size))


egraph
Loading