Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
67 changes: 67 additions & 0 deletions tools/clang/lib/SPIRV/SpirvEmitter.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -8740,6 +8740,11 @@ SpirvEmitter::processIntrinsicCallExpr(const CallExpr *callExpr) {
}
break;
}
case hlsl::IntrinsicOp::IOP_dot4add_i8packed:
case hlsl::IntrinsicOp::IOP_dot4add_u8packed: {
retVal = processIntrinsicDP4a(callExpr, hlslOpcode);
break;
}
case hlsl::IntrinsicOp::IOP_pack_s8:
case hlsl::IntrinsicOp::IOP_pack_u8:
case hlsl::IntrinsicOp::IOP_pack_clamp_s8:
Expand Down Expand Up @@ -11370,6 +11375,68 @@ SpirvEmitter::processIntrinsicLog10(const CallExpr *callExpr) {
range);
}

SpirvInstruction *SpirvEmitter::processIntrinsicDP4a(const CallExpr *callExpr,
hlsl::IntrinsicOp op) {
// Processing the `dot4add_i8packed` and `dot4add_u8packed` intrinsics.
// There is no direct substitution for them in SPIR-V, but the combination
// of OpSDot / OpUDot and OpIAdd works. Note that the OpSDotAccSat and
// OpUDotAccSat operations are not matching the HLSL intrinsics as there
// should not be any saturation.
//
// int32 dot4add_i8packed(uint32 a, uint32 b, int32 acc);
// A 4-dimensional signed integer dot-product with add. Multiplies together
// each corresponding pair of signed 8-bit int bytes in the two input
// DWORDs, and sums the results into the 32-bit signed integer accumulator.
//
// uint32 dot4add_u8packed(uint32 a, uint32 b, uint32 acc);
// A 4-dimensional unsigned integer dot-product with add. Multiplies
// together each corresponding pair of unsigned 8-bit int bytes in the two
// input DWORDs, and sums the results into the 32-bit unsigned integer
// accumulator.

auto loc = callExpr->getExprLoc();
auto range = callExpr->getSourceRange();
assert(op == hlsl::IntrinsicOp::IOP_dot4add_i8packed ||
op == hlsl::IntrinsicOp::IOP_dot4add_u8packed);

// Validate the argument count - if it's wrong, the compiler won't get
// here anyway, so an assert should be fine.
assert(callExpr->getNumArgs() == 3u);

// Prepare the three arguments.
const Expr *arg0 = callExpr->getArg(0);
const Expr *arg1 = callExpr->getArg(1);
const Expr *arg2 = callExpr->getArg(2);
auto *arg0Instr = doExpr(arg0);
auto *arg1Instr = doExpr(arg1);
auto *arg2Instr = doExpr(arg2);

// Prepare the array inputs for createSpirvIntrInstExt below.
// Need to use this function because the OpSDot/OpUDot operations require
// two capabilities and an extension to be declared in the module.
SpirvInstruction *operands[]{arg0Instr, arg1Instr};
uint32_t capabilities[]{
uint32_t(spv::Capability::DotProduct),
uint32_t(spv::Capability::DotProductInput4x8BitPacked)};
llvm::StringRef extensions[]{"SPV_KHR_integer_dot_product"};
llvm::StringRef instSet = "";

// Pick the opcode based on the instruction.
const bool isSigned = op == hlsl::IntrinsicOp::IOP_dot4add_i8packed;
const spv::Op spirvOp = isSigned ? spv::Op::OpSDot : spv::Op::OpUDot;

const auto returnType = callExpr->getType();

// Create the dot product instruction.
auto *dotResult =
spvBuilder.createSpirvIntrInstExt(uint32_t(spirvOp), returnType, operands,
extensions, instSet, capabilities, loc);

// Create and return the integer addition instruction.
return spvBuilder.createBinaryOp(spv::Op::OpIAdd, returnType, dotResult,
arg2Instr, loc, range);
}

SpirvInstruction *
SpirvEmitter::processIntrinsic8BitPack(const CallExpr *callExpr,
hlsl::IntrinsicOp op) {
Expand Down
4 changes: 4 additions & 0 deletions tools/clang/lib/SPIRV/SpirvEmitter.h
Original file line number Diff line number Diff line change
Expand Up @@ -649,6 +649,10 @@ class SpirvEmitter : public ASTConsumer {
/// Processes the NonUniformResourceIndex intrinsic function.
SpirvInstruction *processIntrinsicNonUniformResourceIndex(const CallExpr *);

/// Processes the SM 6.4 dot4add_{i|u}8packed intrinsic functions.
SpirvInstruction *processIntrinsicDP4a(const CallExpr *callExpr,
hlsl::IntrinsicOp op);

/// Processes the SM 6.6 pack_{s|u}8 and pack_clamp_{s|u}8 intrinsic
/// functions.
SpirvInstruction *processIntrinsic8BitPack(const CallExpr *,
Expand Down
34 changes: 18 additions & 16 deletions tools/clang/test/CodeGenSPIRV_Lit/intrinsics.dot4add.hlsl
Original file line number Diff line number Diff line change
@@ -1,23 +1,25 @@
// RUN: not %dxc -E frag -T ps_6_4 -fspv-target-env=vulkan1.1 -fcgl %s -spirv 2>&1 | FileCheck %s
// RUN: %dxc -E main -T ps_6_4 -fspv-target-env=vulkan1.1 -fcgl %s -spirv 2>&1 | FileCheck %s

uint frag(float4 vertex
: SV_POSITION) : SV_Target {
float2 main(uint4 inputs : Inputs0, uint acc0 : Acc0, int acc1 : Acc1) : SV_Target {
uint acc = 0;

// CHECK: 8:14: error: dot4add_u8packed intrinsic function unimplemented
acc += 1 + dot4add_u8packed(vertex.x, vertex.y, uint(vertex.y));
// CHECK: [[input_x_ref:%[0-9]+]] = OpAccessChain %_ptr_Function_uint %inputs %int_0
// CHECK-NEXT: [[input_x_val:%[0-9]+]] = OpLoad %uint [[input_x_ref]]
// CHECK-NEXT: [[input_y_ref:%[0-9]+]] = OpAccessChain %_ptr_Function_uint %inputs %int_1
// CHECK-NEXT: [[input_y_val:%[0-9]+]] = OpLoad %uint [[input_y_ref]]
// CHECK-NEXT: [[a0:%[0-9]+]] = OpLoad %uint %acc0
// CHECK-NEXT: [[t0:%[0-9]+]] = OpUDot %uint [[input_x_val]] [[input_y_val]]
// CHECK-NEXT: [[t1:%[0-9]+]] = OpIAdd %uint [[t0]] [[a0]]
acc += dot4add_u8packed(inputs.x, inputs.y, acc0);

// CHECK: 11:14: error: dot4add_i8packed intrinsic function unimplemented
acc += 2 + dot4add_i8packed(vertex.z, vertex.w, int(vertex.x));

// CHECK: 14:10: error: dot4add_u8packed intrinsic function unimplemented
acc += dot4add_u8packed(vertex.x, vertex.y, uint(vertex.y)) + 1;

// CHECK: 17:13: error: dot4add_u8packed intrinsic function unimplemented
acc = 1 + dot4add_u8packed(vertex.x, vertex.y, uint(vertex.y));

// CHECK: 20:9: error: dot4add_u8packed intrinsic function unimplemented
acc = dot4add_u8packed(vertex.x, vertex.y, uint(vertex.y)) + 1;
// CHECK: [[input_z_ref:%[0-9]+]] = OpAccessChain %_ptr_Function_uint %inputs %int_2
// CHECK-NEXT: [[input_z_val:%[0-9]+]] = OpLoad %uint [[input_z_ref]]
// CHECK-NEXT: [[input_w_ref:%[0-9]+]] = OpAccessChain %_ptr_Function_uint %inputs %int_3
// CHECK-NEXT: [[input_w_val:%[0-9]+]] = OpLoad %uint [[input_w_ref]]
// CHECK-NEXT: [[a1:%[0-9]+]] = OpLoad %int %acc1
// CHECK-NEXT: [[t2:%[0-9]+]] = OpSDot %int [[input_z_val]] [[input_w_val]]
// CHECK-NEXT: [[t3:%[0-9]+]] = OpIAdd %int [[t2]] [[a1]]
acc += dot4add_i8packed(inputs.z, inputs.w, acc1);

return acc;
}