Skip to content

Commit

Permalink
[MLIR][Linalg] More Linalg named ops (#90236)
Browse files Browse the repository at this point in the history
Adding `min` that was already implemented but not exposed.

Adding a few additional unary ops:
* Reciprocal as `arith.div(1,arg)`
* Round as `math.round(arg)`
* Sqrt as `math.sqrt(arg)`
* Rsqrt as `math.rsqrt(arg)`
* Square as `math.powf(arg, 2)`
* TanH as `math.tanh(arg)`

All with the agreed semantics at the round table: no implicit
broadcast/type cast.
  • Loading branch information
rengolin committed Apr 28, 2024
1 parent dc6ce60 commit 4cec3b3
Show file tree
Hide file tree
Showing 8 changed files with 853 additions and 3 deletions.
8 changes: 7 additions & 1 deletion mlir/include/mlir/Dialect/Linalg/IR/LinalgEnums.td
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,13 @@ def UnaryFn : I32EnumAttr<"UnaryFn", "", [
I32EnumAttrCase<"abs", 2>,
I32EnumAttrCase<"ceil", 3>,
I32EnumAttrCase<"floor", 4>,
I32EnumAttrCase<"negf", 5>
I32EnumAttrCase<"negf", 5>,
I32EnumAttrCase<"reciprocal", 6>,
I32EnumAttrCase<"round", 7>,
I32EnumAttrCase<"sqrt", 8>,
I32EnumAttrCase<"rsqrt", 9>,
I32EnumAttrCase<"square", 10>,
I32EnumAttrCase<"tanh", 11>
]> {
let genSpecializedAttr = 0;
let cppNamespace = "::mlir::linalg";
Expand Down
261 changes: 260 additions & 1 deletion mlir/include/mlir/Dialect/Linalg/IR/LinalgNamedStructuredOps.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -304,6 +304,216 @@ structured_op: !LinalgStructuredOpConfig
- !ScalarExpression
scalar_arg: I
--- !LinalgOpConfig
metadata: !LinalgOpMetadata
name: reciprocal
cpp_class_name: ReciprocalOp
doc: |-
Applies reciprocal(x) elementwise.
No numeric casting is performed on the input operand.
structured_op: !LinalgStructuredOpConfig
args:
- !LinalgOperandDefConfig
name: I
kind: input_tensor
type_var: T1
shape_map: affine_map<() -> ()>
- !LinalgOperandDefConfig
name: O
kind: output_tensor
type_var: T1
shape_map: affine_map<() -> ()>
indexing_maps: !LinalgIndexingMapsConfig
static_indexing_maps:
- affine_map<() -> ()>
- affine_map<() -> ()>
iterator_types: []
assignments:
- !ScalarAssign
arg: O
value: !ScalarExpression
scalar_fn:
kind: unary
fn_name: reciprocal
operands:
- !ScalarExpression
scalar_arg: I
--- !LinalgOpConfig
metadata: !LinalgOpMetadata
name: round
cpp_class_name: RoundOp
doc: |-
Applies round(x) elementwise.
No numeric casting is performed on the input operand.
structured_op: !LinalgStructuredOpConfig
args:
- !LinalgOperandDefConfig
name: I
kind: input_tensor
type_var: T1
shape_map: affine_map<() -> ()>
- !LinalgOperandDefConfig
name: O
kind: output_tensor
type_var: T1
shape_map: affine_map<() -> ()>
indexing_maps: !LinalgIndexingMapsConfig
static_indexing_maps:
- affine_map<() -> ()>
- affine_map<() -> ()>
iterator_types: []
assignments:
- !ScalarAssign
arg: O
value: !ScalarExpression
scalar_fn:
kind: unary
fn_name: round
operands:
- !ScalarExpression
scalar_arg: I
--- !LinalgOpConfig
metadata: !LinalgOpMetadata
name: sqrt
cpp_class_name: SqrtOp
doc: |-
Applies sqrt(x) elementwise.
No numeric casting is performed on the input operand.
structured_op: !LinalgStructuredOpConfig
args:
- !LinalgOperandDefConfig
name: I
kind: input_tensor
type_var: T1
shape_map: affine_map<() -> ()>
- !LinalgOperandDefConfig
name: O
kind: output_tensor
type_var: T1
shape_map: affine_map<() -> ()>
indexing_maps: !LinalgIndexingMapsConfig
static_indexing_maps:
- affine_map<() -> ()>
- affine_map<() -> ()>
iterator_types: []
assignments:
- !ScalarAssign
arg: O
value: !ScalarExpression
scalar_fn:
kind: unary
fn_name: sqrt
operands:
- !ScalarExpression
scalar_arg: I
--- !LinalgOpConfig
metadata: !LinalgOpMetadata
name: rsqrt
cpp_class_name: RsqrtOp
doc: |-
Applies rsqrt(x) elementwise.
No numeric casting is performed on the input operand.
structured_op: !LinalgStructuredOpConfig
args:
- !LinalgOperandDefConfig
name: I
kind: input_tensor
type_var: T1
shape_map: affine_map<() -> ()>
- !LinalgOperandDefConfig
name: O
kind: output_tensor
type_var: T1
shape_map: affine_map<() -> ()>
indexing_maps: !LinalgIndexingMapsConfig
static_indexing_maps:
- affine_map<() -> ()>
- affine_map<() -> ()>
iterator_types: []
assignments:
- !ScalarAssign
arg: O
value: !ScalarExpression
scalar_fn:
kind: unary
fn_name: rsqrt
operands:
- !ScalarExpression
scalar_arg: I
--- !LinalgOpConfig
metadata: !LinalgOpMetadata
name: square
cpp_class_name: SquareOp
doc: |-
Applies square(x) elementwise.
No numeric casting is performed on the input operand.
structured_op: !LinalgStructuredOpConfig
args:
- !LinalgOperandDefConfig
name: I
kind: input_tensor
type_var: T1
shape_map: affine_map<() -> ()>
- !LinalgOperandDefConfig
name: O
kind: output_tensor
type_var: T1
shape_map: affine_map<() -> ()>
indexing_maps: !LinalgIndexingMapsConfig
static_indexing_maps:
- affine_map<() -> ()>
- affine_map<() -> ()>
iterator_types: []
assignments:
- !ScalarAssign
arg: O
value: !ScalarExpression
scalar_fn:
kind: unary
fn_name: square
operands:
- !ScalarExpression
scalar_arg: I
--- !LinalgOpConfig
metadata: !LinalgOpMetadata
name: tanh
cpp_class_name: TanhOp
doc: |-
Applies tanh(x) elementwise.
No numeric casting is performed on the input operand.
structured_op: !LinalgStructuredOpConfig
args:
- !LinalgOperandDefConfig
name: I
kind: input_tensor
type_var: T1
shape_map: affine_map<() -> ()>
- !LinalgOperandDefConfig
name: O
kind: output_tensor
type_var: T1
shape_map: affine_map<() -> ()>
indexing_maps: !LinalgIndexingMapsConfig
static_indexing_maps:
- affine_map<() -> ()>
- affine_map<() -> ()>
iterator_types: []
assignments:
- !ScalarAssign
arg: O
value: !ScalarExpression
scalar_fn:
kind: unary
fn_name: tanh
operands:
- !ScalarExpression
scalar_arg: I
--- !LinalgOpConfig
metadata: !LinalgOpMetadata
name: elemwise_binary
cpp_class_name: ElemwiseBinaryOp
Expand Down Expand Up @@ -625,7 +835,7 @@ metadata: !LinalgOpMetadata
This means reduction/broadcast/element cast semantics is explicit. Further
passes can take that into account when lowering this code. For example,
a `linalg.broadcast` + `linalg.div` sequence can be lowered to a
a `linalg.broadcast` + `linalg.max` sequence can be lowered to a
`linalg.generic` with different affine maps for the two operands.
structured_op: !LinalgStructuredOpConfig
args:
Expand Down Expand Up @@ -663,6 +873,55 @@ structured_op: !LinalgStructuredOpConfig
- !ScalarExpression
scalar_arg: rhs
--- !LinalgOpConfig
metadata: !LinalgOpMetadata
name: min
cpp_class_name: MinOp
doc: |-
Takes the min (signed) between two inputs, elementwise.
The shapes and element types must be identical. The appropriate casts,
broadcasts and reductions should be done previously to calling this op.
This means reduction/broadcast/element cast semantics is explicit. Further
passes can take that into account when lowering this code. For example,
a `linalg.broadcast` + `linalg.min` sequence can be lowered to a
`linalg.generic` with different affine maps for the two operands.
structured_op: !LinalgStructuredOpConfig
args:
- !LinalgOperandDefConfig
name: lhs
kind: input_tensor
type_var: T1
shape_map: affine_map<() -> ()>
- !LinalgOperandDefConfig
name: rhs
kind: input_tensor
type_var: T1
shape_map: affine_map<() -> ()>
- !LinalgOperandDefConfig
name: O
kind: output_tensor
type_var: T1
shape_map: affine_map<() -> ()>
indexing_maps: !LinalgIndexingMapsConfig
static_indexing_maps:
- affine_map<() -> ()>
- affine_map<() -> ()>
- affine_map<() -> ()>
iterator_types: []
assignments:
- !ScalarAssign
arg: O
value: !ScalarExpression
scalar_fn:
kind: binary
fn_name: min_signed
operands:
- !ScalarExpression
scalar_arg: lhs
- !ScalarExpression
scalar_arg: rhs
--- !LinalgOpConfig
metadata: !LinalgOpMetadata
name: matmul
cpp_class_name: MatmulOp
Expand Down
16 changes: 16 additions & 0 deletions mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -395,6 +395,22 @@ class RegionBuilderHelper {
return builder.create<math::FloorOp>(arg.getLoc(), arg);
case UnaryFn::negf:
return builder.create<arith::NegFOp>(arg.getLoc(), arg);
case UnaryFn::reciprocal: {
Attribute oneAttr = builder.getOneAttr(arg.getType());
auto one = builder.create<arith::ConstantOp>(arg.getLoc(),
::cast<TypedAttr>(oneAttr));
return builder.create<arith::DivFOp>(arg.getLoc(), one, arg);
}
case UnaryFn::round:
return builder.create<math::RoundOp>(arg.getLoc(), arg);
case UnaryFn::sqrt:
return builder.create<math::SqrtOp>(arg.getLoc(), arg);
case UnaryFn::rsqrt:
return builder.create<math::RsqrtOp>(arg.getLoc(), arg);
case UnaryFn::square:
return builder.create<arith::MulFOp>(arg.getLoc(), arg, arg);
case UnaryFn::tanh:
return builder.create<math::TanhOp>(arg.getLoc(), arg);
}
llvm_unreachable("unsupported unary function");
}
Expand Down
5 changes: 5 additions & 0 deletions mlir/python/mlir/dialects/linalg/opdsl/lang/comprehension.py
Original file line number Diff line number Diff line change
Expand Up @@ -291,6 +291,11 @@ class UnaryFn:
ceil = UnaryFnType("ceil")
floor = UnaryFnType("floor")
negf = UnaryFnType("negf")
round = UnaryFnType("round")
sqrt = UnaryFnType("sqrt")
rsqrt = UnaryFnType("rsqrt")
square = UnaryFnType("square")
tanh = UnaryFnType("tanh")


class BinaryFnType:
Expand Down

0 comments on commit 4cec3b3

Please sign in to comment.