Skip to content

Commit

Permalink
[HLSL][DXIL][SPIRV] Implementation of an abstraction for intrinsic se…
Browse files Browse the repository at this point in the history
…lection of HLSL backends (#87171)

Start of #83882
- `Builtins.td` - add the `hlsl` `all` elementwise builtin.
- `CGBuiltin.cpp` - Show a use case for CGHLSLUtils via an `all`
intrinsic codegen.
- `CGHLSLRuntime.cpp` - move `thread_id` to use CGHLSLUtils.
- `CGHLSLRuntime.h` - Create a macro to help pick the right intrinsic
for the backend.
- `hlsl_intrinsics.h` - Add the `all` api.
- `SemaChecking.cpp` - Add `all` builtin type checking
- `IntrinsicsDirectX.td` - Add the `all` `dx` intrinsic
- `IntrinsicsSPIRV.td` - Add the `all` `spv` intrinsic

Work still needed:
- `SPIRVInstructionSelector.cpp` - Add an implementation of `OpAll` for
`spv_all` intrinsic
  • Loading branch information
farzonl committed Apr 5, 2024
1 parent 1341760 commit 1cb64d7
Show file tree
Hide file tree
Showing 9 changed files with 445 additions and 15 deletions.
6 changes: 6 additions & 0 deletions clang/include/clang/Basic/Builtins.td
Original file line number Diff line number Diff line change
Expand Up @@ -4587,6 +4587,12 @@ def GetDeviceSideMangledName : LangBuiltin<"CUDA_LANG"> {
}

// HLSL
def HLSLAll : LangBuiltin<"HLSL_LANG"> {
let Spellings = ["__builtin_hlsl_elementwise_all"];
let Attributes = [NoThrow, Const];
let Prototype = "bool(...)";
}

def HLSLAny : LangBuiltin<"HLSL_LANG"> {
let Spellings = ["__builtin_hlsl_elementwise_any"];
let Attributes = [NoThrow, Const];
Expand Down
8 changes: 8 additions & 0 deletions clang/lib/CodeGen/CGBuiltin.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
#include "ABIInfo.h"
#include "CGCUDARuntime.h"
#include "CGCXXABI.h"
#include "CGHLSLRuntime.h"
#include "CGObjCRuntime.h"
#include "CGOpenCLRuntime.h"
#include "CGRecordLayout.h"
Expand Down Expand Up @@ -18172,6 +18173,13 @@ Value *CodeGenFunction::EmitHLSLBuiltinExpr(unsigned BuiltinID,
return nullptr;

switch (BuiltinID) {
case Builtin::BI__builtin_hlsl_elementwise_all: {
Value *Op0 = EmitScalarExpr(E->getArg(0));
return Builder.CreateIntrinsic(
/*ReturnType=*/llvm::Type::getInt1Ty(getLLVMContext()),
CGM.getHLSLRuntime().getAllIntrinsic(), ArrayRef<Value *>{Op0}, nullptr,
"hlsl.all");
}
case Builtin::BI__builtin_hlsl_elementwise_any: {
Value *Op0 = EmitScalarExpr(E->getArg(0));
return Builder.CreateIntrinsic(
Expand Down
20 changes: 6 additions & 14 deletions clang/lib/CodeGen/CGHLSLRuntime.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -17,8 +17,6 @@
#include "CodeGenModule.h"
#include "clang/AST/Decl.h"
#include "clang/Basic/TargetOptions.h"
#include "llvm/IR/IntrinsicsDirectX.h"
#include "llvm/IR/IntrinsicsSPIRV.h"
#include "llvm/IR/Metadata.h"
#include "llvm/IR/Module.h"
#include "llvm/Support/FormatVariadic.h"
Expand Down Expand Up @@ -117,6 +115,10 @@ GlobalVariable *replaceBuffer(CGHLSLRuntime::Buffer &Buf) {

} // namespace

llvm::Triple::ArchType CGHLSLRuntime::getArch() {
return CGM.getTarget().getTriple().getArch();
}

void CGHLSLRuntime::addConstant(VarDecl *D, Buffer &CB) {
if (D->getStorageClass() == SC_Static) {
// For static inside cbuffer, take as global static.
Expand Down Expand Up @@ -343,18 +345,8 @@ llvm::Value *CGHLSLRuntime::emitInputSemantic(IRBuilder<> &B,
return B.CreateCall(FunctionCallee(DxGroupIndex));
}
if (D.hasAttr<HLSLSV_DispatchThreadIDAttr>()) {
llvm::Function *ThreadIDIntrinsic;
switch (CGM.getTarget().getTriple().getArch()) {
case llvm::Triple::dxil:
ThreadIDIntrinsic = CGM.getIntrinsic(Intrinsic::dx_thread_id);
break;
case llvm::Triple::spirv:
ThreadIDIntrinsic = CGM.getIntrinsic(Intrinsic::spv_thread_id);
break;
default:
llvm_unreachable("Input semantic not supported by target");
break;
}
llvm::Function *ThreadIDIntrinsic =
CGM.getIntrinsic(getThreadIdIntrinsic());
return buildVectorInput(B, ThreadIDIntrinsic, Ty);
}
assert(false && "Unhandled parameter attribute");
Expand Down
32 changes: 32 additions & 0 deletions clang/lib/CodeGen/CGHLSLRuntime.h
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,11 @@
#define LLVM_CLANG_LIB_CODEGEN_CGHLSLRUNTIME_H

#include "llvm/IR/IRBuilder.h"
#include "llvm/IR/Intrinsics.h"
#include "llvm/IR/IntrinsicsDirectX.h"
#include "llvm/IR/IntrinsicsSPIRV.h"

#include "clang/Basic/Builtins.h"
#include "clang/Basic/HLSLRuntime.h"

#include "llvm/ADT/SmallVector.h"
Expand All @@ -26,6 +30,22 @@
#include <optional>
#include <vector>

// A function generator macro for picking the right intrinsic
// for the target backend
#define GENERATE_HLSL_INTRINSIC_FUNCTION(FunctionName, IntrinsicPostfix) \
llvm::Intrinsic::ID get##FunctionName##Intrinsic() { \
llvm::Triple::ArchType Arch = getArch(); \
switch (Arch) { \
case llvm::Triple::dxil: \
return llvm::Intrinsic::dx_##IntrinsicPostfix; \
case llvm::Triple::spirv: \
return llvm::Intrinsic::spv_##IntrinsicPostfix; \
default: \
llvm_unreachable("Intrinsic " #IntrinsicPostfix \
" not supported by target architecture"); \
} \
}

namespace llvm {
class GlobalVariable;
class Function;
Expand All @@ -48,6 +68,17 @@ class CodeGenModule;

class CGHLSLRuntime {
public:
//===----------------------------------------------------------------------===//
// Start of reserved area for HLSL intrinsic getters.
//===----------------------------------------------------------------------===//

GENERATE_HLSL_INTRINSIC_FUNCTION(All, all)
GENERATE_HLSL_INTRINSIC_FUNCTION(ThreadId, thread_id)

//===----------------------------------------------------------------------===//
// End of reserved area for HLSL intrinsic getters.
//===----------------------------------------------------------------------===//

struct BufferResBinding {
// The ID like 2 in register(b2, space1).
std::optional<unsigned> Reg;
Expand Down Expand Up @@ -96,6 +127,7 @@ class CGHLSLRuntime {
BufferResBinding &Binding);
void addConstant(VarDecl *D, Buffer &CB);
void addBufferDecls(const DeclContext *DC, Buffer &CB);
llvm::Triple::ArchType getArch();
llvm::SmallVector<Buffer> Buffers;
};

Expand Down
112 changes: 112 additions & 0 deletions clang/lib/Headers/hlsl/hlsl_intrinsics.h
Original file line number Diff line number Diff line change
Expand Up @@ -100,6 +100,118 @@ double3 abs(double3);
_HLSL_BUILTIN_ALIAS(__builtin_elementwise_abs)
double4 abs(double4);

//===----------------------------------------------------------------------===//
// all builtins
//===----------------------------------------------------------------------===//

/// \fn bool all(T x)
/// \brief Returns True if all components of the \a x parameter are non-zero;
/// otherwise, false. \param x The input value.

#ifdef __HLSL_ENABLE_16_BIT
_HLSL_AVAILABILITY(shadermodel, 6.2)
_HLSL_BUILTIN_ALIAS(__builtin_hlsl_elementwise_all)
bool all(int16_t);
_HLSL_AVAILABILITY(shadermodel, 6.2)
_HLSL_BUILTIN_ALIAS(__builtin_hlsl_elementwise_all)
bool all(int16_t2);
_HLSL_AVAILABILITY(shadermodel, 6.2)
_HLSL_BUILTIN_ALIAS(__builtin_hlsl_elementwise_all)
bool all(int16_t3);
_HLSL_AVAILABILITY(shadermodel, 6.2)
_HLSL_BUILTIN_ALIAS(__builtin_hlsl_elementwise_all)
bool all(int16_t4);
_HLSL_AVAILABILITY(shadermodel, 6.2)
_HLSL_BUILTIN_ALIAS(__builtin_hlsl_elementwise_all)
bool all(uint16_t);
_HLSL_AVAILABILITY(shadermodel, 6.2)
_HLSL_BUILTIN_ALIAS(__builtin_hlsl_elementwise_all)
bool all(uint16_t2);
_HLSL_AVAILABILITY(shadermodel, 6.2)
_HLSL_BUILTIN_ALIAS(__builtin_hlsl_elementwise_all)
bool all(uint16_t3);
_HLSL_AVAILABILITY(shadermodel, 6.2)
_HLSL_BUILTIN_ALIAS(__builtin_hlsl_elementwise_all)
bool all(uint16_t4);
#endif

_HLSL_16BIT_AVAILABILITY(shadermodel, 6.2)
_HLSL_BUILTIN_ALIAS(__builtin_hlsl_elementwise_all)
bool all(half);
_HLSL_16BIT_AVAILABILITY(shadermodel, 6.2)
_HLSL_BUILTIN_ALIAS(__builtin_hlsl_elementwise_all)
bool all(half2);
_HLSL_16BIT_AVAILABILITY(shadermodel, 6.2)
_HLSL_BUILTIN_ALIAS(__builtin_hlsl_elementwise_all)
bool all(half3);
_HLSL_16BIT_AVAILABILITY(shadermodel, 6.2)
_HLSL_BUILTIN_ALIAS(__builtin_hlsl_elementwise_all)
bool all(half4);

_HLSL_BUILTIN_ALIAS(__builtin_hlsl_elementwise_all)
bool all(bool);
_HLSL_BUILTIN_ALIAS(__builtin_hlsl_elementwise_all)
bool all(bool2);
_HLSL_BUILTIN_ALIAS(__builtin_hlsl_elementwise_all)
bool all(bool3);
_HLSL_BUILTIN_ALIAS(__builtin_hlsl_elementwise_all)
bool all(bool4);
_HLSL_BUILTIN_ALIAS(__builtin_hlsl_elementwise_all)

_HLSL_BUILTIN_ALIAS(__builtin_hlsl_elementwise_all)
bool all(int);
_HLSL_BUILTIN_ALIAS(__builtin_hlsl_elementwise_all)
bool all(int2);
_HLSL_BUILTIN_ALIAS(__builtin_hlsl_elementwise_all)
bool all(int3);
_HLSL_BUILTIN_ALIAS(__builtin_hlsl_elementwise_all)
bool all(int4);

_HLSL_BUILTIN_ALIAS(__builtin_hlsl_elementwise_all)
bool all(uint);
_HLSL_BUILTIN_ALIAS(__builtin_hlsl_elementwise_all)
bool all(uint2);
_HLSL_BUILTIN_ALIAS(__builtin_hlsl_elementwise_all)
bool all(uint3);
_HLSL_BUILTIN_ALIAS(__builtin_hlsl_elementwise_all)
bool all(uint4);

_HLSL_BUILTIN_ALIAS(__builtin_hlsl_elementwise_all)
bool all(float);
_HLSL_BUILTIN_ALIAS(__builtin_hlsl_elementwise_all)
bool all(float2);
_HLSL_BUILTIN_ALIAS(__builtin_hlsl_elementwise_all)
bool all(float3);
_HLSL_BUILTIN_ALIAS(__builtin_hlsl_elementwise_all)
bool all(float4);

_HLSL_BUILTIN_ALIAS(__builtin_hlsl_elementwise_all)
bool all(int64_t);
_HLSL_BUILTIN_ALIAS(__builtin_hlsl_elementwise_all)
bool all(int64_t2);
_HLSL_BUILTIN_ALIAS(__builtin_hlsl_elementwise_all)
bool all(int64_t3);
_HLSL_BUILTIN_ALIAS(__builtin_hlsl_elementwise_all)
bool all(int64_t4);

_HLSL_BUILTIN_ALIAS(__builtin_hlsl_elementwise_all)
bool all(uint64_t);
_HLSL_BUILTIN_ALIAS(__builtin_hlsl_elementwise_all)
bool all(uint64_t2);
_HLSL_BUILTIN_ALIAS(__builtin_hlsl_elementwise_all)
bool all(uint64_t3);
_HLSL_BUILTIN_ALIAS(__builtin_hlsl_elementwise_all)
bool all(uint64_t4);

_HLSL_BUILTIN_ALIAS(__builtin_hlsl_elementwise_all)
bool all(double);
_HLSL_BUILTIN_ALIAS(__builtin_hlsl_elementwise_all)
bool all(double2);
_HLSL_BUILTIN_ALIAS(__builtin_hlsl_elementwise_all)
bool all(double3);
_HLSL_BUILTIN_ALIAS(__builtin_hlsl_elementwise_all)
bool all(double4);

//===----------------------------------------------------------------------===//
// any builtins
//===----------------------------------------------------------------------===//
Expand Down
1 change: 1 addition & 0 deletions clang/lib/Sema/SemaChecking.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -5563,6 +5563,7 @@ void SetElementTypeAsReturnType(Sema *S, CallExpr *TheCall,
// returning an ExprError
bool Sema::CheckHLSLBuiltinFunctionCall(unsigned BuiltinID, CallExpr *TheCall) {
switch (BuiltinID) {
case Builtin::BI__builtin_hlsl_elementwise_all:
case Builtin::BI__builtin_hlsl_elementwise_any: {
if (checkArgCount(*this, TheCall, 1))
return true;
Expand Down

0 comments on commit 1cb64d7

Please sign in to comment.