diff --git a/llvm/lib/IR/AutoUpgrade.cpp b/llvm/lib/IR/AutoUpgrade.cpp index e2a6a5494ecb7..4e10c6ce5e2bd 100644 --- a/llvm/lib/IR/AutoUpgrade.cpp +++ b/llvm/lib/IR/AutoUpgrade.cpp @@ -593,68 +593,83 @@ static bool UpgradeX86IntrinsicFunction(Function *F, StringRef Name, } static Intrinsic::ID ShouldUpgradeNVPTXBF16Intrinsic(StringRef Name) { - return StringSwitch(Name) - .Case("abs.bf16", Intrinsic::nvvm_abs_bf16) - .Case("abs.bf16x2", Intrinsic::nvvm_abs_bf16x2) - .Case("fma.rn.bf16", Intrinsic::nvvm_fma_rn_bf16) - .Case("fma.rn.bf16x2", Intrinsic::nvvm_fma_rn_bf16x2) - .Case("fma.rn.ftz_bf16", Intrinsic::nvvm_fma_rn_ftz_bf16) - .Case("fma.rn.ftz.bf16x2", Intrinsic::nvvm_fma_rn_ftz_bf16x2) - .Case("fma.rn.ftz.relu.bf16", Intrinsic::nvvm_fma_rn_ftz_relu_bf16) - .Case("fma.rn.ftz.relu.bf16x2", Intrinsic::nvvm_fma_rn_ftz_relu_bf16x2) - .Case("fma.rn.ftz_sat.bf16", Intrinsic::nvvm_fma_rn_ftz_sat_bf16) - .Case("fma.rn.ftz_sat.bf16x2", Intrinsic::nvvm_fma_rn_ftz_sat_bf16x2) - .Case("fma.rn.relu.bf16", Intrinsic::nvvm_fma_rn_relu_bf16) - .Case("fma.rn.relu.bf16x2", Intrinsic::nvvm_fma_rn_relu_bf16x2) - .Case("fma.rn.sat.bf16", Intrinsic::nvvm_fma_rn_sat_bf16) - .Case("fma.rn.sat.bf16x2", Intrinsic::nvvm_fma_rn_sat_bf16x2) - .Case("fmax.bf16", Intrinsic::nvvm_fmax_bf16) - .Case("fmax.bf16x2", Intrinsic::nvvm_fmax_bf16x2) - .Case("fmax.ftz.bf16", Intrinsic::nvvm_fmax_ftz_bf16) - .Case("fmax.ftz.bf16x2", Intrinsic::nvvm_fmax_ftz_bf16x2) - .Case("fmax.ftz.nan.bf16", Intrinsic::nvvm_fmax_ftz_nan_bf16) - .Case("fmax.ftz.nan.bf16x2", Intrinsic::nvvm_fmax_ftz_nan_bf16x2) - .Case("fmax.ftz.nan.xorsign.abs.bf16", - Intrinsic::nvvm_fmax_ftz_nan_xorsign_abs_bf16) - .Case("fmax.ftz.nan.xorsign.abs.bf16x2", - Intrinsic::nvvm_fmax_ftz_nan_xorsign_abs_bf16x2) - .Case("fmax.ftz.xorsign.abs.bf16", - Intrinsic::nvvm_fmax_ftz_xorsign_abs_bf16) - .Case("fmax.ftz.xorsign.abs.bf16x2", - Intrinsic::nvvm_fmax_ftz_xorsign_abs_bf16x2) - .Case("fmax.nan.bf16", Intrinsic::nvvm_fmax_nan_bf16) - .Case("fmax.nan.bf16x2", Intrinsic::nvvm_fmax_nan_bf16x2) - .Case("fmax.nan.xorsign.abs.bf16", - Intrinsic::nvvm_fmax_nan_xorsign_abs_bf16) - .Case("fmax.nan.xorsign.abs.bf16x2", - Intrinsic::nvvm_fmax_nan_xorsign_abs_bf16x2) - .Case("fmax.xorsign.abs.bf16", Intrinsic::nvvm_fmax_xorsign_abs_bf16) - .Case("fmax.xorsign.abs.bf16x2", Intrinsic::nvvm_fmax_xorsign_abs_bf16x2) - .Case("fmin.bf16", Intrinsic::nvvm_fmin_bf16) - .Case("fmin.bf16x2", Intrinsic::nvvm_fmin_bf16x2) - .Case("fmin.ftz.bf16", Intrinsic::nvvm_fmin_ftz_bf16) - .Case("fmin.ftz.bf16x2", Intrinsic::nvvm_fmin_ftz_bf16x2) - .Case("fmin.ftz.nan_bf16", Intrinsic::nvvm_fmin_ftz_nan_bf16) - .Case("fmin.ftz.nan_bf16x2", Intrinsic::nvvm_fmin_ftz_nan_bf16x2) - .Case("fmin.ftz.nan.xorsign.abs.bf16", - Intrinsic::nvvm_fmin_ftz_nan_xorsign_abs_bf16) - .Case("fmin.ftz.nan.xorsign.abs.bf16x2", - Intrinsic::nvvm_fmin_ftz_nan_xorsign_abs_bf16x2) - .Case("fmin.ftz.xorsign.abs.bf16", - Intrinsic::nvvm_fmin_ftz_xorsign_abs_bf16) - .Case("fmin.ftz.xorsign.abs.bf16x2", - Intrinsic::nvvm_fmin_ftz_xorsign_abs_bf16x2) - .Case("fmin.nan.bf16", Intrinsic::nvvm_fmin_nan_bf16) - .Case("fmin.nan.bf16x2", Intrinsic::nvvm_fmin_nan_bf16x2) - .Case("fmin.nan.xorsign.abs.bf16", - Intrinsic::nvvm_fmin_nan_xorsign_abs_bf16) - .Case("fmin.nan.xorsign.abs.bf16x2", - Intrinsic::nvvm_fmin_nan_xorsign_abs_bf16x2) - .Case("fmin.xorsign.abs.bf16", Intrinsic::nvvm_fmin_xorsign_abs_bf16) - .Case("fmin.xorsign.abs.bf16x2", Intrinsic::nvvm_fmin_xorsign_abs_bf16x2) - .Case("neg.bf16", Intrinsic::nvvm_neg_bf16) - .Case("neg.bf16x2", Intrinsic::nvvm_neg_bf16x2) - .Default(Intrinsic::not_intrinsic); + if (Name.consume_front("abs.")) + return StringSwitch(Name) + .Case("bf16", Intrinsic::nvvm_abs_bf16) + .Case("bf16x2", Intrinsic::nvvm_abs_bf16x2) + .Default(Intrinsic::not_intrinsic); + + if (Name.consume_front("fma.rn.")) + return StringSwitch(Name) + .Case("bf16", Intrinsic::nvvm_fma_rn_bf16) + .Case("bf16x2", Intrinsic::nvvm_fma_rn_bf16x2) + .Case("ftz_bf16", Intrinsic::nvvm_fma_rn_ftz_bf16) + .Case("ftz.bf16x2", Intrinsic::nvvm_fma_rn_ftz_bf16x2) + .Case("ftz.relu.bf16", Intrinsic::nvvm_fma_rn_ftz_relu_bf16) + .Case("ftz.relu.bf16x2", Intrinsic::nvvm_fma_rn_ftz_relu_bf16x2) + .Case("ftz_sat.bf16", Intrinsic::nvvm_fma_rn_ftz_sat_bf16) + .Case("ftz_sat.bf16x2", Intrinsic::nvvm_fma_rn_ftz_sat_bf16x2) + .Case("relu.bf16", Intrinsic::nvvm_fma_rn_relu_bf16) + .Case("relu.bf16x2", Intrinsic::nvvm_fma_rn_relu_bf16x2) + .Case("sat.bf16", Intrinsic::nvvm_fma_rn_sat_bf16) + .Case("sat.bf16x2", Intrinsic::nvvm_fma_rn_sat_bf16x2) + .Default(Intrinsic::not_intrinsic); + + if (Name.consume_front("fmax.")) + return StringSwitch(Name) + .Case("bf16", Intrinsic::nvvm_fmax_bf16) + .Case("bf16x2", Intrinsic::nvvm_fmax_bf16x2) + .Case("ftz.bf16", Intrinsic::nvvm_fmax_ftz_bf16) + .Case("ftz.bf16x2", Intrinsic::nvvm_fmax_ftz_bf16x2) + .Case("ftz.nan.bf16", Intrinsic::nvvm_fmax_ftz_nan_bf16) + .Case("ftz.nan.bf16x2", Intrinsic::nvvm_fmax_ftz_nan_bf16x2) + .Case("ftz.nan.xorsign.abs.bf16", + Intrinsic::nvvm_fmax_ftz_nan_xorsign_abs_bf16) + .Case("ftz.nan.xorsign.abs.bf16x2", + Intrinsic::nvvm_fmax_ftz_nan_xorsign_abs_bf16x2) + .Case("ftz.xorsign.abs.bf16", Intrinsic::nvvm_fmax_ftz_xorsign_abs_bf16) + .Case("ftz.xorsign.abs.bf16x2", + Intrinsic::nvvm_fmax_ftz_xorsign_abs_bf16x2) + .Case("nan.bf16", Intrinsic::nvvm_fmax_nan_bf16) + .Case("nan.bf16x2", Intrinsic::nvvm_fmax_nan_bf16x2) + .Case("nan.xorsign.abs.bf16", Intrinsic::nvvm_fmax_nan_xorsign_abs_bf16) + .Case("nan.xorsign.abs.bf16x2", + Intrinsic::nvvm_fmax_nan_xorsign_abs_bf16x2) + .Case("xorsign.abs.bf16", Intrinsic::nvvm_fmax_xorsign_abs_bf16) + .Case("xorsign.abs.bf16x2", Intrinsic::nvvm_fmax_xorsign_abs_bf16x2) + .Default(Intrinsic::not_intrinsic); + + if (Name.consume_front("fmin.")) + return StringSwitch(Name) + .Case("bf16", Intrinsic::nvvm_fmin_bf16) + .Case("bf16x2", Intrinsic::nvvm_fmin_bf16x2) + .Case("ftz.bf16", Intrinsic::nvvm_fmin_ftz_bf16) + .Case("ftz.bf16x2", Intrinsic::nvvm_fmin_ftz_bf16x2) + .Case("ftz.nan_bf16", Intrinsic::nvvm_fmin_ftz_nan_bf16) + .Case("ftz.nan_bf16x2", Intrinsic::nvvm_fmin_ftz_nan_bf16x2) + .Case("ftz.nan.xorsign.abs.bf16", + Intrinsic::nvvm_fmin_ftz_nan_xorsign_abs_bf16) + .Case("ftz.nan.xorsign.abs.bf16x2", + Intrinsic::nvvm_fmin_ftz_nan_xorsign_abs_bf16x2) + .Case("ftz.xorsign.abs.bf16", Intrinsic::nvvm_fmin_ftz_xorsign_abs_bf16) + .Case("ftz.xorsign.abs.bf16x2", + Intrinsic::nvvm_fmin_ftz_xorsign_abs_bf16x2) + .Case("nan.bf16", Intrinsic::nvvm_fmin_nan_bf16) + .Case("nan.bf16x2", Intrinsic::nvvm_fmin_nan_bf16x2) + .Case("nan.xorsign.abs.bf16", Intrinsic::nvvm_fmin_nan_xorsign_abs_bf16) + .Case("nan.xorsign.abs.bf16x2", + Intrinsic::nvvm_fmin_nan_xorsign_abs_bf16x2) + .Case("xorsign.abs.bf16", Intrinsic::nvvm_fmin_xorsign_abs_bf16) + .Case("xorsign.abs.bf16x2", Intrinsic::nvvm_fmin_xorsign_abs_bf16x2) + .Default(Intrinsic::not_intrinsic); + + if (Name.consume_front("neg.")) + return StringSwitch(Name) + .Case("bf16", Intrinsic::nvvm_neg_bf16) + .Case("bf16x2", Intrinsic::nvvm_neg_bf16x2) + .Default(Intrinsic::not_intrinsic); + + return Intrinsic::not_intrinsic; } static bool UpgradeIntrinsicFunction1(Function *F, Function *&NewFn) { @@ -1052,42 +1067,55 @@ static bool UpgradeIntrinsicFunction1(Function *F, Function *&NewFn) { break; } case 'n': { - if (Name.startswith("nvvm.")) { - Name = Name.substr(5); - - // The following nvvm intrinsics correspond exactly to an LLVM intrinsic. - Intrinsic::ID IID = StringSwitch(Name) - .Cases("brev32", "brev64", Intrinsic::bitreverse) - .Case("clz.i", Intrinsic::ctlz) - .Case("popc.i", Intrinsic::ctpop) - .Default(Intrinsic::not_intrinsic); - if (IID != Intrinsic::not_intrinsic && F->arg_size() == 1) { - NewFn = Intrinsic::getDeclaration(F->getParent(), IID, - {F->getReturnType()}); - return true; + if (Name.consume_front("nvvm.")) { + // Check for nvvm intrinsics corresponding exactly to an LLVM intrinsic. + if (F->arg_size() == 1) { + Intrinsic::ID IID = + StringSwitch(Name) + .Cases("brev32", "brev64", Intrinsic::bitreverse) + .Case("clz.i", Intrinsic::ctlz) + .Case("popc.i", Intrinsic::ctpop) + .Default(Intrinsic::not_intrinsic); + if (IID != Intrinsic::not_intrinsic) { + NewFn = Intrinsic::getDeclaration(F->getParent(), IID, + {F->getReturnType()}); + return true; + } } - IID = ShouldUpgradeNVPTXBF16Intrinsic(Name); - if (IID != Intrinsic::not_intrinsic && - !F->getReturnType()->getScalarType()->isBFloatTy()) { - NewFn = nullptr; - return true; + + // Check for nvvm intrinsics that need a return type adjustment. + if (!F->getReturnType()->getScalarType()->isBFloatTy()) { + Intrinsic::ID IID = ShouldUpgradeNVPTXBF16Intrinsic(Name); + if (IID != Intrinsic::not_intrinsic) { + NewFn = nullptr; + return true; + } } + // The following nvvm intrinsics correspond exactly to an LLVM idiom, but // not to an intrinsic alone. We expand them in UpgradeIntrinsicCall. // // TODO: We could add lohi.i2d. - bool Expand = StringSwitch(Name) - .Cases("abs.i", "abs.ll", true) - .Cases("clz.ll", "popc.ll", "h2f", true) - .Cases("max.i", "max.ll", "max.ui", "max.ull", true) - .Cases("min.i", "min.ll", "min.ui", "min.ull", true) - .StartsWith("atomic.load.add.f32.p", true) - .StartsWith("atomic.load.add.f64.p", true) - .Default(false); + bool Expand = false; + if (Name.consume_front("abs.")) + // nvvm.abs.{i,ii} + Expand = Name == "i" || Name == "ll"; + else if (Name == "clz.ll" || Name == "popc.ll" || Name == "h2f") + Expand = true; + else if (Name.consume_front("max.") || Name.consume_front("min.")) + // nvvm.{min,max}.{i,ii,ui,ull} + Expand = Name == "i" || Name == "ll" || Name == "ui" || Name == "ull"; + else if (Name.consume_front("atomic.load.add.")) + // nvvm.atomic.load.add.{f32.p,f64.p} + Expand = Name.startswith("f32.p") || Name.startswith("f64.p"); + else + Expand = false; + if (Expand) { NewFn = nullptr; return true; } + break; // No other 'nvvm.*'. } break; }