120 changes: 0 additions & 120 deletions libcxx/test/std/atomics/atomics.lockfree/isalwayslockfree.pass.cpp

This file was deleted.

34 changes: 30 additions & 4 deletions libcxx/test/std/atomics/atomics.ref/is_always_lock_free.pass.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,10 @@
// UNSUPPORTED: c++03, c++11, c++14, c++17

// <atomic>

//
// template <class T>
// class atomic_ref;
//
// static constexpr bool is_always_lock_free;
// bool is_lock_free() const noexcept;

Expand All @@ -18,10 +21,29 @@
#include <concepts>

#include "test_macros.h"
#include "atomic_helpers.h"

template <typename T>
void check_always_lock_free(std::atomic_ref<T> const a) {
std::same_as<const bool> decltype(auto) is_always_lock_free = std::atomic_ref<T>::is_always_lock_free;
void check_always_lock_free(std::atomic_ref<T> const& a) {
using InfoT = LockFreeStatusInfo<T>;

constexpr std::same_as<const bool> decltype(auto) is_always_lock_free = std::atomic_ref<T>::is_always_lock_free;

// If we know the status of T for sure, validate the exact result of the function.
if constexpr (InfoT::status_known) {
constexpr LockFreeStatus known_status = InfoT::value;
if constexpr (known_status == LockFreeStatus::always) {
static_assert(is_always_lock_free, "is_always_lock_free is inconsistent with known lock-free status");
assert(a.is_lock_free() && "is_lock_free() is inconsistent with known lock-free status");
} else if constexpr (known_status == LockFreeStatus::never) {
static_assert(!is_always_lock_free, "is_always_lock_free is inconsistent with known lock-free status");
assert(!a.is_lock_free() && "is_lock_free() is inconsistent with known lock-free status");
} else {
assert(a.is_lock_free() || !a.is_lock_free()); // This is kinda dumb, but we might as well call the function once.
}
}

// In all cases, also sanity-check it based on the implication always-lock-free => lock-free.
if (is_always_lock_free) {
std::same_as<bool> decltype(auto) is_lock_free = a.is_lock_free();
assert(is_lock_free);
Expand All @@ -33,10 +55,14 @@ void check_always_lock_free(std::atomic_ref<T> const a) {
do { \
typedef T type; \
type obj{}; \
check_always_lock_free(std::atomic_ref<type>(obj)); \
std::atomic_ref<type> a(obj); \
check_always_lock_free(a); \
} while (0)

void test() {
char c = 'x';
check_always_lock_free(std::atomic_ref<char>(c));

int i = 0;
check_always_lock_free(std::atomic_ref<int>(i));

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -12,10 +12,7 @@

// iter_type put(iter_type s, ios_base& iob, char_type fill, double v) const;

// With the Microsoft UCRT, printf("%a", 0.0) produces "0x0.0000000000000p+0"
// while other C runtimes produce just "0x0p+0".
// https://developercommunity.visualstudio.com/t/Printf-formatting-of-float-as-hex-prints/1660844
// XFAIL: msvc
// XFAIL: win32-broken-printf-a-precision

// XFAIL: LIBCXX-AIX-FIXME

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -12,10 +12,7 @@

// iter_type put(iter_type s, ios_base& iob, char_type fill, long double v) const;

// With the Microsoft UCRT, printf("%a", 0.0) produces "0x0.0000000000000p+0"
// while other C runtimes produce just "0x0p+0".
// https://developercommunity.visualstudio.com/t/Printf-formatting-of-float-as-hex-prints/1660844
// XFAIL: msvc
// XFAIL: win32-broken-printf-a-precision

// XFAIL: LIBCXX-AIX-FIXME

Expand Down
103 changes: 103 additions & 0 deletions libcxx/test/support/atomic_helpers.h
Original file line number Diff line number Diff line change
Expand Up @@ -11,9 +11,112 @@

#include <cassert>
#include <cstdint>
#include <cstddef>
#include <type_traits>

#include "test_macros.h"

#if defined(TEST_COMPILER_CLANG)
# define TEST_ATOMIC_CHAR_LOCK_FREE __CLANG_ATOMIC_CHAR_LOCK_FREE
# define TEST_ATOMIC_SHORT_LOCK_FREE __CLANG_ATOMIC_SHORT_LOCK_FREE
# define TEST_ATOMIC_INT_LOCK_FREE __CLANG_ATOMIC_INT_LOCK_FREE
# define TEST_ATOMIC_LONG_LOCK_FREE __CLANG_ATOMIC_LONG_LOCK_FREE
# define TEST_ATOMIC_LLONG_LOCK_FREE __CLANG_ATOMIC_LLONG_LOCK_FREE
# define TEST_ATOMIC_POINTER_LOCK_FREE __CLANG_ATOMIC_POINTER_LOCK_FREE
#elif defined(TEST_COMPILER_GCC)
# define TEST_ATOMIC_CHAR_LOCK_FREE __GCC_ATOMIC_CHAR_LOCK_FREE
# define TEST_ATOMIC_SHORT_LOCK_FREE __GCC_ATOMIC_SHORT_LOCK_FREE
# define TEST_ATOMIC_INT_LOCK_FREE __GCC_ATOMIC_INT_LOCK_FREE
# define TEST_ATOMIC_LONG_LOCK_FREE __GCC_ATOMIC_LONG_LOCK_FREE
# define TEST_ATOMIC_LLONG_LOCK_FREE __GCC_ATOMIC_LLONG_LOCK_FREE
# define TEST_ATOMIC_POINTER_LOCK_FREE __GCC_ATOMIC_POINTER_LOCK_FREE
#elif TEST_COMPILER_MSVC
// This is lifted from STL/stl/inc/atomic on github for the purposes of
// keeping the tests compiling for MSVC's STL. It's not a perfect solution
// but at least the tests will keep running.
//
// Note MSVC's STL never produces a type that is sometimes lock free, but not always lock free.
template <class T, size_t Size = sizeof(T)>
constexpr bool msvc_is_lock_free_macro_value() {
return (Size <= 8 && (Size & Size - 1) == 0) ? 2 : 0;
}
# define TEST_ATOMIC_CHAR_LOCK_FREE ::msvc_is_lock_free_macro_value<char>()
# define TEST_ATOMIC_SHORT_LOCK_FREE ::msvc_is_lock_free_macro_value<short>()
# define TEST_ATOMIC_INT_LOCK_FREE ::msvc_is_lock_free_macro_value<int>()
# define TEST_ATOMIC_LONG_LOCK_FREE ::msvc_is_lock_free_macro_value<long>()
# define TEST_ATOMIC_LLONG_LOCK_FREE ::msvc_is_lock_free_macro_value<long long>()
# define TEST_ATOMIC_POINTER_LOCK_FREE ::msvc_is_lock_free_macro_value<void*>()
#else
# error "Unknown compiler"
#endif

#ifdef TEST_COMPILER_CLANG
# pragma clang diagnostic push
# pragma clang diagnostic ignored "-Wc++11-extensions"
#endif

enum class LockFreeStatus : int { unknown = -1, never = 0, sometimes = 1, always = 2 };

// We should really be checking whether the alignment of T is greater-than-or-equal-to the alignment required
// for T to be atomic, but this is basically impossible to implement portably. Instead, we assume that any type
// aligned to at least its size is going to be atomic if there exists atomic operations for that size at all,
// which is true on most platforms. This technically reduces our test coverage in the sense that if a type has
// an alignment requirement less than its size but could still be made lockfree, LockFreeStatusInfo will report
// that we don't know whether it is lockfree or not.
#define COMPARE_TYPES(T, FundamentalT) (sizeof(T) == sizeof(FundamentalT) && TEST_ALIGNOF(T) >= sizeof(T))

template <class T>
struct LockFreeStatusInfo {
static const LockFreeStatus value = LockFreeStatus(
COMPARE_TYPES(T, char)
? TEST_ATOMIC_CHAR_LOCK_FREE
: (COMPARE_TYPES(T, short)
? TEST_ATOMIC_SHORT_LOCK_FREE
: (COMPARE_TYPES(T, int)
? TEST_ATOMIC_INT_LOCK_FREE
: (COMPARE_TYPES(T, long)
? TEST_ATOMIC_LONG_LOCK_FREE
: (COMPARE_TYPES(T, long long)
? TEST_ATOMIC_LLONG_LOCK_FREE
: (COMPARE_TYPES(T, void*) ? TEST_ATOMIC_POINTER_LOCK_FREE : -1))))));

static const bool status_known = LockFreeStatusInfo::value != LockFreeStatus::unknown;
};

#undef COMPARE_TYPES

// This doesn't work in C++03 due to issues with scoped enumerations. Just disable the test.
#if TEST_STD_VER >= 11
static_assert(LockFreeStatusInfo<char>::status_known, "");
static_assert(LockFreeStatusInfo<short>::status_known, "");
static_assert(LockFreeStatusInfo<int>::status_known, "");
static_assert(LockFreeStatusInfo<long>::status_known, "");
static_assert(LockFreeStatusInfo<void*>::status_known, "");

// long long is a bit funky: on some platforms, its alignment is 4 bytes but its size is
// 8 bytes. In that case, atomics may or may not be lockfree based on their address.
static_assert(alignof(long long) == sizeof(long long) ? LockFreeStatusInfo<long long>::status_known : true, "");

// Those should always be lock free: hardcode some expected values to make sure our tests are actually
// testing something meaningful.
static_assert(LockFreeStatusInfo<char>::value == LockFreeStatus::always, "");
static_assert(LockFreeStatusInfo<short>::value == LockFreeStatus::always, "");
static_assert(LockFreeStatusInfo<int>::value == LockFreeStatus::always, "");
#endif

// These macros are somewhat suprising to use, since they take the values 0, 1, or 2.
// To make the tests clearer, get rid of them in preference of LockFreeStatusInfo.
#undef TEST_ATOMIC_CHAR_LOCK_FREE
#undef TEST_ATOMIC_SHORT_LOCK_FREE
#undef TEST_ATOMIC_INT_LOCK_FREE
#undef TEST_ATOMIC_LONG_LOCK_FREE
#undef TEST_ATOMIC_LLONG_LOCK_FREE
#undef TEST_ATOMIC_POINTER_LOCK_FREE

#ifdef TEST_COMPILER_CLANG
# pragma clang diagnostic pop
#endif

struct UserAtomicType {
int i;

Expand Down
20 changes: 20 additions & 0 deletions libcxx/utils/libcxx/test/features.py
Original file line number Diff line number Diff line change
Expand Up @@ -263,6 +263,26 @@ def _mingwSupportsModules(cfg):
""",
),
),
# Check for a Windows UCRT bug (not fixed upstream yet).
# With UCRT, printf("%a", 0.0) produces "0x0.0000000000000p+0",
# while other C runtimes produce just "0x0p+0".
# https://developercommunity.visualstudio.com/t/Printf-formatting-of-float-as-hex-prints/1660844
Feature(
name="win32-broken-printf-a-precision",
when=lambda cfg: "_WIN32" in compilerMacros(cfg)
and not programSucceeds(
cfg,
"""
#include <stdio.h>
#include <string.h>
int main(int, char**) {
char buf[100];
snprintf(buf, sizeof(buf), "%a", 0.0);
return strcmp(buf, "0x0p+0");
}
""",
),
),
# Check for Glibc < 2.27, where the ru_RU.UTF-8 locale had
# mon_decimal_point == ".", which our tests don't handle.
Feature(
Expand Down
2 changes: 2 additions & 0 deletions libcxxabi/test/test_demangle.pass.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,8 @@
// 80-bit format, and this demangling test is failing on it.
// XFAIL: LIBCXX-ANDROID-FIXME && target={{i686|x86_64}}-{{.+}}-android{{.*}}

// XFAIL: win32-broken-printf-a-precision

#include "support/timer.h"
#include <algorithm>
#include <cassert>
Expand Down
10 changes: 3 additions & 7 deletions lldb/packages/Python/lldbsuite/test/lldbplatformutil.py
Original file line number Diff line number Diff line change
Expand Up @@ -266,17 +266,13 @@ def getCompiler():
return module.getCompiler()


def getCompilerBinary():
"""Returns the compiler binary the test suite is running with."""
return getCompiler().split()[0]


def getCompilerVersion():
"""Returns a string that represents the compiler version.
Supports: llvm, clang.
"""
compiler = getCompilerBinary()
version_output = subprocess.check_output([compiler, "--version"], errors="replace")
version_output = subprocess.check_output(
[getCompiler(), "--version"], errors="replace"
)
m = re.search("version ([0-9.]+)", version_output)
if m:
return m.group(1)
Expand Down
4 changes: 0 additions & 4 deletions lldb/packages/Python/lldbsuite/test/lldbtest.py
Original file line number Diff line number Diff line change
Expand Up @@ -1379,10 +1379,6 @@ def getCompiler(self):
"""Returns the compiler in effect the test suite is running with."""
return lldbplatformutil.getCompiler()

def getCompilerBinary(self):
"""Returns the compiler binary the test suite is running with."""
return lldbplatformutil.getCompilerBinary()

def getCompilerVersion(self):
"""Returns a string that represents the compiler version.
Supports: llvm, clang.
Expand Down
7 changes: 0 additions & 7 deletions llvm/include/llvm/Analysis/Loads.h
Original file line number Diff line number Diff line change
Expand Up @@ -104,13 +104,6 @@ bool isSafeToLoadUnconditionally(Value *V, Type *Ty, Align Alignment,
const DominatorTree *DT = nullptr,
const TargetLibraryInfo *TLI = nullptr);

/// Return true if speculation of the given load must be suppressed to avoid
/// ordering or interfering with an active sanitizer. If not suppressed,
/// dereferenceability and alignment must be proven separately. Note: This
/// is only needed for raw reasoning; if you use the interface below
/// (isSafeToSpeculativelyExecute), this is handled internally.
bool mustSuppressSpeculation(const LoadInst &LI);

/// The default number of maximum instructions to scan in the block, used by
/// FindAvailableLoadedValue().
extern cl::opt<unsigned> DefMaxInstsToScan;
Expand Down
7 changes: 7 additions & 0 deletions llvm/include/llvm/Analysis/ValueTracking.h
Original file line number Diff line number Diff line change
Expand Up @@ -792,6 +792,13 @@ bool onlyUsedByLifetimeMarkers(const Value *V);
/// droppable instructions.
bool onlyUsedByLifetimeMarkersOrDroppableInsts(const Value *V);

/// Return true if speculation of the given load must be suppressed to avoid
/// ordering or interfering with an active sanitizer. If not suppressed,
/// dereferenceability and alignment must be proven separately. Note: This
/// is only needed for raw reasoning; if you use the interface below
/// (isSafeToSpeculativelyExecute), this is handled internally.
bool mustSuppressSpeculation(const LoadInst &LI);

/// Return true if the instruction does not have any effects besides
/// calculating the result and does not have undefined behavior.
///
Expand Down
221 changes: 210 additions & 11 deletions llvm/include/llvm/SandboxIR/SandboxIR.h
Original file line number Diff line number Diff line change
Expand Up @@ -18,13 +18,19 @@
//
// namespace sandboxir {
//
// +- Argument +- BinaryOperator
// | |
// Value -+- BasicBlock +- BranchInst
// | |
// +- Function +- Constant +- CastInst
// | | |
// +- User ------+- Instruction -+- CallInst
// Value -+- Argument
// |
// +- BasicBlock
// |
// +- User ------+- Constant ------ Function
// |
// +- Instruction -+- BinaryOperator
// |
// +- BranchInst
// |
// +- CastInst
// |
// +- CallBase ----- CallInst
// |
// +- CmpInst
// |
Expand Down Expand Up @@ -82,6 +88,8 @@ class ReturnInst;
class StoreInst;
class User;
class Value;
class CallBase;
class CallInst;

/// Iterator for the `Use` edges of a User's operands.
/// \Returns the operand `Use` when dereferenced.
Expand All @@ -103,12 +111,20 @@ class OperandUseIterator {
OperandUseIterator() = default;
value_type operator*() const;
OperandUseIterator &operator++();
OperandUseIterator operator++(int) {
auto Copy = *this;
this->operator++();
return Copy;
}
bool operator==(const OperandUseIterator &Other) const {
return Use == Other.Use;
}
bool operator!=(const OperandUseIterator &Other) const {
return !(*this == Other);
}
OperandUseIterator operator+(unsigned Num) const;
OperandUseIterator operator-(unsigned Num) const;
int operator-(const OperandUseIterator &Other) const;
};

/// Iterator for the `Use` edges of a Value's users.
Expand All @@ -135,6 +151,7 @@ class UserUseIterator {
bool operator!=(const UserUseIterator &Other) const {
return !(*this == Other);
}
const sandboxir::Use &getUse() const { return Use; }
};

/// A SandboxIR Value has users. This is the base class.
Expand Down Expand Up @@ -184,6 +201,8 @@ class Value {
friend class LoadInst; // For getting `Val`.
friend class StoreInst; // For getting `Val`.
friend class ReturnInst; // For getting `Val`.
friend class CallBase; // For getting `Val`.
friend class CallInst; // For getting `Val`.

/// All values point to the context.
Context &Ctx;
Expand Down Expand Up @@ -417,7 +436,10 @@ class User : public Value {
class Constant : public sandboxir::User {
Constant(llvm::Constant *C, sandboxir::Context &SBCtx)
: sandboxir::User(ClassID::Constant, C, SBCtx) {}
friend class Context; // For constructor.
Constant(ClassID ID, llvm::Constant *C, sandboxir::Context &SBCtx)
: sandboxir::User(ID, C, SBCtx) {}
friend class Function; // For constructor
friend class Context; // For constructor.
Use getOperandUseInternal(unsigned OpIdx, bool Verify) const final {
return getOperandUseDefault(OpIdx, Verify);
}
Expand All @@ -435,7 +457,7 @@ class Constant : public sandboxir::User {
return getUseOperandNoDefault(Use);
}
#ifndef NDEBUG
void verify() const final {
void verify() const override {
assert(isa<llvm::Constant>(Val) && "Expected Constant!");
}
friend raw_ostream &operator<<(raw_ostream &OS,
Expand Down Expand Up @@ -518,6 +540,7 @@ class Instruction : public sandboxir::User {
friend class LoadInst; // For getTopmostLLVMInstruction().
friend class StoreInst; // For getTopmostLLVMInstruction().
friend class ReturnInst; // For getTopmostLLVMInstruction().
friend class CallInst; // For getTopmostLLVMInstruction().

/// \Returns the LLVM IR Instructions that this SandboxIR maps to in program
/// order.
Expand Down Expand Up @@ -835,6 +858,177 @@ class ReturnInst final : public Instruction {
#endif
};

class CallBase : public Instruction {
CallBase(ClassID ID, Opcode Opc, llvm::Instruction *I, Context &Ctx)
: Instruction(ID, Opc, I, Ctx) {}
friend class CallInst; // For constructor.

public:
static bool classof(const Value *From) {
auto Opc = From->getSubclassID();
return Opc == Instruction::ClassID::Call ||
Opc == Instruction::ClassID::Invoke ||
Opc == Instruction::ClassID::CallBr;
}

FunctionType *getFunctionType() const {
return cast<llvm::CallBase>(Val)->getFunctionType();
}

op_iterator data_operands_begin() { return op_begin(); }
const_op_iterator data_operands_begin() const {
return const_cast<CallBase *>(this)->data_operands_begin();
}
op_iterator data_operands_end() {
auto *LLVMCB = cast<llvm::CallBase>(Val);
auto Dist = LLVMCB->data_operands_end() - LLVMCB->data_operands_begin();
return op_begin() + Dist;
}
const_op_iterator data_operands_end() const {
auto *LLVMCB = cast<llvm::CallBase>(Val);
auto Dist = LLVMCB->data_operands_end() - LLVMCB->data_operands_begin();
return op_begin() + Dist;
}
iterator_range<op_iterator> data_ops() {
return make_range(data_operands_begin(), data_operands_end());
}
iterator_range<const_op_iterator> data_ops() const {
return make_range(data_operands_begin(), data_operands_end());
}
bool data_operands_empty() const {
return data_operands_end() == data_operands_begin();
}
unsigned data_operands_size() const {
return std::distance(data_operands_begin(), data_operands_end());
}
bool isDataOperand(Use U) const {
assert(this == U.getUser() &&
"Only valid to query with a use of this instruction!");
return cast<llvm::CallBase>(Val)->isDataOperand(U.LLVMUse);
}
unsigned getDataOperandNo(Use U) const {
assert(isDataOperand(U) && "Data operand # out of range!");
return cast<llvm::CallBase>(Val)->getDataOperandNo(U.LLVMUse);
}

/// Return the total number operands (not operand bundles) used by
/// every operand bundle in this OperandBundleUser.
unsigned getNumTotalBundleOperands() const {
return cast<llvm::CallBase>(Val)->getNumTotalBundleOperands();
}

op_iterator arg_begin() { return op_begin(); }
const_op_iterator arg_begin() const { return op_begin(); }
op_iterator arg_end() {
return data_operands_end() - getNumTotalBundleOperands();
}
const_op_iterator arg_end() const {
return const_cast<CallBase *>(this)->arg_end();
}
iterator_range<op_iterator> args() {
return make_range(arg_begin(), arg_end());
}
iterator_range<const_op_iterator> args() const {
return make_range(arg_begin(), arg_end());
}
bool arg_empty() const { return arg_end() == arg_begin(); }
unsigned arg_size() const { return arg_end() - arg_begin(); }

Value *getArgOperand(unsigned OpIdx) const {
assert(OpIdx < arg_size() && "Out of bounds!");
return getOperand(OpIdx);
}
void setArgOperand(unsigned OpIdx, Value *NewOp) {
assert(OpIdx < arg_size() && "Out of bounds!");
setOperand(OpIdx, NewOp);
}

Use getArgOperandUse(unsigned Idx) const {
assert(Idx < arg_size() && "Out of bounds!");
return getOperandUse(Idx);
}
Use getArgOperandUse(unsigned Idx) {
assert(Idx < arg_size() && "Out of bounds!");
return getOperandUse(Idx);
}

bool isArgOperand(Use U) const {
return cast<llvm::CallBase>(Val)->isArgOperand(U.LLVMUse);
}
unsigned getArgOperandNo(Use U) const {
return cast<llvm::CallBase>(Val)->getArgOperandNo(U.LLVMUse);
}
bool hasArgument(const Value *V) const { return is_contained(args(), V); }

Value *getCalledOperand() const;
Use getCalledOperandUse() const;

Function *getCalledFunction() const;
bool isIndirectCall() const {
return cast<llvm::CallBase>(Val)->isIndirectCall();
}
bool isCallee(Use U) const {
return cast<llvm::CallBase>(Val)->isCallee(U.LLVMUse);
}
Function *getCaller();
const Function *getCaller() const {
return const_cast<CallBase *>(this)->getCaller();
}
bool isMustTailCall() const {
return cast<llvm::CallBase>(Val)->isMustTailCall();
}
bool isTailCall() const { return cast<llvm::CallBase>(Val)->isTailCall(); }
Intrinsic::ID getIntrinsicID() const {
return cast<llvm::CallBase>(Val)->getIntrinsicID();
}
void setCalledOperand(Value *V) { getCalledOperandUse().set(V); }
void setCalledFunction(Function *F);
CallingConv::ID getCallingConv() const {
return cast<llvm::CallBase>(Val)->getCallingConv();
}
bool isInlineAsm() const { return cast<llvm::CallBase>(Val)->isInlineAsm(); }
};

class CallInst final : public CallBase {
/// Use Context::createCallInst(). Don't call the
/// constructor directly.
CallInst(llvm::Instruction *I, Context &Ctx)
: CallBase(ClassID::Call, Opcode::Call, I, Ctx) {}
friend class Context; // For accessing the constructor in
// create*()
Use getOperandUseInternal(unsigned OpIdx, bool Verify) const final {
return getOperandUseDefault(OpIdx, Verify);
}
SmallVector<llvm::Instruction *, 1> getLLVMInstrs() const final {
return {cast<llvm::Instruction>(Val)};
}

public:
static CallInst *create(FunctionType *FTy, Value *Func,
ArrayRef<Value *> Args, BBIterator WhereIt,
BasicBlock *WhereBB, Context &Ctx,
const Twine &NameStr = "");
static CallInst *create(FunctionType *FTy, Value *Func,
ArrayRef<Value *> Args, Instruction *InsertBefore,
Context &Ctx, const Twine &NameStr = "");
static CallInst *create(FunctionType *FTy, Value *Func,
ArrayRef<Value *> Args, BasicBlock *InsertAtEnd,
Context &Ctx, const Twine &NameStr = "");

static bool classof(const Value *From) {
return From->getSubclassID() == ClassID::Call;
}
unsigned getUseOperandNo(const Use &Use) const final {
return getUseOperandNoDefault(Use);
}
unsigned getNumOfIRInstrs() const final { return 1u; }
#ifndef NDEBUG
void verify() const final {}
void dump(raw_ostream &OS) const override;
LLVM_DUMP_METHOD void dump() const override;
#endif
};

/// An LLLVM Instruction that has no SandboxIR equivalent class gets mapped to
/// an OpaqueInstr.
class OpaqueInst : public sandboxir::Instruction {
Expand Down Expand Up @@ -983,6 +1177,8 @@ class Context {
friend StoreInst; // For createStoreInst()
ReturnInst *createReturnInst(llvm::ReturnInst *I);
friend ReturnInst; // For createReturnInst()
CallInst *createCallInst(llvm::CallInst *I);
friend CallInst; // For createCallInst()

public:
Context(LLVMContext &LLVMCtx)
Expand Down Expand Up @@ -1010,7 +1206,7 @@ class Context {
size_t getNumValues() const { return LLVMValueToValueMap.size(); }
};

class Function : public sandboxir::Value {
class Function : public Constant {
/// Helper for mapped_iterator.
struct LLVMBBToBB {
Context &Ctx;
Expand All @@ -1021,7 +1217,7 @@ class Function : public sandboxir::Value {
};
/// Use Context::createFunction() instead.
Function(llvm::Function *F, sandboxir::Context &Ctx)
: sandboxir::Value(ClassID::Function, F, Ctx) {}
: Constant(ClassID::Function, F, Ctx) {}
friend class Context; // For constructor.

public:
Expand All @@ -1047,6 +1243,9 @@ class Function : public sandboxir::Value {
LLVMBBToBB BBGetter(Ctx);
return iterator(cast<llvm::Function>(Val)->end(), BBGetter);
}
FunctionType *getFunctionType() const {
return cast<llvm::Function>(Val)->getFunctionType();
}

#ifndef NDEBUG
void verify() const final {
Expand Down
3 changes: 3 additions & 0 deletions llvm/include/llvm/SandboxIR/SandboxIRValues.def
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,9 @@ DEF_INSTR(Br, OP(Br), BranchInst)
DEF_INSTR(Load, OP(Load), LoadInst)
DEF_INSTR(Store, OP(Store), StoreInst)
DEF_INSTR(Ret, OP(Ret), ReturnInst)
DEF_INSTR(Call, OP(Call), CallInst)
DEF_INSTR(Invoke, OP(Invoke), InvokeInst)
DEF_INSTR(CallBr, OP(CallBr), CallBrInst)

#ifdef DEF_VALUE
#undef DEF_VALUE
Expand Down
2 changes: 2 additions & 0 deletions llvm/include/llvm/SandboxIR/Use.h
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@ namespace llvm::sandboxir {
class Context;
class Value;
class User;
class CallBase;

/// Represents a Def-use/Use-def edge in SandboxIR.
/// NOTE: Unlike llvm::Use, this is not an integral part of the use-def chains.
Expand All @@ -40,6 +41,7 @@ class Use {
friend class User; // For constructor
friend class OperandUseIterator; // For constructor
friend class UserUseIterator; // For accessing members
friend class CallBase; // For LLVMUse

public:
operator Value *() const { return get(); }
Expand Down
15 changes: 0 additions & 15 deletions llvm/lib/Analysis/Loads.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -345,21 +345,6 @@ bool llvm::isDereferenceableAndAlignedInLoop(LoadInst *LI, Loop *L,
HeaderFirstNonPHI, AC, &DT);
}

static bool suppressSpeculativeLoadForSanitizers(const Function &F) {
// Speculative load may create a race that did not exist in the source.
return F.hasFnAttribute(Attribute::SanitizeThread) ||
// Speculative load may load data from dirty regions.
F.hasFnAttribute(Attribute::SanitizeAddress) ||
F.hasFnAttribute(Attribute::SanitizeHWAddress);
}

bool llvm::mustSuppressSpeculation(const LoadInst &LI) {
if (!LI.isUnordered())
return true;
const Function &F = *LI.getFunction();
return suppressSpeculativeLoadForSanitizers(F);
}

/// Check if executing a load of this pointer value cannot trap.
///
/// If DT and ScanFrom are specified this method performs context-sensitive
Expand Down
11 changes: 11 additions & 0 deletions llvm/lib/Analysis/ValueTracking.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -6798,6 +6798,17 @@ bool llvm::onlyUsedByLifetimeMarkersOrDroppableInsts(const Value *V) {
V, /* AllowLifetime */ true, /* AllowDroppable */ true);
}

bool llvm::mustSuppressSpeculation(const LoadInst &LI) {
if (!LI.isUnordered())
return true;
const Function &F = *LI.getFunction();
// Speculative load may create a race that did not exist in the source.
return F.hasFnAttribute(Attribute::SanitizeThread) ||
// Speculative load may load data from dirty regions.
F.hasFnAttribute(Attribute::SanitizeAddress) ||
F.hasFnAttribute(Attribute::SanitizeHWAddress);
}

bool llvm::isSafeToSpeculativelyExecute(const Instruction *Inst,
const Instruction *CtxI,
AssumptionCache *AC,
Expand Down
115 changes: 112 additions & 3 deletions llvm/lib/SandboxIR/SandboxIR.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,12 @@ using namespace llvm::sandboxir;

Value *Use::get() const { return Ctx->getValue(LLVMUse->get()); }

void Use::set(Value *V) { LLVMUse->set(V->Val); }
void Use::set(Value *V) {
auto &Tracker = Ctx->getTracker();
if (Tracker.isTracking())
Tracker.track(std::make_unique<UseSet>(*this, Tracker));
LLVMUse->set(V->Val);
}

unsigned Use::getOperandNo() const { return Usr->getUseOperandNo(*this); }

Expand Down Expand Up @@ -84,6 +89,25 @@ UserUseIterator &UserUseIterator::operator++() {
return *this;
}

OperandUseIterator OperandUseIterator::operator+(unsigned Num) const {
sandboxir::Use U = Use.getUser()->getOperandUseInternal(
Use.getOperandNo() + Num, /*Verify=*/true);
return OperandUseIterator(U);
}

OperandUseIterator OperandUseIterator::operator-(unsigned Num) const {
assert(Use.getOperandNo() >= Num && "Out of bounds!");
sandboxir::Use U = Use.getUser()->getOperandUseInternal(
Use.getOperandNo() - Num, /*Verify=*/true);
return OperandUseIterator(U);
}

int OperandUseIterator::operator-(const OperandUseIterator &Other) const {
int ThisOpNo = Use.getOperandNo();
int OtherOpNo = Other.Use.getOperandNo();
return ThisOpNo - OtherOpNo;
}

Value::Value(ClassID SubclassID, llvm::Value *Val, Context &Ctx)
: SubclassID(SubclassID), Val(Val), Ctx(Ctx) {
#ifndef NDEBUG
Expand Down Expand Up @@ -713,6 +737,78 @@ void ReturnInst::dump() const {
dump(dbgs());
dbgs() << "\n";
}
#endif // NDEBUG

Value *CallBase::getCalledOperand() const {
return Ctx.getValue(cast<llvm::CallBase>(Val)->getCalledOperand());
}

Use CallBase::getCalledOperandUse() const {
llvm::Use *LLVMUse = &cast<llvm::CallBase>(Val)->getCalledOperandUse();
return Use(LLVMUse, cast<User>(Ctx.getValue(LLVMUse->getUser())), Ctx);
}

Function *CallBase::getCalledFunction() const {
return cast_or_null<Function>(
Ctx.getValue(cast<llvm::CallBase>(Val)->getCalledFunction()));
}
Function *CallBase::getCaller() {
return cast<Function>(Ctx.getValue(cast<llvm::CallBase>(Val)->getCaller()));
}

void CallBase::setCalledFunction(Function *F) {
// F's function type is private, so we rely on `setCalledFunction()` to update
// it. But even though we are calling `setCalledFunction()` we also need to
// track this change at the SandboxIR level, which is why we call
// `setCalledOperand()` here.
// Note: This may break if `setCalledFunction()` early returns if `F`
// is already set, but we do have a unit test for it.
setCalledOperand(F);
cast<llvm::CallBase>(Val)->setCalledFunction(F->getFunctionType(),
cast<llvm::Function>(F->Val));
}

CallInst *CallInst::create(FunctionType *FTy, Value *Func,
ArrayRef<Value *> Args, BasicBlock::iterator WhereIt,
BasicBlock *WhereBB, Context &Ctx,
const Twine &NameStr) {
auto &Builder = Ctx.getLLVMIRBuilder();
if (WhereIt != WhereBB->end())
Builder.SetInsertPoint((*WhereIt).getTopmostLLVMInstruction());
else
Builder.SetInsertPoint(cast<llvm::BasicBlock>(WhereBB->Val));
SmallVector<llvm::Value *> LLVMArgs;
LLVMArgs.reserve(Args.size());
for (Value *Arg : Args)
LLVMArgs.push_back(Arg->Val);
llvm::CallInst *NewCI = Builder.CreateCall(FTy, Func->Val, LLVMArgs, NameStr);
return Ctx.createCallInst(NewCI);
}

CallInst *CallInst::create(FunctionType *FTy, Value *Func,
ArrayRef<Value *> Args, Instruction *InsertBefore,
Context &Ctx, const Twine &NameStr) {
return CallInst::create(FTy, Func, Args, InsertBefore->getIterator(),
InsertBefore->getParent(), Ctx, NameStr);
}

CallInst *CallInst::create(FunctionType *FTy, Value *Func,
ArrayRef<Value *> Args, BasicBlock *InsertAtEnd,
Context &Ctx, const Twine &NameStr) {
return CallInst::create(FTy, Func, Args, InsertAtEnd->end(), InsertAtEnd, Ctx,
NameStr);
}

#ifndef NDEBUG
void CallInst::dump(raw_ostream &OS) const {
dumpCommonPrefix(OS);
dumpCommonSuffix(OS);
}

void CallInst::dump() const {
dump(dbgs());
dbgs() << "\n";
}

void OpaqueInst::dump(raw_ostream &OS) const {
dumpCommonPrefix(OS);
Expand Down Expand Up @@ -819,7 +915,10 @@ Value *Context::getOrCreateValueInternal(llvm::Value *LLVMV, llvm::User *U) {
return It->second.get();

if (auto *C = dyn_cast<llvm::Constant>(LLVMV)) {
It->second = std::unique_ptr<Constant>(new Constant(C, *this));
if (auto *F = dyn_cast<llvm::Function>(LLVMV))
It->second = std::unique_ptr<Function>(new Function(F, *this));
else
It->second = std::unique_ptr<Constant>(new Constant(C, *this));
auto *NewC = It->second.get();
for (llvm::Value *COp : C->operands())
getOrCreateValueInternal(COp, C);
Expand Down Expand Up @@ -864,6 +963,11 @@ Value *Context::getOrCreateValueInternal(llvm::Value *LLVMV, llvm::User *U) {
It->second = std::unique_ptr<ReturnInst>(new ReturnInst(LLVMRet, *this));
return It->second.get();
}
case llvm::Instruction::Call: {
auto *LLVMCall = cast<llvm::CallInst>(LLVMV);
It->second = std::unique_ptr<CallInst>(new CallInst(LLVMCall, *this));
return It->second.get();
}
default:
break;
}
Expand Down Expand Up @@ -907,6 +1011,11 @@ ReturnInst *Context::createReturnInst(llvm::ReturnInst *I) {
return cast<ReturnInst>(registerValue(std::move(NewPtr)));
}

CallInst *Context::createCallInst(llvm::CallInst *I) {
auto NewPtr = std::unique_ptr<CallInst>(new CallInst(I, *this));
return cast<CallInst>(registerValue(std::move(NewPtr)));
}

Value *Context::getValue(llvm::Value *V) const {
auto It = LLVMValueToValueMap.find(V);
if (It != LLVMValueToValueMap.end())
Expand All @@ -917,13 +1026,13 @@ Value *Context::getValue(llvm::Value *V) const {
Function *Context::createFunction(llvm::Function *F) {
assert(getValue(F) == nullptr && "Already exists!");
auto NewFPtr = std::unique_ptr<Function>(new Function(F, *this));
auto *SBF = cast<Function>(registerValue(std::move(NewFPtr)));
// Create arguments.
for (auto &Arg : F->args())
getOrCreateArgument(&Arg);
// Create BBs.
for (auto &BB : *F)
createBasicBlock(&BB);
auto *SBF = cast<Function>(registerValue(std::move(NewFPtr)));
return SBF;
}

Expand Down
191 changes: 191 additions & 0 deletions llvm/test/CodeGen/AMDGPU/GlobalISel/inst-select-constant.mir
Original file line number Diff line number Diff line change
Expand Up @@ -406,6 +406,197 @@ body: |
S_ENDPGM 0, implicit %0 , implicit %1 , implicit %2, implicit %3, implicit %4, implicit %5, implicit %6, implicit %7
...

---
name: constant_s_p0
legalized: true
regBankSelected: true
tracksRegLiveness: true

body: |
bb.0:
; WAVE64-LABEL: name: constant_s_p0
; WAVE64: [[S_MOV_B:%[0-9]+]]:sreg_64 = S_MOV_B64_IMM_PSEUDO 0
; WAVE64-NEXT: [[S_MOV_B1:%[0-9]+]]:sreg_64 = S_MOV_B64_IMM_PSEUDO 1
; WAVE64-NEXT: [[S_MOV_B2:%[0-9]+]]:sreg_64 = S_MOV_B64_IMM_PSEUDO -1
; WAVE64-NEXT: [[S_MOV_B3:%[0-9]+]]:sreg_64 = S_MOV_B64_IMM_PSEUDO -54
; WAVE64-NEXT: [[S_MOV_B4:%[0-9]+]]:sreg_64 = S_MOV_B64_IMM_PSEUDO 27
; WAVE64-NEXT: [[S_MOV_B5:%[0-9]+]]:sreg_64 = S_MOV_B64_IMM_PSEUDO 4294967295
; WAVE64-NEXT: [[S_MOV_B32_:%[0-9]+]]:sreg_32 = S_MOV_B32 0
; WAVE64-NEXT: [[S_MOV_B32_1:%[0-9]+]]:sreg_32 = S_MOV_B32 1
; WAVE64-NEXT: [[REG_SEQUENCE:%[0-9]+]]:sreg_64 = REG_SEQUENCE [[S_MOV_B32_]], %subreg.sub0, [[S_MOV_B32_1]], %subreg.sub1
; WAVE64-NEXT: [[S_MOV_B32_2:%[0-9]+]]:sreg_32 = S_MOV_B32 23255
; WAVE64-NEXT: [[S_MOV_B32_3:%[0-9]+]]:sreg_32 = S_MOV_B32 -16
; WAVE64-NEXT: [[REG_SEQUENCE1:%[0-9]+]]:sreg_64 = REG_SEQUENCE [[S_MOV_B32_2]], %subreg.sub0, [[S_MOV_B32_3]], %subreg.sub1
; WAVE64-NEXT: S_ENDPGM 0, implicit [[S_MOV_B]], implicit [[S_MOV_B1]], implicit [[S_MOV_B2]], implicit [[S_MOV_B3]], implicit [[S_MOV_B4]], implicit [[S_MOV_B5]], implicit [[REG_SEQUENCE]], implicit [[REG_SEQUENCE1]]
;
; WAVE32-LABEL: name: constant_s_p0
; WAVE32: [[S_MOV_B:%[0-9]+]]:sreg_64 = S_MOV_B64_IMM_PSEUDO 0
; WAVE32-NEXT: [[S_MOV_B1:%[0-9]+]]:sreg_64 = S_MOV_B64_IMM_PSEUDO 1
; WAVE32-NEXT: [[S_MOV_B2:%[0-9]+]]:sreg_64 = S_MOV_B64_IMM_PSEUDO -1
; WAVE32-NEXT: [[S_MOV_B3:%[0-9]+]]:sreg_64 = S_MOV_B64_IMM_PSEUDO -54
; WAVE32-NEXT: [[S_MOV_B4:%[0-9]+]]:sreg_64 = S_MOV_B64_IMM_PSEUDO 27
; WAVE32-NEXT: [[S_MOV_B5:%[0-9]+]]:sreg_64 = S_MOV_B64_IMM_PSEUDO 4294967295
; WAVE32-NEXT: [[S_MOV_B32_:%[0-9]+]]:sreg_32 = S_MOV_B32 0
; WAVE32-NEXT: [[S_MOV_B32_1:%[0-9]+]]:sreg_32 = S_MOV_B32 1
; WAVE32-NEXT: [[REG_SEQUENCE:%[0-9]+]]:sreg_64 = REG_SEQUENCE [[S_MOV_B32_]], %subreg.sub0, [[S_MOV_B32_1]], %subreg.sub1
; WAVE32-NEXT: [[S_MOV_B32_2:%[0-9]+]]:sreg_32 = S_MOV_B32 23255
; WAVE32-NEXT: [[S_MOV_B32_3:%[0-9]+]]:sreg_32 = S_MOV_B32 -16
; WAVE32-NEXT: [[REG_SEQUENCE1:%[0-9]+]]:sreg_64 = REG_SEQUENCE [[S_MOV_B32_2]], %subreg.sub0, [[S_MOV_B32_3]], %subreg.sub1
; WAVE32-NEXT: S_ENDPGM 0, implicit [[S_MOV_B]], implicit [[S_MOV_B1]], implicit [[S_MOV_B2]], implicit [[S_MOV_B3]], implicit [[S_MOV_B4]], implicit [[S_MOV_B5]], implicit [[REG_SEQUENCE]], implicit [[REG_SEQUENCE1]]
%0:sgpr(p0) = G_CONSTANT i64 0
%1:sgpr(p0) = G_CONSTANT i64 1
%2:sgpr(p0) = G_CONSTANT i64 -1
%3:sgpr(p0) = G_CONSTANT i64 -54
%4:sgpr(p0) = G_CONSTANT i64 27
%5:sgpr(p0) = G_CONSTANT i64 4294967295
%6:sgpr(p0) = G_CONSTANT i64 4294967296
%7:sgpr(p0) = G_CONSTANT i64 18446744004990098135
S_ENDPGM 0, implicit %0 , implicit %1 , implicit %2, implicit %3, implicit %4, implicit %5, implicit %6, implicit %7
...

---
name: constant_v_p0
legalized: true
regBankSelected: true
tracksRegLiveness: true

body: |
bb.0:
; WAVE64-LABEL: name: constant_v_p0
; WAVE64: [[V_MOV_B:%[0-9]+]]:vreg_64 = V_MOV_B64_PSEUDO 0, implicit $exec
; WAVE64-NEXT: [[V_MOV_B1:%[0-9]+]]:vreg_64 = V_MOV_B64_PSEUDO 1, implicit $exec
; WAVE64-NEXT: [[V_MOV_B2:%[0-9]+]]:vreg_64 = V_MOV_B64_PSEUDO -1, implicit $exec
; WAVE64-NEXT: [[V_MOV_B3:%[0-9]+]]:vreg_64 = V_MOV_B64_PSEUDO -54, implicit $exec
; WAVE64-NEXT: [[V_MOV_B4:%[0-9]+]]:vreg_64 = V_MOV_B64_PSEUDO 27, implicit $exec
; WAVE64-NEXT: [[V_MOV_B5:%[0-9]+]]:vreg_64 = V_MOV_B64_PSEUDO 4294967295, implicit $exec
; WAVE64-NEXT: [[V_MOV_B32_e32_:%[0-9]+]]:vgpr_32 = V_MOV_B32_e32 0, implicit $exec
; WAVE64-NEXT: [[V_MOV_B32_e32_1:%[0-9]+]]:vgpr_32 = V_MOV_B32_e32 1, implicit $exec
; WAVE64-NEXT: [[REG_SEQUENCE:%[0-9]+]]:vreg_64 = REG_SEQUENCE [[V_MOV_B32_e32_]], %subreg.sub0, [[V_MOV_B32_e32_1]], %subreg.sub1
; WAVE64-NEXT: [[V_MOV_B32_e32_2:%[0-9]+]]:vgpr_32 = V_MOV_B32_e32 23255, implicit $exec
; WAVE64-NEXT: [[V_MOV_B32_e32_3:%[0-9]+]]:vgpr_32 = V_MOV_B32_e32 -16, implicit $exec
; WAVE64-NEXT: [[REG_SEQUENCE1:%[0-9]+]]:vreg_64 = REG_SEQUENCE [[V_MOV_B32_e32_2]], %subreg.sub0, [[V_MOV_B32_e32_3]], %subreg.sub1
; WAVE64-NEXT: S_ENDPGM 0, implicit [[V_MOV_B]], implicit [[V_MOV_B1]], implicit [[V_MOV_B2]], implicit [[V_MOV_B3]], implicit [[V_MOV_B4]], implicit [[V_MOV_B5]], implicit [[REG_SEQUENCE]], implicit [[REG_SEQUENCE1]]
;
; WAVE32-LABEL: name: constant_v_p0
; WAVE32: [[V_MOV_B:%[0-9]+]]:vreg_64 = V_MOV_B64_PSEUDO 0, implicit $exec
; WAVE32-NEXT: [[V_MOV_B1:%[0-9]+]]:vreg_64 = V_MOV_B64_PSEUDO 1, implicit $exec
; WAVE32-NEXT: [[V_MOV_B2:%[0-9]+]]:vreg_64 = V_MOV_B64_PSEUDO -1, implicit $exec
; WAVE32-NEXT: [[V_MOV_B3:%[0-9]+]]:vreg_64 = V_MOV_B64_PSEUDO -54, implicit $exec
; WAVE32-NEXT: [[V_MOV_B4:%[0-9]+]]:vreg_64 = V_MOV_B64_PSEUDO 27, implicit $exec
; WAVE32-NEXT: [[V_MOV_B5:%[0-9]+]]:vreg_64 = V_MOV_B64_PSEUDO 4294967295, implicit $exec
; WAVE32-NEXT: [[V_MOV_B32_e32_:%[0-9]+]]:vgpr_32 = V_MOV_B32_e32 0, implicit $exec
; WAVE32-NEXT: [[V_MOV_B32_e32_1:%[0-9]+]]:vgpr_32 = V_MOV_B32_e32 1, implicit $exec
; WAVE32-NEXT: [[REG_SEQUENCE:%[0-9]+]]:vreg_64 = REG_SEQUENCE [[V_MOV_B32_e32_]], %subreg.sub0, [[V_MOV_B32_e32_1]], %subreg.sub1
; WAVE32-NEXT: [[V_MOV_B32_e32_2:%[0-9]+]]:vgpr_32 = V_MOV_B32_e32 23255, implicit $exec
; WAVE32-NEXT: [[V_MOV_B32_e32_3:%[0-9]+]]:vgpr_32 = V_MOV_B32_e32 -16, implicit $exec
; WAVE32-NEXT: [[REG_SEQUENCE1:%[0-9]+]]:vreg_64 = REG_SEQUENCE [[V_MOV_B32_e32_2]], %subreg.sub0, [[V_MOV_B32_e32_3]], %subreg.sub1
; WAVE32-NEXT: S_ENDPGM 0, implicit [[V_MOV_B]], implicit [[V_MOV_B1]], implicit [[V_MOV_B2]], implicit [[V_MOV_B3]], implicit [[V_MOV_B4]], implicit [[V_MOV_B5]], implicit [[REG_SEQUENCE]], implicit [[REG_SEQUENCE1]]
%0:vgpr(p0) = G_CONSTANT i64 0
%1:vgpr(p0) = G_CONSTANT i64 1
%2:vgpr(p0) = G_CONSTANT i64 -1
%3:vgpr(p0) = G_CONSTANT i64 -54
%4:vgpr(p0) = G_CONSTANT i64 27
%5:vgpr(p0) = G_CONSTANT i64 4294967295
%6:vgpr(p0) = G_CONSTANT i64 4294967296
%7:vgpr(p0) = G_CONSTANT i64 18446744004990098135
S_ENDPGM 0, implicit %0 , implicit %1 , implicit %2, implicit %3, implicit %4, implicit %5, implicit %6, implicit %7
...
---
name: constant_s_p4
legalized: true
regBankSelected: true
tracksRegLiveness: true

body: |
bb.0:
; WAVE64-LABEL: name: constant_s_p4
; WAVE64: [[S_MOV_B:%[0-9]+]]:sreg_64 = S_MOV_B64_IMM_PSEUDO 0
; WAVE64-NEXT: [[S_MOV_B1:%[0-9]+]]:sreg_64 = S_MOV_B64_IMM_PSEUDO 1
; WAVE64-NEXT: [[S_MOV_B2:%[0-9]+]]:sreg_64 = S_MOV_B64_IMM_PSEUDO -1
; WAVE64-NEXT: [[S_MOV_B3:%[0-9]+]]:sreg_64 = S_MOV_B64_IMM_PSEUDO -54
; WAVE64-NEXT: [[S_MOV_B4:%[0-9]+]]:sreg_64 = S_MOV_B64_IMM_PSEUDO 27
; WAVE64-NEXT: [[S_MOV_B5:%[0-9]+]]:sreg_64 = S_MOV_B64_IMM_PSEUDO 4294967295
; WAVE64-NEXT: [[S_MOV_B32_:%[0-9]+]]:sreg_32 = S_MOV_B32 0
; WAVE64-NEXT: [[S_MOV_B32_1:%[0-9]+]]:sreg_32 = S_MOV_B32 1
; WAVE64-NEXT: [[REG_SEQUENCE:%[0-9]+]]:sreg_64 = REG_SEQUENCE [[S_MOV_B32_]], %subreg.sub0, [[S_MOV_B32_1]], %subreg.sub1
; WAVE64-NEXT: [[S_MOV_B32_2:%[0-9]+]]:sreg_32 = S_MOV_B32 23255
; WAVE64-NEXT: [[S_MOV_B32_3:%[0-9]+]]:sreg_32 = S_MOV_B32 -16
; WAVE64-NEXT: [[REG_SEQUENCE1:%[0-9]+]]:sreg_64 = REG_SEQUENCE [[S_MOV_B32_2]], %subreg.sub0, [[S_MOV_B32_3]], %subreg.sub1
; WAVE64-NEXT: S_ENDPGM 0, implicit [[S_MOV_B]], implicit [[S_MOV_B1]], implicit [[S_MOV_B2]], implicit [[S_MOV_B3]], implicit [[S_MOV_B4]], implicit [[S_MOV_B5]], implicit [[REG_SEQUENCE]], implicit [[REG_SEQUENCE1]]
;
; WAVE32-LABEL: name: constant_s_p4
; WAVE32: [[S_MOV_B:%[0-9]+]]:sreg_64 = S_MOV_B64_IMM_PSEUDO 0
; WAVE32-NEXT: [[S_MOV_B1:%[0-9]+]]:sreg_64 = S_MOV_B64_IMM_PSEUDO 1
; WAVE32-NEXT: [[S_MOV_B2:%[0-9]+]]:sreg_64 = S_MOV_B64_IMM_PSEUDO -1
; WAVE32-NEXT: [[S_MOV_B3:%[0-9]+]]:sreg_64 = S_MOV_B64_IMM_PSEUDO -54
; WAVE32-NEXT: [[S_MOV_B4:%[0-9]+]]:sreg_64 = S_MOV_B64_IMM_PSEUDO 27
; WAVE32-NEXT: [[S_MOV_B5:%[0-9]+]]:sreg_64 = S_MOV_B64_IMM_PSEUDO 4294967295
; WAVE32-NEXT: [[S_MOV_B32_:%[0-9]+]]:sreg_32 = S_MOV_B32 0
; WAVE32-NEXT: [[S_MOV_B32_1:%[0-9]+]]:sreg_32 = S_MOV_B32 1
; WAVE32-NEXT: [[REG_SEQUENCE:%[0-9]+]]:sreg_64 = REG_SEQUENCE [[S_MOV_B32_]], %subreg.sub0, [[S_MOV_B32_1]], %subreg.sub1
; WAVE32-NEXT: [[S_MOV_B32_2:%[0-9]+]]:sreg_32 = S_MOV_B32 23255
; WAVE32-NEXT: [[S_MOV_B32_3:%[0-9]+]]:sreg_32 = S_MOV_B32 -16
; WAVE32-NEXT: [[REG_SEQUENCE1:%[0-9]+]]:sreg_64 = REG_SEQUENCE [[S_MOV_B32_2]], %subreg.sub0, [[S_MOV_B32_3]], %subreg.sub1
; WAVE32-NEXT: S_ENDPGM 0, implicit [[S_MOV_B]], implicit [[S_MOV_B1]], implicit [[S_MOV_B2]], implicit [[S_MOV_B3]], implicit [[S_MOV_B4]], implicit [[S_MOV_B5]], implicit [[REG_SEQUENCE]], implicit [[REG_SEQUENCE1]]
%0:sgpr(p4) = G_CONSTANT i64 0
%1:sgpr(p4) = G_CONSTANT i64 1
%2:sgpr(p4) = G_CONSTANT i64 -1
%3:sgpr(p4) = G_CONSTANT i64 -54
%4:sgpr(p4) = G_CONSTANT i64 27
%5:sgpr(p4) = G_CONSTANT i64 4294967295
%6:sgpr(p4) = G_CONSTANT i64 4294967296
%7:sgpr(p4) = G_CONSTANT i64 18446744004990098135
S_ENDPGM 0, implicit %0 , implicit %1 , implicit %2, implicit %3, implicit %4, implicit %5, implicit %6, implicit %7
...

---
name: constant_v_p4
legalized: true
regBankSelected: true
tracksRegLiveness: true

body: |
bb.0:
; WAVE64-LABEL: name: constant_v_p4
; WAVE64: [[V_MOV_B:%[0-9]+]]:vreg_64 = V_MOV_B64_PSEUDO 0, implicit $exec
; WAVE64-NEXT: [[V_MOV_B1:%[0-9]+]]:vreg_64 = V_MOV_B64_PSEUDO 1, implicit $exec
; WAVE64-NEXT: [[V_MOV_B2:%[0-9]+]]:vreg_64 = V_MOV_B64_PSEUDO -1, implicit $exec
; WAVE64-NEXT: [[V_MOV_B3:%[0-9]+]]:vreg_64 = V_MOV_B64_PSEUDO -54, implicit $exec
; WAVE64-NEXT: [[V_MOV_B4:%[0-9]+]]:vreg_64 = V_MOV_B64_PSEUDO 27, implicit $exec
; WAVE64-NEXT: [[V_MOV_B5:%[0-9]+]]:vreg_64 = V_MOV_B64_PSEUDO 4294967295, implicit $exec
; WAVE64-NEXT: [[V_MOV_B32_e32_:%[0-9]+]]:vgpr_32 = V_MOV_B32_e32 0, implicit $exec
; WAVE64-NEXT: [[V_MOV_B32_e32_1:%[0-9]+]]:vgpr_32 = V_MOV_B32_e32 1, implicit $exec
; WAVE64-NEXT: [[REG_SEQUENCE:%[0-9]+]]:vreg_64 = REG_SEQUENCE [[V_MOV_B32_e32_]], %subreg.sub0, [[V_MOV_B32_e32_1]], %subreg.sub1
; WAVE64-NEXT: [[V_MOV_B32_e32_2:%[0-9]+]]:vgpr_32 = V_MOV_B32_e32 23255, implicit $exec
; WAVE64-NEXT: [[V_MOV_B32_e32_3:%[0-9]+]]:vgpr_32 = V_MOV_B32_e32 -16, implicit $exec
; WAVE64-NEXT: [[REG_SEQUENCE1:%[0-9]+]]:vreg_64 = REG_SEQUENCE [[V_MOV_B32_e32_2]], %subreg.sub0, [[V_MOV_B32_e32_3]], %subreg.sub1
; WAVE64-NEXT: S_ENDPGM 0, implicit [[V_MOV_B]], implicit [[V_MOV_B1]], implicit [[V_MOV_B2]], implicit [[V_MOV_B3]], implicit [[V_MOV_B4]], implicit [[V_MOV_B5]], implicit [[REG_SEQUENCE]], implicit [[REG_SEQUENCE1]]
;
; WAVE32-LABEL: name: constant_v_p4
; WAVE32: [[V_MOV_B:%[0-9]+]]:vreg_64 = V_MOV_B64_PSEUDO 0, implicit $exec
; WAVE32-NEXT: [[V_MOV_B1:%[0-9]+]]:vreg_64 = V_MOV_B64_PSEUDO 1, implicit $exec
; WAVE32-NEXT: [[V_MOV_B2:%[0-9]+]]:vreg_64 = V_MOV_B64_PSEUDO -1, implicit $exec
; WAVE32-NEXT: [[V_MOV_B3:%[0-9]+]]:vreg_64 = V_MOV_B64_PSEUDO -54, implicit $exec
; WAVE32-NEXT: [[V_MOV_B4:%[0-9]+]]:vreg_64 = V_MOV_B64_PSEUDO 27, implicit $exec
; WAVE32-NEXT: [[V_MOV_B5:%[0-9]+]]:vreg_64 = V_MOV_B64_PSEUDO 4294967295, implicit $exec
; WAVE32-NEXT: [[V_MOV_B32_e32_:%[0-9]+]]:vgpr_32 = V_MOV_B32_e32 0, implicit $exec
; WAVE32-NEXT: [[V_MOV_B32_e32_1:%[0-9]+]]:vgpr_32 = V_MOV_B32_e32 1, implicit $exec
; WAVE32-NEXT: [[REG_SEQUENCE:%[0-9]+]]:vreg_64 = REG_SEQUENCE [[V_MOV_B32_e32_]], %subreg.sub0, [[V_MOV_B32_e32_1]], %subreg.sub1
; WAVE32-NEXT: [[V_MOV_B32_e32_2:%[0-9]+]]:vgpr_32 = V_MOV_B32_e32 23255, implicit $exec
; WAVE32-NEXT: [[V_MOV_B32_e32_3:%[0-9]+]]:vgpr_32 = V_MOV_B32_e32 -16, implicit $exec
; WAVE32-NEXT: [[REG_SEQUENCE1:%[0-9]+]]:vreg_64 = REG_SEQUENCE [[V_MOV_B32_e32_2]], %subreg.sub0, [[V_MOV_B32_e32_3]], %subreg.sub1
; WAVE32-NEXT: S_ENDPGM 0, implicit [[V_MOV_B]], implicit [[V_MOV_B1]], implicit [[V_MOV_B2]], implicit [[V_MOV_B3]], implicit [[V_MOV_B4]], implicit [[V_MOV_B5]], implicit [[REG_SEQUENCE]], implicit [[REG_SEQUENCE1]]
%0:vgpr(p4) = G_CONSTANT i64 0
%1:vgpr(p4) = G_CONSTANT i64 1
%2:vgpr(p4) = G_CONSTANT i64 -1
%3:vgpr(p4) = G_CONSTANT i64 -54
%4:vgpr(p4) = G_CONSTANT i64 27
%5:vgpr(p4) = G_CONSTANT i64 4294967295
%6:vgpr(p4) = G_CONSTANT i64 4294967296
%7:vgpr(p4) = G_CONSTANT i64 18446744004990098135
S_ENDPGM 0, implicit %0 , implicit %1 , implicit %2, implicit %3, implicit %4, implicit %5, implicit %6, implicit %7
...

---
name: constant_s_p999
legalized: true
Expand Down
212 changes: 206 additions & 6 deletions llvm/unittests/SandboxIR/SandboxIRTest.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -90,7 +90,7 @@ define void @foo(i32 %v1) {
EXPECT_FALSE(isa<sandboxir::Instruction>(Const0));
EXPECT_TRUE(isa<sandboxir::Instruction>(OpaqueI));

EXPECT_FALSE(isa<sandboxir::User>(F));
EXPECT_TRUE(isa<sandboxir::User>(F));
EXPECT_FALSE(isa<sandboxir::User>(Arg0));
EXPECT_FALSE(isa<sandboxir::User>(BB));
EXPECT_TRUE(isa<sandboxir::User>(AddI));
Expand Down Expand Up @@ -180,8 +180,8 @@ define i32 @foo(i32 %v0, i32 %v1) {
BS << "\n";
I0->getOperandUse(0).dump(BS);
EXPECT_EQ(Buff, R"IR(
Def: i32 %v0 ; SB1. (Argument)
User: %add0 = add i32 %v0, %v1 ; SB4. (Opaque)
Def: i32 %v0 ; SB2. (Argument)
User: %add0 = add i32 %v0, %v1 ; SB5. (Opaque)
OperandNo: 0
)IR");
#endif // NDEBUG
Expand Down Expand Up @@ -398,10 +398,10 @@ define void @foo(i32 %arg0, i32 %arg1) {
EXPECT_EQ(Buff, R"IR(
void @foo(i32 %arg0, i32 %arg1) {
bb0:
br label %bb1 ; SB3. (Br)
br label %bb1 ; SB4. (Br)
bb1:
ret void ; SB5. (Ret)
ret void ; SB6. (Ret)
}
)IR");
}
Expand Down Expand Up @@ -466,7 +466,7 @@ define void @foo(i32 %v1) {
BB0.dump(BS);
EXPECT_EQ(Buff, R"IR(
bb0:
br label %bb1 ; SB2. (Br)
br label %bb1 ; SB3. (Br)
)IR");
}
#endif // NDEBUG
Expand Down Expand Up @@ -836,3 +836,203 @@ define i8 @foo(i8 %val) {
sandboxir::ReturnInst::create(Val, /*InsertAtEnd=*/BB, Ctx));
EXPECT_EQ(NewRet4->getReturnValue(), Val);
}

TEST_F(SandboxIRTest, CallBase) {
parseIR(C, R"IR(
declare void @bar1(i8)
declare void @bar2()
declare void @bar3()
declare void @variadic(ptr, ...)
define i8 @foo(i8 %arg0, i32 %arg1, ptr %indirectFoo) {
%call = call i8 @foo(i8 %arg0, i32 %arg1)
call void @bar1(i8 %arg0)
call void @bar2()
call void %indirectFoo()
call void @bar2() noreturn
tail call fastcc void @bar2()
call void (ptr, ...) @variadic(ptr %indirectFoo, i32 1)
ret i8 %call
}
)IR");
llvm::Function &LLVMF = *M->getFunction("foo");
unsigned ArgIdx = 0;
llvm::Argument *LLVMArg0 = LLVMF.getArg(ArgIdx++);
llvm::Argument *LLVMArg1 = LLVMF.getArg(ArgIdx++);
llvm::BasicBlock *LLVMBB = &*LLVMF.begin();
SmallVector<llvm::CallBase *, 8> LLVMCalls;
auto LLVMIt = LLVMBB->begin();
while (isa<llvm::CallBase>(&*LLVMIt))
LLVMCalls.push_back(cast<llvm::CallBase>(&*LLVMIt++));

sandboxir::Context Ctx(C);
sandboxir::Function &F = *Ctx.createFunction(&LLVMF);

for (llvm::CallBase *LLVMCall : LLVMCalls) {
// Check classof(Instruction *).
auto *Call = cast<sandboxir::CallBase>(Ctx.getValue(LLVMCall));
// Check classof(Value *).
EXPECT_TRUE(isa<sandboxir::CallBase>((sandboxir::Value *)Call));
// Check getFunctionType().
EXPECT_EQ(Call->getFunctionType(), LLVMCall->getFunctionType());
// Check data_ops().
EXPECT_EQ(range_size(Call->data_ops()), range_size(LLVMCall->data_ops()));
auto DataOpIt = Call->data_operands_begin();
for (llvm::Use &LLVMUse : LLVMCall->data_ops()) {
Value *LLVMOp = LLVMUse.get();
sandboxir::Use Use = *DataOpIt++;
EXPECT_EQ(Ctx.getValue(LLVMOp), Use.get());
// Check isDataOperand().
EXPECT_EQ(Call->isDataOperand(Use), LLVMCall->isDataOperand(&LLVMUse));
// Check getDataOperandNo().
EXPECT_EQ(Call->getDataOperandNo(Use),
LLVMCall->getDataOperandNo(&LLVMUse));
// Check isArgOperand().
EXPECT_EQ(Call->isArgOperand(Use), LLVMCall->isArgOperand(&LLVMUse));
// Check isCallee().
EXPECT_EQ(Call->isCallee(Use), LLVMCall->isCallee(&LLVMUse));
}
// Check data_operands_empty().
EXPECT_EQ(Call->data_operands_empty(), LLVMCall->data_operands_empty());
// Check data_operands_size().
EXPECT_EQ(Call->data_operands_size(), LLVMCall->data_operands_size());
// Check getNumTotalBundleOperands().
EXPECT_EQ(Call->getNumTotalBundleOperands(),
LLVMCall->getNumTotalBundleOperands());
// Check args().
EXPECT_EQ(range_size(Call->args()), range_size(LLVMCall->args()));
auto ArgIt = Call->arg_begin();
for (llvm::Use &LLVMUse : LLVMCall->args()) {
Value *LLVMArg = LLVMUse.get();
sandboxir::Use Use = *ArgIt++;
EXPECT_EQ(Ctx.getValue(LLVMArg), Use.get());
}
// Check arg_empty().
EXPECT_EQ(Call->arg_empty(), LLVMCall->arg_empty());
// Check arg_size().
EXPECT_EQ(Call->arg_size(), LLVMCall->arg_size());
for (unsigned ArgIdx = 0, E = Call->arg_size(); ArgIdx != E; ++ArgIdx) {
// Check getArgOperand().
EXPECT_EQ(Call->getArgOperand(ArgIdx),
Ctx.getValue(LLVMCall->getArgOperand(ArgIdx)));
// Check getArgOperandUse().
sandboxir::Use Use = Call->getArgOperandUse(ArgIdx);
llvm::Use &LLVMUse = LLVMCall->getArgOperandUse(ArgIdx);
EXPECT_EQ(Use.get(), Ctx.getValue(LLVMUse.get()));
// Check getArgOperandNo().
EXPECT_EQ(Call->getArgOperandNo(Use),
LLVMCall->getArgOperandNo(&LLVMUse));
}
// Check hasArgument().
SmallVector<llvm::Value *> TestArgs(
{LLVMArg0, LLVMArg1, &LLVMF, LLVMBB, LLVMCall});
for (llvm::Value *LLVMV : TestArgs) {
sandboxir::Value *V = Ctx.getValue(LLVMV);
EXPECT_EQ(Call->hasArgument(V), LLVMCall->hasArgument(LLVMV));
}
// Check getCalledOperand().
EXPECT_EQ(Call->getCalledOperand(),
Ctx.getValue(LLVMCall->getCalledOperand()));
// Check getCalledOperandUse().
EXPECT_EQ(Call->getCalledOperandUse().get(),
Ctx.getValue(LLVMCall->getCalledOperandUse()));
// Check getCalledFunction().
if (LLVMCall->getCalledFunction() == nullptr)
EXPECT_EQ(Call->getCalledFunction(), nullptr);
else {
auto *LLVMCF = cast<llvm::Function>(LLVMCall->getCalledFunction());
(void)LLVMCF;
EXPECT_EQ(Call->getCalledFunction(),
cast<sandboxir::Function>(
Ctx.getValue(LLVMCall->getCalledFunction())));
}
// Check isIndirectCall().
EXPECT_EQ(Call->isIndirectCall(), LLVMCall->isIndirectCall());
// Check getCaller().
EXPECT_EQ(Call->getCaller(), Ctx.getValue(LLVMCall->getCaller()));
// Check isMustTailCall().
EXPECT_EQ(Call->isMustTailCall(), LLVMCall->isMustTailCall());
// Check isTailCall().
EXPECT_EQ(Call->isTailCall(), LLVMCall->isTailCall());
// Check getIntrinsicID().
EXPECT_EQ(Call->getIntrinsicID(), LLVMCall->getIntrinsicID());
// Check getCallingConv().
EXPECT_EQ(Call->getCallingConv(), LLVMCall->getCallingConv());
// Check isInlineAsm().
EXPECT_EQ(Call->isInlineAsm(), LLVMCall->isInlineAsm());
}

auto *Arg0 = F.getArg(0);
auto *Arg1 = F.getArg(1);
auto *BB = &*F.begin();
auto It = BB->begin();
auto *Call0 = cast<sandboxir::CallBase>(&*It++);
[[maybe_unused]] auto *Call1 = cast<sandboxir::CallBase>(&*It++);
auto *Call2 = cast<sandboxir::CallBase>(&*It++);
// Check setArgOperand
Call0->setArgOperand(0, Arg1);
EXPECT_EQ(Call0->getArgOperand(0), Arg1);
Call0->setArgOperand(0, Arg0);
EXPECT_EQ(Call0->getArgOperand(0), Arg0);

auto *Bar3F = Ctx.createFunction(M->getFunction("bar3"));

// Check setCalledOperand
auto *SvOp = Call0->getCalledOperand();
Call0->setCalledOperand(Bar3F);
EXPECT_EQ(Call0->getCalledOperand(), Bar3F);
Call0->setCalledOperand(SvOp);
// Check setCalledFunction
Call2->setCalledFunction(Bar3F);
EXPECT_EQ(Call2->getCalledFunction(), Bar3F);
}

TEST_F(SandboxIRTest, CallInst) {
parseIR(C, R"IR(
define i8 @foo(i8 %arg) {
%call = call i8 @foo(i8 %arg)
ret i8 %call
}
)IR");
Function &LLVMF = *M->getFunction("foo");
sandboxir::Context Ctx(C);
auto &F = *Ctx.createFunction(&LLVMF);
unsigned ArgIdx = 0;
auto *Arg0 = F.getArg(ArgIdx++);
auto *BB = &*F.begin();
auto It = BB->begin();
auto *Call = cast<sandboxir::CallInst>(&*It++);
auto *Ret = cast<sandboxir::ReturnInst>(&*It++);
EXPECT_EQ(Call->getNumOperands(), 2u);
EXPECT_EQ(Ret->getOpcode(), sandboxir::Instruction::Opcode::Ret);
FunctionType *FTy = F.getFunctionType();
SmallVector<sandboxir::Value *, 1> Args;
Args.push_back(Arg0);
{
// Check create() WhereIt.
auto *Call = cast<sandboxir::CallInst>(sandboxir::CallInst::create(
FTy, &F, Args, /*WhereIt=*/Ret->getIterator(), BB, Ctx));
EXPECT_EQ(Call->getNextNode(), Ret);
EXPECT_EQ(Call->getCalledFunction(), &F);
EXPECT_EQ(range_size(Call->args()), 1u);
EXPECT_EQ(Call->getArgOperand(0), Arg0);
}
{
// Check create() InsertBefore.
auto *Call = cast<sandboxir::CallInst>(
sandboxir::CallInst::create(FTy, &F, Args, /*InsertBefore=*/Ret, Ctx));
EXPECT_EQ(Call->getNextNode(), Ret);
EXPECT_EQ(Call->getCalledFunction(), &F);
EXPECT_EQ(range_size(Call->args()), 1u);
EXPECT_EQ(Call->getArgOperand(0), Arg0);
}
{
// Check create() InsertAtEnd.
auto *Call = cast<sandboxir::CallInst>(
sandboxir::CallInst::create(FTy, &F, Args, /*InsertAtEnd=*/BB, Ctx));
EXPECT_EQ(Call->getPrevNode(), Ret);
EXPECT_EQ(Call->getCalledFunction(), &F);
EXPECT_EQ(range_size(Call->args()), 1u);
EXPECT_EQ(Call->getArgOperand(0), Arg0);
}
}
75 changes: 75 additions & 0 deletions llvm/unittests/SandboxIR/TrackerTest.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -69,6 +69,34 @@ define void @foo(ptr %ptr) {
EXPECT_EQ(Ld->getOperand(0), Gep0);
}

TEST_F(TrackerTest, SetUse) {
parseIR(C, R"IR(
define void @foo(ptr %ptr, i8 %arg) {
%ld = load i8, ptr %ptr
%add = add i8 %ld, %arg
ret void
}
)IR");
Function &LLVMF = *M->getFunction("foo");
sandboxir::Context Ctx(C);
auto *F = Ctx.createFunction(&LLVMF);
unsigned ArgIdx = 0;
auto *Arg0 = F->getArg(ArgIdx++);
auto *BB = &*F->begin();
auto &Tracker = Ctx.getTracker();
Tracker.save();
auto It = BB->begin();
auto *Ld = &*It++;
auto *Add = &*It++;

Ctx.save();
sandboxir::Use Use = Add->getOperandUse(0);
Use.set(Arg0);
EXPECT_EQ(Add->getOperand(0), Arg0);
Ctx.revert();
EXPECT_EQ(Add->getOperand(0), Ld);
}

TEST_F(TrackerTest, SwapOperands) {
parseIR(C, R"IR(
define void @foo(i1 %cond) {
Expand Down Expand Up @@ -413,3 +441,50 @@ define i32 @foo(i32 %arg) {
EXPECT_EQ(&*It++, Ret);
EXPECT_EQ(It, BB->end());
}

TEST_F(TrackerTest, CallBaseSetters) {
parseIR(C, R"IR(
declare void @bar1(i8)
declare void @bar2(i8)
define void @foo(i8 %arg0, i8 %arg1) {
call void @bar1(i8 %arg0)
ret void
}
)IR");
Function &LLVMF = *M->getFunction("foo");
sandboxir::Context Ctx(C);

auto *F = Ctx.createFunction(&LLVMF);
unsigned ArgIdx = 0;
auto *Arg0 = F->getArg(ArgIdx++);
auto *Arg1 = F->getArg(ArgIdx++);
auto *BB = &*F->begin();
auto It = BB->begin();
auto *Call = cast<sandboxir::CallBase>(&*It++);
[[maybe_unused]] auto *Ret = cast<sandboxir::ReturnInst>(&*It++);

// Check setArgOperand().
Ctx.save();
Call->setArgOperand(0, Arg1);
EXPECT_EQ(Call->getArgOperand(0), Arg1);
Ctx.revert();
EXPECT_EQ(Call->getArgOperand(0), Arg0);

auto *Bar1F = Call->getCalledFunction();
auto *Bar2F = Ctx.createFunction(M->getFunction("bar2"));

// Check setCalledOperand().
Ctx.save();
Call->setCalledOperand(Bar2F);
EXPECT_EQ(Call->getCalledOperand(), Bar2F);
Ctx.revert();
EXPECT_EQ(Call->getCalledOperand(), Bar1F);

// Check setCalledFunction().
Ctx.save();
Call->setCalledFunction(Bar2F);
EXPECT_EQ(Call->getCalledFunction(), Bar2F);
Ctx.revert();
EXPECT_EQ(Call->getCalledFunction(), Bar1F);
}
42 changes: 40 additions & 2 deletions mlir/lib/Conversion/VectorToSPIRV/VectorToSPIRV.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -906,6 +906,43 @@ struct VectorReductionToFPDotProd final
}
};

struct VectorStepOpConvert final : OpConversionPattern<vector::StepOp> {
using OpConversionPattern::OpConversionPattern;

LogicalResult
matchAndRewrite(vector::StepOp stepOp, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
const auto &typeConverter = *getTypeConverter<SPIRVTypeConverter>();
Type dstType = typeConverter.convertType(stepOp.getType());
if (!dstType)
return failure();

Location loc = stepOp.getLoc();
int64_t numElements = stepOp.getType().getNumElements();
auto intType =
rewriter.getIntegerType(typeConverter.getIndexTypeBitwidth());

// Input vectors of size 1 are converted to scalars by the type converter.
// We just create a constant in this case.
if (numElements == 1) {
Value zero = spirv::ConstantOp::getZero(intType, loc, rewriter);
rewriter.replaceOp(stepOp, zero);
return success();
}

SmallVector<Value> source;
source.reserve(numElements);
for (int64_t i = 0; i < numElements; ++i) {
Attribute intAttr = rewriter.getIntegerAttr(intType, i);
Value constOp = rewriter.create<spirv::ConstantOp>(loc, intType, intAttr);
source.push_back(constOp);
}
rewriter.replaceOpWithNewOp<spirv::CompositeConstructOp>(stepOp, dstType,
source);
return success();
}
};

} // namespace
#define CL_INT_MAX_MIN_OPS \
spirv::CLUMaxOp, spirv::CLUMinOp, spirv::CLSMaxOp, spirv::CLSMinOp
Expand All @@ -929,8 +966,9 @@ void mlir::populateVectorToSPIRVPatterns(SPIRVTypeConverter &typeConverter,
VectorReductionFloatMinMax<GL_FLOAT_MAX_MIN_OPS>, VectorShapeCast,
VectorInsertStridedSliceOpConvert, VectorShuffleOpConvert,
VectorInterleaveOpConvert, VectorDeinterleaveOpConvert,
VectorSplatPattern, VectorLoadOpConverter, VectorStoreOpConverter>(
typeConverter, patterns.getContext(), PatternBenefit(1));
VectorSplatPattern, VectorLoadOpConverter, VectorStoreOpConverter,
VectorStepOpConvert>(typeConverter, patterns.getContext(),
PatternBenefit(1));

// Make sure that the more specialized dot product pattern has higher benefit
// than the generic one that extracts all elements.
Expand Down
26 changes: 26 additions & 0 deletions mlir/test/Conversion/VectorToSPIRV/vector-to-spirv.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -794,6 +794,32 @@ func.func @shape_cast_size1_vector(%arg0 : vector<f32>) -> vector<1xf32> {

// -----

// CHECK-LABEL: @step()
// CHECK: %[[CST0:.*]] = spirv.Constant 0 : i32
// CHECK: %[[CST1:.*]] = spirv.Constant 1 : i32
// CHECK: %[[CST2:.*]] = spirv.Constant 2 : i32
// CHECK: %[[CST3:.*]] = spirv.Constant 3 : i32
// CHECK: %[[CONSTRUCT:.*]] = spirv.CompositeConstruct %[[CST0]], %[[CST1]], %[[CST2]], %[[CST3]] : (i32, i32, i32, i32) -> vector<4xi32>
// CHECK: %[[CAST:.*]] = builtin.unrealized_conversion_cast %[[CONSTRUCT]] : vector<4xi32> to vector<4xindex>
// CHECK: return %[[CAST]] : vector<4xindex>
func.func @step() -> vector<4xindex> {
%0 = vector.step : vector<4xindex>
return %0 : vector<4xindex>
}

// -----

// CHECK-LABEL: @step_size1()
// CHECK: %[[CST0:.*]] = spirv.Constant 0 : i32
// CHECK: %[[CAST:.*]] = builtin.unrealized_conversion_cast %[[CST0]] : i32 to vector<1xindex>
// CHECK: return %[[CAST]] : vector<1xindex>
func.func @step_size1() -> vector<1xindex> {
%0 = vector.step : vector<1xindex>
return %0 : vector<1xindex>
}

// -----

module attributes {
spirv.target_env = #spirv.target_env<
#spirv.vce<v1.0, [Shader], [SPV_KHR_storage_buffer_storage_class]>, #spirv.resource_limits<>>
Expand Down