Skip to content

Commit

Permalink
[spirv] Relax SV_Position type requirements (#4275)
Browse files Browse the repository at this point in the history
A valid vertex shader output variable with SV_Position semantics may be
constructed from any HLSL BuiltinType that translates to a 32-bit
floating point type in the SPIR-V backend, so relax the requirements to
allow the use of additonal types (such as half4) when
-enable-16bit-types is false.

Fixes #4262
  • Loading branch information
sudonatalie committed Feb 22, 2022
1 parent 3f8e22c commit 3fcd83e
Show file tree
Hide file tree
Showing 2 changed files with 92 additions and 29 deletions.
38 changes: 29 additions & 9 deletions tools/clang/lib/SPIRV/DeclResultIdMapper.cpp
Expand Up @@ -517,9 +517,27 @@ bool insertSeenSemanticsForEntryPointIfNotExist(
return true;
}

// Returns whether the type is float4 or a composite type recursively including
// only float4 e.g., float4, float4[1], struct S { float4 foo[1]; }.
bool containOnlyVecWithFourFloats(QualType type) {
// Returns whether the type is translated to a 32-bit floating point type,
// depending on whether SPIR-V codegen options are configured to use 16-bit
// types when possible.
bool is32BitFloatingPointType(BuiltinType::Kind kind, bool use16Bit) {
// Always translated into 32-bit floating point types.
if (kind == BuiltinType::Float || kind == BuiltinType::LitFloat)
return true;

// Translated into 32-bit floating point types when run without
// -enable-16bit-types.
if (kind == BuiltinType::Half || kind == BuiltinType::HalfFloat ||
kind == BuiltinType::Min10Float || kind == BuiltinType::Min16Float)
return !use16Bit;

return false;
}

// Returns whether the type is a 4-component 32-bit float or a composite type
// recursively including only such a vector e.g., float4, float4[1], struct S {
// float4 foo[1]; }.
bool containOnlyVecWithFourFloats(QualType type, bool use16Bit) {
if (type->isReferenceType())
type = type->getPointeeType();

Expand All @@ -532,15 +550,15 @@ bool containOnlyVecWithFourFloats(QualType type) {
(const ConstantArrayType *)type->getAsArrayTypeUnsafe();
elemCount = hlsl::GetArraySize(type);
return elemCount == 1 &&
containOnlyVecWithFourFloats(arrayType->getElementType());
containOnlyVecWithFourFloats(arrayType->getElementType(), use16Bit);
}

if (const auto *structType = type->getAs<RecordType>()) {
uint32_t fieldCount = 0;
for (const auto *field : structType->getDecl()->fields()) {
if (fieldCount != 0)
return false;
if (!containOnlyVecWithFourFloats(field->getType()))
if (!containOnlyVecWithFourFloats(field->getType(), use16Bit))
return false;
++fieldCount;
}
Expand All @@ -550,7 +568,8 @@ bool containOnlyVecWithFourFloats(QualType type) {
QualType elemType = {};
if (isVectorType(type, &elemType, &elemCount)) {
if (const auto *builtinType = elemType->getAs<BuiltinType>()) {
return elemCount == 4 && builtinType->getKind() == BuiltinType::Float;
return elemCount == 4 &&
is32BitFloatingPointType(builtinType->getKind(), use16Bit);
}
return false;
}
Expand Down Expand Up @@ -3300,9 +3319,10 @@ SpirvVariable *DeclResultIdMapper::createSpirvStageVar(
// by VSOut, HS/DS/GS In/Out, MSOut.
case hlsl::Semantic::Kind::Position: {
if (sigPointKind == hlsl::SigPoint::Kind::VSOut &&
!containOnlyVecWithFourFloats(type)) {
emitError("semantic Position must be float4 or a composite type "
"recursively including only float4",
!containOnlyVecWithFourFloats(
type, theEmitter.getSpirvOptions().enable16BitTypes)) {
emitError("SV_Position must be a 4-component 32-bit float vector or a "
"composite which recursively contains only such a vector",
srcLoc);
}

Expand Down
83 changes: 63 additions & 20 deletions tools/clang/unittests/SPIRV/CodeGenSpirvTest.cpp
Expand Up @@ -2929,9 +2929,11 @@ float4 PSMain(PSInput input) : SV_TARGET

std::string getVertexPositionTypeTestShader(const std::string &subType,
const std::string &positionType,
const std::string &check) {
const std::string command(R"(// RUN: %dxc -T vs_6_0 -E main)");
const std::string code = command + subType + R"(
const std::string &check,
bool use16bit) {
const std::string code = std::string(R"(// RUN: %dxc -T vs_6_2 -E main)") +
(use16bit ? R"( -enable-16bit-types)" : R"()") + R"(
)" + subType + R"(
struct output {
)" + positionType + R"(
};
Expand All @@ -2946,41 +2948,46 @@ output main() : SV_Position
}

const char *kInvalidPositionTypeForVSErrorMessage =
"// CHECK: error: semantic Position must be float4 or a composite type "
"recursively including only float4";
"// CHECK: error: SV_Position must be a 4-component 32-bit float vector or "
"a composite which recursively contains only such a vector";

TEST_F(FileTest, PositionInVSWithArrayType) {
runCodeTest(getVertexPositionTypeTestShader(
"", "float x[4];", kInvalidPositionTypeForVSErrorMessage),
Expect::Failure);
runCodeTest(
getVertexPositionTypeTestShader(
"", "float x[4];", kInvalidPositionTypeForVSErrorMessage, false),
Expect::Failure);
}
TEST_F(FileTest, PositionInVSWithDoubleType) {
runCodeTest(getVertexPositionTypeTestShader(
"", "double4 x;", kInvalidPositionTypeForVSErrorMessage),
Expect::Failure);
runCodeTest(
getVertexPositionTypeTestShader(
"", "double4 x;", kInvalidPositionTypeForVSErrorMessage, false),
Expect::Failure);
}
TEST_F(FileTest, PositionInVSWithIntType) {
runCodeTest(getVertexPositionTypeTestShader(
"", "int4 x;", kInvalidPositionTypeForVSErrorMessage),
"", "int4 x;", kInvalidPositionTypeForVSErrorMessage, false),
Expect::Failure);
}
TEST_F(FileTest, PositionInVSWithMatrixType) {
runCodeTest(getVertexPositionTypeTestShader(
"", "float1x4 x;", kInvalidPositionTypeForVSErrorMessage),
Expect::Failure);
runCodeTest(
getVertexPositionTypeTestShader(
"", "float1x4 x;", kInvalidPositionTypeForVSErrorMessage, false),
Expect::Failure);
}
TEST_F(FileTest, PositionInVSWithInvalidFloatVectorType) {
runCodeTest(getVertexPositionTypeTestShader(
"", "float3 x;", kInvalidPositionTypeForVSErrorMessage),
Expect::Failure);
runCodeTest(
getVertexPositionTypeTestShader(
"", "float3 x;", kInvalidPositionTypeForVSErrorMessage, false),
Expect::Failure);
}
TEST_F(FileTest, PositionInVSWithInvalidInnerStructType) {
runCodeTest(getVertexPositionTypeTestShader(
R"(
struct InvalidType {
float3 x;
};)",
"InvalidType x;", kInvalidPositionTypeForVSErrorMessage),
"InvalidType x;", kInvalidPositionTypeForVSErrorMessage,
false),
Expect::Failure);
}
TEST_F(FileTest, PositionInVSWithValidInnerStructType) {
Expand All @@ -2991,7 +2998,43 @@ struct validType {
"validType x;", R"(
// CHECK: %validType = OpTypeStruct %v4float
// CHECK: %output = OpTypeStruct %validType
)"));
)",
false));
}
TEST_F(FileTest, PositionInVSWithValidFloatType) {
runCodeTest(getVertexPositionTypeTestShader("", "float4 x;", R"(
// CHECK: %output = OpTypeStruct %v4float
)",
false));
}
TEST_F(FileTest, PositionInVSWithValidMin10Float4Type) {
runCodeTest(getVertexPositionTypeTestShader("", "min10float4 x;", R"(
// CHECK: %output = OpTypeStruct %v4float
)",
false));
}
TEST_F(FileTest, PositionInVSWithValidMin16Float4Type) {
runCodeTest(getVertexPositionTypeTestShader("", "min16float4 x;", R"(
// CHECK: %output = OpTypeStruct %v4float
)",
false));
}
TEST_F(FileTest, PositionInVSWithValidHalf4Type) {
runCodeTest(getVertexPositionTypeTestShader("", "half4 x;", R"(
// CHECK: %output = OpTypeStruct %v4float
)",
false));
}
TEST_F(FileTest, PositionInVSWithInvalidHalf4Type) {
runCodeTest(getVertexPositionTypeTestShader(
"", "half4 x;", kInvalidPositionTypeForVSErrorMessage, true),
Expect::Failure);
}
TEST_F(FileTest, PositionInVSWithInvalidMin10Float4Type) {
runCodeTest(
getVertexPositionTypeTestShader(
"", "min10float4 x;", kInvalidPositionTypeForVSErrorMessage, true),
Expect::Failure);
}
TEST_F(FileTest, ShaderDebugInfoFunction) {
runFileTest("shader.debug.function.hlsl");
Expand Down

0 comments on commit 3fcd83e

Please sign in to comment.