Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[MLIR][GPU] Support bf16 and i1 gpu::shuffles to LLVMSPIRV conversion #119675

Open
wants to merge 1 commit into
base: main
Choose a base branch
from

Conversation

PietroGhg
Copy link
Contributor

This PR adds support to the bf16 and i1 data types when converting gpu::shuffle to the LLVMSPV dialect, by inserting bitcast to/from i16 (for bf16) and extending/truncating to i8 (for i1).

@llvmbot
Copy link
Member

llvmbot commented Dec 12, 2024

@llvm/pr-subscribers-mlir

Author: Pietro Ghiglio (PietroGhg)

Changes

This PR adds support to the bf16 and i1 data types when converting gpu::shuffle to the LLVMSPV dialect, by inserting bitcast to/from i16 (for bf16) and extending/truncating to i8 (for i1).


Full diff: https://github.com/llvm/llvm-project/pull/119675.diff

2 Files Affected:

  • (modified) mlir/lib/Conversion/GPUToLLVMSPV/GPUToLLVMSPV.cpp (+76-7)
  • (modified) mlir/test/Conversion/GPUToLLVMSPV/gpu-to-llvm-spv.mlir (+15-4)
diff --git a/mlir/lib/Conversion/GPUToLLVMSPV/GPUToLLVMSPV.cpp b/mlir/lib/Conversion/GPUToLLVMSPV/GPUToLLVMSPV.cpp
index 03745f4537e99e..415e67aebab978 100644
--- a/mlir/lib/Conversion/GPUToLLVMSPV/GPUToLLVMSPV.cpp
+++ b/mlir/lib/Conversion/GPUToLLVMSPV/GPUToLLVMSPV.cpp
@@ -262,15 +262,20 @@ struct GPUShuffleConversion final : ConvertOpToLLVMPattern<gpu::ShuffleOp> {
         .Default([](auto) { return std::nullopt; });
   }
 
-  static std::optional<std::string> getFuncName(gpu::ShuffleOp op) {
-    StringRef baseName = getBaseName(op.getMode());
-    std::optional<StringRef> typeMangling = getTypeMangling(op.getType(0));
+  static std::optional<std::string> getFuncName(gpu::ShuffleMode mode,
+                                                Type type) {
+    StringRef baseName = getBaseName(mode);
+    std::optional<StringRef> typeMangling = getTypeMangling(type);
     if (!typeMangling)
       return std::nullopt;
     return llvm::formatv("_Z{0}{1}{2}", baseName.size(), baseName,
                          typeMangling.value());
   }
 
+  static std::optional<std::string> getFuncName(gpu::ShuffleOp op) {
+    return getFuncName(op.getMode(), op.getType(0));
+  }
+
   /// Get the subgroup size from the target or return a default.
   static std::optional<int> getSubgroupSize(Operation *op) {
     auto parentFunc = op->getParentOfType<LLVM::LLVMFuncOp>();
@@ -286,6 +291,51 @@ struct GPUShuffleConversion final : ConvertOpToLLVMPattern<gpu::ShuffleOp> {
            val == getSubgroupSize(op);
   }
 
+  static bool needsBitCastOrExt(gpu::ShuffleOp op) {
+    Type type = op.getType(0);
+    return isa<BFloat16Type>(type) || type.isInteger(1);
+  }
+
+  static Type getBitCastOrExtTy(Type oldTy,
+                                ConversionPatternRewriter &rewriter) {
+    return TypeSwitch<Type, Type>(oldTy)
+        .Case<BFloat16Type>([&](auto) { return rewriter.getIntegerType(16); })
+        .Case<IntegerType>([&](auto intTy) -> Type {
+          if (intTy.getWidth() == 1)
+            return rewriter.getIntegerType(8);
+          return Type{};
+        })
+        .Default([](auto) { return Type{}; });
+  }
+
+  static Value doBitcastOrExt(Value oldVal, Type newTy, Location loc,
+                              ConversionPatternRewriter &rewriter) {
+    return TypeSwitch<Type, Value>(oldVal.getType())
+        .Case<BFloat16Type>([&](auto) {
+          return rewriter.create<LLVM::BitcastOp>(loc, newTy, oldVal);
+        })
+        .Case<IntegerType>([&](auto intTy) -> Value {
+          if (intTy.getWidth() == 1)
+            return rewriter.create<LLVM::ZExtOp>(loc, newTy, oldVal);
+          return Value{};
+        })
+        .Default([](auto) { return Value{}; });
+  }
+
+  static Value doBitcastOrTrunc(Value oldVal, Type newTy, Location loc,
+                                ConversionPatternRewriter &rewriter) {
+    return TypeSwitch<Type, Value>(newTy)
+        .Case<BFloat16Type>([&](auto) {
+          return rewriter.create<LLVM::BitcastOp>(loc, newTy, oldVal);
+        })
+        .Case<IntegerType>([&](auto intTy) -> Value {
+          if (intTy.getWidth() == 1)
+            return rewriter.create<LLVM::TruncOp>(loc, newTy, oldVal);
+          return Value{};
+        })
+        .Default([](auto) { return Value{}; });
+  }
+
   LogicalResult
   matchAndRewrite(gpu::ShuffleOp op, OpAdaptor adaptor,
                   ConversionPatternRewriter &rewriter) const final {
@@ -293,23 +343,42 @@ struct GPUShuffleConversion final : ConvertOpToLLVMPattern<gpu::ShuffleOp> {
       return rewriter.notifyMatchFailure(
           op, "shuffle width and subgroup size mismatch");
 
-    std::optional<std::string> funcName = getFuncName(op);
+    Location loc = op->getLoc();
+    Type bitcastOrExtDestTy = getBitCastOrExtTy(op.getType(0), rewriter);
+    std::optional<std::string> funcName;
+    Value inValue;
+    if (bitcastOrExtDestTy) {
+      Value newVal =
+          doBitcastOrExt(adaptor.getValue(), bitcastOrExtDestTy, loc, rewriter);
+      assert(newVal && "Unhandled op type in bitcastorext");
+      funcName = getFuncName(op.getMode(), bitcastOrExtDestTy);
+      inValue = newVal;
+    } else {
+      funcName = getFuncName(op);
+      inValue = adaptor.getValue();
+    }
     if (!funcName)
       return rewriter.notifyMatchFailure(op, "unsupported value type");
 
     Operation *moduleOp = op->getParentWithTrait<OpTrait::SymbolTable>();
     assert(moduleOp && "Expecting module");
-    Type valueType = adaptor.getValue().getType();
+    Type valueType = inValue.getType();
     Type offsetType = adaptor.getOffset().getType();
     Type resultType = valueType;
     LLVM::LLVMFuncOp func = lookupOrCreateSPIRVFn(
         moduleOp, funcName.value(), {valueType, offsetType}, resultType,
         /*isMemNone=*/false, /*isConvergent=*/true);
 
-    Location loc = op->getLoc();
-    std::array<Value, 2> args{adaptor.getValue(), adaptor.getOffset()};
+    std::array<Value, 2> args{inValue, adaptor.getOffset()};
     Value result =
         createSPIRVBuiltinCall(loc, rewriter, func, args).getResult();
+    if (bitcastOrExtDestTy) {
+      Value newVal =
+          doBitcastOrTrunc(result, adaptor.getValue().getType(), loc, rewriter);
+      assert(newVal && "Unhandled op type in bitcastortrunc");
+      result = newVal;
+    }
+
     Value trueVal =
         rewriter.create<LLVM::ConstantOp>(loc, rewriter.getI1Type(), true);
     rewriter.replaceOp(op, {result, trueVal});
diff --git a/mlir/test/Conversion/GPUToLLVMSPV/gpu-to-llvm-spv.mlir b/mlir/test/Conversion/GPUToLLVMSPV/gpu-to-llvm-spv.mlir
index 16b692b9689398..6fab647cb35681 100644
--- a/mlir/test/Conversion/GPUToLLVMSPV/gpu-to-llvm-spv.mlir
+++ b/mlir/test/Conversion/GPUToLLVMSPV/gpu-to-llvm-spv.mlir
@@ -277,7 +277,8 @@ gpu.module @shuffles {
   // CHECK-SAME:              (%[[I8_VAL:.*]]: i8, %[[I16_VAL:.*]]: i16,
   // CHECK-SAME:               %[[I32_VAL:.*]]: i32, %[[I64_VAL:.*]]: i64,
   // CHECK-SAME:               %[[F16_VAL:.*]]: f16, %[[F32_VAL:.*]]: f32,
-  // CHECK-SAME:               %[[F64_VAL:.*]]: f64,  %[[OFFSET:.*]]: i32)
+  // CHECK-SAME:               %[[F64_VAL:.*]]: f64, %[[BF16_VAL:.*]]: bf16,
+  // CHECK-SAME:               %[[I1_VAL:.*]]: i1, %[[OFFSET:.*]]: i32)
   llvm.func @gpu_shuffles(%i8_val: i8,
                           %i16_val: i16,
                           %i32_val: i32,
@@ -285,6 +286,8 @@ gpu.module @shuffles {
                           %f16_val: f16,
                           %f32_val: f32,
                           %f64_val: f64,
+                          %bf16_val: bf16,
+                          %i1_val: i1,
                           %offset: i32) attributes {intel_reqd_sub_group_size = 16 : i32} {
     %width = arith.constant 16 : i32
     // CHECK:         llvm.call spir_funccc @_Z17sub_group_shufflecj(%[[I8_VAL]], %[[OFFSET]])
@@ -301,6 +304,14 @@ gpu.module @shuffles {
     // CHECK:         llvm.mlir.constant(true) : i1
     // CHECK:         llvm.call spir_funccc @_Z22sub_group_shuffle_downdj(%[[F64_VAL]], %[[OFFSET]])
     // CHECK:         llvm.mlir.constant(true) : i1
+    // CHECK:         %[[BF16_INBC:.*]] = llvm.bitcast %[[BF16_VAL]] : bf16 to i16
+    // CHECK:         %[[BF16_CALL:.*]] = llvm.call spir_funccc @_Z22sub_group_shuffle_downsj(%[[BF16_INBC]], %[[OFFSET]])
+    // CHECK:         llvm.bitcast %[[BF16_CALL]] : i16 to bf16
+    // CHECK:         llvm.mlir.constant(true) : i1
+    // CHECK:         %[[I1_ZEXT:.*]] = llvm.zext %[[I1_VAL]] : i1 to i8
+    // CHECK:         %[[I1_CALL:.*]] = llvm.call spir_funccc @_Z21sub_group_shuffle_xorcj(%18, %arg9)
+    // CHECK:         llvm.trunc %[[I1_CALL:.*]] : i8 to i1
+    // CHECK:         llvm.mlir.constant(true) : i1
     %shuffleResult0, %valid0 = gpu.shuffle idx %i8_val, %offset, %width : i8
     %shuffleResult1, %valid1 = gpu.shuffle xor %i16_val, %offset, %width : i16
     %shuffleResult2, %valid2 = gpu.shuffle idx %i32_val, %offset, %width : i32
@@ -308,6 +319,8 @@ gpu.module @shuffles {
     %shuffleResult4, %valid4 = gpu.shuffle up %f16_val, %offset, %width : f16
     %shuffleResult5, %valid5 = gpu.shuffle up %f32_val, %offset, %width : f32
     %shuffleResult6, %valid6 = gpu.shuffle down %f64_val, %offset, %width : f64
+    %shuffleResult7, %valid7 = gpu.shuffle down %bf16_val, %offset, %width : bf16
+    %shuffleResult8, %valid8 = gpu.shuffle xor %i1_val, %offset, %width : i1
     llvm.return
   }
 }
@@ -342,10 +355,8 @@ gpu.module @shuffles_mismatch {
 // Cannot convert due to value type not being supported by the conversion
 
 gpu.module @not_supported_lowering {
-  llvm.func @gpu_shuffles(%val: i1, %id: i32) attributes {intel_reqd_sub_group_size = 32 : i32} {
+  llvm.func @gpu_shuffles(%id: i32) attributes {intel_reqd_sub_group_size = 32 : i32} {
     %width = arith.constant 32 : i32
-    // expected-error@below {{failed to legalize operation 'gpu.shuffle' that was explicitly marked illegal}}
-    %shuffleResult, %valid = gpu.shuffle xor %val, %id, %width : i1
     llvm.return
   }
 }

@llvmbot
Copy link
Member

llvmbot commented Dec 12, 2024

@llvm/pr-subscribers-mlir-gpu

Author: Pietro Ghiglio (PietroGhg)

Changes

This PR adds support to the bf16 and i1 data types when converting gpu::shuffle to the LLVMSPV dialect, by inserting bitcast to/from i16 (for bf16) and extending/truncating to i8 (for i1).


Full diff: https://github.com/llvm/llvm-project/pull/119675.diff

2 Files Affected:

  • (modified) mlir/lib/Conversion/GPUToLLVMSPV/GPUToLLVMSPV.cpp (+76-7)
  • (modified) mlir/test/Conversion/GPUToLLVMSPV/gpu-to-llvm-spv.mlir (+15-4)
diff --git a/mlir/lib/Conversion/GPUToLLVMSPV/GPUToLLVMSPV.cpp b/mlir/lib/Conversion/GPUToLLVMSPV/GPUToLLVMSPV.cpp
index 03745f4537e99e..415e67aebab978 100644
--- a/mlir/lib/Conversion/GPUToLLVMSPV/GPUToLLVMSPV.cpp
+++ b/mlir/lib/Conversion/GPUToLLVMSPV/GPUToLLVMSPV.cpp
@@ -262,15 +262,20 @@ struct GPUShuffleConversion final : ConvertOpToLLVMPattern<gpu::ShuffleOp> {
         .Default([](auto) { return std::nullopt; });
   }
 
-  static std::optional<std::string> getFuncName(gpu::ShuffleOp op) {
-    StringRef baseName = getBaseName(op.getMode());
-    std::optional<StringRef> typeMangling = getTypeMangling(op.getType(0));
+  static std::optional<std::string> getFuncName(gpu::ShuffleMode mode,
+                                                Type type) {
+    StringRef baseName = getBaseName(mode);
+    std::optional<StringRef> typeMangling = getTypeMangling(type);
     if (!typeMangling)
       return std::nullopt;
     return llvm::formatv("_Z{0}{1}{2}", baseName.size(), baseName,
                          typeMangling.value());
   }
 
+  static std::optional<std::string> getFuncName(gpu::ShuffleOp op) {
+    return getFuncName(op.getMode(), op.getType(0));
+  }
+
   /// Get the subgroup size from the target or return a default.
   static std::optional<int> getSubgroupSize(Operation *op) {
     auto parentFunc = op->getParentOfType<LLVM::LLVMFuncOp>();
@@ -286,6 +291,51 @@ struct GPUShuffleConversion final : ConvertOpToLLVMPattern<gpu::ShuffleOp> {
            val == getSubgroupSize(op);
   }
 
+  static bool needsBitCastOrExt(gpu::ShuffleOp op) {
+    Type type = op.getType(0);
+    return isa<BFloat16Type>(type) || type.isInteger(1);
+  }
+
+  static Type getBitCastOrExtTy(Type oldTy,
+                                ConversionPatternRewriter &rewriter) {
+    return TypeSwitch<Type, Type>(oldTy)
+        .Case<BFloat16Type>([&](auto) { return rewriter.getIntegerType(16); })
+        .Case<IntegerType>([&](auto intTy) -> Type {
+          if (intTy.getWidth() == 1)
+            return rewriter.getIntegerType(8);
+          return Type{};
+        })
+        .Default([](auto) { return Type{}; });
+  }
+
+  static Value doBitcastOrExt(Value oldVal, Type newTy, Location loc,
+                              ConversionPatternRewriter &rewriter) {
+    return TypeSwitch<Type, Value>(oldVal.getType())
+        .Case<BFloat16Type>([&](auto) {
+          return rewriter.create<LLVM::BitcastOp>(loc, newTy, oldVal);
+        })
+        .Case<IntegerType>([&](auto intTy) -> Value {
+          if (intTy.getWidth() == 1)
+            return rewriter.create<LLVM::ZExtOp>(loc, newTy, oldVal);
+          return Value{};
+        })
+        .Default([](auto) { return Value{}; });
+  }
+
+  static Value doBitcastOrTrunc(Value oldVal, Type newTy, Location loc,
+                                ConversionPatternRewriter &rewriter) {
+    return TypeSwitch<Type, Value>(newTy)
+        .Case<BFloat16Type>([&](auto) {
+          return rewriter.create<LLVM::BitcastOp>(loc, newTy, oldVal);
+        })
+        .Case<IntegerType>([&](auto intTy) -> Value {
+          if (intTy.getWidth() == 1)
+            return rewriter.create<LLVM::TruncOp>(loc, newTy, oldVal);
+          return Value{};
+        })
+        .Default([](auto) { return Value{}; });
+  }
+
   LogicalResult
   matchAndRewrite(gpu::ShuffleOp op, OpAdaptor adaptor,
                   ConversionPatternRewriter &rewriter) const final {
@@ -293,23 +343,42 @@ struct GPUShuffleConversion final : ConvertOpToLLVMPattern<gpu::ShuffleOp> {
       return rewriter.notifyMatchFailure(
           op, "shuffle width and subgroup size mismatch");
 
-    std::optional<std::string> funcName = getFuncName(op);
+    Location loc = op->getLoc();
+    Type bitcastOrExtDestTy = getBitCastOrExtTy(op.getType(0), rewriter);
+    std::optional<std::string> funcName;
+    Value inValue;
+    if (bitcastOrExtDestTy) {
+      Value newVal =
+          doBitcastOrExt(adaptor.getValue(), bitcastOrExtDestTy, loc, rewriter);
+      assert(newVal && "Unhandled op type in bitcastorext");
+      funcName = getFuncName(op.getMode(), bitcastOrExtDestTy);
+      inValue = newVal;
+    } else {
+      funcName = getFuncName(op);
+      inValue = adaptor.getValue();
+    }
     if (!funcName)
       return rewriter.notifyMatchFailure(op, "unsupported value type");
 
     Operation *moduleOp = op->getParentWithTrait<OpTrait::SymbolTable>();
     assert(moduleOp && "Expecting module");
-    Type valueType = adaptor.getValue().getType();
+    Type valueType = inValue.getType();
     Type offsetType = adaptor.getOffset().getType();
     Type resultType = valueType;
     LLVM::LLVMFuncOp func = lookupOrCreateSPIRVFn(
         moduleOp, funcName.value(), {valueType, offsetType}, resultType,
         /*isMemNone=*/false, /*isConvergent=*/true);
 
-    Location loc = op->getLoc();
-    std::array<Value, 2> args{adaptor.getValue(), adaptor.getOffset()};
+    std::array<Value, 2> args{inValue, adaptor.getOffset()};
     Value result =
         createSPIRVBuiltinCall(loc, rewriter, func, args).getResult();
+    if (bitcastOrExtDestTy) {
+      Value newVal =
+          doBitcastOrTrunc(result, adaptor.getValue().getType(), loc, rewriter);
+      assert(newVal && "Unhandled op type in bitcastortrunc");
+      result = newVal;
+    }
+
     Value trueVal =
         rewriter.create<LLVM::ConstantOp>(loc, rewriter.getI1Type(), true);
     rewriter.replaceOp(op, {result, trueVal});
diff --git a/mlir/test/Conversion/GPUToLLVMSPV/gpu-to-llvm-spv.mlir b/mlir/test/Conversion/GPUToLLVMSPV/gpu-to-llvm-spv.mlir
index 16b692b9689398..6fab647cb35681 100644
--- a/mlir/test/Conversion/GPUToLLVMSPV/gpu-to-llvm-spv.mlir
+++ b/mlir/test/Conversion/GPUToLLVMSPV/gpu-to-llvm-spv.mlir
@@ -277,7 +277,8 @@ gpu.module @shuffles {
   // CHECK-SAME:              (%[[I8_VAL:.*]]: i8, %[[I16_VAL:.*]]: i16,
   // CHECK-SAME:               %[[I32_VAL:.*]]: i32, %[[I64_VAL:.*]]: i64,
   // CHECK-SAME:               %[[F16_VAL:.*]]: f16, %[[F32_VAL:.*]]: f32,
-  // CHECK-SAME:               %[[F64_VAL:.*]]: f64,  %[[OFFSET:.*]]: i32)
+  // CHECK-SAME:               %[[F64_VAL:.*]]: f64, %[[BF16_VAL:.*]]: bf16,
+  // CHECK-SAME:               %[[I1_VAL:.*]]: i1, %[[OFFSET:.*]]: i32)
   llvm.func @gpu_shuffles(%i8_val: i8,
                           %i16_val: i16,
                           %i32_val: i32,
@@ -285,6 +286,8 @@ gpu.module @shuffles {
                           %f16_val: f16,
                           %f32_val: f32,
                           %f64_val: f64,
+                          %bf16_val: bf16,
+                          %i1_val: i1,
                           %offset: i32) attributes {intel_reqd_sub_group_size = 16 : i32} {
     %width = arith.constant 16 : i32
     // CHECK:         llvm.call spir_funccc @_Z17sub_group_shufflecj(%[[I8_VAL]], %[[OFFSET]])
@@ -301,6 +304,14 @@ gpu.module @shuffles {
     // CHECK:         llvm.mlir.constant(true) : i1
     // CHECK:         llvm.call spir_funccc @_Z22sub_group_shuffle_downdj(%[[F64_VAL]], %[[OFFSET]])
     // CHECK:         llvm.mlir.constant(true) : i1
+    // CHECK:         %[[BF16_INBC:.*]] = llvm.bitcast %[[BF16_VAL]] : bf16 to i16
+    // CHECK:         %[[BF16_CALL:.*]] = llvm.call spir_funccc @_Z22sub_group_shuffle_downsj(%[[BF16_INBC]], %[[OFFSET]])
+    // CHECK:         llvm.bitcast %[[BF16_CALL]] : i16 to bf16
+    // CHECK:         llvm.mlir.constant(true) : i1
+    // CHECK:         %[[I1_ZEXT:.*]] = llvm.zext %[[I1_VAL]] : i1 to i8
+    // CHECK:         %[[I1_CALL:.*]] = llvm.call spir_funccc @_Z21sub_group_shuffle_xorcj(%18, %arg9)
+    // CHECK:         llvm.trunc %[[I1_CALL:.*]] : i8 to i1
+    // CHECK:         llvm.mlir.constant(true) : i1
     %shuffleResult0, %valid0 = gpu.shuffle idx %i8_val, %offset, %width : i8
     %shuffleResult1, %valid1 = gpu.shuffle xor %i16_val, %offset, %width : i16
     %shuffleResult2, %valid2 = gpu.shuffle idx %i32_val, %offset, %width : i32
@@ -308,6 +319,8 @@ gpu.module @shuffles {
     %shuffleResult4, %valid4 = gpu.shuffle up %f16_val, %offset, %width : f16
     %shuffleResult5, %valid5 = gpu.shuffle up %f32_val, %offset, %width : f32
     %shuffleResult6, %valid6 = gpu.shuffle down %f64_val, %offset, %width : f64
+    %shuffleResult7, %valid7 = gpu.shuffle down %bf16_val, %offset, %width : bf16
+    %shuffleResult8, %valid8 = gpu.shuffle xor %i1_val, %offset, %width : i1
     llvm.return
   }
 }
@@ -342,10 +355,8 @@ gpu.module @shuffles_mismatch {
 // Cannot convert due to value type not being supported by the conversion
 
 gpu.module @not_supported_lowering {
-  llvm.func @gpu_shuffles(%val: i1, %id: i32) attributes {intel_reqd_sub_group_size = 32 : i32} {
+  llvm.func @gpu_shuffles(%id: i32) attributes {intel_reqd_sub_group_size = 32 : i32} {
     %width = arith.constant 32 : i32
-    // expected-error@below {{failed to legalize operation 'gpu.shuffle' that was explicitly marked illegal}}
-    %shuffleResult, %valid = gpu.shuffle xor %val, %id, %width : i1
     llvm.return
   }
 }

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging this pull request may close these issues.

2 participants