Skip to content

Commit

Permalink
gpu: jit: reorder: enable any float to hf8, fixup
Browse files Browse the repository at this point in the history
  • Loading branch information
kealan-barbieri authored and karturov committed Apr 10, 2024
1 parent c3972ef commit 668abae
Show file tree
Hide file tree
Showing 4 changed files with 67 additions and 26 deletions.
63 changes: 48 additions & 15 deletions src/gpu/jit/codegen/reorder.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -259,6 +259,7 @@ void emit_reorder_1d_tile(ngen::HW hw, GeneratorT *host,
int dst_type_size = ngen::getBytes(dst_type);
int src_stride_bytes = src_stride * src_type_size;
int dst_stride_bytes = dst_stride * dst_type_size;
int max_type_size = std::max(src_type_size, dst_type_size);
bool dst_b = ngen_is_b(dst_type);
bool dst_d = ngen_is_dw(dst_type);
bool dst_q = ngen_is_qw(dst_type);
Expand Down Expand Up @@ -408,12 +409,22 @@ void emit_reorder_1d_tile(ngen::HW hw, GeneratorT *host,
return;
}
// hf8 -> f16
if (src_hf8 && dst_hf) {
if (src_hf8) {
int step = get_step();
const int src_stride_bytes = src_stride;
const int dst_stride_bytes = 2 * dst_stride;
const int step_nregs
= utils::div_up(step * ((int)sizeof(ngen::half)), grf_size);
const bool do_post_reorder = !dst_hf;
const int nregs = utils::div_up(width
* std::max((int)sizeof(ngen::half), max_type_size)
* std::max(src_stride, dst_stride),
grf_size);
if (do_post_reorder) {
auto tmp_dst = lex_scope.alloc_reg_buf_data(nregs).format(
0, ngen::DataType::hf);
dst = std::move(tmp_dst);
}
auto tmp1 = lex_scope.alloc_reg_buf_data(step_nregs);
auto tmp2 = lex_scope.alloc_reg_buf_data(step_nregs);
for (int i = 0; i < width; i += step) {
Expand Down Expand Up @@ -451,30 +462,54 @@ void emit_reorder_1d_tile(ngen::HW hw, GeneratorT *host,
host->mov(esize, d.reinterpret(0, ngen::DataType::uw)(dst_stride),
tmp2.subregister(0, ngen::DataType::uw)(dst_stride));
}
if (do_post_reorder) {
emit_reorder_1d_tile(
hw, host, scope, width, dst, dst_stride, _dst, dst_stride);
}
return;
}

if (src_hf && dst_hf8) {
if (dst_hf8) {
int step = get_step();
const int src_stride_bytes = 2 * src_stride;
const int dst_stride_bytes = dst_stride;
const int step_nregs
= utils::div_up(step * ((int)sizeof(ngen::half)), grf_size);
auto tmp1 = lex_scope.alloc_reg_buf_data(step_nregs);
auto tmp2 = lex_scope.alloc_reg_buf_data(step_nregs);
const bool do_pre_reorder = !src_hf;
const int nregs = utils::div_up(width

* std::max((int)sizeof(ngen::half), max_type_size)
* std::max(src_stride, dst_stride),
grf_size);
if (do_pre_reorder) {
auto tmp_src = lex_scope.alloc_reg_buf_data(nregs).format(
0, ngen::DataType::hf);
emit_reorder_1d_tile(hw, host, scope, width, src, src_stride,
tmp_src, src_stride);
src = std::move(tmp_src);
}
for (int i = 0; i < width; i += step) {
step = std::min(step, width - i);
step = utils::rnd_down_pow2(step);
int esize = step;

auto s = src.subregister(i, esize, src_stride_bytes);
auto d = dst.subregister(i, esize, dst_stride_bytes);

host->mov(esize, tmp1.subregister(0, ngen::DataType::uw)(1),
s.reinterpret(0, ngen::DataType::uw)(src_stride));
if (src_stride > 1 && s.getByteOffset() > 1) {
host->mov(esize,
tmp1.subregister(0, ngen::DataType::uw)(src_stride),
s.reinterpret(0, ngen::DataType::uw)(src_stride));
host->mov(esize, tmp1.subregister(0, ngen::DataType::uw)(1),
tmp1.subregister(0, ngen::DataType::uw)(src_stride));
} else {
host->mov(esize, tmp1.subregister(0, ngen::DataType::uw)(1),
s.reinterpret(0, ngen::DataType::uw)(src_stride));
}
// get sign bits
host->and_(esize | host->nz | host->f1[1], host->null.uw(),
s.reinterpret(0, ngen::DataType::uw)(1), 0x8000);
host->and_(esize | host->nz | host->f2[0], host->null.uw(),
tmp1.subregister(0, ngen::DataType::uw)(1), 0x8000);
// multiply by hf 128 to force overflow of exponent
host->mul(esize, tmp1.subregister(0, ngen::DataType::hf)(1),
tmp1.subregister(0, ngen::DataType::hf)(1),
Expand All @@ -487,22 +522,21 @@ void emit_reorder_1d_tile(ngen::HW hw, GeneratorT *host,
// check for NaN, inf.
host->and_(esize | host->ze | host->f0[0], host->null.uw(),
~tmp1.subregister(0, ngen::DataType::uw)(1), 0x7C00);
// check for zero mantissa.
host->and_(esize | host->ze | host->f1[0], host->null.uw(),
tmp1.subregister(0, ngen::DataType::uw)(1), 0x7F);
// round.
host->add(esize | host->f1[0],
tmp1.subregister(0, ngen::DataType::uw)(1),
host->add(esize, tmp1.subregister(0, ngen::DataType::uw)(1),
tmp1.subregister(0, ngen::DataType::uw)(1), -0x40);
// check for zero mantissa.
host->and_(esize | host->nz | host->f1[0], host->null.uw(),
tmp1.subregister(0, ngen::DataType::uw)(1), 0x3FF);
host->eshr(esize, tmp1.subregister(0, ngen::DataType::uw)(1),
tmp1.subregister(0, ngen::DataType::uw)(src_stride), 7);
tmp1.subregister(0, ngen::DataType::uw)(1), 7);
host->add(esize | host->f1[0],
tmp1.subregister(0, ngen::DataType::uw)(1),
tmp1.subregister(0, ngen::DataType::uw)(1), 1);
host->mov(esize | host->f0[0],
tmp1.subregister(0, ngen::DataType::uw)(1), 0x7F);
// handle sign.
host->or_(esize | host->f1[1],
host->or_(esize | host->f2[0],
tmp1.subregister(0, ngen::DataType::uw)(1),
tmp1.subregister(0, ngen::DataType::uw)(1), 0x80);

Expand All @@ -519,7 +553,6 @@ void emit_reorder_1d_tile(ngen::HW hw, GeneratorT *host,
// x <-> bf8
if (src_bf8 || dst_bf8) {
int step = get_step();
int max_type_size = std::max(src_type_size, dst_type_size);
ngen::DataType src_raw
= src_bf8 ? ngen::DataType::ub : ngen::DataType::w;
ngen::DataType dst_raw
Expand Down
2 changes: 1 addition & 1 deletion src/gpu/jit/gemm/gen_gemm.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -139,7 +139,7 @@ struct gen_gemm_t : public gpu_gemm_t {
ok = ok && d->b_type() == bf16
&& utils::one_of(d->c_type(), bf16, f32)
&& utils::one_of(d->acc_type, bf16, f32);
} else if (!wei_decomp_) {
} else if (!wei_decomp) {
ok = ok
&& utils::one_of(
d->a_type(), f32, f16, f8_e5m2, f8_e4m3)
Expand Down
2 changes: 1 addition & 1 deletion src/gpu/jit/gemm/gen_gemm_kernel_generator.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -25470,7 +25470,7 @@ bool gemm_kernel_generator_t<hw>::copyRegisters(Type Ts, Type Td,

const int nphases = 2, qCXMin = -1, qCXMax = -1;

Subregister saveF0;
Subregister saveF0, saveF1, saveF2;
bool releaseEmuFlag = false;
bool preswizzle = (hw >= HW::XeHP);
GRFRange copyTemp;
Expand Down
26 changes: 17 additions & 9 deletions src/gpu/jit/reorder/gen_reorder.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,7 @@ status_t gen_reorder_t::pd_t::init(
auto *compute_engine = utils::downcast<compute::compute_engine_t *>(engine);
auto *device_info = compute_engine->device_info();
zero_points_config_t zp_cfg(this);
using namespace data_type;

auto post_ops_ok = [&]() {
const auto &po = attr()->post_ops_;
Expand All @@ -65,30 +66,37 @@ status_t gen_reorder_t::pd_t::init(
&& (!zp_cfg.do_dst_compensation
|| zp_cfg.is_common_dst_zero_point);
};
auto is_bf16_or_f32_or_bf8 = [](data_type_t dt) {
return utils::one_of(dt, data_type::bf16, data_type::f32,
data_type::f8_e5m2, data_type::f8_e4m3);
auto is_bf16_or_f32_or_f8 = [](data_type_t dt) {
return utils::one_of(dt, bf16, f32, f8_e5m2, f8_e4m3);
};
auto hf8_ok = [&]() {
bool any_hf8 = utils::one_of(f8_e4m3, dst_dt, src_dt);
return IMPLICATION(any_hf8,
utils::everyone_is(f8_e4m3, dst_dt, src_dt)
|| utils::one_of(src_dt, bf16, f16, f32)
|| utils::one_of(dst_dt, bf16, f16, f32));
};
bool any_hf8 = utils::one_of(data_type::f8_e4m3, dst_dt, src_dt);
auto skip_mask = dnnl_primitive_attr::skip_mask_t::post_ops
| dnnl_primitive_attr::skip_mask_t::zero_points_runtime
| dnnl_primitive_attr::skip_mask_t::scales_runtime;
using namespace data_type;
bool ok = src_engine == dst_engine && src_engine->kind() == engine_kind::gpu
&& utils::one_of(src_dt, f32, f16, bf16, f8_e5m2, s32, s8, u8, f64)
&& utils::one_of(dst_dt, f32, f16, bf16, f8_e5m2, s32, s8, u8, f64)
&& utils::one_of(
src_dt, f32, f16, bf16, f8_e5m2, f8_e4m3, s32, s8, u8, f64)
&& utils::one_of(
dst_dt, f32, f16, bf16, f8_e5m2, f8_e4m3, s32, s8, u8, f64)
&& IMPLICATION(src_dt == data_type::f16 || dst_dt == data_type::f16,
device_info->has_native(data_type::f16))
&& IMPLICATION(
src_dt == data_type::bf16, is_bf16_or_f32_or_bf8(dst_dt))
src_dt == data_type::bf16, is_bf16_or_f32_or_f8(dst_dt))
&& IMPLICATION(
dst_dt == data_type::bf16, is_bf16_or_f32_or_bf8(src_dt))
dst_dt == data_type::bf16, is_bf16_or_f32_or_f8(src_dt))
&& IMPLICATION(utils::one_of(data_type::f8_e5m2, src_dt, dst_dt),
device_info->has_native(data_type::f8_e5m2))
&& IMPLICATION(src_dt == data_type::f64 || dst_dt == data_type::f64,
device_info->has_native(data_type::f64))
&& attr()->has_default_values(skip_mask) && extra_ok()
&& post_ops_ok() && scales_ok() && zps_ok();
&& post_ops_ok() && scales_ok() && zps_ok() && hf8_ok();
if (!ok) return status::unimplemented;

memory_desc_wrapper src_mdw {src_md()};
Expand Down

0 comments on commit 668abae

Please sign in to comment.