Skip to content

Commit

Permalink
[mlir][amdgpu] Add support for multi-dim arith.truncf/extf fp8 loweri…
Browse files Browse the repository at this point in the history
…ng (#98074)

The existing `fp8` lowering from `arith` to `amdgpu` bails out on the
multidimensional case. We can handle this by `vector.shape_cast`
collapsing to the 1-D case on extraction and re-casting back to the
desired output shape.
  • Loading branch information
rsuderman committed Jul 9, 2024
1 parent 9739df2 commit f35318e
Show file tree
Hide file tree
Showing 2 changed files with 98 additions and 16 deletions.
56 changes: 40 additions & 16 deletions mlir/lib/Conversion/ArithToAMDGPU/ArithToAMDGPU.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -67,9 +67,6 @@ LogicalResult ExtFOnFloat8RewritePattern::match(arith::ExtFOp op) const {
if (auto inVecType = dyn_cast<VectorType>(inType)) {
if (inVecType.isScalable())
return failure();
if (inVecType.getShape().size() > 1)
// Multi-dimensional vectors are currently unsupported.
return failure();
inType = inVecType.getElementType();
}
return success(inType.isFloat8E5M2FNUZ() || inType.isFloat8E4M3FNUZ());
Expand All @@ -80,28 +77,38 @@ void ExtFOnFloat8RewritePattern::rewrite(arith::ExtFOp op,
Location loc = op.getLoc();
Value in = op.getIn();
Type outElemType = getElementTypeOrSelf(op.getOut().getType());
if (!isa<VectorType>(in.getType())) {
auto inType = dyn_cast<VectorType>(in.getType());
if (!inType) {
Value asFloat = rewriter.create<amdgpu::ExtPackedFp8Op>(
loc, rewriter.getF32Type(), in, 0);
Value result = castF32To(outElemType, asFloat, loc, rewriter);
return rewriter.replaceOp(op, result);
}
VectorType inType = cast<VectorType>(in.getType());
int64_t numElements = inType.getNumElements();
Value zero = rewriter.create<arith::ConstantOp>(
loc, outElemType, rewriter.getFloatAttr(outElemType, 0.0));
Value result =
rewriter.createOrFold<vector::SplatOp>(loc, op.getOut().getType(), zero);
if (inType.getShape().empty()) {
Value scalarIn =
rewriter.create<vector::ExtractOp>(loc, in, ArrayRef<int64_t>{});
// Recurse to send the 0-D vector case to the 1-D vector case
Value scalarExt =
rewriter.create<arith::ExtFOp>(loc, outElemType, scalarIn);
result = rewriter.create<vector::InsertOp>(loc, scalarExt, zero,
ArrayRef<int64_t>{});
Value result = rewriter.create<vector::InsertOp>(loc, scalarExt, zero,
ArrayRef<int64_t>{});
return rewriter.replaceOp(op, result);
}

VectorType outType = cast<VectorType>(op.getOut().getType());
VectorType flatTy = VectorType::get(SmallVector<int64_t>{numElements},
outType.getElementType());
Value result = rewriter.createOrFold<vector::SplatOp>(loc, flatTy, zero);

if (inType.getRank() > 1) {
inType = VectorType::get(SmallVector<int64_t>{numElements},
inType.getElementType());
in = rewriter.create<vector::ShapeCastOp>(loc, inType, in);
}

for (int64_t i = 0; i < numElements; i += 4) {
int64_t elemsThisOp = std::min(numElements, i + 4) - i;
Value inSlice = rewriter.create<vector::ExtractStridedSliceOp>(
Expand All @@ -113,6 +120,11 @@ void ExtFOnFloat8RewritePattern::rewrite(arith::ExtFOp op,
result = rewriter.create<vector::InsertOp>(loc, asType, result, i + j);
}
}

if (inType.getRank() != outType.getRank()) {
result = rewriter.create<vector::ShapeCastOp>(loc, outType, result);
}

rewriter.replaceOp(op, result);
}

Expand Down Expand Up @@ -181,9 +193,6 @@ LogicalResult TruncFToFloat8RewritePattern::match(arith::TruncFOp op) const {
if (auto outVecType = dyn_cast<VectorType>(outType)) {
if (outVecType.isScalable())
return failure();
if (outVecType.getShape().size() > 1)
// Multi-dimensional vectors are currently unsupported.
return failure();
outType = outVecType.getElementType();
}
auto inType = dyn_cast<FloatType>(getElementTypeOrSelf(op.getIn().getType()));
Expand All @@ -200,8 +209,9 @@ void TruncFToFloat8RewritePattern::rewrite(arith::TruncFOp op,
Type outElemType = getElementTypeOrSelf(op.getOut().getType());
if (saturateFP8)
in = clampInput(rewriter, loc, outElemType, in);
auto inVectorTy = dyn_cast<VectorType>(in.getType());
VectorType truncResType = VectorType::get(4, outElemType);
if (!isa<VectorType>(in.getType())) {
if (!inVectorTy) {
Value asFloat = castToF32(in, loc, rewriter);
Value asF8s = rewriter.create<amdgpu::PackedTrunc2xFp8Op>(
loc, truncResType, asFloat, /*sourceB=*/nullptr, 0,
Expand All @@ -213,18 +223,27 @@ void TruncFToFloat8RewritePattern::rewrite(arith::TruncFOp op,
int64_t numElements = outType.getNumElements();
Value zero = rewriter.create<arith::ConstantOp>(
loc, outElemType, rewriter.getFloatAttr(outElemType, 0.0));
Value result = rewriter.createOrFold<vector::SplatOp>(loc, outType, zero);
if (outType.getShape().empty()) {
Value scalarIn =
rewriter.create<vector::ExtractOp>(loc, in, ArrayRef<int64_t>{});
// Recurse to send the 0-D vector case to the 1-D vector case
Value scalarTrunc =
rewriter.create<arith::TruncFOp>(loc, outElemType, scalarIn);
result = rewriter.create<vector::InsertOp>(loc, scalarTrunc, zero,
ArrayRef<int64_t>{});
Value result = rewriter.create<vector::InsertOp>(loc, scalarTrunc, zero,
ArrayRef<int64_t>{});
return rewriter.replaceOp(op, result);
}

VectorType flatTy = VectorType::get(SmallVector<int64_t>{numElements},
outType.getElementType());
Value result = rewriter.createOrFold<vector::SplatOp>(loc, flatTy, zero);

if (inVectorTy.getRank() > 1) {
inVectorTy = VectorType::get(SmallVector<int64_t>{numElements},
inVectorTy.getElementType());
in = rewriter.create<vector::ShapeCastOp>(loc, inVectorTy, in);
}

for (int64_t i = 0; i < numElements; i += 4) {
int64_t elemsThisOp = std::min(numElements, i + 4) - i;
Value thisResult = nullptr;
Expand All @@ -245,6 +264,11 @@ void TruncFToFloat8RewritePattern::rewrite(arith::TruncFOp op,
result = rewriter.create<vector::InsertStridedSliceOp>(loc, thisResult,
result, i, 1);
}

if (inVectorTy.getRank() != outType.getRank()) {
result = rewriter.create<vector::ShapeCastOp>(loc, outType, result);
}

rewriter.replaceOp(op, result);
}

Expand Down
58 changes: 58 additions & 0 deletions mlir/test/Conversion/ArithToAMDGPU/8-bit-floats.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -115,3 +115,61 @@ func.func @vector_trunc_long(%v: vector<9xf32>) -> vector<9xf8E4M3FNUZ> {
%w = arith.truncf %v : vector<9xf32> to vector<9xf8E4M3FNUZ>
return %w : vector<9xf8E4M3FNUZ>
}

// -----

// CHECK-LABEL: func.func @vector_trunc_long_2d
// CHECK-SAME: ([[V:%.+]]: vector<1x9xf32>)
// CHECK: [[ZEROES:%.+]] = arith.constant dense<0.000000e+00> : vector<9xf8E4M3FNUZ>
// CHECK: [[T0:%.+]] = amdgpu.packed_trunc_2xfp8 %{{.+}}, %{{.+}} into undef[word 0]
// CHECK: [[T1:%.+]] = amdgpu.packed_trunc_2xfp8 %{{.+}}, %{{.+}} into [[T0]][word 1]
// CHECK: [[W0:%.+]] = vector.insert_strided_slice [[T1]], [[ZEROES]] {offsets = [0], strides = [1]}

// CHECK: [[T2:%.+]] = amdgpu.packed_trunc_2xfp8 %{{.+}}, %{{.+}} into undef[word 0]
// CHECK: [[T3:%.+]] = amdgpu.packed_trunc_2xfp8 %{{.+}}, %{{.+}} into [[T2]][word 1]
// CHECK: [[W1:%.+]] = vector.insert_strided_slice [[T3]], [[W0]] {offsets = [4], strides = [1]}

// CHECK: [[T4:%.+]] = amdgpu.packed_trunc_2xfp8 %{{.+}}, undef into undef[word 0]
// CHECK: [[T4_SHORT:%.+]] = vector.extract_strided_slice [[T4]] {offsets = [0], sizes = [1], strides = [1]}
// CHECK: [[W:%.+]] = vector.insert_strided_slice [[T4_SHORT]], [[W1]] {offsets = [8], strides = [1]}
// CHECK: [[RE:%.+]] = vector.shape_cast [[W]] : vector<9xf8E4M3FNUZ> to vector<1x9xf8E4M3FNUZ>
// CHECK: return [[RE]]
func.func @vector_trunc_long_2d(%v: vector<1x9xf32>) -> vector<1x9xf8E4M3FNUZ> {
%w = arith.truncf %v : vector<1x9xf32> to vector<1x9xf8E4M3FNUZ>
return %w : vector<1x9xf8E4M3FNUZ>
}

// -----

// CHECK-LABEL: func.func @vector_ext_long_2d
// CHECK-SAME: ([[V:%.+]]: vector<1x9xf8E4M3FNUZ>)
// CHECK: [[CAST:%.+]] = vector.shape_cast [[V]] : vector<1x9xf8E4M3FNUZ> to vector<9xf8E4M3FNUZ>
// CHECK: [[V0:%.+]] = vector.extract_strided_slice [[CAST]] {offsets = [0], sizes = [4], strides = [1]}
// CHECK: [[F0:%.+]] = amdgpu.ext_packed_fp8 [[V0]][0]
// CHECK: [[W0:%.+]] = vector.insert [[F0]]
// CHECK: [[F1:%.+]] = amdgpu.ext_packed_fp8 [[V0]][1]
// CHECK: [[W1:%.+]] = vector.insert [[F1]], [[W0]]
// CHECK: [[F2:%.+]] = amdgpu.ext_packed_fp8 [[V0]][2]
// CHECK: [[W2:%.+]] = vector.insert [[F2]], [[W1]]
// CHECK: [[F3:%.+]] = amdgpu.ext_packed_fp8 [[V0]][3]
// CHECK: [[W3:%.+]] = vector.insert [[F3]], [[W2]]

// CHECK: [[V1:%.+]] = vector.extract_strided_slice [[CAST]] {offsets = [4], sizes = [4], strides = [1]} : vector<9xf8E4M3FNUZ> to vector<4xf8E4M3FNUZ>
// CHECK: [[F4:%.+]] = amdgpu.ext_packed_fp8 [[V1]][0]
// CHECK: [[W4:%.+]] = vector.insert [[F4]], [[W3]]
// CHECK: [[F5:%.+]] = amdgpu.ext_packed_fp8 [[V1]][1]
// CHECK: [[W5:%.+]] = vector.insert [[F5]], [[W4]]
// CHECK: [[F6:%.+]] = amdgpu.ext_packed_fp8 [[V1]][2]
// CHECK: [[W6:%.+]] = vector.insert [[F6]], [[W5]]
// CHECK: [[F7:%.+]] = amdgpu.ext_packed_fp8 [[V1]][3]
// CHECK: [[W7:%.+]] = vector.insert [[F7]], [[W6]]

// CHECK: [[V2:%.+]] = vector.extract_strided_slice [[CAST]] {offsets = [8], sizes = [1], strides = [1]} : vector<9xf8E4M3FNUZ> to vector<1xf8E4M3FNUZ>
// CHECK: [[F8:%.+]] = amdgpu.ext_packed_fp8 [[V2]][0]
// CHECK: [[W8:%.+]] = vector.insert [[F8]], [[W7]]
// CHECK: [[CAST:%.+]] = vector.shape_cast [[W8]] : vector<9xf32> to vector<1x9xf32>
// CHECK: return [[CAST]]
func.func @vector_ext_long_2d(%v: vector<1x9xf8E4M3FNUZ>) -> vector<1x9xf32> {
%w = arith.extf %v : vector<1x9xf8E4M3FNUZ> to vector<1x9xf32>
return %w : vector<1x9xf32>
}

0 comments on commit f35318e

Please sign in to comment.