Skip to content

Commit

Permalink
[mlir] switch the transform loop extension to use types
Browse files Browse the repository at this point in the history
Add types to the Loop (SCF) extension of the transform dialect.

See https://discourse.llvm.org/t/rfc-type-system-for-the-transform-dialect/65702

Reviewed By: nicolasvasilache

Differential Revision: https://reviews.llvm.org/D135587
  • Loading branch information
ftynse committed Oct 11, 2022
1 parent 3e1f6d0 commit 59bb8af
Show file tree
Hide file tree
Showing 13 changed files with 84 additions and 69 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -9,8 +9,8 @@
#ifndef MLIR_DIALECT_SCF_TRANSFORMOPS_SCFTRANSFORMOPS_H
#define MLIR_DIALECT_SCF_TRANSFORMOPS_SCFTRANSFORMOPS_H

#include "mlir/Dialect/PDL/IR/PDLTypes.h"
#include "mlir/Dialect/Transform/IR/TransformInterfaces.h"
#include "mlir/Dialect/Transform/IR/TransformTypes.h"
#include "mlir/IR/OpImplementation.h"

namespace mlir {
Expand Down
39 changes: 24 additions & 15 deletions mlir/include/mlir/Dialect/SCF/TransformOps/SCFTransformOps.td
Original file line number Diff line number Diff line change
Expand Up @@ -12,10 +12,12 @@
include "mlir/Dialect/Transform/IR/TransformDialect.td"
include "mlir/Dialect/Transform/IR/TransformEffects.td"
include "mlir/Dialect/Transform/IR/TransformInterfaces.td"
include "mlir/Dialect/PDL/IR/PDLTypes.td"
include "mlir/Dialect/Transform/IR/TransformTypes.td"
include "mlir/Interfaces/SideEffectInterfaces.td"
include "mlir/IR/OpBase.td"

def Transform_ScfForOp : Transform_ConcreteOpType<"scf.for">;

def GetParentForOp : Op<Transform_Dialect, "loop.get_parent_for",
[NavigationTransformOpTrait, MemoryEffectsOpInterface,
DeclareOpInterfaceMethods<TransformOpInterface>]> {
Expand All @@ -30,12 +32,13 @@ def GetParentForOp : Op<Transform_Dialect, "loop.get_parent_for",
}];

let arguments =
(ins PDL_Operation:$target,
(ins TransformTypeInterface:$target,
DefaultValuedAttr<ConfinedAttr<I64Attr, [IntPositive]>,
"1">:$num_loops);
let results = (outs PDL_Operation:$parent);
let results = (outs TransformTypeInterface:$parent);

let assemblyFormat = "$target attr-dict";
let assemblyFormat =
"$target attr-dict `:` functional-type(operands, results)";
}

def LoopOutlineOp : Op<Transform_Dialect, "loop.outline",
Expand All @@ -55,11 +58,15 @@ def LoopOutlineOp : Op<Transform_Dialect, "loop.outline",
order as the operand handle.
}];

let arguments = (ins PDL_Operation:$target,
// Note that despite the name of the transform operation and related utility
// functions, the actual implementation does not require the operation to be
// a loop.
let arguments = (ins TransformTypeInterface:$target,
StrAttr:$func_name);
let results = (outs PDL_Operation:$transformed);
let results = (outs TransformTypeInterface:$transformed);

let assemblyFormat = "$target attr-dict";
let assemblyFormat =
"$target attr-dict `:` functional-type(operands, results)";
}

def LoopPeelOp : Op<Transform_Dialect, "loop.peel",
Expand Down Expand Up @@ -90,12 +97,13 @@ def LoopPeelOp : Op<Transform_Dialect, "loop.peel",
}];

let arguments =
(ins PDL_Operation:$target,
(ins Transform_ScfForOp:$target,
DefaultValuedAttr<BoolAttr, "false">:$fail_if_already_divisible);
// TODO: Return both the peeled loop and the remainder loop.
let results = (outs PDL_Operation:$transformed);
let results = (outs TransformTypeInterface:$transformed);

let assemblyFormat = "$target attr-dict";
let assemblyFormat =
"$target attr-dict `:` functional-type(operands, results)";

let extraClassDeclaration = [{
::mlir::DiagnosedSilenceableFailure applyToOne(
Expand Down Expand Up @@ -131,12 +139,13 @@ def LoopPipelineOp : Op<Transform_Dialect, "loop.pipeline",
pipelined loops, which can be empty.
}];

let arguments = (ins PDL_Operation:$target,
let arguments = (ins Transform_ScfForOp:$target,
DefaultValuedAttr<I64Attr, "1">:$iteration_interval,
DefaultValuedAttr<I64Attr, "10">:$read_latency);
let results = (outs PDL_Operation:$transformed);
let results = (outs TransformTypeInterface:$transformed);

let assemblyFormat = "$target attr-dict";
let assemblyFormat =
"$target attr-dict `:` functional-type(operands, results)";

let extraClassDeclaration = [{
::mlir::DiagnosedSilenceableFailure applyToOne(
Expand Down Expand Up @@ -165,10 +174,10 @@ def LoopUnrollOp : Op<Transform_Dialect, "loop.unroll",
removed after a full unrolling.
}];

let arguments = (ins PDL_Operation:$target,
let arguments = (ins Transform_ScfForOp:$target,
ConfinedAttr<I64Attr, [IntPositive]>:$factor);

let assemblyFormat = "$target attr-dict";
let assemblyFormat = "$target attr-dict `:` type($target)";

let extraClassDeclaration = [{
::mlir::DiagnosedSilenceableFailure applyToOne(
Expand Down
1 change: 0 additions & 1 deletion mlir/lib/Dialect/SCF/TransformOps/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,6 @@ add_mlir_dialect_library(MLIRSCFTransformOps
MLIRAffineDialect
MLIRFuncDialect
MLIRIR
MLIRPDLDialect
MLIRSCFDialect
MLIRSCFTransforms
MLIRSCFUtils
Expand Down
3 changes: 0 additions & 3 deletions mlir/lib/Dialect/SCF/TransformOps/SCFTransformOps.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,6 @@
#include "mlir/Dialect/SCF/TransformOps/SCFTransformOps.h"
#include "mlir/Dialect/Affine/IR/AffineOps.h"
#include "mlir/Dialect/Func/IR/FuncOps.h"
#include "mlir/Dialect/PDL/IR/PDL.h"
#include "mlir/Dialect/SCF/IR/SCF.h"
#include "mlir/Dialect/SCF/Transforms/Patterns.h"
#include "mlir/Dialect/SCF/Transforms/Transforms.h"
Expand Down Expand Up @@ -239,8 +238,6 @@ class SCFTransformDialectExtension
using Base::Base;

void init() {
declareDependentDialect<pdl::PDLDialect>();

declareGeneratedDialect<AffineDialect>();
declareGeneratedDialect<func::FuncDialect>();

Expand Down
13 changes: 8 additions & 5 deletions mlir/python/mlir/dialects/_loop_transform_ops_ext.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,6 @@
try:
from ..ir import *
from ._ods_common import get_op_result_or_value as _get_op_result_or_value
from ..dialects import pdl
except ImportError as e:
raise RuntimeError("Error loading imports from extension module") from e

Expand All @@ -28,13 +27,14 @@ class GetParentForOp:
"""Extension for GetParentForOp."""

def __init__(self,
result_type: Type,
target: Union[Operation, Value],
*,
num_loops: int = 1,
ip=None,
loc=None):
super().__init__(
pdl.OperationType.get(),
result_type,
_get_op_result_or_value(target),
num_loops=_get_int64_attr(num_loops, default_value=1),
ip=ip,
Expand All @@ -45,13 +45,14 @@ class LoopOutlineOp:
"""Extension for LoopOutlineOp."""

def __init__(self,
result_type: Type,
target: Union[Operation, Value],
*,
func_name: Union[str, StringAttr],
ip=None,
loc=None):
super().__init__(
pdl.OperationType.get(),
result_type,
_get_op_result_or_value(target),
func_name=(func_name if isinstance(func_name, StringAttr) else
StringAttr.get(func_name)),
Expand All @@ -63,13 +64,14 @@ class LoopPeelOp:
"""Extension for LoopPeelOp."""

def __init__(self,
result_type: Type,
target: Union[Operation, Value],
*,
fail_if_already_divisible: Union[bool, BoolAttr] = False,
ip=None,
loc=None):
super().__init__(
pdl.OperationType.get(),
result_type,
_get_op_result_or_value(target),
fail_if_already_divisible=(fail_if_already_divisible if isinstance(
fail_if_already_divisible, BoolAttr) else
Expand All @@ -82,14 +84,15 @@ class LoopPipelineOp:
"""Extension for LoopPipelineOp."""

def __init__(self,
result_type: Type,
target: Union[Operation, Value],
*,
iteration_interval: Optional[Union[int, IntegerAttr]] = None,
read_latency: Optional[Union[int, IntegerAttr]] = None,
ip=None,
loc=None):
super().__init__(
pdl.OperationType.get(),
result_type,
_get_op_result_or_value(target),
iteration_interval=_get_int64_attr(iteration_interval, default_value=1),
read_latency=_get_int64_attr(read_latency, default_value=10),
Expand Down
3 changes: 2 additions & 1 deletion mlir/test/Dialect/Linalg/transform-op-fuse.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -51,7 +51,8 @@ transform.with_pdl_patterns {
^bb1(%arg1: !pdl.operation):
%0 = transform.structured.match ops{["linalg.elemwise_binary"]} in %arg1
%1, %loops:2 = transform.structured.fuse %0 {tile_sizes = [32, 32], tile_interchange = [0, 1]}
transform.loop.peel %loops#0
%loop = transform.cast %loops#0 : !pdl.operation to !transform.op<"scf.for">
transform.loop.peel %loop : (!transform.op<"scf.for">) -> !pdl.operation
}
}

Expand Down
8 changes: 4 additions & 4 deletions mlir/test/Dialect/Linalg/transform-op-match.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -14,11 +14,11 @@ transform.with_pdl_patterns {
transform.sequence %arg0 : !pdl.operation failures(propagate) {
^bb1(%arg1: !pdl.operation):
%match_name = transform.structured.match ops{["arith.constant"]} in %arg1
transform.test_print_remark_at_operand %match_name, "matched op name"
transform.test_print_remark_at_operand %match_name, "matched op name" : !pdl.operation
transform.test_consume_operand %match_name

%match_attr = transform.structured.match ops{["arith.constant"]} attributes{my_attr} in %arg1
transform.test_print_remark_at_operand %match_attr, "matched attr name"
transform.test_print_remark_at_operand %match_attr, "matched attr name" : !pdl.operation
transform.test_consume_operand %match_attr
}
}
Expand All @@ -38,7 +38,7 @@ transform.with_pdl_patterns {
^bb1(%arg1: !pdl.operation):
%match_name = transform.structured.match
ops{["arith.constant"]} filter_result_type = f32 in %arg1
transform.test_print_remark_at_operand %match_name, "matched op name"
transform.test_print_remark_at_operand %match_name, "matched op name" : !pdl.operation
transform.test_consume_operand %match_name
}
}
Expand Down Expand Up @@ -69,7 +69,7 @@ transform.with_pdl_patterns {
ops{["linalg.generic"]}
attributes{iterator_types = ["parallel", "parallel", "parallel"]}
in %arg1
transform.test_print_remark_at_operand %match_attr, "matched complex attr"
transform.test_print_remark_at_operand %match_attr, "matched complex attr" : !pdl.operation
transform.test_consume_operand %match_attr

%no_match = transform.structured.match
Expand Down
2 changes: 1 addition & 1 deletion mlir/test/Dialect/MemRef/transform-ops.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -33,5 +33,5 @@ transform.sequence failures(propagate) {
%0 = transform.structured.match ops{["memref.alloc"]} in %arg1
%1 = transform.memref.multibuffer %0 {factor = 2 : i64}
// Verify that the returned handle is usable.
transform.test_print_remark_at_operand %1, "transformed"
transform.test_print_remark_at_operand %1, "transformed" : !pdl.operation
}
34 changes: 17 additions & 17 deletions mlir/test/Dialect/SCF/transform-ops.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -21,12 +21,12 @@ transform.with_pdl_patterns {
^bb1(%arg1: !pdl.operation):
%0 = transform.structured.match ops{["arith.addi"]} in %arg1
// CHECK: = transform.loop.get_parent_for
%1 = transform.loop.get_parent_for %0
%2 = transform.loop.get_parent_for %0 { num_loops = 2 }
%3 = transform.loop.get_parent_for %0 { num_loops = 3 }
transform.test_print_remark_at_operand %1, "third loop"
transform.test_print_remark_at_operand %2, "second loop"
transform.test_print_remark_at_operand %3, "first loop"
%1 = transform.loop.get_parent_for %0 : (!pdl.operation) -> !transform.op<"scf.for">
%2 = transform.loop.get_parent_for %0 { num_loops = 2 } : (!pdl.operation) -> !transform.op<"scf.for">
%3 = transform.loop.get_parent_for %0 { num_loops = 3 } : (!pdl.operation) -> !transform.op<"scf.for">
transform.test_print_remark_at_operand %1, "third loop" : !transform.op<"scf.for">
transform.test_print_remark_at_operand %2, "second loop" : !transform.op<"scf.for">
transform.test_print_remark_at_operand %3, "first loop" : !transform.op<"scf.for">
}
}

Expand All @@ -44,7 +44,7 @@ transform.with_pdl_patterns {
^bb1(%arg1: !pdl.operation):
%0 = transform.structured.match ops{["arith.addi"]} in %arg1
// expected-error @below {{could not find an 'scf.for' parent}}
%1 = transform.loop.get_parent_for %0
%1 = transform.loop.get_parent_for %0 : (!pdl.operation) -> !transform.op<"scf.for">
}
}

Expand Down Expand Up @@ -85,9 +85,9 @@ transform.with_pdl_patterns {
sequence %arg0 : !pdl.operation failures(propagate) {
^bb1(%arg1: !pdl.operation):
%0 = transform.structured.match ops{["arith.addi"]} in %arg1
%1 = transform.loop.get_parent_for %0
%1 = transform.loop.get_parent_for %0 : (!pdl.operation) -> !transform.op<"scf.for">
// CHECK: = transform.loop.outline %{{.*}}
transform.loop.outline %1 {func_name = "foo"}
transform.loop.outline %1 {func_name = "foo"} : (!transform.op<"scf.for">) -> !pdl.operation
}
}

Expand Down Expand Up @@ -115,7 +115,7 @@ transform.with_pdl_patterns {
^bb1(%arg1: !pdl.operation):
%0 = transform.structured.match ops{["scf.while"]} in %arg1
// expected-error @below {{failed to outline}}
transform.loop.outline %0 {func_name = "foo"}
transform.loop.outline %0 {func_name = "foo"} : (!pdl.operation) -> !pdl.operation
}
}

Expand Down Expand Up @@ -145,8 +145,8 @@ transform.with_pdl_patterns {
sequence %arg0 : !pdl.operation failures(propagate) {
^bb1(%arg1: !pdl.operation):
%0 = transform.structured.match ops{["arith.addi"]} in %arg1
%1 = transform.loop.get_parent_for %0
transform.loop.peel %1
%1 = transform.loop.get_parent_for %0 : (!pdl.operation) -> !transform.op<"scf.for">
transform.loop.peel %1 : (!transform.op<"scf.for">) -> !pdl.operation
}
}

Expand Down Expand Up @@ -181,10 +181,10 @@ transform.with_pdl_patterns {
sequence %arg0 : !pdl.operation failures(propagate) {
^bb1(%arg1: !pdl.operation):
%0 = transform.structured.match ops{["arith.addf"]} in %arg1
%1 = transform.loop.get_parent_for %0
%2 = transform.loop.pipeline %1
%1 = transform.loop.get_parent_for %0 : (!pdl.operation) -> !transform.op<"scf.for">
%2 = transform.loop.pipeline %1 : (!transform.op<"scf.for">) -> !pdl.operation
// Verify that the returned handle is usable.
transform.test_print_remark_at_operand %2, "transformed"
transform.test_print_remark_at_operand %2, "transformed" : !pdl.operation
}
}

Expand All @@ -208,8 +208,8 @@ transform.with_pdl_patterns {
sequence %arg0 : !pdl.operation failures(propagate) {
^bb1(%arg1: !pdl.operation):
%0 = transform.structured.match ops{["arith.addi"]} in %arg1
%1 = transform.loop.get_parent_for %0
transform.loop.unroll %1 { factor = 4 }
%1 = transform.loop.get_parent_for %0 : (!pdl.operation) -> !transform.op<"scf.for">
transform.loop.unroll %1 { factor = 4 } : !transform.op<"scf.for">
}
}

4 changes: 2 additions & 2 deletions mlir/test/Dialect/Transform/expensive-checks.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@ transform.with_pdl_patterns {
// expected-note @below {{invalidated by this transform op that consumes its operand #0}}
test_consume_operand %1
// expected-error @below {{op uses a handle invalidated by a previously executed transform op}}
test_print_remark_at_operand %0, "remark"
test_print_remark_at_operand %0, "remark" : !pdl.operation
}
}

Expand Down Expand Up @@ -57,7 +57,7 @@ transform.with_pdl_patterns {
%2 = replicate num(%0) %1 : !pdl.operation, !pdl.operation
// expected-error @below {{a handle passed as operand #0 and consumed by this operation points to a payload operation more than once}}
test_consume_operand %2
test_print_remark_at_operand %0, "remark"
test_print_remark_at_operand %0, "remark" : !pdl.operation
}
}

Expand Down

0 comments on commit 59bb8af

Please sign in to comment.