Skip to content

Latest commit

 

History

History
663 lines (536 loc) · 39.4 KB

sycl_ext_oneapi_matrix.asciidoc

File metadata and controls

663 lines (536 loc) · 39.4 KB

Matrix Programming Extension for DPC++: sycl_ext_oneapi_matrix

Notice

Copyright (c) 2021-2022 Intel Corporation. All rights reserved.

Note
Khronos® is a registered trademark and SYCL™ and SPIR™ are trademarks of The Khronos Group Inc. OpenCL™ is a trademark of Apple Inc. used by permission by Khronos.

This extension is written against the SYCL 2020 revision 5 specification. All references below to the "core SYCL specification" or to section numbers in the SYCL specification refer to that revision.

NOTE: This document describes the current design and API for the matrix extension to DPC++. This is an initial experimental version to try out functionality and performance, and future versions of this API may change in ways that are incompatible with this experimental version. The current implementation provides support of the matrix interface on Intel® Advanced Matrix Extensions (AMX), DPAS and Nvidia® Tensor Cores.

Introduction

This document presents an ongoing work towards defining a unified matrix interface. This interface is intended to unify different tensor hardware: Intel AMX in CPUs, DPAS in Intel GPUs, Habana Gaudi and Goya tensor and gemm cores, Nvidia TPUs, IBM Power MMA. All these hardware provide low-level intrinsics or assembly to access and perform matrix operations. The goal is to provide a unified interface that is portable but also benefit from the maximum performance these different hardware can offer.

Feature test macro

This extension provides a feature-test macro as described in the core SYCL specification section 6.3.3 "Feature test macros". Therefore, an implementation supporting this extension must predefine the macro SYCL_EXT_ONEAPI_MATRIX to one of the values defined in the table below. Applications can test for the existence of this macro to determine if the implementation supports this feature, or applications can test the macro’s value to determine which of the extension’s APIs the implementation supports.

Value Description

1

The APIs of this experimental extension are not versioned, so the feature-test macro always has this value.

Matrix API Versions

While this document presents the core API that unifies Intel AMX, DPAS, and Nvidia Tensor Cores, the implementations support slightly different versions of the API. For this reason, we introduce a new macro, namely SYCL_EXT_ONEAPI_MATRIX_VERSION to distinguish between these different implementations. The goal in the next few months is to get rid of this implementation versioning macro. These are the current values for this macro.

Value Description

1

Initial extension JIT implementation on Intel AMX and DPAS. load, store, mad, fill, piece-wise operations, and the query interface are supported. The old API used for this implementation is detailed in [matrix extension](doc/extensions/deprecated/sycl_ext_oneapi_deprecated_matrix_no_use.asciidoc)

2

JIT implementation on Intel AMX and DPAS. load, store, mad, fill, piece-wise operations, and the query interface are supported

3

Implementation on Nvidia Tensor Cores

New joint_matrix class

We introduce a new class called joint_matrix. The user needs to specify the type of the elements, shape, the matrix use, the memory layout, and the memory scope of the matrix. This results in the following description:

namespace sycl::ext::oneapi::experimental::matrix {
template <typename T, use Use,
          size_t Rows=sycl::dynamic_extent, size_t Cols=sycl::dynamic_extent,
          layout Layout = layout::dynamic, typename Group = sub_group>
struct joint_matrix {
    joint_matrix(Group g) {}
};
}
Important
Matrix layout defaulting to layout::dynamic applies only to matrix with use::accumulator

Shape

The same class, joint_matrix, should handle both cases where sizes are constant (GPU case) and when sizes are variable (CPU case). Note that a Intel AMX 2d tile register permits sizes up to 1024 (16rowsx64cols) bytes. The ability to define only one interface for both makes it possible to give the user a way to make use of the flexibility introduced by the CPU but at the same time save resources on the GPU. We use sycl::dynamic_extent to differentiate between static and dynamic sizes.

Important
In the current implementation, only the static extent is supported

Use

Specifying the usage of the matrix: matrix left (A), matrix right (B) or accumulator (C) is required by backend implementations to reason about the layout of the matrix in registers.

namespace sycl::ext::oneapi::experimental::matrix {
enum class use {
  a,
  b,
  accumulator
};
}

Layout

Besides row major and column major layouts, layout is flexible enough to introduce custom layouts such as packed layout.

namespace sycl::ext::oneapi::experimental::matrix {
enum class layout {
  row_major,
  col_major,
  packed,
  dynamic
};
}

Memory Scope

In this experimental API version, we used the terminology of joint_matrix instead of plain matrix to emphasize that the matrix is shared among a group of work items and is not private to each work item. The memory scope is added as an additional template parameter and is also part of the constructor arguments.

Important
In the current implementation, only the subgroup scope is supported

When the group is a sycl::sub_group, a matrix is declared as follows:

joint_matrix<int8_t, use::a, tM, tN, layout::row_major> tA(sg);

Matrix Operations and their Execution Scope

We define three new functions needed to perform the main and common operations on matrices, namely load, store, and the actual multiply and add operation. This set of functions can be easily extended if the matrix hardware implements new features.

The base pointer determines the starting address of the matrix to be loaded/stored. layout determines whether the data is being read/written in a row (row_major), column major (column_major) fashion, or if the data has already been transformed into VNNI format (packed). stride describes the number of elements between consecutive rows for row major and packed layouts, or between columns for the column major layout.

Note that in order to get maximum performance on Intel AMX and DPAS, prepacking data in the memory is necessary. If users did not specify the packed layouts, transforms done by the implementation will be slow due to extra scatter/gather operations. Hence, we expose the packed layout to the user to specify that A or B have already been VNNIed. The packed or VNNI layout is introduced in the VNNI layout section below.

Important
In the current AMX and DPAS implementation, the layout in the load of matrix B (provided by the layout memL parameter below) must be packed or row_major. Automatic VNNI transform is supported on AMX. The layout in the load of matrices A and C must be row_major, and the layout in the store of matrix C (provided by the layout memL parameter below) must also be row_major.

Since the matrix functions are group operations (as defined in Section 4.17.3 of the SYCL specification), the matrix API has to be accessed by all the work-items in the group in a convergent control flow. The Group template argument can be a work-group or a subgroup. These functions will be called once by each work item in the group.

To be aligned with the SYCL 2020 group algorithms, an additional group argument is added to the matrix operations to designate that these functions are collective operations. The DPC++ syntax is the following:

Important
In the current implementation, only the subgroup scope is supported.

Load

namespace sycl::ext::oneapi::experimental::matrix {
  template <typename Group, typename T, size_t NumRows, size_t NumCols,
            access::address_space Space>
  void joint_matrix_load(Group sg,
    joint_matrix<T, use::accumulator, NumRows, NumCols, layout::dynamic, Group> &res,
    multi_ptr<T, Space, IsDecorated> src, size_t stride, layout memL);

  template <typename Group, typename T, size_t NumRows, size_t NumCols,
          use Use, layout Layout, access::address_space Space>
  void joint_matrix_load(Group sg,
    joint_matrix<T, Use, NumRows, NumCols, Layout, Group> &res,
    multi_ptr<T, Space, IsDecorated> src, size_t stride);
}

joint_matrix_load loads data from memory to the 2d tiles/registers of Intel AMX/DPAS. We define two overloads of the load function depending on whether the memory layout was declared as part of the joint_matrix type or not. The first overload that takes memory layout as an argument is only available for a joint_matrix type that was declared with layout::dynamic. The second overload without a memory layout must not be used with a joint_matrix type that was declared with layout::dynamic.

Store

namespace sycl::ext::oneapi::experimental::matrix {
  template <typename Group, typename T, size_t NumRows, size_t NumCols,
            access::address_space Space>
  void joint_matrix_store(Group sg,
    joint_matrix<T, use::accumulator, NumRows, NumCols, layout::dynamic, Group> &res,
    multi_ptr<T, Space, IsDecorated> src, size_t stride, layout memL);
}

This function stores the data in the accumulator matrix from the 2d tiles back to memory.

Multiply and Add

namespace sycl::ext::oneapi::experimental::matrix {
  template <typename Group, typename Ta, typename Tb, typename Tc, std::size_t M, std::size_t K, std::size_t N,
            layout LayoutA, layout LayoutB>
  joint_matrix<Td, use::accumulator, M, N, layout::dynamic, Group> joint_matrix_mad(Group sg,
    joint_matrix<Ta, use::a, M, K, layoutA, Group> A,
    joint_matrix<Tb, use::b, K, N, layoutB, Group> B,
    joint_matrix<Tc, use::accumulator, M, N, layout::dynamic, Group> C);
}

The matrix multiply and add function performs the multiply operation on the matrices A and B, accumulate the result with C and return the result.

Matrix Initialization: joint_matrix_fill

The current interface presented above assumes that all the matrices are directly loaded from memory. This new function called joint_matrix_fill makes it possible to multiply a matrix which is not directly loaded from memory but rather initialized directly in the register. On Intel AMX, if the initialization constant is zero, this would map to the _tile_zero intrinsic:

namespace sycl::ext::oneapi::experimental::matrix {
  template <typename Group, typename T, size_t NumRows, size_t NumCols,
           use Use, layout Layout, typename Tv>
  void joint_matrix_fill(Group sg, joint_matrix<T, Use, NumRows, NumCols, Layout, Group> &m, Tv v);
}
Important
In the current implementation, only the subgroup scope is supported.

Element Indexing and Piece-Wise Operations

Background

Besides matrix multiply and add, this extension aims to make it possible to perform piece-wise operations on matrices in a SPMD manner. The mechanisms that are recommended to perform such piece-wise operations depend upon which of the following classes the operation falls into:

Class 1- Element-wise operations where the same operation is performed on every element of the matrix, such that the operation can be performed without knowledge of the position of the element within the matrix. Activation functions or adding a constant value to every element of the matrix are two examples.

Class 2- Piece-wise operations where the operation depends on the element index of the matrix or the operation takes multiple elements as operands (such as a sum of all elements in a row for example). Quantization that is needed for conversion between low precision types like int8_t and fp32 uses piece-wise operations.

Explicit conversion with mapping from SIMD to SPMD

The data elements in a joint_matrix are distributed or shared across the work-items in the Group in an implementation-defined way. There is no fixed allocation of matrix elements owned by a joint_matrix instance to the WIs comprising the group used to instantiate it. For instance, the matrix is a shared entity among the work items in the case of the AMX backend because the AMX tile that holds the matrix data is a 2d register that is shared among the work items. Therefore the partitioning among the WIs is implementation defined. However, it is necessary to allocate WIs to specific elements of the matrix in order to perform element-wise operations. In order to be able to perform element-wise operations in a general and efficient way, we provide a conversion function from the joint_matrix domain that is owned by a group of work items to the portion that is owned by each work item. This enables the WI to perform piece-wise operations on the matrix within the SYCL SPMD programming model.

We introduce a new function get_wi_data that provides a view of the portion of the matrix that is owned by the current WI. The indexing provided inside the wi_data class accesses only the portion of the current WI and returns wi_element. This latter holds a reference to the original joint_matrix that wi_data was constructed from. This means that modifying wi_data also modifies the corresponding joint matrix elements. Users can use the = operator to update the element of the joint_matrix represented by the wi_element after the element-wise operation.

Using get_wi_data, it is not possible to know which portions of data are owned by each thread in the group as this is implementation defined and changes from one backend to the other. For general piece-wise operations such as summing the rows of a matrix, the WI data to joint matrix mapping coordinates information must be known in order to reason about the matrix view and extract the relevant piece. However, for element-wise operations where the same operation is performed on all the elements of the matrix, having all the WIs in the group apply the operation inside a loop iterating over the length of wi_data guarantees the whole matrix element-wise operation.

Therefore, this extension currently only supports class 1 of operations because the mapping between get_wi_data and joint_matrix elements is not required to be known for these operations. However, general piece-wise operations will be supported in the future as a new API will be provided to convey the mapping from joint_matrix domain to WI Domain (See Section "WI data to joint matrix mapping coordinates information for piece-wise operations for more information").

Also, note that get_wi_data cannot return a fixed size array length because the length of the WI portion is a runtime variable for the following reasons:

1- The main compilation mode of SYCL is JIT compilation and partitioning among WIs is implementation defined.

2- SG size is not generally fixed.

3- AMX has the flexibility of allowing variable sizes on the matrix (dynamic_extent).

In the case of CUDA backend which is SYCL AOT compiled and SG size = 32 known and fixed, the additional marray capability will be provided.

The code listing below shows a synopsis of these new APIs.

namespace sycl::ext::oneapi::experimental::matrix {
template <typename T, size_t NumRows, size_t NumCols,
          use Use, layout Layout,
          typename Group = sycl::sub_group>
struct joint_matrix {
   wi_data<T, Use, NumRows, NumCols, Layout, Group> get_wi_data();
};
template <typename T, size_t NumRows, size_t NumCols, use Use, layout Layout, typename Group>
class wi_data {
  size_t length();
  wi_element<T, NumRows, NumCols, Use, Layout, Group> operator[](size_t i);
};
template <typename T, size_t NumRows, size_t NumCols,
          use Use, layout Layout,
          typename Group = sycl::sub_group>
class wi_element {
  operator T();
  wi_element &operator=(const T &rhs);
…
};
}

In the following example wi_data_c is a reference to the WI owned portion of the joint matrix matC. As such wi_data_c[i] OP rhs updates the corresponding matrix element in the joint_matrix matC. Vectorization along the subgroup dimension will get enabled automatically to vectorize the contiguous portion of the matrix.

auto wi_data_c = matC.get_wi_data();
for (int i = 0; i < wi_data_c.length(); i++)
        wi_data_c[i] *= alpha;    // Note that the indexing here "i" is in the vector owned by a WI, not in the matrix C
Important
In the current implementation, only the subgroup scope is supported.
Important
The WI data to joint matrix mapping coordinates information is not implemented yet.
Important
In the Tensor Cores implementation, it is possible to know how many elements are owned by each WI at compile time. In this case, wi_data can be of type marray. An additional interface will be provided for the Tensor Cores backend.

VNNI/Packed Layout

Intel AMX and DPAS compute assumes that the B tile register (src1) is in the VNNI format as they need 32bit of K-data in A and B to be contiguous in memory. The VNNI blocking factor is 2 in the case of 16-bit types, and it is 4 in the case of 8-bit types. While the current implementation assumes that the matrix has been already packed by the user for performance reasons, the layout information is needed to inform the implementation about this transformation. The following example illustrates how a matrix in row_major layout is transformed into the packed layout for a 16-bit type.

Example 1: 16-bit elements

// Example of a 4 row x 4 column matrix using a 16-bit data element, in row-major layout.
// Element a1 is contiguous in memory with element b1, etc.
// ---------------------------------
// a1, b1, c1, d1
// a2, b2, c2, d2
// a3, b3, c3, d3
// a4, b4, c4, d4
// ---------------------------------
// The same matrix reformatted in packed layout.
// Here, packing of 2 elements is needed to form 32 bits.
// Element a1 is contiguous in memory with element a2, etc.
// ---------------------------------
// a1, a2, b1, b2, c1, c2, d1, d2
// a3, a4, b3, b4, c3, c4, d3, d4

Example 2: 8-bit elements

// Example of a 4 row x 4 column matrix using a 8-bit data element, in row-major layout.
// Element a1 is contiguous in memory with element b1, etc.
// ---------------------------------
// a1, b1, c1, d1
// a2, b2, c2, d2
// a3, b3, c3, d3
// a4, b4, c4, d4
// ---------------------------------
// The same matrix reformatted in packed layout.
// Here, packing of 4 elements is needed to form 32 bits.
// Elements a1, a2, a3, a4 are contiguous in memory, etc.
// ---------------------------------
// a1, a2, a3, a4, b1, b2, b3, b4, c1, c2, c3, c4, d1, d2, d3, d4

Example using int8_t type

using namespace sycl::ext::oneapi::experimental::matrix;

queue q;
range<2> G = {M/tM, N};
range<2> L = {1, SG_SIZE};
int8_t *memA = malloc_shared<int8_t>(M*K, q);
int8_t *memB = malloc_shared<int8_t>(K*N, q);
int32_t *memC = malloc_shared<int32_t>(M*N, q);
q.parallel_for(nd_range<2>(G, L), [=](nd_item<2> item)
  [[sycl::reqd_sub_group_size(SG_SIZE)]] {
   const auto global_idx = item.get_global_id(0);
   const auto global_idy = item.get_global_id(1);
   const auto sg_startx = global_idx - item.get_local_id(0);
   const auto sg_starty = global_idy - item.get_local_id(1);
   sub_group sg = item.get_sub_group();
   joint_matrix<int8_t, use::a, tM, tK, layout::row_major> tA(sg);
   joint_matrix<int8_t, use::b, tK, tN, layout::row_major> tB(sg);
   joint_matrix<int32_t, use::accumulator, tM, tN> tC(sg);
   joint_matrix_fill(sg, tC, 0);
   for (int k = 0; k < K; k += tk) {
     joint_matrix_load(sg, tA, memA + sg_startx * tM * K + k, K);
     joint_matrix_load(sg, tB, memB + k * N + sg_starty/SG_SIZE*tN, N);
     tC = joint_matrix_mad(sg, tA, tB, tC);
   }
   auto wi_data_c = matC.get_wi_data();
   for (int i = 0; i < wi_data_c.length(); i++)
     wi_data_c[i] *= alpha; // The indexing here "i" is in the vector owned by a WI, not in the matrix C
   joint_matrix_store(sg, tC, memC + sg_startx * tM * N + sg_starty/SG_SIZE*tN, N, layout::row_major);
}).wait();

Query Interface

Intel AMX, DPAS and Nvidia TPUs support different sizes and types. The query interface is used to validate user code and inform them about supported types, sizes, scope, and layouts by the implementation. This also offers development and tuning productivity by both scientists and library developers. The query interface we are proposing here is a compile-time query, so there will be no runtime errors. The query interface proposed here consists of three functionalities:

  • Validation: at compile time, the validation functionality informs the user whether a specific combination is valid or not. This takes place when the user specifies all template parameters.

  • Default values: this provides a default shape if the user does not provide a specific combination. In this case, aliases to the joint_matrix type can be used, namely joint_matrix_a/b/accumulator where no additional argument is needed. This form happens when the user specifies all template parameters except the sizes of the matrices (tiles) M, N, and K.

  • General query: the general query interface provides information about sizes, types, static/dynamic, and scopes that are supported by a specific TPU implementation. This is needed to avoid padding by the user, for tuning, and efficient code generation if used by a library. The general query returns an array of combinations of combination type. Each combination includes the sizes and the types for the matrices A, B, and accumulator. Note that for each TPU, the query returns max_msize, max_nsize, max_ksize or msize, nsize, ksize exclusively, depending on whether the implementation supports a continuous or discrete number of sizes. For example, the Intel AMX implementation supports a continuous number of sizes, so the max_* variant is applied and only the maximum number is returned. The DPAS implementation, on the other hand, supports a discrete list of numbers so the msize, nsize, ksize variant is applied. This form takes place when users only specify the TPU they are interested in using.

The table below provides a description for each of the member variables and type aliases in tpu_params class and the forms in which they are defined.

Member/type alias in tpu_params Forms they are defined in Description

type_a

validation, default values

type alias for the type of matrix A

type_b

validation, default values

type alias for the type of matrix B

type_accumulator

validation, default values

type alias for the type of matrix accumulator

M

validation, default values

when no sizes are provided by the user, indicates the suggested default size for M; usually this corresponds to the maximum size the implementation supports. In validation mode, where the user does provide sizes, this is the same value M that the user provides if M is supported by the implementation

N

validation, default values

when no sizes are provided by the user, indicates the suggested default size for N; usually this corresponds to the maximum size the implementation supports. In validation mode, where the user does provide sizes, this is the same value N that the user provides if N is supported by the implementation

K

validation, default values

when no sizes are provided by the user, indicates the suggested default size for K; usually this corresponds to the maximum size the implementation supports. In validation mode, where the user does provide sizes, this is the same value K that the user provides if K is supported by the implementation

joint_matrix_a

validation, default values

type alias for joint_matrix for matrix A

joint_matrix_b

validation, default values

type alias for joint_matrix for matrix B

joint_matrix_accumulator

validation, default values

type alias for joint_matrix for matrix accumulator

dynamic_p

validation, default values, general query

a boolean that indicates whether the implementation supports dynamic sizes (true) or not (false)

numtiles

validation, default values, general query

indicates number of tiles in Intel AMX (does not apply to DPAS)

scope

validation, default values, general query

indicates the memory and execution scope supported by the TPU implementation

combination

validation, default values, general query

composes the types and sizes of A, B, accumulator matrices allowed in one combination

max_msize, max_nsize, max_ksize

validation, default values, general query

if the TPU implementation supports a continuous number of element sizes, each of these members is non-zero, and the TPU implementation supports all element sizes from 1 up to (and including) that number. By contrast, if the TPU implementation supports a discrete number of element sizes, each of these members has the value zero

msize, nsize, ksize

validation, default values, general query

if the TPU implementation supports a discrete number of element sizes, each of these members is non-zero, and the value tells one of the supported element sizes. By contrast, if the TPU supports a continuous number of element sizes, each of these members has the value zero

atype, btype, accumulatortype

validation, default values, general query

indicates the types supported in the combination

combinations

validation, default values, general query

tells the set of supported matrix sizes and types according to the template parameters that are provided. In the "general query" form, the user provides only the TPU type, so the combinations array contains all supported tile sizes and element types for that TPU. In the "default values" form, the user provides the TPU type and element types, so the combinations array contains only those supported matrix sizes and element types that match those element types on that TPU. In the "validation" form, the user provides the TPU type, element types, and element sizes so only this specific combination is returned in the combinations array.

num_combinations

validation, default values, general query

indicates number of combinations supported by the TPU implementation which corresponds to the size of the combinations array

namespace sycl::ext::oneapi::experimental::matrix {
template<tpu u, typename Ta=void, typename Tb=void, typename Tc=void, int sM=0, int sN=0, int sK=0>
struct tpu_params;

// Validation form: Valid or not
// Specialization when both types and sizes are given
template <typename Ta, typename Tb, typename Tc, int sM, int sN, int sK>
struct tpu_params<
    tpu::amx, Ta, Tb, Tc, sM, sN, sK,
    typename std::enable_if<(
        !std::is_same_v<Ta, void> && !std::is_same_v<Tb, void> &&
        !std::is_same_v<Tc, void> && sM != 0 && sN != 0 && sK != 0)>::type> {
  // Validate that parameters are supported
  static_assert(
      (sM == 0 && sN == 0 && sK == 0) ||
          (is_combination_valid_amx<Ta, Tb, Tc>(sM, sN, sK)),
      "Invalid parameters for Intel AMX, query valid types and maximum sizes "
      "using: "
      "tpu_params<tpu::amx> myparams; and then check out myparams.combinations array");


  using type_a = Ta; // this type alias is not available in the current implementation
  using type_b = Tb; // this type alias is not available in the current implementation
  using type_accumulator = Tc; // this type alias is not available in the current implementation

  // if combination is valid, construct the matrices

  static constexpr std::size_t M = (sM != 0) ? sM : 16;
  static constexpr std::size_t N = (sN != 0) ? sN : 16;
  static constexpr std::size_t K =
      (sK != 0) ? sK : ((sizeof(Ta) == 1) ? 64 : 32);

  template <layout Layout, typename Group = sub_group>
  using joint_matrix_a = joint_matrix<Ta, use::a, defaultM, defaultK, Layout, Group>;
  template <layout Layout, typename Group = sub_group>
  using joint_matrix_b = joint_matrix<Tb, use::b, defaultK, defaultN, Layout, Group>;
  template <layout Layout, typename Group = sub_group>
  using joint_matrix_accumulator = joint_matrix<Tc, use::accumulator, defaultM, defaultN, Layout, Group>;

  static constexpr bool dynamic_p = false; // should be true in future implementations
                          // because Intel AMX hardware supports dynamic sizes
  static constexpr uint32_t numtiles = 8;
  static constexpr scope_t scope = scope_t::sub_group;
  struct combination {
    uint32_t max_msize;
    uint32_t max_nsize;
    uint32_t max_ksize;
    uint32_t msize;
    uint32_t nsize;
    uint32_t ksize;
    matrix_type atype;
    matrix_type btype;
    matrix_type accumulatortype;
  };
  // In this case, the combinations array contains only the combination that the user provided
  static constexpr combination combinations[] = {
      {16, 16, (sizeof(Ta) == 1) ? 64 : 32, sM, sN, sK}};
  static constexpr int num_combinations =
      sizeof(combinations) / sizeof(combination);
};

// Default values form: Sizes-only query
// Specialization for when only types are given, need to query only sizes
template <typename Ta, typename Tb, typename Tc>
struct tpu_params<tpu::amx, Ta, Tb, Tc, 0, 0, 0,
                  typename std::enable_if<(!std::is_same_v<Ta, void> &&
                                           !std::is_same_v<Tb, void> &&
                                           !std::is_same_v<Tc, void>)>::type> {
  static_assert((are_types_valid_amx<Ta, Tb, Tc>()),
                "Invalid types for Intel AMX, supported types are int8_t, uint8_t, "
                "and bf16 (Note that unsigned short should be used in the"
                "DPC++ code to implement bf16) ");

  using type_a = Ta; // this type alias is not available in the current implementation
  using type_b = Tb; // this type alias is not available in the current implementation
  using type_accumulator = Tc; // this type alias is not available in the current implementation

  // construct the matrices using the default sizes
  static constexpr std::size_t M = 16;
  static constexpr std::size_t N = 16;
  static constexpr std::size_t K = ((sizeof(Ta) == 1) ? 64 : 32);

  template <layout Layout, typename Group = sub_group>
  using joint_matrix_a = joint_matrix<Ta, use::a, M, K, Layout, Group>;
  template <layout Layout, typename Group = sub_group>
  using joint_matrix_b = joint_matrix<Tb, use::b, K, N, Layout, Group>;
  template <layout Layout, typename Group = sub_group>
  using joint_matrix_accumulator = joint_matrix<Tc, use::accumulator, M, N, Layout, Group>;

  static constexpr bool dynamic_p = false; // should be true in future implementations because
                          // Intel AMX hardware supports dynamic sizes
  static constexpr uint32_t numtiles = 8;
  static constexpr scope_t scope = scope_t::sub_group;
  struct combination {
    uint32_t max_msize;
    uint32_t max_nsize;
    uint32_t max_ksize;
    uint32_t msize;
    uint32_t nsize;
    uint32_t ksize;
    matrix_type atype;
    matrix_type btype;
    matrix_type accumulatortype;
  };
  // In this case, the combinations array contain only the combinations that correspond to the Ta, Tb, and Tc
  // types that the user provided
  static constexpr combination combinations[] = {
      {16, 16, (sizeof(Ta) == 1) ? 64 : 32}};
  static constexpr int num_combinations =
      sizeof(combinations) / sizeof(combination);
};

// General query form:
// types are not given, no default sizes and no implicit matrix construction
template <int sM, int sN, int sK>
struct tpu_params<tpu::amx, void, void, void, sM, sN, sK> {
  static constexpr bool dynamic_p = false; // should be true in future implementations because
                          // Intel AMX hardware supports dynamic sizes
  static constexpr uint32_t numtiles = 8;
  static constexpr scope_t scope = scope_t::sub_group;
  struct combination {
    uint32_t max_msize;
    uint32_t max_nsize;
    uint32_t max_ksize;
    uint32_t msize;
    uint32_t nsize;
    uint32_t ksize;
    matrix_type atype;
    matrix_type btype;
    matrix_type accumulatortype;
  };

  static constexpr combination combinations[] = {
      {16, 16, 64, 0, 0, 0, matrix_type::sint8, matrix_type::sint8, matrix_type::sint32},
      {16, 16, 64, 0, 0, 0, matrix_type::sint8, matrix_type::uint8, matrix_type::sint32},
      {16, 16, 64, 0, 0, 0, matrix_type::uint8, matrix_type::sint8, matrix_type::sint32},
      {16, 16, 64, 0, 0, 0, matrix_type::uint8, matrix_type::uint8, matrix_type::sint32},
      {16, 16, 32, 0, 0,0, matrix_type::bf16, matrix_type::bf16, matrix_type::fp32}};
  static constexpr int num_combinations =
      sizeof(combinations) / sizeof(combination);
};


enum class tpu {
  dpas,
  amx
};

enum class matrix_type {
  bf16,
  fp16,
  fp19,  // tfloat32
  fp32,
  fp64,
  sint2,
  sint4,
  sint8,
  sint16,
  sint32,
  sint64,
  uint2,
  uint4,
  uint8,
  uint16,
  uint32,
  uint64
};

enum class scope_t {
  sub_group,
  work_group
};
}

Validation Example:

// User can provide sizes besides the types and tpu_params can assert if they are supported or not
// in this case, an assertion will happens as 16 is not a supported size for M
using myparams = tpu_params<tpu::dpas, int8_t, int8_t, int, 16, 8, 32>;
size_t NDRangeM = M / myparams::M;  //Assertion would happen at this line
size_t NDRangeN = N / myparams::N;

Default Values Example:

using myparams = tpu_params_both<tpu::dpas, int8_t, int8_t, int>;
// use this to construct the ranges on the host side
size_t NDRangeM = M / myparams::M;
size_t NDRangeN = N / myparams::N;
//if M, N, K do not multiply the default sizes, padding has to be done
// device code: the matrices are constructed using the default dimensions
myparams::joint_matrix_a sub_a(sg);
myparams::joint_matrix_b sub_b(sg);
myparams::joint_matrix_accumulator sub_c(sg);

General Query Example:

constexpr int M = 1500; // with msize = 8 and msize = 4,
          // M can be broken up to 125 sequence of 8-sized ops and remaining 500 using 125 sequence of 4-sized ops
tpu_params<tpu::dpas> params;
constexpr int msize = break_dimension(params, M);
constexpr int msize_remainder = break_dimension_remainder(params, M);
constexpr int nsize = params.combinations[0].nsize;
constexpr int ksize = params.combinations[0].ksize;
// device code:
joint_matrix<int8_t, use::a, msize, ksize> sub_a(sg);
joint_matrix<int8_t, use::b, ksize, nsize> sub_b(sg);
joint_matrix<int, use::accumulator, msize, nsize> sub_c(sg);
//Remainder handling

Future-looking API

Memory scope

The current experimental API uses joint_ semantics to define the memory scope of the matrix. The long term solution is to use the proposed group_local_memory extension to allocate the matrix in local memory associated with a SYCL group as shown in the example below.

multi_ptr<matrix<T>, address_space::local_space> tA_ptr = group_local_memory<matrix<sub_group, int8_t, tM, tN, use::a>>(sg);

We did not utilize this extension for this matrix API version because sub-group local memory is not yet well defined in DPC++. Moreover, the representation of this notion in LLVM IR and SPIR-V is not clear yet.

WI data to joint matrix mapping coordinates information for piece-wise operations

The indexing provided inside the wi_data class accesses only the portion of the matrix held by the current WI. It is not possible to know the location of this portion in the original matrix. This coordinates mapping is implementation defined and changes from one backend to the other. For general piece-wise operations like sum of rows of a matrix, the WI data to joint matrix mapping information is needed to reason about the matrix view. Within the joint matrix extension, we want to write, as much as possible, one code to run on different backends. If backend X states that a WI owns one exact row of the matrix for instance, writing the following code will work only on that backend for that version of hardware. If a different hardware and implementation is used, the same WI may own only half of the row if, for example, the SG size increased.

auto data = C.get_wi_data();
for (int i = 0; i < length; ++i) {
  sum_of_local_rows[row] += data[i];
}

We want to keep backward compatibility in the joint matrix code when implementations or hardware change. To that end, instead of hard-coding this mapping, we use general backend and target-agnostic functionality, especially in the JIT compilation mode of SYCL. For this reason we would like to be able to query this mapping so that code does not have to change from one version to the other.

So for the mapping problem, since this mapping is implementation-defined, one of the proposals is to add runtime functions like:

auto data = C.get_wi_data();
for (int i = 0; i < length; ++i) {
  auto row, col = data[i].get_coord();
  sum_of_local_rows[row] += data[i];
}

Open Questions

  • Ronan Keryell: "It would be interesting to investigate whether providing also member functions would simplify the API. Provide both so it is possible to use the best one for each use case, while waiting for https://en.wikipedia.org/wiki/Uniform_Function_Call_Syntax to land into C++?"

  • In the future looking APIs, get_wi_data (that is currently under design) returns an owned object. Should this return a view object to make sure the original matrix C is changed after its slices are modified.

  • dynamic_extent on the shape of joint_matrix is only available on Intel AMX. Should this be part of the API?

  • This document still contains non-portable code between Intel AMX and DPAS, and Nvidia Tensor Cores such as: packed layout and dynamic_extent. Currently, these restrictions are explained in the spec text. But we might decide to move these to a separate Intel-specific additional matrix API document.

TODO List

  • Add WI data to joint matrix mapping coordinates information for piece-wise operations. This will be added as part of the query or new methods to the 'get_wi_data' class.

  • Change the type of scope in the query interface to be able to return more than one value. This will be useful in the event we support other scopes like workgroup besides subgroups

  • Add a more realistic and complete example that shows the value of the general query

Revision History

Rev Date Author Changes

1

2021-04-13

Dounia Khaldi

Initial public working draft.

2

2021-10-05

Dounia Khaldi

JIT implementation on both Intel AMX and DPAS

3

2022-05-16

Dounia Khaldi

Add matrix fill and piece-wise operations support

4

2022-08-25

Dounia Khaldi

Update the matrix spec by adding the new matrix use parameter and remove reference to the AOT AMX initial implementation