Skip to content

Commit

Permalink
feature[next]: Collect shifts for all nodes in TraceShift pass (GridT…
Browse files Browse the repository at this point in the history
…ools#1321)

- Shifts are collected not only for the closure inputs, but for every iterator expression in the tree (including iterator arguments to lambdas).
- The collected shifts are now represented as a set.

This PR is a prerequisite for the temporary extraction heuristics.
  • Loading branch information
tehrengruber committed Aug 24, 2023
1 parent 90c4121 commit c8dbffd
Show file tree
Hide file tree
Showing 2 changed files with 168 additions and 65 deletions.
111 changes: 84 additions & 27 deletions src/gt4py/next/iterator/transforms/trace_shifts.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,9 +11,9 @@
# distribution for a copy of the license or check <https://www.gnu.org/licenses/>.
#
# SPDX-License-Identifier: GPL-3.0-or-later
import dataclasses
import enum
from collections.abc import Callable
from dataclasses import dataclass
from typing import Any, Final, Iterable, Literal

from gt4py.eve import NodeTranslator
Expand All @@ -26,6 +26,30 @@ class Sentinel(enum.Enum):
TYPE = object()


@dataclasses.dataclass(frozen=True)
class ShiftRecorder:
recorded_shifts: dict[int, set[tuple[ir.OffsetLiteral, ...]]] = dataclasses.field(
default_factory=dict
)

def register_node(self, inp: ir.Expr | ir.Sym) -> None:
self.recorded_shifts.setdefault(id(inp), set())

def __call__(self, inp: ir.Expr | ir.Sym, offsets: tuple[ir.OffsetLiteral, ...]) -> None:
self.recorded_shifts[id(inp)].add(offsets)


@dataclasses.dataclass(frozen=True)
class ForwardingShiftRecorder:
wrapped_tracer: Any
shift_recorder: ShiftRecorder

def __call__(self, inp: ir.Expr | ir.Sym, offsets: tuple[ir.OffsetLiteral, ...]):
self.shift_recorder(inp, offsets)
# Forward shift to wrapped tracer such it can record the shifts of the parent nodes
self.wrapped_tracer.shift(offsets).deref()


# for performance reasons (`isinstance` is slow otherwise) we don't use abc here
class IteratorTracer:
def deref(self):
Expand All @@ -35,29 +59,27 @@ def shift(self, offsets: tuple[ir.OffsetLiteral, ...]):
raise NotImplementedError()


@dataclass(frozen=True)
class InputTracer(IteratorTracer):
inp: str
register_deref: Callable[[str, tuple[ir.OffsetLiteral, ...]], None]
@dataclasses.dataclass(frozen=True)
class IteratorArgTracer(IteratorTracer):
arg: ir.Expr | ir.Sym
shift_recorder: ShiftRecorder | ForwardingShiftRecorder
offsets: tuple[ir.OffsetLiteral, ...] = ()
lift_level: int = 0

def shift(self, offsets: tuple[ir.OffsetLiteral, ...]):
return InputTracer(
inp=self.inp,
register_deref=self.register_deref,
return IteratorArgTracer(
arg=self.arg,
shift_recorder=self.shift_recorder,
offsets=self.offsets + tuple(offsets),
lift_level=self.lift_level,
)

def deref(self):
self.register_deref(self.inp, self.offsets)
self.shift_recorder(self.arg, self.offsets)
return Sentinel.VALUE


# This class is only needed because we currently allow conditionals on iterators. Since this is
# not supported in the C++ backend it can likely be removed again in the future.
@dataclass(frozen=True)
@dataclasses.dataclass(frozen=True)
class CombinedTracer(IteratorTracer):
its: tuple[IteratorTracer, ...]

Expand Down Expand Up @@ -98,13 +120,13 @@ def apply(arg):
return apply


@dataclass(frozen=True)
@dataclasses.dataclass(frozen=True)
class AppliedLift(IteratorTracer):
stencil: Callable
its: tuple[IteratorTracer, ...]

def shift(self, offsets):
return AppliedLift(self.stencil, tuple(_shift(it) for it in self.its))
return AppliedLift(self.stencil, tuple(_shift(*offsets)(it) for it in self.its))

def deref(self):
return self.stencil(*self.its)
Expand Down Expand Up @@ -211,7 +233,10 @@ def _tuple_get(index, tuple_val):
}


@dataclasses.dataclass(frozen=True)
class TraceShifts(NodeTranslator):
shift_recorder: ShiftRecorder = dataclasses.field(default_factory=ShiftRecorder)

def visit_Literal(self, node: ir.SymRef, *, ctx: dict[str, Any]) -> Any:
return Sentinel.VALUE

Expand All @@ -232,30 +257,62 @@ def visit_FunCall(self, node: ir.FunCall, *, ctx: dict[str, Any]) -> Any:
args = self.visit(node.args, ctx=ctx)
return fun(*args)

def visit(self, node, **kwargs):
result = super().visit(node, **kwargs)
if isinstance(result, IteratorTracer):
assert isinstance(node, (ir.Sym, ir.Expr))

self.shift_recorder.register_node(node)
result = IteratorArgTracer(
arg=node, shift_recorder=ForwardingShiftRecorder(result, self.shift_recorder)
)
return result

def visit_Lambda(self, node: ir.Lambda, *, ctx: dict[str, Any]) -> Callable:
def fun(*args):
new_args = []
for param, arg in zip(node.params, args, strict=True):
if isinstance(arg, IteratorTracer):
self.shift_recorder.register_node(param)
new_args.append(
IteratorArgTracer(
arg=param,
shift_recorder=ForwardingShiftRecorder(arg, self.shift_recorder),
)
)
else:
new_args.append(arg)

return self.visit(
node.expr, ctx=ctx | {p.id: a for p, a in zip(node.params, args, strict=True)}
node.expr, ctx=ctx | {p.id: a for p, a in zip(node.params, new_args, strict=True)}
)

return fun

def visit_StencilClosure(
self, node: ir.StencilClosure, *, shifts: dict[str, list[tuple[ir.OffsetLiteral, ...]]]
):
def register_deref(inp: str, offsets: tuple[ir.OffsetLiteral, ...]):
shifts[inp].append(offsets)

def visit_StencilClosure(self, node: ir.StencilClosure):
tracers = []
for inp in node.inputs:
shifts.setdefault(inp.id, [])
tracers.append(InputTracer(inp=inp.id, register_deref=register_deref))
self.shift_recorder.register_node(inp)
tracers.append(IteratorArgTracer(arg=inp, shift_recorder=self.shift_recorder))

result = self.visit(node.stencil, ctx=_START_CTX)(*tracers)
assert all(el is Sentinel.VALUE for el in _primitive_constituents(result))

@classmethod
def apply(cls, node: ir.StencilClosure) -> dict[str, list[tuple[ir.OffsetLiteral, ...]]]:
shifts = dict[str, list[tuple[ir.OffsetLiteral, ...]]]()
cls().visit(node, shifts=shifts)
return shifts
def apply(
cls, node: ir.StencilClosure, *, inputs_only=True
) -> (
dict[int, set[tuple[ir.OffsetLiteral, ...]]] | dict[str, set[tuple[ir.OffsetLiteral, ...]]]
):
instance = cls()
instance.visit(node)

recorded_shifts = instance.shift_recorder.recorded_shifts

if inputs_only:
inputs_shifts = {}
for inp in node.inputs:
inputs_shifts[str(inp.id)] = recorded_shifts[id(inp)]
return inputs_shifts

return recorded_shifts
Loading

0 comments on commit c8dbffd

Please sign in to comment.