-
Notifications
You must be signed in to change notification settings - Fork 12.3k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[MLIR][GPU] Support bf16 and i1 gpu::shuffles to LLVMSPIRV conversion #119675
base: main
Are you sure you want to change the base?
Conversation
|
@llvm/pr-subscribers-mlir Author: Pietro Ghiglio (PietroGhg) ChangesThis PR adds support to the Full diff: https://github.com/llvm/llvm-project/pull/119675.diff 2 Files Affected:
diff --git a/mlir/lib/Conversion/GPUToLLVMSPV/GPUToLLVMSPV.cpp b/mlir/lib/Conversion/GPUToLLVMSPV/GPUToLLVMSPV.cpp
index 03745f4537e99e..415e67aebab978 100644
--- a/mlir/lib/Conversion/GPUToLLVMSPV/GPUToLLVMSPV.cpp
+++ b/mlir/lib/Conversion/GPUToLLVMSPV/GPUToLLVMSPV.cpp
@@ -262,15 +262,20 @@ struct GPUShuffleConversion final : ConvertOpToLLVMPattern<gpu::ShuffleOp> {
.Default([](auto) { return std::nullopt; });
}
- static std::optional<std::string> getFuncName(gpu::ShuffleOp op) {
- StringRef baseName = getBaseName(op.getMode());
- std::optional<StringRef> typeMangling = getTypeMangling(op.getType(0));
+ static std::optional<std::string> getFuncName(gpu::ShuffleMode mode,
+ Type type) {
+ StringRef baseName = getBaseName(mode);
+ std::optional<StringRef> typeMangling = getTypeMangling(type);
if (!typeMangling)
return std::nullopt;
return llvm::formatv("_Z{0}{1}{2}", baseName.size(), baseName,
typeMangling.value());
}
+ static std::optional<std::string> getFuncName(gpu::ShuffleOp op) {
+ return getFuncName(op.getMode(), op.getType(0));
+ }
+
/// Get the subgroup size from the target or return a default.
static std::optional<int> getSubgroupSize(Operation *op) {
auto parentFunc = op->getParentOfType<LLVM::LLVMFuncOp>();
@@ -286,6 +291,51 @@ struct GPUShuffleConversion final : ConvertOpToLLVMPattern<gpu::ShuffleOp> {
val == getSubgroupSize(op);
}
+ static bool needsBitCastOrExt(gpu::ShuffleOp op) {
+ Type type = op.getType(0);
+ return isa<BFloat16Type>(type) || type.isInteger(1);
+ }
+
+ static Type getBitCastOrExtTy(Type oldTy,
+ ConversionPatternRewriter &rewriter) {
+ return TypeSwitch<Type, Type>(oldTy)
+ .Case<BFloat16Type>([&](auto) { return rewriter.getIntegerType(16); })
+ .Case<IntegerType>([&](auto intTy) -> Type {
+ if (intTy.getWidth() == 1)
+ return rewriter.getIntegerType(8);
+ return Type{};
+ })
+ .Default([](auto) { return Type{}; });
+ }
+
+ static Value doBitcastOrExt(Value oldVal, Type newTy, Location loc,
+ ConversionPatternRewriter &rewriter) {
+ return TypeSwitch<Type, Value>(oldVal.getType())
+ .Case<BFloat16Type>([&](auto) {
+ return rewriter.create<LLVM::BitcastOp>(loc, newTy, oldVal);
+ })
+ .Case<IntegerType>([&](auto intTy) -> Value {
+ if (intTy.getWidth() == 1)
+ return rewriter.create<LLVM::ZExtOp>(loc, newTy, oldVal);
+ return Value{};
+ })
+ .Default([](auto) { return Value{}; });
+ }
+
+ static Value doBitcastOrTrunc(Value oldVal, Type newTy, Location loc,
+ ConversionPatternRewriter &rewriter) {
+ return TypeSwitch<Type, Value>(newTy)
+ .Case<BFloat16Type>([&](auto) {
+ return rewriter.create<LLVM::BitcastOp>(loc, newTy, oldVal);
+ })
+ .Case<IntegerType>([&](auto intTy) -> Value {
+ if (intTy.getWidth() == 1)
+ return rewriter.create<LLVM::TruncOp>(loc, newTy, oldVal);
+ return Value{};
+ })
+ .Default([](auto) { return Value{}; });
+ }
+
LogicalResult
matchAndRewrite(gpu::ShuffleOp op, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const final {
@@ -293,23 +343,42 @@ struct GPUShuffleConversion final : ConvertOpToLLVMPattern<gpu::ShuffleOp> {
return rewriter.notifyMatchFailure(
op, "shuffle width and subgroup size mismatch");
- std::optional<std::string> funcName = getFuncName(op);
+ Location loc = op->getLoc();
+ Type bitcastOrExtDestTy = getBitCastOrExtTy(op.getType(0), rewriter);
+ std::optional<std::string> funcName;
+ Value inValue;
+ if (bitcastOrExtDestTy) {
+ Value newVal =
+ doBitcastOrExt(adaptor.getValue(), bitcastOrExtDestTy, loc, rewriter);
+ assert(newVal && "Unhandled op type in bitcastorext");
+ funcName = getFuncName(op.getMode(), bitcastOrExtDestTy);
+ inValue = newVal;
+ } else {
+ funcName = getFuncName(op);
+ inValue = adaptor.getValue();
+ }
if (!funcName)
return rewriter.notifyMatchFailure(op, "unsupported value type");
Operation *moduleOp = op->getParentWithTrait<OpTrait::SymbolTable>();
assert(moduleOp && "Expecting module");
- Type valueType = adaptor.getValue().getType();
+ Type valueType = inValue.getType();
Type offsetType = adaptor.getOffset().getType();
Type resultType = valueType;
LLVM::LLVMFuncOp func = lookupOrCreateSPIRVFn(
moduleOp, funcName.value(), {valueType, offsetType}, resultType,
/*isMemNone=*/false, /*isConvergent=*/true);
- Location loc = op->getLoc();
- std::array<Value, 2> args{adaptor.getValue(), adaptor.getOffset()};
+ std::array<Value, 2> args{inValue, adaptor.getOffset()};
Value result =
createSPIRVBuiltinCall(loc, rewriter, func, args).getResult();
+ if (bitcastOrExtDestTy) {
+ Value newVal =
+ doBitcastOrTrunc(result, adaptor.getValue().getType(), loc, rewriter);
+ assert(newVal && "Unhandled op type in bitcastortrunc");
+ result = newVal;
+ }
+
Value trueVal =
rewriter.create<LLVM::ConstantOp>(loc, rewriter.getI1Type(), true);
rewriter.replaceOp(op, {result, trueVal});
diff --git a/mlir/test/Conversion/GPUToLLVMSPV/gpu-to-llvm-spv.mlir b/mlir/test/Conversion/GPUToLLVMSPV/gpu-to-llvm-spv.mlir
index 16b692b9689398..6fab647cb35681 100644
--- a/mlir/test/Conversion/GPUToLLVMSPV/gpu-to-llvm-spv.mlir
+++ b/mlir/test/Conversion/GPUToLLVMSPV/gpu-to-llvm-spv.mlir
@@ -277,7 +277,8 @@ gpu.module @shuffles {
// CHECK-SAME: (%[[I8_VAL:.*]]: i8, %[[I16_VAL:.*]]: i16,
// CHECK-SAME: %[[I32_VAL:.*]]: i32, %[[I64_VAL:.*]]: i64,
// CHECK-SAME: %[[F16_VAL:.*]]: f16, %[[F32_VAL:.*]]: f32,
- // CHECK-SAME: %[[F64_VAL:.*]]: f64, %[[OFFSET:.*]]: i32)
+ // CHECK-SAME: %[[F64_VAL:.*]]: f64, %[[BF16_VAL:.*]]: bf16,
+ // CHECK-SAME: %[[I1_VAL:.*]]: i1, %[[OFFSET:.*]]: i32)
llvm.func @gpu_shuffles(%i8_val: i8,
%i16_val: i16,
%i32_val: i32,
@@ -285,6 +286,8 @@ gpu.module @shuffles {
%f16_val: f16,
%f32_val: f32,
%f64_val: f64,
+ %bf16_val: bf16,
+ %i1_val: i1,
%offset: i32) attributes {intel_reqd_sub_group_size = 16 : i32} {
%width = arith.constant 16 : i32
// CHECK: llvm.call spir_funccc @_Z17sub_group_shufflecj(%[[I8_VAL]], %[[OFFSET]])
@@ -301,6 +304,14 @@ gpu.module @shuffles {
// CHECK: llvm.mlir.constant(true) : i1
// CHECK: llvm.call spir_funccc @_Z22sub_group_shuffle_downdj(%[[F64_VAL]], %[[OFFSET]])
// CHECK: llvm.mlir.constant(true) : i1
+ // CHECK: %[[BF16_INBC:.*]] = llvm.bitcast %[[BF16_VAL]] : bf16 to i16
+ // CHECK: %[[BF16_CALL:.*]] = llvm.call spir_funccc @_Z22sub_group_shuffle_downsj(%[[BF16_INBC]], %[[OFFSET]])
+ // CHECK: llvm.bitcast %[[BF16_CALL]] : i16 to bf16
+ // CHECK: llvm.mlir.constant(true) : i1
+ // CHECK: %[[I1_ZEXT:.*]] = llvm.zext %[[I1_VAL]] : i1 to i8
+ // CHECK: %[[I1_CALL:.*]] = llvm.call spir_funccc @_Z21sub_group_shuffle_xorcj(%18, %arg9)
+ // CHECK: llvm.trunc %[[I1_CALL:.*]] : i8 to i1
+ // CHECK: llvm.mlir.constant(true) : i1
%shuffleResult0, %valid0 = gpu.shuffle idx %i8_val, %offset, %width : i8
%shuffleResult1, %valid1 = gpu.shuffle xor %i16_val, %offset, %width : i16
%shuffleResult2, %valid2 = gpu.shuffle idx %i32_val, %offset, %width : i32
@@ -308,6 +319,8 @@ gpu.module @shuffles {
%shuffleResult4, %valid4 = gpu.shuffle up %f16_val, %offset, %width : f16
%shuffleResult5, %valid5 = gpu.shuffle up %f32_val, %offset, %width : f32
%shuffleResult6, %valid6 = gpu.shuffle down %f64_val, %offset, %width : f64
+ %shuffleResult7, %valid7 = gpu.shuffle down %bf16_val, %offset, %width : bf16
+ %shuffleResult8, %valid8 = gpu.shuffle xor %i1_val, %offset, %width : i1
llvm.return
}
}
@@ -342,10 +355,8 @@ gpu.module @shuffles_mismatch {
// Cannot convert due to value type not being supported by the conversion
gpu.module @not_supported_lowering {
- llvm.func @gpu_shuffles(%val: i1, %id: i32) attributes {intel_reqd_sub_group_size = 32 : i32} {
+ llvm.func @gpu_shuffles(%id: i32) attributes {intel_reqd_sub_group_size = 32 : i32} {
%width = arith.constant 32 : i32
- // expected-error@below {{failed to legalize operation 'gpu.shuffle' that was explicitly marked illegal}}
- %shuffleResult, %valid = gpu.shuffle xor %val, %id, %width : i1
llvm.return
}
}
|
|
@llvm/pr-subscribers-mlir-gpu Author: Pietro Ghiglio (PietroGhg) ChangesThis PR adds support to the Full diff: https://github.com/llvm/llvm-project/pull/119675.diff 2 Files Affected:
diff --git a/mlir/lib/Conversion/GPUToLLVMSPV/GPUToLLVMSPV.cpp b/mlir/lib/Conversion/GPUToLLVMSPV/GPUToLLVMSPV.cpp
index 03745f4537e99e..415e67aebab978 100644
--- a/mlir/lib/Conversion/GPUToLLVMSPV/GPUToLLVMSPV.cpp
+++ b/mlir/lib/Conversion/GPUToLLVMSPV/GPUToLLVMSPV.cpp
@@ -262,15 +262,20 @@ struct GPUShuffleConversion final : ConvertOpToLLVMPattern<gpu::ShuffleOp> {
.Default([](auto) { return std::nullopt; });
}
- static std::optional<std::string> getFuncName(gpu::ShuffleOp op) {
- StringRef baseName = getBaseName(op.getMode());
- std::optional<StringRef> typeMangling = getTypeMangling(op.getType(0));
+ static std::optional<std::string> getFuncName(gpu::ShuffleMode mode,
+ Type type) {
+ StringRef baseName = getBaseName(mode);
+ std::optional<StringRef> typeMangling = getTypeMangling(type);
if (!typeMangling)
return std::nullopt;
return llvm::formatv("_Z{0}{1}{2}", baseName.size(), baseName,
typeMangling.value());
}
+ static std::optional<std::string> getFuncName(gpu::ShuffleOp op) {
+ return getFuncName(op.getMode(), op.getType(0));
+ }
+
/// Get the subgroup size from the target or return a default.
static std::optional<int> getSubgroupSize(Operation *op) {
auto parentFunc = op->getParentOfType<LLVM::LLVMFuncOp>();
@@ -286,6 +291,51 @@ struct GPUShuffleConversion final : ConvertOpToLLVMPattern<gpu::ShuffleOp> {
val == getSubgroupSize(op);
}
+ static bool needsBitCastOrExt(gpu::ShuffleOp op) {
+ Type type = op.getType(0);
+ return isa<BFloat16Type>(type) || type.isInteger(1);
+ }
+
+ static Type getBitCastOrExtTy(Type oldTy,
+ ConversionPatternRewriter &rewriter) {
+ return TypeSwitch<Type, Type>(oldTy)
+ .Case<BFloat16Type>([&](auto) { return rewriter.getIntegerType(16); })
+ .Case<IntegerType>([&](auto intTy) -> Type {
+ if (intTy.getWidth() == 1)
+ return rewriter.getIntegerType(8);
+ return Type{};
+ })
+ .Default([](auto) { return Type{}; });
+ }
+
+ static Value doBitcastOrExt(Value oldVal, Type newTy, Location loc,
+ ConversionPatternRewriter &rewriter) {
+ return TypeSwitch<Type, Value>(oldVal.getType())
+ .Case<BFloat16Type>([&](auto) {
+ return rewriter.create<LLVM::BitcastOp>(loc, newTy, oldVal);
+ })
+ .Case<IntegerType>([&](auto intTy) -> Value {
+ if (intTy.getWidth() == 1)
+ return rewriter.create<LLVM::ZExtOp>(loc, newTy, oldVal);
+ return Value{};
+ })
+ .Default([](auto) { return Value{}; });
+ }
+
+ static Value doBitcastOrTrunc(Value oldVal, Type newTy, Location loc,
+ ConversionPatternRewriter &rewriter) {
+ return TypeSwitch<Type, Value>(newTy)
+ .Case<BFloat16Type>([&](auto) {
+ return rewriter.create<LLVM::BitcastOp>(loc, newTy, oldVal);
+ })
+ .Case<IntegerType>([&](auto intTy) -> Value {
+ if (intTy.getWidth() == 1)
+ return rewriter.create<LLVM::TruncOp>(loc, newTy, oldVal);
+ return Value{};
+ })
+ .Default([](auto) { return Value{}; });
+ }
+
LogicalResult
matchAndRewrite(gpu::ShuffleOp op, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const final {
@@ -293,23 +343,42 @@ struct GPUShuffleConversion final : ConvertOpToLLVMPattern<gpu::ShuffleOp> {
return rewriter.notifyMatchFailure(
op, "shuffle width and subgroup size mismatch");
- std::optional<std::string> funcName = getFuncName(op);
+ Location loc = op->getLoc();
+ Type bitcastOrExtDestTy = getBitCastOrExtTy(op.getType(0), rewriter);
+ std::optional<std::string> funcName;
+ Value inValue;
+ if (bitcastOrExtDestTy) {
+ Value newVal =
+ doBitcastOrExt(adaptor.getValue(), bitcastOrExtDestTy, loc, rewriter);
+ assert(newVal && "Unhandled op type in bitcastorext");
+ funcName = getFuncName(op.getMode(), bitcastOrExtDestTy);
+ inValue = newVal;
+ } else {
+ funcName = getFuncName(op);
+ inValue = adaptor.getValue();
+ }
if (!funcName)
return rewriter.notifyMatchFailure(op, "unsupported value type");
Operation *moduleOp = op->getParentWithTrait<OpTrait::SymbolTable>();
assert(moduleOp && "Expecting module");
- Type valueType = adaptor.getValue().getType();
+ Type valueType = inValue.getType();
Type offsetType = adaptor.getOffset().getType();
Type resultType = valueType;
LLVM::LLVMFuncOp func = lookupOrCreateSPIRVFn(
moduleOp, funcName.value(), {valueType, offsetType}, resultType,
/*isMemNone=*/false, /*isConvergent=*/true);
- Location loc = op->getLoc();
- std::array<Value, 2> args{adaptor.getValue(), adaptor.getOffset()};
+ std::array<Value, 2> args{inValue, adaptor.getOffset()};
Value result =
createSPIRVBuiltinCall(loc, rewriter, func, args).getResult();
+ if (bitcastOrExtDestTy) {
+ Value newVal =
+ doBitcastOrTrunc(result, adaptor.getValue().getType(), loc, rewriter);
+ assert(newVal && "Unhandled op type in bitcastortrunc");
+ result = newVal;
+ }
+
Value trueVal =
rewriter.create<LLVM::ConstantOp>(loc, rewriter.getI1Type(), true);
rewriter.replaceOp(op, {result, trueVal});
diff --git a/mlir/test/Conversion/GPUToLLVMSPV/gpu-to-llvm-spv.mlir b/mlir/test/Conversion/GPUToLLVMSPV/gpu-to-llvm-spv.mlir
index 16b692b9689398..6fab647cb35681 100644
--- a/mlir/test/Conversion/GPUToLLVMSPV/gpu-to-llvm-spv.mlir
+++ b/mlir/test/Conversion/GPUToLLVMSPV/gpu-to-llvm-spv.mlir
@@ -277,7 +277,8 @@ gpu.module @shuffles {
// CHECK-SAME: (%[[I8_VAL:.*]]: i8, %[[I16_VAL:.*]]: i16,
// CHECK-SAME: %[[I32_VAL:.*]]: i32, %[[I64_VAL:.*]]: i64,
// CHECK-SAME: %[[F16_VAL:.*]]: f16, %[[F32_VAL:.*]]: f32,
- // CHECK-SAME: %[[F64_VAL:.*]]: f64, %[[OFFSET:.*]]: i32)
+ // CHECK-SAME: %[[F64_VAL:.*]]: f64, %[[BF16_VAL:.*]]: bf16,
+ // CHECK-SAME: %[[I1_VAL:.*]]: i1, %[[OFFSET:.*]]: i32)
llvm.func @gpu_shuffles(%i8_val: i8,
%i16_val: i16,
%i32_val: i32,
@@ -285,6 +286,8 @@ gpu.module @shuffles {
%f16_val: f16,
%f32_val: f32,
%f64_val: f64,
+ %bf16_val: bf16,
+ %i1_val: i1,
%offset: i32) attributes {intel_reqd_sub_group_size = 16 : i32} {
%width = arith.constant 16 : i32
// CHECK: llvm.call spir_funccc @_Z17sub_group_shufflecj(%[[I8_VAL]], %[[OFFSET]])
@@ -301,6 +304,14 @@ gpu.module @shuffles {
// CHECK: llvm.mlir.constant(true) : i1
// CHECK: llvm.call spir_funccc @_Z22sub_group_shuffle_downdj(%[[F64_VAL]], %[[OFFSET]])
// CHECK: llvm.mlir.constant(true) : i1
+ // CHECK: %[[BF16_INBC:.*]] = llvm.bitcast %[[BF16_VAL]] : bf16 to i16
+ // CHECK: %[[BF16_CALL:.*]] = llvm.call spir_funccc @_Z22sub_group_shuffle_downsj(%[[BF16_INBC]], %[[OFFSET]])
+ // CHECK: llvm.bitcast %[[BF16_CALL]] : i16 to bf16
+ // CHECK: llvm.mlir.constant(true) : i1
+ // CHECK: %[[I1_ZEXT:.*]] = llvm.zext %[[I1_VAL]] : i1 to i8
+ // CHECK: %[[I1_CALL:.*]] = llvm.call spir_funccc @_Z21sub_group_shuffle_xorcj(%18, %arg9)
+ // CHECK: llvm.trunc %[[I1_CALL:.*]] : i8 to i1
+ // CHECK: llvm.mlir.constant(true) : i1
%shuffleResult0, %valid0 = gpu.shuffle idx %i8_val, %offset, %width : i8
%shuffleResult1, %valid1 = gpu.shuffle xor %i16_val, %offset, %width : i16
%shuffleResult2, %valid2 = gpu.shuffle idx %i32_val, %offset, %width : i32
@@ -308,6 +319,8 @@ gpu.module @shuffles {
%shuffleResult4, %valid4 = gpu.shuffle up %f16_val, %offset, %width : f16
%shuffleResult5, %valid5 = gpu.shuffle up %f32_val, %offset, %width : f32
%shuffleResult6, %valid6 = gpu.shuffle down %f64_val, %offset, %width : f64
+ %shuffleResult7, %valid7 = gpu.shuffle down %bf16_val, %offset, %width : bf16
+ %shuffleResult8, %valid8 = gpu.shuffle xor %i1_val, %offset, %width : i1
llvm.return
}
}
@@ -342,10 +355,8 @@ gpu.module @shuffles_mismatch {
// Cannot convert due to value type not being supported by the conversion
gpu.module @not_supported_lowering {
- llvm.func @gpu_shuffles(%val: i1, %id: i32) attributes {intel_reqd_sub_group_size = 32 : i32} {
+ llvm.func @gpu_shuffles(%id: i32) attributes {intel_reqd_sub_group_size = 32 : i32} {
%width = arith.constant 32 : i32
- // expected-error@below {{failed to legalize operation 'gpu.shuffle' that was explicitly marked illegal}}
- %shuffleResult, %valid = gpu.shuffle xor %val, %id, %width : i1
llvm.return
}
}
|
This PR adds support to the
bf16andi1data types when convertinggpu::shuffleto theLLVMSPVdialect, by insertingbitcastto/fromi16(forbf16) and extending/truncating toi8(fori1).