Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[mlir][amdgpu] Add support for multi-dim arith.truncf/extf fp8 lowering #98074

Merged
merged 3 commits into from
Jul 9, 2024

Conversation

rsuderman
Copy link
Contributor

The existing fp8 lowering from arith to amdgpu bails out on the multidimensional case. We can handle this by vector.shape_cast collapsing to the 1-D case on extraction and re-casting back to the desired output shape.

The existing `fp8` lowering from `arith` to `amdgpu` bails out on the
multidimensional case. We can handle this by `vector.shape_cast`
collapsing to the 1-D case on extraction and re-casting back to the
desired output shape.
@llvmbot
Copy link
Collaborator

llvmbot commented Jul 8, 2024

@llvm/pr-subscribers-mlir
@llvm/pr-subscribers-backend-amdgpu

@llvm/pr-subscribers-mlir-gpu

Author: Rob Suderman (rsuderman)

Changes

The existing fp8 lowering from arith to amdgpu bails out on the multidimensional case. We can handle this by vector.shape_cast collapsing to the 1-D case on extraction and re-casting back to the desired output shape.


Full diff: https://github.com/llvm/llvm-project/pull/98074.diff

2 Files Affected:

  • (modified) mlir/lib/Conversion/ArithToAMDGPU/ArithToAMDGPU.cpp (+40-16)
  • (modified) mlir/test/Conversion/ArithToAMDGPU/8-bit-floats.mlir (+58)
diff --git a/mlir/lib/Conversion/ArithToAMDGPU/ArithToAMDGPU.cpp b/mlir/lib/Conversion/ArithToAMDGPU/ArithToAMDGPU.cpp
index 3d3ff001c541b5..7524db4da59178 100644
--- a/mlir/lib/Conversion/ArithToAMDGPU/ArithToAMDGPU.cpp
+++ b/mlir/lib/Conversion/ArithToAMDGPU/ArithToAMDGPU.cpp
@@ -68,9 +68,6 @@ LogicalResult ExtFOnFloat8RewritePattern::match(arith::ExtFOp op) const {
   if (auto inVecType = dyn_cast<VectorType>(inType)) {
     if (inVecType.isScalable())
       return failure();
-    if (inVecType.getShape().size() > 1)
-      // Multi-dimensional vectors are currently unsupported.
-      return failure();
     inType = inVecType.getElementType();
   }
   return success(inType.isFloat8E5M2FNUZ() || inType.isFloat8E4M3FNUZ());
@@ -81,28 +78,37 @@ void ExtFOnFloat8RewritePattern::rewrite(arith::ExtFOp op,
   Location loc = op.getLoc();
   Value in = op.getIn();
   Type outElemType = getElementTypeOrSelf(op.getOut().getType());
-  if (!isa<VectorType>(in.getType())) {
+  auto inType = dyn_cast<VectorType>(in.getType());
+  if (!inType) {
     Value asFloat = rewriter.create<amdgpu::ExtPackedFp8Op>(
         loc, rewriter.getF32Type(), in, 0);
     Value result = castF32To(outElemType, asFloat, loc, rewriter);
     return rewriter.replaceOp(op, result);
   }
-  VectorType inType = cast<VectorType>(in.getType());
   int64_t numElements = inType.getNumElements();
   Value zero = rewriter.create<arith::ConstantOp>(
       loc, outElemType, rewriter.getFloatAttr(outElemType, 0.0));
-  Value result =
-      rewriter.createOrFold<vector::SplatOp>(loc, op.getOut().getType(), zero);
   if (inType.getShape().empty()) {
     Value scalarIn =
         rewriter.create<vector::ExtractOp>(loc, in, ArrayRef<int64_t>{});
     // Recurse to send the 0-D vector case to the 1-D vector case
     Value scalarExt =
         rewriter.create<arith::ExtFOp>(loc, outElemType, scalarIn);
-    result = rewriter.create<vector::InsertOp>(loc, scalarExt, zero,
-                                               ArrayRef<int64_t>{});
+    Value result = rewriter.create<vector::InsertOp>(loc, scalarExt, zero,
+                                                     ArrayRef<int64_t>{});
     return rewriter.replaceOp(op, result);
   }
+
+  VectorType flatTy =
+      VectorType::get(SmallVector<int64_t>{numElements}, outElemType);
+  Value result = rewriter.createOrFold<vector::SplatOp>(loc, flatTy, zero);
+
+  if (inType.getShape().size() > 1) {
+    inType = VectorType::get(SmallVector<int64_t>{numElements},
+                             inType.getElementType());
+    in = rewriter.create<vector::ShapeCastOp>(loc, inType, in);
+  }
+
   for (int64_t i = 0; i < numElements; i += 4) {
     int64_t elemsThisOp = std::min(numElements, i + 4) - i;
     Value inSlice = rewriter.create<vector::ExtractStridedSliceOp>(
@@ -114,6 +120,12 @@ void ExtFOnFloat8RewritePattern::rewrite(arith::ExtFOp op,
       result = rewriter.create<vector::InsertOp>(loc, asType, result, i + j);
     }
   }
+
+  VectorType outType = cast<VectorType>(op.getOut().getType());
+  if (inType.getShape().size() != outType.getShape().size()) {
+    result = rewriter.create<vector::ShapeCastOp>(loc, outType, result);
+  }
+
   rewriter.replaceOp(op, result);
 }
 
@@ -182,9 +194,6 @@ LogicalResult TruncFToFloat8RewritePattern::match(arith::TruncFOp op) const {
   if (auto outVecType = dyn_cast<VectorType>(outType)) {
     if (outVecType.isScalable())
       return failure();
-    if (outVecType.getShape().size() > 1)
-      // Multi-dimensional vectors are currently unsupported.
-      return failure();
     outType = outVecType.getElementType();
   }
   auto inType = dyn_cast<FloatType>(getElementTypeOrSelf(op.getIn().getType()));
@@ -201,8 +210,9 @@ void TruncFToFloat8RewritePattern::rewrite(arith::TruncFOp op,
   Type outElemType = getElementTypeOrSelf(op.getOut().getType());
   if (saturateFP8)
     in = clampInput(rewriter, loc, outElemType, in);
+  auto inVectorTy = dyn_cast<VectorType>(in.getType());
   VectorType truncResType = VectorType::get(4, outElemType);
-  if (!isa<VectorType>(in.getType())) {
+  if (!inVectorTy) {
     Value asFloat = castToF32(in, loc, rewriter);
     Value asF8s = rewriter.create<amdgpu::PackedTrunc2xFp8Op>(
         loc, truncResType, asFloat, /*sourceB=*/nullptr, 0,
@@ -214,18 +224,27 @@ void TruncFToFloat8RewritePattern::rewrite(arith::TruncFOp op,
   int64_t numElements = outType.getNumElements();
   Value zero = rewriter.create<arith::ConstantOp>(
       loc, outElemType, rewriter.getFloatAttr(outElemType, 0.0));
-  Value result = rewriter.createOrFold<vector::SplatOp>(loc, outType, zero);
   if (outType.getShape().empty()) {
     Value scalarIn =
         rewriter.create<vector::ExtractOp>(loc, in, ArrayRef<int64_t>{});
     // Recurse to send the 0-D vector case to the 1-D vector case
     Value scalarTrunc =
         rewriter.create<arith::TruncFOp>(loc, outElemType, scalarIn);
-    result = rewriter.create<vector::InsertOp>(loc, scalarTrunc, zero,
-                                               ArrayRef<int64_t>{});
+    Value result = rewriter.create<vector::InsertOp>(loc, scalarTrunc, zero,
+                                                     ArrayRef<int64_t>{});
     return rewriter.replaceOp(op, result);
   }
 
+  VectorType flatTy =
+      VectorType::get(SmallVector<int64_t>{numElements}, outElemType);
+  Value result = rewriter.createOrFold<vector::SplatOp>(loc, flatTy, zero);
+
+  if (inVectorTy.getShape().size() > 1) {
+    inVectorTy = VectorType::get(SmallVector<int64_t>{numElements},
+                                 inVectorTy.getElementType());
+    in = rewriter.create<vector::ShapeCastOp>(loc, inVectorTy, in);
+  }
+
   for (int64_t i = 0; i < numElements; i += 4) {
     int64_t elemsThisOp = std::min(numElements, i + 4) - i;
     Value thisResult = nullptr;
@@ -246,6 +265,11 @@ void TruncFToFloat8RewritePattern::rewrite(arith::TruncFOp op,
     result = rewriter.create<vector::InsertStridedSliceOp>(loc, thisResult,
                                                            result, i, 1);
   }
+
+  if (inVectorTy.getShape().size() != outType.getShape().size()) {
+    result = rewriter.create<vector::ShapeCastOp>(loc, outType, result);
+  }
+
   rewriter.replaceOp(op, result);
 }
 
diff --git a/mlir/test/Conversion/ArithToAMDGPU/8-bit-floats.mlir b/mlir/test/Conversion/ArithToAMDGPU/8-bit-floats.mlir
index 159a2f02f0560e..32c069180ece58 100644
--- a/mlir/test/Conversion/ArithToAMDGPU/8-bit-floats.mlir
+++ b/mlir/test/Conversion/ArithToAMDGPU/8-bit-floats.mlir
@@ -115,3 +115,61 @@ func.func @vector_trunc_long(%v: vector<9xf32>) -> vector<9xf8E4M3FNUZ> {
   %w = arith.truncf %v : vector<9xf32> to vector<9xf8E4M3FNUZ>
   return %w : vector<9xf8E4M3FNUZ>
 }
+
+// -----
+
+// CHECK-LABEL: func.func @vector_trunc_long
+// CHECK-SAME: ([[V:%.+]]: vector<1x9xf32>)
+// CHECK: [[ZEROES:%.+]] = arith.constant dense<0.000000e+00> : vector<9xf8E4M3FNUZ>
+// CHECK: [[T0:%.+]] = amdgpu.packed_trunc_2xfp8 %{{.+}}, %{{.+}} into undef[word 0]
+// CHECK: [[T1:%.+]] = amdgpu.packed_trunc_2xfp8 %{{.+}}, %{{.+}} into [[T0]][word 1]
+// CHECK: [[W0:%.+]] = vector.insert_strided_slice [[T1]], [[ZEROES]] {offsets = [0], strides = [1]}
+
+// CHECK: [[T2:%.+]] = amdgpu.packed_trunc_2xfp8 %{{.+}}, %{{.+}} into undef[word 0]
+// CHECK: [[T3:%.+]] = amdgpu.packed_trunc_2xfp8 %{{.+}}, %{{.+}} into [[T2]][word 1]
+// CHECK: [[W1:%.+]] = vector.insert_strided_slice [[T3]], [[W0]] {offsets = [4], strides = [1]}
+
+// CHECK: [[T4:%.+]] = amdgpu.packed_trunc_2xfp8 %{{.+}}, undef into undef[word 0]
+// CHECK: [[T4_SHORT:%.+]] = vector.extract_strided_slice [[T4]] {offsets = [0], sizes = [1], strides = [1]}
+// CHECK: [[W:%.+]] = vector.insert_strided_slice [[T4_SHORT]], [[W1]] {offsets = [8], strides = [1]}
+// CHECK: [[RE:%.+]] = vector.shape_cast [[W]] : vector<9xf8E4M3FNUZ> to vector<1x9xf8E4M3FNUZ>
+// CHECK: return [[RE]]
+func.func @vector_trunc_long(%v: vector<1x9xf32>) -> vector<1x9xf8E4M3FNUZ> {
+  %w = arith.truncf %v : vector<1x9xf32> to vector<1x9xf8E4M3FNUZ>
+  return %w : vector<1x9xf8E4M3FNUZ>
+}
+
+// -----
+
+// CHECK-LABEL: func.func @vector_ext_long
+// CHECK-SAME: ([[V:%.+]]: vector<1x9xf8E4M3FNUZ>)
+// CHECK: [[CAST:%.+]] = vector.shape_cast [[V]] : vector<1x9xf8E4M3FNUZ> to vector<9xf8E4M3FNUZ>
+// CHECK: [[V0:%.+]] = vector.extract_strided_slice [[CAST]] {offsets = [0], sizes = [4], strides = [1]}
+// CHECK: [[F0:%.+]] = amdgpu.ext_packed_fp8 [[V0]][0]
+// CHECK: [[W0:%.+]] = vector.insert [[F0]]
+// CHECK: [[F1:%.+]] = amdgpu.ext_packed_fp8 [[V0]][1]
+// CHECK: [[W1:%.+]] = vector.insert [[F1]], [[W0]]
+// CHECK: [[F2:%.+]] = amdgpu.ext_packed_fp8 [[V0]][2]
+// CHECK: [[W2:%.+]] = vector.insert [[F2]], [[W1]]
+// CHECK: [[F3:%.+]] = amdgpu.ext_packed_fp8 [[V0]][3]
+// CHECK: [[W3:%.+]] = vector.insert [[F3]], [[W2]]
+
+// CHECK: [[V1:%.+]] = vector.extract_strided_slice [[CAST]] {offsets = [4], sizes = [4], strides = [1]} : vector<9xf8E4M3FNUZ> to vector<4xf8E4M3FNUZ>
+// CHECK: [[F4:%.+]] = amdgpu.ext_packed_fp8 [[V1]][0]
+// CHECK: [[W4:%.+]] = vector.insert [[F4]], [[W3]]
+// CHECK: [[F5:%.+]] = amdgpu.ext_packed_fp8 [[V1]][1]
+// CHECK: [[W5:%.+]] = vector.insert [[F5]], [[W4]]
+// CHECK: [[F6:%.+]] = amdgpu.ext_packed_fp8 [[V1]][2]
+// CHECK: [[W6:%.+]] = vector.insert [[F6]], [[W5]]
+// CHECK: [[F7:%.+]] = amdgpu.ext_packed_fp8 [[V1]][3]
+// CHECK: [[W7:%.+]] = vector.insert [[F7]], [[W6]]
+
+// CHECK: [[V2:%.+]] = vector.extract_strided_slice [[CAST]] {offsets = [8], sizes = [1], strides = [1]} : vector<9xf8E4M3FNUZ> to vector<1xf8E4M3FNUZ>
+// CHECK: [[F8:%.+]] = amdgpu.ext_packed_fp8 [[V2]][0]
+// CHECK: [[W8:%.+]] = vector.insert [[F8]], [[W7]]
+// CHECK: [[CAST:%.+]] = vector.shape_cast [[W8]] : vector<9xf32> to vector<1x9xf32>
+// CHECK: return [[CAST]]
+func.func @vector_ext_long(%v: vector<1x9xf8E4M3FNUZ>) -> vector<1x9xf32> {
+  %w = arith.extf %v : vector<1x9xf8E4M3FNUZ> to vector<1x9xf32>
+  return %w : vector<1x9xf32>
+}

Copy link
Contributor

@krzysz00 krzysz00 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Overall seems reasonable, just some nits

Value result = rewriter.createOrFold<vector::SplatOp>(loc, flatTy, zero);

if (inType.getShape().size() > 1) {
inType = VectorType::get(SmallVector<int64_t>{numElements},
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Do we need to create the SmallVector<> here? This should take ArrayRef?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

You cannot use ArrayRef as this would wrap list initializer memory, which is not safe post invoking the ArrayRef constructor. It ends up being safe 99% of the time but this highly depends on platform and exact case.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Extort that this is a one element "array" so the const& T constructor kicks in, which means that you'd just be passing in a pointer to numElements

https://llvm.org/doxygen/ArrayRef_8h_source.html#l00073

mlir/lib/Conversion/ArithToAMDGPU/ArithToAMDGPU.cpp Outdated Show resolved Hide resolved
mlir/lib/Conversion/ArithToAMDGPU/ArithToAMDGPU.cpp Outdated Show resolved Hide resolved
mlir/lib/Conversion/ArithToAMDGPU/ArithToAMDGPU.cpp Outdated Show resolved Hide resolved
@rsuderman rsuderman requested a review from krzysz00 July 8, 2024 23:26
Copy link
Contributor

@krzysz00 krzysz00 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Approved, thanks

vector cannot be cloned legally due to scalable vector types
@rsuderman rsuderman merged commit f35318e into llvm:main Jul 9, 2024
7 checks passed
aaryanshukla pushed a commit to aaryanshukla/llvm-project that referenced this pull request Jul 14, 2024
…ng (llvm#98074)

The existing `fp8` lowering from `arith` to `amdgpu` bails out on the
multidimensional case. We can handle this by `vector.shape_cast`
collapsing to the 1-D case on extraction and re-casting back to the
desired output shape.
qiaojbao pushed a commit to GPUOpen-Drivers/llvm-project that referenced this pull request Aug 13, 2024
…3d2db3b39

Local branch amd-gfx 0bb3d2d Merged main:63a9662f6cc41658f36ccddfb565d3dc5ad0135f into amd-gfx:b4e97c5e1144
Remote branch main f35318e [mlir][amdgpu] Add support for multi-dim arith.truncf/extf fp8 lowering (llvm#98074)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants