Skip to content

Latest commit

 

History

History
1136 lines (972 loc) · 43.5 KB

sycl_ext_oneapi_matrix.asciidoc

File metadata and controls

1136 lines (972 loc) · 43.5 KB

sycl_ext_oneapi_matrix

Notice

Copyright (c) 2021-2023 Intel Corporation. All rights reserved.

Khronos® is a registered trademark and SYCL™ and SPIR™ are trademarks of The Khronos Group Inc. OpenCL™ is a trademark of Apple Inc. used by permission by Khronos.

Contact

To report problems with this extension, please open a new issue at:

Dependencies

This extension is written against the SYCL 2020 revision 6 specification. All references below to the "core SYCL specification" or to section numbers in the SYCL specification refer to that revision.

Status

This is an experimental extension specification, intended to provide early access to features and gather community feedback. Interfaces defined in this specification are implemented in DPC++, but they are not finalized and may change incompatibly in future versions of DPC++ without prior notice. Shipping software products should not rely on APIs defined in this specification.

Backend support status

This extension is currently implemented in DPC++ only for devices that contain a matrix hardware, specifically Intel® Advanced Matrix Extensions (Intel® AMX), Intel® Xe Matrix Extensions (Intel® XMX), Nvidia® Tensor Cores and AMD Matrix Cores®.

The joint_matrix type and the joint_matrix_mad function are optional kernel features as defined in section 5.7 of the core SYCL specification. Each device supports only certain values for the M, N, and K template parameters and only certain types for the Ta, Tb, and Tc template parameters. Applications can use the query API in matrix_params or get_info<ext::oneapi::experimental::info::device::matrix_combinations> to determine the set of legal parameters for each device. If the application submits a kernel using an unsupported joint_matrix type or calls joint_matrix_mad with an unsupported combination, the implementation throws a synchronous exception with the errc::kernel_not_supported error code as described in section 5.7.

Overview

Joint matrix is a SYCL extension for matrix hardware programming. It unifies targets like Intel AMX in CPUs, Intel XMX in Intel GPUs, Nvidia Tensor Cores and AMD Matrix Cores®. This provides a portable and performant API for users who want to build their own neural networks applications, perform custom optimizations, or experiment with new operations in a timely and performing manner.

Specification

Feature test macro

This extension provides a feature-test macro as described in the core SYCL specification. An implementation supporting this extension must predefine the macro SYCL_EXT_ONEAPI_MATRIX to one of the values defined in the table below. Applications can test for the existence of this macro to determine if the implementation supports this feature, or applications can test the macro’s value to determine which of the extension’s features the implementation supports.

Value Description

1

The APIs of this experimental extension are not versioned, so the feature-test macro always has this value.

New joint_matrix class

This extension adds a new class named joint_matrix, which represents a small 2-dimensional matrix that supports native operations in hardware. There are a number of template parameters, namely the group scope, the type of the elements, the matrix use, the shape, and the memory layout of the matrix. This results in the following description:

namespace sycl::ext::oneapi::experimental::matrix {

template <typename Group, typename T, use Use, size_t Rows, size_t Cols,
          layout Layout = (Use == use::accumulator) ?
                          layout::dynamic : /*unspecified*/ >
struct joint_matrix {
  joint_matrix();
  joint_matrix(const joint_matrix &) = delete;
  joint_matrix &operator=(const joint_matrix &) = delete;
};

} // namespace sycl::ext::oneapi::experimental::matrix

The constructor for the joint_matrix type is a group function as defined in section 4.17.3 of the core SYCL specification. It must be encountered in converged control flow by all work-items in the Group.

Group Memory Scope

Most operations on the joint_matrix are group functions, meaning that all work items in a group collectively perform an operation on the same matrix. The Group template parameter specifies the execution scope of the work-items in the group. The joint_matrix is shared among the work items in the group and is not private to each work item. This extension currently supports only the sub-group scope, so the Group template parameter must be sycl::sub_group, and group operations for the joint matrix must be done collectively by the work-items in a single sub-group. In this case, a matrix is declared as follows:

joint_matrix<sub_group, int8_t, use::a, tM, tN, layout::row_major> tA;

Element Type

The T template parameter specifies the type of each element in the matrix. Each device supports only certain element types, so the application must ensure that the element type is supported on the device where the kernel using this joint_matrix runs. The query functions (defined below) may be used to determine the set of element types that are supported on a device.

Matrix Use

The main operation performed by the matrix hardware is D=C+A*B. The Use template parameter specifies which of these terms (A, ,B, C, or D) corresponds to the joint_matrix object. The use enumeration defines the set of legal values. The A matrix must have the value use::a. The B matrix must have the value use::b. The C and D matrices must both have the value use::accumulator. This is used by backend implementations to reason about the layout of the matrix in registers.

namespace sycl::ext::oneapi::experimental::matrix {

enum class use {
  a,
  b,
  accumulator
};

} // namespace sycl::ext::oneapi::experimental::matrix

Matrix Shape

The Rows and Cols template parameters provide the number of rows and columns in the joint matrix. Each device supports only certain combinations of row and column sizes, so the application must ensure that the combination is supported on the device where the kernel using this joint_matrix runs. The query functions (defined below) may be used to determine the set of combinations that are supported on a device.

Matrix Layout

The Layout template parameter specifies the memory layout of the matrix, using one of the values in the layout enumeration. The A and B matrices can be either layout::row_major or layout::col_major (but not layout::dynamic). The C and D matrices must be layout::dynamic.

namespace sycl::ext::oneapi::experimental::matrix {

enum class layout {
  row_major,
  col_major,
  dynamic
};

} // namespace sycl::ext::oneapi::experimental::matrix

Note that the Layout template parameters defaults to layout::dynamic when Use is use::accumulator, so applications need not specify this template parameter for the C or D matrices, and it is invalid to specify any other value for Layout. When Use has any other value, there is no default for Layout, and the application must specify one explicitly.

Collective matrix operations

The following operations (load, store, multiply-and-add, fill, and element-wise operations) are group functions as defined in section 4.17.3 of the core SYCL specification. As such, they must be encountered in convergent control flow by the work-items in the group that performs the group operation.

Load

namespace sycl::ext::oneapi::experimental::matrix {

// Only available when std::is_same_v<T1, std::remove_const_t<T2>>
template <typename Group, typename T1, typename T2,
          size_t Rows, size_t Cols,
          access::address_space Space, access::decorated IsDecorated>
void joint_matrix_load(Group g,
    joint_matrix<Group, T1, use::accumulator, Rows, Cols, layout::dynamic> &res,
    multi_ptr<T2, Space, IsDecorated> src, size_t stride, layout Layout);

// Only available when Layout != layout::dynamic
// and when std::is_same_v<T1, std::remove_const_t<T2>>
template <typename Group, typename T1, typename T2,
          size_t Rows, size_t Cols,
          use Use, layout Layout,
          access::address_space Space, access::decorated IsDecorated>
void joint_matrix_load(Group g,
    joint_matrix<Group, T1, Use, Rows, Cols, Layout> &res,
    multi_ptr<T2, Space, IsDecorated> src, size_t stride);

// Only available when std::is_same_v<T1, std::remove_const_t<T2>>
template <typename Group, typename T1, typename T2,
          size_t Rows, size_t Cols,
          typename PropertyListT>
void joint_matrix_load(Group g,
    joint_matrix<Group, T1, use::accumulator, Rows, Cols, layout::dynamic> &res,
    annotated_ptr<T2, PropertyListT> src, size_t stride, layout Layout);

// Only available when Layout != layout::dynamic
// and when std::is_same_v<T1, std::remove_const_t<T2>>
template <typename Group, typename T1, typename T2,
          size_t Rows, size_t Cols, use Use, layout Layout,
          typename PropertyListT>
void joint_matrix_load(Group g,
    joint_matrix<Group, T1, Use, Rows, Cols, Layout> &res,
    annotated_ptr<T2, PropertyListT> src, size_t stride);

} // namespace sycl::ext::oneapi::experimental::matrix

joint_matrix_load loads data from memory to the registers of the matrix hardware. We define two overloads of the load function depending on whether the memory layout was declared as part of the joint_matrix type or not. The first overload that takes memory layout as an argument is only available for a joint_matrix type that used the default value layout::dynamic. The second overload without a memory layout must not be used with a joint_matrix type that has layout::dynamic.

The base pointer src of type T here determines the starting address of the matrix to be loaded from. Layout determines whether the data is being read in a row (row_major), column major (col_major) fashion. stride describes the number of elements between consecutive rows for the row major layout, or between columns for the column major layout.

The two last overloads of joint_matrix_load take sycl::ext::oneapi::experimental::annotated_ptr as argument instead of sycl::multi_ptr. The property list associated with the annotated_ptr argument represents the compile-time constant properties for cache control included in the SYCL extenion sycl_ext_intel_cache_controls as illustrated in the example below.

using syclex = sycl::ext::oneapi::experimental;
using syclintelex = sycl::ext::intel::experimental;

auto A_ptr = syclex::annotated_ptr{A,
               syclex::properties{syclintelex::read_hint<
                   syclintelex::cache_control<syclintelex::cache_mode::cached,
                                              syclex::cache_level::L2>>}};
q.parallel_for(nd_range<2>(G, L), [=](nd_item<2> it) {
  sub_group sg = it.get_sub_group();
  joint_matrix<sub_group, bfloat16, use::a, tM, tK, layout::row_major> tA;
  for (int k = 0; k < K; k += tileK) {
    // User specifies that this load will be cached to L2
    joint_matrix_load(sg, tA, A_ptr + sg_startx * tM * K + k, K);
    ...
  }
});

Store

namespace sycl::ext::oneapi::experimental::matrix {

// T1 must be the same as T2
template <typename Group, typename T1, typename T2, size_t Rows, size_t Cols,
          access::address_space Space, access::decorated IsDecorated>
void joint_matrix_store(Group g,
   const joint_matrix<Group, T1, use::accumulator, Rows, Cols, layout::dynamic> &res,
   multi_ptr<T2, Space, IsDecorated> dest, size_t stride, layout Layout);

template <typename Group, typename T1, typename T2, size_t Rows, size_t Cols,
          typename PropertyListT>
void joint_matrix_store(Group g,
   const joint_matrix<Group, T1, use::accumulator, Rows, Cols, layout::dynamic> &res,
   annotated_ptr<T2, PropertyListT> dest, size_t stride, layout Layout);

} // namespace sycl::ext::oneapi::experimental::matrix

This function stores the data in the accumulator matrix from the registers back to memory.

The base pointer dest here determines the starting address of the matrix to be stored. Layout determines whether the data is being written in a row (row_major), column major (col_major) fashion. stride describes the number of elements between consecutive rows for the row major layout, or between columns for the column major layout.

The second overload of joint_matrix_store takes sycl::ext::oneapi::experimental::annotated_ptr as argument instead of sycl::multi_ptr. The property list associated with the annotated_ptr argument represents the compile-time constant properties for cache control included in the SYCL extenion sycl_ext_intel_cache_controls

Multiply and Add

namespace sycl::ext::oneapi::experimental::matrix {

template <typename Group, typename Ta, typename Tb, typename Tc, typename Td,
          std::size_t M, std::size_t K, std::size_t N,
          layout LayoutA, layout LayoutB>
void joint_matrix_mad(Group g,
    joint_matrix<Group, Td, use::accumulator, M, N, layout::dynamic> &D,
    const joint_matrix<Group, Ta, use::a, M, K, LayoutA> &A,
    const joint_matrix<Group, Tb, use::b, K, N, LayoutB> &B,
    const joint_matrix<Group, Tc, use::accumulator, M, N, layout::dynamic> &C);

} // namespace sycl::ext::oneapi::experimental::matrix

The matrix multiply and add function performs the multiply operation on the matrices A and B, accumulates the result with C and returns the result into the matrix D.

Each device supports only certain combinations of types for the A, B, and C matrices. The application must use the query operations (defined below) to ensure that the combination of types is supported on the device where the kernel calling joint_matrix_mad runs.

Fill (Initialization)

Unlike joint_matrix_load that assumes that all the matrices are directly loaded from memory, joint_matrix_fill makes it possible to multiply a matrix which is not directly loaded from memory but rather initialized directly in the register. Note that the value type Tv must be convertible to the matrix elements type T.

namespace sycl::ext::oneapi::experimental::matrix {

template <typename Group, typename T, size_t Rows, size_t Cols,
          use Use, layout Layout, typename Tv>
void joint_matrix_fill(Group g, joint_matrix<Group, T, Use, Rows,
          Cols, Layout> &m, Tv v);

} // namespace sycl::ext::oneapi::experimental::matrix

Copy

namespace sycl::ext::oneapi::experimental::matrix {

template <typename Group, typename T1, typename T2, size_t Rows,
          size_t Cols, use Use1, use Use2, layout Layout1, layout Layout2>
void joint_matrix_copy(Group g,
                      joint_matrix<Group, T1, Use1, Rows, Cols, Layout1> &src,
                      joint_matrix<Group, T2, Use2, Rows, Cols, Layout2> &dest);

} // namespace sycl::ext::oneapi::experimental::matrix

This function copies Rows x Cols elements of type T1 from joint matrix src to Rows x Cols elements of type T2 of joint matrix dest. The two matrices must have the same scope and shape. Use, type, and layout can be different so this function converts between different use of matrices.

Element-Wise Operations

Besides matrix multiply and add, this extension aims to make it possible to perform element-wise operations on matrices in a SPMD manner. joint_matrix_apply function performs an element-wise operation where the same operation is performed on every element of the joint matrix, such that the operation can be performed without knowledge of the position of the element within the matrix. Activation functions or adding a constant value to every element of the matrix are two examples of this usage. When the operation depends on the element index of the matrix, an Intel-specific extension is available as part of the sycl_ext_intel_matrix

Besides the Group and the joint_matrix arguments, joint_matrix_apply takes a C++ Callable object which is invoked once for each element of the matrix. This callable object must be invocable with a single parameter of type T&. Commonly, applications pass a lambda expression.

namespace sycl::ext::oneapi::experimental::matrix {

template<typename Group, typename T, use Use, size_t Rows, size_t Cols,
  layout Layout, typename F>
void joint_matrix_apply(Group g, joint_matrix<Group, T, Use, Rows, Cols,
  Layout>& C, F&& func);

} // namespace sycl::ext::oneapi::experimental::matrix

In the following example, every element of the matrix C is multiplied by alpha. Then, an activation function, relu in this example, is applied on each of the elements of C.

joint_matrix_apply(sg, C, [=](T &x) {
    x *= alpha;
    relu(x);
});

Prefetch

namespace sycl::ext::oneapi::experimental::matrix {

template <size_t Rows, size_t Cols, typename Group, typename T,
          typename Properties = empty_properties_t>
void joint_matrix_prefetch(Group g, T* ptr, size_t stride, layout Layout,
                           Properties properties = {});

} // namespace sycl::ext::oneapi::experimental::matrix

joint_matrix_prefetch allows groups of work-items to cooperatively prefetch Rows x Cols elements in a 2d manner. This function is a group function, as defined in Section 4.17.3 of the core SYCL specification.

The level of cache targeted by joint_matrix_prefetch in the last argument is specified using the compile-time properties defined in the SYCL extension sycl_ext_oneapi_prefetch as illustrated in the example below. When no cache levels are specified, the default behavior is to prefetch into the lowest level cache (i.e. L1).

using syclex = sycl::ext::oneapi::experimental;

bfloat16 *memA = malloc_shared<bfloat16>(M*K, q);
q.parallel_for(nd_range<2>(G, L), [=](nd_item<2> it) {
  sub_group sg = it.get_sub_group();
  for (int k = 0; k < K; k += tileK) {
    syclex::joint_matrix_prefetch<tM, tK>(sg, memA + tM * K + tK, K,
                                  layout::row_major,
                                  syclex::properties{syclex::prefetch_hint_L2});
    ...
  }
});

Support for Machine Learning Types

Some devices support special matrix element types that are commonly used in machine learning algorithms. These types are unusual because the type of the matrix element is different from the way the data is stored in memory. As a result, each of these elements has two types. There is an abstract identifier for the element type, which is an incomplete type defined in the sycl::ext::oneapi::experimental::matrix::precision namespace, and there is a corresponding storage format type. The following synopsis lists the abstract types and the table shows the associated storage format type.

namespace sycl::ext::oneapi::experimental::matrix::precision {

class tf32;

} // namespace sycl::ext::oneapi::experimental::matrix::precision
joint_matrix element type Storage type Descritpion

precision::tf32

float

The TF32 type has a 19 bit format with one sign bit, 8 exponent bits (offering the same range as float), and 10 mantissa bits (offering the same precision as sycl::half).

In order to declare a joint_matrix with one of these element types, use the abstract type like so:

joint_matrix<sub_group, precision::tf32, use::a, tM, tK,
             layout::row_major> tA;

Operations on these matrices use the functions described above, but there are different constraints on the template parameters as described below.

load

The template parameter T2 must either be the storage format type that corresponds to the abstract type T1 or it must be a const-qualified version of that storage format type. For example:

joint_matrix<sub_group, precision::tf32, use::a, tM, tK, layout::row_major> tA;

float *buf = malloc_shared<float>(M*K, q);
auto pBuf = address_space_cast<sycl::access::address_space::global_space,
                               sycl::access::decorated::no>(buf);

joint_matrix_load(sg, tA, pBuf + Offset, Stride);

store

The template parameter T2 must be the storage format type that corresponds to the abstract type T1. For example:

joint_matrix<sub_group, precision::tf32, use::accumulator, tM, tK> tC;

float *buf = malloc_shared<float>(M*K, q);
auto pBuf = address_space_cast<sycl::access::address_space::global_space,
                               sycl::access::decorated::no>(buf);

joint_matrix_store(sg, tA, pBuf + Offset, Stride, layout::row_major);

fill

The template parameter Tv must be implicitly convertible to the storage format type that corresponds to the abstract type T. For example:

joint_matrix<sub_group, precision::tf32, use::a, tM, tK, layout::row_major> tA;
float v = 42.0;
joint_matrix_fill(sg, tA, v);

copy

There is no special constraint for the joint_matrix_copy function. The template parameters T1 and T2 correspond to the element types of the src and dest matrices.

joint_matrix<sub_group, precision::tf32, use::a, tM, tK, layout::row_major> tA;
joint_matrix<sub_group, float, use::accumulator, tM, tK> tC;
joint_matrix_copy(sg, tC, tA);

Element-wise operations

The Callable function type F must be invocable with a single argument whose type is a reference to the storage format type that corresponds to the abstract type T. For example, in the case where C is a joint matrix of type precision::tf32:

joint_matrix<sub_group, precision::tf32, use::accumulator, tM, tK> tC;
joint_matrix_apply(sg, tC, [=](float &x) {
    x *= alpha;
});

Rounding TF32 values

The functions joint_matrix_load, joint_matrix_fill, and joint_matrix_apply do not define any rounding mode when the float values are converted to TF32, and the implementation may either round or truncate these conversions. If an application wants more control over this rounding, it can use the round_to_tf32 function. This performs the round to nearest even (RTE) rounding mode.

namespace sycl::ext::oneapi::experimental::matrix {

float round_to_tf32(float elem);

} // namespace sycl::ext::oneapi::experimental::matrix

Example using int8_t type

using namespace sycl::ext::oneapi::experimental::matrix;

queue q;
range<2> G = {M/tM, N};
range<2> L = {1, SG_SIZE};
int8_t *memA = malloc_shared<int8_t>(M*K, q);
int8_t *memB = malloc_shared<int8_t>(K*N, q);
int32_t *memC = malloc_shared<int32_t>(M*N, q);
auto pA = address_space_cast<sycl::access::address_space::global_space,
                             sycl::access::decorated::no>(memA);
auto pB = address_space_cast<sycl::access::address_space::global_space,
                             sycl::access::decorated::no>(memB);
auto pC = address_space_cast<sycl::access::address_space::global_space,
                             sycl::access::decorated::no>(memC);
q.parallel_for(nd_range<2>(G, L), [=](nd_item<2> item)
  [[sycl::reqd_sub_group_size(SG_SIZE)]] {
   const auto global_idx = item.get_global_id(0);
   const auto global_idy = item.get_global_id(1);
   const auto sg_startx = global_idx - item.get_local_id(0);
   const auto sg_starty = global_idy - item.get_local_id(1);
   sub_group sg = item.get_sub_group();
   joint_matrix<sub_group, int8_t, use::a, tM, tK, layout::row_major> tA;
   joint_matrix<sub_group, int8_t, use::b, tK, tN, layout::row_major> tB;
   joint_matrix<sub_group, int32_t, use::accumulator, tM, tN> tC;
   joint_matrix_fill(sg, tC, 0);
   for (int k = 0; k < K; k += tK) {
     joint_matrix_load(sg, tA, pA + sg_startx * tM * K + k, K);
     joint_matrix_load(sg, tB, pB + k * N + sg_starty/SG_SIZE*tN, N);
     joint_matrix_mad(sg, tC, tA, tB, tC);
   }
   joint_matrix_apply(sg, tC, [=](int8_t x) {
    x *= alpha;
   });
   joint_matrix_store(sg, tC, pC + sg_startx * tM * N + sg_starty/SG_SIZE*tN,
                      N, layout::row_major);
}).wait();

Query Interface

Most devices support only certain values for the Rows and Cols template parameters and only certain types for the T template parameter. Moreover, most devices support only certain combinations of these template parameter for the A, B, C, and D matrices in the joint_matrix_mad function (see Appendix: Supported Combinations Per Hardware). This extension adds two query APIs that can be used to determine the set of legal parameters for a particular device. One form provides constexpr values for these parameters, which can be used when the application knows the specific device architecture on which it will run. The other form uses the standard information descriptor queries for the device object.

The description below uses the terms M, N, and K to identify the matrix dimensions of a multiply and add operation D = C + A*B. The D and C matrices are M rows by N columns. The A matrix is M rows by K columns, and the B matrix is K rows by N columns.

Compile-Time Query

This returns constexpr values to use in joint_matrix template arguments but depends on an enumeration of the matrix hardware (See sycl::ext::oneapi::experimental::architecture) in the sycl_ext_oneapi_device_architecture extension that can be tested. The compile-time query interface proposed here consists of two functionalities:

  • Validation: at compile time, the validation functionality informs the user whether a specific combination is valid or not. This takes place when the user specifies all template parameters.

  • Default values: this provides a default shape if the user does not provide a specific combination. In this case, aliases to the joint_matrix type can be used, namely joint_matrix_a/b/c/d where no additional argument is needed. This form happens when the user specifies all template parameters except the sizes of the matrices M, N, and K.

The table below provides a description for each of the member variables in matrix_params class and the forms in which they are defined.

Member/type alias in matrix_params Description
static constexpr size_t M

when no sizes are provided by the user, indicates the suggested default size for M; usually this corresponds to the maximum size the implementation supports. In validation mode, where the user does provide sizes, this is the same value M that the user provides if M is supported by the implementation

static constexpr size_t N

when no sizes are provided by the user, indicates the suggested default size for N; usually this corresponds to the maximum size the implementation supports. In validation mode, where the user does provide sizes, this is the same value N that the user provides if N is supported by the implementation

static constexpr size_t K

when no sizes are provided by the user, indicates the suggested default size for K; usually this corresponds to the maximum size the implementation supports. In validation mode, where the user does provide sizes, this is the same value K that the user provides if K is supported by the implementation

template <typename Group, layout Layout>
using joint_matrix_a

type alias for joint_matrix for matrix A

template <typename Group, layout Layout>
using joint_matrix_b

type alias for joint_matrix for matrix B

template <typename Group>
using joint_matrix_c

type alias for joint_matrix for the input matrix accumulator

template <typename Group>
using joint_matrix_d

type alias for joint_matrix for the output matrix accumulator

namespace sycl::ext::oneapi::experimental::matrix {

template<architecture Arch, typename Ta, typename Tb, typename Tc,
         typename Td=Tc, size_t sM=0, size_t sN=0, size_t sK=0>
struct matrix_params;

// This is the validation form, when all template parameters are
// specified.
template<architecture Arch, typename Ta, typename Tb, typename Tc,
         typename Td, size_t sM, size_t sN, size_t sK>
struct matrix_params<Arch, Ta, Tb, Tc, Td, sM, sN, sK> {
  // An implementation typically uses static_assert here to trigger a
  // compilation error when the matrix types or shapes are not
  // supported by the device identified by the architecture "Arch".

  static constexpr size_t M = sM;
  static constexpr size_t N = sN;
  static constexpr size_t K = sK;

  template <typename Group, layout Layout>
  using joint_matrix_a = joint_matrix<Group, Ta, use::a, sM, sK, Layout>;

  template <typename Group, layout Layout>
  using joint_matrix_b = joint_matrix<Group, Tb, use::b, sK, sN, Layout>;

  template <typename Group>
  using joint_matrix_c = joint_matrix<Group, Tc, use::accumulator, sM, sN>;

  template <typename Group>
  using joint_matrix_d = joint_matrix<Group, Td, use::accumulator, sM, sN>;
};

// This is the default values form, where the matrix dimensions are
// omitted.
template<architecture Arch, typename Ta, typename Tb, typename Tc, typename Td>
struct matrix_params<Arch, Ta, Tb, Tc, Td, 0, 0, 0> {
  // An implementation typically uses static_assert here to trigger a
  // compilation error when the matrix types are not supported by the
  // device identified by the architecture "Arch".

  static constexpr size_t M = /* implementation defined */;
  static constexpr size_t N = /* implementation defined */;
  static constexpr size_t K = /* implementation defined */;

  template <typename Group, layout Layout>
  using joint_matrix_a = joint_matrix<Group, Ta, use::a, M, K, Layout>;

  template <typename Group, layout Layout>
  using joint_matrix_b = joint_matrix<Group, Tb, use::b, K, N, Layout>;

  template <typename Group>
  using joint_matrix_c = joint_matrix<Group, Tc, use::accumulator, M, N>;

  template <typename Group>
  using joint_matrix_d = joint_matrix<Group, Td, use::accumulator, M, N>;
};

} // namespace sycl::ext::oneapi::experimental::matrix
Validation Example:
// User can provide sizes besides the types and matrix_params can assert
// if they are supported or not
// in this case, an assertion will happens as 16 is not a supported size for M
using myparams = matrix_params<architecture::intel_gpu_pvc, int8_t,
                               int8_t, int, int, 16, 16, 32>;
size_t NDRangeM = M / myparams::M;  //Assertion would happen at this line
size_t NDRangeN = N / myparams::N;
Default Values Example:
using myparams = matrix_params<architecture::intel_gpu_pvc, int8_t, int8_t, int>;
// use this to construct the ranges on the host side
size_t NDRangeM = M / myparams::M;
size_t NDRangeN = N / myparams::N;
//if M, N, K do not multiply the default sizes, padding has to be done
// device code: the matrices are constructed using the default dimensions
myparams::joint_matrix_a<sub_group, layout::row_major> sub_a;
myparams::joint_matrix_b<sub_group, layout::row_major> sub_b;
myparams::joint_matrix_c<sub_group> sub_c;

Runtime Query

The runtime query does not require the application to hard-code a specific device type, but it also returns values that are not constexpr. It provides similar information as the compile time query API via an extended device information descriptor.

The table below provides a description for each of the device matrix descriptors that can be queried using get_info API.

Device descriptors Return type Description

ext::oneapi::experimental::info::device::matrix_combinations

std::vector<combination>

tells the set of supported matrix sizes and types on this device

The runtime query returns a vector of matrix_combinations of combination type. Each combination includes the sizes and the types for the matrices A, B, C, and D. Note that for each matrix hardware, the query returns max_msize, max_nsize, max_ksize or msize, nsize, ksize exclusively, depending on whether the implementation supports a continuous or discrete number of sizes. If a device support a continuous number of sizes, the max_* variant is applied and only the maximum number is returned. However, if a device supports a discrete list of numbers so the msize, nsize, ksize variant is applied.

namespace sycl::ext::oneapi::experimental::matrix {

enum class matrix_type {
  bf16,
  fp16,
  tf32,
  fp32,
  fp64,
  sint8,
  sint16,
  sint32,
  sint64,
  uint8,
  uint16,
  uint32,
  uint64
};
struct combination {
  size_t max_msize;
  size_t max_nsize;
  size_t max_ksize;
  size_t msize;
  size_t nsize;
  size_t ksize;
  matrix_type atype;
  matrix_type btype;
  matrix_type ctype;
  matrix_type dtype;
};

} // namespace sycl::ext::oneapi::experimental::matrix

Each combination of the matrix_combinations vector composes the types and sizes of A, B, C, and D matrices supported by the device implementation. The table below provides a description of each member of the combination struct.

Member of combination Description

max_msize, max_nsize, max_ksize

if the matrix implementation supports a continuous number of element sizes, each of these members is non-zero, and the matrix implementation supports all element sizes from 1 up to (and including) that number. By contrast, if the matrix hardware implementation supports a discrete number of element sizes, each of these members has the value zero

msize, nsize, ksize

if the matrix implementation supports a discrete number of element sizes, each of these members is non-zero, and the value tells one of the supported element sizes. By contrast, if the matrix hardware supports a continuous number of element sizes, each of these members has the value zero

atype, btype, ctype, dtype

indicates the types supported in the combination. these are of type matrix_type which tells the list of types that are supported for the A, B, C, and D matrices in the T template parameter as follows:
bf16: sycl::ext::oneapi::bfloat16
fp16: sycl::half
tf32: sycl::ext::oneapi::experimental::matrix::precision::tf32
fp32: float
fp64: double
sint8: int8_t
sint16: int16_t
sint32: int32_t
sint64: int64_t
uint8: uint8_t
uint16: uint16_t
uint32: uint32_t
uint64: uint64_t

Runtime Query Example:
// Ta, Tb, Tc, and Td are the types used in applications
std::vector<combination> combinations =
           device.get_info<info::device::matrix_combinations>();
for (int i = 0; sizeof(combinations); i++) {
  if (Ta == combinations[i].atype &&
      Tb == combinations[i].btype &&
      Tc == combinations[i].ctype &&
      Td == combinations[i].dtype) {
    // joint matrix GEMM kernel can be called using these sizes
    joint_matrix_gemm(combinations[i].msize,
         combinations[i].nsize, combinations[i].ksize);
  }
}

Appendix: Supported Combinations Per Hardware

The table below provides a list of the combinations that joint_matrix implementations support on each of Intel AMX and Intel XMX hardware. Note that these can be returned using ext::oneapi::experimental::info::device::matrix_combinations.

Intel AMX Supported Combinations

This is currently available in devices with the architecture architecture::intel_cpu_spr, and architecture::intel_cpu_gnr. In this architecture’s implementation, the type of the C matrix must be the same as the type of the D matrix. Therefore, that common type is shown in a single column in the table below.

A type B type C and D type M N K device

matrix_type::uint8

matrix_type::uint8

matrix_type::sint32

<= 16

<= 16

<= 64

architecture::intel_cpu_spr, architecture::intel_cpu_gnr

matrix_type::uint8

matrix_type::sint8

matrix_type::sint32

<= 16

<= 16

<= 64

architecture::intel_cpu_spr, architecture::intel_cpu_gnr

matrix_type::sint8

matrix_type::uint8

matrix_type::sint32

<= 16

<= 16

<= 64

architecture::intel_cpu_spr, architecture::intel_cpu_gnr

matrix_type::sint8

matrix_type::sint8

matrix_type::sint32

<= 16

<= 16

<= 64

architecture::intel_cpu_spr, architecture::intel_cpu_gnr

matrix_type::bf16

matrix_type::bf16

matrix_type::fp32

<= 16

<= 16

<= 32

architecture::intel_cpu_spr, architecture::intel_cpu_gnr

matrix_type::fp16

matrix_type::fp16

matrix_type::fp32

<= 16

<= 16

<= 32

architecture::intel_cpu_gnr

Intel XMX Supported Combinations

This is currently available in devices with the architecture architecture::intel_gpu_pvc, architecture::intel_gpu_dg2_g10, architecture::intel_gpu_dg2_g11, and architecture::intel_gpu_dg2_g12. In these architectures' implementation, the type of the C matrix must be the same as the type of the D matrix. Therefore, that common type is shown in a single column in the table below.

A type B type C and D type M N K device

matrix_type::uint8

matrix_type::uint8

matrix_type::sint32

<= 8

16

32

architecture::intel_gpu_pvc

8

architecture::intel_gpu_dg2_g10, architecture::intel_gpu_dg2_g11, architecture::intel_gpu_dg2_g12

matrix_type::uint8

matrix_type::sint8

matrix_type::sint32

<= 8

16

32

architecture::intel_gpu_pvc

8

architecture::intel_gpu_dg2_g10, architecture::intel_gpu_dg2_g11, architecture::intel_gpu_dg2_g12

matrix_type::sint8

matrix_type::uint8

matrix_type::sint32

<= 8

16

32

architecture::intel_gpu_pvc

8

architecture::intel_gpu_dg2_g10, architecture::intel_gpu_dg2_g11, architecture::intel_gpu_dg2_g12

matrix_type::sint8

matrix_type::sint8

matrix_type::sint32

<= 8

16

32

architecture::intel_gpu_pvc

8

architecture::intel_gpu_dg2_g10, architecture::intel_gpu_dg2_g11, architecture::intel_gpu_dg2_g12

matrix_type::fp16

matrix_type::fp16

matrix_type::fp32

<= 8

16

16

architecture::intel_gpu_pvc

8

architecture::intel_gpu_dg2_g10, architecture::intel_gpu_dg2_g11, architecture::intel_gpu_dg2_g12

matrix_type::bf16

matrix_type::bf16

matrix_type::fp32

16

16

16

architecture::intel_gpu_pvc

32

64

16

<= 8

16

16

8

architecture::intel_gpu_dg2_g10, architecture::intel_gpu_dg2_g11, architecture::intel_gpu_dg2_g12

matrix_type::tf32

matrix_type::tf32

matrix_type::fp32

<= 8

16

8

architecture::intel_gpu_pvc

Nvidia Tensor Cores Supported Combinations

The complete set of matrix data types and shapes that are supported by the ext_oneapi_cuda backend are represented in the following table. In this architecture’s implementation, the type of the A matrix must be the same as the type of the B matrix.

Important
When compiling for the ext_oneapi_cuda backend the target arch backend flag, -fsycl-targets=nvidia_gpu_sm_xx (or equivalents, e.g. -Xsycl-target-backend --cuda-gpu-arch=sm_xx), must be used, where sm_xx must be a Compute Capability that is equal to or greater than the appropriate Minimum Compute Capability. When an executable has been compiled for sm_xx, if the executable is run on a device with compute capability less than sm_xx then an error will be thrown. The mapping to Minimum Compute Capability from each supported parameter combination is specified in the following table.
A and B type C type D type M N K Minimum Compute Capability

matrix_type::fp16

matrix_type::fp32

matrix_type::fp32

16

16

16

sm_70

8

32

16

32

8

16

matrix_type::fp16

matrix_type::fp16

matrix_type::fp16

16

16

16

8

32

16

32

8

16

matrix_type::fp16

matrix_type::fp32

matrix_type::fp16

16

16

16

8

32

16

32

8

16

matrix_type::fp16

matrix_type::fp16

matrix_type::fp32

16

16

16

8

32

16

32

8

16

matrix_type::sint8

matrix_type::sint32

matrix_type::sint32

16

16

16

sm_72

8

32

16

32

8

16

matrix_type::uint8

matrix_type::sint32

matrix_type::sint32

16

16

16

8

32

16

32

8

16

matrix_type::tf32

matrix_type::fp32

matrix_type::fp32

16

16

8

sm_80

matrix_type::bf16

matrix_type::fp32

matrix_type::fp32

16

16

16

8

32

16

32

8

16

matrix_type::fp64

matrix_type::fp64

matrix_type::fp64

8

8

4

Important
The stride argument to joint_matrix_load and joint_matrix_store must be a multiple of 8 when T is half, and a multiple of 4 when T is float; where T is the type of the joint_matrix elements. When T is not half or float there are no restrictions to stride.

AMD Matrix Cores Supported Combinations

The complete set of matrix data types and dimensions that are supported by the ext_oneapi_hip backend are represented in the following table. In this architecture’s implementation, A and B matrices must have the same type. Similarly, C and D matrices must share the same type.

Important
The supported instructions may be run on GFX90A (MI200, MI210, MI250 and MI250X GPUs) architecture. When compiling for the ext_oneapi_hip backend the target arch backend flag, -fsycl-targets=amd_gpu_gfx90a, must be used. An attempt to run the compiled code on an unsupported architecture will throw an error.
A and B type C and D type M N K

matrix_type::fp16

matrix_type::fp32

32

32

8

16

16

16

matrix_type::sint8

matrix_type::sint32

32

32

8

16

16

16

matrix_type::bf16

matrix_type::fp32

32

32

8

16

16

16

matrix_type::fp64

matrix_type::fp64

16

16

4

Revision History

Rev Date Author Changes

1

2021-04-13

Dounia Khaldi

Initial public working draft.

2

2021-10-05

Dounia Khaldi

JIT implementation on both Intel AMX and DPAS

3

2022-05-16

Dounia Khaldi

Add matrix fill and piece-wise operations support

4

2022-08-25

Dounia Khaldi

Update the matrix spec by adding the new matrix use parameter and remove reference to the AOT AMX initial implementation

5

2022-11-07

Dounia Khaldi

Update the matrix spec by making it portable across Intel AMX, Intel XMX and Nvidia Tensor Cores, and move the Intel-specifics to a separate extension document

6

2023-01-09

Dounia Khaldi

Add joint_matrix_apply API, tf32 type, runtime query, and supported combinations appendix for Intel AMX and Intel XMX

7

2023-04-11

Jack Kirk

Add Nvidia Tensor Cores supported combinations

8

2023-10-05

Mahmoud Moadeli

Add AMD Matrix Core supported combinations

9

2023-11-13

Dounia Khaldi

Add Granite Rapids Intel AMX supported combinations

9

2023-12-04

Dounia Khaldi

Add prefetch and annotated_ptr load/store overloads