Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 2 additions & 6 deletions lighthouse/schedule/x86/pack_lowering.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,9 +24,7 @@ def lower_packs_for_vectorization(
vector_tile_sizes: Target vector shapes
vector_unroll_factors: Unroll factors for each vector loop.
"""
foreach_pack = transform.ForeachOp([], (pack_ops,))
with ir.InsertionPoint(foreach_pack.body):
pack_op = foreach_pack.bodyTargets[0]
with lh_transform.foreach(pack_ops) as pack_op:
tiled_pack = structured.TileUsingForOp(
pack_op, sizes=pack_tile_sizes
).tiled_linalg_op
Expand Down Expand Up @@ -59,9 +57,7 @@ def lower_unpacks_for_vectorization(
unpack_tile_sizes: Unpack sub-tiling sizes
vector_tile_sizes: Target vector shapes
"""
foreach_unpack = transform.ForeachOp([], (unpack_ops,))
with ir.InsertionPoint(foreach_unpack.body):
unpack_op = foreach_unpack.bodyTargets[0]
with lh_transform.foreach(unpack_ops) as unpack_op:
tiled_unpack = structured.TileUsingForOp(
unpack_op, sizes=unpack_tile_sizes
).tiled_linalg_op
Expand Down
2 changes: 2 additions & 0 deletions lighthouse/transform/__init__.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
from .cleanup import cleanup
from .cleanup import simplify_vector_ops
from .cleanup import flatten_vector_ops
from .foreach import foreach
from .hoisting import loop_hoisting
from .matchers import match_op
from .tiling import tile_ops
Expand All @@ -12,6 +13,7 @@
__all__ = [
"cleanup",
"flatten_vector_ops",
"foreach",
"loop_hoisting",
"match_op",
"pack_propagation",
Expand Down
75 changes: 75 additions & 0 deletions lighthouse/transform/foreach.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,75 @@
from collections.abc import Sequence

from mlir import ir
from mlir.dialects import transform


class foreach(transform.ForeachOp):
"""
Context manager wrapper for foreach transform op.

Apply transforms nested under the foreach loop exactly once
per element of the payload associated to the targets handle.

The wrapper creates a new foreach operation.
On entry, the insertion point is placed in the loop's body
and the block arguments are returned.
On exit, the insertion point is restored.

Nested multiple entry is not supported.

Typical usage:

with foreach(ops_handle) as op:
transform.rewrite(op)
...
transform.yield_()

With results:

foreach_op = lh_transform.foreach(
linalg_ops, vec_ops, result_types=[type]
)
with foreach_op as (linalg_op, vec_op):
...
transform.yield_([val])
res = foreach_op.results[0]

Args:
targets: Handles to targets
result_types: Result types (default: no returns)
with_zip_shortest: limit iterations to the shortest target
kwargs: Additional arguments for the foreach operation
"""

def __init__(
self,
*targets,
result_types: Sequence[ir.Type] | None = None,
with_zip_shortest: bool = False,
**kwargs,
):
if result_types is None:
result_types = []

super().__init__(
results=result_types,
targets=targets,
with_zip_shortest=with_zip_shortest,
**kwargs,
)
self.insertion_point: ir.InsertionPoint | None = None

def __enter__(self) -> Sequence[ir.BlockArgument]:
if self.insertion_point is not None:
raise Exception("Nested re-entry is not supported")
# Set insertion point in the loop's body
self.insertion_point = ir.InsertionPoint(self.body)
self.insertion_point.__enter__()

return self.bodyTargets[0] if len(self.bodyTargets) == 1 else self.bodyTargets

def __exit__(self, *args):
# Restore insertion point
self.insertion_point.__exit__(*args)
self.insertion_point = None
9 changes: 4 additions & 5 deletions lighthouse/transform/tiling.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,9 @@
from mlir import ir
from mlir.dialects import transform
from mlir.dialects.transform import loop
from mlir.dialects.transform import structured

from lighthouse.transform import foreach


def tile_ops(
target,
Expand All @@ -11,7 +12,7 @@ def tile_ops(
tile_interchange: list[int] | None = None,
peel_loops: list[int] = [],
unroll_factors: list[int] = [],
) -> ir.Value:
):
"""
Apply tiling to the target.

Expand Down Expand Up @@ -40,9 +41,7 @@ def tile_ops(
"Both unrolling and peeling is not supported"
)

foreach = transform.ForeachOp([], (target,))
with ir.InsertionPoint(foreach.body):
op = foreach.bodyTargets[0]
with foreach(target) as op:
if fuse_producers:
_, *loops = structured.FuseOp(
op,
Expand Down
6 changes: 3 additions & 3 deletions lighthouse/transform/vectorization.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,8 @@
from mlir.dialects.transform import vector
from mlir.dialects.transform import x86

from lighthouse.transform import foreach


def vectorize_ops(
target,
Expand All @@ -18,9 +20,7 @@ def vectorize_ops(
vector_sizes: Vector sizes
vectorize_kwargs: Options passed to vectorization transform
"""
foreach = transform.ForeachOp([], (target,))
with ir.InsertionPoint(foreach.body):
op = foreach.bodyTargets[0]
with foreach(target) as op:
structured.structured_vectorize(op, vector_sizes, **vectorize_kwargs)
transform.yield_()

Expand Down
55 changes: 55 additions & 0 deletions test/transform/test_foreach.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,55 @@
# RUN: %PYTHON %s | FileCheck %s

from mlir import ir
from mlir.dialects import transform

from lighthouse import schedule as lh_schedule
from lighthouse import transform as lh_transform


def test_insertion_points():
with lh_schedule.schedule_boilerplate() as (sched, named_seq):
# Create empty foreach op between two ops.
funcs = lh_transform.match_op(named_seq.bodyTarget, "func.func")
foreach_op = lh_transform.foreach(funcs)
loops = lh_transform.match_op(named_seq.bodyTarget, "scf.for")

# Insert transforms in the foreach.
with foreach_op as func:
transform.apply_dce(func)

# Create another complete foreach loop with results.
foreach_with_res = lh_transform.foreach(
funcs, loops, result_types=[transform.any_op_t()]
)
with foreach_with_res as (func, loop):
transform.apply_cse(func)
transform.apply_licm(loop)
transform.yield_([func])

# Terminate the first foreach
with foreach_op:
transform.yield_()

# Insert print at the end of the schedule
transform.print_(target=foreach_with_res.results[0])
transform.yield_()
sched.body.operations[0].verify()
print(sched)


# CHECK: %[[FUNCS:.+]] = transform.structured.match ops{["func.func"]}
# CHECK: transform.foreach %[[FUNCS]]
# CHECK: ^bb0(%[[FUNC:.+]]: !transform.any_op):
# CHECK: transform.apply_dce to %[[FUNC]]
# CHECK: %[[LOOPS:.+]] = transform.structured.match ops{["scf.for"]}
# CHECK: %[[RES:.+]] = transform.foreach %[[FUNCS]], %[[LOOPS]]
# CHECK: ^bb0(%[[FUNC:.+]]: !transform.any_op, %[[LOOP:.+]]: !transform.any_op):
# CHECK: transform.apply_cse to %[[FUNC]]
# CHECK: transform.apply_licm to %[[LOOP]]
# CHECK: transform.yield %[[FUNC]]
# CHECK: transform.print %[[RES]]
# CHECK: transform.yield

with ir.Context(), ir.Location.unknown():
test_insertion_points()
Loading