diff --git a/mlir/include/mlir/Dialect/LLVMIR/NVVMOps.td b/mlir/include/mlir/Dialect/LLVMIR/NVVMOps.td index 20cb2e47343c0..f19500e1957c7 100644 --- a/mlir/include/mlir/Dialect/LLVMIR/NVVMOps.td +++ b/mlir/include/mlir/Dialect/LLVMIR/NVVMOps.td @@ -51,21 +51,21 @@ class NVVM_Op traits = []> : // NVVM intrinsic operations //===----------------------------------------------------------------------===// -class NVVM_IntrOp traits, +class NVVM_IntrOp overloadedResults, + list overloadedOperands, list traits, int numResults> : LLVM_IntrOpBase overloadedResults=*/[], - /*list overloadedOperands=*/[], - traits, numResults>; + overloadedResults, overloadedOperands, traits, numResults>; //===----------------------------------------------------------------------===// // NVVM special register op definitions //===----------------------------------------------------------------------===// -class NVVM_SpecialRegisterOp traits = []> : - NVVM_IntrOp { - let arguments = (ins); +class NVVM_SpecialRegisterOp traits = []> : + NVVM_IntrOp, + Arguments<(ins)> { let assemblyFormat = "attr-dict `:` type($res)"; } @@ -92,16 +92,6 @@ def NVVM_GridDimXOp : NVVM_SpecialRegisterOp<"read.ptx.sreg.nctaid.x">; def NVVM_GridDimYOp : NVVM_SpecialRegisterOp<"read.ptx.sreg.nctaid.y">; def NVVM_GridDimZOp : NVVM_SpecialRegisterOp<"read.ptx.sreg.nctaid.z">; -//===----------------------------------------------------------------------===// -// NVVM approximate op definitions -//===----------------------------------------------------------------------===// - -def NVVM_RcpApproxFtzF32Op : NVVM_IntrOp<"rcp.approx.ftz.f", [NoSideEffect], 1> { - let arguments = (ins F32:$arg); - let results = (outs F32:$res); - let assemblyFormat = "$arg attr-dict `:` type($res)"; -} - //===----------------------------------------------------------------------===// // NVVM synchronization op definitions //===----------------------------------------------------------------------===// diff --git a/mlir/include/mlir/Dialect/LLVMIR/Transforms/OptimizeForNVVM.h b/mlir/include/mlir/Dialect/LLVMIR/Transforms/OptimizeForNVVM.h deleted file mode 100644 index af0c4ea4e568c..0000000000000 --- a/mlir/include/mlir/Dialect/LLVMIR/Transforms/OptimizeForNVVM.h +++ /dev/null @@ -1,25 +0,0 @@ -//===- OptimizeForNVVM.h - Optimize LLVM IR for NVVM -*- C++ -*-===// -// -// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. -// See https://llvm.org/LICENSE.txt for license information. -// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception -// -//===----------------------------------------------------------------------===// - -#ifndef MLIR_DIALECT_LLVMIR_TRANSFORMS_OPTIMIZENVVM_H -#define MLIR_DIALECT_LLVMIR_TRANSFORMS_OPTIMIZENVVM_H - -#include - -namespace mlir { -class Pass; - -namespace NVVM { - -/// Creates a pass that optimizes LLVM IR for the NVVM target. -std::unique_ptr createOptimizeForTargetPass(); - -} // namespace NVVM -} // namespace mlir - -#endif // MLIR_DIALECT_LLVMIR_TRANSFORMS_OPTIMIZENVVM_H diff --git a/mlir/include/mlir/Dialect/LLVMIR/Transforms/Passes.h b/mlir/include/mlir/Dialect/LLVMIR/Transforms/Passes.h index 39948557b55a6..868a0e5635105 100644 --- a/mlir/include/mlir/Dialect/LLVMIR/Transforms/Passes.h +++ b/mlir/include/mlir/Dialect/LLVMIR/Transforms/Passes.h @@ -10,7 +10,6 @@ #define MLIR_DIALECT_LLVMIR_TRANSFORMS_PASSES_H #include "mlir/Dialect/LLVMIR/Transforms/LegalizeForExport.h" -#include "mlir/Dialect/LLVMIR/Transforms/OptimizeForNVVM.h" #include "mlir/Pass/Pass.h" namespace mlir { diff --git a/mlir/include/mlir/Dialect/LLVMIR/Transforms/Passes.td b/mlir/include/mlir/Dialect/LLVMIR/Transforms/Passes.td index 060822603bc20..0dc193e794f52 100644 --- a/mlir/include/mlir/Dialect/LLVMIR/Transforms/Passes.td +++ b/mlir/include/mlir/Dialect/LLVMIR/Transforms/Passes.td @@ -16,9 +16,4 @@ def LLVMLegalizeForExport : Pass<"llvm-legalize-for-export"> { let constructor = "mlir::LLVM::createLegalizeForExportPass()"; } -def NVVMOptimizeForTarget : Pass<"llvm-optimize-for-nvvm-target"> { - let summary = "Optimize NVVM IR"; - let constructor = "mlir::NVVM::createOptimizeForTargetPass()"; -} - #endif // MLIR_DIALECT_LLVMIR_TRANSFORMS_PASSES diff --git a/mlir/lib/Dialect/LLVMIR/Transforms/CMakeLists.txt b/mlir/lib/Dialect/LLVMIR/Transforms/CMakeLists.txt index e27d83e4426db..3e1342dcf2c9c 100644 --- a/mlir/lib/Dialect/LLVMIR/Transforms/CMakeLists.txt +++ b/mlir/lib/Dialect/LLVMIR/Transforms/CMakeLists.txt @@ -1,6 +1,5 @@ add_mlir_dialect_library(MLIRLLVMIRTransforms LegalizeForExport.cpp - OptimizeForNVVM.cpp DEPENDS MLIRLLVMPassIncGen diff --git a/mlir/lib/Dialect/LLVMIR/Transforms/OptimizeForNVVM.cpp b/mlir/lib/Dialect/LLVMIR/Transforms/OptimizeForNVVM.cpp deleted file mode 100644 index d269aa82ecec5..0000000000000 --- a/mlir/lib/Dialect/LLVMIR/Transforms/OptimizeForNVVM.cpp +++ /dev/null @@ -1,97 +0,0 @@ -//===- OptimizeForNVVM.cpp - Optimize LLVM IR for NVVM ---------===// -// -// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. -// See https://llvm.org/LICENSE.txt for license information. -// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception -// -//===----------------------------------------------------------------------===// - -#include "mlir/Dialect/LLVMIR/Transforms/OptimizeForNVVM.h" -#include "PassDetail.h" -#include "mlir/Dialect/LLVMIR/NVVMDialect.h" -#include "mlir/IR/Builders.h" -#include "mlir/IR/PatternMatch.h" -#include "mlir/Transforms/GreedyPatternRewriteDriver.h" - -using namespace mlir; - -namespace { -// Replaces fdiv on fp16 with fp32 multiplication with reciprocal plus one -// (conditional) Newton iteration. -// -// This as accurate as promoting the division to fp32 in the NVPTX backend, but -// faster because it performs less Newton iterations, avoids the slow path -// for e.g. denormals, and allows reuse of the reciprocal for multiple divisions -// by the same divisor. -struct ExpandDivF16 : public OpRewritePattern { - using OpRewritePattern::OpRewritePattern; - -private: - LogicalResult matchAndRewrite(LLVM::FDivOp op, - PatternRewriter &rewriter) const override; -}; - -struct NVVMOptimizeForTarget - : public NVVMOptimizeForTargetBase { - void runOnOperation() override; - - void getDependentDialects(DialectRegistry ®istry) const override { - registry.insert(); - } -}; -} // namespace - -LogicalResult ExpandDivF16::matchAndRewrite(LLVM::FDivOp op, - PatternRewriter &rewriter) const { - if (!op.getType().isF16()) - return rewriter.notifyMatchFailure(op, "not f16"); - Location loc = op.getLoc(); - - Type f32Type = rewriter.getF32Type(); - Type i32Type = rewriter.getI32Type(); - - // Extend lhs and rhs to fp32. - Value lhs = rewriter.create(loc, f32Type, op.getLhs()); - Value rhs = rewriter.create(loc, f32Type, op.getRhs()); - - // float rcp = rcp.approx.ftz.f32(rhs), approx = lhs * rcp. - Value rcp = rewriter.create(loc, f32Type, rhs); - Value approx = rewriter.create(loc, lhs, rcp); - - // Refine the approximation with one Newton iteration: - // float refined = approx + (lhs - approx * rhs) * rcp; - Value err = rewriter.create( - loc, approx, rewriter.create(loc, rhs), lhs); - Value refined = rewriter.create(loc, err, rcp, approx); - - // Use refined value if approx is normal (exponent neither all 0 or all 1). - Value mask = rewriter.create( - loc, i32Type, rewriter.getUI32IntegerAttr(0x7f800000)); - Value cast = rewriter.create(loc, i32Type, approx); - Value exp = rewriter.create(loc, i32Type, cast, mask); - Value zero = rewriter.create( - loc, i32Type, rewriter.getUI32IntegerAttr(0)); - Value pred = rewriter.create( - loc, - rewriter.create(loc, LLVM::ICmpPredicate::eq, exp, zero), - rewriter.create(loc, LLVM::ICmpPredicate::eq, exp, mask)); - Value result = - rewriter.create(loc, f32Type, pred, approx, refined); - - // Replace with trucation back to fp16. - rewriter.replaceOpWithNewOp(op, op.getType(), result); - - return success(); -} - -void NVVMOptimizeForTarget::runOnOperation() { - MLIRContext *ctx = getOperation()->getContext(); - RewritePatternSet patterns(ctx); - patterns.add(ctx); - if (failed(applyPatternsAndFoldGreedily(getOperation(), std::move(patterns)))) - return signalPassFailure(); -} - -std::unique_ptr NVVM::createOptimizeForTargetPass() { - return std::make_unique(); -} diff --git a/mlir/test/Dialect/LLVMIR/nvvm.mlir b/mlir/test/Dialect/LLVMIR/nvvm.mlir index c978d773d9591..9b28841c3c781 100644 --- a/mlir/test/Dialect/LLVMIR/nvvm.mlir +++ b/mlir/test/Dialect/LLVMIR/nvvm.mlir @@ -29,13 +29,6 @@ func.func @nvvm_special_regs() -> i32 { llvm.return %0 : i32 } -// CHECK-LABEL: @nvvm_rcp -func.func @nvvm_rcp(%arg0: f32) -> f32 { - // CHECK: nvvm.rcp.approx.ftz.f %arg0 : f32 - %0 = nvvm.rcp.approx.ftz.f %arg0 : f32 - llvm.return %0 : f32 -} - // CHECK-LABEL: @llvm_nvvm_barrier0 func.func @llvm_nvvm_barrier0() { // CHECK: nvvm.barrier0 diff --git a/mlir/test/Dialect/LLVMIR/optimize-for-nvvm.mlir b/mlir/test/Dialect/LLVMIR/optimize-for-nvvm.mlir deleted file mode 100644 index e1cfd0c44f89b..0000000000000 --- a/mlir/test/Dialect/LLVMIR/optimize-for-nvvm.mlir +++ /dev/null @@ -1,24 +0,0 @@ -// RUN: mlir-opt %s -llvm-optimize-for-nvvm-target | FileCheck %s - -// CHECK-LABEL: llvm.func @fdiv_fp16 -llvm.func @fdiv_fp16(%arg0 : f16, %arg1 : f16) -> f16 { - // CHECK-DAG: %[[c0:.*]] = llvm.mlir.constant(0 : ui32) : i32 - // CHECK-DAG: %[[mask:.*]] = llvm.mlir.constant(2139095040 : ui32) : i32 - // CHECK-DAG: %[[lhs:.*]] = llvm.fpext %arg0 : f16 to f32 - // CHECK-DAG: %[[rhs:.*]] = llvm.fpext %arg1 : f16 to f32 - // CHECK-DAG: %[[rcp:.*]] = nvvm.rcp.approx.ftz.f %[[rhs]] : f32 - // CHECK-DAG: %[[approx:.*]] = llvm.fmul %[[lhs]], %[[rcp]] : f32 - // CHECK-DAG: %[[neg:.*]] = llvm.fneg %[[rhs]] : f32 - // CHECK-DAG: %[[err:.*]] = "llvm.intr.fma"(%[[approx]], %[[neg]], %[[lhs]]) : (f32, f32, f32) -> f32 - // CHECK-DAG: %[[refined:.*]] = "llvm.intr.fma"(%[[err]], %[[rcp]], %[[approx]]) : (f32, f32, f32) -> f32 - // CHECK-DAG: %[[cast:.*]] = llvm.bitcast %[[approx]] : f32 to i32 - // CHECK-DAG: %[[exp:.*]] = llvm.and %[[cast]], %[[mask]] : i32 - // CHECK-DAG: %[[is_zero:.*]] = llvm.icmp "eq" %[[exp]], %[[c0]] : i32 - // CHECK-DAG: %[[is_mask:.*]] = llvm.icmp "eq" %[[exp]], %[[mask]] : i32 - // CHECK-DAG: %[[pred:.*]] = llvm.or %[[is_zero]], %[[is_mask]] : i1 - // CHECK-DAG: %[[select:.*]] = llvm.select %[[pred]], %[[approx]], %[[refined]] : i1, f32 - // CHECK-DAG: %[[result:.*]] = llvm.fptrunc %[[select]] : f32 to f16 - %result = llvm.fdiv %arg0, %arg1 : f16 - // CHECK: llvm.return %[[result]] : f16 - llvm.return %result : f16 -} diff --git a/mlir/test/Target/LLVMIR/nvvmir.mlir b/mlir/test/Target/LLVMIR/nvvmir.mlir index a66560d0e0da8..53af04140c38d 100644 --- a/mlir/test/Target/LLVMIR/nvvmir.mlir +++ b/mlir/test/Target/LLVMIR/nvvmir.mlir @@ -33,13 +33,6 @@ llvm.func @nvvm_special_regs() -> i32 { llvm.return %1 : i32 } -// CHECK-LABEL: @nvvm_rcp -llvm.func @nvvm_rcp(%0: f32) -> f32 { - // CHECK: call float @llvm.nvvm.rcp.approx.ftz.f - %1 = nvvm.rcp.approx.ftz.f %0 : f32 - llvm.return %1 : f32 -} - // CHECK-LABEL: @llvm_nvvm_barrier0 llvm.func @llvm_nvvm_barrier0() { // CHECK: call void @llvm.nvvm.barrier0() diff --git a/utils/bazel/llvm-project-overlay/mlir/BUILD.bazel b/utils/bazel/llvm-project-overlay/mlir/BUILD.bazel index 3afa3101fa48a..48264e5126614 100644 --- a/utils/bazel/llvm-project-overlay/mlir/BUILD.bazel +++ b/utils/bazel/llvm-project-overlay/mlir/BUILD.bazel @@ -3386,9 +3386,7 @@ cc_library( ":IR", ":LLVMDialect", ":LLVMPassIncGen", - ":NVVMDialect", ":Pass", - ":Transforms", ], )