Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 2 additions & 1 deletion applications/dual_gemm/collective/xe_dual_gemm_mma.hpp
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
/***************************************************************************************************
* Copyright (c) 2024 - 2025 Codeplay Software Ltd. All rights reserved.
* Copyright (C) 2025 Intel Corporation, All rights reserved.
* SPDX-License-Identifier: BSD-3-Clause
*
* Redistribution and use in source and binary forms, with or without
Expand Down Expand Up @@ -169,7 +170,7 @@ struct DualGemmMma<MainloopIntelXeXMX16<Stages, Schedule>, TileShape_, ElementA_
TiledMma tiled_mma;
// TODO(Codeplay): see if we can make this nicer
// To make all work items in a subgroup have the same global tensors pass in the index of work item 0 in each subgroup
auto sg = syclcompat::get_nd_item<1>().get_sub_group();
auto sg = compat::get_nd_item<1>().get_sub_group();
auto first_thread_in_sg_idx = sg.get_group_linear_id() * DispatchPolicy::SubgroupSize;
auto thr_mma = tiled_mma.get_slice(first_thread_in_sg_idx);

Expand Down
4 changes: 3 additions & 1 deletion applications/flash_attention_v2/collective/fmha_fusion.hpp
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
/***************************************************************************************************
* Copyright (c) 2025 - 2025 Codeplay Software Ltd. All rights reserved.
* Copyright (c) 2025 - 2025 Codeplay Software Ltd. All rights reserved.
* Copyright (C) 2025 Intel Corporation, All rights reserved.
* SPDX-License-Identifier: BSD-3-Clause
*
* Redistribution and use in source and binary forms, with or without
Expand Down Expand Up @@ -39,6 +40,7 @@ using namespace cute;

struct VariableLength {
int max_length;
int total_length = 0;
int* cumulative_length = nullptr;

CUTE_HOST_DEVICE operator int() const {
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,255 @@
/***************************************************************************************************
* Copyright (C) 2025 Intel Corporation, All rights reserved.
* SPDX-License-Identifier: BSD-3-Clause
*
* Redistribution and use in source and binary forms, with or without
* modification, are permitted provided that the following conditions are met:
*
* 1. Redistributions of source code must retain the above copyright notice, this
* list of conditions and the following disclaimer.
*
* 2. Redistributions in binary form must reproduce the above copyright notice,
* this list of conditions and the following disclaimer in the documentation
* and/or other materials provided with the distribution.
*
* 3. Neither the name of the copyright holder nor the names of its
* contributors may be used to endorse or promote products derived from
* this software without specific prior written permission.
*
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
* AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
* IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
* DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
* FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
* DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
* SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
* CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
* OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
*
**************************************************************************************************/
/*! \file
\brief Functor performing elementwise operations used by epilogues.
*/

#pragma once

#include <sycl/sycl.hpp>
#include "cutlass/cutlass.h"
#include "cutlass/epilogue/dispatch_policy.hpp"
#include "cutlass/epilogue/collective/collective_epilogue.hpp"
#include "cutlass/epilogue/collective/detail.hpp"
#include "cutlass/detail/layout.hpp"

/////////////////////////////////////////////////////////////////////////////////////////////////
namespace cutlass {
namespace flash_attention {
namespace collective {

/////////////////////////////////////////////////////////////////////////////////////////////////

template <class DispatchPolicy, class MMAOperation_, class TileShapeOutput_, class SubgroupLayout_, class... Args> class FlashChunkPrefillEpilogue {
static_assert(cutlass::detail::dependent_false<DispatchPolicy>, "Could not find an epilogue specialization.");
};

template <class MMAOperation_, class TileShapeOutput_, class SubgroupLayout_, class ElementCompute_, class ElementO_, class StrideO_, class ElementLSE_, class CopyOpO_>
class FlashChunkPrefillEpilogue<epilogue::IntelXeXMX16, MMAOperation_, TileShapeOutput_, SubgroupLayout_, ElementCompute_, ElementO_, StrideO_, ElementLSE_, CopyOpO_> {
public:
//
// Type Aliases
//
using DispatchPolicy = epilogue::IntelXeXMX16;
using ElementO = ElementO_;
using StrideO = StrideO_;
using ElementLSE = ElementLSE_;
using CopyOpO = CopyOpO_;
using SubgroupLayout = SubgroupLayout_;
using TileShapeOutput = TileShapeOutput_;
using TiledMmaOutput = typename TiledMMAHelper<MMA_Atom<MMAOperation_>, Layout<TileShapeOutput>, SubgroupLayout>::TiledMMA;
using GmemTiledCopyO = CopyOpO;
using ElementOutput = ElementO_;
using ElementCompute = ElementCompute_;
using ElementAccumulator = ElementCompute_;
using SubgroupTileShape = decltype(cute::shape_div(TileShapeOutput{}, (SubgroupLayout{}.shape())));

static constexpr int SubgroupSize = DispatchPolicy::SubgroupSize;

static_assert(cute::rank(TileShapeOutput{}) == 3, "TileShapeOutput must be rank-3: [CTA_M_QO, CTA_N_VO, CTA_K_PV]");
static_assert(cute::rank(StrideO{}) == 3, "StrideO must be rank-3: [seq_len_qo, head_size_vo, batch * num_heads]");

using CopyThreadShape = Shape<_1, Int<SubgroupSize>>;

using traits_store_O = Copy_Traits<GmemTiledCopyO, StrideO>;
using atom_load_O = Copy_Atom<traits_store_O, ElementO>;
using val_layout_load_O = decltype(make_layout(shape_div(typename traits_store_O::BlockShape{}, CopyThreadShape{})));
using XE_Copy_O = decltype(make_tiled_copy(atom_load_O{}, Layout<CopyThreadShape>{}, val_layout_load_O{}));

private:
constexpr static bool is_destination_supported = not cute::is_void_v<ElementO>;

public:
using EmptyType = cute::tuple<>;

struct TensorStorageImpl : cute::tuple<EmptyType, EmptyType> {};

struct SharedStorage {
using TensorStorage = TensorStorageImpl;

TensorStorage tensors;
};
using TensorStorage = typename SharedStorage::TensorStorage;

// Host side epilogue arguments
struct Arguments {
ElementO const *ptr_O;
StrideO dO;
};

// Device side epilogue params
struct Params {
XE_Copy_O xe_store_o;
};

//
// Methods
//
template <typename To_type, typename Engine, typename Layout>
CUTLASS_DEVICE auto convert_type(Tensor<Engine, Layout> const &tensor) {
using From_type = typename Engine::value_type;
constexpr int numel = decltype(size(tensor))::value;
cutlass::NumericArrayConverter<To_type, From_type, numel> convert_op;
auto frag =
convert_op(*reinterpret_cast<const cutlass::Array<From_type, numel> *>(
tensor.data()));
return make_tensor(make_rmem_ptr<To_type>(&frag), tensor.layout());
}

template <class ProblemShape>
static constexpr Params to_underlying_arguments(ProblemShape const &problem_shape, Arguments const &args,
[[maybe_unused]] void *workspace) {
auto [batch, num_heads_q, num_heads_kv, seq_len_qo, seq_len_kv, seq_len_kv_cache, head_size_qk, head_size_vo] = problem_shape;
auto tensorO = make_tensor(make_gmem_ptr(static_cast<ElementO const*>(args.ptr_O)),
make_layout(make_shape(seq_len_qo, num_heads_q * head_size_vo, batch),
args.dO));
XE_Copy_O xe_store_o{XE_Copy_O{}.with(tensorO)};
return {
xe_store_o,
};
}

template <class ProblemShape>
static size_t get_workspace_size(ProblemShape const &problem_shape, Arguments const &args) {
return 0;
}

template <class ProblemShape>
static cutlass::Status initialize_workspace(ProblemShape const &problem_shape, Arguments const &args, void *workspace,
cudaStream_t stream, CudaHostAdapter *cuda_adapter = nullptr) {
return Status::kSuccess;
}

template <class ProblemShape>
CUTLASS_HOST_DEVICE static bool can_implement(ProblemShape const &problem_shape,
[[maybe_unused]] Arguments const &args) {
return true;
}

CUTLASS_HOST_DEVICE
FlashChunkPrefillEpilogue(Params const &params_, TensorStorage const &) : params(params_) {}

template <class ProblemShape, class SequenceLengthShape, class TileCoord, class FragOut, class FragMax, class FragSum>
CUTLASS_DEVICE void operator()(ProblemShape problem_shape, SequenceLengthShape sequence_length_shape, TileCoord tile_coord, FragOut &out,
FragMax const &max, FragSum &sum) {

using namespace cute;

static constexpr bool is_var_len = cutlass::fmha::collective::is_variable_length_v<tuple_element_t<2, ProblemShape>>;

using FragOutLayout = typename FragOut::layout_type;

constexpr int Vec = shape<0>(FragOutLayout{});
constexpr int FragsM = shape<1>(FragOutLayout{});
constexpr int FragsN = size(select<2,3>(shape(FragOutLayout{})));

auto sg = compat::get_nd_item<1>().get_sub_group();
auto out_reg = make_tensor(static_cast<decltype(out) &&>(out).data() , Shape<Int<Vec>, Int<FragsM>, Int<FragsN>>{});

CUTLASS_PRAGMA_UNROLL
for (int y = 0; y < FragsM; y++) {
CUTLASS_PRAGMA_UNROLL
for (int x = 0; x < Vec; x++) {
int indx = y * Vec + x;
auto cur_sum = reduce_over_group(sg, sum(indx), sycl::plus<>());
auto cur_scale = (cur_sum == 0.f || cur_sum != cur_sum) ? 1.0f : sycl::native::recip(cur_sum);
CUTLASS_PRAGMA_UNROLL
for (int z = 0; z < FragsN; z++) {
out_reg(x, y, z) *= cur_scale;
}
}
}

// Indexing variables
auto [batch, num_heads_q, num_heads_kv, head_size_vo] = select<0, 1, 2, 7>(problem_shape);
auto [seq_len_qo] = select<0>(sequence_length_shape);
// Represent the full output tensor
Tensor mO_mnl = cute::get_xe_tensor(make_shape(seq_len_qo, head_size_vo, 1));

auto [m_coord, n_coord, k_coord, l_coord] = tile_coord;
// Tile the output tensor per WG
Tensor g_wg_O = local_tile(mO_mnl, select<0,1>(TileShapeOutput{}), make_coord(m_coord,n_coord,0)); // (BLK_M,BLK_N,m,n,l)
static constexpr auto ATOM_N = get<2>(typename TiledMmaOutput::ThrLayoutVMNK{}.shape());
auto m_sg = get_sub_group_id() / ATOM_N;
auto n_sg = get_sub_group_id() % ATOM_N;
// Tile the output tensor per SG
Tensor gO = local_tile(g_wg_O, SubgroupTileShape{}, make_coord(m_sg,n_sg,_), Step<_1,_1, X>{}); // (BLK_M,BLK_N,m,n,l)
auto thread_xe_store_o = params.xe_store_o.get_thread_slice(ThreadIdxX());
Tensor tOgO = thread_xe_store_o.partition_D(gO);

Tensor final_out_reg = make_fragment_like<ElementOutput>(out_reg);
// iff ElementOutput == ElementAccumulator, then convert_type doesn't do the right conversion
// so we call copy() which internally performs a static_cast op on the data.
// for ElementOutput == bf16 | fp16, convert_type calls relevant NumericConverter specialization.
if constexpr (cute::is_same_v<ElementOutput, ElementCompute>) {
copy(out_reg, final_out_reg);
} else {
Tensor temp = convert_type<ElementOutput>(out_reg);
copy(temp, final_out_reg);
}
copy(params.xe_store_o, final_out_reg, tOgO);
}

// SequenceLengthShapeType = Shape<int, int, int>
// For Fixed Sequence Length, ProblemShapeType = Shape<int, int, int, int, int, int, int, int>
// For Variable Sequence Length, ProblemShapeType = Shape<int, int, int, VariableSeqlen, VariableSeqlen, VariableSeqlen, int, int>
template <bool VarLen, class ProblemShapeType, class SequenceLengthShapeType>
CUTLASS_DEVICE static constexpr Params get_updated_copies(Params const& params, ProblemShapeType const& problem_shape,
SequenceLengthShapeType const& sequence_length_shape, int const& l_coord, int const& q_head_coord) {
auto [num_heads_q, num_heads_kv, head_size_vo] = select<1, 2, 7>(problem_shape);
auto [seq_len_qo] = select<0>(sequence_length_shape);
int offset_o = 0;
if constexpr (VarLen) {
auto qo_cumulative_length = get<3>(problem_shape).cumulative_length;
offset_o = num_heads_q * head_size_vo * qo_cumulative_length[l_coord] + q_head_coord * head_size_vo;
} else {
offset_o = num_heads_q * head_size_vo * seq_len_qo * l_coord + q_head_coord * head_size_vo;
}
auto store_traits = static_cast<traits_store_O const&>(params.xe_store_o);
ElementO* base_ptr = (ElementO*)store_traits.base_ptr;
auto shape_o = make_shape(static_cast<int>(seq_len_qo), num_heads_q * head_size_vo, 1);
StrideO stride_o = cutlass::make_cute_packed_stride(StrideO{}, shape_o);
auto tensorO = make_tensor(make_gmem_ptr(base_ptr + offset_o), make_layout(shape_o, stride_o));
XE_Copy_O xe_store_o{XE_Copy_O{}.with(tensorO)};
return Params{xe_store_o};
}

private:
Params const &params;
};

/////////////////////////////////////////////////////////////////////////////////////////////////

} // namespace collective
} // namespace flash_attention
} // namespace cutlass

/////////////////////////////////////////////////////////////////////////////////////////////////
Loading