Skip to content

Commit 458ede8

Browse files
bondhugulatensorflower-gardener
authored andcommitted
Introduce splat op + provide its LLVM lowering
- introduce splat op in standard dialect (currently for int/float/index input type, output type can be vector or statically shaped tensor) - implement LLVM lowering (when result type is 1-d vector) - add constant folding hook for it - while on Ops.cpp, fix some stale names Signed-off-by: Uday Bondhugula <uday@polymagelabs.com> Closes tensorflow/mlir#141 COPYBARA_INTEGRATE_REVIEW=tensorflow/mlir#141 from bondhugula:splat 48976a6aa0a75be6d91187db6418de989e03eb51 PiperOrigin-RevId: 270965304
1 parent 42d8fa6 commit 458ede8

File tree

10 files changed

+212
-27
lines changed

10 files changed

+212
-27
lines changed

mlir/g3doc/Dialects/Standard.md

Lines changed: 26 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -352,6 +352,32 @@ because of the
352352
[restrictions on dimensions and symbols](Affine.md#restrictions-on-dimensions-and-symbols)
353353
in these contexts.
354354

355+
### 'splat' operation
356+
357+
Syntax:
358+
359+
``` {.ebnf}
360+
operation ::= `splat` ssa-use `:` ( vector-type | tensor-type )
361+
```
362+
363+
Broadcast the operand to all elements of the result vector or tensor. The
364+
operand has to be of either integer or float type. When the result is a tensor,
365+
it has to be statically shaped.
366+
367+
Example:
368+
369+
```mlir {.mlir}
370+
%s = load %A[%i] : memref<128xf32>
371+
%v = splat %s : vector<4xf32>
372+
%t = splat %s : tensor<8x16xi32>
373+
```
374+
375+
TODO: This operation is easy to extend to broadcast to dynamically shaped
376+
tensors in the same way dynamically shaped memrefs are handled. `mlir {.mlir} //
377+
Broadcasts %s to a 2-d dynamically shaped tensor, with %m, %n binding // to the
378+
sizes of the two dynamic dimensions. %m = "foo"() : () -> (index) %n = "bar"() :
379+
() -> (index) %t = splat %s [%m, %n] : tensor<?x?xi32>`
380+
355381
### 'store' operation
356382

357383
Syntax:

mlir/g3doc/LangRef.md

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -692,7 +692,7 @@ index-type ::= `index`
692692

693693
The `index` type is a signless integer whose size is equal to the natural
694694
machine word of the target ([rationale](Rationale.md#signless-types)) and is
695-
used by the affine constructs in MLIR. Unlike fixed-size integers. It cannot be
695+
used by the affine constructs in MLIR. Unlike fixed-size integers, it cannot be
696696
used as an element of vector, tensor or memref type
697697
([rationale](Rationale.md#index-type-disallowed-in-vectortensormemref-types)).
698698

mlir/include/mlir/Dialect/StandardOps/Ops.td

Lines changed: 26 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -881,6 +881,32 @@ def ShlISOp : IntArithmeticOp<"shlis"> {
881881
let summary = "signed integer shift left";
882882
}
883883

884+
def SplatOp : Std_Op<"splat", [NoSideEffect]> {
885+
let summary = "splat or broadcast operation";
886+
let description = [{
887+
The "splat" op reads a value of integer or float type and broadcasts it into
888+
a vector or a tensor. The output of splat is thus a new value of either
889+
vector or tensor type with elemental type being its operand's type.
890+
When the result is a tensor, it has to be statically shaped.
891+
892+
%1 = splat %0 : vector<8xi32>
893+
%2 = splat %0 : tensor<4x8xi32>
894+
895+
// TODO: handle broadcast to dynamically shaped tensors.
896+
}];
897+
898+
let arguments = (ins AnyTypeOf<[AnyInteger, AnyFloat],
899+
"integer or float type">:$input);
900+
let results = (outs AnyTypeOf<[AnyVector, AnyStaticShapeTensor]>:$aggregate);
901+
902+
let builders =
903+
[OpBuilder<"Builder *builder, OperationState &result, Value *element, "
904+
"Type aggregateType",
905+
[{ build(builder, result, aggregateType, element); }]>];
906+
907+
let hasFolder = 1;
908+
}
909+
884910
def SubFOp : FloatArithmeticOp<"subf"> {
885911
let summary = "floating point subtraction operation";
886912
let hasFolder = 1;

mlir/lib/Conversion/StandardToLLVM/ConvertStandardToLLVM.cpp

Lines changed: 39 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -248,17 +248,6 @@ class LLVMLegalizationPattern : public LLVMOpLowering {
248248
return builder.create<LLVM::ConstantOp>(loc, getIndexType(), attr);
249249
}
250250

251-
// Get the array attribute named "position" containing the given list of
252-
// integers as integer attribute elements.
253-
static ArrayAttr getIntegerArrayAttr(ConversionPatternRewriter &builder,
254-
ArrayRef<int64_t> values) {
255-
SmallVector<Attribute, 4> attrs;
256-
attrs.reserve(values.size());
257-
for (int64_t pos : values)
258-
attrs.push_back(builder.getIntegerAttr(builder.getIndexType(), pos));
259-
return builder.getArrayAttr(attrs);
260-
}
261-
262251
// Extract raw data pointer value from a value representing a memref.
263252
static Value *extractMemRefElementPtr(ConversionPatternRewriter &builder,
264253
Location loc,
@@ -269,9 +258,9 @@ class LLVMLegalizationPattern : public LLVMOpLowering {
269258
if (hasStaticShape)
270259
return convertedMemRefValue;
271260
else
272-
return builder.create<LLVM::ExtractValueOp>(
273-
loc, elementTypePtr, convertedMemRefValue,
274-
getIntegerArrayAttr(builder, 0));
261+
return builder.create<LLVM::ExtractValueOp>(loc, elementTypePtr,
262+
convertedMemRefValue,
263+
builder.getIndexArrayAttr(0));
275264
return buffer;
276265
}
277266

@@ -1028,6 +1017,39 @@ struct CondBranchOpLowering
10281017
using Super::Super;
10291018
};
10301019

1020+
// The Splat operation is lowered to an insertelement + a shufflevector
1021+
// operation. Splat to only 1-d vector result types are lowered.
1022+
struct SplatOpLowering : public LLVMLegalizationPattern<SplatOp> {
1023+
using LLVMLegalizationPattern<SplatOp>::LLVMLegalizationPattern;
1024+
1025+
PatternMatchResult
1026+
matchAndRewrite(Operation *op, ArrayRef<Value *> operands,
1027+
ConversionPatternRewriter &rewriter) const override {
1028+
auto splatOp = cast<SplatOp>(op);
1029+
VectorType resultType = splatOp.getType().dyn_cast<VectorType>();
1030+
if (!resultType || resultType.getRank() != 1)
1031+
return matchFailure();
1032+
1033+
// First insert it into an undef vector so we can shuffle it.
1034+
auto vectorType = lowering.convertType(splatOp.getType());
1035+
Value *undef = rewriter.create<LLVM::UndefOp>(op->getLoc(), vectorType);
1036+
auto zero = rewriter.create<LLVM::ConstantOp>(
1037+
op->getLoc(), lowering.convertType(rewriter.getIntegerType(32)),
1038+
rewriter.getZeroAttr(rewriter.getIntegerType(32)));
1039+
1040+
auto v = rewriter.create<LLVM::InsertElementOp>(
1041+
op->getLoc(), vectorType, undef, splatOp.getOperand(), zero);
1042+
1043+
int64_t width = splatOp.getType().cast<VectorType>().getDimSize(0);
1044+
SmallVector<int32_t, 4> zeroValues(width, 0);
1045+
1046+
// Shuffle the value across the desired number of elements.
1047+
ArrayAttr zeroAttrs = rewriter.getI32ArrayAttr(zeroValues);
1048+
rewriter.replaceOpWithNewOp<LLVM::ShuffleVectorOp>(op, v, undef, zeroAttrs);
1049+
return matchSuccess();
1050+
}
1051+
};
1052+
10311053
} // namespace
10321054

10331055
static void ensureDistinctSuccessors(Block &bb) {
@@ -1089,9 +1111,9 @@ void mlir::populateStdToLLVMConversionPatterns(
10891111
DivFOpLowering, FuncOpConversion, IndexCastOpLowering, LoadOpLowering,
10901112
MemRefCastOpLowering, MulFOpLowering, MulIOpLowering, OrOpLowering,
10911113
RemISOpLowering, RemIUOpLowering, RemFOpLowering, ReturnOpLowering,
1092-
SelectOpLowering, SignExtendIOpLowering, SIToFPLowering, StoreOpLowering,
1093-
SubFOpLowering, SubIOpLowering, TruncateIOpLowering, XOrOpLowering,
1094-
ZeroExtendIOpLowering>(*converter.getDialect(), converter);
1114+
SelectOpLowering, SIToFPLowering, SignExtendIOpLowering, SplatOpLowering,
1115+
StoreOpLowering, SubFOpLowering, SubIOpLowering, TruncateIOpLowering,
1116+
XOrOpLowering, ZeroExtendIOpLowering>(*converter.getDialect(), converter);
10951117
}
10961118

10971119
// Convert types using the stored LLVM IR module.

mlir/lib/Dialect/StandardOps/Ops.cpp

Lines changed: 57 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -202,10 +202,10 @@ ParseResult mlir::parseDimAndSymbolList(OpAsmParser &parser,
202202
numDims = opInfos.size();
203203

204204
// Parse the optional symbol operands.
205-
auto affineIntTy = parser.getBuilder().getIndexType();
205+
auto indexTy = parser.getBuilder().getIndexType();
206206
if (parser.parseOperandList(opInfos,
207207
OpAsmParser::Delimiter::OptionalSquare) ||
208-
parser.resolveOperands(opInfos, affineIntTy, operands))
208+
parser.resolveOperands(opInfos, indexTy, operands))
209209
return failure();
210210
return success();
211211
}
@@ -1658,14 +1658,14 @@ static ParseResult parseExtractElementOp(OpAsmParser &parser,
16581658
SmallVector<OpAsmParser::OperandType, 4> indexInfo;
16591659
ShapedType type;
16601660

1661-
auto affineIntTy = parser.getBuilder().getIndexType();
1661+
auto indexTy = parser.getBuilder().getIndexType();
16621662
return failure(
16631663
parser.parseOperand(aggregateInfo) ||
16641664
parser.parseOperandList(indexInfo, OpAsmParser::Delimiter::Square) ||
16651665
parser.parseOptionalAttributeDict(result.attributes) ||
16661666
parser.parseColonType(type) ||
16671667
parser.resolveOperand(aggregateInfo, type, result.operands) ||
1668-
parser.resolveOperands(indexInfo, affineIntTy, result.operands) ||
1668+
parser.resolveOperands(indexInfo, indexTy, result.operands) ||
16691669
parser.addTypeToList(type.getElementType(), result.types));
16701670
}
16711671

@@ -1739,14 +1739,14 @@ static ParseResult parseLoadOp(OpAsmParser &parser, OperationState &result) {
17391739
SmallVector<OpAsmParser::OperandType, 4> indexInfo;
17401740
MemRefType type;
17411741

1742-
auto affineIntTy = parser.getBuilder().getIndexType();
1742+
auto indexTy = parser.getBuilder().getIndexType();
17431743
return failure(
17441744
parser.parseOperand(memrefInfo) ||
17451745
parser.parseOperandList(indexInfo, OpAsmParser::Delimiter::Square) ||
17461746
parser.parseOptionalAttributeDict(result.attributes) ||
17471747
parser.parseColonType(type) ||
17481748
parser.resolveOperand(memrefInfo, type, result.operands) ||
1749-
parser.resolveOperands(indexInfo, affineIntTy, result.operands) ||
1749+
parser.resolveOperands(indexInfo, indexTy, result.operands) ||
17501750
parser.addTypeToList(type.getElementType(), result.types));
17511751
}
17521752

@@ -2043,6 +2043,55 @@ static LogicalResult verify(SignExtendIOp op) {
20432043
return success();
20442044
}
20452045

2046+
//===----------------------------------------------------------------------===//
2047+
// SplatOp
2048+
//===----------------------------------------------------------------------===//
2049+
2050+
static void print(OpAsmPrinter &p, SplatOp op) {
2051+
p << "splat " << *op.getOperand();
2052+
p.printOptionalAttrDict(op.getAttrs());
2053+
p << " : " << op.getType();
2054+
}
2055+
2056+
static ParseResult parseSplatOp(OpAsmParser &parser, OperationState &result) {
2057+
OpAsmParser::OperandType splatValueInfo;
2058+
ShapedType shapedType;
2059+
2060+
return failure(parser.parseOperand(splatValueInfo) ||
2061+
parser.parseOptionalAttributeDict(result.attributes) ||
2062+
parser.parseColonType(shapedType) ||
2063+
parser.resolveOperand(splatValueInfo,
2064+
shapedType.getElementType(),
2065+
result.operands) ||
2066+
parser.addTypeToList(shapedType, result.types));
2067+
}
2068+
2069+
static LogicalResult verify(SplatOp op) {
2070+
// TODO: we could replace this by a trait.
2071+
if (op.getOperand()->getType() !=
2072+
op.getType().cast<ShapedType>().getElementType())
2073+
return op.emitError("operand should be of elemental type of result type");
2074+
2075+
return success();
2076+
}
2077+
2078+
// Constant folding hook for SplatOp.
2079+
OpFoldResult SplatOp::fold(ArrayRef<Attribute> operands) {
2080+
assert(operands.size() == 1 && "splat takes one operand");
2081+
2082+
auto constOperand = operands.front();
2083+
if (!constOperand ||
2084+
(!constOperand.isa<IntegerAttr>() && !constOperand.isa<FloatAttr>()))
2085+
return {};
2086+
2087+
auto shapedType = getType().cast<ShapedType>();
2088+
assert(shapedType.getElementType() == constOperand.getType() &&
2089+
"incorrect input attribute type for folding");
2090+
2091+
// SplatElementsAttr::get treats single value for second arg as being a splat.
2092+
return SplatElementsAttr::get(shapedType, {constOperand});
2093+
}
2094+
20462095
//===----------------------------------------------------------------------===//
20472096
// StoreOp
20482097
//===----------------------------------------------------------------------===//
@@ -2062,7 +2111,7 @@ static ParseResult parseStoreOp(OpAsmParser &parser, OperationState &result) {
20622111
SmallVector<OpAsmParser::OperandType, 4> indexInfo;
20632112
MemRefType memrefType;
20642113

2065-
auto affineIntTy = parser.getBuilder().getIndexType();
2114+
auto indexTy = parser.getBuilder().getIndexType();
20662115
return failure(
20672116
parser.parseOperand(storeValueInfo) || parser.parseComma() ||
20682117
parser.parseOperand(memrefInfo) ||
@@ -2072,7 +2121,7 @@ static ParseResult parseStoreOp(OpAsmParser &parser, OperationState &result) {
20722121
parser.resolveOperand(storeValueInfo, memrefType.getElementType(),
20732122
result.operands) ||
20742123
parser.resolveOperand(memrefInfo, memrefType, result.operands) ||
2075-
parser.resolveOperands(indexInfo, affineIntTy, result.operands));
2124+
parser.resolveOperands(indexInfo, indexTy, result.operands));
20762125
}
20772126

20782127
static LogicalResult verify(StoreOp op) {

mlir/test/Conversion/StandardToLLVM/convert-to-llvmir.mlir

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -552,3 +552,18 @@ func @vec_bin(%arg0: vector<2x2x2xf32>) -> vector<2x2x2xf32> {
552552
// And we're done
553553
// CHECK-NEXT: return
554554
}
555+
556+
// CHECK-LABEL: @splat
557+
// CHECK-SAME: [[A:%arg[0-9]+]]: !llvm<"<4 x float>">
558+
// CHECK-SAME: [[ELT:%arg[0-9]+]]: !llvm.float
559+
func @splat(%a: vector<4xf32>, %b: f32) -> vector<4xf32> {
560+
%vb = splat %b : vector<4xf32>
561+
%r = mulf %a, %vb : vector<4xf32>
562+
return %r : vector<4xf32>
563+
}
564+
// CHECK-NEXT: [[UNDEF:%[0-9]+]] = llvm.mlir.undef : !llvm<"<4 x float>">
565+
// CHECK-NEXT: [[ZERO:%[0-9]+]] = llvm.mlir.constant(0 : i32) : !llvm.i32
566+
// CHECK-NEXT: [[V:%[0-9]+]] = llvm.insertelement [[UNDEF]], [[ELT]], [[ZERO]] : !llvm<"<4 x float>">
567+
// CHECK-NEXT: [[SPLAT:%[0-9]+]] = llvm.shufflevector [[V]], [[UNDEF]] [0 : i32, 0 : i32, 0 : i32, 0 : i32]
568+
// CHECK-NEXT: [[SCALE:%[0-9]+]] = llvm.fmul [[A]], [[SPLAT]] : !llvm<"<4 x float>">
569+
// CHECK-NEXT: llvm.return [[SCALE]] : !llvm<"<4 x float>">

mlir/test/IR/core-ops.mlir

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -467,6 +467,17 @@ func @test_dimop(%arg0: tensor<4x4x?xf32>) {
467467
return
468468
}
469469

470+
// CHECK-LABEL: func @test_splat_op
471+
// CHECK-SAME: [[S:%arg[0-9]+]]: f32
472+
func @test_splat_op(%s : f32) {
473+
%v = splat %s : vector<8xf32>
474+
// CHECK: splat [[S]] : vector<8xf32>
475+
%t = splat %s : tensor<8xf32>
476+
// CHECK: splat [[S]] : tensor<8xf32>
477+
%u = "std.splat"(%s) : (f32) -> vector<4xf32>
478+
// CHECK: splat [[S]] : vector<4xf32>
479+
return
480+
}
470481

471482
// CHECK-LABEL: func @test_vector.transfer_ops(%arg0
472483
func @test_vector.transfer_ops(%arg0: memref<?x?xf32>) {

mlir/test/IR/invalid-ops.mlir

Lines changed: 24 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -821,3 +821,27 @@ func @return_not_in_function() {
821821
}): () -> ()
822822
return
823823
}
824+
825+
// -----
826+
827+
func @invalid_splat(%v : f32) {
828+
splat %v : memref<8xf32>
829+
// expected-error@-1 {{must be vector of any type values or statically shaped tensor of any type values}}
830+
return
831+
}
832+
833+
// -----
834+
835+
func @invalid_splat(%v : vector<8xf32>) {
836+
%w = splat %v : tensor<8xvector<8xf32>>
837+
// expected-error@-1 {{must be integer or float type}}
838+
return
839+
}
840+
841+
// -----
842+
843+
func @invalid_splat(%v : f32) { // expected-note {{prior use here}}
844+
splat %v : vector<8xf64>
845+
// expected-error@-1 {{expects different type than prior uses}}
846+
return
847+
}

mlir/test/Transforms/constant-fold.mlir

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -540,3 +540,15 @@ func @custom_insertion_position() {
540540
}) : () -> ()
541541
return
542542
}
543+
544+
// CHECK-LABEL: func @splat_fold
545+
func @splat_fold() -> (vector<4xf32>, tensor<4xf32>) {
546+
%c = constant 1.0 : f32
547+
%v = splat %c : vector<4xf32>
548+
%t = splat %c : tensor<4xf32>
549+
return %v, %t : vector<4xf32>, tensor<4xf32>
550+
551+
// CHECK-NEXT: [[V:%.*]] = constant dense<1.000000e+00> : vector<4xf32>
552+
// CHECK-NEXT: [[T:%.*]] = constant dense<1.000000e+00> : tensor<4xf32>
553+
// CHECK-NEXT: return [[V]], [[T]] : vector<4xf32>, tensor<4xf32>
554+
}

mlir/utils/vim/syntax/mlir.vim

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -31,7 +31,7 @@ syn match mlirType /x\s*\zsvector/
3131
" Operations.
3232
" Core ops (not exhaustive yet).
3333
" TODO: the list is not exhaustive.
34-
syn keyword mlirOps alloc addf addi call call_indirect cmpi constant dealloc dma_start dma_wait dim extract_element for getTensor if load memref_cast mulf muli store select subf subi tensor_cast
34+
syn keyword mlirOps alloc addf addi call call_indirect cmpi constant dealloc dma_start dma_wait dim extract_element for getTensor if load memref_cast mulf muli splat store select subf subi tensor_cast
3535

3636
" Affine ops.
3737
syn match mlirOps /\<affine\.apply\>/

0 commit comments

Comments
 (0)