diff --git a/mlir/include/mlir/Dialect/Shard/IR/ShardOps.td b/mlir/include/mlir/Dialect/Shard/IR/ShardOps.td index 5e68f75ee08bf..6ef7c72d305ee 100644 --- a/mlir/include/mlir/Dialect/Shard/IR/ShardOps.td +++ b/mlir/include/mlir/Dialect/Shard/IR/ShardOps.td @@ -530,11 +530,11 @@ def Shard_AllGatherOp : Shard_CollectiveCommunicationOpBase<"all_gather", [ ``` }]; let arguments = !con(commonArgs, (ins - AnyNon0RankedTensor:$input, + AnyTypeOf<[AnyMemRef, AnyRankedTensor]>:$input, IndexAttr:$gather_axis )); let results = (outs - AnyNon0RankedTensor:$result + AnyTypeOf<[AnyMemRef, AnyRankedTensor]>:$result ); let assemblyFormat = [{ $input `on` $grid (`grid_axes` `=` $grid_axes^)? `gather_axis` `=` $gather_axis diff --git a/mlir/include/mlir/Dialect/Shard/Transforms/Transforms.h b/mlir/include/mlir/Dialect/Shard/Transforms/Transforms.h index 57d65e687ea35..1ddd1985389bc 100644 --- a/mlir/include/mlir/Dialect/Shard/Transforms/Transforms.h +++ b/mlir/include/mlir/Dialect/Shard/Transforms/Transforms.h @@ -39,14 +39,14 @@ createCollectiveProcessGroupSize(GridOp grid, ArrayRef axes, ImplicitLocOpBuilder &builder); // Get process linear index along the given grid axes. -TypedValue createProcessLinearIndex(StringRef grid, - ArrayRef gridAxes, - ImplicitLocOpBuilder &builder); +TypedValue +createProcessLinearIndex(ImplicitLocOpBuilder &builder, StringRef grid, + ArrayRef gridAxes = {}); // Get process linear index from a multi-index along the given grid axes . TypedValue -createProcessLinearIndex(StringRef grid, ValueRange processInGroupMultiIndex, - ArrayRef gridAxes, - ImplicitLocOpBuilder &builder); +createProcessLinearIndex(ImplicitLocOpBuilder &builder, StringRef grid, + ValueRange processInGroupMultiIndex, + ArrayRef gridAxes = {}); } // namespace shard } // namespace mlir diff --git a/mlir/lib/Conversion/ShardToMPI/ShardToMPI.cpp b/mlir/lib/Conversion/ShardToMPI/ShardToMPI.cpp index b0831dc05abb1..87ae28892fcf7 100644 --- a/mlir/lib/Conversion/ShardToMPI/ShardToMPI.cpp +++ b/mlir/lib/Conversion/ShardToMPI/ShardToMPI.cpp @@ -22,6 +22,7 @@ #include "mlir/Dialect/Linalg/IR/Linalg.h" #include "mlir/Dialect/MPI/IR/MPI.h" #include "mlir/Dialect/MemRef/IR/MemRef.h" +#include "mlir/Dialect/MemRef/Utils/MemRefUtils.h" #include "mlir/Dialect/SCF/IR/SCF.h" #include "mlir/Dialect/Shard/IR/ShardDialect.h" #include "mlir/Dialect/Shard/IR/ShardOps.h" @@ -507,103 +508,152 @@ static mpi::MPI_ReductionOpEnumAttr getMPIReductionOp(ReductionKindAttr kind) { } } -struct ConvertAllReduceOp : public OpConversionPattern { - using OpConversionPattern::OpConversionPattern; +template +struct CommOpPattern : public OpConversionPattern { + using OpConversionPattern::OpConversionPattern; - LogicalResult - matchAndRewrite(AllReduceOp op, OpAdaptor adaptor, - ConversionPatternRewriter &rewriter) const override { - SymbolTableCollection symbolTableCollection; - auto grid = adaptor.getGrid(); - mlir::shard::GridOp gridOp = getGrid(op, symbolTableCollection); - if (!gridOp) - return op->emitError() << "No grid found for AllReduceOp"; - if (ShapedType::isDynamicShape(gridOp.getShape())) - return op->emitError() - << "Dynamic grid shape not supported in AllReduceOp"; - - ImplicitLocOpBuilder iBuilder(op.getLoc(), rewriter); - Value input = adaptor.getInput(); - auto inputShape = cast(input.getType()).getShape(); + MemRefType getMemrefType(ShapedType tensorType) const { + return MemRefType::get(tensorType.getShape(), tensorType.getElementType()); + } + Value getAsMemref(Value input, ImplicitLocOpBuilder &iBuilder) const { + auto itype = input.getType(); // If the source is a memref, cast it to a tensor. - if (isa(input.getType())) { - auto memrefType = MemRefType::get( - inputShape, cast(input.getType()).getElementType()); + if (isa(itype)) { + auto memrefType = getMemrefType(cast(itype)); input = bufferization::ToBufferOp::create(iBuilder, memrefType, input); + } else { + assert(isa(itype) && + "expected input to be of MemRefType or TensorType"); } - MemRefType inType = cast(input.getType()); + return input; + } - // Get the actual shape to allocate the buffer. - SmallVector shape(inType.getRank()); - for (auto i = 0; i < inType.getRank(); ++i) { - auto s = inputShape[i]; - if (ShapedType::isDynamic(s)) - shape[i] = memref::DimOp::create(iBuilder, input, s).getResult(); - else - shape[i] = iBuilder.getIndexAttr(s); - } + FailureOr checkGrid(CommOp op, + SymbolTableCollection &symbolTableCollection, + bool allowDynamic = false) const { + GridOp gridOp = getGrid(op, symbolTableCollection); + if (!gridOp) + return op->emitError() << "Missing grid symbol."; + if (!allowDynamic && ShapedType::isDynamicShape(gridOp.getShape())) + return op->emitError() << "Dynamic grid shape not supported."; + return gridOp; + } - // Allocate buffer and copy input to buffer. - Value buffer = memref::AllocOp::create( - iBuilder, shape, cast(op.getType()).getElementType()); - linalg::CopyOp::create(iBuilder, input, buffer); + // Get an MPI_Comm_split for a given grid and axes. + // The color is the linear index of the process in the grid along the + // non-'grid-axes'. The key is the linear index of the process in the grid + // along the grid-axes. + Value getComm(GridOp &gridOp, ::llvm::ArrayRef gridAxes, + ImplicitLocOpBuilder &iBuilder) const { + size_t gridDims = gridOp.getShape().size(); + auto commType = mpi::CommType::get(gridOp->getContext()); + Value commWorld = mpi::CommWorldOp::create(iBuilder, commType); - // Get an MPI_Comm_split for the AllReduce operation. - // The color is the linear index of the process in the grid along the - // non-reduced axes. The key is the linear index of the process in the grid - // along the reduced axes. - SmallVector indexResultTypes(gridOp.getShape().size(), - iBuilder.getIndexType()); - SmallVector myMultiIndex = - ProcessMultiIndexOp::create(iBuilder, indexResultTypes, grid) - .getResult(); - Value zero = arith::ConstantIndexOp::create(iBuilder, 0); - SmallVector multiKey(myMultiIndex.size(), zero); + if (gridAxes.empty() || gridAxes.size() >= gridDims) { + return commWorld; + } - auto redAxes = adaptor.getGridAxes(); - for (auto axis : redAxes) { - multiKey[axis] = myMultiIndex[axis]; - myMultiIndex[axis] = zero; + SmallVector otherAxes; + for (GridAxis i = 0; i < static_cast(gridDims); ++i) { + if (!llvm::is_contained(gridAxes, i)) + otherAxes.emplace_back(i); } + SmallVector indexResultTypes(otherAxes.size(), + iBuilder.getIndexType()); + Value color = - createProcessLinearIndex(grid, myMultiIndex, redAxes, iBuilder); + createProcessLinearIndex(iBuilder, gridOp.getSymName(), otherAxes); color = arith::IndexCastOp::create(iBuilder, iBuilder.getI32Type(), color); - Value key = createProcessLinearIndex(grid, multiKey, redAxes, iBuilder); + + Value key = + createProcessLinearIndex(iBuilder, gridOp.getSymName(), gridAxes); key = arith::IndexCastOp::create(iBuilder, iBuilder.getI32Type(), key); // Finally split the communicator - auto commType = mpi::CommType::get(op->getContext()); - Value commWorld = mpi::CommWorldOp::create(iBuilder, commType); - auto comm = - mpi::CommSplitOp::create(iBuilder, commType, commWorld, color, key) - .getNewcomm(); - - Value buffer1d = buffer; - // Collapse shape to 1d if needed - if (inType.getRank() > 1) { - ReassociationIndices reassociation(inType.getRank()); - std::iota(reassociation.begin(), reassociation.end(), 0); - buffer1d = memref::CollapseShapeOp::create( - iBuilder, buffer, ArrayRef(reassociation)); - } + return mpi::CommSplitOp::create(iBuilder, commType, commWorld, color, key) + .getNewcomm(); + } +}; +struct ConvertAllReduceOp : public CommOpPattern { + using CommOpPattern::CommOpPattern; + + LogicalResult + matchAndRewrite(AllReduceOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + SymbolTableCollection symbolTableCollection; + FailureOr gridOp = checkGrid(op, symbolTableCollection); + if (failed(gridOp)) + return failure(); + ImplicitLocOpBuilder iBuilder(op.getLoc(), rewriter); + Value input = getAsMemref(adaptor.getInput(), iBuilder); + MemRefType inType = cast(input.getType()); + if (!memref::isStaticShapeAndContiguousRowMajor(inType)) + return op.emitError( + "Expected static shaped memref in contiguous row-major layout."); + MemRefType outType = getMemrefType(cast(op.getType())); + if (!memref::isStaticShapeAndContiguousRowMajor(outType)) + return op.emitError( + "Expected static shaped memref in contiguous row-major layout."); + + // Allocate buffer and copy input to buffer. + Value buffer = memref::AllocOp::create(iBuilder, outType); + linalg::CopyOp::create(iBuilder, input, buffer); + // Get the right communicator + Value comm = getComm(*gridOp, adaptor.getGridAxes(), iBuilder); // Create the MPI AllReduce operation. - mpi::AllReduceOp::create(iBuilder, TypeRange(), buffer1d, buffer1d, + mpi::AllReduceOp::create(iBuilder, TypeRange(), buffer, buffer, getMPIReductionOp(adaptor.getReductionAttr()), comm); - // If the destination is a memref, cast it to a tensor + // If the destination is a tensor, cast it to a tensor if (isa(op.getType())) buffer = bufferization::ToTensorOp::create(iBuilder, op.getType(), buffer, true); - rewriter.replaceOp(op, buffer); return success(); } }; +struct ConvertAllGatherOp : public CommOpPattern { + using CommOpPattern::CommOpPattern; + + LogicalResult + matchAndRewrite(AllGatherOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + SymbolTableCollection symbolTableCollection; + FailureOr gridOp = checkGrid(op, symbolTableCollection); + if (failed(gridOp)) + return failure(); + ImplicitLocOpBuilder iBuilder(op.getLoc(), rewriter); + Value input = getAsMemref(adaptor.getInput(), iBuilder); + MemRefType inType = cast(input.getType()); + if (!memref::isStaticShapeAndContiguousRowMajor(inType)) + return op.emitError( + "Expected static shaped memref in contiguous row-major layout."); + MemRefType outType = getMemrefType(cast(op.getType())); + if (!memref::isStaticShapeAndContiguousRowMajor(outType)) + return op.emitError( + "Expected static shaped memref in contiguous row-major layout."); + + // Get the right communicator + Value comm = getComm(*gridOp, adaptor.getGridAxes(), iBuilder); + // Allocate output buffer + Value output = memref::AllocOp::create(iBuilder, outType); + // Create the MPI AllGather operation. + mpi::AllGatherOp::create(iBuilder, TypeRange(), input, output, comm); + + // If the destination is a tensor, cast it to a tensor + if (isa(op.getType())) + output = bufferization::ToTensorOp::create(iBuilder, op.getType(), output, + true); + rewriter.replaceOp(op, output); + return success(); + } +}; + struct ConvertUpdateHaloOp : public OpConversionPattern { using OpConversionPattern::OpConversionPattern; @@ -895,8 +945,8 @@ struct ConvertShardToMPIPass patterns.add(typeConverter, - ctxt); + ConvertAllGatherOp, ConvertAllReduceOp, + ConvertProcessLinearIndexOp>(typeConverter, ctxt); SymbolTableCollection stc; populateProcessMultiIndexOpLoweringPatterns(patterns, stc); populateAllSliceOpLoweringPatterns(patterns, stc); diff --git a/mlir/lib/Dialect/Linalg/Transforms/ShardingInterfaceImpl.cpp b/mlir/lib/Dialect/Linalg/Transforms/ShardingInterfaceImpl.cpp index 0ae2a9cc0318c..d0165595f9fb6 100644 --- a/mlir/lib/Dialect/Linalg/Transforms/ShardingInterfaceImpl.cpp +++ b/mlir/lib/Dialect/Linalg/Transforms/ShardingInterfaceImpl.cpp @@ -128,7 +128,7 @@ static Value createDestinationPassingStyleInitOperand( ArrayRef reductionGridAxes, GridOp gridOp, ImplicitLocOpBuilder &builder) { Value processLinearIndexInReductionGroup = shard::createProcessLinearIndex( - gridOp.getSymName(), reductionGridAxes, builder); + builder, gridOp.getSymName(), reductionGridAxes); Value zero = arith::ConstantIndexOp::create(builder, 0); Value isLeadProcess = arith::CmpIOp::create( builder, builder.getI1Type(), arith::CmpIPredicate::eq, diff --git a/mlir/lib/Dialect/Shard/Transforms/Transforms.cpp b/mlir/lib/Dialect/Shard/Transforms/Transforms.cpp index b433b8b0be7b2..835bc443d4b2a 100644 --- a/mlir/lib/Dialect/Shard/Transforms/Transforms.cpp +++ b/mlir/lib/Dialect/Shard/Transforms/Transforms.cpp @@ -208,9 +208,9 @@ createCollectiveProcessGroupSize(GridOp grid, ArrayRef axes, } TypedValue -createProcessLinearIndex(StringRef grid, ValueRange processInGroupMultiIndex, - ArrayRef gridAxes, - ImplicitLocOpBuilder &builder) { +createProcessLinearIndex(ImplicitLocOpBuilder &builder, StringRef grid, + ValueRange processInGroupMultiIndex, + ArrayRef gridAxes) { Operation::result_range processGroupShape = GridShapeOp::create(builder, grid, gridAxes).getResult(); OpFoldResult processInGroupLinearIndex = affine::linearizeIndex( @@ -224,11 +224,12 @@ createProcessLinearIndex(StringRef grid, ValueRange processInGroupMultiIndex, return cast>(res); } -TypedValue createProcessLinearIndex(StringRef grid, - ArrayRef gridAxes, - ImplicitLocOpBuilder &builder) { +TypedValue createProcessLinearIndex(ImplicitLocOpBuilder &builder, + StringRef grid, + ArrayRef gridAxes) { return createProcessLinearIndex( - grid, ProcessMultiIndexOp::create(builder, grid, gridAxes).getResults(), - gridAxes, builder); + builder, grid, + ProcessMultiIndexOp::create(builder, grid, gridAxes).getResults(), + gridAxes); } } // namespace mlir::shard diff --git a/mlir/test/Conversion/ShardToMPI/convert-shard-to-mpi.mlir b/mlir/test/Conversion/ShardToMPI/convert-shard-to-mpi.mlir index a0b6bfaf6fd3d..9a8ad5eea1c7b 100644 --- a/mlir/test/Conversion/ShardToMPI/convert-shard-to-mpi.mlir +++ b/mlir/test/Conversion/ShardToMPI/convert-shard-to-mpi.mlir @@ -102,15 +102,14 @@ module attributes { mpi.dlti = #dlti.map<"MPI:comm_world_rank" = 7> } { func.func @allreduce_tensor( // CHECK-SAME: [[varg0:%.*]]: tensor<3x4xf32> %arg0 : tensor<3x4xf32>) -> tensor<3x4xf32> { - // CHECK-DAG: [[vc4_i32:%.*]] = arith.constant 4 : i32 + // CHECK-DAG: [[vc1_i32:%.*]] = arith.constant 1 : i32 // CHECK-DAG: [[vc2_i32:%.*]] = arith.constant 2 : i32 // CHECK: [[v0:%.*]] = bufferization.to_buffer [[varg0]] : tensor<3x4xf32> to memref<3x4xf32> // CHECK: [[valloc:%.*]] = memref.alloc() : memref<3x4xf32> // CHECK: linalg.copy ins([[v0]] : memref<3x4xf32>) outs([[valloc]] : memref<3x4xf32>) // CHECK: [[v1:%.*]] = mpi.comm_world : !mpi.comm - // CHECK: [[vnewcomm:%.*]] = mpi.comm_split([[v1]], [[vc2_i32]], [[vc4_i32]]) : !mpi.comm - // CHECK: [[vcollapse_shape:%.*]] = memref.collapse_shape [[valloc]] {{\[\[}}0, 1]] : memref<3x4xf32> into memref<12xf32> - // CHECK: mpi.allreduce([[vcollapse_shape]], [[vcollapse_shape]], MPI_MAX, [[vnewcomm]]) : memref<12xf32>, memref<12xf32> + // CHECK: [[vnewcomm:%.*]] = mpi.comm_split([[v1]], [[vc2_i32]], [[vc1_i32]]) : !mpi.comm + // CHECK: mpi.allreduce([[valloc]], [[valloc]], MPI_MAX, [[vnewcomm]]) : memref<3x4xf32>, memref<3x4xf32> // CHECK: [[v2:%.*]] = bufferization.to_tensor [[valloc]] restrict : memref<3x4xf32> to tensor<3x4xf32> %0 = shard.all_reduce %arg0 on @grid0 grid_axes = [0, 1] reduction = max : tensor<3x4xf32> -> tensor<3x4xf32> // CHECK: return [[v2]] : tensor<3x4xf32> @@ -121,14 +120,13 @@ module attributes { mpi.dlti = #dlti.map<"MPI:comm_world_rank" = 7> } { func.func @allreduce_memref( // CHECK-SAME: [[varg0:%.*]]: memref<3x4xf32> %arg0 : memref<3x4xf32>) -> memref<3x4xf32> { - // CHECK: [[vc4_i32:%.*]] = arith.constant 4 : i32 - // CHECK: [[vc2_i32:%.*]] = arith.constant 2 : i32 + // CHECK-DAG: [[vc1_i32:%.*]] = arith.constant 1 : i32 + // CHECK-DAG: [[vc2_i32:%.*]] = arith.constant 2 : i32 // CHECK: [[valloc:%.*]] = memref.alloc() : memref<3x4xf32> // CHECK: linalg.copy ins([[varg0]] : memref<3x4xf32>) outs([[valloc]] : memref<3x4xf32>) // CHECK: [[v0:%.*]] = mpi.comm_world : !mpi.comm - // CHECK: [[vnewcomm:%.*]] = mpi.comm_split([[v0]], [[vc2_i32]], [[vc4_i32]]) : !mpi.comm - // CHECK: [[vcollapse_shape:%.*]] = memref.collapse_shape [[valloc]] {{\[\[}}0, 1]] : memref<3x4xf32> into memref<12xf32> - // CHECK: mpi.allreduce([[vcollapse_shape]], [[vcollapse_shape]], MPI_MAX, [[vnewcomm]]) : memref<12xf32>, memref<12xf32> + // CHECK: [[vnewcomm:%.*]] = mpi.comm_split([[v0]], [[vc2_i32]], [[vc1_i32]]) : !mpi.comm + // CHECK: mpi.allreduce([[valloc]], [[valloc]], MPI_MAX, [[vnewcomm]]) : memref<3x4xf32>, memref<3x4xf32> %0 = shard.all_reduce %arg0 on @grid0 grid_axes = [0, 1] reduction = max : memref<3x4xf32> -> memref<3x4xf32> // CHECK: return [[valloc]] : memref<3x4xf32> return %0 : memref<3x4xf32> @@ -138,18 +136,51 @@ module attributes { mpi.dlti = #dlti.map<"MPI:comm_world_rank" = 7> } { func.func @allreduce_new_type( // CHECK-SAME: [[varg0:%.*]]: memref<3x4xf32> %arg0 : memref<3x4xf32>) -> memref<3x4xf64> { - // CHECK: [[vc4_i32:%.*]] = arith.constant 4 : i32 - // CHECK: [[vc2_i32:%.*]] = arith.constant 2 : i32 + // CHECK-DAG: [[vc1_i32:%.*]] = arith.constant 1 : i32 + // CHECK-DAG: [[vc2_i32:%.*]] = arith.constant 2 : i32 // CHECK: [[valloc:%.*]] = memref.alloc() : memref<3x4xf64> // CHECK: linalg.copy ins([[varg0]] : memref<3x4xf32>) outs([[valloc]] : memref<3x4xf64>) // CHECK: [[v0:%.*]] = mpi.comm_world : !mpi.comm - // CHECK: [[vnewcomm:%.*]] = mpi.comm_split([[v0]], [[vc2_i32]], [[vc4_i32]]) : !mpi.comm - // CHECK: [[vcollapse_shape:%.*]] = memref.collapse_shape [[valloc]] {{\[\[}}0, 1]] : memref<3x4xf64> into memref<12xf64> - // CHECK: mpi.allreduce([[vcollapse_shape]], [[vcollapse_shape]], MPI_MAX, [[vnewcomm]]) : memref<12xf64>, memref<12xf64> + // CHECK: [[vnewcomm:%.*]] = mpi.comm_split([[v0]], [[vc2_i32]], [[vc1_i32]]) : !mpi.comm + // CHECK: mpi.allreduce([[valloc]], [[valloc]], MPI_MAX, [[vnewcomm]]) : memref<3x4xf64>, memref<3x4xf64> %0 = shard.all_reduce %arg0 on @grid0 grid_axes = [0, 1] reduction = max : memref<3x4xf32> -> memref<3x4xf64> // CHECK: return [[valloc]] : memref<3x4xf64> return %0 : memref<3x4xf64> } + + // CHECK-LABEL: func @allgather_tensor + func.func @allgather_tensor( + // CHECK-SAME: [[varg0:%.*]]: tensor<3x4xf32> + // CHECK-SAME: -> tensor<3x20xf32> + %arg0 : tensor<3x4xf32>) -> tensor<3x20xf32> { + // CHECK-DAG: [[vc2_i32:%.*]] = arith.constant 2 : i32 + // CHECK-DAG: [[vc1_i32:%.*]] = arith.constant 1 : i32 + // CHECK: [[v0:%.*]] = bufferization.to_buffer [[varg0]] : tensor<3x4xf32> to memref<3x4xf32> + // CHECK: [[v1:%.*]] = mpi.comm_world : !mpi.comm + // CHECK: [[vnewcomm:%.*]] = mpi.comm_split([[v1]], [[vc1_i32]], [[vc2_i32]]) : !mpi.comm + // CHECK: [[valloc:%.*]] = memref.alloc() : memref<3x20xf32> + // CHECK: mpi.allgather([[v0]], [[valloc]], [[vnewcomm]]) : memref<3x4xf32>, memref<3x20xf32> + // CHECK: [[v2:%.*]] = bufferization.to_tensor [[valloc]] restrict : memref<3x20xf32> to tensor<3x20xf32> + %0 = shard.all_gather %arg0 on @grid0 grid_axes = [2] gather_axis = 1 : tensor<3x4xf32> -> tensor<3x20xf32> + // CHECK: return [[v2]] : tensor<3x20xf32> + return %0 : tensor<3x20xf32> + } + + // CHECK-LABEL: func @allgather_memref + func.func @allgather_memref( + // CHECK-SAME: [[varg0:%.*]]: memref<3x4xf32> + // CHECK-SAME: -> memref<3x20xf32> + %arg0 : memref<3x4xf32>) -> memref<3x20xf32> { + // CHECK-DAG: [[vc1_i32:%.*]] = arith.constant 1 : i32 + // CHECK-DAG: [[vc2_i32:%.*]] = arith.constant 2 : i32 + // CHECK: [[v0:%.*]] = mpi.comm_world : !mpi.comm + // CHECK: [[vnewcomm:%.*]] = mpi.comm_split([[v0]], [[vc1_i32]], [[vc2_i32]]) : !mpi.comm + // CHECK: [[valloc:%.*]] = memref.alloc() : memref<3x20xf32> + // CHECK: mpi.allgather([[varg0]], [[valloc]], [[vnewcomm]]) : memref<3x4xf32>, memref<3x20xf32> + %0 = shard.all_gather %arg0 on @grid0 grid_axes = [2] gather_axis = 1 : memref<3x4xf32> -> memref<3x20xf32> + // CHECK: return [[valloc]] : memref<3x20xf32> + return %0 : memref<3x20xf32> + } } // -----