Skip to content

Commit

Permalink
[MLIR][GPU] Add 16-bit version of cudaMemset in cudaRuntimeWrappers
Browse files Browse the repository at this point in the history
Add 16-bit version of cudaMemset in cudaRuntimeWrappers and update the GPU to LLVM lowering.

Reviewed By: bondhugula

Differential Revision: https://reviews.llvm.org/D151642
  • Loading branch information
navdeepkk-polymagelabs authored and bondhugula committed Jun 8, 2023
1 parent 6a3fe0f commit 18cc07a
Show file tree
Hide file tree
Showing 3 changed files with 49 additions and 8 deletions.
27 changes: 21 additions & 6 deletions mlir/lib/Conversion/GPUCommon/GPUToLLVMConversion.cpp
Expand Up @@ -97,6 +97,7 @@ class ConvertOpToGpuRuntimeCallPattern : public ConvertOpToLLVMPattern<OpTy> {
Type llvmPointerPointerType =
this->getTypeConverter()->getPointerType(llvmPointerType);
Type llvmInt8Type = IntegerType::get(context, 8);
Type llvmInt16Type = IntegerType::get(context, 16);
Type llvmInt32Type = IntegerType::get(context, 32);
Type llvmInt64Type = IntegerType::get(context, 64);
Type llvmInt8PointerType =
Expand Down Expand Up @@ -186,7 +187,14 @@ class ConvertOpToGpuRuntimeCallPattern : public ConvertOpToLLVMPattern<OpTy> {
{llvmPointerType /* void *dst */, llvmPointerType /* void *src */,
llvmIntPtrType /* intptr_t sizeBytes */,
llvmPointerType /* void *stream */}};
FunctionCallBuilder memsetCallBuilder = {
FunctionCallBuilder memset16CallBuilder = {
"mgpuMemset16",
llvmVoidType,
{llvmPointerType /* void *dst */,
llvmInt16Type /* unsigned short value */,
llvmIntPtrType /* intptr_t sizeBytes */,
llvmPointerType /* void *stream */}};
FunctionCallBuilder memset32CallBuilder = {
"mgpuMemset32",
llvmVoidType,
{llvmPointerType /* void *dst */, llvmInt32Type /* unsigned int value */,
Expand Down Expand Up @@ -1365,22 +1373,29 @@ LogicalResult ConvertMemsetOpToGpuRuntimeCallPattern::matchAndRewrite(
auto loc = memsetOp.getLoc();

Type valueType = adaptor.getValue().getType();
if (!valueType.isIntOrFloat() || valueType.getIntOrFloatBitWidth() != 32) {
return rewriter.notifyMatchFailure(memsetOp,
"value must be a 32 bit scalar");
unsigned bitWidth = valueType.getIntOrFloatBitWidth();
// Ints and floats of 16 or 32 bit width are allowed.
if (!valueType.isIntOrFloat() || (bitWidth != 16 && bitWidth != 32)) {
return rewriter.notifyMatchFailure(
memsetOp, "value must be a 16 or 32 bit int or float");
}

unsigned valueTypeWidth = valueType.getIntOrFloatBitWidth();
Type bitCastType = valueTypeWidth == 32 ? llvmInt32Type : llvmInt16Type;

MemRefDescriptor dstDesc(adaptor.getDst());
Value numElements = getNumElements(rewriter, loc, memRefType, dstDesc);

auto value =
rewriter.create<LLVM::BitcastOp>(loc, llvmInt32Type, adaptor.getValue());
rewriter.create<LLVM::BitcastOp>(loc, bitCastType, adaptor.getValue());
auto dst = bitAndAddrspaceCast(loc, rewriter, llvmPointerType,
dstDesc.alignedPtr(rewriter, loc),
*getTypeConverter());

auto stream = adaptor.getAsyncDependencies().front();
memsetCallBuilder.create(loc, rewriter, {dst, value, numElements, stream});
FunctionCallBuilder builder =
valueTypeWidth == 32 ? memset32CallBuilder : memset16CallBuilder;
builder.create(loc, rewriter, {dst, value, numElements, stream});

rewriter.replaceOp(memsetOp, {stream});
return success();
Expand Down
6 changes: 6 additions & 0 deletions mlir/lib/ExecutionEngine/CudaRuntimeWrappers.cpp
Expand Up @@ -176,6 +176,12 @@ extern "C" void mgpuMemset32(void *dst, unsigned int value, size_t count,
value, count, stream));
}

extern "C" void mgpuMemset16(void *dst, unsigned short value, size_t count,
CUstream stream) {
CUDA_REPORT_IF_ERROR(cuMemsetD16Async(reinterpret_cast<CUdeviceptr>(dst),
value, count, stream));
}

///
/// Helper functions for writing mlir example code
///
Expand Down
24 changes: 22 additions & 2 deletions mlir/test/Conversion/GPUCommon/typed-pointers.mlir
Expand Up @@ -42,8 +42,8 @@ module attributes {gpu.container_module} {

module attributes {gpu.container_module} {

// CHECK: func @foo
func.func @foo(%dst : memref<7xf32, 1>, %value : f32) {
// CHECK: func @memset_f32
func.func @memset_f32(%dst : memref<7xf32, 1>, %value : f32) {
// CHECK: %[[t0:.*]] = llvm.call @mgpuStreamCreate
%t0 = gpu.wait async
// CHECK: %[[size_bytes:.*]] = llvm.mlir.constant
Expand All @@ -59,3 +59,23 @@ module attributes {gpu.container_module} {
}
}

// -----

module attributes {gpu.container_module} {

// CHECK: func @memset_f16
func.func @memset_f16(%dst : memref<7xf16, 1>, %value : f16) {
// CHECK: %[[t0:.*]] = llvm.call @mgpuStreamCreate
%t0 = gpu.wait async
// CHECK: %[[size_bytes:.*]] = llvm.mlir.constant
// CHECK: %[[value:.*]] = llvm.bitcast
// CHECK: %[[addr_cast:.*]] = llvm.addrspacecast
// CHECK: %[[dst:.*]] = llvm.bitcast %[[addr_cast]]
// CHECK: llvm.call @mgpuMemset16(%[[dst]], %[[value]], %[[size_bytes]], %[[t0]])
%t1 = gpu.memset async [%t0] %dst, %value : memref<7xf16, 1>, f16
// CHECK: llvm.call @mgpuStreamSynchronize(%[[t0]])
// CHECK: llvm.call @mgpuStreamDestroy(%[[t0]])
gpu.wait [%t1]
return
}
}

0 comments on commit 18cc07a

Please sign in to comment.