-
Notifications
You must be signed in to change notification settings - Fork 960
/
jit_uni_pool_kernel.hpp
265 lines (215 loc) · 8.5 KB
/
jit_uni_pool_kernel.hpp
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
/*******************************************************************************
* Copyright 2017-2021 Intel Corporation
* Copyright 2018 YANDEX LLC
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*******************************************************************************/
#ifndef CPU_X64_JIT_UNI_POOL_KERNEL_HPP
#define CPU_X64_JIT_UNI_POOL_KERNEL_HPP
#include <cfloat>
#include <functional>
#include <memory>
#include "common/memory_tracking.hpp"
#include "cpu/x64/injectors/jit_uni_postops_injector.hpp"
#include "cpu/x64/jit_generator.hpp"
#include "cpu/x64/jit_primitive_conf.hpp"
namespace dnnl {
namespace impl {
namespace cpu {
namespace x64 {
struct bf16_emulation_t;
template <cpu_isa_t isa>
struct jit_uni_pool_kernel : public jit_generator {
jit_uni_pool_kernel(
const jit_pool_conf_t &ajpp, const memory_desc_t *dst_md);
jit_pool_conf_t jpp;
~jit_uni_pool_kernel();
DECLARE_CPU_JIT_AUX_FUNCTIONS(jit_uni_pool_kernel)
static status_t init_conf(jit_pool_conf_t &jbp,
memory_tracking::registrar_t &scratchpad, const pooling_pd_t *ppd);
private:
using Xmm = Xbyak::Xmm;
using Ymm = Xbyak::Ymm;
using Zmm = Xbyak::Zmm;
using Opmask = Xbyak::Opmask;
using Reg32 = Xbyak::Reg32;
using Reg64 = Xbyak::Reg64;
using Vmm = typename cpu_isa_traits<isa>::Vmm;
int vmm_idx_upper_bound() const noexcept {
return utils::one_of(isa, avx512_common, avx512_core) ? 31 : 15;
}
int reg_idx(int idx) const noexcept { return vmm_idx_upper_bound() - idx; }
Xmm xreg(int idx) const noexcept { return Xmm(reg_idx(idx)); }
Ymm yreg(int idx) const noexcept { return Ymm(reg_idx(idx)); }
Zmm zreg(int idx) const noexcept { return Zmm(reg_idx(idx)); }
Vmm vreg(int idx) const noexcept { return Vmm(reg_idx(idx)); }
const Xbyak::AddressFrame &vmmword = (isa == sse41)
? xword
: (isa == avx || isa == avx2) ? yword : zword;
Xmm vmm_mask = Xmm(0);
Xmm xmm_tmp_1 = Xmm(0);
Ymm ymm_tmp_1 = Ymm(0);
Vmm vmm_tmp_1 = Vmm(0);
// Used only for avx and if c tail is present
Vmm vmm_c_tail_mask = Vmm(2);
Xmm xmm_ker_area_h = Xmm(2);
Xmm xmm_one = Xmm(2);
Xmm xmm_tmp = Xmm(3);
Vmm vmm_ker_area_h = Vmm(2);
Vmm vmm_one = Vmm(2);
Vmm vmm_tmp = Vmm(3);
Ymm ymm_tmp = Ymm(3);
Vmm vmm_k_offset = Vmm(1);
// Used only for avx512 when bf16 is present
inline Vmm vmm_idx() {
if (!jpp.is_backward) {
return (jpp.is_training) ? Vmm(4) : Vmm(1);
} else
return Vmm(4);
}
Zmm bf16_emu_reserv_1 = Zmm(5);
Zmm bf16_emu_reserv_2 = Zmm(6);
Zmm bf16_emu_reserv_3 = Zmm(7);
Reg64 bf16_emu_reserv_4 = r11;
Zmm bf16_emu_reserv_5 = Zmm(8);
Opmask k_c_tail_mask = Opmask(4);
Opmask k_mask_cvt = Opmask(5);
Opmask k_store_mask = Opmask(6);
// Here be some (tame) dragons. This kernel does not follow the regular
// OS-agnostic ABI pattern because when isa is sse41 it uses maskmovdqu
// instruction which has its destination hardcoded in rdi. Therefore:
// - all registers are hardcoded
// - on Windows rdi and rcx are swapped to mimic the Unix x86_64 ABI
//
// While this is only required by the backward pass, the quirk above
// is applied to the forward pass as well to keep things simpler.
using reg64_t = const Reg64;
reg64_t reg_param = rdi; // Always mimic the Unix ABI
reg64_t reg_input = r8;
reg64_t aux_reg_input = r9;
reg64_t reg_index = r10;
reg64_t reg_output = r12;
reg64_t reg_kd_pad_shift = r13;
reg64_t dst_ptr = rdi; // Must be rdi due to maskmovdqu
reg64_t kj = r14;
reg64_t oi_iter = r15;
reg64_t reg_kh = rax;
reg64_t reg_k_shift = rbx;
reg64_t tmp_gpr = rcx; // Must be rcx because rdi is used above
reg64_t reg_ker_area_h = rdx;
reg64_t reg_nbc = rsi;
reg64_t reg_zero_ptr = r9;
reg64_t reg_zero_id = r13;
reg64_t reg_zero_ih = r14;
reg64_t aux_reg_zero_ih = r15;
reg64_t ki = r12;
reg64_t aux_reg_input_d = r8;
Reg32 reg_shuf_mask = esi;
bool sse_high_half = false;
bool disable_postops_when_sse_high_half_processed_ = false;
int prev_kw;
void prepare_tail_mask();
void put_one_in_vmm();
void uni_broadcast_reg_val(const int reg_idx, const int vmm_idx);
void push_vmm_val(const int idx);
void pop_vmm_val(const int idx);
void load(const int idx, const reg64_t ®_ptr, const int offset,
const bool is_c_tail_proccessing);
void store(const int idx, const reg64_t ®_ptr, const int offset,
const bool is_c_tail_proccessing);
void maybe_recalculate_divisor(int jj, int ur_w, int pad_l, int pad_r,
bool with_c_tail_proccessing);
void avg_step(int ur_w, int ur_bc, int pad_l, int pad_r,
bool with_c_tail_proccessing);
void max_step_fwd(int ur_w, int ur_bc, int pad_l, int pad_r,
bool with_c_tail_proccessing);
void max_step_bwd(int ur_w, int ur_bc, int pad_l, int pad_r,
bool with_c_tail_proccessing);
void zero_diff_src(int ur_bc, bool with_c_tail_proccessing);
void step(int ur_w, int ur_bc, int pad_l, int pad_r,
bool with_c_tail_proccessing) {
if (jpp.alg == alg_kind::pooling_max) {
if (jpp.is_backward)
max_step_bwd(
ur_w, ur_bc, pad_l, pad_r, with_c_tail_proccessing);
else
max_step_fwd(
ur_w, ur_bc, pad_l, pad_r, with_c_tail_proccessing);
} else
avg_step(ur_w, ur_bc, pad_l, pad_r, with_c_tail_proccessing);
}
void step_high_half(int ur_w, int ur_bc, int pad_l, int pad_r,
bool with_c_tail_processing) {
add(reg_input, sizeof(float) * 4);
add(reg_output, sizeof(float) * 4);
if (jpp.alg == alg_kind::pooling_max
&& (jpp.is_training || jpp.is_backward))
add(reg_index, types::data_type_size(jpp.ind_dt) * 4);
step(ur_w, ur_bc, pad_l, pad_r, with_c_tail_processing);
}
void generate() override;
void avx_vpadd1(const Ymm &y0, const Xmm &x1, const Xmm &xtmp) {
assert(y0.getIdx() != x1.getIdx());
vextractf128(xtmp, y0, 0);
vpaddd(xtmp, xtmp, x1);
vinsertf128(y0, y0, xtmp, 0);
vextractf128(xtmp, y0, 1);
vpaddd(xtmp, xtmp, x1);
vinsertf128(y0, y0, xtmp, 1);
}
void avx_vpadd1(const Xmm &x0, const Xmm &x1, const Xmm &) {
assert(false /*function should not be used*/);
paddd(x0, x1);
}
void avx_pmovzxbd(const Ymm &y0, const Xmm &x1, const Xmm &xtmp) {
Xmm x0(y0.getIdx());
pshufd(xmm_tmp, x1, 1);
pmovzxbd(x0, x1);
pmovzxbd(xmm_tmp, xmm_tmp);
vinsertf128(y0, y0, xmm_tmp, 1);
}
void avx_pmovzxbd(const Xmm &x0, const Xmm &x1, const Xmm &) {
assert(false /*function should not be used*/);
pmovzxbd(x0, x1);
}
void avx_pcmpeqd(
const Ymm &y0, const Ymm &y1, const Ymm &y2, const Xmm &xtmp) {
assert(y0.getIdx() != y1.getIdx());
assert(y0.getIdx() != y2.getIdx());
Xmm x0(y0.getIdx());
Xmm x2(y2.getIdx());
vextractf128(x0, y1, 1);
vextractf128(xtmp, y2, 1);
pcmpeqd(xtmp, x0);
vextractf128(x0, y1, 0);
pcmpeqd(x0, x2);
vinsertf128(y0, y0, xtmp, 1);
}
void avx_pcmpeqd(const Xmm &x0, const Xmm &x1, const Xmm &, const Xmm &) {
assert(false /*function should not be used*/);
pcmpeqd(x0, x1);
}
void apply_postops(int ur_bc, int ur_w, int c_block,
const std::function<bool(int, bool)> &is_tail_predicate);
static bool post_ops_ok(jit_pool_conf_t &jpp, const primitive_attr_t &attr,
const memory_desc_wrapper &dst_d);
std::unique_ptr<bf16_emulation_t> bf16_emu_;
std::unique_ptr<injector::jit_uni_postops_injector_t<isa>>
postops_injector_;
};
} // namespace x64
} // namespace cpu
} // namespace impl
} // namespace dnnl
#endif
// vim: et ts=4 sw=4 cindent cino+=l0,\:4,N-s