@@ -52,13 +52,15 @@ static std::optional<StringRef> extractMdStringValue(MDNode *Node,
52
52
return NodeText->getString ();
53
53
}
54
54
55
- static Expected<dxbc::ShaderVisibility>
56
- extractShaderVisibility (MDNode *Node, unsigned int OpId) {
55
+ template <typename T, typename = std::enable_if_t <
56
+ std::is_enum_v<T> &&
57
+ std::is_same_v<std::underlying_type_t <T>, uint32_t >>>
58
+ Expected<T> extractEnumValue (MDNode *Node, unsigned int OpId, StringRef ErrText,
59
+ llvm::function_ref<bool (uint32_t )> VerifyFn) {
57
60
if (std::optional<uint32_t > Val = extractMdIntValue (Node, OpId)) {
58
- if (!dxbc::isValidShaderVisibility (*Val))
59
- return make_error<RootSignatureValidationError<uint32_t >>(
60
- " ShaderVisibility" , *Val);
61
- return dxbc::ShaderVisibility (*Val);
61
+ if (!VerifyFn (*Val))
62
+ return make_error<RootSignatureValidationError<uint32_t >>(ErrText, *Val);
63
+ return static_cast <T>(*Val);
62
64
}
63
65
return make_error<InvalidRSMetadataValue>(" ShaderVisibility" );
64
66
}
@@ -233,7 +235,9 @@ Error MetadataParser::parseRootConstants(mcdxbc::RootSignatureDesc &RSD,
233
235
return make_error<InvalidRSMetadataFormat>(" RootConstants Element" );
234
236
235
237
Expected<dxbc::ShaderVisibility> Visibility =
236
- extractShaderVisibility (RootConstantNode, 1 );
238
+ extractEnumValue<dxbc::ShaderVisibility>(RootConstantNode, 1 ,
239
+ " ShaderVisibility" ,
240
+ dxbc::isValidShaderVisibility);
237
241
if (auto E = Visibility.takeError ())
238
242
return Error (std::move (E));
239
243
@@ -287,7 +291,9 @@ Error MetadataParser::parseRootDescriptors(
287
291
}
288
292
289
293
Expected<dxbc::ShaderVisibility> Visibility =
290
- extractShaderVisibility (RootDescriptorNode, 1 );
294
+ extractEnumValue<dxbc::ShaderVisibility>(RootDescriptorNode, 1 ,
295
+ " ShaderVisibility" ,
296
+ dxbc::isValidShaderVisibility);
291
297
if (auto E = Visibility.takeError ())
292
298
return Error (std::move (E));
293
299
@@ -380,7 +386,9 @@ Error MetadataParser::parseDescriptorTable(mcdxbc::RootSignatureDesc &RSD,
380
386
return make_error<InvalidRSMetadataFormat>(" Descriptor Table" );
381
387
382
388
Expected<dxbc::ShaderVisibility> Visibility =
383
- extractShaderVisibility (DescriptorTableNode, 1 );
389
+ extractEnumValue<dxbc::ShaderVisibility>(DescriptorTableNode, 1 ,
390
+ " ShaderVisibility" ,
391
+ dxbc::isValidShaderVisibility);
384
392
if (auto E = Visibility.takeError ())
385
393
return Error (std::move (E));
386
394
@@ -406,26 +414,34 @@ Error MetadataParser::parseStaticSampler(mcdxbc::RootSignatureDesc &RSD,
406
414
if (StaticSamplerNode->getNumOperands () != 14 )
407
415
return make_error<InvalidRSMetadataFormat>(" Static Sampler" );
408
416
409
- dxbc::RTS0::v1::StaticSampler Sampler;
410
- if (std::optional<uint32_t > Val = extractMdIntValue (StaticSamplerNode, 1 ))
411
- Sampler.Filter = *Val;
412
- else
413
- return make_error<InvalidRSMetadataValue>(" Filter" );
417
+ mcdxbc::StaticSampler Sampler;
414
418
415
- if (std::optional<uint32_t > Val = extractMdIntValue (StaticSamplerNode, 2 ))
416
- Sampler.AddressU = *Val;
417
- else
418
- return make_error<InvalidRSMetadataValue>(" AddressU" );
419
+ Expected<dxbc::SamplerFilter> Filter = extractEnumValue<dxbc::SamplerFilter>(
420
+ StaticSamplerNode, 1 , " Filter" , dxbc::isValidSamplerFilter);
421
+ if (auto E = Filter.takeError ())
422
+ return Error (std::move (E));
423
+ Sampler.Filter = *Filter;
419
424
420
- if (std::optional<uint32_t > Val = extractMdIntValue (StaticSamplerNode, 3 ))
421
- Sampler.AddressV = *Val;
422
- else
423
- return make_error<InvalidRSMetadataValue>(" AddressV" );
425
+ Expected<dxbc::TextureAddressMode> AddressU =
426
+ extractEnumValue<dxbc::TextureAddressMode>(
427
+ StaticSamplerNode, 2 , " AddressU" , dxbc::isValidAddress);
428
+ if (auto E = AddressU.takeError ())
429
+ return Error (std::move (E));
430
+ Sampler.AddressU = *AddressU;
424
431
425
- if (std::optional<uint32_t > Val = extractMdIntValue (StaticSamplerNode, 4 ))
426
- Sampler.AddressW = *Val;
427
- else
428
- return make_error<InvalidRSMetadataValue>(" AddressW" );
432
+ Expected<dxbc::TextureAddressMode> AddressV =
433
+ extractEnumValue<dxbc::TextureAddressMode>(
434
+ StaticSamplerNode, 3 , " AddressV" , dxbc::isValidAddress);
435
+ if (auto E = AddressV.takeError ())
436
+ return Error (std::move (E));
437
+ Sampler.AddressV = *AddressV;
438
+
439
+ Expected<dxbc::TextureAddressMode> AddressW =
440
+ extractEnumValue<dxbc::TextureAddressMode>(
441
+ StaticSamplerNode, 4 , " AddressW" , dxbc::isValidAddress);
442
+ if (auto E = AddressW.takeError ())
443
+ return Error (std::move (E));
444
+ Sampler.AddressW = *AddressW;
429
445
430
446
if (std::optional<float > Val = extractMdFloatValue (StaticSamplerNode, 5 ))
431
447
Sampler.MipLODBias = *Val;
@@ -437,15 +453,19 @@ Error MetadataParser::parseStaticSampler(mcdxbc::RootSignatureDesc &RSD,
437
453
else
438
454
return make_error<InvalidRSMetadataValue>(" MaxAnisotropy" );
439
455
440
- if (std::optional<uint32_t > Val = extractMdIntValue (StaticSamplerNode, 7 ))
441
- Sampler.ComparisonFunc = *Val;
442
- else
443
- return make_error<InvalidRSMetadataValue>(" ComparisonFunc" );
456
+ Expected<dxbc::ComparisonFunc> ComparisonFunc =
457
+ extractEnumValue<dxbc::ComparisonFunc>(
458
+ StaticSamplerNode, 7 , " ComparisonFunc" , dxbc::isValidComparisonFunc);
459
+ if (auto E = ComparisonFunc.takeError ())
460
+ return Error (std::move (E));
461
+ Sampler.ComparisonFunc = *ComparisonFunc;
444
462
445
- if (std::optional<uint32_t > Val = extractMdIntValue (StaticSamplerNode, 8 ))
446
- Sampler.BorderColor = *Val;
447
- else
448
- return make_error<InvalidRSMetadataValue>(" ComparisonFunc" );
463
+ Expected<dxbc::StaticBorderColor> BorderColor =
464
+ extractEnumValue<dxbc::StaticBorderColor>(
465
+ StaticSamplerNode, 8 , " BorderColor" , dxbc::isValidBorderColor);
466
+ if (auto E = BorderColor.takeError ())
467
+ return Error (std::move (E));
468
+ Sampler.BorderColor = *BorderColor;
449
469
450
470
if (std::optional<float > Val = extractMdFloatValue (StaticSamplerNode, 9 ))
451
471
Sampler.MinLOD = *Val;
@@ -467,10 +487,13 @@ Error MetadataParser::parseStaticSampler(mcdxbc::RootSignatureDesc &RSD,
467
487
else
468
488
return make_error<InvalidRSMetadataValue>(" RegisterSpace" );
469
489
470
- if (std::optional<uint32_t > Val = extractMdIntValue (StaticSamplerNode, 13 ))
471
- Sampler.ShaderVisibility = *Val;
472
- else
473
- return make_error<InvalidRSMetadataValue>(" ShaderVisibility" );
490
+ Expected<dxbc::ShaderVisibility> Visibility =
491
+ extractEnumValue<dxbc::ShaderVisibility>(StaticSamplerNode, 13 ,
492
+ " ShaderVisibility" ,
493
+ dxbc::isValidShaderVisibility);
494
+ if (auto E = Visibility.takeError ())
495
+ return Error (std::move (E));
496
+ Sampler.ShaderVisibility = *Visibility;
474
497
475
498
RSD.StaticSamplers .push_back (Sampler);
476
499
return Error::success ();
@@ -594,30 +617,7 @@ Error MetadataParser::validateRootSignature(
594
617
}
595
618
}
596
619
597
- for (const dxbc::RTS0::v1::StaticSampler &Sampler : RSD.StaticSamplers ) {
598
- if (!hlsl::rootsig::verifySamplerFilter (Sampler.Filter ))
599
- DeferredErrs =
600
- joinErrors (std::move (DeferredErrs),
601
- make_error<RootSignatureValidationError<uint32_t >>(
602
- " Filter" , Sampler.Filter ));
603
-
604
- if (!hlsl::rootsig::verifyAddress (Sampler.AddressU ))
605
- DeferredErrs =
606
- joinErrors (std::move (DeferredErrs),
607
- make_error<RootSignatureValidationError<uint32_t >>(
608
- " AddressU" , Sampler.AddressU ));
609
-
610
- if (!hlsl::rootsig::verifyAddress (Sampler.AddressV ))
611
- DeferredErrs =
612
- joinErrors (std::move (DeferredErrs),
613
- make_error<RootSignatureValidationError<uint32_t >>(
614
- " AddressV" , Sampler.AddressV ));
615
-
616
- if (!hlsl::rootsig::verifyAddress (Sampler.AddressW ))
617
- DeferredErrs =
618
- joinErrors (std::move (DeferredErrs),
619
- make_error<RootSignatureValidationError<uint32_t >>(
620
- " AddressW" , Sampler.AddressW ));
620
+ for (const mcdxbc::StaticSampler &Sampler : RSD.StaticSamplers ) {
621
621
622
622
if (!hlsl::rootsig::verifyMipLODBias (Sampler.MipLODBias ))
623
623
DeferredErrs = joinErrors (std::move (DeferredErrs),
@@ -630,18 +630,6 @@ Error MetadataParser::validateRootSignature(
630
630
make_error<RootSignatureValidationError<uint32_t >>(
631
631
" MaxAnisotropy" , Sampler.MaxAnisotropy ));
632
632
633
- if (!hlsl::rootsig::verifyComparisonFunc (Sampler.ComparisonFunc ))
634
- DeferredErrs =
635
- joinErrors (std::move (DeferredErrs),
636
- make_error<RootSignatureValidationError<uint32_t >>(
637
- " ComparisonFunc" , Sampler.ComparisonFunc ));
638
-
639
- if (!hlsl::rootsig::verifyBorderColor (Sampler.BorderColor ))
640
- DeferredErrs =
641
- joinErrors (std::move (DeferredErrs),
642
- make_error<RootSignatureValidationError<uint32_t >>(
643
- " BorderColor" , Sampler.BorderColor ));
644
-
645
633
if (!hlsl::rootsig::verifyLOD (Sampler.MinLOD ))
646
634
DeferredErrs = joinErrors (std::move (DeferredErrs),
647
635
make_error<RootSignatureValidationError<float >>(
@@ -663,12 +651,6 @@ Error MetadataParser::validateRootSignature(
663
651
joinErrors (std::move (DeferredErrs),
664
652
make_error<RootSignatureValidationError<uint32_t >>(
665
653
" RegisterSpace" , Sampler.RegisterSpace ));
666
-
667
- if (!dxbc::isValidShaderVisibility (Sampler.ShaderVisibility ))
668
- DeferredErrs =
669
- joinErrors (std::move (DeferredErrs),
670
- make_error<RootSignatureValidationError<uint32_t >>(
671
- " ShaderVisibility" , Sampler.ShaderVisibility ));
672
654
}
673
655
674
656
return DeferredErrs;
0 commit comments