Skip to content

Conversation

@farzonl
Copy link
Member

@farzonl farzonl commented Dec 7, 2025

fixes #171049
fixes #171050

  • Allow Bools for matrix type when in HLSL mode
  • use ConvertTypeForMem to figure out the bool size
  • Add Bool matrix types to hlsl_basic_types.h

fixes llvm#171049
fixes llvm#171050

- Allow Bools for matrix type when in HLSL mode
- use ConvertTypeForMem to figure out the bool size
- Add Bool matrix types to hlsl_basic_types.h
@llvmbot llvmbot added clang Clang issues not falling into any other category backend:X86 clang:frontend Language frontend issues, e.g. anything involving "Sema" clang:headers Headers provided by Clang, e.g. for intrinsics clang:codegen IR generation bugs: mangling, exceptions, etc. HLSL HLSL Language Support labels Dec 7, 2025
@llvmbot
Copy link
Member

llvmbot commented Dec 7, 2025

@llvm/pr-subscribers-clang

@llvm/pr-subscribers-hlsl

Author: Farzon Lotfi (farzonl)

Changes

fixes #171049
fixes #171050

  • Allow Bools for matrix type when in HLSL mode
  • use ConvertTypeForMem to figure out the bool size
  • Add Bool matrix types to hlsl_basic_types.h

Full diff: https://github.com/llvm/llvm-project/pull/171051.diff

9 Files Affected:

  • (modified) clang/include/clang/AST/TypeBase.h (+27-4)
  • (modified) clang/lib/AST/ASTContext.cpp (+1-1)
  • (modified) clang/lib/CodeGen/CGExpr.cpp (+6-1)
  • (modified) clang/lib/CodeGen/CodeGenTypes.cpp (+4-1)
  • (modified) clang/lib/Headers/hlsl/hlsl_basic_types.h (+17)
  • (modified) clang/lib/Sema/SemaChecking.cpp (+2-2)
  • (modified) clang/lib/Sema/SemaType.cpp (+1-1)
  • (added) clang/test/CodeGenHLSL/BoolMatrix.hlsl (+151)
  • (modified) clang/test/CodeGenHLSL/basic_types.hlsl (+33)
diff --git a/clang/include/clang/AST/TypeBase.h b/clang/include/clang/AST/TypeBase.h
index 30b9efe5a31b7..cf6897b6e515c 100644
--- a/clang/include/clang/AST/TypeBase.h
+++ b/clang/include/clang/AST/TypeBase.h
@@ -2637,6 +2637,7 @@ class alignas(TypeAlignment) Type : public ExtQualsTypeCommonBase {
   bool isVectorType() const;                    // GCC vector type.
   bool isExtVectorType() const;                 // Extended vector type.
   bool isExtVectorBoolType() const;             // Extended vector type with bool element.
+  bool isConstantMatrixBoolType() const; // Matrix type with bool element.
   // Extended vector type with bool element that is packed. HLSL doesn't pack
   // its bool vectors.
   bool isPackedVectorBoolType(const ASTContext &ctx) const;
@@ -4352,12 +4353,26 @@ class MatrixType : public Type, public llvm::FoldingSetNode {
 
   /// Valid elements types are the following:
   /// * an integer type (as in C23 6.2.5p22), but excluding enumerated types
-  ///   and _Bool
+  ///   and _Bool (except that in HLSL, bool is allowed)
   /// * the standard floating types float or double
   /// * a half-precision floating point type, if one is supported on the target
-  static bool isValidElementType(QualType T) {
-    return T->isDependentType() ||
-           (T->isRealType() && !T->isBooleanType() && !T->isEnumeralType());
+  static bool isValidElementType(QualType T, const LangOptions &LangOpts) {
+    // Dependent is always okay
+    if (T->isDependentType())
+      return true;
+
+    // Enums are never okay
+    if (T->isEnumeralType())
+      return false;
+
+    // In HLSL, bool is allowed as a matrix element type.
+    // Note: isRealType includes bool so don't need to check
+    if (LangOpts.HLSL)
+      return T->isRealType();
+
+    // In non-HLSL modes, follow the existing rule:
+    // real type, but not _Bool.
+    return T->isRealType() && !T->isBooleanType();
   }
 
   bool isSugared() const { return false; }
@@ -8665,6 +8680,14 @@ inline bool Type::isExtVectorBoolType() const {
   return cast<ExtVectorType>(CanonicalType)->getElementType()->isBooleanType();
 }
 
+inline bool Type::isConstantMatrixBoolType() const {
+  if (!isConstantMatrixType())
+    return false;
+  return cast<ConstantMatrixType>(CanonicalType)
+      ->getElementType()
+      ->isBooleanType();
+}
+
 inline bool Type::isSubscriptableVectorType() const {
   return isVectorType() || isSveVLSBuiltinType();
 }
diff --git a/clang/lib/AST/ASTContext.cpp b/clang/lib/AST/ASTContext.cpp
index 404ce3ffd77c7..5ca76c79df7c6 100644
--- a/clang/lib/AST/ASTContext.cpp
+++ b/clang/lib/AST/ASTContext.cpp
@@ -4712,7 +4712,7 @@ QualType ASTContext::getConstantMatrixType(QualType ElementTy, unsigned NumRows,
   ConstantMatrixType::Profile(ID, ElementTy, NumRows, NumColumns,
                               Type::ConstantMatrix);
 
-  assert(MatrixType::isValidElementType(ElementTy) &&
+  assert(MatrixType::isValidElementType(ElementTy, getLangOpts()) &&
          "need a valid element type");
   assert(NumRows > 0 && NumRows <= LangOpts.MaxMatrixDimension &&
          NumColumns > 0 && NumColumns <= LangOpts.MaxMatrixDimension &&
diff --git a/clang/lib/CodeGen/CGExpr.cpp b/clang/lib/CodeGen/CGExpr.cpp
index 3bde8e1fa2ac3..b44c1a7d18120 100644
--- a/clang/lib/CodeGen/CGExpr.cpp
+++ b/clang/lib/CodeGen/CGExpr.cpp
@@ -2655,8 +2655,13 @@ void CodeGenFunction::EmitStoreThroughLValue(RValue Src, LValue Dst,
         MB.CreateIndexAssumption(Idx, MatTy->getNumElementsFlattened());
       }
       llvm::Instruction *Load = Builder.CreateLoad(Dst.getMatrixAddress());
+      llvm::Value *InsertVal = Src.getScalarVal();
+      if (getLangOpts().HLSL && InsertVal->getType()->isIntegerTy(1)) {
+        llvm::Type *StorageElmTy = Load->getType()->getScalarType();
+        InsertVal = Builder.CreateZExt(InsertVal, StorageElmTy);
+      }
       llvm::Value *Vec =
-          Builder.CreateInsertElement(Load, Src.getScalarVal(), Idx, "matins");
+          Builder.CreateInsertElement(Load, InsertVal, Idx, "matins");
       auto *I = Builder.CreateStore(Vec, Dst.getMatrixAddress(),
                                     Dst.isVolatileQualified());
       addInstToCurrentSourceAtom(I, Vec);
diff --git a/clang/lib/CodeGen/CodeGenTypes.cpp b/clang/lib/CodeGen/CodeGenTypes.cpp
index be862cf07f177..a41bf86d6f95c 100644
--- a/clang/lib/CodeGen/CodeGenTypes.cpp
+++ b/clang/lib/CodeGen/CodeGenTypes.cpp
@@ -104,7 +104,10 @@ llvm::Type *CodeGenTypes::ConvertTypeForMem(QualType T) {
   if (T->isConstantMatrixType()) {
     const Type *Ty = Context.getCanonicalType(T).getTypePtr();
     const ConstantMatrixType *MT = cast<ConstantMatrixType>(Ty);
-    return llvm::ArrayType::get(ConvertType(MT->getElementType()),
+    llvm::Type *IRElemTy = ConvertType(MT->getElementType());
+    if (T->isConstantMatrixBoolType() && Context.getLangOpts().HLSL)
+      IRElemTy = ConvertTypeForMem(Context.BoolTy);
+    return llvm::ArrayType::get(IRElemTy,
                                 MT->getNumRows() * MT->getNumColumns());
   }
 
diff --git a/clang/lib/Headers/hlsl/hlsl_basic_types.h b/clang/lib/Headers/hlsl/hlsl_basic_types.h
index fc1e265067714..b1d87c51de9bb 100644
--- a/clang/lib/Headers/hlsl/hlsl_basic_types.h
+++ b/clang/lib/Headers/hlsl/hlsl_basic_types.h
@@ -150,6 +150,23 @@ typedef matrix<uint16_t, 4, 3> uint16_t4x3;
 typedef matrix<uint16_t, 4, 4> uint16_t4x4;
 #endif
 
+typedef matrix<bool, 1, 1> bool1x1;
+typedef matrix<bool, 1, 2> bool1x2;
+typedef matrix<bool, 1, 3> bool1x3;
+typedef matrix<bool, 1, 4> bool1x4;
+typedef matrix<bool, 2, 1> bool2x1;
+typedef matrix<bool, 2, 2> bool2x2;
+typedef matrix<bool, 2, 3> bool2x3;
+typedef matrix<bool, 2, 4> bool2x4;
+typedef matrix<bool, 3, 1> bool3x1;
+typedef matrix<bool, 3, 2> bool3x2;
+typedef matrix<bool, 3, 3> bool3x3;
+typedef matrix<bool, 3, 4> bool3x4;
+typedef matrix<bool, 4, 1> bool4x1;
+typedef matrix<bool, 4, 2> bool4x2;
+typedef matrix<bool, 4, 3> bool4x3;
+typedef matrix<bool, 4, 4> bool4x4;
+
 typedef matrix<int, 1, 1> int1x1;
 typedef matrix<int, 1, 2> int1x2;
 typedef matrix<int, 1, 3> int1x3;
diff --git a/clang/lib/Sema/SemaChecking.cpp b/clang/lib/Sema/SemaChecking.cpp
index 02c838bc4a862..bdf2e08400801 100644
--- a/clang/lib/Sema/SemaChecking.cpp
+++ b/clang/lib/Sema/SemaChecking.cpp
@@ -2145,7 +2145,7 @@ checkMathBuiltinElementType(Sema &S, SourceLocation Loc, QualType ArgTy,
   switch (ArgTyRestr) {
   case Sema::EltwiseBuiltinArgTyRestriction::None:
     if (!ArgTy->getAs<VectorType>() &&
-        !ConstantMatrixType::isValidElementType(ArgTy)) {
+        !ConstantMatrixType::isValidElementType(ArgTy, S.getLangOpts())) {
       return S.Diag(Loc, diag::err_builtin_invalid_arg_type)
              << ArgOrdinal << /* vector */ 2 << /* integer */ 1 << /* fp */ 1
              << ArgTy;
@@ -16545,7 +16545,7 @@ ExprResult Sema::BuiltinMatrixColumnMajorLoad(CallExpr *TheCall,
   } else {
     ElementTy = PtrTy->getPointeeType().getUnqualifiedType();
 
-    if (!ConstantMatrixType::isValidElementType(ElementTy)) {
+    if (!ConstantMatrixType::isValidElementType(ElementTy, getLangOpts())) {
       Diag(PtrExpr->getBeginLoc(), diag::err_builtin_invalid_arg_type)
           << PtrArgIdx + 1 << 0 << /* pointer to element ty */ 5
           << /* no fp */ 0 << PtrExpr->getType();
diff --git a/clang/lib/Sema/SemaType.cpp b/clang/lib/Sema/SemaType.cpp
index fd64d4456cbfa..7ef83433326ed 100644
--- a/clang/lib/Sema/SemaType.cpp
+++ b/clang/lib/Sema/SemaType.cpp
@@ -2467,7 +2467,7 @@ QualType Sema::BuildMatrixType(QualType ElementTy, Expr *NumRows, Expr *NumCols,
 
   // Check element type, if it is not dependent.
   if (!ElementTy->isDependentType() &&
-      !MatrixType::isValidElementType(ElementTy)) {
+      !MatrixType::isValidElementType(ElementTy, getLangOpts())) {
     Diag(AttrLoc, diag::err_attribute_invalid_matrix_type) << ElementTy;
     return QualType();
   }
diff --git a/clang/test/CodeGenHLSL/BoolMatrix.hlsl b/clang/test/CodeGenHLSL/BoolMatrix.hlsl
new file mode 100644
index 0000000000000..da90738b68b96
--- /dev/null
+++ b/clang/test/CodeGenHLSL/BoolMatrix.hlsl
@@ -0,0 +1,151 @@
+// NOTE: Assertions have been autogenerated by utils/update_cc_test_checks.py UTC_ARGS: --version 6
+// RUN: %clang_cc1 -finclude-default-header -triple dxil-pc-shadermodel6.3-library -emit-llvm -disable-llvm-passes -o - %s | FileCheck %s
+
+
+struct S {
+    bool2x2 bM;
+    float f;
+};
+
+// CHECK-LABEL: define hidden noundef i1 @_Z3fn1v(
+// CHECK-SAME: ) #[[ATTR0:[0-9]+]] {
+// CHECK-NEXT:  [[ENTRY:.*:]]
+// CHECK-NEXT:    [[RETVAL:%.*]] = alloca i1, align 4
+// CHECK-NEXT:    [[B:%.*]] = alloca [4 x i32], align 4
+// CHECK-NEXT:    store <4 x i1> splat (i1 true), ptr [[B]], align 4
+// CHECK-NEXT:    [[TMP0:%.*]] = load <4 x i32>, ptr [[B]], align 4
+// CHECK-NEXT:    [[MATRIXEXT:%.*]] = extractelement <4 x i32> [[TMP0]], i32 0
+// CHECK-NEXT:    store i32 [[MATRIXEXT]], ptr [[RETVAL]], align 4
+// CHECK-NEXT:    [[TMP1:%.*]] = load i1, ptr [[RETVAL]], align 4
+// CHECK-NEXT:    ret i1 [[TMP1]]
+//
+bool fn1() {
+  bool2x2 B = {true,true,true,true};
+  return B[0][0];
+}
+
+// CHECK-LABEL: define hidden noundef <4 x i1> @_Z3fn2b(
+// CHECK-SAME: i1 noundef [[V:%.*]]) #[[ATTR0]] {
+// CHECK-NEXT:  [[ENTRY:.*:]]
+// CHECK-NEXT:    [[RETVAL:%.*]] = alloca <4 x i1>, align 4
+// CHECK-NEXT:    [[V_ADDR:%.*]] = alloca i32, align 4
+// CHECK-NEXT:    [[A:%.*]] = alloca [4 x i32], align 4
+// CHECK-NEXT:    [[STOREDV:%.*]] = zext i1 [[V]] to i32
+// CHECK-NEXT:    store i32 [[STOREDV]], ptr [[V_ADDR]], align 4
+// CHECK-NEXT:    [[TMP0:%.*]] = load i32, ptr [[V_ADDR]], align 4
+// CHECK-NEXT:    [[LOADEDV:%.*]] = trunc i32 [[TMP0]] to i1
+// CHECK-NEXT:    [[VECINIT:%.*]] = insertelement <4 x i1> poison, i1 [[LOADEDV]], i32 0
+// CHECK-NEXT:    [[TMP1:%.*]] = load i32, ptr [[V_ADDR]], align 4
+// CHECK-NEXT:    [[LOADEDV1:%.*]] = trunc i32 [[TMP1]] to i1
+// CHECK-NEXT:    [[VECINIT2:%.*]] = insertelement <4 x i1> [[VECINIT]], i1 [[LOADEDV1]], i32 1
+// CHECK-NEXT:    [[VECINIT3:%.*]] = insertelement <4 x i1> [[VECINIT2]], i1 true, i32 2
+// CHECK-NEXT:    [[VECINIT4:%.*]] = insertelement <4 x i1> [[VECINIT3]], i1 false, i32 3
+// CHECK-NEXT:    store <4 x i1> [[VECINIT4]], ptr [[A]], align 4
+// CHECK-NEXT:    [[TMP2:%.*]] = load <4 x i32>, ptr [[A]], align 4
+// CHECK-NEXT:    store <4 x i32> [[TMP2]], ptr [[RETVAL]], align 4
+// CHECK-NEXT:    [[TMP3:%.*]] = load <4 x i1>, ptr [[RETVAL]], align 4
+// CHECK-NEXT:    ret <4 x i1> [[TMP3]]
+//
+bool2x2 fn2(bool V) {
+  bool2x2 A = {V, true, V, false};
+  return A;
+}
+
+// CHECK-LABEL: define hidden noundef i1 @_Z3fn3v(
+// CHECK-SAME: ) #[[ATTR0]] {
+// CHECK-NEXT:  [[ENTRY:.*:]]
+// CHECK-NEXT:    [[RETVAL:%.*]] = alloca i1, align 4
+// CHECK-NEXT:    [[S:%.*]] = alloca [[STRUCT_S:%.*]], align 1
+// CHECK-NEXT:    [[BM:%.*]] = getelementptr inbounds nuw [[STRUCT_S]], ptr [[S]], i32 0, i32 0
+// CHECK-NEXT:    store <4 x i1> <i1 true, i1 false, i1 true, i1 false>, ptr [[BM]], align 1
+// CHECK-NEXT:    [[F:%.*]] = getelementptr inbounds nuw [[STRUCT_S]], ptr [[S]], i32 0, i32 1
+// CHECK-NEXT:    store float 1.000000e+00, ptr [[F]], align 1
+// CHECK-NEXT:    [[BM1:%.*]] = getelementptr inbounds nuw [[STRUCT_S]], ptr [[S]], i32 0, i32 0
+// CHECK-NEXT:    [[TMP0:%.*]] = load <4 x i32>, ptr [[BM1]], align 1
+// CHECK-NEXT:    [[MATRIXEXT:%.*]] = extractelement <4 x i32> [[TMP0]], i32 0
+// CHECK-NEXT:    store i32 [[MATRIXEXT]], ptr [[RETVAL]], align 4
+// CHECK-NEXT:    [[TMP1:%.*]] = load i1, ptr [[RETVAL]], align 4
+// CHECK-NEXT:    ret i1 [[TMP1]]
+//
+bool fn3() {
+  S s = {{true,true, false, false}, 1.0};
+  return s.bM[0][0];
+}
+
+// CHECK-LABEL: define hidden noundef i1 @_Z3fn4v(
+// CHECK-SAME: ) #[[ATTR0]] {
+// CHECK-NEXT:  [[ENTRY:.*:]]
+// CHECK-NEXT:    [[RETVAL:%.*]] = alloca i1, align 4
+// CHECK-NEXT:    [[ARR:%.*]] = alloca [2 x [4 x i32]], align 4
+// CHECK-NEXT:    store <4 x i1> splat (i1 true), ptr [[ARR]], align 4
+// CHECK-NEXT:    [[ARRAYINIT_ELEMENT:%.*]] = getelementptr inbounds [4 x i32], ptr [[ARR]], i32 1
+// CHECK-NEXT:    store <4 x i1> zeroinitializer, ptr [[ARRAYINIT_ELEMENT]], align 4
+// CHECK-NEXT:    [[ARRAYIDX:%.*]] = getelementptr inbounds [2 x [4 x i32]], ptr [[ARR]], i32 0, i32 0
+// CHECK-NEXT:    [[TMP0:%.*]] = load <4 x i32>, ptr [[ARRAYIDX]], align 4
+// CHECK-NEXT:    [[MATRIXEXT:%.*]] = extractelement <4 x i32> [[TMP0]], i32 1
+// CHECK-NEXT:    store i32 [[MATRIXEXT]], ptr [[RETVAL]], align 4
+// CHECK-NEXT:    [[TMP1:%.*]] = load i1, ptr [[RETVAL]], align 4
+// CHECK-NEXT:    ret i1 [[TMP1]]
+//
+bool fn4() {
+  bool2x2 Arr[2] = {{true,true,true,true}, {false,false,false,false}};
+  return Arr[0][1][0];
+}
+
+// CHECK-LABEL: define hidden void @_Z3fn5v(
+// CHECK-SAME: ) #[[ATTR0]] {
+// CHECK-NEXT:  [[ENTRY:.*:]]
+// CHECK-NEXT:    [[M:%.*]] = alloca [4 x i32], align 4
+// CHECK-NEXT:    store <4 x i1> splat (i1 true), ptr [[M]], align 4
+// CHECK-NEXT:    [[TMP0:%.*]] = load <4 x i32>, ptr [[M]], align 4
+// CHECK-NEXT:    [[MATINS:%.*]] = insertelement <4 x i32> [[TMP0]], i32 0, i32 3
+// CHECK-NEXT:    store <4 x i32> [[MATINS]], ptr [[M]], align 4
+// CHECK-NEXT:    ret void
+//
+void fn5() {
+  bool2x2 M = {true,true,true,true};
+  M[1][1] = false;
+}
+
+// CHECK-LABEL: define hidden void @_Z3fn6v(
+// CHECK-SAME: ) #[[ATTR0]] {
+// CHECK-NEXT:  [[ENTRY:.*:]]
+// CHECK-NEXT:    [[V:%.*]] = alloca i32, align 4
+// CHECK-NEXT:    [[S:%.*]] = alloca [[STRUCT_S:%.*]], align 1
+// CHECK-NEXT:    store i32 0, ptr [[V]], align 4
+// CHECK-NEXT:    [[BM:%.*]] = getelementptr inbounds nuw [[STRUCT_S]], ptr [[S]], i32 0, i32 0
+// CHECK-NEXT:    store <4 x i1> <i1 true, i1 false, i1 true, i1 false>, ptr [[BM]], align 1
+// CHECK-NEXT:    [[F:%.*]] = getelementptr inbounds nuw [[STRUCT_S]], ptr [[S]], i32 0, i32 1
+// CHECK-NEXT:    store float 1.000000e+00, ptr [[F]], align 1
+// CHECK-NEXT:    [[TMP0:%.*]] = load i32, ptr [[V]], align 4
+// CHECK-NEXT:    [[LOADEDV:%.*]] = trunc i32 [[TMP0]] to i1
+// CHECK-NEXT:    [[BM1:%.*]] = getelementptr inbounds nuw [[STRUCT_S]], ptr [[S]], i32 0, i32 0
+// CHECK-NEXT:    [[TMP1:%.*]] = load <4 x i32>, ptr [[BM1]], align 1
+// CHECK-NEXT:    [[TMP2:%.*]] = zext i1 [[LOADEDV]] to i32
+// CHECK-NEXT:    [[MATINS:%.*]] = insertelement <4 x i32> [[TMP1]], i32 [[TMP2]], i32 1
+// CHECK-NEXT:    store <4 x i32> [[MATINS]], ptr [[BM1]], align 1
+// CHECK-NEXT:    ret void
+//
+void fn6() {
+  bool V = false;
+  S s = {{true,true,false,false}, 1.0};
+  s.bM[1][0] = V;
+}
+
+// CHECK-LABEL: define hidden void @_Z3fn7v(
+// CHECK-SAME: ) #[[ATTR0]] {
+// CHECK-NEXT:  [[ENTRY:.*:]]
+// CHECK-NEXT:    [[ARR:%.*]] = alloca [2 x [4 x i32]], align 4
+// CHECK-NEXT:    store <4 x i1> splat (i1 true), ptr [[ARR]], align 4
+// CHECK-NEXT:    [[ARRAYINIT_ELEMENT:%.*]] = getelementptr inbounds [4 x i32], ptr [[ARR]], i32 1
+// CHECK-NEXT:    store <4 x i1> zeroinitializer, ptr [[ARRAYINIT_ELEMENT]], align 4
+// CHECK-NEXT:    [[ARRAYIDX:%.*]] = getelementptr inbounds [2 x [4 x i32]], ptr [[ARR]], i32 0, i32 0
+// CHECK-NEXT:    [[TMP0:%.*]] = load <4 x i32>, ptr [[ARRAYIDX]], align 4
+// CHECK-NEXT:    [[MATINS:%.*]] = insertelement <4 x i32> [[TMP0]], i32 0, i32 1
+// CHECK-NEXT:    store <4 x i32> [[MATINS]], ptr [[ARRAYIDX]], align 4
+// CHECK-NEXT:    ret void
+//
+void fn7() {
+  bool2x2 Arr[2] = {{true,true,true,true}, {false,false,false,false}};
+  Arr[0][1][0] = false;
+}
diff --git a/clang/test/CodeGenHLSL/basic_types.hlsl b/clang/test/CodeGenHLSL/basic_types.hlsl
index 8836126934957..677a9a8f5d1de 100644
--- a/clang/test/CodeGenHLSL/basic_types.hlsl
+++ b/clang/test/CodeGenHLSL/basic_types.hlsl
@@ -38,6 +38,22 @@
 // CHECK: @double2_Val = external hidden addrspace(2) global <2 x double>, align 16
 // CHECK: @double3_Val = external hidden addrspace(2) global <3 x double>, align 32
 // CHECK: @double4_Val = external hidden addrspace(2) global <4 x double>, align 32
+// CHECK: @bool1x1_Val = external hidden addrspace(2) global [1 x i32], align 4
+// CHECK: @bool1x2_Val = external hidden addrspace(2) global [2 x i32], align 4
+// CHECK: @bool1x3_Val = external hidden addrspace(2) global [3 x i32], align 4
+// CHECK: @bool1x4_Val = external hidden addrspace(2) global [4 x i32], align 4
+// CHECK: @bool2x1_Val = external hidden addrspace(2) global [2 x i32], align 4
+// CHECK: @bool2x2_Val = external hidden addrspace(2) global [4 x i32], align 4
+// CHECK: @bool2x3_Val = external hidden addrspace(2) global [6 x i32], align 4
+// CHECK: @bool2x4_Val = external hidden addrspace(2) global [8 x i32], align 4
+// CHECK: @bool3x1_Val = external hidden addrspace(2) global [3 x i32], align 4
+// CHECK: @bool3x2_Val = external hidden addrspace(2) global [6 x i32], align 4
+// CHECK: @bool3x3_Val = external hidden addrspace(2) global [9 x i32], align 4
+// CHECK: @bool3x4_Val = external hidden addrspace(2) global [12 x i32], align 4
+// CHECK: @bool4x1_Val = external hidden addrspace(2) global [4 x i32], align 4
+// CHECK: @bool4x2_Val = external hidden addrspace(2) global [8 x i32], align 4
+// CHECK: @bool4x3_Val = external hidden addrspace(2) global [12 x i32], align 4
+// CHECK: @bool4x4_Val = external hidden addrspace(2) global [16 x i32], align 4
 
 #ifdef NAMESPACED
 #define TYPE_DECL(T)  hlsl::T T##_Val
@@ -93,3 +109,20 @@ TYPE_DECL( float4  );
 TYPE_DECL( double2 );
 TYPE_DECL( double3 );
 TYPE_DECL( double4 );
+
+TYPE_DECL( bool1x1 );
+TYPE_DECL( bool1x2 );
+TYPE_DECL( bool1x3 );
+TYPE_DECL( bool1x4 );
+TYPE_DECL( bool2x1 );
+TYPE_DECL( bool2x2 );
+TYPE_DECL( bool2x3 );
+TYPE_DECL( bool2x4 );
+TYPE_DECL( bool3x1 );
+TYPE_DECL( bool3x2 );
+TYPE_DECL( bool3x3 );
+TYPE_DECL( bool3x4 );
+TYPE_DECL( bool4x1 );
+TYPE_DECL( bool4x2 );
+TYPE_DECL( bool4x3 );
+TYPE_DECL( bool4x4 );

@llvmbot
Copy link
Member

llvmbot commented Dec 7, 2025

@llvm/pr-subscribers-clang-codegen

Author: Farzon Lotfi (farzonl)

Changes

fixes #171049
fixes #171050

  • Allow Bools for matrix type when in HLSL mode
  • use ConvertTypeForMem to figure out the bool size
  • Add Bool matrix types to hlsl_basic_types.h

Full diff: https://github.com/llvm/llvm-project/pull/171051.diff

9 Files Affected:

  • (modified) clang/include/clang/AST/TypeBase.h (+27-4)
  • (modified) clang/lib/AST/ASTContext.cpp (+1-1)
  • (modified) clang/lib/CodeGen/CGExpr.cpp (+6-1)
  • (modified) clang/lib/CodeGen/CodeGenTypes.cpp (+4-1)
  • (modified) clang/lib/Headers/hlsl/hlsl_basic_types.h (+17)
  • (modified) clang/lib/Sema/SemaChecking.cpp (+2-2)
  • (modified) clang/lib/Sema/SemaType.cpp (+1-1)
  • (added) clang/test/CodeGenHLSL/BoolMatrix.hlsl (+151)
  • (modified) clang/test/CodeGenHLSL/basic_types.hlsl (+33)
diff --git a/clang/include/clang/AST/TypeBase.h b/clang/include/clang/AST/TypeBase.h
index 30b9efe5a31b7..cf6897b6e515c 100644
--- a/clang/include/clang/AST/TypeBase.h
+++ b/clang/include/clang/AST/TypeBase.h
@@ -2637,6 +2637,7 @@ class alignas(TypeAlignment) Type : public ExtQualsTypeCommonBase {
   bool isVectorType() const;                    // GCC vector type.
   bool isExtVectorType() const;                 // Extended vector type.
   bool isExtVectorBoolType() const;             // Extended vector type with bool element.
+  bool isConstantMatrixBoolType() const; // Matrix type with bool element.
   // Extended vector type with bool element that is packed. HLSL doesn't pack
   // its bool vectors.
   bool isPackedVectorBoolType(const ASTContext &ctx) const;
@@ -4352,12 +4353,26 @@ class MatrixType : public Type, public llvm::FoldingSetNode {
 
   /// Valid elements types are the following:
   /// * an integer type (as in C23 6.2.5p22), but excluding enumerated types
-  ///   and _Bool
+  ///   and _Bool (except that in HLSL, bool is allowed)
   /// * the standard floating types float or double
   /// * a half-precision floating point type, if one is supported on the target
-  static bool isValidElementType(QualType T) {
-    return T->isDependentType() ||
-           (T->isRealType() && !T->isBooleanType() && !T->isEnumeralType());
+  static bool isValidElementType(QualType T, const LangOptions &LangOpts) {
+    // Dependent is always okay
+    if (T->isDependentType())
+      return true;
+
+    // Enums are never okay
+    if (T->isEnumeralType())
+      return false;
+
+    // In HLSL, bool is allowed as a matrix element type.
+    // Note: isRealType includes bool so don't need to check
+    if (LangOpts.HLSL)
+      return T->isRealType();
+
+    // In non-HLSL modes, follow the existing rule:
+    // real type, but not _Bool.
+    return T->isRealType() && !T->isBooleanType();
   }
 
   bool isSugared() const { return false; }
@@ -8665,6 +8680,14 @@ inline bool Type::isExtVectorBoolType() const {
   return cast<ExtVectorType>(CanonicalType)->getElementType()->isBooleanType();
 }
 
+inline bool Type::isConstantMatrixBoolType() const {
+  if (!isConstantMatrixType())
+    return false;
+  return cast<ConstantMatrixType>(CanonicalType)
+      ->getElementType()
+      ->isBooleanType();
+}
+
 inline bool Type::isSubscriptableVectorType() const {
   return isVectorType() || isSveVLSBuiltinType();
 }
diff --git a/clang/lib/AST/ASTContext.cpp b/clang/lib/AST/ASTContext.cpp
index 404ce3ffd77c7..5ca76c79df7c6 100644
--- a/clang/lib/AST/ASTContext.cpp
+++ b/clang/lib/AST/ASTContext.cpp
@@ -4712,7 +4712,7 @@ QualType ASTContext::getConstantMatrixType(QualType ElementTy, unsigned NumRows,
   ConstantMatrixType::Profile(ID, ElementTy, NumRows, NumColumns,
                               Type::ConstantMatrix);
 
-  assert(MatrixType::isValidElementType(ElementTy) &&
+  assert(MatrixType::isValidElementType(ElementTy, getLangOpts()) &&
          "need a valid element type");
   assert(NumRows > 0 && NumRows <= LangOpts.MaxMatrixDimension &&
          NumColumns > 0 && NumColumns <= LangOpts.MaxMatrixDimension &&
diff --git a/clang/lib/CodeGen/CGExpr.cpp b/clang/lib/CodeGen/CGExpr.cpp
index 3bde8e1fa2ac3..b44c1a7d18120 100644
--- a/clang/lib/CodeGen/CGExpr.cpp
+++ b/clang/lib/CodeGen/CGExpr.cpp
@@ -2655,8 +2655,13 @@ void CodeGenFunction::EmitStoreThroughLValue(RValue Src, LValue Dst,
         MB.CreateIndexAssumption(Idx, MatTy->getNumElementsFlattened());
       }
       llvm::Instruction *Load = Builder.CreateLoad(Dst.getMatrixAddress());
+      llvm::Value *InsertVal = Src.getScalarVal();
+      if (getLangOpts().HLSL && InsertVal->getType()->isIntegerTy(1)) {
+        llvm::Type *StorageElmTy = Load->getType()->getScalarType();
+        InsertVal = Builder.CreateZExt(InsertVal, StorageElmTy);
+      }
       llvm::Value *Vec =
-          Builder.CreateInsertElement(Load, Src.getScalarVal(), Idx, "matins");
+          Builder.CreateInsertElement(Load, InsertVal, Idx, "matins");
       auto *I = Builder.CreateStore(Vec, Dst.getMatrixAddress(),
                                     Dst.isVolatileQualified());
       addInstToCurrentSourceAtom(I, Vec);
diff --git a/clang/lib/CodeGen/CodeGenTypes.cpp b/clang/lib/CodeGen/CodeGenTypes.cpp
index be862cf07f177..a41bf86d6f95c 100644
--- a/clang/lib/CodeGen/CodeGenTypes.cpp
+++ b/clang/lib/CodeGen/CodeGenTypes.cpp
@@ -104,7 +104,10 @@ llvm::Type *CodeGenTypes::ConvertTypeForMem(QualType T) {
   if (T->isConstantMatrixType()) {
     const Type *Ty = Context.getCanonicalType(T).getTypePtr();
     const ConstantMatrixType *MT = cast<ConstantMatrixType>(Ty);
-    return llvm::ArrayType::get(ConvertType(MT->getElementType()),
+    llvm::Type *IRElemTy = ConvertType(MT->getElementType());
+    if (T->isConstantMatrixBoolType() && Context.getLangOpts().HLSL)
+      IRElemTy = ConvertTypeForMem(Context.BoolTy);
+    return llvm::ArrayType::get(IRElemTy,
                                 MT->getNumRows() * MT->getNumColumns());
   }
 
diff --git a/clang/lib/Headers/hlsl/hlsl_basic_types.h b/clang/lib/Headers/hlsl/hlsl_basic_types.h
index fc1e265067714..b1d87c51de9bb 100644
--- a/clang/lib/Headers/hlsl/hlsl_basic_types.h
+++ b/clang/lib/Headers/hlsl/hlsl_basic_types.h
@@ -150,6 +150,23 @@ typedef matrix<uint16_t, 4, 3> uint16_t4x3;
 typedef matrix<uint16_t, 4, 4> uint16_t4x4;
 #endif
 
+typedef matrix<bool, 1, 1> bool1x1;
+typedef matrix<bool, 1, 2> bool1x2;
+typedef matrix<bool, 1, 3> bool1x3;
+typedef matrix<bool, 1, 4> bool1x4;
+typedef matrix<bool, 2, 1> bool2x1;
+typedef matrix<bool, 2, 2> bool2x2;
+typedef matrix<bool, 2, 3> bool2x3;
+typedef matrix<bool, 2, 4> bool2x4;
+typedef matrix<bool, 3, 1> bool3x1;
+typedef matrix<bool, 3, 2> bool3x2;
+typedef matrix<bool, 3, 3> bool3x3;
+typedef matrix<bool, 3, 4> bool3x4;
+typedef matrix<bool, 4, 1> bool4x1;
+typedef matrix<bool, 4, 2> bool4x2;
+typedef matrix<bool, 4, 3> bool4x3;
+typedef matrix<bool, 4, 4> bool4x4;
+
 typedef matrix<int, 1, 1> int1x1;
 typedef matrix<int, 1, 2> int1x2;
 typedef matrix<int, 1, 3> int1x3;
diff --git a/clang/lib/Sema/SemaChecking.cpp b/clang/lib/Sema/SemaChecking.cpp
index 02c838bc4a862..bdf2e08400801 100644
--- a/clang/lib/Sema/SemaChecking.cpp
+++ b/clang/lib/Sema/SemaChecking.cpp
@@ -2145,7 +2145,7 @@ checkMathBuiltinElementType(Sema &S, SourceLocation Loc, QualType ArgTy,
   switch (ArgTyRestr) {
   case Sema::EltwiseBuiltinArgTyRestriction::None:
     if (!ArgTy->getAs<VectorType>() &&
-        !ConstantMatrixType::isValidElementType(ArgTy)) {
+        !ConstantMatrixType::isValidElementType(ArgTy, S.getLangOpts())) {
       return S.Diag(Loc, diag::err_builtin_invalid_arg_type)
              << ArgOrdinal << /* vector */ 2 << /* integer */ 1 << /* fp */ 1
              << ArgTy;
@@ -16545,7 +16545,7 @@ ExprResult Sema::BuiltinMatrixColumnMajorLoad(CallExpr *TheCall,
   } else {
     ElementTy = PtrTy->getPointeeType().getUnqualifiedType();
 
-    if (!ConstantMatrixType::isValidElementType(ElementTy)) {
+    if (!ConstantMatrixType::isValidElementType(ElementTy, getLangOpts())) {
       Diag(PtrExpr->getBeginLoc(), diag::err_builtin_invalid_arg_type)
           << PtrArgIdx + 1 << 0 << /* pointer to element ty */ 5
           << /* no fp */ 0 << PtrExpr->getType();
diff --git a/clang/lib/Sema/SemaType.cpp b/clang/lib/Sema/SemaType.cpp
index fd64d4456cbfa..7ef83433326ed 100644
--- a/clang/lib/Sema/SemaType.cpp
+++ b/clang/lib/Sema/SemaType.cpp
@@ -2467,7 +2467,7 @@ QualType Sema::BuildMatrixType(QualType ElementTy, Expr *NumRows, Expr *NumCols,
 
   // Check element type, if it is not dependent.
   if (!ElementTy->isDependentType() &&
-      !MatrixType::isValidElementType(ElementTy)) {
+      !MatrixType::isValidElementType(ElementTy, getLangOpts())) {
     Diag(AttrLoc, diag::err_attribute_invalid_matrix_type) << ElementTy;
     return QualType();
   }
diff --git a/clang/test/CodeGenHLSL/BoolMatrix.hlsl b/clang/test/CodeGenHLSL/BoolMatrix.hlsl
new file mode 100644
index 0000000000000..da90738b68b96
--- /dev/null
+++ b/clang/test/CodeGenHLSL/BoolMatrix.hlsl
@@ -0,0 +1,151 @@
+// NOTE: Assertions have been autogenerated by utils/update_cc_test_checks.py UTC_ARGS: --version 6
+// RUN: %clang_cc1 -finclude-default-header -triple dxil-pc-shadermodel6.3-library -emit-llvm -disable-llvm-passes -o - %s | FileCheck %s
+
+
+struct S {
+    bool2x2 bM;
+    float f;
+};
+
+// CHECK-LABEL: define hidden noundef i1 @_Z3fn1v(
+// CHECK-SAME: ) #[[ATTR0:[0-9]+]] {
+// CHECK-NEXT:  [[ENTRY:.*:]]
+// CHECK-NEXT:    [[RETVAL:%.*]] = alloca i1, align 4
+// CHECK-NEXT:    [[B:%.*]] = alloca [4 x i32], align 4
+// CHECK-NEXT:    store <4 x i1> splat (i1 true), ptr [[B]], align 4
+// CHECK-NEXT:    [[TMP0:%.*]] = load <4 x i32>, ptr [[B]], align 4
+// CHECK-NEXT:    [[MATRIXEXT:%.*]] = extractelement <4 x i32> [[TMP0]], i32 0
+// CHECK-NEXT:    store i32 [[MATRIXEXT]], ptr [[RETVAL]], align 4
+// CHECK-NEXT:    [[TMP1:%.*]] = load i1, ptr [[RETVAL]], align 4
+// CHECK-NEXT:    ret i1 [[TMP1]]
+//
+bool fn1() {
+  bool2x2 B = {true,true,true,true};
+  return B[0][0];
+}
+
+// CHECK-LABEL: define hidden noundef <4 x i1> @_Z3fn2b(
+// CHECK-SAME: i1 noundef [[V:%.*]]) #[[ATTR0]] {
+// CHECK-NEXT:  [[ENTRY:.*:]]
+// CHECK-NEXT:    [[RETVAL:%.*]] = alloca <4 x i1>, align 4
+// CHECK-NEXT:    [[V_ADDR:%.*]] = alloca i32, align 4
+// CHECK-NEXT:    [[A:%.*]] = alloca [4 x i32], align 4
+// CHECK-NEXT:    [[STOREDV:%.*]] = zext i1 [[V]] to i32
+// CHECK-NEXT:    store i32 [[STOREDV]], ptr [[V_ADDR]], align 4
+// CHECK-NEXT:    [[TMP0:%.*]] = load i32, ptr [[V_ADDR]], align 4
+// CHECK-NEXT:    [[LOADEDV:%.*]] = trunc i32 [[TMP0]] to i1
+// CHECK-NEXT:    [[VECINIT:%.*]] = insertelement <4 x i1> poison, i1 [[LOADEDV]], i32 0
+// CHECK-NEXT:    [[TMP1:%.*]] = load i32, ptr [[V_ADDR]], align 4
+// CHECK-NEXT:    [[LOADEDV1:%.*]] = trunc i32 [[TMP1]] to i1
+// CHECK-NEXT:    [[VECINIT2:%.*]] = insertelement <4 x i1> [[VECINIT]], i1 [[LOADEDV1]], i32 1
+// CHECK-NEXT:    [[VECINIT3:%.*]] = insertelement <4 x i1> [[VECINIT2]], i1 true, i32 2
+// CHECK-NEXT:    [[VECINIT4:%.*]] = insertelement <4 x i1> [[VECINIT3]], i1 false, i32 3
+// CHECK-NEXT:    store <4 x i1> [[VECINIT4]], ptr [[A]], align 4
+// CHECK-NEXT:    [[TMP2:%.*]] = load <4 x i32>, ptr [[A]], align 4
+// CHECK-NEXT:    store <4 x i32> [[TMP2]], ptr [[RETVAL]], align 4
+// CHECK-NEXT:    [[TMP3:%.*]] = load <4 x i1>, ptr [[RETVAL]], align 4
+// CHECK-NEXT:    ret <4 x i1> [[TMP3]]
+//
+bool2x2 fn2(bool V) {
+  bool2x2 A = {V, true, V, false};
+  return A;
+}
+
+// CHECK-LABEL: define hidden noundef i1 @_Z3fn3v(
+// CHECK-SAME: ) #[[ATTR0]] {
+// CHECK-NEXT:  [[ENTRY:.*:]]
+// CHECK-NEXT:    [[RETVAL:%.*]] = alloca i1, align 4
+// CHECK-NEXT:    [[S:%.*]] = alloca [[STRUCT_S:%.*]], align 1
+// CHECK-NEXT:    [[BM:%.*]] = getelementptr inbounds nuw [[STRUCT_S]], ptr [[S]], i32 0, i32 0
+// CHECK-NEXT:    store <4 x i1> <i1 true, i1 false, i1 true, i1 false>, ptr [[BM]], align 1
+// CHECK-NEXT:    [[F:%.*]] = getelementptr inbounds nuw [[STRUCT_S]], ptr [[S]], i32 0, i32 1
+// CHECK-NEXT:    store float 1.000000e+00, ptr [[F]], align 1
+// CHECK-NEXT:    [[BM1:%.*]] = getelementptr inbounds nuw [[STRUCT_S]], ptr [[S]], i32 0, i32 0
+// CHECK-NEXT:    [[TMP0:%.*]] = load <4 x i32>, ptr [[BM1]], align 1
+// CHECK-NEXT:    [[MATRIXEXT:%.*]] = extractelement <4 x i32> [[TMP0]], i32 0
+// CHECK-NEXT:    store i32 [[MATRIXEXT]], ptr [[RETVAL]], align 4
+// CHECK-NEXT:    [[TMP1:%.*]] = load i1, ptr [[RETVAL]], align 4
+// CHECK-NEXT:    ret i1 [[TMP1]]
+//
+bool fn3() {
+  S s = {{true,true, false, false}, 1.0};
+  return s.bM[0][0];
+}
+
+// CHECK-LABEL: define hidden noundef i1 @_Z3fn4v(
+// CHECK-SAME: ) #[[ATTR0]] {
+// CHECK-NEXT:  [[ENTRY:.*:]]
+// CHECK-NEXT:    [[RETVAL:%.*]] = alloca i1, align 4
+// CHECK-NEXT:    [[ARR:%.*]] = alloca [2 x [4 x i32]], align 4
+// CHECK-NEXT:    store <4 x i1> splat (i1 true), ptr [[ARR]], align 4
+// CHECK-NEXT:    [[ARRAYINIT_ELEMENT:%.*]] = getelementptr inbounds [4 x i32], ptr [[ARR]], i32 1
+// CHECK-NEXT:    store <4 x i1> zeroinitializer, ptr [[ARRAYINIT_ELEMENT]], align 4
+// CHECK-NEXT:    [[ARRAYIDX:%.*]] = getelementptr inbounds [2 x [4 x i32]], ptr [[ARR]], i32 0, i32 0
+// CHECK-NEXT:    [[TMP0:%.*]] = load <4 x i32>, ptr [[ARRAYIDX]], align 4
+// CHECK-NEXT:    [[MATRIXEXT:%.*]] = extractelement <4 x i32> [[TMP0]], i32 1
+// CHECK-NEXT:    store i32 [[MATRIXEXT]], ptr [[RETVAL]], align 4
+// CHECK-NEXT:    [[TMP1:%.*]] = load i1, ptr [[RETVAL]], align 4
+// CHECK-NEXT:    ret i1 [[TMP1]]
+//
+bool fn4() {
+  bool2x2 Arr[2] = {{true,true,true,true}, {false,false,false,false}};
+  return Arr[0][1][0];
+}
+
+// CHECK-LABEL: define hidden void @_Z3fn5v(
+// CHECK-SAME: ) #[[ATTR0]] {
+// CHECK-NEXT:  [[ENTRY:.*:]]
+// CHECK-NEXT:    [[M:%.*]] = alloca [4 x i32], align 4
+// CHECK-NEXT:    store <4 x i1> splat (i1 true), ptr [[M]], align 4
+// CHECK-NEXT:    [[TMP0:%.*]] = load <4 x i32>, ptr [[M]], align 4
+// CHECK-NEXT:    [[MATINS:%.*]] = insertelement <4 x i32> [[TMP0]], i32 0, i32 3
+// CHECK-NEXT:    store <4 x i32> [[MATINS]], ptr [[M]], align 4
+// CHECK-NEXT:    ret void
+//
+void fn5() {
+  bool2x2 M = {true,true,true,true};
+  M[1][1] = false;
+}
+
+// CHECK-LABEL: define hidden void @_Z3fn6v(
+// CHECK-SAME: ) #[[ATTR0]] {
+// CHECK-NEXT:  [[ENTRY:.*:]]
+// CHECK-NEXT:    [[V:%.*]] = alloca i32, align 4
+// CHECK-NEXT:    [[S:%.*]] = alloca [[STRUCT_S:%.*]], align 1
+// CHECK-NEXT:    store i32 0, ptr [[V]], align 4
+// CHECK-NEXT:    [[BM:%.*]] = getelementptr inbounds nuw [[STRUCT_S]], ptr [[S]], i32 0, i32 0
+// CHECK-NEXT:    store <4 x i1> <i1 true, i1 false, i1 true, i1 false>, ptr [[BM]], align 1
+// CHECK-NEXT:    [[F:%.*]] = getelementptr inbounds nuw [[STRUCT_S]], ptr [[S]], i32 0, i32 1
+// CHECK-NEXT:    store float 1.000000e+00, ptr [[F]], align 1
+// CHECK-NEXT:    [[TMP0:%.*]] = load i32, ptr [[V]], align 4
+// CHECK-NEXT:    [[LOADEDV:%.*]] = trunc i32 [[TMP0]] to i1
+// CHECK-NEXT:    [[BM1:%.*]] = getelementptr inbounds nuw [[STRUCT_S]], ptr [[S]], i32 0, i32 0
+// CHECK-NEXT:    [[TMP1:%.*]] = load <4 x i32>, ptr [[BM1]], align 1
+// CHECK-NEXT:    [[TMP2:%.*]] = zext i1 [[LOADEDV]] to i32
+// CHECK-NEXT:    [[MATINS:%.*]] = insertelement <4 x i32> [[TMP1]], i32 [[TMP2]], i32 1
+// CHECK-NEXT:    store <4 x i32> [[MATINS]], ptr [[BM1]], align 1
+// CHECK-NEXT:    ret void
+//
+void fn6() {
+  bool V = false;
+  S s = {{true,true,false,false}, 1.0};
+  s.bM[1][0] = V;
+}
+
+// CHECK-LABEL: define hidden void @_Z3fn7v(
+// CHECK-SAME: ) #[[ATTR0]] {
+// CHECK-NEXT:  [[ENTRY:.*:]]
+// CHECK-NEXT:    [[ARR:%.*]] = alloca [2 x [4 x i32]], align 4
+// CHECK-NEXT:    store <4 x i1> splat (i1 true), ptr [[ARR]], align 4
+// CHECK-NEXT:    [[ARRAYINIT_ELEMENT:%.*]] = getelementptr inbounds [4 x i32], ptr [[ARR]], i32 1
+// CHECK-NEXT:    store <4 x i1> zeroinitializer, ptr [[ARRAYINIT_ELEMENT]], align 4
+// CHECK-NEXT:    [[ARRAYIDX:%.*]] = getelementptr inbounds [2 x [4 x i32]], ptr [[ARR]], i32 0, i32 0
+// CHECK-NEXT:    [[TMP0:%.*]] = load <4 x i32>, ptr [[ARRAYIDX]], align 4
+// CHECK-NEXT:    [[MATINS:%.*]] = insertelement <4 x i32> [[TMP0]], i32 0, i32 1
+// CHECK-NEXT:    store <4 x i32> [[MATINS]], ptr [[ARRAYIDX]], align 4
+// CHECK-NEXT:    ret void
+//
+void fn7() {
+  bool2x2 Arr[2] = {{true,true,true,true}, {false,false,false,false}};
+  Arr[0][1][0] = false;
+}
diff --git a/clang/test/CodeGenHLSL/basic_types.hlsl b/clang/test/CodeGenHLSL/basic_types.hlsl
index 8836126934957..677a9a8f5d1de 100644
--- a/clang/test/CodeGenHLSL/basic_types.hlsl
+++ b/clang/test/CodeGenHLSL/basic_types.hlsl
@@ -38,6 +38,22 @@
 // CHECK: @double2_Val = external hidden addrspace(2) global <2 x double>, align 16
 // CHECK: @double3_Val = external hidden addrspace(2) global <3 x double>, align 32
 // CHECK: @double4_Val = external hidden addrspace(2) global <4 x double>, align 32
+// CHECK: @bool1x1_Val = external hidden addrspace(2) global [1 x i32], align 4
+// CHECK: @bool1x2_Val = external hidden addrspace(2) global [2 x i32], align 4
+// CHECK: @bool1x3_Val = external hidden addrspace(2) global [3 x i32], align 4
+// CHECK: @bool1x4_Val = external hidden addrspace(2) global [4 x i32], align 4
+// CHECK: @bool2x1_Val = external hidden addrspace(2) global [2 x i32], align 4
+// CHECK: @bool2x2_Val = external hidden addrspace(2) global [4 x i32], align 4
+// CHECK: @bool2x3_Val = external hidden addrspace(2) global [6 x i32], align 4
+// CHECK: @bool2x4_Val = external hidden addrspace(2) global [8 x i32], align 4
+// CHECK: @bool3x1_Val = external hidden addrspace(2) global [3 x i32], align 4
+// CHECK: @bool3x2_Val = external hidden addrspace(2) global [6 x i32], align 4
+// CHECK: @bool3x3_Val = external hidden addrspace(2) global [9 x i32], align 4
+// CHECK: @bool3x4_Val = external hidden addrspace(2) global [12 x i32], align 4
+// CHECK: @bool4x1_Val = external hidden addrspace(2) global [4 x i32], align 4
+// CHECK: @bool4x2_Val = external hidden addrspace(2) global [8 x i32], align 4
+// CHECK: @bool4x3_Val = external hidden addrspace(2) global [12 x i32], align 4
+// CHECK: @bool4x4_Val = external hidden addrspace(2) global [16 x i32], align 4
 
 #ifdef NAMESPACED
 #define TYPE_DECL(T)  hlsl::T T##_Val
@@ -93,3 +109,20 @@ TYPE_DECL( float4  );
 TYPE_DECL( double2 );
 TYPE_DECL( double3 );
 TYPE_DECL( double4 );
+
+TYPE_DECL( bool1x1 );
+TYPE_DECL( bool1x2 );
+TYPE_DECL( bool1x3 );
+TYPE_DECL( bool1x4 );
+TYPE_DECL( bool2x1 );
+TYPE_DECL( bool2x2 );
+TYPE_DECL( bool2x3 );
+TYPE_DECL( bool2x4 );
+TYPE_DECL( bool3x1 );
+TYPE_DECL( bool3x2 );
+TYPE_DECL( bool3x3 );
+TYPE_DECL( bool3x4 );
+TYPE_DECL( bool4x1 );
+TYPE_DECL( bool4x2 );
+TYPE_DECL( bool4x3 );
+TYPE_DECL( bool4x4 );

@llvmbot
Copy link
Member

llvmbot commented Dec 7, 2025

@llvm/pr-subscribers-backend-x86

Author: Farzon Lotfi (farzonl)

Changes

fixes #171049
fixes #171050

  • Allow Bools for matrix type when in HLSL mode
  • use ConvertTypeForMem to figure out the bool size
  • Add Bool matrix types to hlsl_basic_types.h

Full diff: https://github.com/llvm/llvm-project/pull/171051.diff

9 Files Affected:

  • (modified) clang/include/clang/AST/TypeBase.h (+27-4)
  • (modified) clang/lib/AST/ASTContext.cpp (+1-1)
  • (modified) clang/lib/CodeGen/CGExpr.cpp (+6-1)
  • (modified) clang/lib/CodeGen/CodeGenTypes.cpp (+4-1)
  • (modified) clang/lib/Headers/hlsl/hlsl_basic_types.h (+17)
  • (modified) clang/lib/Sema/SemaChecking.cpp (+2-2)
  • (modified) clang/lib/Sema/SemaType.cpp (+1-1)
  • (added) clang/test/CodeGenHLSL/BoolMatrix.hlsl (+151)
  • (modified) clang/test/CodeGenHLSL/basic_types.hlsl (+33)
diff --git a/clang/include/clang/AST/TypeBase.h b/clang/include/clang/AST/TypeBase.h
index 30b9efe5a31b7..cf6897b6e515c 100644
--- a/clang/include/clang/AST/TypeBase.h
+++ b/clang/include/clang/AST/TypeBase.h
@@ -2637,6 +2637,7 @@ class alignas(TypeAlignment) Type : public ExtQualsTypeCommonBase {
   bool isVectorType() const;                    // GCC vector type.
   bool isExtVectorType() const;                 // Extended vector type.
   bool isExtVectorBoolType() const;             // Extended vector type with bool element.
+  bool isConstantMatrixBoolType() const; // Matrix type with bool element.
   // Extended vector type with bool element that is packed. HLSL doesn't pack
   // its bool vectors.
   bool isPackedVectorBoolType(const ASTContext &ctx) const;
@@ -4352,12 +4353,26 @@ class MatrixType : public Type, public llvm::FoldingSetNode {
 
   /// Valid elements types are the following:
   /// * an integer type (as in C23 6.2.5p22), but excluding enumerated types
-  ///   and _Bool
+  ///   and _Bool (except that in HLSL, bool is allowed)
   /// * the standard floating types float or double
   /// * a half-precision floating point type, if one is supported on the target
-  static bool isValidElementType(QualType T) {
-    return T->isDependentType() ||
-           (T->isRealType() && !T->isBooleanType() && !T->isEnumeralType());
+  static bool isValidElementType(QualType T, const LangOptions &LangOpts) {
+    // Dependent is always okay
+    if (T->isDependentType())
+      return true;
+
+    // Enums are never okay
+    if (T->isEnumeralType())
+      return false;
+
+    // In HLSL, bool is allowed as a matrix element type.
+    // Note: isRealType includes bool so don't need to check
+    if (LangOpts.HLSL)
+      return T->isRealType();
+
+    // In non-HLSL modes, follow the existing rule:
+    // real type, but not _Bool.
+    return T->isRealType() && !T->isBooleanType();
   }
 
   bool isSugared() const { return false; }
@@ -8665,6 +8680,14 @@ inline bool Type::isExtVectorBoolType() const {
   return cast<ExtVectorType>(CanonicalType)->getElementType()->isBooleanType();
 }
 
+inline bool Type::isConstantMatrixBoolType() const {
+  if (!isConstantMatrixType())
+    return false;
+  return cast<ConstantMatrixType>(CanonicalType)
+      ->getElementType()
+      ->isBooleanType();
+}
+
 inline bool Type::isSubscriptableVectorType() const {
   return isVectorType() || isSveVLSBuiltinType();
 }
diff --git a/clang/lib/AST/ASTContext.cpp b/clang/lib/AST/ASTContext.cpp
index 404ce3ffd77c7..5ca76c79df7c6 100644
--- a/clang/lib/AST/ASTContext.cpp
+++ b/clang/lib/AST/ASTContext.cpp
@@ -4712,7 +4712,7 @@ QualType ASTContext::getConstantMatrixType(QualType ElementTy, unsigned NumRows,
   ConstantMatrixType::Profile(ID, ElementTy, NumRows, NumColumns,
                               Type::ConstantMatrix);
 
-  assert(MatrixType::isValidElementType(ElementTy) &&
+  assert(MatrixType::isValidElementType(ElementTy, getLangOpts()) &&
          "need a valid element type");
   assert(NumRows > 0 && NumRows <= LangOpts.MaxMatrixDimension &&
          NumColumns > 0 && NumColumns <= LangOpts.MaxMatrixDimension &&
diff --git a/clang/lib/CodeGen/CGExpr.cpp b/clang/lib/CodeGen/CGExpr.cpp
index 3bde8e1fa2ac3..b44c1a7d18120 100644
--- a/clang/lib/CodeGen/CGExpr.cpp
+++ b/clang/lib/CodeGen/CGExpr.cpp
@@ -2655,8 +2655,13 @@ void CodeGenFunction::EmitStoreThroughLValue(RValue Src, LValue Dst,
         MB.CreateIndexAssumption(Idx, MatTy->getNumElementsFlattened());
       }
       llvm::Instruction *Load = Builder.CreateLoad(Dst.getMatrixAddress());
+      llvm::Value *InsertVal = Src.getScalarVal();
+      if (getLangOpts().HLSL && InsertVal->getType()->isIntegerTy(1)) {
+        llvm::Type *StorageElmTy = Load->getType()->getScalarType();
+        InsertVal = Builder.CreateZExt(InsertVal, StorageElmTy);
+      }
       llvm::Value *Vec =
-          Builder.CreateInsertElement(Load, Src.getScalarVal(), Idx, "matins");
+          Builder.CreateInsertElement(Load, InsertVal, Idx, "matins");
       auto *I = Builder.CreateStore(Vec, Dst.getMatrixAddress(),
                                     Dst.isVolatileQualified());
       addInstToCurrentSourceAtom(I, Vec);
diff --git a/clang/lib/CodeGen/CodeGenTypes.cpp b/clang/lib/CodeGen/CodeGenTypes.cpp
index be862cf07f177..a41bf86d6f95c 100644
--- a/clang/lib/CodeGen/CodeGenTypes.cpp
+++ b/clang/lib/CodeGen/CodeGenTypes.cpp
@@ -104,7 +104,10 @@ llvm::Type *CodeGenTypes::ConvertTypeForMem(QualType T) {
   if (T->isConstantMatrixType()) {
     const Type *Ty = Context.getCanonicalType(T).getTypePtr();
     const ConstantMatrixType *MT = cast<ConstantMatrixType>(Ty);
-    return llvm::ArrayType::get(ConvertType(MT->getElementType()),
+    llvm::Type *IRElemTy = ConvertType(MT->getElementType());
+    if (T->isConstantMatrixBoolType() && Context.getLangOpts().HLSL)
+      IRElemTy = ConvertTypeForMem(Context.BoolTy);
+    return llvm::ArrayType::get(IRElemTy,
                                 MT->getNumRows() * MT->getNumColumns());
   }
 
diff --git a/clang/lib/Headers/hlsl/hlsl_basic_types.h b/clang/lib/Headers/hlsl/hlsl_basic_types.h
index fc1e265067714..b1d87c51de9bb 100644
--- a/clang/lib/Headers/hlsl/hlsl_basic_types.h
+++ b/clang/lib/Headers/hlsl/hlsl_basic_types.h
@@ -150,6 +150,23 @@ typedef matrix<uint16_t, 4, 3> uint16_t4x3;
 typedef matrix<uint16_t, 4, 4> uint16_t4x4;
 #endif
 
+typedef matrix<bool, 1, 1> bool1x1;
+typedef matrix<bool, 1, 2> bool1x2;
+typedef matrix<bool, 1, 3> bool1x3;
+typedef matrix<bool, 1, 4> bool1x4;
+typedef matrix<bool, 2, 1> bool2x1;
+typedef matrix<bool, 2, 2> bool2x2;
+typedef matrix<bool, 2, 3> bool2x3;
+typedef matrix<bool, 2, 4> bool2x4;
+typedef matrix<bool, 3, 1> bool3x1;
+typedef matrix<bool, 3, 2> bool3x2;
+typedef matrix<bool, 3, 3> bool3x3;
+typedef matrix<bool, 3, 4> bool3x4;
+typedef matrix<bool, 4, 1> bool4x1;
+typedef matrix<bool, 4, 2> bool4x2;
+typedef matrix<bool, 4, 3> bool4x3;
+typedef matrix<bool, 4, 4> bool4x4;
+
 typedef matrix<int, 1, 1> int1x1;
 typedef matrix<int, 1, 2> int1x2;
 typedef matrix<int, 1, 3> int1x3;
diff --git a/clang/lib/Sema/SemaChecking.cpp b/clang/lib/Sema/SemaChecking.cpp
index 02c838bc4a862..bdf2e08400801 100644
--- a/clang/lib/Sema/SemaChecking.cpp
+++ b/clang/lib/Sema/SemaChecking.cpp
@@ -2145,7 +2145,7 @@ checkMathBuiltinElementType(Sema &S, SourceLocation Loc, QualType ArgTy,
   switch (ArgTyRestr) {
   case Sema::EltwiseBuiltinArgTyRestriction::None:
     if (!ArgTy->getAs<VectorType>() &&
-        !ConstantMatrixType::isValidElementType(ArgTy)) {
+        !ConstantMatrixType::isValidElementType(ArgTy, S.getLangOpts())) {
       return S.Diag(Loc, diag::err_builtin_invalid_arg_type)
              << ArgOrdinal << /* vector */ 2 << /* integer */ 1 << /* fp */ 1
              << ArgTy;
@@ -16545,7 +16545,7 @@ ExprResult Sema::BuiltinMatrixColumnMajorLoad(CallExpr *TheCall,
   } else {
     ElementTy = PtrTy->getPointeeType().getUnqualifiedType();
 
-    if (!ConstantMatrixType::isValidElementType(ElementTy)) {
+    if (!ConstantMatrixType::isValidElementType(ElementTy, getLangOpts())) {
       Diag(PtrExpr->getBeginLoc(), diag::err_builtin_invalid_arg_type)
           << PtrArgIdx + 1 << 0 << /* pointer to element ty */ 5
           << /* no fp */ 0 << PtrExpr->getType();
diff --git a/clang/lib/Sema/SemaType.cpp b/clang/lib/Sema/SemaType.cpp
index fd64d4456cbfa..7ef83433326ed 100644
--- a/clang/lib/Sema/SemaType.cpp
+++ b/clang/lib/Sema/SemaType.cpp
@@ -2467,7 +2467,7 @@ QualType Sema::BuildMatrixType(QualType ElementTy, Expr *NumRows, Expr *NumCols,
 
   // Check element type, if it is not dependent.
   if (!ElementTy->isDependentType() &&
-      !MatrixType::isValidElementType(ElementTy)) {
+      !MatrixType::isValidElementType(ElementTy, getLangOpts())) {
     Diag(AttrLoc, diag::err_attribute_invalid_matrix_type) << ElementTy;
     return QualType();
   }
diff --git a/clang/test/CodeGenHLSL/BoolMatrix.hlsl b/clang/test/CodeGenHLSL/BoolMatrix.hlsl
new file mode 100644
index 0000000000000..da90738b68b96
--- /dev/null
+++ b/clang/test/CodeGenHLSL/BoolMatrix.hlsl
@@ -0,0 +1,151 @@
+// NOTE: Assertions have been autogenerated by utils/update_cc_test_checks.py UTC_ARGS: --version 6
+// RUN: %clang_cc1 -finclude-default-header -triple dxil-pc-shadermodel6.3-library -emit-llvm -disable-llvm-passes -o - %s | FileCheck %s
+
+
+struct S {
+    bool2x2 bM;
+    float f;
+};
+
+// CHECK-LABEL: define hidden noundef i1 @_Z3fn1v(
+// CHECK-SAME: ) #[[ATTR0:[0-9]+]] {
+// CHECK-NEXT:  [[ENTRY:.*:]]
+// CHECK-NEXT:    [[RETVAL:%.*]] = alloca i1, align 4
+// CHECK-NEXT:    [[B:%.*]] = alloca [4 x i32], align 4
+// CHECK-NEXT:    store <4 x i1> splat (i1 true), ptr [[B]], align 4
+// CHECK-NEXT:    [[TMP0:%.*]] = load <4 x i32>, ptr [[B]], align 4
+// CHECK-NEXT:    [[MATRIXEXT:%.*]] = extractelement <4 x i32> [[TMP0]], i32 0
+// CHECK-NEXT:    store i32 [[MATRIXEXT]], ptr [[RETVAL]], align 4
+// CHECK-NEXT:    [[TMP1:%.*]] = load i1, ptr [[RETVAL]], align 4
+// CHECK-NEXT:    ret i1 [[TMP1]]
+//
+bool fn1() {
+  bool2x2 B = {true,true,true,true};
+  return B[0][0];
+}
+
+// CHECK-LABEL: define hidden noundef <4 x i1> @_Z3fn2b(
+// CHECK-SAME: i1 noundef [[V:%.*]]) #[[ATTR0]] {
+// CHECK-NEXT:  [[ENTRY:.*:]]
+// CHECK-NEXT:    [[RETVAL:%.*]] = alloca <4 x i1>, align 4
+// CHECK-NEXT:    [[V_ADDR:%.*]] = alloca i32, align 4
+// CHECK-NEXT:    [[A:%.*]] = alloca [4 x i32], align 4
+// CHECK-NEXT:    [[STOREDV:%.*]] = zext i1 [[V]] to i32
+// CHECK-NEXT:    store i32 [[STOREDV]], ptr [[V_ADDR]], align 4
+// CHECK-NEXT:    [[TMP0:%.*]] = load i32, ptr [[V_ADDR]], align 4
+// CHECK-NEXT:    [[LOADEDV:%.*]] = trunc i32 [[TMP0]] to i1
+// CHECK-NEXT:    [[VECINIT:%.*]] = insertelement <4 x i1> poison, i1 [[LOADEDV]], i32 0
+// CHECK-NEXT:    [[TMP1:%.*]] = load i32, ptr [[V_ADDR]], align 4
+// CHECK-NEXT:    [[LOADEDV1:%.*]] = trunc i32 [[TMP1]] to i1
+// CHECK-NEXT:    [[VECINIT2:%.*]] = insertelement <4 x i1> [[VECINIT]], i1 [[LOADEDV1]], i32 1
+// CHECK-NEXT:    [[VECINIT3:%.*]] = insertelement <4 x i1> [[VECINIT2]], i1 true, i32 2
+// CHECK-NEXT:    [[VECINIT4:%.*]] = insertelement <4 x i1> [[VECINIT3]], i1 false, i32 3
+// CHECK-NEXT:    store <4 x i1> [[VECINIT4]], ptr [[A]], align 4
+// CHECK-NEXT:    [[TMP2:%.*]] = load <4 x i32>, ptr [[A]], align 4
+// CHECK-NEXT:    store <4 x i32> [[TMP2]], ptr [[RETVAL]], align 4
+// CHECK-NEXT:    [[TMP3:%.*]] = load <4 x i1>, ptr [[RETVAL]], align 4
+// CHECK-NEXT:    ret <4 x i1> [[TMP3]]
+//
+bool2x2 fn2(bool V) {
+  bool2x2 A = {V, true, V, false};
+  return A;
+}
+
+// CHECK-LABEL: define hidden noundef i1 @_Z3fn3v(
+// CHECK-SAME: ) #[[ATTR0]] {
+// CHECK-NEXT:  [[ENTRY:.*:]]
+// CHECK-NEXT:    [[RETVAL:%.*]] = alloca i1, align 4
+// CHECK-NEXT:    [[S:%.*]] = alloca [[STRUCT_S:%.*]], align 1
+// CHECK-NEXT:    [[BM:%.*]] = getelementptr inbounds nuw [[STRUCT_S]], ptr [[S]], i32 0, i32 0
+// CHECK-NEXT:    store <4 x i1> <i1 true, i1 false, i1 true, i1 false>, ptr [[BM]], align 1
+// CHECK-NEXT:    [[F:%.*]] = getelementptr inbounds nuw [[STRUCT_S]], ptr [[S]], i32 0, i32 1
+// CHECK-NEXT:    store float 1.000000e+00, ptr [[F]], align 1
+// CHECK-NEXT:    [[BM1:%.*]] = getelementptr inbounds nuw [[STRUCT_S]], ptr [[S]], i32 0, i32 0
+// CHECK-NEXT:    [[TMP0:%.*]] = load <4 x i32>, ptr [[BM1]], align 1
+// CHECK-NEXT:    [[MATRIXEXT:%.*]] = extractelement <4 x i32> [[TMP0]], i32 0
+// CHECK-NEXT:    store i32 [[MATRIXEXT]], ptr [[RETVAL]], align 4
+// CHECK-NEXT:    [[TMP1:%.*]] = load i1, ptr [[RETVAL]], align 4
+// CHECK-NEXT:    ret i1 [[TMP1]]
+//
+bool fn3() {
+  S s = {{true,true, false, false}, 1.0};
+  return s.bM[0][0];
+}
+
+// CHECK-LABEL: define hidden noundef i1 @_Z3fn4v(
+// CHECK-SAME: ) #[[ATTR0]] {
+// CHECK-NEXT:  [[ENTRY:.*:]]
+// CHECK-NEXT:    [[RETVAL:%.*]] = alloca i1, align 4
+// CHECK-NEXT:    [[ARR:%.*]] = alloca [2 x [4 x i32]], align 4
+// CHECK-NEXT:    store <4 x i1> splat (i1 true), ptr [[ARR]], align 4
+// CHECK-NEXT:    [[ARRAYINIT_ELEMENT:%.*]] = getelementptr inbounds [4 x i32], ptr [[ARR]], i32 1
+// CHECK-NEXT:    store <4 x i1> zeroinitializer, ptr [[ARRAYINIT_ELEMENT]], align 4
+// CHECK-NEXT:    [[ARRAYIDX:%.*]] = getelementptr inbounds [2 x [4 x i32]], ptr [[ARR]], i32 0, i32 0
+// CHECK-NEXT:    [[TMP0:%.*]] = load <4 x i32>, ptr [[ARRAYIDX]], align 4
+// CHECK-NEXT:    [[MATRIXEXT:%.*]] = extractelement <4 x i32> [[TMP0]], i32 1
+// CHECK-NEXT:    store i32 [[MATRIXEXT]], ptr [[RETVAL]], align 4
+// CHECK-NEXT:    [[TMP1:%.*]] = load i1, ptr [[RETVAL]], align 4
+// CHECK-NEXT:    ret i1 [[TMP1]]
+//
+bool fn4() {
+  bool2x2 Arr[2] = {{true,true,true,true}, {false,false,false,false}};
+  return Arr[0][1][0];
+}
+
+// CHECK-LABEL: define hidden void @_Z3fn5v(
+// CHECK-SAME: ) #[[ATTR0]] {
+// CHECK-NEXT:  [[ENTRY:.*:]]
+// CHECK-NEXT:    [[M:%.*]] = alloca [4 x i32], align 4
+// CHECK-NEXT:    store <4 x i1> splat (i1 true), ptr [[M]], align 4
+// CHECK-NEXT:    [[TMP0:%.*]] = load <4 x i32>, ptr [[M]], align 4
+// CHECK-NEXT:    [[MATINS:%.*]] = insertelement <4 x i32> [[TMP0]], i32 0, i32 3
+// CHECK-NEXT:    store <4 x i32> [[MATINS]], ptr [[M]], align 4
+// CHECK-NEXT:    ret void
+//
+void fn5() {
+  bool2x2 M = {true,true,true,true};
+  M[1][1] = false;
+}
+
+// CHECK-LABEL: define hidden void @_Z3fn6v(
+// CHECK-SAME: ) #[[ATTR0]] {
+// CHECK-NEXT:  [[ENTRY:.*:]]
+// CHECK-NEXT:    [[V:%.*]] = alloca i32, align 4
+// CHECK-NEXT:    [[S:%.*]] = alloca [[STRUCT_S:%.*]], align 1
+// CHECK-NEXT:    store i32 0, ptr [[V]], align 4
+// CHECK-NEXT:    [[BM:%.*]] = getelementptr inbounds nuw [[STRUCT_S]], ptr [[S]], i32 0, i32 0
+// CHECK-NEXT:    store <4 x i1> <i1 true, i1 false, i1 true, i1 false>, ptr [[BM]], align 1
+// CHECK-NEXT:    [[F:%.*]] = getelementptr inbounds nuw [[STRUCT_S]], ptr [[S]], i32 0, i32 1
+// CHECK-NEXT:    store float 1.000000e+00, ptr [[F]], align 1
+// CHECK-NEXT:    [[TMP0:%.*]] = load i32, ptr [[V]], align 4
+// CHECK-NEXT:    [[LOADEDV:%.*]] = trunc i32 [[TMP0]] to i1
+// CHECK-NEXT:    [[BM1:%.*]] = getelementptr inbounds nuw [[STRUCT_S]], ptr [[S]], i32 0, i32 0
+// CHECK-NEXT:    [[TMP1:%.*]] = load <4 x i32>, ptr [[BM1]], align 1
+// CHECK-NEXT:    [[TMP2:%.*]] = zext i1 [[LOADEDV]] to i32
+// CHECK-NEXT:    [[MATINS:%.*]] = insertelement <4 x i32> [[TMP1]], i32 [[TMP2]], i32 1
+// CHECK-NEXT:    store <4 x i32> [[MATINS]], ptr [[BM1]], align 1
+// CHECK-NEXT:    ret void
+//
+void fn6() {
+  bool V = false;
+  S s = {{true,true,false,false}, 1.0};
+  s.bM[1][0] = V;
+}
+
+// CHECK-LABEL: define hidden void @_Z3fn7v(
+// CHECK-SAME: ) #[[ATTR0]] {
+// CHECK-NEXT:  [[ENTRY:.*:]]
+// CHECK-NEXT:    [[ARR:%.*]] = alloca [2 x [4 x i32]], align 4
+// CHECK-NEXT:    store <4 x i1> splat (i1 true), ptr [[ARR]], align 4
+// CHECK-NEXT:    [[ARRAYINIT_ELEMENT:%.*]] = getelementptr inbounds [4 x i32], ptr [[ARR]], i32 1
+// CHECK-NEXT:    store <4 x i1> zeroinitializer, ptr [[ARRAYINIT_ELEMENT]], align 4
+// CHECK-NEXT:    [[ARRAYIDX:%.*]] = getelementptr inbounds [2 x [4 x i32]], ptr [[ARR]], i32 0, i32 0
+// CHECK-NEXT:    [[TMP0:%.*]] = load <4 x i32>, ptr [[ARRAYIDX]], align 4
+// CHECK-NEXT:    [[MATINS:%.*]] = insertelement <4 x i32> [[TMP0]], i32 0, i32 1
+// CHECK-NEXT:    store <4 x i32> [[MATINS]], ptr [[ARRAYIDX]], align 4
+// CHECK-NEXT:    ret void
+//
+void fn7() {
+  bool2x2 Arr[2] = {{true,true,true,true}, {false,false,false,false}};
+  Arr[0][1][0] = false;
+}
diff --git a/clang/test/CodeGenHLSL/basic_types.hlsl b/clang/test/CodeGenHLSL/basic_types.hlsl
index 8836126934957..677a9a8f5d1de 100644
--- a/clang/test/CodeGenHLSL/basic_types.hlsl
+++ b/clang/test/CodeGenHLSL/basic_types.hlsl
@@ -38,6 +38,22 @@
 // CHECK: @double2_Val = external hidden addrspace(2) global <2 x double>, align 16
 // CHECK: @double3_Val = external hidden addrspace(2) global <3 x double>, align 32
 // CHECK: @double4_Val = external hidden addrspace(2) global <4 x double>, align 32
+// CHECK: @bool1x1_Val = external hidden addrspace(2) global [1 x i32], align 4
+// CHECK: @bool1x2_Val = external hidden addrspace(2) global [2 x i32], align 4
+// CHECK: @bool1x3_Val = external hidden addrspace(2) global [3 x i32], align 4
+// CHECK: @bool1x4_Val = external hidden addrspace(2) global [4 x i32], align 4
+// CHECK: @bool2x1_Val = external hidden addrspace(2) global [2 x i32], align 4
+// CHECK: @bool2x2_Val = external hidden addrspace(2) global [4 x i32], align 4
+// CHECK: @bool2x3_Val = external hidden addrspace(2) global [6 x i32], align 4
+// CHECK: @bool2x4_Val = external hidden addrspace(2) global [8 x i32], align 4
+// CHECK: @bool3x1_Val = external hidden addrspace(2) global [3 x i32], align 4
+// CHECK: @bool3x2_Val = external hidden addrspace(2) global [6 x i32], align 4
+// CHECK: @bool3x3_Val = external hidden addrspace(2) global [9 x i32], align 4
+// CHECK: @bool3x4_Val = external hidden addrspace(2) global [12 x i32], align 4
+// CHECK: @bool4x1_Val = external hidden addrspace(2) global [4 x i32], align 4
+// CHECK: @bool4x2_Val = external hidden addrspace(2) global [8 x i32], align 4
+// CHECK: @bool4x3_Val = external hidden addrspace(2) global [12 x i32], align 4
+// CHECK: @bool4x4_Val = external hidden addrspace(2) global [16 x i32], align 4
 
 #ifdef NAMESPACED
 #define TYPE_DECL(T)  hlsl::T T##_Val
@@ -93,3 +109,20 @@ TYPE_DECL( float4  );
 TYPE_DECL( double2 );
 TYPE_DECL( double3 );
 TYPE_DECL( double4 );
+
+TYPE_DECL( bool1x1 );
+TYPE_DECL( bool1x2 );
+TYPE_DECL( bool1x3 );
+TYPE_DECL( bool1x4 );
+TYPE_DECL( bool2x1 );
+TYPE_DECL( bool2x2 );
+TYPE_DECL( bool2x3 );
+TYPE_DECL( bool2x4 );
+TYPE_DECL( bool3x1 );
+TYPE_DECL( bool3x2 );
+TYPE_DECL( bool3x3 );
+TYPE_DECL( bool3x4 );
+TYPE_DECL( bool4x1 );
+TYPE_DECL( bool4x2 );
+TYPE_DECL( bool4x3 );
+TYPE_DECL( bool4x4 );

@github-actions
Copy link

github-actions bot commented Dec 7, 2025

🐧 Linux x64 Test Results

  • 111822 tests passed
  • 4485 tests skipped

✅ The build succeeded and all tests passed.

float test_builtin_clamp_bool_type_promotion(bool p0) {
return __builtin_hlsl_elementwise_clamp(p0, p0, p0);
// expected-error@-1 {{1st argument must be a vector, integer or floating-point type (was 'bool')}}
return __builtin_hlsl_elementwise_clamp(p0, p0, p0); // note: should not error
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

backend:X86 clang:codegen IR generation bugs: mangling, exceptions, etc. clang:frontend Language frontend issues, e.g. anything involving "Sema" clang:headers Headers provided by Clang, e.g. for intrinsics clang Clang issues not falling into any other category HLSL HLSL Language Support

Projects

None yet

Development

Successfully merging this pull request may close these issues.

Make sure matrix booleans are 32 bit see Add bool matrix types to clang/lib/Headers/hlsl/hlsl_basic_types.h

2 participants