Skip to content

Commit

Permalink
Include changes for Intel PVC pipeline (#51)
Browse files Browse the repository at this point in the history
Add an Intel PVC pipeline to compute GEMM.
  • Loading branch information
muhammad-tanvir-1211 committed May 16, 2024
1 parent 9d12ff6 commit 7e7db70
Show file tree
Hide file tree
Showing 13 changed files with 559 additions and 10 deletions.
2 changes: 1 addition & 1 deletion include/cute/atom/copy_atom.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -769,7 +769,7 @@ print_latex_copy(LayoutS const& S, ThrIDS const& TS, // (m,n) -> (tid,vid) and
#include <cute/atom/copy_traits_sm90_tma.hpp>
#endif

#if defined(CUTLASS_ENABLE_SYCL)
#if defined(SYCL_INTEL_TARGET)
#include <cute/atom/copy_traits_xe.hpp>
#endif

Expand Down
26 changes: 22 additions & 4 deletions include/cute/atom/copy_traits_xe.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -53,8 +53,17 @@ struct XE_2D_LD_Unpack
static_assert(is_rmem<TD>::value);
int H = size<0>(traits.tensor);
int W = size<1>(traits.tensor) * sizeof(typename Copy_Traits::CopyInternalType);
auto [y, x] = src.data().coord_;
CopyOp::copy(traits.tensor.data().get(), W, H, W, intel::coord_t{static_cast<int>(x), static_cast<int>(y)}, &*dst.data());
auto [y, x, z] = src.data().coord_;
CopyOp::copy(traits.tensor.data() + z, W, H, W, intel::coord_t{x, y}, &*dst.data());
}

template <class GCoord, class GShape, class GStride>
CUTE_HOST_DEVICE constexpr auto
get_pvc_tensor(GCoord const& coord, GShape const& shape, GStride const& stride_mul) const
{
return make_tensor(make_inttuple_iter(coord),
make_layout(make_shape(_1{}, get<0>(shape), get<1>(shape), get<2>(shape)),
make_stride(_1{}, E<0>{} * get<0>(stride_mul), E<1>{} * get<1>(stride_mul), E<2>{} * get<2>(stride(tensor)))));
}
};

Expand Down Expand Up @@ -274,8 +283,17 @@ struct XE_2D_ST_Unpack
static_assert(is_rmem<TS>::value);
int H = size<0>(traits.tensor);
int W = size<1>(traits.tensor) * sizeof(typename Copy_Traits::CopyInternalType);
auto [y, x] = dst.data().coord_;
CopyOp::copy(traits.tensor.data().get(), W, H, W, intel::coord_t{static_cast<int>(x), static_cast<int>(y)}, &*src.data());
auto [y, x, z] = dst.data().coord_;
CopyOp::copy(traits.tensor.data() + z, W, H, W, intel::coord_t{x, y}, &*src.data());
}

template <class GCoord, class GShape, class GStride>
CUTE_HOST_DEVICE constexpr auto
get_pvc_tensor(GCoord const& coord, GShape const& shape, GStride const& stride_mul) const
{
return make_tensor(make_inttuple_iter(coord),
make_layout(make_shape(_1{}, get<0>(shape), get<1>(shape), get<2>(shape)),
make_stride(_1{}, E<0>{} * get<0>(stride_mul), E<1>{} * get<1>(stride_mul), E<2>{} * get<2>(stride(tensor)))));
}
};

Expand Down
4 changes: 4 additions & 0 deletions include/cute/atom/mma_atom.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,10 @@
#include <cute/tensor.hpp>
#include <cute/util/type_traits.hpp>

#if defined(CUTLASS_ENABLE_SYCL)
#include <cute/atom/mma_traits_xe.hpp>
#endif

namespace cute {

template <class... Args>
Expand Down
4 changes: 2 additions & 2 deletions include/cute/atom/mma_traits_xe.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -41,8 +41,8 @@ template <>
struct MMA_Traits<XE_8x16x16_BF16BF16F32F32_NN>
{
using ValTypeD = float;
using ValTypeA = sycl::ext::oneapi::bfloat16;
using ValTypeB = sycl::ext::oneapi::bfloat16;
using ValTypeA = bfloat16_t;
using ValTypeB = bfloat16_t;
using ValTypeC = float;

using Shape_MNK = Shape<_8,_16,_16>;
Expand Down
6 changes: 3 additions & 3 deletions include/cute/config.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -34,9 +34,9 @@
# define CUTE_HOST_DEVICE __forceinline__ __host__ __device__
# define CUTE_DEVICE __forceinline__ __device__
# define CUTE_HOST __forceinline__ __host__
#elif defined(__SYCL_CUDA_ARCH__)
# define CUTE_HOST_DEVICE __attribute__((always_inline))
# define CUTE_DEVICE __attribute__((always_inline))
#elif defined(__SYCL_DEVICE_ONLY__)
# define CUTE_HOST_DEVICE __attribute__((always_inline)) inline
# define CUTE_DEVICE __attribute__((always_inline)) inline
# define CUTE_HOST inline
#else
# define CUTE_HOST_DEVICE inline
Expand Down
1 change: 1 addition & 0 deletions include/cute/tensor.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,7 @@

#include <cute/pointer.hpp>
#include <cute/layout.hpp>
#include <iomanip>

namespace cute
{
Expand Down
7 changes: 7 additions & 0 deletions include/cutlass/arch/arch.h
Original file line number Diff line number Diff line change
Expand Up @@ -97,6 +97,13 @@ struct Sm90 {
static int const kMinComputeCapability = 90;
};

#if defined(CUTLASS_ENABLE_SYCL)
struct IntelPVC {
static int const kMinComputeCapability = 0;
};

#endif

/// Triggers a breakpoint on the device
CUTLASS_DEVICE
void device_breakpoint() {
Expand Down
4 changes: 4 additions & 0 deletions include/cutlass/gemm/collective/collective_mma.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -75,4 +75,8 @@ struct CollectiveMma {
#include "cutlass/gemm/collective/sm90_mma_tma_gmma_ss_warpspecialized.hpp"
#include "cutlass/gemm/collective/sm90_mma_array_tma_gmma_ss_warpspecialized.hpp"
#include "cutlass/gemm/collective/sm90_mma_tma_gmma_ss_warpspecialized_fp8.hpp"

#if defined(SYCL_INTEL_TARGET)
#include "cutlass/gemm/collective/intel_pvc_mma.hpp"
#endif
/////////////////////////////////////////////////////////////////////////////////////////////////
220 changes: 220 additions & 0 deletions include/cutlass/gemm/collective/intel_pvc_mma.hpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,220 @@
/***************************************************************************************************
* Copyright (c) 2024 - 2024 Codeplay Software Ltd. All rights reserved.
* SPDX-License-Identifier: BSD-3-Clause
*
* Redistribution and use in source and binary forms, with or without
* modification, are permitted provided that the following conditions are met:
*
* 1. Redistributions of source code must retain the above copyright notice, this
* list of conditions and the following disclaimer.
*
* 2. Redistributions in binary form must reproduce the above copyright notice,
* this list of conditions and the following disclaimer in the documentation
* and/or other materials provided with the distribution.
*
* 3. Neither the name of the copyright holder nor the names of its
* contributors may be used to endorse or promote products derived from
* this software without specific prior written permission.
*
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
* AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
* IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
* DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
* FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
* DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
* SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
* CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
* OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
*
**************************************************************************************************/
#pragma once

#include "cutlass/cutlass.h"
#include "cutlass/gemm/dispatch_policy.hpp"

#include "cute/algorithm/functional.hpp"
#include "cute/atom/mma_atom.hpp"
#include "cute/algorithm/gemm.hpp"
#include "cute/atom/mma_atom.hpp"
#include "cute/tensor_predicate.hpp"

/////////////////////////////////////////////////////////////////////////////////////////////////

namespace cutlass::gemm::collective {
using namespace cute;
/////////////////////////////////////////////////////////////////////////////////////////////////

template <
class TileShape_,
class ElementA_,
class StrideA_,
class ElementB_,
class StrideB_,
class TiledMma_,
class GmemTiledCopyA_,
class SmemLayoutAtomA_,
class SmemCopyAtomA_,
class TransformA_,
class GmemTiledCopyB_,
class SmemLayoutAtomB_,
class SmemCopyAtomB_,
class TransformB_>
struct CollectiveMma<
MainloopIntelPVCUnpredicated,
TileShape_,
ElementA_,
StrideA_,
ElementB_,
StrideB_,
TiledMma_,
GmemTiledCopyA_,
SmemLayoutAtomA_,
SmemCopyAtomA_,
TransformA_,
GmemTiledCopyB_,
SmemLayoutAtomB_,
SmemCopyAtomB_,
TransformB_>
{
//
// Type Aliases
//
using DispatchPolicy = MainloopIntelPVCUnpredicated;
using TileShape = TileShape_;
using ElementA = ElementA_;
using StrideA = StrideA_;
using ElementB = ElementB_;
using StrideB = StrideB_;
using TiledMma = TiledMma_;
using ElementAccumulator = typename TiledMma::ValTypeC;
using GmemTiledCopyA = GmemTiledCopyA_;
using GmemTiledCopyB = GmemTiledCopyB_;
using SmemLayoutAtomA = SmemLayoutAtomA_;
using SmemLayoutAtomB = SmemLayoutAtomB_;
using SmemCopyAtomA = SmemCopyAtomA_;
using SmemCopyAtomB = SmemCopyAtomB_;
using TransformA = TransformA_;
using TransformB = TransformB_;
using ArchTag = typename DispatchPolicy::ArchTag;

static constexpr int SubgroupSize = DispatchPolicy::SubgroupSize;

static constexpr int DpasM = get<0>(shape(typename TiledMma::LayoutA_TV{})); // rows per dpas operation per sub_group for Matrix A
static constexpr int DpasN = get<1>(shape(typename TiledMma::LayoutB_TV{})); // cols per dpas operation per sub_group for Matrix B
static constexpr int DpasK = get<1>(shape(typename TiledMma::LayoutA_TV{})); // cols per dpas operation per sub_group for Matrix A

static constexpr uint32_t MaxThreadsPerBlock = DpasM * DpasN;
static constexpr uint32_t MinBlocksPerMultiprocessor = 1;

static constexpr int FragsM = get<0>(TileShape{}) / DpasM; // A frags per sub_group
static constexpr int FragsN = get<1>(TileShape{}) / DpasN; // B frags per sub_group
static constexpr int FragsK = get<2>(TileShape{}) / DpasK;

// Calculate the vector width based on the amount of registers
// required per work item by dividing the total fragment size by
// the sub_group size.
static constexpr int VecC = (DpasN * DpasM) / SubgroupSize;
static constexpr int VecA = (DpasM * DpasK) / SubgroupSize;
static constexpr int VecB = (DpasN * DpasK) / SubgroupSize;

// Host side kernel arguments
struct Arguments {
ElementA const* ptr_A;
StrideA dA;
ElementB const* ptr_B;
StrideB dB;
};

struct Params {
using XE_Copy_A = decltype(make_xe_2d_copy<GmemTiledCopyA>(make_tensor(static_cast<ElementA const*>(nullptr),
repeat_like(StrideA{}, int32_t(0)), StrideA{})));
using XE_Copy_B = decltype(make_xe_2d_copy<GmemTiledCopyB>(make_tensor(static_cast<ElementB const*>(nullptr),
repeat_like(StrideB{}, int32_t(0)), StrideB{})));
XE_Copy_A gmem_tiled_copy_a;
XE_Copy_B gmem_tiled_copy_b;
};

//
// Methods
//

CollectiveMma() = default;

template <class ProblemShape>
static constexpr Params
to_underlying_arguments(ProblemShape const& problem_shape, Arguments const& args, void* workspace) {
(void) workspace;

auto problem_shape_MNKL = append<4>(problem_shape, 1);
auto [M,N,K,L] = problem_shape_MNKL;

Tensor tensorA = make_tensor(args.ptr_A, make_layout(make_shape(M,K,L), args.dA));
Tensor tensorB = make_tensor(args.ptr_B, make_layout(make_shape(K,N,L), args.dB));

typename Params::XE_Copy_A copyA = make_xe_2d_copy<GmemTiledCopyA>(tensorA);
typename Params::XE_Copy_B copyB = make_xe_2d_copy<GmemTiledCopyB>(tensorB);
return Params{copyA, copyB};
}

/// Perform a subgroup-scoped matrix multiply-accumulate
template <
class FrgTensorD,
class TensorA,
class TensorB,
class FrgTensorC,
class KTileIterator,
class ResidueMNK
>
CUTLASS_DEVICE void
operator() (
FrgTensorD &accum,
TensorA gA,
TensorB gB,
FrgTensorC const &src_accum,
KTileIterator k_tile_iter, int k_tile_count,
ResidueMNK residue_mnk,
int thread_idx,
char *smem_buf,
Params const& mainloop)
{
(void)residue_mnk;
(void)thread_idx;
(void)smem_buf;

static_assert(is_rmem<FrgTensorD>::value, "D tensor must be rmem resident.");
static_assert(is_tuple<typename TensorA::engine_type::iterator::value_type>::value, "A tensor must be a tuple iterator.");
static_assert(is_tuple<typename TensorB::engine_type::iterator::value_type>::value, "B tensor must be a tuple iterator.");
static_assert(is_rmem<FrgTensorC>::value, "C tensor must be rmem resident.");

// Tensor to hold input data
Tensor tAr = make_tensor<typename TiledMma::ValTypeA>(Shape<Int<get<0>(TileShape{}) * FragsK>, Int<1>>{});
Tensor tBr = make_tensor<typename TiledMma::ValTypeB>(Shape<Int<FragsK * get<1>(TileShape{}) / FragsN>, Int<FragsN>>{});

Tensor tAr_view = make_tensor(static_cast<decltype(tAr) &&>(tAr).data(),
Shape<Int<VecA>, Int<FragsM>, Int<FragsK>>{});
Tensor tBr_view = make_tensor(static_cast<decltype(tBr) &&>(tBr).data(),
Shape<Int<VecB>, Int<FragsK>, Int<FragsN>>{});

// Instantiate the M MA object
TiledMma tiled_mma;

//
// Mainloop
//
for (int k_tile = 0, k = 0; k_tile < k_tile_count; ++k_tile, k += DpasK * FragsK)
{
// Copy gmem to rmem for the first k_tile
copy(mainloop.gmem_tiled_copy_a, gA(_,_,k), tAr);
copy(mainloop.gmem_tiled_copy_b, gB(_,k/2,_), tBr);

for (int kl = 0; kl < FragsK; kl++) {
cute::gemm(tiled_mma, accum, tAr_view(_, _, kl), tBr_view(_, kl, _), src_accum);
}
}
}
};

} // namespace cutlass::gemm::collective

/////////////////////////////////////////////////////////////////////////////////////////////////
4 changes: 4 additions & 0 deletions include/cutlass/gemm/device/gemm_universal_adapter.h
Original file line number Diff line number Diff line change
Expand Up @@ -402,7 +402,11 @@ class GemmUniversalAdapter<
const auto sycl_block = syclcompat::dim3(block.x, block.y, block.z);
const auto sycl_grid = syclcompat::dim3(grid.x, grid.y, grid.z);

#if defined (SYCL_INTEL_TARGET)
syclcompat::experimental::launch<device_kernel<GemmKernel>, DispatchPolicy::SubgroupSize>(sycl_grid, sycl_block, smem_size, params);
#else
syclcompat::launch<device_kernel<GemmKernel>>(sycl_grid, sycl_block, smem_size, params);
#endif
#else
device_kernel<GemmKernel><<<grid, block, smem_size, stream>>>(params);
#endif
Expand Down
14 changes: 14 additions & 0 deletions include/cutlass/gemm/dispatch_policy.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -71,6 +71,7 @@ enum class KernelInputTransformType {
//
// Kernel schedule policies (the base class tags, one for each kernel layer file)
//
struct KernelSinglestage { };
struct KernelMultistage { };
struct KernelCpAsyncWarpSpecialized { };
struct KernelCpAsyncWarpSpecializedPingpong { };
Expand Down Expand Up @@ -269,6 +270,19 @@ struct MainloopSm90ArrayTmaGmmaWarpSpecialized {
"KernelSchedule must be one of the Ptr-Array or Grouped Gemm TMA Warp Specialized Cooperative policies");
};


#if defined(SYCL_INTEL_TARGET)
struct MainloopIntelPVCBase {
constexpr static int Stages = 1;
using ArchTag = arch::IntelPVC;
using Schedule = KernelSinglestage;
using ClusterShape = Shape<_1,_1,_1>;
static constexpr int SubgroupSize = 16;
};

struct MainloopIntelPVCUnpredicated : MainloopIntelPVCBase{};
#endif

//////////////////////////////////////////////////////////////////////////////

} // namespace cutlass::gemm
4 changes: 4 additions & 0 deletions include/cutlass/gemm/kernel/gemm_universal.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -89,4 +89,8 @@ struct IsCutlass3ArrayKernel<ProblemShape, cute::void_t<typename ProblemShape::U
#include "cutlass/gemm/kernel/sm90_gemm_tma_warpspecialized_pingpong.hpp"
#include "cutlass/gemm/kernel/sm90_gemm_tma_warpspecialized_cooperative.hpp"
#include "cutlass/gemm/kernel/sm90_gemm_array_tma_warpspecialized_cooperative.hpp"

#if defined(SYCL_INTEL_TARGET)
#include "cutlass/gemm/kernel/intel_pvc_gemm.hpp"
#endif
////////////////////////////////////////////////////////////////////////////////
Loading

0 comments on commit 7e7db70

Please sign in to comment.