Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions clang/lib/DPCT/Asm/AsmNodes.h
Original file line number Diff line number Diff line change
Expand Up @@ -105,6 +105,7 @@ class InlineAsmBuiltinType : public InlineAsmType {
bool isInt() const { return isSigned() || isUnsigned(); }
bool isFloat() const { return isOneOf(f16, f32, f64); }
bool isScalar() const { return isInt() || isFloat(); }
bool isVector() const { return isOneOf(f16x2, bf16x2, s16x2, u16x2); }
unsigned getWidth() const;

static bool classof(const InlineAsmType *T) {
Expand Down
50 changes: 49 additions & 1 deletion clang/lib/DPCT/AsmMigration.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -200,6 +200,48 @@ class SYCLGenBase {
return SYCLGenSuccess();
}

bool needBitCast(const InlineAsmType *From, const InlineAsmType *To) {
if (From == To)
return false;
if (const auto *BIFrom = dyn_cast<InlineAsmBuiltinType>(From),
*BITo = dyn_cast<InlineAsmBuiltinType>(To);
BIFrom->isScalar() && BITo->isScalar())
return false;
return true;
}

bool emitBitCast(const InlineAsmType *From, const InlineAsmType *To,
std::string &Val) {
assert(needBitCast(From, To) && "Bit cast is unnecessary");
std::string Buffer;
llvm::raw_string_ostream TmpOS(Buffer);
llvm::SaveAndRestore<llvm::raw_ostream *> OutStream(Stream);
switchOutStream(TmpOS);
std::string FromT, ToT;
if (tryEmitType(FromT, From))
return SYCLGenError();
if (tryEmitType(ToT, To))
return SYCLGenError();
auto isVecTy = [&](const InlineAsmType *Ty) {
if (isa<InlineAsmVectorType>(Ty))
return true;
const auto *BI = dyn_cast<InlineAsmBuiltinType>(Ty);
return BI && BI->isVector();
};
if (isVecTy(From))
OS() << Val;
else
OS() << MapNames::getClNamespace() << "vec<" << FromT << ", 1>(" << Val
<< ')';
OS() << ".template as<";
if (isVecTy(To))
OS() << ToT << ">()";
else
OS() << MapNames::getClNamespace() << "vec<" << ToT << ", 1>>().x()";
Val = std::move(Buffer);
return SYCLGenSuccess();
}

// Types
bool emitType(const InlineAsmType *T);
bool emitBuiltinType(const InlineAsmBuiltinType *T);
Expand Down Expand Up @@ -1496,13 +1538,19 @@ class SYCLGen : public SYCLGenBase {
std::string Op;
if (tryEmitStmt(Op, Inst->getInputOperand(0)))
return SYCLGenError();

if (needBitCast(Inst->getInputOperand(0)->getType(), Inst->getType(0)) &&
emitBitCast(Inst->getInputOperand(0)->getType(), Inst->getType(0), Op))
return SYCLGenError();
std::string ReplaceString = MapNames::getClNamespace() + MathFn.str() + '(';
if (Inst->getOpcode() == asmtok::op_ex2)
ReplaceString += "2, ";
ReplaceString += Op + ")";
if (Inst->hasAttr(InstAttr::rn, InstAttr::rz, InstAttr::rm, InstAttr::rp))
report(Diagnostics::ROUNDING_MODE_UNSUPPORTED, true);
if (needBitCast(Inst->getType(0), Inst->getOutputOperand()->getType()) &&
emitBitCast(Inst->getType(0), Inst->getOutputOperand()->getType(),
ReplaceString))
return SYCLGenError();
OS() << ReplaceString;
endstmt();
return SYCLGenSuccess();
Expand Down
6 changes: 6 additions & 0 deletions clang/test/dpct/asm/tanh.cu
Original file line number Diff line number Diff line change
Expand Up @@ -15,4 +15,10 @@ __global__ void tanh() {
asm("tanh.approx.f32 %0, %1;" : "=f"(f32) : "f"(1.0f));
}

__global__ void f(unsigned *out) {
unsigned const in = 0;
// CHECK: *out = sycl::tanh(sycl::vec<uint32_t, 1>(in).template as<sycl::half2>()).template as<sycl::vec<uint32_t, 1>>().x();
asm volatile("tanh.approx.f16x2 %0, %1;" : "=r"(*out) : "r"(in));
}

// clang-format on