Skip to content

Commit

Permalink
all: weights layouts with oc blocks 32 and 64 for 0D spatial case
Browse files Browse the repository at this point in the history
  • Loading branch information
akharito committed Oct 10, 2020
1 parent ac0ae4b commit dcb5c69
Show file tree
Hide file tree
Showing 7 changed files with 130 additions and 7 deletions.
18 changes: 18 additions & 0 deletions include/dnnl.hpp
Expand Up @@ -1527,6 +1527,15 @@ struct memory : public handle<dnnl_memory_t> {

// Opaque blocked formats

AB16b16a = dnnl_AB16b16a,
AB16b32a = dnnl_AB16b32a,
AB16b64a = dnnl_AB16b64a,
AB8b16a2b = dnnl_AB8b16a2b,
AB8b32a2b = dnnl_AB8b32a2b,
AB8b64a2b = dnnl_AB8b64a2b,
AB4b16a4b = dnnl_AB4b16a4b,
AB4b32a4b = dnnl_AB4b32a4b,
AB4b64a4b = dnnl_AB4b64a4b,
Abc16a = dnnl_Abc16a,
ABc16a16b = dnnl_ABc16a16b,
ABc4a4b = dnnl_ABc4a4b,
Expand Down Expand Up @@ -1698,6 +1707,15 @@ struct memory : public handle<dnnl_memory_t> {
NCdhw32n32c = dnnl_NCdhw32n32c,
NChw32n32c = dnnl_NChw32n32c,
IOhw16i16o = dnnl_IOhw16i16o,
OI16i16o = dnnl_OI16i16o,
OI16i32o = dnnl_OI16i32o,
OI16i64o = dnnl_OI16i64o,
OI8i16o2i = dnnl_OI8i16o2i,
OI8i32o2i = dnnl_OI8i32o2i,
OI8i64o2i = dnnl_OI8i64o2i,
OI4i16o4i = dnnl_OI4i16o4i,
OI4i32o4i = dnnl_OI4i32o4i,
OI4i64o4i = dnnl_OI4i64o4i,
Ohwi32o = dnnl_Ohwi32o,
IOdhw16i16o = dnnl_IOdhw16i16o,
gIOhw16i16o = dnnl_gIOhw16i16o,
Expand Down
19 changes: 19 additions & 0 deletions include/dnnl_types.h
Expand Up @@ -242,6 +242,15 @@ typedef enum {
dnnl_ABc8a16b2a,
dnnl_ABc8a8b,
dnnl_ABc8a4b,
dnnl_AB16b16a,
dnnl_AB16b32a,
dnnl_AB16b64a,
dnnl_AB8b16a2b,
dnnl_AB8b32a2b,
dnnl_AB8b64a2b,
dnnl_AB4b16a4b,
dnnl_AB4b32a4b,
dnnl_AB4b64a4b,
/// 3D tensor blocked by 2nd dimension with block size 8
dnnl_aBc8b,
dnnl_ABc8b16a2b,
Expand Down Expand Up @@ -589,6 +598,16 @@ typedef enum {
dnnl_NChw32n32c = dnnl_ABcd32a32b,
dnnl_NCdhw32n32c = dnnl_ABcde32a32b,

// weights, 2D
dnnl_OI16i16o = dnnl_AB16b16a,
dnnl_OI16i32o = dnnl_AB16b32a,
dnnl_OI16i64o = dnnl_AB16b64a,
dnnl_OI8i16o2i = dnnl_AB8b16a2b,
dnnl_OI8i32o2i = dnnl_AB8b32a2b,
dnnl_OI8i64o2i = dnnl_AB8b64a2b,
dnnl_OI4i16o4i = dnnl_AB4b16a4b,
dnnl_OI4i32o4i = dnnl_AB4b32a4b,
dnnl_OI4i64o4i = dnnl_AB4b64a4b,
// weights, 3D
dnnl_IOw16o16i = dnnl_BAc16a16b,
dnnl_IOw16i16o = dnnl_BAc16b16a,
Expand Down
18 changes: 18 additions & 0 deletions src/common/c_types_map.hpp
Expand Up @@ -208,6 +208,15 @@ const format_tag_t dcab = dnnl_dcab;
const format_tag_t cdeba = dnnl_cdeba;
const format_tag_t decab = dnnl_decab;
const format_tag_t defcab = dnnl_defcab;
const format_tag_t AB16b16a = dnnl_AB16b16a;
const format_tag_t AB16b32a = dnnl_AB16b32a;
const format_tag_t AB16b64a = dnnl_AB16b64a;
const format_tag_t AB8b16a2b = dnnl_AB8b16a2b;
const format_tag_t AB8b32a2b = dnnl_AB8b32a2b;
const format_tag_t AB8b64a2b = dnnl_AB8b64a2b;
const format_tag_t AB4b16a4b = dnnl_AB4b16a4b;
const format_tag_t AB4b32a4b = dnnl_AB4b32a4b;
const format_tag_t AB4b64a4b = dnnl_AB4b64a4b;
const format_tag_t Abc16a = dnnl_Abc16a;
const format_tag_t ABc16a16b = dnnl_ABc16a16b;
const format_tag_t ABc4a4b = dnnl_ABc4a4b;
Expand Down Expand Up @@ -454,6 +463,15 @@ const format_tag_t NCdhw16n16c = dnnl_NCdhw16n16c;
const format_tag_t NCw32n32c = dnnl_NCw32n32c;
const format_tag_t NChw32n32c = dnnl_NChw32n32c;
const format_tag_t NCdhw32n32c = dnnl_NCdhw32n32c;
const format_tag_t OI16i16o = dnnl_OI16i16o;
const format_tag_t OI16i32o = dnnl_OI16i32o;
const format_tag_t OI16i64o = dnnl_OI16i64o;
const format_tag_t OI8i16o2i = dnnl_OI8i16o2i;
const format_tag_t OI8i32o2i = dnnl_OI8i32o2i;
const format_tag_t OI8i64o2i = dnnl_OI8i64o2i;
const format_tag_t OI4i16o4i = dnnl_OI4i16o4i;
const format_tag_t OI4i32o4i = dnnl_OI4i32o4i;
const format_tag_t OI4i64o4i = dnnl_OI4i64o4i;
const format_tag_t IOdhw16i16o = dnnl_IOdhw16i16o;
const format_tag_t IOhw16i16o = dnnl_IOhw16i16o;
const format_tag_t Ohwi32o = dnnl_Ohwi32o;
Expand Down
18 changes: 18 additions & 0 deletions src/common/dnnl_debug_autogenerated.cpp
Expand Up @@ -119,6 +119,15 @@ const char *dnnl_fmt_tag2str(dnnl_format_tag_t v) {
if (v == dnnl_ABc8a16b2a) return "ABc8a16b2a";
if (v == dnnl_ABc8a8b) return "ABc8a8b";
if (v == dnnl_ABc8a4b) return "ABc8a4b";
if (v == dnnl_AB16b16a) return "AB16b16a";
if (v == dnnl_AB16b32a) return "AB16b32a";
if (v == dnnl_AB16b64a) return "AB16b64a";
if (v == dnnl_AB8b16a2b) return "AB8b16a2b";
if (v == dnnl_AB8b32a2b) return "AB8b32a2b";
if (v == dnnl_AB8b64a2b) return "AB8b64a2b";
if (v == dnnl_AB4b16a4b) return "AB4b16a4b";
if (v == dnnl_AB4b32a4b) return "AB4b32a4b";
if (v == dnnl_AB4b64a4b) return "AB4b64a4b";
if (v == dnnl_aBc8b) return "aBc8b";
if (v == dnnl_ABc8b16a2b) return "ABc8b16a2b";
if (v == dnnl_BAc8a16b2a) return "BAc8a16b2a";
Expand Down Expand Up @@ -348,6 +357,15 @@ const char *dnnl_fmt_tag2str(dnnl_format_tag_t v) {
if (v == dnnl_NCw32n32c) return "NCw32n32c";
if (v == dnnl_NChw32n32c) return "NChw32n32c";
if (v == dnnl_NCdhw32n32c) return "NCdhw32n32c";
if (v == dnnl_OI16i16o) return "OI16i16o";
if (v == dnnl_OI16i32o) return "OI16i32o";
if (v == dnnl_OI16i64o) return "OI16i64o";
if (v == dnnl_OI8i16o2i) return "OI8i16o2i";
if (v == dnnl_OI8i32o2i) return "OI8i32o2i";
if (v == dnnl_OI8i64o2i) return "OI8i64o2i";
if (v == dnnl_OI4i16o4i) return "OI4i16o4i";
if (v == dnnl_OI4i32o4i) return "OI4i32o4i";
if (v == dnnl_OI4i64o4i) return "OI4i64o4i";
if (v == dnnl_IOw16o16i) return "IOw16o16i";
if (v == dnnl_IOw16i16o) return "IOw16i16o";
if (v == dnnl_OIw16i16o) return "OIw16i16o";
Expand Down
9 changes: 9 additions & 0 deletions src/common/memory_desc_wrapper.cpp
Expand Up @@ -208,6 +208,15 @@ status_t memory_desc_wrapper::compute_blocking(
C(Acdb4a, {0, 2, 3, 1}, {4}, {0});
C(Acdeb4a, {0, 2, 3, 4, 1}, {4}, {0});

C(AB16b16a, {0, 1}, {16, 16}, {1, 0});
C(AB16b32a, {0, 1}, {16, 32}, {1, 0});
C(AB16b64a, {0, 1}, {16, 64}, {1, 0});
C(AB8b16a2b, {0, 1}, {8, 16, 2}, {1, 0, 1});
C(AB8b32a2b, {0, 1}, {8, 32, 2}, {1, 0, 1});
C(AB8b64a2b, {0, 1}, {8, 64, 2}, {1, 0, 1});
C(AB4b16a4b, {0, 1}, {4, 16, 4}, {1, 0, 1});
C(AB4b32a4b, {0, 1}, {4, 32, 4}, {1, 0, 1});
C(AB4b64a4b, {0, 1}, {4, 64, 4}, {1, 0, 1});
C(Abc16a, {0, 1, 2}, {16}, {0});
C(ABc16a16b, {0, 1, 2}, {16, 16}, {0, 1});
C(ABc4a4b, {0, 1, 2}, {4, 4}, {0, 1});
Expand Down
37 changes: 30 additions & 7 deletions src/common/tag_traits.hpp
Expand Up @@ -53,6 +53,8 @@ enum class inner_blk_t {
_8b8c,
_8c8b,
_16a16b,
_16b64a,
_16b32a,
_16b16a,
_16b16c,
_16c16b,
Expand All @@ -65,8 +67,12 @@ enum class inner_blk_t {
_2c8b4c,
_8a16b2a,
_4b16a4b,
_4b32a4b,
_4b64a4b,
_2b8a4b,
_8b16a2b,
_8b32a2b,
_8b64a2b,
_8b16c2b,
_4c16b4c,
_8c16b2c,
Expand All @@ -92,13 +98,15 @@ constexpr int AB_or_BC_blk_off(int x0, int x1) {
static_assert(
utils::one_of(f, ib::_4a4b, ib::_4b4a, ib::_4b4c, ib::_4c4b,
ib::_8a8b, ib::_8b8a, ib::_8b8c, ib::_8c8b, ib::_16a16b,
ib::_16b16a, ib::_16b16c, ib::_16c16b, ib::_32a32b,
ib::_16a2b, ib::_16a4b, ib::_16b2c, ib::_16b4c, ib::_2c8b4c,
ib::_8a16b2a, ib::_4b16a4b, ib::_2b8a4b, ib::_8b16a2b,
ib::_8b16c2b, ib::_4c16b4c, ib::_8c16b2c, ib::_2a8b8a2b,
ib::_2b8c8b2c, ib::_4a8b8a4b, ib::_4b8c8b4c, ib::_2b4c2b,
ib::_2c4b2c, ib::_4b8c2b, ib::_4c8b2c, ib::_16c16b4c,
ib::_16b16a4b, ib::_16c16b2c, ib::_16b16a2b),
ib::_16b64a, ib::_16b32a, ib::_16b16a, ib::_16b16c,
ib::_16c16b, ib::_32a32b, ib::_16a2b, ib::_16a4b,
ib::_16b2c, ib::_16b4c, ib::_2c8b4c, ib::_8a16b2a,
ib::_4b64a4b, ib::_4b32a4b, ib::_4b16a4b, ib::_2b8a4b,
ib::_8b64a2b, ib::_8b32a2b, ib::_8b16a2b, ib::_8b16c2b,
ib::_4c16b4c, ib::_8c16b2c, ib::_2a8b8a2b, ib::_2b8c8b2c,
ib::_4a8b8a4b, ib::_4b8c8b4c, ib::_2b4c2b, ib::_2c4b2c,
ib::_4b8c2b, ib::_4c8b2c, ib::_16c16b4c, ib::_16b16a4b,
ib::_16c16b2c, ib::_16b16a2b),
"unexpected inner_blk format");

// clang-format off
Expand All @@ -108,16 +116,22 @@ constexpr int AB_or_BC_blk_off(int x0, int x1) {
: (f == ib::_8a8b || f == ib::_8b8c) ? 8 * x0 + x1
: (f == ib::_8b8a || f == ib::_8c8b) ? 8 * x1 + x0
: (f == ib::_16a16b || f == ib::_16b16c) ? 16 * x0 + x1
: (f == ib::_16b64a) ? 64 * x1 + x0
: (f == ib::_16b32a) ? 32 * x1 + x0
: (f == ib::_16b16a || f == ib::_16c16b) ? 16 * x1 + x0
: (f == ib::_16a2b || f == ib::_16b2c) ? 2 * x0 + x1
: (f == ib::_16a4b || f == ib::_16b4c) ? 4 * x0 + x1
: (f == ib::_32a32b) ? 32 * x0 + x1
: (f == ib::_8a16b2a || f == ib::_8b16c2b) ? (x0 / 2) * 32 + x1 * 2 + x0 % 2
: (f == ib::_4b16a4b || f == ib::_4c16b4c) ? (x1 / 4) * 64 + x0 * 4 + x1 % 4
: (f == ib::_4b32a4b) ? (x1 / 4) * 128 + x0 * 4 + x1 % 4
: (f == ib::_4b64a4b) ? (x1 / 4) * 256 + x0 * 4 + x1 % 4
: (f == ib::_2b8a4b || f == ib::_2c8b4c) ? (x1 / 4) * 32 + x0 * 4 + x1 % 4
: (f == ib::_16b16a4b || f == ib::_16c16b4c) ? (x1 / 4) * 64 + x0 * 4 + x1 % 4
: (f == ib::_16b16a2b || f == ib::_16c16b2c) ? (x1 / 2) * 32 + x0 * 2 + x1 % 2
: (f == ib::_8b16a2b || f == ib::_8c16b2c) ? (x1 / 2) * 32 + x0 * 2 + x1 % 2
: (f == ib::_8b32a2b) ? (x1 / 2) * 64 + x0 * 2 + x1 % 2
: (f == ib::_8b64a2b) ? (x1 / 2) * 128 + x0 * 2 + x1 % 2
: (f == ib::_2b4c2b || f == ib::_2c4b2c) ? (x0 / 2) * 8 + x1 * 2 + x0 % 2
: (f == ib::_4b8c2b || f == ib::_4c8b2c) ? (x0 / 2) * 16 + x1 * 2 + x0 % 2
: (f == ib::_2a8b8a2b || f == ib::_2b8c8b2c) ? (x0 / 8) * 128 + (x1 / 2) * 16 + (x0 % 8) * 2 + x1 % 2
Expand Down Expand Up @@ -220,6 +234,15 @@ DECL_TRAITS(Acb4a, _A, _4a, 3);
DECL_TRAITS(Acdb4a, _A, _4a, 4);
DECL_TRAITS(Acdeb4a, _A, _4a, 5);

DECL_TRAITS(AB16b16a, _AB, _16b16a, 2);
DECL_TRAITS(AB16b32a, _AB, _16b32a, 2);
DECL_TRAITS(AB16b64a, _AB, _16b64a, 2);
DECL_TRAITS(AB8b16a2b, _AB, _8b16a2b, 2);
DECL_TRAITS(AB8b32a2b, _AB, _8b32a2b, 2);
DECL_TRAITS(AB8b64a2b, _AB, _8b64a2b, 2);
DECL_TRAITS(AB4b16a4b, _AB, _4b16a4b, 2);
DECL_TRAITS(AB4b32a4b, _AB, _4b32a4b, 2);
DECL_TRAITS(AB4b64a4b, _AB, _4b64a4b, 2);
DECL_TRAITS(Abc16a, _A, _16a, 3);
DECL_TRAITS(ABc16a16b, _AB, _16a16b, 3);
DECL_TRAITS(ABc4a4b, _AB, _4a4b, 3);
Expand Down
18 changes: 18 additions & 0 deletions tests/benchdnn/dnnl_debug_autogenerated.cpp
Expand Up @@ -110,6 +110,15 @@ dnnl_format_tag_t str2fmt_tag(const char *str) {
CASE(ABc8a16b2a);
CASE(ABc8a8b);
CASE(ABc8a4b);
CASE(AB16b16a);
CASE(AB16b32a);
CASE(AB16b64a);
CASE(AB8b16a2b);
CASE(AB8b32a2b);
CASE(AB8b64a2b);
CASE(AB4b16a4b);
CASE(AB4b32a4b);
CASE(AB4b64a4b);
CASE(aBc8b);
CASE(ABc8b16a2b);
CASE(BAc8a16b2a);
Expand Down Expand Up @@ -338,6 +347,15 @@ dnnl_format_tag_t str2fmt_tag(const char *str) {
CASE(NCw32n32c);
CASE(NChw32n32c);
CASE(NCdhw32n32c);
CASE(OI16i16o);
CASE(OI16i32o);
CASE(OI16i64o);
CASE(OI8i16o2i);
CASE(OI8i32o2i);
CASE(OI8i64o2i);
CASE(OI4i16o4i);
CASE(OI4i32o4i);
CASE(OI4i64o4i);
CASE(IOw16o16i);
CASE(IOw16i16o);
CASE(OIw16i16o);
Expand Down

0 comments on commit dcb5c69

Please sign in to comment.