-
Notifications
You must be signed in to change notification settings - Fork 11.5k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[mlir][amdgpu] Add support for multi-dim arith.truncf/extf fp8 lowering #98074
Conversation
The existing `fp8` lowering from `arith` to `amdgpu` bails out on the multidimensional case. We can handle this by `vector.shape_cast` collapsing to the 1-D case on extraction and re-casting back to the desired output shape.
@llvm/pr-subscribers-mlir @llvm/pr-subscribers-mlir-gpu Author: Rob Suderman (rsuderman) ChangesThe existing Full diff: https://github.com/llvm/llvm-project/pull/98074.diff 2 Files Affected:
diff --git a/mlir/lib/Conversion/ArithToAMDGPU/ArithToAMDGPU.cpp b/mlir/lib/Conversion/ArithToAMDGPU/ArithToAMDGPU.cpp
index 3d3ff001c541b5..7524db4da59178 100644
--- a/mlir/lib/Conversion/ArithToAMDGPU/ArithToAMDGPU.cpp
+++ b/mlir/lib/Conversion/ArithToAMDGPU/ArithToAMDGPU.cpp
@@ -68,9 +68,6 @@ LogicalResult ExtFOnFloat8RewritePattern::match(arith::ExtFOp op) const {
if (auto inVecType = dyn_cast<VectorType>(inType)) {
if (inVecType.isScalable())
return failure();
- if (inVecType.getShape().size() > 1)
- // Multi-dimensional vectors are currently unsupported.
- return failure();
inType = inVecType.getElementType();
}
return success(inType.isFloat8E5M2FNUZ() || inType.isFloat8E4M3FNUZ());
@@ -81,28 +78,37 @@ void ExtFOnFloat8RewritePattern::rewrite(arith::ExtFOp op,
Location loc = op.getLoc();
Value in = op.getIn();
Type outElemType = getElementTypeOrSelf(op.getOut().getType());
- if (!isa<VectorType>(in.getType())) {
+ auto inType = dyn_cast<VectorType>(in.getType());
+ if (!inType) {
Value asFloat = rewriter.create<amdgpu::ExtPackedFp8Op>(
loc, rewriter.getF32Type(), in, 0);
Value result = castF32To(outElemType, asFloat, loc, rewriter);
return rewriter.replaceOp(op, result);
}
- VectorType inType = cast<VectorType>(in.getType());
int64_t numElements = inType.getNumElements();
Value zero = rewriter.create<arith::ConstantOp>(
loc, outElemType, rewriter.getFloatAttr(outElemType, 0.0));
- Value result =
- rewriter.createOrFold<vector::SplatOp>(loc, op.getOut().getType(), zero);
if (inType.getShape().empty()) {
Value scalarIn =
rewriter.create<vector::ExtractOp>(loc, in, ArrayRef<int64_t>{});
// Recurse to send the 0-D vector case to the 1-D vector case
Value scalarExt =
rewriter.create<arith::ExtFOp>(loc, outElemType, scalarIn);
- result = rewriter.create<vector::InsertOp>(loc, scalarExt, zero,
- ArrayRef<int64_t>{});
+ Value result = rewriter.create<vector::InsertOp>(loc, scalarExt, zero,
+ ArrayRef<int64_t>{});
return rewriter.replaceOp(op, result);
}
+
+ VectorType flatTy =
+ VectorType::get(SmallVector<int64_t>{numElements}, outElemType);
+ Value result = rewriter.createOrFold<vector::SplatOp>(loc, flatTy, zero);
+
+ if (inType.getShape().size() > 1) {
+ inType = VectorType::get(SmallVector<int64_t>{numElements},
+ inType.getElementType());
+ in = rewriter.create<vector::ShapeCastOp>(loc, inType, in);
+ }
+
for (int64_t i = 0; i < numElements; i += 4) {
int64_t elemsThisOp = std::min(numElements, i + 4) - i;
Value inSlice = rewriter.create<vector::ExtractStridedSliceOp>(
@@ -114,6 +120,12 @@ void ExtFOnFloat8RewritePattern::rewrite(arith::ExtFOp op,
result = rewriter.create<vector::InsertOp>(loc, asType, result, i + j);
}
}
+
+ VectorType outType = cast<VectorType>(op.getOut().getType());
+ if (inType.getShape().size() != outType.getShape().size()) {
+ result = rewriter.create<vector::ShapeCastOp>(loc, outType, result);
+ }
+
rewriter.replaceOp(op, result);
}
@@ -182,9 +194,6 @@ LogicalResult TruncFToFloat8RewritePattern::match(arith::TruncFOp op) const {
if (auto outVecType = dyn_cast<VectorType>(outType)) {
if (outVecType.isScalable())
return failure();
- if (outVecType.getShape().size() > 1)
- // Multi-dimensional vectors are currently unsupported.
- return failure();
outType = outVecType.getElementType();
}
auto inType = dyn_cast<FloatType>(getElementTypeOrSelf(op.getIn().getType()));
@@ -201,8 +210,9 @@ void TruncFToFloat8RewritePattern::rewrite(arith::TruncFOp op,
Type outElemType = getElementTypeOrSelf(op.getOut().getType());
if (saturateFP8)
in = clampInput(rewriter, loc, outElemType, in);
+ auto inVectorTy = dyn_cast<VectorType>(in.getType());
VectorType truncResType = VectorType::get(4, outElemType);
- if (!isa<VectorType>(in.getType())) {
+ if (!inVectorTy) {
Value asFloat = castToF32(in, loc, rewriter);
Value asF8s = rewriter.create<amdgpu::PackedTrunc2xFp8Op>(
loc, truncResType, asFloat, /*sourceB=*/nullptr, 0,
@@ -214,18 +224,27 @@ void TruncFToFloat8RewritePattern::rewrite(arith::TruncFOp op,
int64_t numElements = outType.getNumElements();
Value zero = rewriter.create<arith::ConstantOp>(
loc, outElemType, rewriter.getFloatAttr(outElemType, 0.0));
- Value result = rewriter.createOrFold<vector::SplatOp>(loc, outType, zero);
if (outType.getShape().empty()) {
Value scalarIn =
rewriter.create<vector::ExtractOp>(loc, in, ArrayRef<int64_t>{});
// Recurse to send the 0-D vector case to the 1-D vector case
Value scalarTrunc =
rewriter.create<arith::TruncFOp>(loc, outElemType, scalarIn);
- result = rewriter.create<vector::InsertOp>(loc, scalarTrunc, zero,
- ArrayRef<int64_t>{});
+ Value result = rewriter.create<vector::InsertOp>(loc, scalarTrunc, zero,
+ ArrayRef<int64_t>{});
return rewriter.replaceOp(op, result);
}
+ VectorType flatTy =
+ VectorType::get(SmallVector<int64_t>{numElements}, outElemType);
+ Value result = rewriter.createOrFold<vector::SplatOp>(loc, flatTy, zero);
+
+ if (inVectorTy.getShape().size() > 1) {
+ inVectorTy = VectorType::get(SmallVector<int64_t>{numElements},
+ inVectorTy.getElementType());
+ in = rewriter.create<vector::ShapeCastOp>(loc, inVectorTy, in);
+ }
+
for (int64_t i = 0; i < numElements; i += 4) {
int64_t elemsThisOp = std::min(numElements, i + 4) - i;
Value thisResult = nullptr;
@@ -246,6 +265,11 @@ void TruncFToFloat8RewritePattern::rewrite(arith::TruncFOp op,
result = rewriter.create<vector::InsertStridedSliceOp>(loc, thisResult,
result, i, 1);
}
+
+ if (inVectorTy.getShape().size() != outType.getShape().size()) {
+ result = rewriter.create<vector::ShapeCastOp>(loc, outType, result);
+ }
+
rewriter.replaceOp(op, result);
}
diff --git a/mlir/test/Conversion/ArithToAMDGPU/8-bit-floats.mlir b/mlir/test/Conversion/ArithToAMDGPU/8-bit-floats.mlir
index 159a2f02f0560e..32c069180ece58 100644
--- a/mlir/test/Conversion/ArithToAMDGPU/8-bit-floats.mlir
+++ b/mlir/test/Conversion/ArithToAMDGPU/8-bit-floats.mlir
@@ -115,3 +115,61 @@ func.func @vector_trunc_long(%v: vector<9xf32>) -> vector<9xf8E4M3FNUZ> {
%w = arith.truncf %v : vector<9xf32> to vector<9xf8E4M3FNUZ>
return %w : vector<9xf8E4M3FNUZ>
}
+
+// -----
+
+// CHECK-LABEL: func.func @vector_trunc_long
+// CHECK-SAME: ([[V:%.+]]: vector<1x9xf32>)
+// CHECK: [[ZEROES:%.+]] = arith.constant dense<0.000000e+00> : vector<9xf8E4M3FNUZ>
+// CHECK: [[T0:%.+]] = amdgpu.packed_trunc_2xfp8 %{{.+}}, %{{.+}} into undef[word 0]
+// CHECK: [[T1:%.+]] = amdgpu.packed_trunc_2xfp8 %{{.+}}, %{{.+}} into [[T0]][word 1]
+// CHECK: [[W0:%.+]] = vector.insert_strided_slice [[T1]], [[ZEROES]] {offsets = [0], strides = [1]}
+
+// CHECK: [[T2:%.+]] = amdgpu.packed_trunc_2xfp8 %{{.+}}, %{{.+}} into undef[word 0]
+// CHECK: [[T3:%.+]] = amdgpu.packed_trunc_2xfp8 %{{.+}}, %{{.+}} into [[T2]][word 1]
+// CHECK: [[W1:%.+]] = vector.insert_strided_slice [[T3]], [[W0]] {offsets = [4], strides = [1]}
+
+// CHECK: [[T4:%.+]] = amdgpu.packed_trunc_2xfp8 %{{.+}}, undef into undef[word 0]
+// CHECK: [[T4_SHORT:%.+]] = vector.extract_strided_slice [[T4]] {offsets = [0], sizes = [1], strides = [1]}
+// CHECK: [[W:%.+]] = vector.insert_strided_slice [[T4_SHORT]], [[W1]] {offsets = [8], strides = [1]}
+// CHECK: [[RE:%.+]] = vector.shape_cast [[W]] : vector<9xf8E4M3FNUZ> to vector<1x9xf8E4M3FNUZ>
+// CHECK: return [[RE]]
+func.func @vector_trunc_long(%v: vector<1x9xf32>) -> vector<1x9xf8E4M3FNUZ> {
+ %w = arith.truncf %v : vector<1x9xf32> to vector<1x9xf8E4M3FNUZ>
+ return %w : vector<1x9xf8E4M3FNUZ>
+}
+
+// -----
+
+// CHECK-LABEL: func.func @vector_ext_long
+// CHECK-SAME: ([[V:%.+]]: vector<1x9xf8E4M3FNUZ>)
+// CHECK: [[CAST:%.+]] = vector.shape_cast [[V]] : vector<1x9xf8E4M3FNUZ> to vector<9xf8E4M3FNUZ>
+// CHECK: [[V0:%.+]] = vector.extract_strided_slice [[CAST]] {offsets = [0], sizes = [4], strides = [1]}
+// CHECK: [[F0:%.+]] = amdgpu.ext_packed_fp8 [[V0]][0]
+// CHECK: [[W0:%.+]] = vector.insert [[F0]]
+// CHECK: [[F1:%.+]] = amdgpu.ext_packed_fp8 [[V0]][1]
+// CHECK: [[W1:%.+]] = vector.insert [[F1]], [[W0]]
+// CHECK: [[F2:%.+]] = amdgpu.ext_packed_fp8 [[V0]][2]
+// CHECK: [[W2:%.+]] = vector.insert [[F2]], [[W1]]
+// CHECK: [[F3:%.+]] = amdgpu.ext_packed_fp8 [[V0]][3]
+// CHECK: [[W3:%.+]] = vector.insert [[F3]], [[W2]]
+
+// CHECK: [[V1:%.+]] = vector.extract_strided_slice [[CAST]] {offsets = [4], sizes = [4], strides = [1]} : vector<9xf8E4M3FNUZ> to vector<4xf8E4M3FNUZ>
+// CHECK: [[F4:%.+]] = amdgpu.ext_packed_fp8 [[V1]][0]
+// CHECK: [[W4:%.+]] = vector.insert [[F4]], [[W3]]
+// CHECK: [[F5:%.+]] = amdgpu.ext_packed_fp8 [[V1]][1]
+// CHECK: [[W5:%.+]] = vector.insert [[F5]], [[W4]]
+// CHECK: [[F6:%.+]] = amdgpu.ext_packed_fp8 [[V1]][2]
+// CHECK: [[W6:%.+]] = vector.insert [[F6]], [[W5]]
+// CHECK: [[F7:%.+]] = amdgpu.ext_packed_fp8 [[V1]][3]
+// CHECK: [[W7:%.+]] = vector.insert [[F7]], [[W6]]
+
+// CHECK: [[V2:%.+]] = vector.extract_strided_slice [[CAST]] {offsets = [8], sizes = [1], strides = [1]} : vector<9xf8E4M3FNUZ> to vector<1xf8E4M3FNUZ>
+// CHECK: [[F8:%.+]] = amdgpu.ext_packed_fp8 [[V2]][0]
+// CHECK: [[W8:%.+]] = vector.insert [[F8]], [[W7]]
+// CHECK: [[CAST:%.+]] = vector.shape_cast [[W8]] : vector<9xf32> to vector<1x9xf32>
+// CHECK: return [[CAST]]
+func.func @vector_ext_long(%v: vector<1x9xf8E4M3FNUZ>) -> vector<1x9xf32> {
+ %w = arith.extf %v : vector<1x9xf8E4M3FNUZ> to vector<1x9xf32>
+ return %w : vector<1x9xf32>
+}
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Overall seems reasonable, just some nits
Value result = rewriter.createOrFold<vector::SplatOp>(loc, flatTy, zero); | ||
|
||
if (inType.getShape().size() > 1) { | ||
inType = VectorType::get(SmallVector<int64_t>{numElements}, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Do we need to create the SmallVector<> here? This should take ArrayRef
?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
You cannot use ArrayRef
as this would wrap list initializer memory, which is not safe post invoking the ArrayRef
constructor. It ends up being safe 99% of the time but this highly depends on platform and exact case.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Extort that this is a one element "array" so the const& T
constructor kicks in, which means that you'd just be passing in a pointer to numElements
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Approved, thanks
vector cannot be cloned legally due to scalable vector types
…ng (llvm#98074) The existing `fp8` lowering from `arith` to `amdgpu` bails out on the multidimensional case. We can handle this by `vector.shape_cast` collapsing to the 1-D case on extraction and re-casting back to the desired output shape.
…3d2db3b39 Local branch amd-gfx 0bb3d2d Merged main:63a9662f6cc41658f36ccddfb565d3dc5ad0135f into amd-gfx:b4e97c5e1144 Remote branch main f35318e [mlir][amdgpu] Add support for multi-dim arith.truncf/extf fp8 lowering (llvm#98074)
The existing
fp8
lowering fromarith
toamdgpu
bails out on the multidimensional case. We can handle this byvector.shape_cast
collapsing to the 1-D case on extraction and re-casting back to the desired output shape.