Skip to content

Commit

Permalink
gpu: jit: conv: always upscale f16/f32 mad
Browse files Browse the repository at this point in the history
  • Loading branch information
kealan-barbieri authored and karturov committed Dec 4, 2023
1 parent c9c0b09 commit 79bc6cc
Showing 1 changed file with 5 additions and 0 deletions.
5 changes: 5 additions & 0 deletions src/gpu/jit/conv/plan.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1152,6 +1152,7 @@ struct fma_context_t {
fma = cfg.fma_kind();
a_type = type_t(cfg.prb().a_data_type);
b_type = type_t(cfg.prb().b_data_type);
c_type = type_t(cfg.prb().c_data_type);
is_src1_broadcast = !cfg.prb().is_dw;
ab_swap_transpose_ = cfg.prb().ab_swap_transpose;
}
Expand All @@ -1171,6 +1172,9 @@ struct fma_context_t {
if (layout.type().is_x8())
return layout.retype(type_t::s16()).make_strided(2);

if (a_type.is_f16() && b_type.is_f16() && c_type.is_f32())
return layout.retype(type_t::f32()).make_dense();

// mad with f16 requires aligned regioning for src1/src2.
if (a_type.is_f16()) return layout.make_dense();

Expand Down Expand Up @@ -1338,6 +1342,7 @@ struct fma_context_t {
fma_kind_t fma;
type_t a_type;
type_t b_type;
type_t c_type;
bool is_src1_broadcast;
bool ab_swap_transpose_;
fma_layout_hint_t a_layout_hint;
Expand Down

0 comments on commit 79bc6cc

Please sign in to comment.