diff --git a/sycl/include/CL/__spirv/spirv_types.hpp b/sycl/include/CL/__spirv/spirv_types.hpp index cae0f38c11748..4fb0fa8886796 100644 --- a/sycl/include/CL/__spirv/spirv_types.hpp +++ b/sycl/include/CL/__spirv/spirv_types.hpp @@ -108,11 +108,32 @@ enum class GroupOperation : uint32_t { ExclusiveScan = 2 }; -enum class MatrixLayout { RowMajor, ColumnMajor, PackedA, PackedB }; +enum class MatrixLayout : uint32_t { + RowMajor = 0, + ColumnMajor = 1, + PackedA = 2, + PackedB = 3 +}; +// TODO: replace the following W/A with a better solution when we have it. +// The following structure is used to represent the joint matrix type in the +// LLVM IR. The structure has a pointer to a multidimensional array member which +// makes the encoding of the matrix type information within the LLVM IR looks +// like this: +// %struct.__spirv_JointMatrixINTEL = type { [42 x [6 x [2 x [1 x float]]]]* } +// Note that an array cannot be of zero size but MatrixLayout and Scope +// parameters can; hence '+ 1' is added to the 3rd and 4th dimensions. +// In general, representing a matrix type information like this is a bit odd +// (especially for MatrixLayout and Scope parameters). But with the current +// tools we have in Clang, this is the only way to preserve and communicate this +// information to SPIRV translator. +// The long term solution would be to introduce a matrix type in Clang and use +// it instead of this member. template -struct __spirv_JointMatrixINTEL; +struct __spirv_JointMatrixINTEL { + T (*Value)[R][C][static_cast(U) + 1][static_cast(S) + 1]; +}; } // namespace __spv diff --git a/sycl/test/matrix/matrix-int8-test.cpp b/sycl/test/matrix/matrix-int8-test.cpp index a1fce823e12a1..ba4327a8e6cd5 100644 --- a/sycl/test/matrix/matrix-int8-test.cpp +++ b/sycl/test/matrix/matrix-int8-test.cpp @@ -1,5 +1,9 @@ -// RUN: %clangxx -fsycl -O2 %s -o %t.out -// XFAIL: * +// RUN: %clangxx -fsycl -fsycl-device-only -O2 -S -emit-llvm -o - %s | FileCheck %s + +// CHECK-DAG: %"struct.__spv::__spirv_JointMatrixINTEL" = type { [12 x [48 x [1 x [4 x i8]]]] addrspace(4)* } +// CHECK-DAG: %"struct.__spv::__spirv_JointMatrixINTEL.[[#]]" = type { [12 x [12 x [1 x [4 x i32]]]] addrspace(4)* } +// CHECK-DAG: %"struct.__spv::__spirv_JointMatrixINTEL.[[#]]" = type { [48 x [12 x [4 x [4 x i8]]]] addrspace(4)* } + #include #if (SYCL_EXT_ONEAPI_MATRIX == 2) #include