Skip to content

Commit

Permalink
[llvm][NFC]Refactor AutoUpgrader case 'n'.
Browse files Browse the repository at this point in the history
The NVPTX intrinsics are under 'n'. Use the consume_front API, so fix
that. Refactor the helper function to group matchers on the first
component and check that first. Do similarly with the final set of
intrinsics, which have a lot of commonality in the matching.  Finally
reorder the argument/return type checking wrt name checking -- the
former is going to be cheaper, so do that first before checking the
name.#

Reviewed By: tra

Differential Revision: https://reviews.llvm.org/D158445
  • Loading branch information
urnathan committed Aug 22, 2023
1 parent f2b150b commit b045c36
Showing 1 changed file with 116 additions and 88 deletions.
204 changes: 116 additions & 88 deletions llvm/lib/IR/AutoUpgrade.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -593,68 +593,83 @@ static bool UpgradeX86IntrinsicFunction(Function *F, StringRef Name,
}

static Intrinsic::ID ShouldUpgradeNVPTXBF16Intrinsic(StringRef Name) {
return StringSwitch<Intrinsic::ID>(Name)
.Case("abs.bf16", Intrinsic::nvvm_abs_bf16)
.Case("abs.bf16x2", Intrinsic::nvvm_abs_bf16x2)
.Case("fma.rn.bf16", Intrinsic::nvvm_fma_rn_bf16)
.Case("fma.rn.bf16x2", Intrinsic::nvvm_fma_rn_bf16x2)
.Case("fma.rn.ftz_bf16", Intrinsic::nvvm_fma_rn_ftz_bf16)
.Case("fma.rn.ftz.bf16x2", Intrinsic::nvvm_fma_rn_ftz_bf16x2)
.Case("fma.rn.ftz.relu.bf16", Intrinsic::nvvm_fma_rn_ftz_relu_bf16)
.Case("fma.rn.ftz.relu.bf16x2", Intrinsic::nvvm_fma_rn_ftz_relu_bf16x2)
.Case("fma.rn.ftz_sat.bf16", Intrinsic::nvvm_fma_rn_ftz_sat_bf16)
.Case("fma.rn.ftz_sat.bf16x2", Intrinsic::nvvm_fma_rn_ftz_sat_bf16x2)
.Case("fma.rn.relu.bf16", Intrinsic::nvvm_fma_rn_relu_bf16)
.Case("fma.rn.relu.bf16x2", Intrinsic::nvvm_fma_rn_relu_bf16x2)
.Case("fma.rn.sat.bf16", Intrinsic::nvvm_fma_rn_sat_bf16)
.Case("fma.rn.sat.bf16x2", Intrinsic::nvvm_fma_rn_sat_bf16x2)
.Case("fmax.bf16", Intrinsic::nvvm_fmax_bf16)
.Case("fmax.bf16x2", Intrinsic::nvvm_fmax_bf16x2)
.Case("fmax.ftz.bf16", Intrinsic::nvvm_fmax_ftz_bf16)
.Case("fmax.ftz.bf16x2", Intrinsic::nvvm_fmax_ftz_bf16x2)
.Case("fmax.ftz.nan.bf16", Intrinsic::nvvm_fmax_ftz_nan_bf16)
.Case("fmax.ftz.nan.bf16x2", Intrinsic::nvvm_fmax_ftz_nan_bf16x2)
.Case("fmax.ftz.nan.xorsign.abs.bf16",
Intrinsic::nvvm_fmax_ftz_nan_xorsign_abs_bf16)
.Case("fmax.ftz.nan.xorsign.abs.bf16x2",
Intrinsic::nvvm_fmax_ftz_nan_xorsign_abs_bf16x2)
.Case("fmax.ftz.xorsign.abs.bf16",
Intrinsic::nvvm_fmax_ftz_xorsign_abs_bf16)
.Case("fmax.ftz.xorsign.abs.bf16x2",
Intrinsic::nvvm_fmax_ftz_xorsign_abs_bf16x2)
.Case("fmax.nan.bf16", Intrinsic::nvvm_fmax_nan_bf16)
.Case("fmax.nan.bf16x2", Intrinsic::nvvm_fmax_nan_bf16x2)
.Case("fmax.nan.xorsign.abs.bf16",
Intrinsic::nvvm_fmax_nan_xorsign_abs_bf16)
.Case("fmax.nan.xorsign.abs.bf16x2",
Intrinsic::nvvm_fmax_nan_xorsign_abs_bf16x2)
.Case("fmax.xorsign.abs.bf16", Intrinsic::nvvm_fmax_xorsign_abs_bf16)
.Case("fmax.xorsign.abs.bf16x2", Intrinsic::nvvm_fmax_xorsign_abs_bf16x2)
.Case("fmin.bf16", Intrinsic::nvvm_fmin_bf16)
.Case("fmin.bf16x2", Intrinsic::nvvm_fmin_bf16x2)
.Case("fmin.ftz.bf16", Intrinsic::nvvm_fmin_ftz_bf16)
.Case("fmin.ftz.bf16x2", Intrinsic::nvvm_fmin_ftz_bf16x2)
.Case("fmin.ftz.nan_bf16", Intrinsic::nvvm_fmin_ftz_nan_bf16)
.Case("fmin.ftz.nan_bf16x2", Intrinsic::nvvm_fmin_ftz_nan_bf16x2)
.Case("fmin.ftz.nan.xorsign.abs.bf16",
Intrinsic::nvvm_fmin_ftz_nan_xorsign_abs_bf16)
.Case("fmin.ftz.nan.xorsign.abs.bf16x2",
Intrinsic::nvvm_fmin_ftz_nan_xorsign_abs_bf16x2)
.Case("fmin.ftz.xorsign.abs.bf16",
Intrinsic::nvvm_fmin_ftz_xorsign_abs_bf16)
.Case("fmin.ftz.xorsign.abs.bf16x2",
Intrinsic::nvvm_fmin_ftz_xorsign_abs_bf16x2)
.Case("fmin.nan.bf16", Intrinsic::nvvm_fmin_nan_bf16)
.Case("fmin.nan.bf16x2", Intrinsic::nvvm_fmin_nan_bf16x2)
.Case("fmin.nan.xorsign.abs.bf16",
Intrinsic::nvvm_fmin_nan_xorsign_abs_bf16)
.Case("fmin.nan.xorsign.abs.bf16x2",
Intrinsic::nvvm_fmin_nan_xorsign_abs_bf16x2)
.Case("fmin.xorsign.abs.bf16", Intrinsic::nvvm_fmin_xorsign_abs_bf16)
.Case("fmin.xorsign.abs.bf16x2", Intrinsic::nvvm_fmin_xorsign_abs_bf16x2)
.Case("neg.bf16", Intrinsic::nvvm_neg_bf16)
.Case("neg.bf16x2", Intrinsic::nvvm_neg_bf16x2)
.Default(Intrinsic::not_intrinsic);
if (Name.consume_front("abs."))
return StringSwitch<Intrinsic::ID>(Name)
.Case("bf16", Intrinsic::nvvm_abs_bf16)
.Case("bf16x2", Intrinsic::nvvm_abs_bf16x2)
.Default(Intrinsic::not_intrinsic);

if (Name.consume_front("fma.rn."))
return StringSwitch<Intrinsic::ID>(Name)
.Case("bf16", Intrinsic::nvvm_fma_rn_bf16)
.Case("bf16x2", Intrinsic::nvvm_fma_rn_bf16x2)
.Case("ftz_bf16", Intrinsic::nvvm_fma_rn_ftz_bf16)
.Case("ftz.bf16x2", Intrinsic::nvvm_fma_rn_ftz_bf16x2)
.Case("ftz.relu.bf16", Intrinsic::nvvm_fma_rn_ftz_relu_bf16)
.Case("ftz.relu.bf16x2", Intrinsic::nvvm_fma_rn_ftz_relu_bf16x2)
.Case("ftz_sat.bf16", Intrinsic::nvvm_fma_rn_ftz_sat_bf16)
.Case("ftz_sat.bf16x2", Intrinsic::nvvm_fma_rn_ftz_sat_bf16x2)
.Case("relu.bf16", Intrinsic::nvvm_fma_rn_relu_bf16)
.Case("relu.bf16x2", Intrinsic::nvvm_fma_rn_relu_bf16x2)
.Case("sat.bf16", Intrinsic::nvvm_fma_rn_sat_bf16)
.Case("sat.bf16x2", Intrinsic::nvvm_fma_rn_sat_bf16x2)
.Default(Intrinsic::not_intrinsic);

if (Name.consume_front("fmax."))
return StringSwitch<Intrinsic::ID>(Name)
.Case("bf16", Intrinsic::nvvm_fmax_bf16)
.Case("bf16x2", Intrinsic::nvvm_fmax_bf16x2)
.Case("ftz.bf16", Intrinsic::nvvm_fmax_ftz_bf16)
.Case("ftz.bf16x2", Intrinsic::nvvm_fmax_ftz_bf16x2)
.Case("ftz.nan.bf16", Intrinsic::nvvm_fmax_ftz_nan_bf16)
.Case("ftz.nan.bf16x2", Intrinsic::nvvm_fmax_ftz_nan_bf16x2)
.Case("ftz.nan.xorsign.abs.bf16",
Intrinsic::nvvm_fmax_ftz_nan_xorsign_abs_bf16)
.Case("ftz.nan.xorsign.abs.bf16x2",
Intrinsic::nvvm_fmax_ftz_nan_xorsign_abs_bf16x2)
.Case("ftz.xorsign.abs.bf16", Intrinsic::nvvm_fmax_ftz_xorsign_abs_bf16)
.Case("ftz.xorsign.abs.bf16x2",
Intrinsic::nvvm_fmax_ftz_xorsign_abs_bf16x2)
.Case("nan.bf16", Intrinsic::nvvm_fmax_nan_bf16)
.Case("nan.bf16x2", Intrinsic::nvvm_fmax_nan_bf16x2)
.Case("nan.xorsign.abs.bf16", Intrinsic::nvvm_fmax_nan_xorsign_abs_bf16)
.Case("nan.xorsign.abs.bf16x2",
Intrinsic::nvvm_fmax_nan_xorsign_abs_bf16x2)
.Case("xorsign.abs.bf16", Intrinsic::nvvm_fmax_xorsign_abs_bf16)
.Case("xorsign.abs.bf16x2", Intrinsic::nvvm_fmax_xorsign_abs_bf16x2)
.Default(Intrinsic::not_intrinsic);

if (Name.consume_front("fmin."))
return StringSwitch<Intrinsic::ID>(Name)
.Case("bf16", Intrinsic::nvvm_fmin_bf16)
.Case("bf16x2", Intrinsic::nvvm_fmin_bf16x2)
.Case("ftz.bf16", Intrinsic::nvvm_fmin_ftz_bf16)
.Case("ftz.bf16x2", Intrinsic::nvvm_fmin_ftz_bf16x2)
.Case("ftz.nan_bf16", Intrinsic::nvvm_fmin_ftz_nan_bf16)
.Case("ftz.nan_bf16x2", Intrinsic::nvvm_fmin_ftz_nan_bf16x2)
.Case("ftz.nan.xorsign.abs.bf16",
Intrinsic::nvvm_fmin_ftz_nan_xorsign_abs_bf16)
.Case("ftz.nan.xorsign.abs.bf16x2",
Intrinsic::nvvm_fmin_ftz_nan_xorsign_abs_bf16x2)
.Case("ftz.xorsign.abs.bf16", Intrinsic::nvvm_fmin_ftz_xorsign_abs_bf16)
.Case("ftz.xorsign.abs.bf16x2",
Intrinsic::nvvm_fmin_ftz_xorsign_abs_bf16x2)
.Case("nan.bf16", Intrinsic::nvvm_fmin_nan_bf16)
.Case("nan.bf16x2", Intrinsic::nvvm_fmin_nan_bf16x2)
.Case("nan.xorsign.abs.bf16", Intrinsic::nvvm_fmin_nan_xorsign_abs_bf16)
.Case("nan.xorsign.abs.bf16x2",
Intrinsic::nvvm_fmin_nan_xorsign_abs_bf16x2)
.Case("xorsign.abs.bf16", Intrinsic::nvvm_fmin_xorsign_abs_bf16)
.Case("xorsign.abs.bf16x2", Intrinsic::nvvm_fmin_xorsign_abs_bf16x2)
.Default(Intrinsic::not_intrinsic);

if (Name.consume_front("neg."))
return StringSwitch<Intrinsic::ID>(Name)
.Case("bf16", Intrinsic::nvvm_neg_bf16)
.Case("bf16x2", Intrinsic::nvvm_neg_bf16x2)
.Default(Intrinsic::not_intrinsic);

return Intrinsic::not_intrinsic;
}

static bool UpgradeIntrinsicFunction1(Function *F, Function *&NewFn) {
Expand Down Expand Up @@ -1052,42 +1067,55 @@ static bool UpgradeIntrinsicFunction1(Function *F, Function *&NewFn) {
break;
}
case 'n': {
if (Name.startswith("nvvm.")) {
Name = Name.substr(5);

// The following nvvm intrinsics correspond exactly to an LLVM intrinsic.
Intrinsic::ID IID = StringSwitch<Intrinsic::ID>(Name)
.Cases("brev32", "brev64", Intrinsic::bitreverse)
.Case("clz.i", Intrinsic::ctlz)
.Case("popc.i", Intrinsic::ctpop)
.Default(Intrinsic::not_intrinsic);
if (IID != Intrinsic::not_intrinsic && F->arg_size() == 1) {
NewFn = Intrinsic::getDeclaration(F->getParent(), IID,
{F->getReturnType()});
return true;
if (Name.consume_front("nvvm.")) {
// Check for nvvm intrinsics corresponding exactly to an LLVM intrinsic.
if (F->arg_size() == 1) {
Intrinsic::ID IID =
StringSwitch<Intrinsic::ID>(Name)
.Cases("brev32", "brev64", Intrinsic::bitreverse)
.Case("clz.i", Intrinsic::ctlz)
.Case("popc.i", Intrinsic::ctpop)
.Default(Intrinsic::not_intrinsic);
if (IID != Intrinsic::not_intrinsic) {
NewFn = Intrinsic::getDeclaration(F->getParent(), IID,
{F->getReturnType()});
return true;
}
}
IID = ShouldUpgradeNVPTXBF16Intrinsic(Name);
if (IID != Intrinsic::not_intrinsic &&
!F->getReturnType()->getScalarType()->isBFloatTy()) {
NewFn = nullptr;
return true;

// Check for nvvm intrinsics that need a return type adjustment.
if (!F->getReturnType()->getScalarType()->isBFloatTy()) {
Intrinsic::ID IID = ShouldUpgradeNVPTXBF16Intrinsic(Name);
if (IID != Intrinsic::not_intrinsic) {
NewFn = nullptr;
return true;
}
}

// The following nvvm intrinsics correspond exactly to an LLVM idiom, but
// not to an intrinsic alone. We expand them in UpgradeIntrinsicCall.
//
// TODO: We could add lohi.i2d.
bool Expand = StringSwitch<bool>(Name)
.Cases("abs.i", "abs.ll", true)
.Cases("clz.ll", "popc.ll", "h2f", true)
.Cases("max.i", "max.ll", "max.ui", "max.ull", true)
.Cases("min.i", "min.ll", "min.ui", "min.ull", true)
.StartsWith("atomic.load.add.f32.p", true)
.StartsWith("atomic.load.add.f64.p", true)
.Default(false);
bool Expand = false;
if (Name.consume_front("abs."))
// nvvm.abs.{i,ii}
Expand = Name == "i" || Name == "ll";
else if (Name == "clz.ll" || Name == "popc.ll" || Name == "h2f")
Expand = true;
else if (Name.consume_front("max.") || Name.consume_front("min."))
// nvvm.{min,max}.{i,ii,ui,ull}
Expand = Name == "i" || Name == "ll" || Name == "ui" || Name == "ull";
else if (Name.consume_front("atomic.load.add."))
// nvvm.atomic.load.add.{f32.p,f64.p}
Expand = Name.startswith("f32.p") || Name.startswith("f64.p");
else
Expand = false;

if (Expand) {
NewFn = nullptr;
return true;
}
break; // No other 'nvvm.*'.
}
break;
}
Expand Down

0 comments on commit b045c36

Please sign in to comment.