Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
272 changes: 272 additions & 0 deletions mlir/lib/Conversion/ArithToAMDGPU/ArithToAMDGPU.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,10 @@
#include "mlir/Dialect/Arith/Utils/Utils.h"
#include "mlir/Dialect/LLVMIR/LLVMDialect.h"
#include "mlir/Dialect/LLVMIR/ROCDLDialect.h"
#include "mlir/Dialect/Utils/IndexingUtils.h"
#include "mlir/Dialect/Vector/IR/VectorOps.h"
#include "mlir/Dialect/Vector/Transforms/VectorRewritePatterns.h"
#include "mlir/Dialect/Vector/Utils/VectorUtils.h"
#include "mlir/IR/BuiltinTypes.h"
#include "mlir/IR/PatternMatch.h"
#include "mlir/IR/TypeUtilities.h"
Expand All @@ -32,6 +35,7 @@ using namespace mlir::amdgpu;
namespace {
// Define commonly used chipsets versions for convenience.
constexpr Chipset kGfx942 = Chipset(9, 4, 2);
constexpr Chipset kGfx950 = Chipset(9, 5, 0);

struct ArithToAMDGPUConversionPass final
: impl::ArithToAMDGPUConversionPassBase<ArithToAMDGPUConversionPass> {
Expand Down Expand Up @@ -73,6 +77,28 @@ struct TruncfToFloat16RewritePattern final
PatternRewriter &rewriter) const override;
};

struct ScalingExtFRewritePattern final
: OpRewritePattern<arith::ScalingExtFOp> {
using OpRewritePattern::OpRewritePattern;

ScalingExtFRewritePattern(MLIRContext *ctx)
: OpRewritePattern::OpRewritePattern(ctx) {}

LogicalResult matchAndRewrite(arith::ScalingExtFOp op,
PatternRewriter &rewriter) const override;
};

struct ScalingTruncFRewritePattern final
: OpRewritePattern<arith::ScalingTruncFOp> {
using OpRewritePattern::OpRewritePattern;

ScalingTruncFRewritePattern(MLIRContext *ctx)
: OpRewritePattern::OpRewritePattern(ctx) {}

LogicalResult matchAndRewrite(arith::ScalingTruncFOp op,
PatternRewriter &rewriter) const override;
};

} // end namespace

static bool isSupportedF8(Type elementType, Chipset chipset) {
Expand Down Expand Up @@ -395,6 +421,247 @@ LogicalResult TruncfToFloat16RewritePattern::matchAndRewrite(
return success();
}

/// Get the broadcasted / splatted value for a chain of ops.
static Value getOriginalVectorValue(Value value) {
Value current = value;
while (Operation *definingOp = current.getDefiningOp()) {
bool skipOp = llvm::TypeSwitch<Operation *, bool>(definingOp)
.Case<vector::ShapeCastOp>([&current](auto op) {
current = op.getSource();
return true;
})
.Case<vector::BroadcastOp>([&current](auto op) {
current = op.getSource();
return false;
})
.Case<vector::SplatOp>([&current](auto op) {
current = op.getInput();
return false;
})
.Default([](Operation *) { return false; });

if (!skipOp) {
break;
}
}
return current;
}

LogicalResult
ScalingExtFRewritePattern::matchAndRewrite(arith::ScalingExtFOp op,
PatternRewriter &rewriter) const {
Location loc = op.getLoc();
constexpr int64_t opWidth = 2;

Value in = op.getIn();
Value scale = op.getScale();
Value out = op.getOut();

Type f32 = rewriter.getF32Type();
Type inType = getElementTypeOrSelf(in);
Type scaleType = getElementTypeOrSelf(scale);
Type outType = getElementTypeOrSelf(out);

VectorType outVecType = dyn_cast<VectorType>(out.getType());
VectorType scaleVecType = dyn_cast<VectorType>(scale.getType());

if (outVecType && outVecType.isScalable())
return failure();

Type scaleF32Type =
scaleVecType ? VectorType::get(scaleVecType.getShape(), f32) : f32;
if (scaleType.getIntOrFloatBitWidth() < 32)
scale = rewriter.create<arith::ExtFOp>(loc, scaleF32Type, scale);
else if (scaleType.getIntOrFloatBitWidth() > 32)
scale = rewriter.create<arith::TruncFOp>(loc, scaleF32Type, scale);

VectorType extScaleResultType = VectorType::get(opWidth, outType);

if (!outVecType) {
Value inCast =
rewriter.create<vector::SplatOp>(loc, VectorType::get(1, inType), in);
// TODO: replace this with non-packed ScaledExtOp
Value scaleExt = rewriter.create<amdgpu::ScaledExtPackedOp>(
loc, extScaleResultType, inCast, scale, 0);
scaleExt = rewriter.replaceOpWithNewOp<vector::ExtractOp>(op, scaleExt, 0);
return success();
}

VectorType inVecType = cast<VectorType>(in.getType());
Value origScale = getOriginalVectorValue(op.getScale());

ArrayRef<int64_t> inShape = inVecType.getShape();
SmallVector<int64_t> originalScaleShape;
if (auto origScaleVecType = dyn_cast<VectorType>(origScale.getType()))
llvm::append_range(originalScaleShape, origScaleVecType.getShape());

originalScaleShape.insert(originalScaleShape.end(),
inShape.size() - originalScaleShape.size(), 1);

auto maybeRatio = computeShapeRatio(inShape, originalScaleShape);
assert(maybeRatio &&
"failed to derive block size from broadcast or splat operation");

SmallVector<int64_t> ratio =
maybeRatio.value_or(SmallVector<int64_t>(inShape.size(), 1));

int64_t blockSize = computeProduct(ratio);

Value zero = rewriter.create<arith::ConstantOp>(
loc, outType, rewriter.getFloatAttr(outType, 0.0));
Value result = rewriter.createOrFold<vector::SplatOp>(loc, outVecType, zero);

for (SmallVector<int64_t> offsets : StaticTileOffsetRange(inShape, ratio)) {
SmallVector<int64_t> strides(offsets.size(), 1);
Value block = rewriter.create<vector::ExtractStridedSliceOp>(
loc, in, offsets, ratio, strides);
VectorType block1DType = VectorType::get(blockSize, inType);
Value block1D =
rewriter.create<vector::ShapeCastOp>(loc, block1DType, block);
Value uniformScale =
rewriter.create<vector::ExtractOp>(loc, scale, offsets);

VectorType blockResultType = VectorType::get(blockSize, outType);
Value blockResult =
rewriter.createOrFold<vector::SplatOp>(loc, blockResultType, zero);

for (int64_t i = 0, sliceWidth = std::min(opWidth, blockSize - i);
i < blockSize;
i += sliceWidth, sliceWidth = std::min(opWidth, blockSize - i)) {
Value slice = rewriter.create<vector::ExtractStridedSliceOp>(
loc, block1D, i, sliceWidth, 1);
// TODO: replace this with non-packed ScaledExtOp for sliceWidth == 1
Value scaleExt = rewriter.create<amdgpu::ScaledExtPackedOp>(
loc, extScaleResultType, slice, uniformScale, 0);
if (sliceWidth != opWidth)
scaleExt = rewriter.create<vector::ExtractStridedSliceOp>(
loc, scaleExt, 0, sliceWidth, 1);
blockResult = rewriter.create<vector::InsertStridedSliceOp>(
loc, scaleExt, blockResult, i, 1);
}

VectorType resultType = VectorType::get(ratio, outType);
Value cast =
rewriter.create<vector::ShapeCastOp>(loc, resultType, blockResult);
result = rewriter.create<vector::InsertStridedSliceOp>(loc, cast, result,
offsets, strides);
}

rewriter.replaceOp(op, result);

return success();
}

LogicalResult
ScalingTruncFRewritePattern::matchAndRewrite(arith::ScalingTruncFOp op,
PatternRewriter &rewriter) const {
Location loc = op.getLoc();
constexpr int64_t opWidth = 2;

Value in = op.getIn();
Value scale = op.getScale();
Value out = op.getOut();

Type f32 = rewriter.getF32Type();
Type inType = getElementTypeOrSelf(in);
Type scaleType = getElementTypeOrSelf(scale);
Type outType = getElementTypeOrSelf(out);

VectorType outVecType = dyn_cast<VectorType>(out.getType());
VectorType scaleVecType = dyn_cast<VectorType>(scale.getType());

if (outVecType && outVecType.isScalable())
return failure();

Type scaleF32Type =
scaleVecType ? VectorType::get(scaleVecType.getShape(), f32) : f32;
if (scaleType.getIntOrFloatBitWidth() < 32)
scale = rewriter.create<arith::ExtFOp>(loc, scaleF32Type, scale);
else if (scaleType.getIntOrFloatBitWidth() > 32)
scale = rewriter.create<arith::TruncFOp>(loc, scaleF32Type, scale);

Value zero = rewriter.create<arith::ConstantOp>(
loc, outType, rewriter.getFloatAttr(outType, 0.0));
unsigned numPackedElem = 32 / outType.getIntOrFloatBitWidth();
VectorType truncScaleResultType = VectorType::get(numPackedElem, outType);

if (!outVecType) {
Type inVecType = VectorType::get(1, inType);
Value inCast = rewriter.create<vector::SplatOp>(loc, inVecType, in);
// TODO: replace this with non-packed ScaledTruncOp
Value scaleTrunc = rewriter.create<amdgpu::PackedScaledTruncOp>(
loc, truncScaleResultType, inCast, scale, 0, /*existing=*/nullptr);
scaleTrunc =
rewriter.replaceOpWithNewOp<vector::ExtractOp>(op, scaleTrunc, 0);
return success();
}

VectorType inVecType = cast<VectorType>(in.getType());
Value origScale = getOriginalVectorValue(op.getScale());

ArrayRef<int64_t> inShape = inVecType.getShape();
SmallVector<int64_t> originalScaleShape;
if (auto origScaleVecType = dyn_cast<VectorType>(origScale.getType()))
llvm::append_range(originalScaleShape, origScaleVecType.getShape());

originalScaleShape.insert(originalScaleShape.end(),
inShape.size() - originalScaleShape.size(), 1);

auto maybeRatio = computeShapeRatio(inShape, originalScaleShape);
assert(maybeRatio &&
"failed to derive block size from broadcast or splat operation");

SmallVector<int64_t> ratio =
maybeRatio.value_or(SmallVector<int64_t>(inShape.size(), 1));

int64_t blockSize = computeProduct(ratio);

Value result = rewriter.createOrFold<vector::SplatOp>(loc, outVecType, zero);

for (SmallVector<int64_t> offsets : StaticTileOffsetRange(inShape, ratio)) {
SmallVector<int64_t> strides(offsets.size(), 1);
Value block = rewriter.create<vector::ExtractStridedSliceOp>(
loc, in, offsets, ratio, strides);
VectorType block1DType = VectorType::get(blockSize, inType);
Value block1D =
rewriter.create<vector::ShapeCastOp>(loc, block1DType, block);
Value uniformScale =
rewriter.create<vector::ExtractOp>(loc, scale, offsets);

VectorType blockResultType = VectorType::get(blockSize, outType);
Value blockResult =
rewriter.createOrFold<vector::SplatOp>(loc, blockResultType, zero);

for (int64_t i = 0, sliceWidth = std::min(opWidth, blockSize - i);
i < blockSize;
i += sliceWidth, sliceWidth = std::min(opWidth, blockSize - i)) {
Value slice = rewriter.create<vector::ExtractStridedSliceOp>(
loc, block1D, i, sliceWidth, 1);
// TODO: replace this with non-packed ScaledTruncOp for sliceWidth == 1
Value scaleTrunc = rewriter.create<amdgpu::PackedScaledTruncOp>(
loc, truncScaleResultType, slice, uniformScale, 0,
/*existing=*/nullptr);
int64_t packedWidth =
cast<VectorType>(scaleTrunc.getType()).getNumElements();
if (packedWidth != opWidth)
scaleTrunc = rewriter.create<vector::ExtractStridedSliceOp>(
loc, scaleTrunc, 0, sliceWidth, 1);
blockResult = rewriter.create<vector::InsertStridedSliceOp>(
loc, scaleTrunc, blockResult, i, 1);
}

VectorType resultType = VectorType::get(ratio, outType);
Value cast =
rewriter.create<vector::ShapeCastOp>(loc, resultType, blockResult);
result = rewriter.create<vector::InsertStridedSliceOp>(loc, cast, result,
offsets, strides);
}

rewriter.replaceOp(op, result);

return success();
}

void mlir::arith::populateArithToAMDGPUConversionPatterns(
RewritePatternSet &patterns, bool convertFP8Arithmetic,
bool saturateFP8Truncf, bool allowPackedF16Rtz, Chipset chipset) {
Expand All @@ -406,6 +673,11 @@ void mlir::arith::populateArithToAMDGPUConversionPatterns(
}
if (allowPackedF16Rtz)
patterns.add<TruncfToFloat16RewritePattern>(patterns.getContext());

if (chipset >= kGfx950) {
patterns.add<ScalingExtFRewritePattern>(patterns.getContext());
patterns.add<ScalingTruncFRewritePattern>(patterns.getContext());
}
}

void ArithToAMDGPUConversionPass::runOnOperation() {
Expand Down
Loading