diff --git a/mlir/lib/Dialect/GPU/IR/GPUDialect.cpp b/mlir/lib/Dialect/GPU/IR/GPUDialect.cpp index 43b02f16aa829..c0f9132de3db4 100644 --- a/mlir/lib/Dialect/GPU/IR/GPUDialect.cpp +++ b/mlir/lib/Dialect/GPU/IR/GPUDialect.cpp @@ -535,17 +535,26 @@ parseAttributions(OpAsmParser &parser, StringRef keyword, /*allowType=*/true); } -/// Prints a GPU function memory attribution. static void printAttributions(OpAsmPrinter &p, StringRef keyword, - ArrayRef values) { + ArrayRef values, + ArrayAttr attributes = {}) { if (values.empty()) return; - auto printBlockArg = [](BlockArgument v) { - return llvm::formatv("{} : {}", v, v.getType()); - }; - p << ' ' << keyword << '(' - << llvm::interleaved(llvm::map_range(values, printBlockArg)) << ')'; + p << ' ' << keyword << '('; + llvm::interleaveComma( + llvm::enumerate(values), p, [&p, attributes](auto pair) { + BlockArgument v = pair.value(); + p << v << " : " << v.getType(); + + size_t attributionIndex = pair.index(); + DictionaryAttr attrs; + if (attributes && attributionIndex < attributes.size()) + attrs = llvm::cast(attributes[attributionIndex]); + if (attrs) + p.printOptionalAttrDict(attrs.getValue()); + }); + p << ')'; } /// Verifies a GPU function memory attribution. @@ -1649,28 +1658,6 @@ ParseResult GPUFuncOp::parse(OpAsmParser &parser, OperationState &result) { return parser.parseRegion(*body, entryArgs); } -static void printAttributions(OpAsmPrinter &p, StringRef keyword, - ArrayRef values, - ArrayAttr attributes) { - if (values.empty()) - return; - - p << ' ' << keyword << '('; - llvm::interleaveComma( - llvm::enumerate(values), p, [&p, attributes](auto pair) { - BlockArgument v = pair.value(); - p << v << " : " << v.getType(); - - size_t attributionIndex = pair.index(); - DictionaryAttr attrs; - if (attributes && attributionIndex < attributes.size()) - attrs = llvm::cast(attributes[attributionIndex]); - if (attrs) - p.printOptionalAttrDict(attrs.getValue()); - }); - p << ')'; -} - void GPUFuncOp::print(OpAsmPrinter &p) { p << ' '; p.printSymbolName(getName()); diff --git a/mlir/test/Dialect/GPU/ops.mlir b/mlir/test/Dialect/GPU/ops.mlir index e3e2474d917c8..7772e7a1681c4 100644 --- a/mlir/test/Dialect/GPU/ops.mlir +++ b/mlir/test/Dialect/GPU/ops.mlir @@ -68,6 +68,31 @@ module attributes {gpu.container_module} { return } + // CHECK-LABEL: func @launch_with_attributions( + func.func @launch_with_attributions(%blk : index, %thrd : index, %float : f32, %data : memref) { + // CHECK: gpu.launch + gpu.launch blocks(%bx, %by, %bz) in (%grid_x = %blk, %grid_y = %blk, %grid_z = %blk) + threads(%tx, %ty, %tz) in (%block_x = %thrd, %block_y = %thrd, %block_z = %thrd) + // CHECK-SAME: workgroup(%[[WGROUP1:.*]] : memref<42xf32, 3>, %[[WGROUP2:.*]] : memref<2xf32, 3>) + workgroup(%arg1: memref<42xf32, 3>, %arg2: memref<2xf32, 3>) + // CHECK-SAME: private(%[[PRIVATE1:.*]] : memref<2xf32, 5>, %[[PRIVATE2:.*]] : memref<1xf32, 5>) + private(%arg3: memref<2xf32, 5>, %arg4: memref<1xf32, 5>) + { + "use"(%float) : (f32) -> () + "use"(%data) : (memref) -> () + // CHECK: "use"(%[[WGROUP1]], %[[WGROUP2]]) + "use"(%arg1, %arg2) : (memref<42xf32, 3>, memref<2xf32, 3>) -> () + // CHECK: "use"(%[[PRIVATE1]]) + "use"(%arg3) : (memref<2xf32, 5>) -> () + // CHECK: "use"(%[[PRIVATE2]]) + "use"(%arg4) : (memref<1xf32, 5>) -> () + // CHECK: gpu.terminator + gpu.terminator + } + return + } + + gpu.module @kernels { gpu.func @kernel_1(%arg0 : f32, %arg1 : memref) kernel { %tIdX = gpu.thread_id x @@ -228,17 +253,20 @@ module attributes {gpu.container_module} { gpu.module @gpu_funcs { // CHECK-LABEL: gpu.func @kernel_1({{.*}}: f32) - // CHECK: workgroup - // CHECK: private - // CHECK: attributes gpu.func @kernel_1(%arg0: f32) - workgroup(%arg1: memref<42xf32, 3>) - private(%arg2: memref<2xf32, 5>, %arg3: memref<1xf32, 5>) + // CHECK: workgroup(%[[WGROUP1:.*]] : memref<42xf32, 3>, %[[WGROUP2:.*]] : memref<2xf32, 3>) + workgroup(%arg1: memref<42xf32, 3>, %arg2: memref<2xf32, 3>) + // CHECK: private(%[[PRIVATE1:.*]] : memref<2xf32, 5>, %[[PRIVATE2:.*]] : memref<1xf32, 5>) + private(%arg3: memref<2xf32, 5>, %arg4: memref<1xf32, 5>) kernel - attributes {foo="bar"} { - "use"(%arg1) : (memref<42xf32, 3>) -> () - "use"(%arg2) : (memref<2xf32, 5>) -> () - "use"(%arg3) : (memref<1xf32, 5>) -> () + // CHECK: attributes {foo = "bar"} + attributes {foo = "bar"} { + // CHECK: "use"(%[[WGROUP1]], %[[WGROUP2]]) + "use"(%arg1, %arg2) : (memref<42xf32, 3>, memref<2xf32, 3>) -> () + // CHECK: "use"(%[[PRIVATE1]]) + "use"(%arg3) : (memref<2xf32, 5>) -> () + // CHECK: "use"(%[[PRIVATE2]]) + "use"(%arg4) : (memref<1xf32, 5>) -> () gpu.return }