Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
18 changes: 16 additions & 2 deletions tools/clang/lib/Headers/hlsl/dx/linalg.h
Original file line number Diff line number Diff line change
Expand Up @@ -194,6 +194,16 @@ template <ComponentEnum DstTy, ComponentEnum SrcTy, int SrcN> struct DstN {
ComponentTypeTraits<DstTy>::ElementsPerScalar;
};

template <SIZE_TYPE MVal, SIZE_TYPE NVal, bool Transposed> struct DimMN {
static const SIZE_TYPE M = MVal;
static const SIZE_TYPE N = NVal;
};

template <SIZE_TYPE MVal, SIZE_TYPE NVal> struct DimMN<MVal, NVal, true> {
static const SIZE_TYPE M = NVal;
static const SIZE_TYPE N = MVal;
};

} // namespace __detail

template <ComponentEnum ElementType, uint DimA> struct VectorRef {
Expand Down Expand Up @@ -242,8 +252,12 @@ class Matrix {

template <ComponentEnum NewCompTy, MatrixUseEnum NewUse = Use,
bool Transpose = false>
Matrix<NewCompTy, M, N, NewUse, Scope> Cast() {
Matrix<NewCompTy, M, N, NewUse, Scope> Result;
Matrix<NewCompTy, __detail::DimMN<M, N, Transpose>::M,
__detail::DimMN<M, N, Transpose>::N, NewUse, Scope>
Cast() {
Matrix<NewCompTy, __detail::DimMN<M, N, Transpose>::M,
__detail::DimMN<M, N, Transpose>::N, NewUse, Scope>
Result;
__builtin_LinAlg_CopyConvertMatrix(Result.__handle, __handle, Transpose);
return Result;
}
Expand Down
23 changes: 15 additions & 8 deletions tools/clang/test/CodeGenDXIL/hlsl/linalg/api/matrix-class.hlsl
Original file line number Diff line number Diff line change
Expand Up @@ -6,11 +6,15 @@ using namespace dx::linalg;

using MatrixATy = Matrix<ComponentType::F32, 4, 4, MatrixUse::A, MatrixScope::Wave>;
using MatrixBTy = Matrix<ComponentType::F32, 4, 4, MatrixUse::B, MatrixScope::Wave>;
using MatrixBTyInt = Matrix<ComponentType::I32, 4, 4, MatrixUse::B, MatrixScope::Wave>;
using MatrixAccumTy = Matrix<ComponentType::F32, 4, 4, MatrixUse::Accumulator, MatrixScope::Wave>;
using TSMatrixATy = Matrix<ComponentType::F32, 4, 4, MatrixUse::A, MatrixScope::Thread>;
using TSMatrixAccumTy = Matrix<ComponentType::F32, 4, 4, MatrixUse::Accumulator, MatrixScope::Thread>;

using Matrix48TyFloat = Matrix<ComponentType::F32, 4, 8, MatrixUse::A, MatrixScope::Wave>;
using Matrix48TyInt = Matrix<ComponentType::I32, 4, 8, MatrixUse::A, MatrixScope::Wave>;
using Matrix84TyInt = Matrix<ComponentType::I32, 8, 4, MatrixUse::A, MatrixScope::Wave>;


ByteAddressBuffer BAB : register(t0);
RWByteAddressBuffer RWBAB : register(u0);
groupshared float SharedArr[256];
Expand All @@ -34,16 +38,19 @@ void main(uint ID : SV_GroupID)

// Matrix::Cast
//
// CHECK: call %dx.types.LinAlgMatrixC4M4N4U1S1 @dx.op.linAlgCopyConvertMatrix.mC4M4N4U1S1.mC9M4N4U0S1(
// CHECK-SAME: i32 -2147483635, %dx.types.LinAlgMatrixC9M4N4U0S1 %[[MATA1]], i1 false)
// CHECK: %[[MAT48F:.*]] = call %dx.types.LinAlgMatrixC9M4N8U0S1 @dx.op.linAlgFillMatrix.mC9M4N8U0S1.f32(
// CHECK-SAME: i32 -2147483636, float 3.000000e+00) ; LinAlgFillMatrix(value)

// CHECK: call %dx.types.LinAlgMatrixC4M4N8U0S1 @dx.op.linAlgCopyConvertMatrix.mC4M4N8U0S1.mC9M4N8U0S1(
// CHECK-SAME: i32 -2147483635, %dx.types.LinAlgMatrixC9M4N8U0S1 %[[MAT48F]], i1 false)
// CHECK-SAME: ; LinAlgCopyConvertMatrix(srcMatrix,transpose)
MatrixBTyInt MatBInt1 = MatA1.Cast<ComponentType::I32, MatrixUse::B>();
Matrix48TyFloat Mat48F = Matrix48TyFloat::Splat(3.0f);
Matrix48TyInt Mat48I = Mat48F.Cast<ComponentType::I32>();

// CHECK: call %dx.types.LinAlgMatrixC4M4N4U1S1 @dx.op.linAlgCopyConvertMatrix.mC4M4N4U1S1.mC9M4N4U1S1(
// CHECK-SAME: i32 -2147483635, %dx.types.LinAlgMatrixC9M4N4U1S1 %[[MATB1]], i1 true)
// CHECK: call %dx.types.LinAlgMatrixC4M8N4U0S1 @dx.op.linAlgCopyConvertMatrix.mC4M8N4U0S1.mC9M4N8U0S1(
// CHECK-SAME: i32 -2147483635, %dx.types.LinAlgMatrixC9M4N8U0S1 %[[MAT48F]], i1 true)
// CHECK-SAME: ; LinAlgCopyConvertMatrix(srcMatrix,transpose)
MatrixBTyInt MatBInt2;
MatBInt2 = MatB1.Cast<ComponentType::I32, MatrixUse::B, true>();
Matrix84TyInt Mat84I = Mat48F.Cast<ComponentType::I32, MatrixUse::A, true>();

// Matrix::Load from ByteAddressBuffer
//
Expand Down
Loading