Skip to content
This repository was archived by the owner on Aug 30, 2024. It is now read-only.
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
11 changes: 11 additions & 0 deletions include/common/core/memory.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -500,12 +500,23 @@ __XETLA_API xetla_vector<T, N> xetla_load_global(
xetla_vector<OffsetT, N / VS> byte_offsets,
xetla_mask<N / VS> mask,
xetla_vector<T, N> pass_thru) {
#if __INTEL_LLVM_COMPILER >= 20240200
__ESIMD_NS::properties props{
__ESIMD_NS::cache_hint_L1<gpu::xetla::detail::get_cache_hint(L1H)>,
__ESIMD_NS::cache_hint_L2<gpu::xetla::detail::get_cache_hint(L2H)>,
__ESIMD_NS::alignment<alignment>};

return __ESIMD_NS::gather<T, N, VS>(p, byte_offsets, mask, pass_thru, props);
#else
constexpr data_size DS = data_size::default_size;
return __ESIMD_ENS::lsc_gather<
T,
VS,
gpu::xetla::detail::get_data_size(DS),
gpu::xetla::detail::get_cache_hint(L1H),
gpu::xetla::detail::get_cache_hint(L2H),
N / VS>(p, byte_offsets, mask, pass_thru);
#endif
}

/// template <typename T, int N, int VS, typename OffsetT,
Expand Down
38 changes: 23 additions & 15 deletions include/subgroup/tile/impl/load_xe.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -457,6 +457,7 @@ tile_load(tile_t& tile, payload_t& payload) {
constexpr uint32_t num_channel = payload_t::num_channel;
constexpr uint32_t load_elems = num_channel * payload_t::vector_size;
constexpr uint32_t pack_factor = payload_t::pack_factor;
const xetla_vector<load_dtype, load_elems> reg_zeros(0);

auto channel_offset = payload.channel_offset + payload.base_offset;
#pragma unroll
Expand Down Expand Up @@ -494,28 +495,35 @@ tile_load(tile_t& tile, payload_t& payload) {
? (xetla_vector_gen<uint32_t, num_channel>(offset_ch_dim, 1) <
size_ch_dim)
: 1;
reg_tmp = xetla_load_global<
load_dtype,
load_elems,
payload_t::vector_size,
L1,
L2>(
payload.base_ptr,
channel_offset + address_offset,
mask,
reg_zeros);
} else {
reg_tmp = xetla_load_global<
load_dtype,
load_elems,
payload_t::vector_size,
L1,
L2>(payload.base_ptr, channel_offset + address_offset, mask);
}
reg_tmp = xetla_load_global<
load_dtype,
load_elems,
payload_t::vector_size,
L1,
L2>(payload.base_ptr, channel_offset + address_offset, mask);

if constexpr (
payload_t::vector_size > 1 && payload_t::num_channel > 1) {
xetla_vector<load_dtype, load_elems> reg_tmp_trans;
#pragma unroll
for (uint32_t iii = 0; iii < payload_t::num_channel; iii++) {
if ((bool)mask[iii]) // TODO (dingyi): Delete after driver fix
reg_tmp_trans.xetla_select<payload_t::vector_size, 1>(
iii * payload_t::vector_size) =
reg_tmp.xetla_select<
payload_t::vector_size,
payload_t::num_channel>(iii);
else // TODO (dingyi): Delete after driver fix
reg_tmp_trans.xetla_select<payload_t::vector_size, 1>(
iii * payload_t::vector_size) = 0;
reg_tmp_trans.xetla_select<payload_t::vector_size, 1>(
iii * payload_t::vector_size) =
reg_tmp.xetla_select<
payload_t::vector_size,
payload_t::num_channel>(iii);
}
reg_sub
.xetla_select<load_elems * pack_factor, 1>(
Expand Down
18 changes: 8 additions & 10 deletions include/subgroup/tile/impl/payload_xe.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -1655,12 +1655,11 @@ struct prefetch_payload_t<
reg_layout_>,
num_coop_sg_,
arch_tag_,
std::enable_if_t<
(!arch_has_2d_load_store<arch_tag_>) &&
(((block_size_y_ != 1 || tile_size_y_ != 1) &&
mem_layout_ == mem_layout::row_major) ||
((block_size_x_ != 1 || tile_size_x_ != 1) &&
mem_layout_ == mem_layout::col_major))>> {
std::enable_if_t<(!arch_has_2d_load_store<arch_tag_>)&&(
((block_size_y_ != 1 || tile_size_y_ != 1) &&
mem_layout_ == mem_layout::row_major) ||
((block_size_x_ != 1 || tile_size_x_ != 1) &&
mem_layout_ == mem_layout::col_major))>> {
using dtype = native_type_t<dtype_>;
using mem_desc_t =
mem_desc_t<dtype_, mem_layout_, mem_space::global, alignment_>;
Expand Down Expand Up @@ -1902,10 +1901,9 @@ struct prefetch_payload_t<
reg_layout_>,
num_coop_sg_,
arch_tag_,
std::enable_if_t<
(arch_has_2d_load_store<arch_tag_>) &&
(((tile_size_y_ != 1) && mem_layout_ == mem_layout::row_major) ||
((tile_size_x_ != 1) && mem_layout_ == mem_layout::col_major))>> {
std::enable_if_t<(arch_has_2d_load_store<arch_tag_>)&&(
((tile_size_y_ != 1) && mem_layout_ == mem_layout::row_major) ||
((tile_size_x_ != 1) && mem_layout_ == mem_layout::col_major))>> {
using dtype = dtype_;
using mem_desc_t =
mem_desc_t<dtype_, mem_layout_, mem_space::global, alignment_>;
Expand Down
69 changes: 31 additions & 38 deletions tests/integration/fmha/fmha_forward.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -129,12 +129,12 @@ class fmha_forward_t {
using comp_attr = group::compute_attr_t<scalar_t, scalar_t, accum_t>;
using knobs = group::perf_tuning_knob_t<accum_step, stages, sync_freq>;
using compute_policy_BrBc = std::conditional_t<
(arch_tag >= gpu_arch::XeHpg),
(arch_has_xmx<arch_tag>),
group::compute_policy_default_xmx<comp_attr, knobs, arch_tag>,
group::compute_policy_default_fpu<comp_attr, knobs, arch_tag>>;
// TODO: add k slicing
using compute_policy_BrBm = std::conditional_t<
(arch_tag >= gpu_arch::XeHpg),
(arch_has_xmx<arch_tag>),
group::compute_policy_default_xmx<comp_attr, knobs, arch_tag>,
group::compute_policy_default_fpu<comp_attr, knobs, arch_tag>>;
// ---------------- // Tile shape and Threads // ---------------- //
Expand Down Expand Up @@ -688,7 +688,7 @@ class fmha_forward_t {
uint8_t,
mem_desc_Dp_Mask_t::layout,
mem_desc_Dp_Mask_t::space>>,
gpu_arch::XeHpc>;
arch_tag>;
load_payload_mask_t load_payload_mask(ctx.mem_desc_Dpij);
subgroup::tile_load(mask_in, load_payload_mask);
matAccSij.reg = matAccSij.reg * mask_in.reg * args.dp_scale;
Expand Down Expand Up @@ -771,7 +771,7 @@ class fmha_forward_t {
uint32_t height = args.uB * args.uN * args.uF;
uint32_t offset_height = b * args.uN * args.uF + f * args.uN + n;

if constexpr (arch_tag != gpu_arch::XeHpc) {
if constexpr (!arch_has_2d_load_store<arch_tag>) {
// offset for curr work item
const uint32_t O_offset = offset_height * args.uH + h;
const auto ld_c = args.uN * args.uH;
Expand All @@ -798,30 +798,30 @@ class fmha_forward_t {
matOi_store_t matOi_store(mem_desc_Oi);
subgroup::tile_store<cache_hint::write_back, cache_hint::write_back>(
matOi, matOi_store);
return;
}

xetla_fill_tdesc<scalar_t, kSgHm, 1, 1>(
transpose_tdecs.xetla_format<uint32_t>(),
args.O_ptr,
args.uH,
height,
args.uH,
h,
offset_height);

for (uint32_t i = 0; i < kSgBr && (f + i < args.uF); ++i) {
// load data from matAccOi
auto v_acc = matAccOi.reg.xetla_select<kSgHm, 1>(i * kSgHm);
v_out = xetla_cvt<scalar_t, accum_t, kSgHm>(v_acc);

xetla_tstore_global<
scalar_t,
kSgHm,
cache_hint::write_back,
cache_hint::write_back>(transpose_tdecs, v_out);
xetla_update_tdesc_offsety(
transpose_tdecs.xetla_format<uint32_t>(), args.uN);
} else {
xetla_fill_tdesc<scalar_t, kSgHm, 1, 1>(
transpose_tdecs.xetla_format<uint32_t>(),
args.O_ptr,
args.uH,
height,
args.uH,
h,
offset_height);

for (uint32_t i = 0; i < kSgBr && (f + i < args.uF); ++i) {
// load data from matAccOi
auto v_acc = matAccOi.reg.xetla_select<kSgHm, 1>(i * kSgHm);
v_out = xetla_cvt<scalar_t, accum_t, kSgHm>(v_acc);

xetla_tstore_global<
scalar_t,
kSgHm,
cache_hint::write_back,
cache_hint::write_back,
arch_tag>(transpose_tdecs, v_out);
xetla_update_tdesc_offsety(
transpose_tdecs.xetla_format<uint32_t>(), args.uN);
}
}
}
// ====================== // preload_Qi // ====================== //
Expand Down Expand Up @@ -888,16 +888,9 @@ class fmha_forward_t {
/// @return The size of local memory required.
inline static constexpr uint32_t get_slm_size() {
constexpr uint32_t size = slm_size_Qi + slm_size_Pij + slm_size_softmax;
if constexpr (arch_tag == gpu_arch::XeHpc) {
static_assert(
size <= (128 * 1024),
"The local memory size should be less than 128KB!");

} else {
static_assert(
size <= (64 * 1024),
"The local memory size should be less than 64KB!");
}
static_assert(
size <= (arch_attr_t<arch_tag>::local_mem_size),
"The local memory size should be less than arch total local memory size");
return size;
};

Expand Down
23 changes: 13 additions & 10 deletions tests/integration/fmha/fmha_utils.h
Original file line number Diff line number Diff line change
Expand Up @@ -134,7 +134,7 @@ template <
typename mat_t,
uint32_t kNumSg,
reduce_op reduce_kind,
gpu_arch arch_tag = gpu_arch::XeHpc>
gpu_arch arch_tag>
struct group_row_reduce_t {
using T = typename mat_t::dtype;
static constexpr uint32_t kNum = mat_t::tile_desc::tile_size_y;
Expand Down Expand Up @@ -215,7 +215,7 @@ enum class add_type : uint8_t {
/// @tparam arch_tag Is the hardware architecture tag.
template <
typename dtype_bias_,
gpu_arch arch_tag = gpu_arch::XeHpc,
gpu_arch arch_tag,
add_type add_tag = add_type::single_line>
struct bias_add_op_t {};

Expand Down Expand Up @@ -324,8 +324,8 @@ struct bias_add_op_t<dtype_bias_, arch_tag, add_type::single_element> {
using base_t = typename mem_desc_bias_t::base_t;

struct arguments_t {
shape_t shape;
base_t base;
shape_t shape;
inline arguments_t() = default;
inline arguments_t(base_t base_, shape_t shape_)
: base(base_), shape(shape_) {}
Expand All @@ -351,11 +351,10 @@ struct bias_add_op_t<dtype_bias_, arch_tag, add_type::single_element> {
uint32_t offset = (pos_y + pos_x * args.shape.stride) * sizeof(dtype_bias);
auto bias_data_vector = xetla_load_global<
dtype_bias,
16,
1,
data_size::default_size,
cache_hint::cached,
cache_hint::cached,
16>(ptr, offset);
cache_hint::cached>(ptr, offset);
dtype_acc bias_data =
xetla_cvt<dtype_acc, dtype_bias, 16>(bias_data_vector)[0];

Expand Down Expand Up @@ -418,15 +417,19 @@ template <
typename mem_desc_c_t_>
class epilogue_transp_t {};

template <typename tile_op_t_, typename tile_shape_, typename mem_desc_c_t_>
template <
typename tile_op_t_,
typename tile_shape_,
typename mem_desc_c_t_,
gpu_arch arch_tag_>
class epilogue_transp_t<
epilogue_policy_tile_op<tile_op_t_, gpu_arch::XeHpc>,
epilogue_policy_tile_op<tile_op_t_, arch_tag_>,
tile_shape_,
mem_desc_c_t_> {
public:
using tile_shape = tile_shape_;
using mem_desc_c_t = mem_desc_c_t_;
static constexpr gpu_arch arch_tag = gpu_arch::XeHpc;
static constexpr gpu_arch arch_tag = arch_tag_;
static constexpr uint32_t barrier_count = 0;
static constexpr uint32_t slm_size = 0;

Expand Down Expand Up @@ -505,7 +508,7 @@ class epilogue_write_back_t<
epilogue_policy_default<arch_tag_>,
tile_shape_,
mem_desc_c_t_,
std::enable_if_t<((arch_tag_ <= gpu_arch::XeHpc))>> {
std::enable_if_t<valid_xe_arch_tag<arch_tag_>>> {
public:
using epilogue_policy = epilogue_policy_default<arch_tag_>;
using tile_shape = tile_shape_;
Expand Down