Skip to content

Commit

Permalink
src: cpu: x64: fix offsets overflow in f32 AVX2 jit conv kernel
Browse files Browse the repository at this point in the history
  • Loading branch information
akharito committed Jan 13, 2021
1 parent d8d6807 commit cb8ef4e
Showing 1 changed file with 17 additions and 10 deletions.
27 changes: 17 additions & 10 deletions src/cpu/x64/jit_avx2_conv_kernel_f32.cpp
@@ -1,5 +1,5 @@
/*******************************************************************************
* Copyright 2016-2020 Intel Corporation
* Copyright 2016-2021 Intel Corporation
* Copyright 2018 YANDEX LLC
*
* Licensed under the Apache License, Version 2.0 (the "License");
Expand Down Expand Up @@ -94,8 +94,9 @@ void jit_avx2_conv_fwd_kernel_f32::oh_step_unroll_kw(

for (int ii = 0; ii < oc_blocks; ii++) {
vmovups(ymm15,
ptr[aux_reg_kernel
+ get_kernel_offset(ii, ki, ifm2)]);
make_safe_addr(aux_reg_kernel,
get_kernel_offset(ii, ki, ifm2),
reg_long_offt));
for (int jj = jj_start; jj < jj_end; jj++)
if (mayiuse(avx2))
vfmadd231ps(Ymm(ur_w * ii + jj),
Expand Down Expand Up @@ -152,7 +153,8 @@ void jit_avx2_conv_fwd_kernel_f32::oh_step_nopad(
}
for (int ii = 0; ii < oc_blocks; ii++) {
vmovups(ymm15,
ptr[aux_reg_kernel + get_kernel_offset(ii, 0, ifm2)]);
make_safe_addr(aux_reg_kernel,
get_kernel_offset(ii, 0, ifm2), reg_long_offt));
for (int jj = jj_start; jj < jj_end; jj++)
if (mayiuse(avx2))
vfmadd231ps(Ymm(ur_w * ii + jj),
Expand All @@ -163,8 +165,9 @@ void jit_avx2_conv_fwd_kernel_f32::oh_step_nopad(
}
}
}
add(aux_reg_kernel, get_kernel_offset(0, 1, 0));
add(aux_reg_input, get_input_offset(0, filter_w_to_input(1)));
safe_add(aux_reg_kernel, get_kernel_offset(0, 1, 0), reg_long_offt);
safe_add(aux_reg_input, get_input_offset(0, filter_w_to_input(1)),
reg_long_offt);

inc(ki_iter);
cmp(ki_iter, kw);
Expand Down Expand Up @@ -396,8 +399,10 @@ void jit_avx2_conv_fwd_kernel_f32::width_blk_step(
- get_input_offset(0, filter_w_to_input(kw)));
} else {
oh_step_unroll_kw(ur_w, pad_l, pad_r, oc_blocks);
add(aux_reg_kernel, get_kernel_offset(0, kw, 0));
add(aux_reg_input, get_input_offset(0, filter_h_to_input(1)));
safe_add(
aux_reg_kernel, get_kernel_offset(0, kw, 0), reg_long_offt);
safe_add(aux_reg_input, get_input_offset(0, filter_h_to_input(1)),
reg_long_offt);
}

dec(kj);
Expand All @@ -408,8 +413,10 @@ void jit_avx2_conv_fwd_kernel_f32::width_blk_step(
L(skip_kh_loop);

if (jcp.ndims == 5) {
add(aux_reg_inp_d, get_input_offset(0, filter_d_to_input(1)));
add(aux_reg_ker_d, get_kernel_offset(0, jcp.kw * jcp.kh, 0));
safe_add(aux_reg_inp_d, get_input_offset(0, filter_d_to_input(1)),
reg_long_offt);
safe_add(aux_reg_ker_d, get_kernel_offset(0, jcp.kw * jcp.kh, 0),
reg_long_offt);

dec(reg_ki);
cmp(reg_ki, 0);
Expand Down

0 comments on commit cb8ef4e

Please sign in to comment.