Skip to content

Commit 3fe7fe4

Browse files
author
Tobias Gysi
committed
[mlir][linalg] Add unsigned min/max/cast function to OpDSL.
Update OpDSL to support unsigned integers by adding unsigned min/max/cast signatures. Add tests in OpDSL and on the C++ side to verify the proper signed and unsigned operations are emitted. The patch addresses an issue brought up in https://reviews.llvm.org/D111170. Reviewed By: rsuderman Differential Revision: https://reviews.llvm.org/D111230
1 parent 06404d5 commit 3fe7fe4

File tree

10 files changed

+601
-158
lines changed

10 files changed

+601
-158
lines changed

mlir/include/mlir/Dialect/Linalg/IR/LinalgNamedStructuredOps.yaml

Lines changed: 277 additions & 0 deletions
Large diffs are not rendered by default.

mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp

Lines changed: 32 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -196,31 +196,40 @@ class RegionBuilderHelper {
196196
// If the cast cannot be performed, a warning will be issued and the
197197
// operand returned as-is (which will presumably yield a verification
198198
// issue downstream).
199-
Value cast(Type toType, Value operand) {
199+
Value cast(Type toType, Value operand, bool isUnsignedCast) {
200200
OpBuilder builder = getBuilder();
201201
auto loc = operand.getLoc();
202202

203203
if (operand.getType() == toType)
204204
return operand;
205205
if (auto toIntType = toType.dyn_cast<IntegerType>()) {
206206
// If operand is floating point, cast directly to the int type.
207-
if (operand.getType().isa<FloatType>())
207+
if (operand.getType().isa<FloatType>()) {
208+
if (isUnsignedCast)
209+
return builder.create<FPToUIOp>(loc, toType, operand);
208210
return builder.create<FPToSIOp>(loc, toType, operand);
211+
}
209212
// Cast index operands directly to the int type.
210213
if (operand.getType().isIndex())
211214
return builder.create<IndexCastOp>(loc, toType, operand);
212215
if (auto fromIntType = operand.getType().dyn_cast<IntegerType>()) {
213-
// Either sign extend or truncate.
214-
if (toIntType.getWidth() > fromIntType.getWidth())
216+
// Either extend or truncate.
217+
if (toIntType.getWidth() > fromIntType.getWidth()) {
218+
if (isUnsignedCast)
219+
return builder.create<ZeroExtendIOp>(loc, toType, operand);
215220
return builder.create<SignExtendIOp>(loc, toType, operand);
221+
}
216222
if (toIntType.getWidth() < fromIntType.getWidth())
217223
return builder.create<TruncateIOp>(loc, toType, operand);
218224
}
219225
} else if (auto toFloatType = toType.dyn_cast<FloatType>()) {
220226
// If operand is integer, cast directly to the float type.
221227
// Note that it is unclear how to cast from BF16<->FP16.
222-
if (operand.getType().isa<IntegerType>())
228+
if (operand.getType().isa<IntegerType>()) {
229+
if (isUnsignedCast)
230+
return builder.create<UIToFPOp>(loc, toFloatType, operand);
223231
return builder.create<SIToFPOp>(loc, toFloatType, operand);
232+
}
224233
if (auto fromFloatType = operand.getType().dyn_cast<FloatType>()) {
225234
if (toFloatType.getWidth() > fromFloatType.getWidth())
226235
return builder.create<FPExtOp>(loc, toFloatType, operand);
@@ -284,6 +293,15 @@ class RegionBuilderHelper {
284293
llvm_unreachable("unsupported non numeric type");
285294
}
286295

296+
Value applyfn__max_unsigned(Value lhs, Value rhs) {
297+
OpBuilder builder = getBuilder();
298+
if (isFloatingPoint(lhs))
299+
return builder.create<MaxFOp>(lhs.getLoc(), lhs, rhs);
300+
if (isInteger(lhs))
301+
return builder.create<MaxUIOp>(lhs.getLoc(), lhs, rhs);
302+
llvm_unreachable("unsupported non numeric type");
303+
}
304+
287305
Value applyfn__min(Value lhs, Value rhs) {
288306
OpBuilder builder = getBuilder();
289307
if (isFloatingPoint(lhs))
@@ -293,6 +311,15 @@ class RegionBuilderHelper {
293311
llvm_unreachable("unsupported non numeric type");
294312
}
295313

314+
Value applyfn__min_unsigned(Value lhs, Value rhs) {
315+
OpBuilder builder = getBuilder();
316+
if (isFloatingPoint(lhs))
317+
return builder.create<MinFOp>(lhs.getLoc(), lhs, rhs);
318+
if (isInteger(lhs))
319+
return builder.create<MinUIOp>(lhs.getLoc(), lhs, rhs);
320+
llvm_unreachable("unsupported non numeric type");
321+
}
322+
296323
void yieldOutputs(ValueRange values) {
297324
assert(!values.empty() && "linalg ops must yield outputs");
298325
if (values.empty())

mlir/python/mlir/dialects/linalg/opdsl/lang/comprehension.py

Lines changed: 17 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -340,6 +340,8 @@ class PrimFn:
340340
max = PrimFnType("max")
341341
min = PrimFnType("min")
342342
sub = PrimFnType("sub")
343+
max_unsigned = PrimFnType("max_unsigned")
344+
min_unsigned = PrimFnType("min_unsigned")
343345

344346

345347
class ReduceFnType:
@@ -365,6 +367,8 @@ class ReduceFn:
365367
mul = PrimFn.mul.reduce
366368
max = PrimFn.max.reduce
367369
min = PrimFn.min.reduce
370+
max_unsigned = PrimFn.max_unsigned.reduce
371+
min_unsigned = PrimFn.min_unsigned.reduce
368372

369373

370374
class PrimApply(TensorExpression):
@@ -438,8 +442,8 @@ def __init__(self, to_type: TypeVar, operand: TensorExpression):
438442
self.operand = operand
439443

440444
def to_scalar_expression(self) -> ScalarExpression:
441-
return ScalarSymbolicCast(self.to_type,
442-
self.operand.to_scalar_expression()).expr()
445+
return ScalarSymbolicCast(self.to_type, self.operand.to_scalar_expression(),
446+
False).expr()
443447

444448
def visit_tensor_exprs(self, callback):
445449
super().visit_tensor_exprs(callback)
@@ -449,6 +453,17 @@ def __repr__(self):
449453
return f"cast({self.to_type}, {repr(self.operand)})"
450454

451455

456+
class cast_unsigned(cast):
457+
"""Casts the element type to an unsigned type (typically symbolic TypeVar)."""
458+
459+
def to_scalar_expression(self) -> ScalarExpression:
460+
return ScalarSymbolicCast(self.to_type, self.operand.to_scalar_expression(),
461+
True).expr()
462+
463+
def __repr__(self):
464+
return f"cast_unsigned({self.to_type}, {repr(self.operand)})"
465+
466+
452467
class ReduceApply(TensorExpression):
453468
"""Application of a reduction.
454469

mlir/python/mlir/dialects/linalg/opdsl/lang/emitter.py

Lines changed: 29 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -230,10 +230,12 @@ def expression(self, expr: ScalarExpression) -> Value:
230230
return fn(*operand_values)
231231
elif expr.symbolic_cast:
232232
operand_value = self.expression(expr.symbolic_cast.operand)
233-
return self.cast(expr.symbolic_cast.to_type.name, operand_value)
233+
return self.cast(expr.symbolic_cast.to_type.name, operand_value,
234+
expr.symbolic_cast.is_unsigned_cast)
234235
raise NotImplementedError(f"Unimplemented scalar body expression: {expr}")
235236

236-
def cast(self, type_var_name: str, operand: Value) -> Value:
237+
def cast(self, type_var_name: str, operand: Value,
238+
is_unsigned_cast: bool) -> Value:
237239
try:
238240
to_type = self.type_mapping[type_var_name]
239241
except KeyError:
@@ -242,29 +244,37 @@ def cast(self, type_var_name: str, operand: Value) -> Value:
242244
if operand.type == to_type:
243245
return operand
244246
if _is_integer_type(to_type):
245-
return self._cast_to_integer(to_type, operand)
247+
return self._cast_to_integer(to_type, operand, is_unsigned_cast)
246248
elif _is_floating_point_type(to_type):
247-
return self._cast_to_floating_point(to_type, operand)
249+
return self._cast_to_floating_point(to_type, operand, is_unsigned_cast)
248250

249-
def _cast_to_integer(self, to_type: Type, operand: Value) -> Value:
251+
def _cast_to_integer(self, to_type: Type, operand: Value,
252+
is_unsigned_cast: bool) -> Value:
250253
to_width = IntegerType(to_type).width
251254
operand_type = operand.type
252255
if _is_floating_point_type(operand_type):
256+
if is_unsigned_cast:
257+
return std.FPToUIOp(to_type, operand).result
253258
return std.FPToSIOp(to_type, operand).result
254259
if _is_index_type(operand_type):
255260
return std.IndexCastOp(to_type, operand).result
256261
# Assume integer.
257262
from_width = IntegerType(operand_type).width
258263
if to_width > from_width:
264+
if is_unsigned_cast:
265+
return std.ZeroExtendIOp(to_type, operand).result
259266
return std.SignExtendIOp(to_type, operand).result
260267
elif to_width < from_width:
261268
return std.TruncateIOp(to_type, operand).result
262269
raise ValueError(f"Unable to cast body expression from {operand_type} to "
263270
f"{to_type}")
264271

265-
def _cast_to_floating_point(self, to_type: Type, operand: Value) -> Value:
272+
def _cast_to_floating_point(self, to_type: Type, operand: Value,
273+
is_unsigned_cast: bool) -> Value:
266274
operand_type = operand.type
267275
if _is_integer_type(operand_type):
276+
if is_unsigned_cast:
277+
return std.UIToFPOp(to_type, operand).result
268278
return std.SIToFPOp(to_type, operand).result
269279
# Assume FloatType.
270280
to_width = _get_floating_point_width(to_type)
@@ -324,13 +334,26 @@ def _eval_max(self, lhs: Value, rhs: Value) -> Value:
324334
return std.MaxSIOp(lhs.type, lhs, rhs).result
325335
raise NotImplementedError("Unsupported 'max' operand: {lhs}")
326336

337+
def _eval_max_unsigned(self, lhs: Value, rhs: Value) -> Value:
338+
if _is_floating_point_type(lhs.type):
339+
return std.MaxFOp(lhs.type, lhs, rhs).result
340+
if _is_integer_type(lhs.type) or _is_index_type(lhs.type):
341+
return std.MaxUIOp(lhs.type, lhs, rhs).result
342+
raise NotImplementedError("Unsupported 'max_unsigned' operand: {lhs}")
343+
327344
def _eval_min(self, lhs: Value, rhs: Value) -> Value:
328345
if _is_floating_point_type(lhs.type):
329346
return std.MinFOp(lhs.type, lhs, rhs).result
330347
if _is_integer_type(lhs.type) or _is_index_type(lhs.type):
331348
return std.MinSIOp(lhs.type, lhs, rhs).result
332349
raise NotImplementedError("Unsupported 'min' operand: {lhs}")
333350

351+
def _eval_min_unsigned(self, lhs: Value, rhs: Value) -> Value:
352+
if _is_floating_point_type(lhs.type):
353+
return std.MinFOp(lhs.type, lhs, rhs).result
354+
if _is_integer_type(lhs.type) or _is_index_type(lhs.type):
355+
return std.MinUIOp(lhs.type, lhs, rhs).result
356+
raise NotImplementedError("Unsupported 'min_unsigned' operand: {lhs}")
334357

335358
def _infer_structured_outs(op_config: LinalgStructuredOpConfig,
336359
in_arg_defs: Sequence[OperandDefConfig],

mlir/python/mlir/dialects/linalg/opdsl/lang/scalar_expr.py

Lines changed: 6 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -85,15 +85,17 @@ def __repr__(self):
8585
class ScalarSymbolicCast:
8686
"""A type of ScalarExpression that symbolically casts an operand to a TypeVar."""
8787

88-
def __init__(self, to_type: TypeVar, operand: "ScalarExpression"):
88+
def __init__(self, to_type: TypeVar, operand: "ScalarExpression",
89+
is_unsigned_cast: bool):
8990
self.to_type = to_type
9091
self.operand = operand
92+
self.is_unsigned_cast = is_unsigned_cast
9193

9294
def expr(self) -> "ScalarExpression":
9395
return ScalarExpression(symbolic_cast=self)
9496

9597
def __repr__(self):
96-
return f"ScalarSymbolicCast({self.to_type}, {self.operand})"
98+
return f"ScalarSymbolicCast({self.to_type}, {self.operand}, {self.is_unsigned_cast})"
9799

98100

99101
class ScalarExpression(YAMLObject):
@@ -144,7 +146,8 @@ def to_yaml_custom_dict(self):
144146
return dict(
145147
symbolic_cast=dict(
146148
type_var=self.symbolic_cast.to_type.name,
147-
operands=[self.symbolic_cast.operand]))
149+
operands=[self.symbolic_cast.operand],
150+
is_unsigned_cast=self.symbolic_cast.is_unsigned_cast))
148151
else:
149152
raise ValueError(f"Unexpected ScalarExpression type: {self}")
150153

mlir/python/mlir/dialects/linalg/opdsl/ops/core_named_ops.py

Lines changed: 49 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,20 @@ def matmul(
2020
implements(ContractionOpInterface)
2121
C[D.m, D.n] += cast(U, A[D.m, D.k]) * cast(U, B[D.k, D.n])
2222

23+
@linalg_structured_op
24+
def matmul_unsigned(
25+
A=TensorDef(T1, S.M, S.K),
26+
B=TensorDef(T2, S.K, S.N),
27+
C=TensorDef(U, S.M, S.N, output=True)):
28+
"""Performs an unsigned matrix multiplication of two 2D inputs.
29+
30+
Numeric casting is performed on the operands to the inner multiply, promoting
31+
them to the same data type as the accumulator/output.
32+
"""
33+
domain(D.m, D.n, D.k)
34+
implements(ContractionOpInterface)
35+
C[D.m, D.n] += cast_unsigned(U, A[D.m, D.k]) * cast_unsigned(U, B[D.k, D.n])
36+
2337
@linalg_structured_op
2438
def quantized_matmul(
2539
A=TensorDef(T1, S.M, S.K),
@@ -411,6 +425,24 @@ def pooling_nhwc_max(
411425
cast(U, I[D.n, D.oh * S.SH + D.kh * S.DH, D.ow * S.SW + D.kw * S.DW,
412426
D.c]))
413427

428+
@linalg_structured_op
429+
def pooling_nhwc_max_unsigned(
430+
I=TensorDef(T1, S.N, S.OH * S.SH + S.KH * S.DH, S.OW * S.SW + S.KW * S.DW, S.C),
431+
K=TensorDef(T2, S.KH, S.KW, index_dims=[D.kh, D.kw]),
432+
O=TensorDef(U, S.N, S.OH, S.OW, S.C, output=True),
433+
strides=AttributeDef(S.SH, S.SW),
434+
dilations=AttributeDef(S.DH, S.DW)):
435+
"""Performs unsigned max pooling.
436+
437+
Numeric casting is performed on the input operand, promoting it to the same
438+
data type as the accumulator/output.
439+
"""
440+
implements(ConvolutionOpInterface)
441+
domain(D.n, D.oh, D.ow, D.kh, D.kw, D.c)
442+
O[D.n, D.oh, D.ow, D.c] = ReduceFn.max_unsigned(D.kh, D.kw)(
443+
cast_unsigned(
444+
U, I[D.n, D.oh * S.SH + D.kh * S.DH, D.ow * S.SW + D.kw * S.DW, D.c]))
445+
414446
@linalg_structured_op
415447
def pooling_nchw_max(
416448
I=TensorDef(T1, S.N, S.C, S.OH * S.SH + S.KH * S.DH, S.OW * S.SW + S.KW * S.DW),
@@ -447,6 +479,23 @@ def pooling_nhwc_min(
447479
cast(U, I[D.n, D.oh * S.SH + D.kh * S.DH, D.ow * S.SW + D.kw * S.DW,
448480
D.c]))
449481

482+
@linalg_structured_op
483+
def pooling_nhwc_min_unsigned(
484+
I=TensorDef(T1, S.N, S.OH * S.SH + S.KH * S.DH, S.OW * S.SW + S.KW * S.DW, S.C),
485+
K=TensorDef(T2, S.KH, S.KW, index_dims=[D.kh, D.kw]),
486+
O=TensorDef(U, S.N, S.OH, S.OW, S.C, output=True),
487+
strides=AttributeDef(S.SH, S.SW),
488+
dilations=AttributeDef(S.DH, S.DW)):
489+
"""Performs unsigned min pooling.
490+
491+
Numeric casting is performed on the input operand, promoting it to the same
492+
data type as the accumulator/output.
493+
"""
494+
implements(ConvolutionOpInterface)
495+
domain(D.n, D.oh, D.ow, D.kh, D.kw, D.c)
496+
O[D.n, D.oh, D.ow, D.c] = ReduceFn.min_unsigned(D.kh, D.kw)(
497+
cast_unsigned(
498+
U, I[D.n, D.oh * S.SH + D.kh * S.DH, D.ow * S.SW + D.kw * S.DW, D.c]))
450499

451500
@linalg_structured_op
452501
def pooling_ndhwc_sum(

0 commit comments

Comments
 (0)