Skip to content

Commit

Permalink
Fixing fmaxf for CPU 16 bit
Browse files Browse the repository at this point in the history
  • Loading branch information
Ranvirsv authored and github-actions[bot] committed Jun 8, 2023
1 parent 088c000 commit 5852224
Show file tree
Hide file tree
Showing 5 changed files with 87 additions and 0 deletions.
1 change: 1 addition & 0 deletions compiler/src/iree/compiler/Codegen/LLVMCPU/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@ iree_cc_library(
SRCS
"ConvertToLLVM.cpp"
"DispatchABI.cpp"
"ExpandF16MaxFToF32Pass.cpp"
"KernelDispatch.cpp"
"LLVMCPUAssignConstantOrdinals.cpp"
"LLVMCPUAssignImportOrdinals.cpp"
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,74 @@
// Copyright 2023 The IREE Authors
//
// Licensed under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception

#include "iree/compiler/Codegen/LLVMCPU/LLVMCPUPasses.h"
#include "iree/compiler/Codegen/PassDetail.h"
#include "mlir/Dialect/Vector/Transforms/LoweringPatterns.h"
#include "mlir/Dialect/Vector/Transforms/VectorRewritePatterns.h"
#include "mlir/Pass/Pass.h"
#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
#include "mlir/Dialect/Arith/IR/Arith.h"

namespace mlir {
namespace iree_compiler {

namespace {

struct ExpandF16MaxFToF32Pattern : public OpRewritePattern<arith::MaxFOp> {
public:
using OpRewritePattern<arith::MaxFOp>::OpRewritePattern;

LogicalResult matchAndRewrite(arith::MaxFOp op,
PatternRewriter &rewriter) const override{
Type resultType = op.getLhs().getType();
if (getElementTypeOrSelf(resultType).getIntOrFloatBitWidth() != 16) {
return failure();
}

Location loc = op.getLoc();

Type wideType = rewriter.getF32Type();
if(auto vecTy = resultType.dyn_cast<VectorType>()) {
wideType = VectorType::get(vecTy.getShape(), wideType);
}

Value lhsExt = rewriter.create<arith::ExtFOp>(loc, wideType, op.getLhs());
Value rhsExt = rewriter.create<arith::ExtFOp>(loc, wideType, op.getRhs());
Value maxExt =
rewriter.create<arith::MaxFOp>(loc, wideType, lhsExt, rhsExt);
Value result = rewriter.create<arith::TruncFOp>(loc, resultType, maxExt);

rewriter.replaceOp(op, result);
return success();
}
};

struct ExpandF16MaxFToF32Pass
: public ExpandF16MaxFToF32Base<
ExpandF16MaxFToF32Pass> {
void getDependentDialects(DialectRegistry &registry) const override {
registry.insert<linalg::LinalgDialect>();
}

void runOnOperation() override {
MLIRContext *context = &getContext();
RewritePatternSet patterns(&getContext());
patterns.insert<ExpandF16MaxFToF32Pattern>(
context);
if (failed(applyPatternsAndFoldGreedily(getOperation(),
std::move(patterns)))) {
return signalPassFailure();
}
}
};
} // namespace

std::unique_ptr<Pass> createExpandF16MaxFToF32Pass() {
return std::make_unique<ExpandF16MaxFToF32Pass>();
}

} // namespace iree_compiler
} // namespace mlir
3 changes: 3 additions & 0 deletions compiler/src/iree/compiler/Codegen/LLVMCPU/LLVMCPUPasses.h
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,9 @@ createLLVMCPUEmitVectorizationRemarksPass();
std::unique_ptr<OperationPass<IREE::HAL::ExecutableVariantOp>>
createLLVMCPULowerExecutableTargetPass();

/// Convert F16 to f32 max.
std::unique_ptr<Pass> createExpandF16MaxFToF32Pass();

/// Pass to lower a sequence of operations to a iree_codegen.ukernel.*
/// operation.
std::unique_ptr<OperationPass<>> createLLVMCPULowerToUKernelsPass();
Expand Down
1 change: 1 addition & 0 deletions compiler/src/iree/compiler/Codegen/LLVMCPU/Passes.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -710,6 +710,7 @@ static void addLowerToLLVMPasses(OpPassManager &passManager) {

void buildLLVMCPUCodegenPassPipeline(OpPassManager &passManager) {
addCommonTargetExecutablePreprocessingPasses(passManager.nest<ModuleOp>());
passManager.addPass(createExpandF16MaxFToF32Pass());
passManager.nest<ModuleOp>().addNestedPass<func::FuncOp>(
createLLVMCPUMaterializeEncodingPass());
// TODO: Remove the following pass the plumb support for #hal.descriptor_type
Expand Down
8 changes: 8 additions & 0 deletions compiler/src/iree/compiler/Codegen/Passes.td
Original file line number Diff line number Diff line change
Expand Up @@ -640,6 +640,14 @@ def LLVMCPULowerExecutableTarget :
"mlir::iree_compiler::createLLVMCPULowerExecutableTargetPass()";
}

def ExpandF16MaxFToF32 :
Pass<"iree-llvmcpu-expand-max-f16-to-f32", ""> {
let summary =
"Promote f16 to f32 for max.";
let constructor =
"mlir::iree_compiler::createExpandF16MaxFToF32Pass()";
}

def LLVMCPULowerToUKernels :
Pass<"iree-llvmcpu-lower-to-ukernels", ""> {
let summary =
Expand Down

0 comments on commit 5852224

Please sign in to comment.