243 changes: 148 additions & 95 deletions mlir/include/mlir/Dialect/Linalg/IR/LinalgStructuredOps.td

Large diffs are not rendered by default.

2 changes: 1 addition & 1 deletion mlir/include/mlir/Dialect/Linalg/IR/LinalgTraits.h
Original file line number Diff line number Diff line change
Expand Up @@ -90,7 +90,7 @@ class NamedStructuredOpTrait
unsigned getNumOutputs() {
ConcreteType concreteOp = cast<ConcreteType>(this->getOperation());
return concreteOp.output_buffers().size() +
concreteOp.output_tensors().size();
concreteOp.result_tensors().size();
}
static LogicalResult verifyTrait(Operation *op) {
ConcreteType concreteOp = cast<ConcreteType>(op);
Expand Down
2 changes: 2 additions & 0 deletions mlir/include/mlir/Dialect/StandardOps/EDSC/Intrinsics.h
Original file line number Diff line number Diff line change
Expand Up @@ -38,10 +38,12 @@ using std_mulf = ValueBuilder<MulFOp>;
using std_memref_cast = ValueBuilder<MemRefCastOp>;
using std_re = ValueBuilder<ReOp>;
using std_ret = OperationBuilder<ReturnOp>;
using std_rsqrt = ValueBuilder<RsqrtOp>;
using std_select = ValueBuilder<SelectOp>;
using std_load = ValueBuilder<LoadOp>;
using std_splat = ValueBuilder<SplatOp>;
using std_store = OperationBuilder<StoreOp>;
using std_subf = ValueBuilder<SubFOp>;
using std_subi = ValueBuilder<SubIOp>;
using std_sub_view = ValueBuilder<SubViewOp>;
using std_tanh = ValueBuilder<TanhOp>;
Expand Down
8 changes: 0 additions & 8 deletions mlir/include/mlir/Dialect/Utils/StructuredOpsUtils.h
Original file line number Diff line number Diff line change
Expand Up @@ -58,14 +58,6 @@ constexpr StringRef getIndexingMapsAttrName() { return "indexing_maps"; }
/// op's iterators.
constexpr StringRef getIteratorTypesAttrName() { return "iterator_types"; }

/// Attribute name for the IntegerAttr which encodes the number of input buffer
/// arguments.
constexpr StringRef getArgsInAttrName() { return "args_in"; }

/// Attribute name for the IntegerAttr which encodes the number of input buffer
/// arguments.
constexpr StringRef getArgsOutAttrName() { return "args_out"; }

/// Attribute name for the StringAttr which encodes an optional documentation
/// string of the structured op.
constexpr StringRef getDocAttrName() { return "doc"; }
Expand Down
2 changes: 1 addition & 1 deletion mlir/lib/Conversion/LinalgToSPIRV/LinalgToSPIRV.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -71,7 +71,7 @@ SingleWorkgroupReduction::matchAsPerformingReduction(
return llvm::None;

// Make sure this is reduction with one input and one output.
if (genericOp.args_in() != 1 || genericOp.args_out() != 1)
if (genericOp.getNumInputs() != 1 || genericOp.getNumOutputs() != 1)
return llvm::None;

auto originalInputType = op->getOperand(0).getType().cast<MemRefType>();
Expand Down
143 changes: 67 additions & 76 deletions mlir/lib/Dialect/Linalg/EDSC/Builders.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -23,37 +23,36 @@ using namespace mlir::scf;

Operation *mlir::edsc::makeGenericLinalgOp(
ArrayRef<IteratorType> iteratorTypes, ArrayRef<StructuredIndexed> inputs,
ArrayRef<StructuredIndexed> outputs,
ArrayRef<StructuredIndexed> outputBuffers, ArrayRef<Value> initTensors,
ArrayRef<StructuredIndexed> resultTensorTypes,
function_ref<void(ValueRange)> regionBuilder, ArrayRef<Value> otherValues,
ArrayRef<Attribute> otherAttributes) {
for (unsigned i = 0, e = outputs.size(); i + 1 < e; ++i)
assert(!(outputs[i].getType().isa<RankedTensorType>() &&
outputs[i + 1].getType().isa<MemRefType>()) &&
"output tensors must be passed after output buffers");
auto &builder = edsc::ScopedContext::getBuilderRef();
auto *ctx = builder.getContext();
unsigned nInputs = inputs.size();
unsigned nOutputs = outputs.size();
OpBuilder &builder = edsc::ScopedContext::getBuilderRef();

// Build maps
SmallVector<SmallVector<AffineExpr, 4>, 4> exprsList;
exprsList.reserve(nInputs + nOutputs);
for (auto structuredIndexed : inputs)
exprsList.emplace_back(structuredIndexed.getExprs().begin(),
structuredIndexed.getExprs().end());
for (auto structuredIndexed : outputs)
exprsList.emplace_back(structuredIndexed.getExprs().begin(),
structuredIndexed.getExprs().end());
exprsList.reserve(inputs.size() + outputBuffers.size() + initTensors.size());
for (auto container : {inputs, outputBuffers, resultTensorTypes})
for (const StructuredIndexed &s : container)
exprsList.emplace_back(s.getExprs().begin(), s.getExprs().end());
auto maps = AffineMap::inferFromExprList(exprsList);

unsigned nViews = nInputs + nOutputs;
SmallVector<Value, 4> values;
values.reserve(nViews);
values.append(inputs.begin(), inputs.end());
std::copy_if(outputs.begin(), outputs.end(), std::back_inserter(values),
[](StructuredIndexed s) { return s.hasValue(); });
SmallVector<Type, 4> types;
std::copy_if(outputs.begin(), outputs.end(), std::back_inserter(types),
[](StructuredIndexed s) { return !s.hasValue(); });
assert(llvm::all_of(resultTensorTypes, [](const StructuredIndexed &s) {
return !s.hasValue();
}));
std::copy(resultTensorTypes.begin(), resultTensorTypes.end(),
std::back_inserter(types));

SmallVector<Value, 4> inputValues, outputBufferValues, initTensorValues;
inputValues.reserve(inputs.size());
outputBufferValues.reserve(outputBuffers.size());
initTensorValues.reserve(initTensors.size());
std::copy(inputs.begin(), inputs.end(), std::back_inserter(inputValues));
std::copy(outputBuffers.begin(), outputBuffers.end(),
std::back_inserter(outputBufferValues));
std::copy(initTensors.begin(), initTensors.end(),
std::back_inserter(initTensorValues));

auto iteratorStrTypes =
llvm::to_vector<8>(llvm::map_range(iteratorTypes, toString));
Expand All @@ -63,9 +62,9 @@ Operation *mlir::edsc::makeGenericLinalgOp(
.create<linalg::GenericOp>(
edsc::ScopedContext::getLocation(),
types,
values,
IntegerAttr::get(IntegerType::get(64, ctx), nInputs),
IntegerAttr::get(IntegerType::get(64, ctx), nOutputs),
inputValues,
outputBufferValues,
initTensorValues,
builder.getAffineMapArrayAttr(maps),
builder.getStrArrayAttr(iteratorStrTypes),
StringAttr() /*doc*/,
Expand All @@ -78,11 +77,12 @@ Operation *mlir::edsc::makeGenericLinalgOp(

using namespace edsc;
SmallVector<Type, 4> blockTypes;
blockTypes.reserve(values.size());
for (auto it : llvm::enumerate(values))
blockTypes.push_back((it.index() < nViews)
? getElementTypeOrSelf(it.value())
: it.value().getType());
blockTypes.reserve(inputs.size() + outputBuffers.size() + initTensors.size());
for (auto container : {inputs, outputBuffers})
for (const StructuredIndexed &s : container)
blockTypes.push_back(getElementTypeOrSelf(s.getType()));
for (Value v : initTensors)
blockTypes.push_back(getElementTypeOrSelf(v.getType()));

assert(op->getNumRegions() == 1);
assert(op->getRegion(0).empty());
Expand Down Expand Up @@ -113,20 +113,17 @@ Operation *mlir::edsc::ops::linalg_generic_pointwise(
UnaryPointwiseOpBuilder unaryOp, StructuredIndexed I, StructuredIndexed O) {
SmallVector<IteratorType, 4> iterTypes(O.getExprs().size(),
IteratorType::Parallel);
if (O.getType().isa<RankedTensorType>()) {
auto fun = [&unaryOp](ValueRange args) {
assert(args.size() == 1 && "expected 1 block arguments");
Value a(args[0]);
linalg_yield(unaryOp(a));
};
return makeGenericLinalgOp(iterTypes, {I}, {O}, fun);
}
auto fun = [&unaryOp](ValueRange args) {
assert(args.size() == 2 && "expected 2 block arguments");
assert(!args.empty() >= 1 && "expected >= 1 block arguments");
Value a(args[0]);
linalg_yield(unaryOp(a));
};
return makeGenericLinalgOp(iterTypes, {I}, {O}, fun);
if (O.getType().isa<RankedTensorType>())
return makeGenericLinalgOp(iterTypes, /*inputs=*/{I}, /*outputBuffers=*/{},
/*initTensors=*/{}, /*resultTensorTypes=*/{O},
fun);
return makeGenericLinalgOp(iterTypes, /*inputs=*/{I}, /*outputBuffers=*/{O},
/*initTensors=*/{}, /*resultTensorTypes=*/{}, fun);
}

Operation *mlir::edsc::ops::linalg_generic_pointwise_tanh(StructuredIndexed I,
Expand All @@ -141,20 +138,18 @@ Operation *mlir::edsc::ops::linalg_generic_pointwise(
StructuredIndexed I2, StructuredIndexed O) {
SmallVector<IteratorType, 4> iterTypes(O.getExprs().size(),
IteratorType::Parallel);
if (O.getType().isa<RankedTensorType>()) {
auto fun = [&binaryOp](ValueRange args) {
assert(args.size() == 2 && "expected 2 block arguments");
Value a(args[0]), b(args[1]);
linalg_yield(binaryOp(a, b));
};
return makeGenericLinalgOp(iterTypes, {I1, I2}, {O}, fun);
}
auto fun = [&binaryOp](ValueRange args) {
assert(args.size() == 3 && "expected 3 block arguments");
assert(args.size() >= 2 && "expected >= 2 block arguments");
Value a(args[0]), b(args[1]);
linalg_yield(binaryOp(a, b));
};
return makeGenericLinalgOp(iterTypes, {I1, I2}, {O}, fun);
if (O.getType().isa<RankedTensorType>())
return makeGenericLinalgOp(
iterTypes, /*inputs=*/{I1, I2}, /*outputBuffers=*/{},
/*initTensors=*/{}, /*resultTensorTypes=*/{O}, fun);
return makeGenericLinalgOp(iterTypes, /*inputs=*/{I1, I2},
/*outputBuffers=*/{O},
/*initTensors=*/{}, /*resultTensorTypes=*/{}, fun);
}

Operation *mlir::edsc::ops::linalg_generic_pointwise_add(StructuredIndexed I1,
Expand Down Expand Up @@ -185,23 +180,10 @@ mlir::edsc::ops::linalg_generic_matmul(Value vA, Value vB, Value vC,
StructuredIndexed A(vA), B(vB), C(vC);
return makeGenericLinalgOp(
{IteratorType::Parallel, IteratorType::Parallel, IteratorType::Reduction},
{A({m, k}), B({k, n})},
{C({m, n})},
regionBuilder);
// clang-format on
}

Operation *
mlir::edsc::ops::linalg_generic_matmul(Value vA, Value vB, RankedTensorType tC,
MatmulRegionBuilder regionBuilder) {
// clang-format off
AffineExpr m, n, k;
bindDims(ScopedContext::getContext(), m, n, k);
StructuredIndexed A(vA), B(vB), C(tC);
return makeGenericLinalgOp(
{IteratorType::Parallel, IteratorType::Parallel, IteratorType::Reduction},
{A({m, k}), B({k, n})},
{C({m, n})},
/*inputs=*/{A({m, k}), B({k, n})},
/*outputBuffers=*/{C({m, n})},
/*initTensors=*/{},
/*resultTensorTypes=*/{},
regionBuilder);
// clang-format on
}
Expand All @@ -216,8 +198,10 @@ mlir::edsc::ops::linalg_generic_matmul(Value vA, Value vB, Value vC,
StructuredIndexed A(vA), B(vB), C(vC), D(tD);
return makeGenericLinalgOp(
{IteratorType::Parallel, IteratorType::Parallel, IteratorType::Reduction},
{A({m, k}), B({k, n}), C({m, n})},
{D({m, n})},
/*inputs=*/{A({m, k}), B({k, n})},
/*outputBuffers=*/{},
/*initTensors=*/{C({m, n})},
/*resultTensorTypes=*/{D({m, n})},
regionBuilder);
// clang-format on
}
Expand All @@ -243,15 +227,18 @@ Operation *mlir::edsc::ops::linalg_generic_conv_nhwc(Value vI, Value vW,
StructuredIndexed I(vI), W(vW), O(vO);
// clang-format off
return makeGenericLinalgOp(
{par, par, par, par, red, red, red}, {
{par, par, par, par, red, red, red},
/*inputs=*/{
I({b,
// Roundtrip to flattened form to serve as canonicalization and ensure
// consistent ordering of subexpressions.
simplifyAffineExpr(s[0] * h + d[0] * kh, numDims, 0),
simplifyAffineExpr(s[1] * w + d[1] * kw, numDims, 0),
c}),
W({kh, kw, c, f})}, {
O({b, h, w, f})},
W({kh, kw, c, f}) },
/*outputBuffers=*/{ O({b, h, w, f}) },
/*initTensors=*/{},
/*resultTensorTypes=*/{},
macRegionBuilder);
// clang-format on
}
Expand All @@ -276,15 +263,19 @@ Operation *mlir::edsc::ops::linalg_generic_dilated_conv_nhwc(
unsigned numDims = kw.cast<AffineDimExpr>().getPosition() + 1;
StructuredIndexed I(vI), W(vW), O(vO);
return makeGenericLinalgOp(
{par, par, par, par, par, red, red}, {
{par, par, par, par, par, red, red},
/*inputs=*/{
I({b,
// Roundtrip to flattened form to serve as canonicalization and ensure
// consistent ordering of subexpressions.
simplifyAffineExpr(s[0] * h + d[0] * kh, numDims, 0),
simplifyAffineExpr(s[1] * w + d[1] * kw, numDims, 0),
c}),
W({kh, kw, c, dm})}, {
W({kh, kw, c, dm})},
/*outputBuffers=*/{
O({b, h, w, simplifyAffineExpr(c * depth_multiplier + dm, numDims, 0)})},
/*initTensors=*/{},
/*resultTensorTypes=*/{},
macRegionBuilder);
// clang-format on
}
336 changes: 232 additions & 104 deletions mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp

Large diffs are not rendered by default.

56 changes: 32 additions & 24 deletions mlir/lib/Dialect/Linalg/Transforms/DropUnitDims.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -261,12 +261,14 @@ static UnitExtentReplacementInfo replaceUnitExtents(AffineMap indexMap,
}

namespace {

/// Pattern to replace tensors operands/results that are unit extents.
struct ReplaceUnitExtentTensors : public OpRewritePattern<GenericOp> {
using OpRewritePattern<GenericOp>::OpRewritePattern;
LogicalResult matchAndRewrite(GenericOp genericOp,
PatternRewriter &rewriter) const override {
if (!genericOp.hasTensorSemantics())
// TODO: support init_tensors and reductions.
if (!genericOp.hasTensorSemantics() || !genericOp.init_tensors().empty())
return failure();

MLIRContext *context = rewriter.getContext();
Expand All @@ -283,8 +285,7 @@ struct ReplaceUnitExtentTensors : public OpRewritePattern<GenericOp> {
reassociationMaps.push_back(replacementInfo.reassociation);
newIndexingMaps.push_back(replacementInfo.indexMap);
newInputOutputTypes.push_back(replacementInfo.type);
doCanonicalization =
doCanonicalization || replacementInfo.type != std::get<1>(it);
doCanonicalization |= replacementInfo.type != std::get<1>(it);
}

// If the indexing maps of the result operation are not invertible (i.e. not
Expand All @@ -295,32 +296,40 @@ struct ReplaceUnitExtentTensors : public OpRewritePattern<GenericOp> {

// If any operand type change, insert a reshape to convert from the original
// type to the new type.
SmallVector<Value, 4> newOperands;
newOperands.reserve(genericOp.getNumOperands());
for (auto operand : llvm::enumerate(genericOp.getOperands())) {
if (operand.value().getType() == newInputOutputTypes[operand.index()]) {
newOperands.push_back(operand.value());
} else {
newOperands.push_back(rewriter.create<linalg::TensorReshapeOp>(
loc, newInputOutputTypes[operand.index()], operand.value(),
reassociationMaps[operand.index()]));
// TODO: get rid of flattenedIdx which assumes operand order and contiguity.
unsigned flattenedIdx = 0;
auto insertReshapes = [&](ValueRange values) {
SmallVector<Value, 4> res;
res.reserve(values.size());
for (auto operand : llvm::enumerate(values)) {
if (operand.value().getType() == newInputOutputTypes[flattenedIdx])
res.push_back(operand.value());
else
res.push_back(rewriter.create<linalg::TensorReshapeOp>(
loc, newInputOutputTypes[flattenedIdx], operand.value(),
reassociationMaps[flattenedIdx]));
++flattenedIdx;
}
}
return res;
};

SmallVector<Value, 4> newInputs = insertReshapes(genericOp.inputs());
SmallVector<Value, 4> newOutputBuffers =
insertReshapes(genericOp.output_buffers());
SmallVector<Value, 4> newInitTensors =
insertReshapes(genericOp.init_tensors());

// If any result type change, insert a reshape to convert from the original
// type to the new type.
SmallVector<Type, 4> resultTypes;
resultTypes.reserve(genericOp.getNumResults());
for (unsigned i : llvm::seq<unsigned>(0, genericOp.getNumResults()))
resultTypes.push_back(
newInputOutputTypes[i + genericOp.getNumOperands()]);
resultTypes.push_back(newInputOutputTypes[i + genericOp.getNumInputs()]);
GenericOp replacementOp = rewriter.create<GenericOp>(
loc, resultTypes, newOperands, genericOp.args_in(),
genericOp.args_out(), rewriter.getAffineMapArrayAttr(newIndexingMaps),
genericOp.iterator_types(),
/*doc = */ nullptr,
/*library_call = */ nullptr,
/*symbol_source = */ nullptr);
loc, resultTypes, newInputs, newOutputBuffers, newInitTensors,
newIndexingMaps,
llvm::to_vector<4>(
genericOp.iterator_types().getAsValueRange<StringAttr>()));
rewriter.inlineRegionBefore(genericOp.region(), replacementOp.region(),
replacementOp.region().begin());

Expand All @@ -332,12 +341,11 @@ struct ReplaceUnitExtentTensors : public OpRewritePattern<GenericOp> {
RankedTensorType origResultType = genericOp.getResult(result.index())
.getType()
.cast<RankedTensorType>();
if (origResultType != result.value().getType()) {
if (origResultType != result.value().getType())
resultReplacements.push_back(rewriter.create<linalg::TensorReshapeOp>(
loc, origResultType, result.value(), reassociationMaps[index]));
} else {
else
resultReplacements.push_back(result.value());
}
}
rewriter.replaceOp(genericOp, resultReplacements);
return success();
Expand Down
59 changes: 30 additions & 29 deletions mlir/lib/Dialect/Linalg/Transforms/Fusion.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -499,32 +499,31 @@ struct FuseGenericOpsOnTensors {
consumerIndexMaps.end());

// Generate the fused op.
// Tensor-level fusion is only on ops without initTensors and outputBuffers.
LinalgOp fusedOp;
if (isa<GenericOp>(producer.getOperation()) &&
isa<GenericOp>(consumer.getOperation())) {
fusedOp =
rewriter
.create<GenericOp>(
rewriter.getUnknownLoc(),
consumer.getOperation()->getResultTypes(), fusedOperands,
rewriter.getI64IntegerAttr(fusedOperands.size()),
rewriter.getI64IntegerAttr(
consumer.getOperation()->getNumResults()),
rewriter.getArrayAttr(fusedIndexMaps),
consumer.iterator_types(),
/*doc=*/nullptr,
/*library_call=*/nullptr,
/*symbol_source=*/nullptr)
.create<GenericOp>(consumer.getLoc(),
consumer.getOperation()->getResultTypes(),
/*inputs=*/fusedOperands,
/*outputBuffers=*/ValueRange{},
/*initTensors=*/ValueRange{},
rewriter.getArrayAttr(fusedIndexMaps),
consumer.iterator_types(),
/*doc=*/nullptr,
/*library_call=*/nullptr,
/*symbol_source=*/nullptr)
.getOperation();
} else {
fusedOp =
rewriter
.create<IndexedGenericOp>(
rewriter.getUnknownLoc(),
consumer.getOperation()->getResultTypes(), fusedOperands,
rewriter.getI64IntegerAttr(fusedOperands.size()),
rewriter.getI64IntegerAttr(
consumer.getOperation()->getNumResults()),
consumer.getLoc(), consumer.getOperation()->getResultTypes(),
/*inputs=*/fusedOperands,
/*outputBuffers=*/ValueRange{},
/*initTensors=*/ValueRange{},
rewriter.getArrayAttr(fusedIndexMaps),
consumer.iterator_types(),
/*doc=*/nullptr,
Expand Down Expand Up @@ -812,9 +811,10 @@ struct FuseTensorReshapeOpAsProducer {
}));
LinalgOp fusedOp = createLinalgOpOfSameType(
consumer, rewriter, rewriter.getUnknownLoc(),
consumerOp->getResultTypes(), fusedOperands,
rewriter.getI64IntegerAttr(fusedOperands.size()),
rewriter.getI64IntegerAttr(consumerOp->getNumResults()),
consumerOp->getResultTypes(),
/*inputs=*/fusedOperands,
/*outputBuffers=*/ValueRange{},
/*initTensors=*/ValueRange{}, // no init tensors for now.
rewriter.getArrayAttr(indexMapAttrs), consumer.iterator_types(),
/*doc=*/nullptr,
/*library_call=*/nullptr,
Expand Down Expand Up @@ -871,10 +871,10 @@ struct FuseTensorReshapeOpAsConsumer {
Operation *producerOp = producer.getOperation();
LinalgOp fusedOp = createLinalgOpOfSameType(
producer, rewriter, rewriter.getUnknownLoc(), consumer.getResultType(),
producerOp->getOperands(),
rewriter.getI64IntegerAttr(producerOp->getNumOperands()),
rewriter.getI64IntegerAttr(1), rewriter.getArrayAttr(indexMapAttrs),
producer.iterator_types(),
/*inputs=*/producerOp->getOperands(),
/*outputBuffers=*/ValueRange{},
/*initTensors=*/ValueRange{}, // no init tensors for now.
rewriter.getArrayAttr(indexMapAttrs), producer.iterator_types(),
/*doc=*/nullptr,
/*library_call=*/nullptr,
/*symbol_source=*/nullptr);
Expand Down Expand Up @@ -932,10 +932,10 @@ struct FuseTensorReshapeOpAsConsumer {
}

int rank = dstShape.size();
int numArgsIn = producer.getNumInputs();
int numArgsOut = producer.getNumOutputs();
auto genericOp = rewriter.create<linalg::GenericOp>(
loc, resultTypes, args, numArgsIn, numArgsOut,
loc, resultTypes, /*inputs=*/args,
/*outputBuffers=*/ValueRange{},
/*initTensors=*/ValueRange{},
SmallVector<AffineMap, 3>(args.size() + resultTypes.size(),
rewriter.getMultiDimIdentityMap(rank)),
SmallVector<StringRef, 3>(rank, getParallelIteratorTypeName()));
Expand Down Expand Up @@ -995,9 +995,10 @@ struct FuseConstantOpAsProducer {

LinalgOp fusedOp = createLinalgOpOfSameType(
consumer, rewriter, rewriter.getUnknownLoc(),
consumerOp->getResultTypes(), fusedOperands,
rewriter.getI64IntegerAttr(consumerOp->getNumOperands() - 1),
rewriter.getI64IntegerAttr(consumerOp->getNumResults()),
consumerOp->getResultTypes(),
/*inputs=*/fusedOperands,
/*outputBuffers=*/ValueRange{},
/*initTensors=*/ValueRange{}, // no init tensors for now.
rewriter.getAffineMapArrayAttr(fusedIndexMaps),
consumer.iterator_types(),
/*doc=*/nullptr,
Expand Down
53 changes: 33 additions & 20 deletions mlir/lib/Dialect/Linalg/Transforms/TensorsToBuffers.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -36,32 +36,45 @@ class GenericOpConverter
LogicalResult
matchAndRewrite(linalg::GenericOp op, ArrayRef<Value> operands,
ConversionPatternRewriter &rewriter) const final {
linalg::GenericOpAdaptor adaptor(operands,
op.getOperation()->getAttrDictionary());

// TODO: support ops with reduction.
if (!op.init_tensors().empty())
return failure();

// All inputs need to be turned into buffers first. Until then, bail out.
if (llvm::any_of(adaptor.inputs(),
[](Value in) { return !in.getType().isa<MemRefType>(); }))
return failure();

Location loc = op.getLoc();
ResultRange results = op.getOperation()->getResults();
SmallVector<Value, 2> newArgs, newResults;
newArgs.reserve(operands.size() + results.size());
newArgs.append(operands.begin(), operands.end());
newResults.reserve(results.size());
SmallVector<Value, 2> outputBuffers, newOutputBuffers;
outputBuffers.assign(adaptor.output_buffers().begin(),
adaptor.output_buffers().end());
newOutputBuffers.reserve(op.getNumOutputs());
newOutputBuffers.append(adaptor.output_buffers().begin(),
adaptor.output_buffers().end());

// Update all types to memref types.
for (auto result : results) {
auto type = result.getType().cast<ShapedType>();
assert(type && "tensor to buffer conversion expects ranked results");
for (Type t : op.getResultTypes()) {
auto type = t.cast<ShapedType>();
if (!type.hasStaticShape())
return rewriter.notifyMatchFailure(
op, "dynamic shapes not currently supported");
auto memrefType = MemRefType::get(type.getShape(), type.getElementType());
auto alloc = rewriter.create<AllocOp>(loc, memrefType);
newArgs.push_back(alloc);
newResults.push_back(alloc);
newOutputBuffers.push_back(alloc);
}

// Generate a new linalg operation that works on buffers.
auto linalgOp = rewriter.create<linalg::GenericOp>(
loc, llvm::None, newArgs, rewriter.getI64IntegerAttr(operands.size()),
rewriter.getI64IntegerAttr(results.size()), op.indexing_maps(),
op.iterator_types(), op.docAttr(), op.library_callAttr(),
op.symbol_sourceAttr());
loc,
/*resultTensorTypes=*/ArrayRef<Type>{},
/*inputs=*/adaptor.inputs(),
/*outputBuffers=*/newOutputBuffers,
/*initTensors=*/ValueRange{}, op.indexing_maps(), op.iterator_types(),
op.docAttr(), op.library_callAttr(), op.symbol_sourceAttr());

// Create a new block in the region of the new Generic Op.
Block &oldBlock = op.getRegion().front();
Expand All @@ -70,23 +83,23 @@ class GenericOpConverter
oldBlock.getArgumentTypes());

// Add the result arguments to the new block.
for (auto result : newResults)
newBlock->addArgument(
result.getType().cast<ShapedType>().getElementType());
for (Value v : newOutputBuffers)
newBlock->addArgument(v.getType().cast<MemRefType>().getElementType());

// Clone the body of the old block to the new block.
BlockAndValueMapping mapping;
for (unsigned i = 0; i < oldBlock.getNumArguments(); i++)
mapping.map(oldBlock.getArgument(i), newBlock->getArgument(i));

OpBuilder::InsertionGuard guard(rewriter);
rewriter.setInsertionPointToEnd(newBlock);
for (auto &op : oldBlock.getOperations()) {
Operation *clonedOp = rewriter.clone(op, mapping);
mapping.map(op.getResults(), clonedOp->getResults());
}

// Replace the results of the old Generic Op with the results of the new
// one.
rewriter.replaceOp(op, newResults);
// Replace the results of the old op with the new output buffers.
rewriter.replaceOp(op, newOutputBuffers);
return success();
}
};
Expand Down
32 changes: 16 additions & 16 deletions mlir/test/Conversion/LinalgToSPIRV/linalg-to-spirv.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -5,8 +5,6 @@
//===----------------------------------------------------------------------===//

#single_workgroup_reduction_trait = {
args_in = 1,
args_out = 1,
iterator_types = ["reduction"],
indexing_maps = [
affine_map<(i) -> (i)>,
Expand Down Expand Up @@ -49,11 +47,13 @@ module attributes {
func @single_workgroup_reduction(%input: memref<16xi32>, %output: memref<1xi32>) attributes {
spv.entry_point_abi = {local_size = dense<[16, 1, 1]>: vector<3xi32>}
} {
linalg.generic #single_workgroup_reduction_trait %input, %output {
linalg.generic #single_workgroup_reduction_trait
ins(%input : memref<16xi32>)
outs(%output : memref<1xi32>) {
^bb(%in: i32, %out: i32):
%sum = addi %in, %out : i32
linalg.yield %sum : i32
} : memref<16xi32>, memref<1xi32>
}
spv.Return
}
}
Expand All @@ -63,8 +63,6 @@ func @single_workgroup_reduction(%input: memref<16xi32>, %output: memref<1xi32>)
// Missing shader entry point ABI

#single_workgroup_reduction_trait = {
args_in = 1,
args_out = 1,
iterator_types = ["reduction"],
indexing_maps = [
affine_map<(i) -> (i)>,
Expand All @@ -78,11 +76,13 @@ module attributes {
} {
func @single_workgroup_reduction(%input: memref<16xi32>, %output: memref<1xi32>) {
// expected-error @+1 {{failed to legalize operation 'linalg.generic'}}
linalg.generic #single_workgroup_reduction_trait %input, %output {
linalg.generic #single_workgroup_reduction_trait
ins(%input : memref<16xi32>)
outs(%output : memref<1xi32>) {
^bb(%in: i32, %out: i32):
%sum = addi %in, %out : i32
linalg.yield %sum : i32
} : memref<16xi32>, memref<1xi32>
}
return
}
}
Expand All @@ -92,8 +92,6 @@ func @single_workgroup_reduction(%input: memref<16xi32>, %output: memref<1xi32>)
// Mismatch between shader entry point ABI and input memref shape

#single_workgroup_reduction_trait = {
args_in = 1,
args_out = 1,
iterator_types = ["reduction"],
indexing_maps = [
affine_map<(i) -> (i)>,
Expand All @@ -109,11 +107,13 @@ func @single_workgroup_reduction(%input: memref<16xi32>, %output: memref<1xi32>)
spv.entry_point_abi = {local_size = dense<[32, 1, 1]>: vector<3xi32>}
} {
// expected-error @+1 {{failed to legalize operation 'linalg.generic'}}
linalg.generic #single_workgroup_reduction_trait %input, %output {
linalg.generic #single_workgroup_reduction_trait
ins(%input : memref<16xi32>)
outs(%output : memref<1xi32>) {
^bb(%in: i32, %out: i32):
%sum = addi %in, %out : i32
linalg.yield %sum : i32
} : memref<16xi32>, memref<1xi32>
}
spv.Return
}
}
Expand All @@ -123,8 +123,6 @@ func @single_workgroup_reduction(%input: memref<16xi32>, %output: memref<1xi32>)
// Unsupported multi-dimension input memref

#single_workgroup_reduction_trait = {
args_in = 1,
args_out = 1,
iterator_types = ["parallel", "reduction"],
indexing_maps = [
affine_map<(i, j) -> (i, j)>,
Expand All @@ -140,11 +138,13 @@ func @single_workgroup_reduction(%input: memref<16x8xi32>, %output: memref<16xi3
spv.entry_point_abi = {local_size = dense<[16, 8, 1]>: vector<3xi32>}
} {
// expected-error @+1 {{failed to legalize operation 'linalg.generic'}}
linalg.generic #single_workgroup_reduction_trait %input, %output {
linalg.generic #single_workgroup_reduction_trait
ins(%input : memref<16x8xi32>)
outs(%output : memref<16xi32>) {
^bb(%in: i32, %out: i32):
%sum = addi %in, %out : i32
linalg.yield %sum : i32
} : memref<16x8xi32>, memref<16xi32>
}
spv.Return
}
}
6 changes: 2 additions & 4 deletions mlir/test/Dialect/Linalg/canonicalize.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -182,8 +182,6 @@ func @no_fold_memref_reshape(%arg0 : memref<?x?xf32>) -> memref<?x?xf32>
]

#trait = {
args_in = 1,
args_out = 1,
indexing_maps = #accesses,
iterator_types = ["parallel"]
}
Expand All @@ -193,10 +191,10 @@ func @dce_zero_memref(%arg0 : memref<0xf32>, %arg1: tensor<0xf32>) -> tensor<0xf
linalg.copy(%arg0, %arg0): memref<0xf32>, memref<0xf32>

// tensor<0xf32> cannot be dce'ed
%1 = linalg.generic #trait %arg1 {
%1 = linalg.generic #trait ins(%arg1 : tensor<0xf32>) {
^bb(%0: f32) :
linalg.yield %0 : f32
} : tensor<0xf32> -> tensor<0xf32>
} -> tensor<0xf32>

return %1: tensor<0xf32>
}
Expand Down
39 changes: 17 additions & 22 deletions mlir/test/Dialect/Linalg/drop-unit-extent-dims.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -6,19 +6,18 @@
]

#trait = {
args_in = 1,
args_out = 1,
iterator_types = ["parallel", "parallel", "parallel", "parallel", "parallel"],
indexing_maps = #accesses,
library_call = "some_external_func"
}

func @drop_one_trip_loops(%arg0 : tensor<?x1x?xf32>) -> tensor<?x1x?x1x?xf32>
{
%0 = linalg.generic #trait %arg0 {
%0 = linalg.generic #trait
ins(%arg0 : tensor<?x1x?xf32>) {
^bb0(%arg1 : f32) :
linalg.yield %arg1 : f32
} : tensor<?x1x?xf32> -> tensor<?x1x?x1x?xf32>
} -> tensor<?x1x?x1x?xf32>
return %0 : tensor<?x1x?x1x?xf32>
}
// CHECK-DAG: #[[$MAP0:.*]] = affine_map<(d0, d1, d2) -> (d0, d1)>
Expand All @@ -40,19 +39,18 @@ func @drop_one_trip_loops(%arg0 : tensor<?x1x?xf32>) -> tensor<?x1x?x1x?xf32>
#map0 = affine_map<(i, j) -> (i, j)>
#access = [#map0, #map0]
#trait = {
args_in = 1,
args_out = 1,
iterator_types = ["parallel", "parallel"],
indexing_maps = #access,
library_call = "some_external_func"
}

func @drop_all_loops(%arg0 : tensor<1x1xf32>) -> tensor<1x1xf32>
{
%0 = linalg.generic #trait %arg0 {
%0 = linalg.generic #trait
ins(%arg0 : tensor<1x1xf32>) {
^bb0(%arg1: f32) :
linalg.yield %arg1 : f32
} : tensor<1x1xf32> -> tensor<1x1xf32>
} -> tensor<1x1xf32>
return %0 : tensor<1x1xf32>
}
// CHECK-DAG: #[[$MAP0:.*]] = affine_map<() -> ()>
Expand All @@ -70,18 +68,17 @@ func @drop_all_loops(%arg0 : tensor<1x1xf32>) -> tensor<1x1xf32>
]

#trait = {
args_in = 1,
args_out = 1,
indexing_maps = #accesses,
iterator_types = ["parallel"],
library_call = "some_external_fn"
}

func @leading_dim_1_canonicalization(%arg0: tensor<1x5xf32>) -> tensor<5xf32> {
%0 = linalg.generic #trait %arg0 {
%0 = linalg.generic #trait
ins(%arg0 : tensor<1x5xf32>) {
^bb0(%arg2: f32): // no predecessors
linalg.yield %arg2 : f32
} : tensor<1x5xf32> -> tensor<5xf32>
} -> tensor<5xf32>
return %0 : tensor<5xf32>
}
// CHECK-DAG: #[[$MAP0:.*]] = affine_map<(d0, d1) -> (d0, d1)>
Expand All @@ -100,8 +97,6 @@ func @leading_dim_1_canonicalization(%arg0: tensor<1x5xf32>) -> tensor<5xf32> {
]

#trait = {
args_in = 2,
args_out = 1,
indexing_maps = #accesses,
iterator_types = ["parallel", "parallel"],
library_call = "some_external_fn"
Expand All @@ -113,11 +108,12 @@ func @broadcast_test(%arg0 : tensor<5xf32>, %arg1 : tensor<5xf32>) -> tensor<5x5
tensor<5xf32> into tensor<1x5xf32>
%1 = linalg.tensor_reshape %arg1 [affine_map<(d0, d1) -> (d0, d1)>] :
tensor<5xf32> into tensor<5x1xf32>
%2 = linalg.generic #trait %0, %1 {
%2 = linalg.generic #trait
ins(%0, %1 : tensor<1x5xf32>, tensor<5x1xf32>) {
^bb0(%arg2: f32, %arg3: f32):
%3 = addf %arg2, %arg3 : f32
linalg.yield %3 : f32
} : tensor<1x5xf32>, tensor<5x1xf32> -> tensor<5x5xf32>
} -> tensor<5x5xf32>
return %2 : tensor<5x5xf32>
}
// CHECK-DAG: #[[$MAP0:.*]] = affine_map<(d0, d1) -> (d1)>
Expand All @@ -138,19 +134,18 @@ func @broadcast_test(%arg0 : tensor<5xf32>, %arg1 : tensor<5xf32>) -> tensor<5x5
]

#trait = {
args_in = 1,
args_out = 1,
indexing_maps = #accesses,
iterator_types = ["parallel", "parallel"],
library_call = "some_external_fn"
}

func @broadcast_scalar(%arg0 : tensor<1x1xf32>) -> tensor<?x?xf32>
{
%0 = linalg.generic #trait %arg0 {
^bb0(%arg1 : f32):
linalg.yield %arg1 : f32
} : tensor<1x1xf32> -> tensor<?x?xf32>
%0 = linalg.generic #trait
ins(%arg0 : tensor<1x1xf32>) {
^bb0(%arg1 : f32):
linalg.yield %arg1 : f32
} -> tensor<?x?xf32>
return %0 : tensor<?x?xf32>
}
// CHECK-DAG: #[[$MAP0:.*]] = affine_map<(d0, d1) -> ()>
Expand Down
33 changes: 15 additions & 18 deletions mlir/test/Dialect/Linalg/fold-unit-trip-loops.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -6,19 +6,18 @@
]

#trait = {
args_in = 1,
args_out = 1,
iterator_types = ["parallel", "parallel", "parallel", "parallel", "parallel"],
indexing_maps = #accesses,
library_call = "some_external_func"
}

func @drop_one_trip_loops(%arg0 : tensor<?x1x?xf32>) -> tensor<?x1x?x1x?xf32>
{
%0 = linalg.generic #trait %arg0 {
%0 = linalg.generic #trait
ins(%arg0 : tensor<?x1x?xf32>) {
^bb0(%arg1 : f32) :
linalg.yield %arg1 : f32
} : tensor<?x1x?xf32> -> tensor<?x1x?x1x?xf32>
} -> tensor<?x1x?x1x?xf32>
return %0 : tensor<?x1x?x1x?xf32>
}
// CHECK-DAG: #[[$MAP0:.*]] = affine_map<(d0, d1, d2) -> (d0, 0, d2)>
Expand All @@ -33,19 +32,18 @@ func @drop_one_trip_loops(%arg0 : tensor<?x1x?xf32>) -> tensor<?x1x?x1x?xf32>
#map0 = affine_map<(i, j) -> (i, j)>
#access = [#map0, #map0]
#trait = {
args_in = 1,
args_out = 1,
iterator_types = ["parallel", "parallel"],
indexing_maps = #access,
library_call = "some_external_func"
}

func @drop_all_loops(%arg0 : tensor<1x1xf32>) -> tensor<1x1xf32>
{
%0 = linalg.generic #trait %arg0 {
%0 = linalg.generic #trait
ins(%arg0 : tensor<1x1xf32>) {
^bb0(%arg1: f32) :
linalg.yield %arg1 : f32
} : tensor<1x1xf32> -> tensor<1x1xf32>
} -> tensor<1x1xf32>
return %0 : tensor<1x1xf32>
}
// CHECK-DAG: #[[$MAP0:.*]] = affine_map<() -> (0, 0)>
Expand All @@ -59,19 +57,19 @@ func @drop_all_loops(%arg0 : tensor<1x1xf32>) -> tensor<1x1xf32>
#map0 = affine_map<(i, j) -> (i, j)>
#access = [#map0, #map0]
#trait = {
args_in = 1,
args_out = 1,
iterator_types = ["parallel", "parallel"],
indexing_maps = #access,
library_call = "some_external_func"
}

func @drop_all_loops(%arg0 : memref<1x1xf32>, %arg1 : memref<1x1xf32>)
{
linalg.generic #trait %arg0, %arg1 {
linalg.generic #trait
ins(%arg0 : memref<1x1xf32>)
outs(%arg1 : memref<1x1xf32>) {
^bb0(%arg2: f32, %arg3 : f32) :
linalg.yield %arg2 : f32
} : memref<1x1xf32>, memref<1x1xf32>
}
return
}
// CHECK-DAG: #[[$MAP0:.*]] = affine_map<() -> (0, 0)>
Expand All @@ -88,18 +86,17 @@ func @drop_all_loops(%arg0 : memref<1x1xf32>, %arg1 : memref<1x1xf32>)
]

#trait = {
args_in = 1,
args_out = 1,
indexing_maps = #accesses,
iterator_types = ["parallel", "parallel"],
library_call = "some_external_fn"
}

func @leading_dim_1_canonicalization(%arg0: tensor<1x5xf32>) -> tensor<5xf32> {
%0 = linalg.generic #trait %arg0 {
^bb0(%arg2: f32): // no predecessors
linalg.yield %arg2 : f32
} : tensor<1x5xf32> -> tensor<5xf32>
%0 = linalg.generic #trait
ins(%arg0 : tensor<1x5xf32>) {
^bb0(%arg2: f32): // no predecessors
linalg.yield %arg2 : f32
} -> tensor<5xf32>
return %0 : tensor<5xf32>
}
// CHECK-DAG: #[[$MAP0:.*]] = affine_map<(d0) -> (0, d0)>
Expand Down
198 changes: 86 additions & 112 deletions mlir/test/Dialect/Linalg/fusion-tensor.mlir

Large diffs are not rendered by default.

67 changes: 31 additions & 36 deletions mlir/test/Dialect/Linalg/fusion.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -470,8 +470,6 @@ func @f8(%A: memref<?x?xf32, offset: 0, strides: [?, ?]>,

#id_2d = affine_map<(i, j) -> (i, j)>
#pointwise_2d_trait = {
args_in = 2,
args_out = 1,
indexing_maps = [#id_2d, #id_2d, #id_2d],
iterator_types = ["parallel", "parallel"]
}
Expand All @@ -483,13 +481,14 @@ func @pointwise(%A: memref<?x?xf32, offset: 0, strides: [?, ?]>,
%c0 = constant 0 : index
%c3 = constant 3 : index
%c2 = constant 2 : index
linalg.generic #pointwise_2d_trait %A, %A, %B {
linalg.generic #pointwise_2d_trait
ins(%A, %A: memref<?x?xf32, offset: 0, strides: [?, ?]>,
memref<?x?xf32, offset: 0, strides: [?, ?]>)
outs(%B : memref<?x?xf32, offset: 0, strides: [?, ?]>) {
^bb0(%E: f32, %arg5: f32, %arg6: f32): // no predecessors
%2 = addf %E, %arg5 : f32
linalg.yield %2 : f32
}: memref<?x?xf32, offset: 0, strides: [?, ?]>,
memref<?x?xf32, offset: 0, strides: [?, ?]>,
memref<?x?xf32, offset: 0, strides: [?, ?]>
}
%0 = dim %B, %c0 : memref<?x?xf32, offset: 0, strides: [?, ?]>
%1 = dim %B, %c1 : memref<?x?xf32, offset: 0, strides: [?, ?]>
scf.for %arg4 = %c0 to %0 step %c2 {
Expand All @@ -503,13 +502,14 @@ func @pointwise(%A: memref<?x?xf32, offset: 0, strides: [?, ?]>,
%6 = std.subview %D[%arg4, %arg5][%c2, %c3][%c1, %c1] :
memref<?x?xf32, offset: 0, strides: [?, ?]> to
memref<?x?xf32, offset: ?, strides: [?, ?]>
linalg.generic #pointwise_2d_trait %4, %5, %6 {
linalg.generic #pointwise_2d_trait
ins(%4, %5: memref<?x?xf32, offset: ?, strides: [?, ?]>,
memref<?x?xf32, offset: ?, strides: [?, ?]>)
outs(%6 : memref<?x?xf32, offset: ?, strides: [?, ?]>) {
^bb0(%arg6: f32, %arg7: f32, %arg8: f32): // no predecessors
%7 = mulf %arg6, %arg7 : f32
linalg.yield %7 : f32
}: memref<?x?xf32, offset: ?, strides: [?, ?]>,
memref<?x?xf32, offset: ?, strides: [?, ?]>,
memref<?x?xf32, offset: ?, strides: [?, ?]>
}
}
}
return
Expand All @@ -527,8 +527,6 @@ func @pointwise(%A: memref<?x?xf32, offset: 0, strides: [?, ?]>,

#id_2d = affine_map<(i, j) -> (i, j)>
#pointwise_2d_trait = {
args_in = 2,
args_out = 1,
indexing_maps = [#id_2d, #id_2d, #id_2d],
iterator_types = ["parallel", "parallel"]
}
Expand All @@ -542,13 +540,13 @@ func @pointwise_no_view(%M: index, %N: index) {
%C = alloc (%M, %N): memref<?x?xf32>
%D = alloc (%M, %N): memref<?x?xf32>
%E = alloc (%M, %N): memref<?x?xf32>
linalg.generic #pointwise_2d_trait %A, %A, %B {
linalg.generic #pointwise_2d_trait
ins(%A, %A : memref<?x?xf32>, memref<?x?xf32>)
outs(%B : memref<?x?xf32>) {
^bb0(%e: f32, %arg5: f32, %arg6: f32): // no predecessors
%2 = addf %e, %arg5 : f32
linalg.yield %2 : f32
}: memref<?x?xf32>,
memref<?x?xf32>,
memref<?x?xf32>
}
%0 = dim %B, %c0 : memref<?x?xf32>
%1 = dim %B, %c1 : memref<?x?xf32>
scf.for %arg4 = %c0 to %0 step %c2 {
Expand All @@ -562,13 +560,14 @@ func @pointwise_no_view(%M: index, %N: index) {
%6 = std.subview %D[%arg4, %arg5][%c2, %c3][%c1, %c1] :
memref<?x?xf32> to
memref<?x?xf32, offset: ?, strides: [?, ?]>
linalg.generic #pointwise_2d_trait %4, %5, %6 {
linalg.generic #pointwise_2d_trait
ins(%4, %5: memref<?x?xf32, offset: ?, strides: [?, ?]>,
memref<?x?xf32, offset: ?, strides: [?, ?]>)
outs(%6 : memref<?x?xf32, offset: ?, strides: [?, ?]>) {
^bb0(%arg6: f32, %arg7: f32, %arg8: f32): // no predecessors
%7 = mulf %arg6, %arg7 : f32
linalg.yield %7 : f32
}: memref<?x?xf32, offset: ?, strides: [?, ?]>,
memref<?x?xf32, offset: ?, strides: [?, ?]>,
memref<?x?xf32, offset: ?, strides: [?, ?]>
}
}
}
return
Expand Down Expand Up @@ -596,25 +595,23 @@ func @fusion_of_three(%arg0: memref<100x10xf32>,
%c1 = constant 1 : index
%0 = alloc() {temp = true} : memref<100x10xf32>
linalg.generic {
args_in = 1 : i64,
args_out = 1 : i64,
indexing_maps = [#map0, #map1],
iterator_types = ["parallel", "parallel"]
} %arg1, %0 {
iterator_types = ["parallel", "parallel"]}
ins(%arg1 : memref<100xf32>)
outs(%0 : memref<100x10xf32>) {
^bb0(%arg3: f32, %arg4: f32): // no predecessors
linalg.yield %arg3 : f32
}: memref<100xf32>, memref<100x10xf32>
}
%1 = alloc() {temp = true} : memref<100x10xf32>
linalg.generic {
args_in = 2 : i64,
args_out = 1 : i64,
indexing_maps = [#map1, #map1, #map1],
iterator_types = ["parallel", "parallel"]
} %arg0, %0, %1 {
iterator_types = ["parallel", "parallel"]}
ins(%arg0, %0: memref<100x10xf32>, memref<100x10xf32>)
outs(%1 : memref<100x10xf32>) {
^bb0(%arg3: f32, %arg4: f32, %arg5: f32): // no predecessors
%2 = subf %arg3, %arg4 : f32
linalg.yield %2 : f32
}: memref<100x10xf32>, memref<100x10xf32>, memref<100x10xf32>
}
dealloc %0 : memref<100x10xf32>
%2 = dim %1, %c0 : memref<100x10xf32>
%3 = dim %1, %c1 : memref<100x10xf32>
Expand All @@ -627,16 +624,14 @@ func @fusion_of_three(%arg0: memref<100x10xf32>,
%7 = std.subview %arg2[%i, %j][%c1, %c1][%c1, %c1] :
memref<100x10xf32> to memref<?x?xf32, #map2>
linalg.generic {
args_in = 1 : i64,
args_out = 1 : i64,
indexing_maps = [#map1, #map1],
iterator_types = ["parallel", "parallel"]
} %6, %7 {
iterator_types = ["parallel", "parallel"]}
ins(%6 : memref<?x?xf32, #map2>)
outs(%7 : memref<?x?xf32, #map2>) {
^bb0(%arg3: f32, %arg4: f32): // no predecessors
%8 = exp %arg3 : f32
linalg.yield %8 : f32
}: memref<?x?xf32, #map2>,
memref<?x?xf32, #map2>
}
}
}
dealloc %1 : memref<100x10xf32>
Expand Down
51 changes: 24 additions & 27 deletions mlir/test/Dialect/Linalg/fusion_indexed_generic.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -3,20 +3,20 @@
#map = affine_map<(d0, d1)[s0, s1, s2] -> (d0 * s1 + s0 + d1 * s2)>
#id_2d = affine_map<(d0, d1) -> (d0, d1)>
#pointwise_2d_trait = {
args_in = 2,
args_out = 1,
indexing_maps = [#id_2d, #id_2d, #id_2d],
iterator_types = ["parallel", "parallel"]
}
func @fuse_indexed_generic_consumer(%A: memref<?x?xf32>,
%B: memref<?x?xf32>,
%C: memref<?x?xf32>,
%D: memref<?x?xf32>) {
linalg.generic #pointwise_2d_trait %A, %B, %C {
linalg.generic #pointwise_2d_trait
ins(%A, %B: memref<?x?xf32>, memref<?x?xf32>)
outs(%C : memref<?x?xf32>) {
^bb0(%e: f32, %arg5: f32, %arg6: f32): // no predecessors
%2 = addf %e, %arg5 : f32
linalg.yield %2 : f32
}: memref<?x?xf32>, memref<?x?xf32>, memref<?x?xf32>
}
%c1 = constant 1 : index
%c0 = constant 0 : index
%c25 = constant 25 : index
Expand All @@ -33,10 +33,9 @@ func @fuse_indexed_generic_consumer(%A: memref<?x?xf32>,
memref<?x?xf32> to memref<?x?xf32, #map>
linalg.indexed_generic {
indexing_maps = [#id_2d, #id_2d],
iterator_types = ["parallel", "parallel"],
args_in = 1,
args_out = 1
} %4, %5 {
iterator_types = ["parallel", "parallel"]}
ins(%4 : memref<?x?xf32, #map>)
outs(%5 : memref<?x?xf32, #map>) {
^bb0(%arg4: index, %arg5: index, %arg6: f32, %arg7: f32):
%6 = addi %arg4, %arg2 : index
%7 = addi %arg5, %arg3 : index
Expand All @@ -46,7 +45,7 @@ func @fuse_indexed_generic_consumer(%A: memref<?x?xf32>,
%11 = sitofp %10 : i32 to f32
%12 = addf %9, %11 : f32
linalg.yield %12 : f32
}: memref<?x?xf32, #map>, memref<?x?xf32, #map>
}
}
}
return
Expand All @@ -66,8 +65,6 @@ func @fuse_indexed_generic_consumer(%A: memref<?x?xf32>,
#map = affine_map<(d0, d1)[s0, s1, s2] -> (d0 * s1 + s0 + d1 * s2)>
#id_2d = affine_map<(d0, d1) -> (d0, d1)>
#pointwise_2d_trait = {
args_in = 2,
args_out = 1,
indexing_maps = [#id_2d, #id_2d, #id_2d],
iterator_types = ["parallel", "parallel"]
}
Expand All @@ -79,14 +76,16 @@ func @fuse_indexed_generic_producer(%A: memref<?x?xf32>,
%c0 = constant 0 : index
%c25 = constant 25 : index
%c10 = constant 10 : index
linalg.indexed_generic #pointwise_2d_trait %A, %B, %C {
linalg.indexed_generic #pointwise_2d_trait
ins(%A, %B : memref<?x?xf32>, memref<?x?xf32>)
outs(%C : memref<?x?xf32>) {
^bb0(%i: index, %j: index, %a: f32, %b: f32, %c: f32): // no predecessors
%i_int = index_cast %i: index to i32
%i_float = sitofp %i_int : i32 to f32
%ab = addf %a, %b : f32
%out = addf %ab, %i_float : f32
linalg.yield %out : f32
}: memref<?x?xf32>, memref<?x?xf32>, memref<?x?xf32>
}
%C_X = dim %C, %c0 : memref<?x?xf32>
%C_Y = dim %C, %c1 : memref<?x?xf32>
%D_X = dim %D, %c0 : memref<?x?xf32>
Expand All @@ -98,14 +97,13 @@ func @fuse_indexed_generic_producer(%A: memref<?x?xf32>,
memref<?x?xf32> to memref<?x?xf32, #map>
linalg.generic {
indexing_maps = [#id_2d, #id_2d],
iterator_types = ["parallel", "parallel"],
args_in = 1,
args_out = 1
} %C_view, %D_view {
iterator_types = ["parallel", "parallel"]}
ins(%C_view : memref<?x?xf32, #map>)
outs(%D_view : memref<?x?xf32, #map>) {
^bb0( %a: f32, %b: f32):
%ab = addf %a, %b : f32
linalg.yield %ab : f32
}: memref<?x?xf32, #map>, memref<?x?xf32, #map>
}
}
return
}
Expand All @@ -125,8 +123,6 @@ func @fuse_indexed_generic_producer(%A: memref<?x?xf32>,
#map = affine_map<(d0, d1)[s0, s1, s2] -> (d0 * s1 + s0 + d1 * s2)>
#id_2d = affine_map<(d0, d1) -> (d0, d1)>
#pointwise_2d_trait = {
args_in = 2,
args_out = 1,
indexing_maps = [#id_2d, #id_2d, #id_2d],
iterator_types = ["parallel", "parallel"]
}
Expand All @@ -137,14 +133,16 @@ func @fuse_indexed_generic_producer_tile_second_dim_only(%A: memref<?x?xf32>,
%c1 = constant 1 : index
%c3 = constant 3 : index
%c0 = constant 0 : index
linalg.indexed_generic #pointwise_2d_trait %A, %B, %C {
linalg.indexed_generic #pointwise_2d_trait
ins(%A, %B: memref<?x?xf32>, memref<?x?xf32>)
outs(%C : memref<?x?xf32>) {
^bb0(%i: index, %j: index, %a: f32, %b: f32, %c: f32): // no predecessors
%j_int = index_cast %j: index to i32
%j_float = sitofp %j_int : i32 to f32
%ab = addf %a, %b : f32
%out = addf %ab, %j_float : f32
linalg.yield %out : f32
}: memref<?x?xf32>, memref<?x?xf32>, memref<?x?xf32>
}
%C_X = dim %C, %c0 : memref<?x?xf32>
%C_Y = dim %C, %c1 : memref<?x?xf32>
%D_X = dim %D, %c0 : memref<?x?xf32>
Expand All @@ -161,14 +159,13 @@ func @fuse_indexed_generic_producer_tile_second_dim_only(%A: memref<?x?xf32>,

linalg.generic {
indexing_maps = [#id_2d, #id_2d],
iterator_types = ["parallel", "parallel"],
args_in = 1,
args_out = 1
} %C_view, %D_view {
iterator_types = ["parallel", "parallel"]}
ins(%C_view : memref<?x?xf32, #map>)
outs(%D_view : memref<?x?xf32, #map>) {
^bb0( %a: f32, %b: f32):
%ab = addf %a, %b : f32
linalg.yield %ab : f32
}: memref<?x?xf32, #map>, memref<?x?xf32, #map>
}
scf.yield
}
return
Expand Down
8 changes: 4 additions & 4 deletions mlir/test/Dialect/Linalg/inlining.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -9,8 +9,6 @@
]

#trait = {
args_in = 1,
args_out = 1,
indexing_maps = #accesses,
iterator_types = ["parallel"]
}
Expand All @@ -23,9 +21,11 @@ func @inline_into(%arg0: memref<?xf32>) {

func @inlined_fn(%arg0: memref<?xf32>) {
// CHECK: linalg.generic
linalg.generic #trait %arg0, %arg0 {
linalg.generic #trait
ins(%arg0 : memref<?xf32>)
outs(%arg0 : memref<?xf32>) {
^bb(%0 : f32, %1 : f32) :
linalg.yield %0 : f32
} : memref<?xf32>, memref<?xf32>
}
return
}
240 changes: 72 additions & 168 deletions mlir/test/Dialect/Linalg/invalid.mlir

Large diffs are not rendered by default.

82 changes: 48 additions & 34 deletions mlir/test/Dialect/Linalg/loops.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -560,12 +560,15 @@ func @pooling_sum(%arg0: memref<?x?xf32>,
doc = "B(i,j,k), C(i,k,j) = foo(A(i, j), B(i,j,k), C(i,k,j))"
}
func @generic_region(%arg0: memref<?x?xf32, offset: ?, strides: [?, 1]>, %arg1: memref<?x?x?xf32, offset: ?, strides: [?, ?, 1]>, %arg2: memref<?x?x?xf32, offset: ?, strides: [?, ?, 1]>) {
linalg.generic #trait2 %arg0, %arg1, %arg2 {
linalg.generic #trait2
ins(%arg0: memref<?x?xf32, offset: ?, strides: [?, 1]>)
outs(%arg1, %arg2 : memref<?x?x?xf32, offset: ?, strides: [?, ?, 1]>,
memref<?x?x?xf32, offset: ?, strides: [?, ?, 1]>) {
^bb0(%a: f32, %b: f32, %c: f32):
%d = mulf %a, %b : f32
%e = addf %c, %d : f32
linalg.yield %d, %e : f32, f32
}: memref<?x?xf32, offset: ?, strides: [?, 1]>, memref<?x?x?xf32, offset: ?, strides: [?, ?, 1]>, memref<?x?x?xf32, offset: ?, strides: [?, ?, 1]>
}
return
}
// CHECKLOOP-LABEL: @generic_region
Expand Down Expand Up @@ -602,7 +605,10 @@ func @indexed_generic_region(
%arg0: memref<?x?xf32, offset: ?, strides: [?, 1]>,
%arg1: memref<?x?x?xf32, offset: ?, strides: [?, ?, 1]>,
%arg2: memref<?x?x?xf32, offset: ?, strides: [?, ?, 1]>) {
linalg.indexed_generic #trait4 %arg0, %arg1, %arg2 {
linalg.indexed_generic #trait4
ins(%arg0 : memref<?x?xf32, offset: ?, strides: [?, 1]>)
outs(%arg1, %arg2 : memref<?x?x?xf32, offset: ?, strides: [?, ?, 1]>,
memref<?x?x?xf32, offset: ?, strides: [?, ?, 1]>) {
^bb0(%i: index, %j: index, %k: index, %a: f32, %b: f32, %c: f32):
%result_1 = mulf %a, %b : f32

Expand All @@ -613,9 +619,7 @@ func @indexed_generic_region(

%result_2 = addf %c, %ijk_float : f32
linalg.yield %result_1, %result_2 : f32, f32
}: memref<?x?xf32, offset: ?, strides: [?, 1]>,
memref<?x?x?xf32, offset: ?, strides: [?, ?, 1]>,
memref<?x?x?xf32, offset: ?, strides: [?, ?, 1]>
}
return
}

Expand Down Expand Up @@ -666,10 +670,12 @@ func @indexed_generic_region(

func @generic_op_zero_rank(%arg0: memref<f32>, %arg1: memref<3x4xf32>)
{
linalg.generic #trait_broadcast %arg0, %arg1 {
linalg.generic #trait_broadcast
ins(%arg0 : memref<f32>)
outs(%arg1 : memref<3x4xf32>) {
^bb(%a: f32, %b: f32) :
linalg.yield %a : f32
} : memref<f32>, memref<3x4xf32>
}
return
}

Expand All @@ -690,13 +696,15 @@ func @generic_op_zero_rank(%arg0: memref<f32>, %arg1: memref<3x4xf32>)

func @indexed_generic_op_zero_rank(%arg0: memref<i32>, %arg1: memref<3x4xi32>)
{
linalg.indexed_generic #trait_broadcast %arg0, %arg1 {
linalg.indexed_generic #trait_broadcast
ins(%arg0 : memref<i32>)
outs(%arg1 : memref<3x4xi32>) {
^bb(%i: index, %j: index, %a: i32, %b: i32) :
%ij = addi %i, %j : index
%ij_int = index_cast %ij : index to i32
%result = addi %a, %ij_int : i32
linalg.yield %result : i32
} : memref<i32>, memref<3x4xi32>
}
return
}

Expand Down Expand Up @@ -736,11 +744,13 @@ func @indexed_generic_op_zero_rank(%arg0: memref<i32>, %arg1: memref<3x4xi32>)

func @generic_op_1D_reduce(%arg0: memref<?xf32>, %arg1: memref<f32>)
{
linalg.generic #trait_reduce_1D %arg0, %arg1 {
linalg.generic #trait_reduce_1D
ins(%arg0 : memref<?xf32>)
outs(%arg1 : memref<f32>) {
^bb(%a: f32, %b: f32) :
%0 = addf %a, %b : f32
linalg.yield %0 : f32
} : memref<?xf32>, memref<f32>
}
return
}
// CHECKLOOP-LABEL: @generic_op_1D_reduce
Expand Down Expand Up @@ -780,14 +790,16 @@ func @indexed_generic_op_1D_reduce(%arg0: memref<?xf32>,
%arg1: memref<f32>,
%arg2: memref<f32>)
{
linalg.indexed_generic #trait_reduce_init_1D %arg0, %arg1, %arg2 {
linalg.indexed_generic #trait_reduce_init_1D
ins(%arg0, %arg1 : memref<?xf32>, memref<f32>)
outs(%arg2 : memref<f32>) {
^bb(%i : index, %a: f32, %b: f32, %c: f32) :
%0 = constant 0 : index
%1 = cmpi "eq", %0, %i : index
%2 = select %1, %b, %c : f32
%3 = addf %a, %2 : f32
linalg.yield %3 : f32
} : memref<?xf32>, memref<f32>, memref<f32>
}
return
}
// CHECKLOOP-LABEL: @indexed_generic_op_1D_reduce
Expand Down Expand Up @@ -823,10 +835,10 @@ func @indexed_generic_op_1D_reduce(%arg0: memref<?xf32>,
}
func @generic_const_init(%arg0: memref<?xf32>) {
%cst = constant 1.0 : f32
linalg.generic #trait_const_fill %arg0 {
linalg.generic #trait_const_fill outs(%arg0 : memref<?xf32>) {
^bb0(%arg1: f32): // no predecessors
linalg.yield %cst : f32
}: memref<?xf32>
}
return
}
// CHECKLOOP-LABEL: @generic_const_init
Expand Down Expand Up @@ -855,11 +867,13 @@ func @generic_const_init(%arg0: memref<?xf32>) {
}
func @scalar_code(%arg0: memref<f32>, %arg1 : memref<f32>, %arg2 : memref<f32>)
{
linalg.generic #scalar_trait %arg0, %arg1, %arg2 {
linalg.generic #scalar_trait
ins(%arg0, %arg1 : memref<f32>, memref<f32>)
outs(%arg2 : memref<f32>) {
^bb(%a : f32, %b : f32, %c : f32) :
%0 = addf %a, %b : f32
linalg.yield %0 : f32
} : memref<f32>, memref<f32>, memref<f32>
}
return
}
// CHECKLOOP-LABEL: @scalar_code
Expand Down Expand Up @@ -944,14 +958,14 @@ func @named_batch_matmul(%A: memref<?x?x?xf32>, %B: memref<?x?x?xf32>, %C: memre
}

func @conv1d(%in : memref<?xf32>, %filter : memref<?xf32>, %out : memref<?xf32>) -> () {
linalg.generic #conv_1d_trait %in, %filter, %out {
linalg.generic #conv_1d_trait
ins(%in, %filter : memref<?xf32>, memref<?xf32>)
outs(%out : memref<?xf32>) {
^bb0(%a: f32, %b: f32, %c: f32) :
%d = mulf %a, %b : f32
%e = addf %c, %d : f32
linalg.yield %e : f32
} : memref<?xf32>,
memref<?xf32>,
memref<?xf32>
}
return
}

Expand Down Expand Up @@ -1012,14 +1026,14 @@ func @conv1d(%in : memref<?xf32>, %filter : memref<?xf32>, %out : memref<?xf32>
}

func @conv2d(%in : memref<?x?xf32>, %filter : memref<?x?xf32>, %out : memref<?x?xf32>) -> () {
linalg.generic #conv_2d_trait %in, %filter, %out {
linalg.generic #conv_2d_trait
ins(%in, %filter : memref<?x?xf32>, memref<?x?xf32>)
outs(%out : memref<?x?xf32>) {
^bb0(%a: f32, %b: f32, %c: f32) :
%d = mulf %a, %b : f32
%e = addf %c, %d : f32
linalg.yield %e : f32
} : memref<?x?xf32>,
memref<?x?xf32>,
memref<?x?xf32>
}
return
}

Expand Down Expand Up @@ -1096,14 +1110,14 @@ func @conv2d(%in : memref<?x?xf32>, %filter : memref<?x?xf32>, %out : memref<?x
}

func @conv3d(%in : memref<?x?x?xf32>, %filter : memref<?x?x?xf32>, %out : memref<?x?x?xf32>) -> () {
linalg.generic #conv_3d_trait %in, %filter, %out {
linalg.generic #conv_3d_trait
ins(%in, %filter : memref<?x?x?xf32>, memref<?x?x?xf32>)
outs(%out : memref<?x?x?xf32>) {
^bb0(%a: f32, %b: f32, %c: f32) :
%d = mulf %a, %b : f32
%e = addf %c, %d : f32
linalg.yield %e : f32
} : memref<?x?x?xf32>,
memref<?x?x?xf32>,
memref<?x?x?xf32>
}
return
}

Expand Down Expand Up @@ -1196,14 +1210,14 @@ func @conv3d(%in : memref<?x?x?xf32>, %filter : memref<?x?x?xf32>, %out : memre
}

func @conv4d(%in : memref<?x?x?x?xf32>, %filter : memref<?x?x?x?xf32>, %out : memref<?x?x?x?xf32>) -> () {
linalg.generic #conv_4d_trait %in, %filter, %out {
linalg.generic #conv_4d_trait
ins(%in, %filter : memref<?x?x?x?xf32>, memref<?x?x?x?xf32>)
outs(%out : memref<?x?x?x?xf32>) {
^bb0(%a: f32, %b: f32, %c: f32) :
%d = mulf %a, %b : f32
%e = addf %c, %d : f32
linalg.yield %e : f32
} : memref<?x?x?x?xf32>,
memref<?x?x?x?xf32>,
memref<?x?x?x?xf32>
}
return
}

Expand Down
25 changes: 12 additions & 13 deletions mlir/test/Dialect/Linalg/parallel_loops.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -5,15 +5,14 @@ func @linalg_generic_sum(%lhs: memref<2x2xf32>,
%rhs: memref<2x2xf32>,
%sum: memref<2x2xf32>) {
linalg.generic {
args_in = 2 : i64,
args_out = 1 : i64,
indexing_maps = [#map0, #map0, #map0],
iterator_types = ["parallel", "parallel"]
} %lhs, %rhs, %sum {
iterator_types = ["parallel", "parallel"]}
ins(%lhs, %rhs : memref<2x2xf32>, memref<2x2xf32>)
outs(%sum : memref<2x2xf32>) {
^bb0(%lhs_in: f32, %rhs_in: f32, %sum_out: f32): // no predecessors
%0 = addf %lhs_in, %rhs_in : f32
linalg.yield %0 : f32
}: memref<2x2xf32>, memref<2x2xf32>, memref<2x2xf32>
}
return
}
// CHECK-LABEL: @linalg_generic_sum
Expand All @@ -35,17 +34,17 @@ func @linalg_generic_sum(%lhs: memref<2x2xf32>,
affine_map<(d0, d1, d2, d3) -> (d0, d1, d3)>
]
#trait = {
args_in = 1,
args_out = 1,
iterator_types = ["parallel", "parallel", "reduction", "parallel"],
indexing_maps = #accesses
}

func @lower_outer_parallel(%A: memref<?x?x?x?xf32>, %B: memref<?x?x?xf32>) {
linalg.generic #trait %A, %B {
linalg.generic #trait
ins(%A : memref<?x?x?x?xf32>)
outs(%B : memref<?x?x?xf32>) {
^bb0(%a: f32, %b: f32):
linalg.yield %a: f32
} : memref<?x?x?x?xf32>, memref<?x?x?xf32>
}
return
}
// CHECK-LABEL: @lower_outer_parallel
Expand All @@ -68,17 +67,17 @@ func @lower_outer_parallel(%A: memref<?x?x?x?xf32>, %B: memref<?x?x?xf32>) {
affine_map<(d0, d1, d2, d3, d4, d5) -> (d0, d2, d4, d5)>
]
#trait = {
args_in = 1,
args_out = 1,
iterator_types = ["parallel", "parallel", "reduction", "parallel", "parallel", "reduction"],
indexing_maps = #accesses
}

func @lower_mixed_parallel(%A: memref<?x?x?x?x?x?xf32>, %B: memref<?x?x?x?xf32>) {
linalg.generic #trait %A, %B {
linalg.generic #trait
ins(%A : memref<?x?x?x?x?x?xf32>)
outs(%B : memref<?x?x?x?xf32>) {
^bb0(%a: f32, %b: f32):
linalg.yield %a: f32
} : memref<?x?x?x?x?x?xf32>, memref<?x?x?x?xf32>
}
return
}
// CHECK-LABEL: @lower_mixed_parallel
Expand Down
118 changes: 64 additions & 54 deletions mlir/test/Dialect/Linalg/roundtrip.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -293,46 +293,51 @@ func @pooling_sum(%arg0: memref<?x?x?xf32>,
]

#trait = {
args_in = 1,
args_out = 1,
indexing_maps = #accesses,
iterator_types = ["parallel", "parallel", "parallel"],
library_call = "some_external_function_name_1"
}

func @generic(%arg0: memref<?x?xvector<3x4xi4>, offset: ?, strides: [?, 1]>,
%arg1: memref<?x?x?xf32, offset: ?, strides: [?, ?, 1]>) {
linalg.generic #trait {foo = 1} %arg0, %arg1 {
linalg.generic #trait
ins(%arg0 : memref<?x?xvector<3x4xi4>, offset: ?, strides: [?, 1]>)
outs(%arg1 : memref<?x?x?xf32, offset: ?, strides: [?, ?, 1]>)
attrs = {foo = 1} {
^bb(%0: vector<3x4xi4>, %1: f32) :
%f0 = constant 0.0 : f32
linalg.yield %f0 : f32
} : memref<?x?xvector<3x4xi4>, offset: ?, strides: [?, 1]>,
memref<?x?x?xf32, offset: ?, strides: [?, ?, 1]>
}
return
}
// CHECK-LABEL: func @generic
// CHECK: linalg.generic {args_in = 1 : i64, args_out = 1 : i64,
// CHECK-SAME: indexing_maps = [#{{.*}}, #{{.*}}], iterator_types = ["parallel", "parallel", "parallel"],
// CHECK-SAME: library_call = "some_external_function_name_1"
// CHECK: linalg.generic {
// CHECK-SAME: indexing_maps = [#{{[0-9a-z]*}}, #{{[0-9a-z]*}}],
// CHECK-SAME: iterator_types = ["parallel", "parallel", "parallel"],
// CHECK-SAME: library_call = "some_external_function_name_1"}
// CHECK-SAME: ins({{.*}} : memref<?x?xvector<3x4xi4>, #[[$strided2D]]>)
// CHECK-SAME: outs({{.*}} : memref<?x?x?xf32, #[[$strided3D]]>)
// CHECK-SAME: {foo = 1 : i64}
// CHECK: memref<?x?xvector<3x4xi4>, #[[$strided2D]]>, memref<?x?x?xf32, #[[$strided3D]]>

func @generic_with_tensor_input(%arg0: tensor<?x?xvector<3x4xi4>>,
%arg1: memref<?x?x?xf32, offset: ?, strides: [?, ?, 1]>) {
linalg.generic #trait {foo = 1} %arg0, %arg1 {
linalg.generic #trait
ins(%arg0 : tensor<?x?xvector<3x4xi4>>)
outs(%arg1 : memref<?x?x?xf32, offset: ?, strides: [?, ?, 1]>)
attrs = {foo = 1} {
^bb(%0: vector<3x4xi4>, %1: f32) :
%f0 = constant 0.0 : f32
linalg.yield %f0 : f32
} : tensor<?x?xvector<3x4xi4>>,
memref<?x?x?xf32, offset: ?, strides: [?, ?, 1]>
}
return
}
// CHECK-LABEL: func @generic_with_tensor_input
// CHECK: linalg.generic {args_in = 1 : i64, args_out = 1 : i64,
// CHECK: linalg.generic {
// CHECK-SAME: indexing_maps = [#{{.*}}, #{{.*}}], iterator_types = ["parallel", "parallel", "parallel"],
// CHECK-SAME: library_call = "some_external_function_name_1"}
// CHECK-SAME: ins({{.*}} : tensor<?x?xvector<3x4xi4>>)
// CHECK-SAME: outs({{.*}} : memref<?x?x?xf32, #[[$strided3D]]>)
// CHECK-SAME: {foo = 1 : i64}
// CHECK: tensor<?x?xvector<3x4xi4>>, memref<?x?x?xf32, #[[$strided3D]]>

// -----

Expand All @@ -342,8 +347,6 @@ func @generic_with_tensor_input(%arg0: tensor<?x?xvector<3x4xi4>>,
]

#trait2 = {
args_in = 2,
args_out = 1,
indexing_maps = #accesses,
iterator_types = ["parallel", "parallel", "parallel"],
library_call = "some_external_function_name_1"
Expand All @@ -352,20 +355,22 @@ func @generic_with_tensor_input(%arg0: tensor<?x?xvector<3x4xi4>>,
func @generic_with_tensor_input_and_output(
%arg0: tensor<?x?xvector<3x4xi4>>, %arg1: tensor<?x?x?xf32>)
-> (tensor<?x?x?xf32>) {
%0 = linalg.generic #trait2 {foo = 1} %arg0, %arg1 {
%0 = linalg.generic #trait2
ins(%arg0, %arg1 : tensor<?x?xvector<3x4xi4>>, tensor<?x?x?xf32>)
attrs = {foo = 1} {
^bb(%0: vector<3x4xi4>, %1: f32) :
%f0 = constant 0.0 : f32
linalg.yield %f0 : f32
} : tensor<?x?xvector<3x4xi4>>, tensor<?x?x?xf32> -> tensor<?x?x?xf32>
} -> tensor<?x?x?xf32>
return %0 : tensor<?x?x?xf32>
}
// CHECK-LABEL: func @generic_with_tensor_input_and_output
// CHECK: linalg.generic {args_in = 2 : i64, args_out = 1 : i64,
// CHECK: linalg.generic {
// CHECK-SAME: indexing_maps = [#{{.*}}, #{{.*}}], iterator_types = ["parallel", "parallel", "parallel"],
// CHECK-SAME: library_call = "some_external_function_name_1"}
// CHECK-SAME: ins({{.*}} : tensor<?x?xvector<3x4xi4>>, tensor<?x?x?xf32>)
// CHECK-SAME: {foo = 1 : i64}
// CHECK-SAME: %{{.*}}, %{{.*}}
// CHECK: tensor<?x?xvector<3x4xi4>>, tensor<?x?x?xf32> -> tensor<?x?x?xf32>
// CHECK: -> tensor<?x?x?xf32>
// CHECK: return {{.*}} : tensor<?x?x?xf32>

// -----
Expand All @@ -376,8 +381,6 @@ func @generic_with_tensor_input_and_output(
]

#trait2 = {
args_in = 2,
args_out = 1,
indexing_maps = #accesses,
iterator_types = ["parallel", "parallel", "parallel"],
library_call = "some_external_function_name_1"
Expand All @@ -386,20 +389,22 @@ func @generic_with_tensor_input_and_output(
func @indexed_generic_with_tensor_input_and_output(
%arg0: tensor<?x?xvector<3x4xi4>>, %arg1: tensor<?x?x?xf32>)
-> (tensor<?x?x?xf32>) {
%0 = linalg.indexed_generic #trait2 {foo = 1} %arg0, %arg1 {
%0 = linalg.indexed_generic #trait2
ins(%arg0, %arg1 : tensor<?x?xvector<3x4xi4>>, tensor<?x?x?xf32>)
attrs = {foo = 1} {
^bb(%i: index, %j: index, %k: index, %0: vector<3x4xi4>, %1: f32) :
%f0 = constant 0.0 : f32
linalg.yield %f0 : f32
} : tensor<?x?xvector<3x4xi4>>, tensor<?x?x?xf32> -> tensor<?x?x?xf32>
} -> tensor<?x?x?xf32>
return %0 : tensor<?x?x?xf32>
}
// CHECK-LABEL: func @indexed_generic_with_tensor_input_and_output
// CHECK: linalg.indexed_generic {args_in = 2 : i64, args_out = 1 : i64,
// CHECK: linalg.indexed_generic {
// CHECK-SAME: indexing_maps = [#{{.*}}, #{{.*}}], iterator_types = ["parallel", "parallel", "parallel"],
// CHECK-SAME: library_call = "some_external_function_name_1"}
// CHECK-SAME: ins({{.*}} : tensor<?x?xvector<3x4xi4>>, tensor<?x?x?xf32>)
// CHECK-SAME: {foo = 1 : i64}
// CHECK-SAME: %{{.*}}, %{{.*}}
// CHECK: tensor<?x?xvector<3x4xi4>>, tensor<?x?x?xf32> -> tensor<?x?x?xf32>
// CHECK: -> tensor<?x?x?xf32>
// CHECK: return {{.*}} : tensor<?x?x?xf32>

// -----
Expand All @@ -410,28 +415,28 @@ func @indexed_generic_with_tensor_input_and_output(
]

#trait_broadcast = {
args_in = 1,
args_out = 1,
indexing_maps = #broadcast_access,
iterator_types = ["parallel", "parallel"],
library_call = "some_broadcast_external_fn"
}

func @generic_op_zero_rank(%arg0: tensor<f32>) -> (tensor<3x4xf32>)
{
%0 = linalg.generic #trait_broadcast %arg0 {
%0 = linalg.generic #trait_broadcast
ins(%arg0 : tensor<f32>) {
^bb(%a: f32) :
linalg.yield %a : f32
} : tensor<f32> -> tensor<3x4xf32>
} -> tensor<3x4xf32>
return %0 : tensor<3x4xf32>
}

func @indexed_generic_op_zero_rank(%arg0: tensor<f32>) -> (tensor<3x4xf32>)
{
%0 = linalg.indexed_generic #trait_broadcast %arg0 {
%0 = linalg.indexed_generic #trait_broadcast
ins(%arg0 : tensor<f32>) {
^bb(%i: index, %j: index, %a: f32) :
linalg.yield %a : f32
} : tensor<f32> -> tensor<3x4xf32>
} -> tensor<3x4xf32>
return %0 : tensor<3x4xf32>
}

Expand All @@ -446,50 +451,55 @@ func @indexed_generic_op_zero_rank(%arg0: tensor<f32>) -> (tensor<3x4xf32>)
]

#trait3 = {
args_in = 1,
args_out = 1,
indexing_maps = #accesses,
iterator_types = ["parallel", "parallel", "parallel"],
library_call = "some_external_function_name_2"
}

func @generic_region(%arg0: memref<?x?xvector<3x4xi4>, offset: ?, strides: [?, 1]>,
%arg1: memref<?x?x?xf32, offset: ?, strides: [?, ?, 1]>) {
linalg.generic #trait3 {foo = 1} %arg0, %arg1 {
linalg.generic #trait3
ins(%arg0 : memref<?x?xvector<3x4xi4>, offset: ?, strides: [?, 1]>)
outs(%arg1 : memref<?x?x?xf32, offset: ?, strides: [?, ?, 1]>)
attrs = {foo = 1} {
^bb(%a: vector<3x4xi4>, %b: f32) :
linalg.yield %b : f32
} : memref<?x?xvector<3x4xi4>, offset: ?, strides: [?, 1]>,
memref<?x?x?xf32, offset: ?, strides: [?, ?, 1]>
}
return
}
// CHECK-LABEL: func @generic_region
// CHECK: linalg.generic {args_in = 1 : i64, args_out = 1 : i64,
// CHECK-SAME: indexing_maps = [#{{.*}}, #{{.*}}], iterator_types = ["parallel", "parallel", "parallel"],
// CHECK: linalg.generic {
// CHECK-SAME: indexing_maps = [#{{[0-9a-z]*}}, #{{[0-9a-z]*}}],
// CHECK-SAME: iterator_types = ["parallel", "parallel", "parallel"],
// CHECK-SAME: library_call = "some_external_function_name_2"
// CHECK-SAME: {foo = 1 : i64}
// CHECK: ^{{.*}}(%{{.*}}: vector<3x4xi4>, %{{.*}}: f32):
// CHECK: linalg.yield %{{.*}} : f32
// CHECK: memref<?x?xvector<3x4xi4>, #[[$strided2D]]>,
// CHECK-SAME: memref<?x?x?xf32, #[[$strided3D]]>
// CHECK-SAME: ins({{.*}} : memref<?x?xvector<3x4xi4>, #[[$strided2D]]>)
// CHECK-SAME: outs({{.*}} : memref<?x?x?xf32, #[[$strided3D]]>)
// CHECK-SAME: attrs = {foo = 1 : i64} {
// CHECK: ^{{.*}}(%{{.*}}: vector<3x4xi4>, %{{.*}}: f32):
// CHECK: linalg.yield %{{.*}} : f32

func @indexed_generic(%arg0: memref<?x?xvector<3x4xi4>, offset: ?, strides: [?, 1]>,
%arg1: memref<?x?x?xf32, offset: ?, strides: [?, ?, 1]>) {
linalg.indexed_generic #trait3 {foo = 1} %arg0, %arg1 {
^bb(%i: index, %j: index, %k: index, %a: vector<3x4xi4>, %b: f32) :
linalg.indexed_generic #trait3
ins(%arg0 : memref<?x?xvector<3x4xi4>, offset: ?, strides: [?, 1]>)
outs(%arg1 : memref<?x?x?xf32, offset: ?, strides: [?, ?, 1]>)
attrs = {foo = 1} {
^bb(%i: index, %j: index, %k: index, %a: vector<3x4xi4>, %b: f32) :
linalg.yield %b : f32
}: memref<?x?xvector<3x4xi4>, offset: ?, strides: [?, 1]>,
memref<?x?x?xf32, offset: ?, strides: [?, ?, 1]>
}
return
}
// CHECK-LABEL: func @indexed_generic
// CHECK: linalg.indexed_generic {args_in = 1 : i64, args_out = 1 : i64,
// CHECK-SAME: indexing_maps = [#{{.*}}, #{{.*}}], iterator_types = ["parallel", "parallel", "parallel"],
// CHECK: linalg.indexed_generic {
// CHECK-SAME: indexing_maps = [#{{[0-9a-z]*}}, #{{[0-9a-z]*}}],
// CHECK-SAME: iterator_types = ["parallel", "parallel", "parallel"],
// CHECK-SAME: library_call = "some_external_function_name_2"
// CHECK-SAME: ins({{.*}} : memref<?x?xvector<3x4xi4>, #[[$strided2D]]>)
// CHECK-SAME: outs({{.*}} : memref<?x?x?xf32, #[[$strided3D]]>)
// CHECK-SAME: {foo = 1 : i64}
// CHECK: ^{{.*}}(%{{.*}}: index, %{{.*}}: index, %{{.*}}: index, %{{.*}}: vector<3x4xi4>, %{{.*}}: f32):
// CHECK: linalg.yield %{{.*}} : f32
// CHECK: }: memref<?x?xvector<3x4xi4>, #[[$strided2D]]>,
// CHECK-SAME: memref<?x?x?xf32, #[[$strided3D]]>
// CHECK: }

// -----

Expand Down
17 changes: 8 additions & 9 deletions mlir/test/Dialect/Linalg/standard.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -72,8 +72,6 @@ func @copy_transpose(%arg0: memref<?x?x?xf32, offset: ?, strides: [?, ?, 1]>, %a
affine_map<(m, n, k) -> (m, n)>
]
#matmul_trait = {
args_in = 2,
args_out = 1,
iterator_types = ["parallel", "parallel", "reduction"],
indexing_maps = #matmul_accesses,
library_call = "external_outerproduct_matmul"
Expand All @@ -88,33 +86,34 @@ func @copy_transpose(%arg0: memref<?x?x?xf32, offset: ?, strides: [?, ?, 1]>, %a
!matrix_type_C = type memref<?x?x!vector_type_C>

func @matmul_vec_impl(%A: !matrix_type_A, %B: !matrix_type_B, %C: !matrix_type_C) {
linalg.generic #matmul_trait %A, %B, %C {
linalg.generic #matmul_trait
ins(%A, %B : !matrix_type_A, !matrix_type_B)
outs(%C : !matrix_type_C) {
^bb0(%a: !vector_type_A, %b: !vector_type_B, %c: !vector_type_C):
%d = vector.outerproduct %a, %b, %c: !vector_type_A, !vector_type_B
linalg.yield %d: !vector_type_C
} : !matrix_type_A, !matrix_type_B, !matrix_type_C

}
return
}
// CHECK-LABEL: func @matmul_vec_impl(
// CHECK: call @external_outerproduct_matmul(%{{.*}}) :

#indexed_matmul_trait = {
args_in = 2,
args_out = 1,
iterator_types = ["parallel", "parallel", "reduction"],
indexing_maps = #matmul_accesses,
library_call = "external_indexed_outerproduct_matmul"
}
func @matmul_vec_indexed(%A: !matrix_type_A,
%B: !matrix_type_B,
%C: !matrix_type_C) {
linalg.indexed_generic #indexed_matmul_trait %A, %B, %C {
linalg.indexed_generic #indexed_matmul_trait
ins(%A, %B : !matrix_type_A, !matrix_type_B)
outs(%C : !matrix_type_C) {
^bb0(%i: index, %j: index, %k: index,
%a: !vector_type_A, %b: !vector_type_B, %c: !vector_type_C):
%d = vector.outerproduct %a, %b, %c: !vector_type_A, !vector_type_B
linalg.yield %d: !vector_type_C
} : !matrix_type_A, !matrix_type_B, !matrix_type_C
}
return
}
// CHECK-LABEL: func @matmul_vec_indexed(
Expand Down
27 changes: 15 additions & 12 deletions mlir/test/Dialect/Linalg/tensors-to-buffers.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -4,23 +4,24 @@

// CHECK-LABEL: func @multiple_results_generic_op
func @multiple_results_generic_op(%arg0: tensor<4xf32>) -> (tensor<4xf32>, tensor<4xf32>) {
%0, %1 = linalg.generic {args_in = 1 : i64, args_out = 2 : i64, indexing_maps = [#map0, #map0, #map0], iterator_types = ["parallel"]} %arg0 {
%0, %1 = linalg.generic {indexing_maps = [#map0, #map0, #map0], iterator_types = ["parallel"]}
ins(%arg0 : tensor<4xf32>) {
^bb0(%gen_arg1: f32):
%tmp1 = exp %gen_arg1 : f32
linalg.yield %tmp1, %tmp1 : f32, f32
}: tensor<4xf32> -> (tensor<4xf32>, tensor<4xf32>)
} -> tensor<4xf32>, tensor<4xf32>
return %0, %1 : tensor<4xf32>, tensor<4xf32>
}
// CHECK: (%[[NEW_ARG0:.*]]: [[TYPE:.*]], %[[ARG1_RESULT:.*]]: [[TYPE]], %[[ARG2_RESULT:.*]]: [[TYPE]])
// CHECK: %[[FIRST_ALLOC:.*]] = alloc() : [[TYPE]]
// CHECK: %[[SECOND_ALLOC:.*]] = alloc() : [[TYPE]]
// CHECK: linalg.generic
// CHECK-SAME: %[[NEW_ARG0]], %[[FIRST_ALLOC]], %[[SECOND_ALLOC]]
// CHECK-SAME: ins(%[[NEW_ARG0]] : [[TYPE]]
// CHECK-SAME: outs(%[[FIRST_ALLOC]], %[[SECOND_ALLOC]] : [[TYPE]], [[TYPE]]
// CHECK-NEXT: ^{{[a-z0-9_]*}}
// CHECK-SAME: %{{.*}}: f32, %{{.*}}: f32, %{{.*}}: f32
// CHECK-NEXT: %{{.*}} = exp
// CHECK-NEXT: linalg.yield
// CHECK-NEXT: [[TYPE]], [[TYPE]], [[TYPE]]
// CHECK: linalg.copy(%[[FIRST_ALLOC]], %[[ARG1_RESULT]])
// CHECK: dealloc %[[FIRST_ALLOC]]
// CHECK: linalg.copy(%[[SECOND_ALLOC]], %[[ARG2_RESULT]])
Expand All @@ -33,31 +34,33 @@ func @multiple_results_generic_op(%arg0: tensor<4xf32>) -> (tensor<4xf32>, tenso

// CHECK-LABEL: func @chained_operations
func @chained_operations(%arg0: tensor<4xf32>) -> tensor<4xf32> {
%0 = linalg.generic {args_in = 1 : i64, args_out = 1 : i64, indexing_maps = [#map0, #map0], iterator_types = ["parallel"]} %arg0 {
%0 = linalg.generic {indexing_maps = [#map0, #map0], iterator_types = ["parallel"]}
ins(%arg0 : tensor<4xf32>) {
^bb0(%gen_arg1: f32):
%tmp1 = exp %gen_arg1 : f32
linalg.yield %tmp1 : f32
}: tensor<4xf32> -> tensor<4xf32>
%1 = linalg.generic {args_in = 1 : i64, args_out = 1 : i64, indexing_maps = [#map0, #map0], iterator_types = ["parallel"]} %0 {
} -> tensor<4xf32>
%1 = linalg.generic {args_in = 1 : i64, args_out = 1 : i64, indexing_maps = [#map0, #map0], iterator_types = ["parallel"]}
ins(%0 : tensor<4xf32>) {
^bb0(%gen_arg2: f32):
%tmp2 = exp %gen_arg2 : f32
linalg.yield %tmp2 : f32
}: tensor<4xf32> -> tensor<4xf32>
} -> tensor<4xf32>
return %1 : tensor<4xf32>
}
// CHECK: (%[[NEW_ARG0:.*]]: [[TYPE:.*]], %[[ARG1_RESULT:.*]]: [[TYPE]])
// CHECK: %[[FIRST_ALLOC:.*]] = alloc() : [[TYPE]]
// CHECK: linalg.generic
// CHECK-SAME: %[[NEW_ARG0]], %[[FIRST_ALLOC]]
// CHECK-SAME: ins(%[[NEW_ARG0]] : [[TYPE]]
// CHECK-SAME: outs(%[[FIRST_ALLOC]] : [[TYPE]]
// CHECK: ^{{[a-z0-9_]*}}
// CHECK-SAME: %{{.*}}: f32, %{{.*}}: f32
// CHECK: [[TYPE]], [[TYPE]]
// CHECK: %[[SECOND_ALLOC:.*]] = alloc() : [[TYPE]]
// CHECK: linalg.generic
// CHECK-SAME: %[[FIRST_ALLOC]], %[[SECOND_ALLOC]]
// CHECK-SAME: ins(%[[FIRST_ALLOC]] : [[TYPE]]
// CHECK-SAME: outs(%[[SECOND_ALLOC]] : [[TYPE]]
// CHECK: ^{{[a-z0-9_]*}}
// CHECK-SAME: %{{.*}}: f32, %{{.*}}: f32
// CHECK: [[TYPE]], [[TYPE]]
// CHECK: dealloc %[[FIRST_ALLOC]]
// CHECK: linalg.copy(%[[SECOND_ALLOC]], %[[ARG1_RESULT]])
// CHECK: dealloc %[[SECOND_ALLOC]]
Expand Down
6 changes: 4 additions & 2 deletions mlir/test/Dialect/Linalg/tile.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -349,11 +349,13 @@ func @fill(%arg0: memref<?x?xf32, offset: ?, strides: [?, 1]>, %arg1: f32) {

func @pointwise(%arg0: memref<?x?xf32, offset: ?, strides: [?, 1]>, %arg1: memref<?x?xf32, offset: ?, strides: [?, 1]>,
%arg2: memref<?x?xf32, offset: ?, strides: [?, 1]>) {
linalg.generic #pointwise_2d_trait %arg0, %arg1, %arg2 {
linalg.generic #pointwise_2d_trait
ins(%arg0, %arg1 : memref<?x?xf32, offset: ?, strides: [?, 1]>, memref<?x?xf32, offset: ?, strides: [?, 1]>)
outs(%arg2 : memref<?x?xf32, offset: ?, strides: [?, 1]>) {
^bb0(%arg4: f32, %arg5: f32, %arg6: f32): // no predecessors
%4 = addf %arg4, %arg5 : f32
linalg.yield %4 : f32
}: memref<?x?xf32, offset: ?, strides: [?, 1]>, memref<?x?xf32, offset: ?, strides: [?, 1]>, memref<?x?xf32, offset: ?, strides: [?, 1]>
}
return
}
// TILE-2-LABEL: func @pointwise
Expand Down
12 changes: 8 additions & 4 deletions mlir/test/Dialect/Linalg/tile_indexed_generic.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -10,13 +10,15 @@
iterator_types = ["parallel"]
}
func @indexed_generic_vector(%operand: memref<50xf32>, %result: memref<50xf32>) {
linalg.indexed_generic #pointwise_1d_trait %operand, %result {
linalg.indexed_generic #pointwise_1d_trait
ins(%operand :memref<50xf32>)
outs(%result : memref<50xf32>) {
^bb0(%i: index, %operand_in: f32, %result_in: f32):
%i_int = index_cast %i: index to i32
%i_float = sitofp %i_int : i32 to f32
%out = addf %operand_in, %i_float : f32
linalg.yield %out : f32
}: memref<50xf32>, memref<50xf32>
}
return
}
// TILE-10n25-LABEL: func @indexed_generic_vector
Expand Down Expand Up @@ -53,15 +55,17 @@ func @indexed_generic_vector(%operand: memref<50xf32>, %result: memref<50xf32>)
iterator_types = ["parallel", "parallel"]
}
func @indexed_generic_matrix(%operand: memref<50x100xf32>, %result: memref<50x100xf32>) {
linalg.indexed_generic #combined_indices_trait %operand, %result {
linalg.indexed_generic #combined_indices_trait
ins(%operand : memref<50x100xf32>)
outs(%result : memref<50x100xf32>) {
^bb0(%i: index, %j: index, %operand_in: f32, %result_in: f32):
%i_int = index_cast %i: index to i32
%i_float = sitofp %i_int : i32 to f32
%j_int = index_cast %j: index to i32
%j_float = sitofp %j_int : i32 to f32
%out = addf %i_float, %j_float : f32
linalg.yield %out : f32
}: memref<50x100xf32>, memref<50x100xf32>
}
return
}
// TILE-10n25-LABEL: func @indexed_generic_matrix
Expand Down
17 changes: 9 additions & 8 deletions mlir/test/Dialect/Linalg/tile_parallel.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -14,13 +14,14 @@
func @sum(%lhs: memref<?x?xf32, offset: ?, strides: [?, 1]>,
%rhs: memref<?x?xf32, offset: ?, strides: [?, 1]>,
%sum: memref<?x?xf32, offset: ?, strides: [?, 1]>) {
linalg.generic #pointwise_2d_trait %lhs, %rhs, %sum {
linalg.generic #pointwise_2d_trait
ins(%lhs, %rhs: memref<?x?xf32, offset: ?, strides: [?, 1]>,
memref<?x?xf32, offset: ?, strides: [?, 1]>)
outs(%sum : memref<?x?xf32, offset: ?, strides: [?, 1]>) {
^bb0(%lhs_in: f32, %rhs_in: f32, %sum_out: f32):
%result = addf %lhs_in, %rhs_in : f32
linalg.yield %result : f32
}: memref<?x?xf32, offset: ?, strides: [?, 1]>,
memref<?x?xf32, offset: ?, strides: [?, 1]>,
memref<?x?xf32, offset: ?, strides: [?, 1]>
}
return
}
// TILE-2-LABEL: func @sum(
Expand All @@ -33,7 +34,7 @@ func @sum(%lhs: memref<?x?xf32, offset: ?, strides: [?, 1]>,
// TILE-2: [[LHS_SUBVIEW:%.*]] = subview [[LHS]]
// TILE-2: [[RHS_SUBVIEW:%.*]] = subview [[RHS]]
// TILE-2: [[SUM_SUBVIEW:%.*]] = subview [[SUM]]
// TILE-2: linalg.generic {{.*}} [[LHS_SUBVIEW]], [[RHS_SUBVIEW]], [[SUM_SUBVIEW]] {
// TILE-2: linalg.generic {{.*}} ins([[LHS_SUBVIEW]], [[RHS_SUBVIEW]]{{.*}} outs([[SUM_SUBVIEW]]

// TILE-02-LABEL: func @sum(
// TILE-02-SAME: [[LHS:%.*]]: {{.*}}, [[RHS:%.*]]: {{.*}}, [[SUM:%.*]]: {{.*}}) {
Expand All @@ -45,12 +46,12 @@ func @sum(%lhs: memref<?x?xf32, offset: ?, strides: [?, 1]>,
// TILE-02: [[LHS_SUBVIEW:%.*]] = subview [[LHS]]
// TILE-02: [[RHS_SUBVIEW:%.*]] = subview [[RHS]]
// TILE-02: [[SUM_SUBVIEW:%.*]] = subview [[SUM]]
// TILE-02: linalg.generic {{.*}} [[LHS_SUBVIEW]], [[RHS_SUBVIEW]], [[SUM_SUBVIEW]] {
// TILE-02: linalg.generic {{.*}} ins([[LHS_SUBVIEW]], [[RHS_SUBVIEW]]{{.*}} outs([[SUM_SUBVIEW]]

// TILE-002-LABEL: func @sum(
// TILE-002-SAME: [[LHS:%.*]]: {{.*}}, [[RHS:%.*]]: {{.*}}, [[SUM:%.*]]: {{.*}}) {
// TILE-002-NO: scf.parallel
// TILE-002: linalg.generic {{.*}} [[LHS]], [[RHS]], [[SUM]] {
// TILE-002: linalg.generic {{.*}} ins([[LHS]], [[RHS]]{{.*}} outs([[SUM]]

// TILE-234-LABEL: func @sum(
// TILE-234-SAME: [[LHS:%.*]]: {{.*}}, [[RHS:%.*]]: {{.*}}, [[SUM:%.*]]: {{.*}}) {
Expand All @@ -64,4 +65,4 @@ func @sum(%lhs: memref<?x?xf32, offset: ?, strides: [?, 1]>,
// TILE-234: [[LHS_SUBVIEW:%.*]] = subview [[LHS]]
// TILE-234: [[RHS_SUBVIEW:%.*]] = subview [[RHS]]
// TILE-234: [[SUM_SUBVIEW:%.*]] = subview [[SUM]]
// TILE-234: linalg.generic {{.*}} [[LHS_SUBVIEW]], [[RHS_SUBVIEW]], [[SUM_SUBVIEW]] {
// TILE-234: linalg.generic {{.*}} ins([[LHS_SUBVIEW]], [[RHS_SUBVIEW]]{{.*}} outs([[SUM_SUBVIEW]]
15 changes: 10 additions & 5 deletions mlir/test/Dialect/Linalg/tile_parallel_reduce.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -59,12 +59,14 @@ func @reduction(%arg0 : memref<?x?x?xf32>,
%arg1 : memref<?x?xf32>,
%arg2 : memref<?xf32>)
{
linalg.generic #trait %arg0, %arg1, %arg2 {
linalg.generic #trait
ins(%arg0, %arg1 : memref<?x?x?xf32>, memref<?x?xf32>)
outs(%arg2 : memref<?xf32>) {
^bb0(%arg3 : f32, %arg4 : f32, %arg5 : f32):
%0 = addf %arg3, %arg4 : f32
%1 = addf %0, %arg5 : f32
linalg.yield %1 : f32
} : memref<?x?x?xf32>, memref<?x?xf32>, memref<?xf32>
}
return
}

Expand All @@ -82,7 +84,8 @@ func @reduction(%arg0 : memref<?x?x?xf32>,
// CHECK: %[[SV2:.*]] = subview %{{.*}}[%[[ARG3]], %[[ARG5]]]
// CHECK: %[[SV3:.*]] = subview %{{.*}}[%[[ARG4]]]
// CHECK: linalg.generic
// CHECK-SAME: %[[SV1]], %[[SV2]], %[[SV3]]
// CHECK-SAME: ins(%[[SV1]], %[[SV2]]
// CHECK-SAME: outs(%[[SV3]]

// TILE1-LABEL: func @reduction
// TILE1-DAG: %[[C2:.*]] = constant 2 : index
Expand All @@ -92,7 +95,8 @@ func @reduction(%arg0 : memref<?x?x?xf32>,
// TILE1: %[[SV2:.*]] = subview %{{.*}}[%[[ARG3]], 0]
// TILE1-NOT: subview
// TILE1: linalg.generic
// TILE1-SAME: %[[SV1]], %[[SV2]], %{{.*}}
// TILE1-SAME: ins(%[[SV1]], %[[SV2]]
// TILE1-SAME: outs(%{{.*}}

// TILE2-LABEL: func @reduction
// TILE2-DAG: %[[C2:.*]] = constant 2 : index
Expand All @@ -105,4 +109,5 @@ func @reduction(%arg0 : memref<?x?x?xf32>,
// TILE2: %[[SV2:.*]] = subview %{{.*}}[%[[ARG3]], 0]
// TILE2: %[[SV3:.*]] = subview %{{.*}}[%[[ARG4]]]
// TILE2: linalg.generic
// TILE2-SAME: %[[SV1]], %[[SV2]], %[[SV3]]
// TILE2-SAME: ins(%[[SV1]], %[[SV2]]
// TILE2-SAME: outs(%[[SV3]]
42 changes: 24 additions & 18 deletions mlir/test/Dialect/Linalg/transform-patterns.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -105,12 +105,14 @@ func @matmul(%A: memref<?x?xf32, offset: ?, strides: [?, 1]>,
}
func @vectorization_test(%A: memref<8x16xf32>, %B: memref<16x32xf32>,
%C: memref<8x32xf32>) {
linalg.generic #matmul_trait %A, %B, %C {
linalg.generic #matmul_trait
ins(%A, %B : memref<8x16xf32>, memref<16x32xf32>)
outs(%C : memref<8x32xf32>) {
^bb(%a: f32, %b: f32, %c: f32) :
%d = mulf %a, %b: f32
%e = addf %c, %d: f32
linalg.yield %e : f32
} : memref<8x16xf32>, memref<16x32xf32>, memref<8x32xf32>
}
return
}
// CHECK-LABEL: func @vectorization_test
Expand All @@ -122,12 +124,14 @@ func @vectorization_test(%A: memref<8x16xf32>, %B: memref<16x32xf32>,

func @vectorization_test_integer(%A: memref<8x16xi32>, %B: memref<16x32xi32>,
%C: memref<8x32xi32>) {
linalg.generic #matmul_trait %A, %B, %C {
linalg.generic #matmul_trait
ins(%A, %B : memref<8x16xi32>, memref<16x32xi32>)
outs(%C : memref<8x32xi32>) {
^bb(%a: i32, %b: i32, %c: i32) :
%d = muli %a, %b: i32
%e = addi %c, %d: i32
linalg.yield %e : i32
} : memref<8x16xi32>, memref<16x32xi32>, memref<8x32xi32>
}
return
}
// CHECK-LABEL: func @vectorization_test_integer
Expand Down Expand Up @@ -187,23 +191,24 @@ func @test_vectorize_copy_scalar(%A : memref<f32>, %B : memref<f32>) {
func @permute_generic(%A: memref<?x?xf32, offset: ?, strides: [?, 1]>,
%B: memref<?x?xf32, offset: ?, strides: [?, 1]>,
%C: memref<?x?xf32, offset: ?, strides: [?, 1]>) {
linalg.generic #generic_matmul_trait %A, %B, %C {
linalg.generic #generic_matmul_trait
ins(%A, %B : memref<?x?xf32, offset: ?, strides: [?, 1]>,
memref<?x?xf32, offset: ?, strides: [?, 1]>)
outs(%C : memref<?x?xf32, offset: ?, strides: [?, 1]>) {
^bb(%a: f32, %b: f32, %c: f32):
%d = mulf %a, %b: f32
%e = addf %c, %d: f32
linalg.yield %e: f32
}: memref<?x?xf32, offset: ?, strides: [?, 1]>,
memref<?x?xf32, offset: ?, strides: [?, 1]>,
memref<?x?xf32, offset: ?, strides: [?, 1]>
}
return
}
// CHECK-LABEL: func @permute_generic
// CHECK: linalg.generic {args_in = 2 : i64, args_out = 1 : i64,
// CHECK: linalg.generic {
// CHECK-SAME: indexing_maps = [#[[$kn]], #[[$nm]], #[[$km]]],
// CHECK-SAME: iterator_types = ["parallel", "reduction", "parallel"],
// CHECK-SAME: library_call = "linalg_matmul"} %{{.*}}, %{{.*}}, %{{.*}}
// CHECK-SAME: library_call = "linalg_matmul"}
// CHECK: memref<?x?xf32, #[[$STRIDED_2D_u_1]]>,
// CHECK-SAME: memref<?x?xf32, #[[$STRIDED_2D_u_1]]>,
// CHECK-SAME: memref<?x?xf32, #[[$STRIDED_2D_u_1]]>
// CHECK-SAME: memref<?x?xf32, #[[$STRIDED_2D_u_1]]>

#indexed_matmul_trait = {
Expand All @@ -217,23 +222,24 @@ func @permute_generic_indexed(
%A: memref<?x?xf32, offset: ?, strides: [?, 1]>,
%B: memref<?x?xf32, offset: ?, strides: [?, 1]>,
%C: memref<?x?xf32, offset: ?, strides: [?, 1]>) {
linalg.indexed_generic #indexed_matmul_trait %A, %B, %C {
linalg.indexed_generic #indexed_matmul_trait
ins(%A, %B : memref<?x?xf32, offset: ?, strides: [?, 1]>,
memref<?x?xf32, offset: ?, strides: [?, 1]>)
outs(%C : memref<?x?xf32, offset: ?, strides: [?, 1]>) {
^bb(%i: index, %j: index, %k: index, %a: f32, %b: f32, %c: f32):
%d = mulf %a, %b: f32
%e = addf %c, %d: f32
linalg.yield %e: f32
} : memref<?x?xf32, offset: ?, strides: [?, 1]>,
memref<?x?xf32, offset: ?, strides: [?, 1]>,
memref<?x?xf32, offset: ?, strides: [?, 1]>
}
return
}
// CHECK-LABEL: func @permute_generic_indexed
// CHECK: linalg.indexed_generic {args_in = 2 : i64, args_out = 1 : i64,
// CHECK: linalg.indexed_generic {
// CHECK-SAME: indexing_maps = [#[[$kn]], #[[$nm]], #[[$km]]],
// CHECK-SAME: iterator_types = ["parallel", "reduction", "parallel"],
// CHECK-SAME: library_call = "linalg_matmul_indexed"} %{{.*}}, %{{.*}}, %{{.*}}
// CHECK-SAME: library_call = "linalg_matmul_indexed"}
// CHECK: memref<?x?xf32, #[[$STRIDED_2D_u_1]]>,
// CHECK-SAME: memref<?x?xf32, #[[$STRIDED_2D_u_1]]>,
// CHECK-SAME: memref<?x?xf32, #[[$STRIDED_2D_u_1]]>
// CHECK-SAME: memref<?x?xf32, #[[$STRIDED_2D_u_1]]>

func @matvec_perm(%A: memref<?x?xf32, offset: ?, strides: [?, 1]>,
Expand Down
91 changes: 58 additions & 33 deletions mlir/test/EDSC/builder-api-test.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -344,6 +344,7 @@ TEST_FUNC(builder_helpers) {
});
});

// clang-format off
// CHECK-LABEL: @builder_helpers
// CHECK: affine.for %{{.*}} = affine_map<(d0) -> (d0)>({{.*}}) to affine_map<(d0) -> (d0)>({{.*}}) {
// CHECK-NEXT: affine.for %{{.*}} = affine_map<(d0) -> (d0)>({{.*}}) to affine_map<(d0) -> (d0)>({{.*}}) {
Expand Down Expand Up @@ -424,9 +425,11 @@ TEST_FUNC(operator_or) {
Value rhs(f.getArgument(1));
lhs || rhs;

// clang-format off
// CHECK-LABEL: @operator_or
// CHECK: [[ARG0:%.*]]: i1, [[ARG1:%.*]]: i1
// CHECK: or [[ARG0]], [[ARG1]]
// clang-format on
f.print(llvm::outs());
f.erase();
}
Expand All @@ -444,11 +447,13 @@ TEST_FUNC(operator_and) {
Value rhs(f.getArgument(1));
negate(lhs && rhs);

// clang-format off
// CHECK-LABEL: @operator_and
// CHECK: [[ARG0:%.*]]: i1, [[ARG1:%.*]]: i1
// CHECK: [[AND:%.*]] = and [[ARG0]], [[ARG1]]
// CHECK: [[TRUE:%.*]] = constant true
// CHECK: subi [[TRUE]], [[AND]] : i1
// clang-format on
f.print(llvm::outs());
f.erase();
}
Expand Down Expand Up @@ -632,6 +637,7 @@ TEST_FUNC(select_op_f32) {
std_select(ugt(B(i, j), B(i + one, j)), A(zero, zero), A(i, j));
});

// clang-format off
// CHECK-LABEL: @select_op
// CHECK: affine.for %{{.*}} = 0 to 1 {
// CHECK-NEXT: affine.for %{{.*}} = 0 to 1 {
Expand Down Expand Up @@ -886,22 +892,25 @@ TEST_FUNC(affine_if_op) {

// clang-format off
// CHECK-LABEL: func @linalg_generic_pointwise
// CHECK: linalg.generic {args_in = 2 : i64, args_out = 1 : i64,
// CHECK: linalg.generic {
// CHECK-SAME: indexing_maps = [affine_map<(d0, d1) -> (d0, d1)>, affine_map<(d0, d1) -> (d0, d1)>, affine_map<(d0, d1) -> (d0, d1)>],
// CHECK-SAME: iterator_types = ["parallel", "parallel"]}
// CHECK-SAME: ins({{.*}}memref<?x?xf32>, memref<?x?xf32>)
// CHECK-SAME: outs({{.*}}memref<?x?xf32>)
// CHECK: addf
// CHECK: }: memref<?x?xf32>, memref<?x?xf32>, memref<?x?xf32>
// CHECK: linalg.generic {args_in = 2 : i64, args_out = 1 : i64,
// CHECK: linalg.generic {
// CHECK-SAME: indexing_maps = [affine_map<(d0, d1) -> (d0, d1)>, affine_map<(d0, d1) -> (d0, d1)>, affine_map<(d0, d1) -> (d0, d1)>],
// CHECK-SAME: iterator_types = ["parallel", "parallel"]}
// CHECK-SAME: ins({{.*}}memref<?x?xf32>, memref<?x?xf32>)
// CHECK-SAME: outs({{.*}}memref<?x?xf32>)
// CHECK: cmpf "ogt"
// CHECK: select
// CHECK: }: memref<?x?xf32>, memref<?x?xf32>, memref<?x?xf32>
// CHECK: linalg.generic {args_in = 1 : i64, args_out = 1 : i64,
// CHECK: linalg.generic {
// CHECK-SAME: indexing_maps = [affine_map<(d0, d1) -> (d0, d1)>, affine_map<(d0, d1) -> (d0, d1)>],
// CHECK-SAME: iterator_types = ["parallel", "parallel"]}
// CHECK-SAME: ins(%{{[a-z0-9]*}} : memref<?x?xf32>)
// CHECK-SAME: outs(%{{[a-z0-9]*}} : memref<?x?xf32>)
// CHECK: tanh
// CHECK: }: memref<?x?xf32>, memref<?x?xf32>
// clang-format on
TEST_FUNC(linalg_generic_pointwise_test) {
using namespace edsc;
Expand Down Expand Up @@ -929,14 +938,16 @@ TEST_FUNC(linalg_generic_pointwise_test) {

// clang-format off
// CHECK-LABEL: func @linalg_generic_matmul
// CHECK: linalg.generic {args_in = 2 : i64, args_out = 1 : i64,
// CHECK: linalg.generic {
// CHECK-SAME: indexing_maps = [affine_map<(d0, d1, d2) -> (d0, d2)>, affine_map<(d0, d1, d2) -> (d2, d1)>, affine_map<(d0, d1, d2) -> (d0, d1)>],
// CHECK-SAME: iterator_types = ["parallel", "parallel", "reduction"]}
// CHECK-SAME: ins(%{{[a-z0-9]*}}, %{{[a-z0-9]*}} : memref<?x?xf32>, memref<?x?xf32>)
// CHECK-SAME: outs(%{{[a-z0-9]*}} : memref<?x?xf32>)
/// CHECK: ^bb0(%[[a0:.*]]: f32, %[[a1:.*]]: f32, %[[a2:.*]]: f32):
// CHECK: %[[a3:.*]] = mulf %[[a0]], %[[a1]] : f32
// CHECK: %[[a4:.*]] = addf %[[a2]], %[[a3]] : f32
// CHECK: linalg.yield %[[a4]] : f32
// CHECK: }: memref<?x?xf32>, memref<?x?xf32>, memref<?x?xf32>
// CHECK: }
// clang-format on
TEST_FUNC(linalg_generic_matmul_test) {
using namespace edsc;
Expand All @@ -958,16 +969,18 @@ TEST_FUNC(linalg_generic_matmul_test) {

// clang-format off
// CHECK-LABEL: func @linalg_generic_conv_nhwc
// CHECK: linalg.generic {args_in = 2 : i64, args_out = 1 : i64,
// CHECK: linalg.generic {
// CHECK-SAME: indexing_maps = [affine_map<(d0, d1, d2, d3, d4, d5, d6) -> (d0, d2 * 3 + d4 * 5, d3 * 4 + d5 * 6, d6)>,
// CHECK-SAME: affine_map<(d0, d1, d2, d3, d4, d5, d6) -> (d4, d5, d6, d1)>,
// CHECK-SAME: affine_map<(d0, d1, d2, d3, d4, d5, d6) -> (d0, d2, d3, d1)>],
// CHECK-SAME: iterator_types = ["parallel", "parallel", "parallel", "parallel", "reduction", "reduction", "reduction"]}
// CHECK-SAME: ins(%{{[a-z0-9]*}}, %{{[a-z0-9]*}} : memref<?x?x?x?xf32>, memref<?x?x?x?xf32>)
// CHECK-SAME: outs(%{{[a-z0-9]*}} : memref<?x?x?x?xf32>)
/// CHECK: ^bb0(%[[a0:.*]]: f32, %[[a1:.*]]: f32, %[[a2:.*]]: f32):
// CHECK: %[[a3:.*]] = mulf %[[a0]], %[[a1]] : f32
// CHECK: %[[a4:.*]] = addf %[[a2]], %[[a3]] : f32
// CHECK: linalg.yield %[[a4]] : f32
// CHECK: }: memref<?x?x?x?xf32>, memref<?x?x?x?xf32>, memref<?x?x?x?xf32>
// CHECK: }
// clang-format on
TEST_FUNC(linalg_generic_conv_nhwc) {
using namespace edsc;
Expand All @@ -992,16 +1005,18 @@ TEST_FUNC(linalg_generic_conv_nhwc) {

// clang-format off
// CHECK-LABEL: func @linalg_generic_dilated_conv_nhwc
// CHECK: linalg.generic {args_in = 2 : i64, args_out = 1 : i64,
// CHECK: linalg.generic {
// CHECK-SAME: indexing_maps = [affine_map<(d0, d1, d2, d3, d4, d5, d6) -> (d0, d3 * 3 + d5 * 5, d4 * 4 + d6 * 6, d2)>,
// CHECK-SAME: affine_map<(d0, d1, d2, d3, d4, d5, d6) -> (d5, d6, d2, d1)>,
// CHECK-SAME: affine_map<(d0, d1, d2, d3, d4, d5, d6) -> (d0, d3, d4, d1 + d2 * 7)>],
// CHECK-SAME: iterator_types = ["parallel", "parallel", "parallel", "parallel", "parallel", "reduction", "reduction"]}
// CHECK-SAME: ins(%{{[a-z0-9]*}}, %{{[a-z0-9]*}} : memref<?x?x?x?xf32>, memref<?x?x?x?xf32>)
// CHECK-SAME: outs(%{{[a-z0-9]*}} : memref<?x?x?x?xf32>)
// CHECK: ^bb0(%[[a0:.*]]: f32, %[[a1:.*]]: f32, %[[a2:.*]]: f32):
// CHECK: %[[a3:.*]] = mulf %[[a0]], %[[a1]] : f32
// CHECK: %[[a4:.*]] = addf %[[a2]], %[[a3]] : f32
// CHECK: linalg.yield %[[a4]] : f32
// CHECK: }: memref<?x?x?x?xf32>, memref<?x?x?x?xf32>, memref<?x?x?x?xf32>
// CHECK: }
// clang-format on
TEST_FUNC(linalg_generic_dilated_conv_nhwc) {
using namespace edsc;
Expand Down Expand Up @@ -1053,38 +1068,43 @@ TEST_FUNC(linalg_metadata_ops) {

// clang-format off
// CHECK-LABEL: func @linalg_tensors
// CHECK: linalg.generic {args_in = 2 : i64, args_out = 1 : i64,
// CHECK: linalg.generic {
// CHECK-SAME: indexing_maps = [affine_map<(d0, d1) -> (d0, d1)>, affine_map<(d0, d1) -> (d0, d1)>, affine_map<(d0, d1) -> (d0, d1)>],
// CHECK-SAME: iterator_types = ["parallel", "parallel"]}
// CHECK-SAME: ins(%{{[a-z0-9]*}}, %{{[a-z0-9]*}} : tensor<?x?xf32>, memref<?x?xf32>)
// CHECK: addf
// CHECK: }: tensor<?x?xf32>, memref<?x?xf32> -> tensor<?x?xf32>
// CHECK: linalg.generic {args_in = 2 : i64, args_out = 1 : i64,
// CHECK: } -> tensor<?x?xf32>
// CHECK: linalg.generic {
// CHECK-SAME: indexing_maps = [affine_map<(d0, d1) -> (d0, d1)>, affine_map<(d0, d1) -> (d0, d1)>, affine_map<(d0, d1) -> (d0, d1)>],
// CHECK-SAME: iterator_types = ["parallel", "parallel"]}
// CHECK-SAME: ins(%{{[a-z0-9]*}}, %{{[a-z0-9]*}} : tensor<?x?xf32>, tensor<?x?xf32>)
// CHECK: cmpf "ogt"
// CHECK: select
// CHECK: }: tensor<?x?xf32>, memref<?x?xf32> -> tensor<?x?xf32>
// CHECK: linalg.generic {args_in = 1 : i64, args_out = 1 : i64,
// CHECK: } -> tensor<?x?xf32>
// CHECK: linalg.generic {
// CHECK-SAME: indexing_maps = [affine_map<(d0, d1) -> (d0, d1)>, affine_map<(d0, d1) -> (d0, d1)>],
// CHECK-SAME: iterator_types = ["parallel", "parallel"]}
// CHECK-SAME: ins(%{{[a-z0-9]*}} : tensor<?x?xf32>)
// CHECK: tanh
// CHECK: }: tensor<?x?xf32> -> tensor<?x?xf32>
// CHECK: linalg.generic {args_in = 2 : i64, args_out = 1 : i64,
// CHECK: } -> tensor<?x?xf32>
// CHECK: linalg.generic {
// CHECK-SAME: indexing_maps = [affine_map<(d0, d1, d2) -> (d0, d2)>,
// CHECK-SAME: affine_map<(d0, d1, d2) -> (d2, d1)>,
// CHECK-SAME: affine_map<(d0, d1, d2) -> (d0, d1)>],
// CHECK-SAME: iterator_types = ["parallel", "parallel", "reduction"]}
// CHECK-SAME: ins(%{{[a-z0-9]*}}, %{{[a-z0-9]*}} : tensor<?x?xf32>, memref<?x?xf32>)
// CHECK: mulf
// CHECK: }: tensor<?x?xf32>, memref<?x?xf32> -> tensor<?x?xf32>
// CHECK: linalg.generic {args_in = 3 : i64, args_out = 1 : i64,
// CHECK: } -> tensor<?x?xf32>
// CHECK: linalg.generic {
// CHECK-SAME: indexing_maps = [affine_map<(d0, d1, d2) -> (d0, d2)>,
// CHECK-SAME: affine_map<(d0, d1, d2) -> (d2, d1)>,
// CHECK-SAME: affine_map<(d0, d1, d2) -> (d0, d1)>,
// CHECK-SAME: affine_map<(d0, d1, d2) -> (d0, d1)>],
// CHECK-SAME: iterator_types = ["parallel", "parallel", "reduction"]
// CHECK-SAME: ins(%{{[a-z0-9]*}}, %{{[a-z0-9]*}} : tensor<?x?xf32>, memref<?x?xf32>)
// CHECK-SAME: init(%{{[a-z0-9]*}} : tensor<?x?xf32>)
// CHECK: mulf
// CHECK: addf
// CHECK: }: tensor<?x?xf32>, memref<?x?xf32>, tensor<?x?xf32> -> tensor<?x?xf32>
// CHECK: } -> tensor<?x?xf32>
// clang-format on
TEST_FUNC(linalg_tensors_test) {
using namespace edsc;
Expand All @@ -1103,10 +1123,15 @@ TEST_FUNC(linalg_tensors_test) {
AffineExpr i, j;
bindDims(&globalContext(), i, j);
StructuredIndexed SA(A), SB(B), SC(tensorType);
linalg_generic_pointwise_add(SA({i, j}), SB({i, j}), SC({i, j}));
linalg_generic_pointwise_max(SA({i, j}), SB({i, j}), SC({i, j}));
linalg_generic_pointwise_tanh(SA({i, j}), SC({i, j}));
Value o1 = linalg_generic_matmul(A, B, tensorType)->getResult(0);
Value added = linalg_generic_pointwise_add(SA({i, j}), SB({i, j}), SC({i, j}))
->getResult(0);
Value maxed = linalg_generic_pointwise_max(
SA({i, j}), StructuredIndexed(added)({i, j}), SC({i, j}))
->getResult(0);
Value tanhed = linalg_generic_pointwise_tanh(StructuredIndexed(maxed)({i, j}),
SC({i, j}))
->getResult(0);
Value o1 = linalg_generic_matmul(A, B, tanhed, tensorType)->getResult(0);
linalg_generic_matmul(A, B, o1, tensorType);

f.print(llvm::outs());
Expand Down Expand Up @@ -1135,19 +1160,19 @@ TEST_FUNC(vector_extractelement_op_i32) {
f.erase();
}

// clang-format off
// CHECK-LABEL: func @memref_vector_matmul_test(
// CHECK-SAME: %[[A:.*]]: memref<?x?xvector<4x16xf32>>,
// CHECK-SAME: %[[B:.*]]: memref<?x?xvector<16x8xf32>>,
// CHECK-SAME: %[[C:.*]]: memref<?x?xvector<4x8xf32>>)
// CHECK: linalg.generic {{.*}} %[[A]], %[[B]], %[[C]]
// CHECK: vector.contract{{.*}}[affine_map<(d0, d1, d2) -> (d0,
// d2)>,
// CHECK: linalg.generic {{{.*}}}
// CHECK-SAME: ins(%[[A]], %[[B]] : memref<?x?xvector<4x16xf32>>, memref<?x?xvector<16x8xf32>>)
// CHECK-SAME: outs(%[[C]] : memref<?x?xvector<4x8xf32>>)
// CHECK: vector.contract{{.*}}[affine_map<(d0, d1, d2) -> (d0, d2)>,
// CHECK-SAME: affine_map<(d0, d1, d2) -> (d2, d1)>,
// CHECK-SAME: affine_map<(d0, d1, d2) -> (d0, d1)>],
// CHECK-SAME: {{.*}}["parallel", "parallel", "reduction"]
// CHECK-SAME: vector<4x16xf32>, vector<16x8xf32> into vector<4x8xf32>
// CHECK: memref<?x?xvector<4x16xf32>>, memref<?x?xvector<16x8xf32>>,
// CHECK-SAME: memref<?x?xvector<4x8xf32>>
// clang-format on
TEST_FUNC(memref_vector_matmul_test) {
using namespace edsc;
using namespace edsc::ops;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -19,15 +19,13 @@ func @void_function_signature_conversion(%arg0: tensor<4x8xf32>) {
func @complex_signature_conversion(%arg0: tensor<5xf32>, %arg1: memref<10xf32>, %arg2: i1, %arg3: f16) -> (i1, tensor<5xf32>, memref<10xf32>, memref<15xf32>, f16) {
%0 = alloc() : memref<15xf32>
%1 = linalg.generic {
args_in = 1 : i64,
args_out = 1 : i64,
indexing_maps = [#map0, #map0],
iterator_types = ["parallel"]
} %arg0 {
iterator_types = ["parallel"]}
ins(%arg0 : tensor<5xf32>) {
^bb0(%gen1_arg0: f32):
%tmp1 = exp %gen1_arg0 : f32
linalg.yield %tmp1 : f32
}: tensor<5xf32> -> tensor<5xf32>
} -> tensor<5xf32>
return %arg2, %1, %arg1, %0, %arg3 : i1, tensor<5xf32>, memref<10xf32>, memref<15xf32>, f16
}
// CHECK: (%[[ARG0:.*]]: memref<5xf32>, %[[ARG1:.*]]: memref<10xf32>, %[[ARG2:.*]]: i1, %[[ARG3:.*]]: f16)
Expand Down
Loading