Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Direct Covolution Issue #742

Closed
Alok-Ranjan23 opened this issue Jun 4, 2020 · 5 comments
Closed

Direct Covolution Issue #742

Alok-Ranjan23 opened this issue Jun 4, 2020 · 5 comments
Assignees
Labels

Comments

@Alok-Ranjan23
Copy link

I am using your C++ API to perform convolution with following dimension
batch = 37,channel = 104, height=222,width=222, no_of_filter = 40,kernel= (9,9),pad_h = 4,pad_w = 4,stride= (1,1).
No_of_filters and channel are multiple of 8 in above case. So i think that C++ api has to take direct convolution path for convolution. But it is taking reference implementation path to perform convolution on these dimension.
DNNL implementation path in ref_convolution_fwd_t::execute_forward [src/cpu/ref_convolution.cpp]
dnnl_primitive_execute cpu,convolution,ref:any,forward_inference,src_f32::blocked:acdb:f0 wei_f32::blocked:abcd:f0 bia_f32::blocked:a:f0 dst_f32::blocked:acdb:f0,,alg:convolution_direct,mb2_ic8oc8_ih222oh222kh9sh1dh0ph4_iw222ow222kw9sw1dw0pw4 54.354

Why are these parameters (in blocked format AVX256 path) not taking direct convolution path instead it is taking reference convolution path to perform convolution?

@emfomenk
Copy link

emfomenk commented Jun 4, 2020

Hi @Alok-Ranjan23,

How did you create a convolution? Do you use format_tag::any, which is a recommended way to create compute-intensive primitive to get the best performance?

oneDNN also has a limited support for plain formats provided as-is from user. A user can force convolution to use nhwc format for source and destination tensors, but the weights should still be created with format tag any, so that convolution chooses the appropriate format.
That is what you attempted to use, according to the verbose output you provided (though you forced the weight format and that was the reason to dispatch into reference implementation). Please note, that this feature of supporting nhwc memory format for src and dst is pretty recent, so please use v1.5+ if you want to rely on it.

I ran your convolution using benchdnn, and it seems dispatching fine:

$ export DNNL_VERBOSE=1
$ # allow convolution to define memory format for all tensors: src, weights, dst
$ ./tests/benchdnn/benchdnn --conv mb2_ic8oc8_ih222oh222kh9sh1dh0ph4_iw222ow222kw9sw1dw0pw4 
...
dnnl_verbose,exec,cpu,convolution,gemm:jit,forward_training,src_f32::blocked:abcd:f0 wei_f32::blocked:abcd:f0 bia_f32::blocked:a:f0 dst_f32::blocked:abcd:f0,,alg:convolution_direct,mb2_ic8oc8_ih222oh222kh9sh1dh0ph4_iw222ow222kw9sw1dw0pw4,155.547


$ # allow convolution to define memory format for weights tensor only.
$ # The src and dst tensors are forced to use NHWC memory format.
$ ./tests/benchdnn/benchdnn --conv --stag=nhwc --dtag=nhwc mb2_ic8oc8_ih222oh222kh9sh1dh0ph4_iw222ow222kw9sw1dw0pw4
...
dnnl_verbose,exec,cpu,convolution,gemm:jit,forward_training,src_f32::blocked:acdb:f0 wei_f32::blocked:cdba:f0 bia_f32::blocked:a:f0 dst_f32::blocked:acdb:f0,,alg:convolution_direct,mb2_ic8oc8_ih222oh222kh9sh1dh0ph4_iw222ow222kw9sw1dw0pw4,156.128
...


$ # force convolution to use fixed memory formats (according to the verbose log)
$ ./tests/benchdnn/benchdnn --conv --stag=nhwc --wtag=abcd --dtag=nhwc mb2_ic8oc8_ih222oh222kh9sh1dh0ph4_iw222ow222kw9sw1dw0pw4
...
dnnl_verbose,exec,cpu,convolution,ref:any,forward_training,src_f32::blocked:acdb:f0 wei_f32::blocked:abcd:f0 bia_f32::blocked:a:f0 dst_f32::blocked:acdb:f0,,alg:convolution_direct,mb2_ic8oc8_ih222oh222kh9sh1dh0ph4_iw222ow222kw9sw1dw0pw4,246.89
...

@emfomenk emfomenk self-assigned this Jun 4, 2020
@Alok-Ranjan23
Copy link
Author

A user can force convolution to use nhwc format for source and destination tensors, but the weights should still be created with format tag any, so that convolution chooses the appropriate format.
That is what you attempted to use, according to the verbose output you provided (though you forced the weight format and that was the reason to dispatch into reference implementation).

No, I have set format_tag as any for weight memory. Please check the attached code snippet.
But still it is taking reference convolution path instead of direct convolution path.

`
void convolution_param(engine eng, dnnl::memory user_src_memory, int batch, int channel, int height, int width,dnnl::memory user_weights_memory, int no_of_filter, int kernel_h, int kernel_w, int pad_h,int pad_w, int stride_h, int stride_w, dnnl::memory conv1_user_bias_memory,dnnl::memory conv1_dst_memory, int out_height, int out_width) {

int times = 1;
using tag = memory::format_tag;
using dt = memory::data_type;

/// Initialize an engine and stream. The last parameter in the call represents
/// the index of the engine.
/// @snippet cnn_inference_f32.cpp Initialize engine and stream
//[Initialize engine and stream]
//engine eng(engine_kind, 0);
stream s(eng);
//[Initialize engine and stream]

/// Create a vector for the primitives and a vector to hold memory
/// that will be used as arguments.
/// @snippet cnn_inference_f32.cpp Create network
//[Create network]
std::vector<primitive> net;
std::vector<std::unordered_map<int, memory>> net_args;
//[Create network]

// AlexNet: conv1
// {batch, 3, 227, 227} (x) {96, 3, 11, 11} -> {batch, 96, 55, 55}
// strides: {4, 4}
memory::dims conv1_src_tz = {batch, channel, height, width};
memory::dims conv1_weights_tz = {no_of_filter, channel, kernel_h, kernel_w};
memory::dims conv1_bias_tz = {no_of_filter};
memory::dims conv1_dst_tz = {batch, no_of_filter, out_height, out_width};
memory::dims conv1_strides = {stride_h, stride_w};
memory::dims conv1_padding = {pad_h, pad_w};


//[Create user memory]
/// Create memory descriptors with layout tag::any. The `any` format enables
/// the convolution primitive to choose the data format that will result in
/// best performance based on its input parameters (convolution kernel
/// sizes, strides, padding, and so on). If the resulting format is different
/// from `nchw`, the user data must be transformed to the format required for
/// the convolution (as explained below).
/// @snippet cnn_inference_f32.cpp Create convolution memory descriptors
//[Create convolution memory descriptors]
auto conv1_src_md = memory::desc({conv1_src_tz}, dt::f32, tag::any);
auto conv1_bias_md = memory::desc({conv1_bias_tz}, dt::f32, tag::any);
auto conv1_weights_md = memory::desc({conv1_weights_tz}, dt::f32, tag::any);
auto conv1_dst_md = memory::desc({conv1_dst_tz}, dt::f32, tag::any);
//[Create convolution memory descriptors]

/// Create a convolution descriptor by specifying propagation kind,
/// [convolution algorithm](@ref dev_guide_convolution), shapes of input,
/// weights, bias, output, convolution strides, padding, and kind of padding.
/// Propagation kind is set to prop_kind::forward_inference to optimize for
/// inference execution and omit computations that are necessary only for
/// backward propagation.
/// @snippet cnn_inference_f32.cpp Create convolution descriptor
//[Create convolution descriptor]
auto conv1_desc = convolution_forward::desc(prop_kind::forward_inference,
                  algorithm::convolution_direct, conv1_src_md, conv1_weights_md,
                  conv1_bias_md, conv1_dst_md, conv1_strides, conv1_padding,
                  conv1_padding);
//[Create convolution descriptor]

/// Create a convolution primitive descriptor. Once created, this
/// descriptor has specific formats instead of the `any` format specified
/// in the convolution descriptor.
/// @snippet cnn_inference_f32.cpp Create convolution primitive descriptor
//[Create convolution primitive descriptor]
auto conv1_prim_desc = convolution_forward::primitive_desc(conv1_desc, eng);
//[Create convolution primitive descriptor]

/// Check whether data and weights formats required by convolution is different
/// from the user format. In case it is different change the layout using
/// reorder primitive.
/// @snippet cnn_inference_f32.cpp Reorder data and weights
//[Reorder data and weights]

auto conv1_src_memory = user_src_memory;

if (conv1_prim_desc.src_desc() != user_src_memory.get_desc()) {
    conv1_src_memory = memory(conv1_prim_desc.src_desc(), eng);
    net.push_back(reorder(user_src_memory, conv1_src_memory));
    net_args.push_back({{DNNL_ARG_FROM, user_src_memory},
        {DNNL_ARG_TO, conv1_src_memory}});
}

auto conv1_weights_memory = user_weights_memory;
if (conv1_prim_desc.weights_desc() != user_weights_memory.get_desc()) {
    conv1_weights_memory = memory(conv1_prim_desc.weights_desc(), eng);
    reorder(user_weights_memory, conv1_weights_memory)
    .execute(s, user_weights_memory, conv1_weights_memory);
}

//[Reorder data and weights]

/// Create a memory primitive for output.
/// @snippet cnn_inference_f32.cpp Create memory for output
//[Create memory for output]
//auto conv1_dst_memory = memory(conv1_prim_desc.dst_desc(), eng);
//[Create memory for output]

/// Create a convolution primitive and add it to the net.
/// @snippet cnn_inference_f32.cpp Create memory for output
//[Create convolution primitive]
net.push_back(convolution_forward(conv1_prim_desc));
net_args.push_back({{DNNL_ARG_SRC, conv1_src_memory},
    {DNNL_ARG_WEIGHTS, conv1_weights_memory},
    {DNNL_ARG_BIAS, conv1_user_bias_memory},
    {DNNL_ARG_DST, conv1_dst_memory}
});
//[Create convolution primitive]


/// @page cnn_inference_f32_cpp
/// Finally, execute the primitives. For this example, the net is executed
/// multiple times and each execution is timed individually.
/// @snippet cnn_inference_f32.cpp Execute model
//[Execute model]
for (int j = 0; j < times; ++j) {
    assert(net.size() == net_args.size() && "something is missing");
    for (size_t i = 0; i < net.size(); ++i) {
        net.at(i).execute(s, net_args.at(i));
    }
}
//[Execute model]
s.wait();

}

int main(int argc, char **argv) {
//Input parameters to convolution
int times = 1; //100
int batch = 2;//;3;
int channel = 8;//3;
int height = 222;//227;
int width = 222;//227;
int no_of_filter = 8;//96;
int kernel_h = 9;//11;
int kernel_w = 9;//11;
int pad_h = 4;
int pad_w = 4;
int stride_h = 1;//4;
int stride_w = 1;//4;

    int out_height = (height + pad_h + pad_w - kernel_h) / stride_h + 1;
    int out_width = (width + pad_h + pad_w - kernel_w) / stride_w + 1;

    using tag = memory::format_tag;
    using dt = memory::data_type;
    memory::dims conv1_src_tz = {batch, channel, height, width};
    memory::dims conv1_weights_tz = {no_of_filter, channel, kernel_h, kernel_w};
    memory::dims conv1_bias_tz = {no_of_filter};
    memory::dims conv1_dst_tz = {batch, no_of_filter, out_height, out_width};

    engine::kind engine_kind = parse_engine_kind(argc, argv);
    engine eng(engine_kind, 0);
    stream s(eng);

    //memory allocation
    auto user_src_memory = memory({{conv1_src_tz}, dt::f32, tag::nhwc}, eng);
    auto user_weights_memory = memory({{conv1_weights_tz}, dt::f32, tag::hwcn}, eng); 
    auto conv1_user_bias_memory = memory({{conv1_bias_tz}, dt::f32, tag::x}, eng);
    auto conv1_dst_memory = memory({{conv1_dst_tz}, dt::f32, tag::aBcd8b }, eng);

    //data initialization
    init_data(user_src_memory);
    init_data(user_weights_memory);
    init_data(conv1_user_bias_memory);

    auto begin = chrono::duration_cast<chrono::milliseconds>(
                     chrono::steady_clock::now().time_since_epoch())
                 .count();

    convolution_param(eng, user_src_memory, batch, channel, height, width, user_weights_memory,
                      no_of_filter, kernel_h, kernel_w, pad_h, pad_w, stride_h, stride_w, conv1_user_bias_memory, conv1_dst_memory, out_height, out_width);


    auto end = chrono::duration_cast<chrono::milliseconds>(
                   chrono::steady_clock::now().time_since_epoch())
               .count();

    auto conv1_dst_memory_new = memory({{conv1_dst_tz}, dt::f32, tag::nhwc}, eng);
    reorder(conv1_dst_memory, conv1_dst_memory_new).execute(s, conv1_dst_memory, conv1_dst_memory_new);
    float *dataHandle= (float *)conv1_dst_memory_new.get_data_handle();
    
return 0;

}`

@emfomenk
Copy link

emfomenk commented Jun 5, 2020

Thanks for the reproducer! I put it into a gist gh742.cpp.

For me it runs jit:gemm implementation (see below).

What oneDNN version do you use? Could you please try the latest one?
Please also post the output of the example run with DNNL_VERBOSE.

$  ( g++ -std=c++11 gh742.cpp -ldnnl -lpthread && DNNL_VERBOSE=1 ./a.out )
dnnl_verbose,info,oneDNN v1.5.0 (commit f9eae92b4bd12bc96071441ee6361761e371e5b6)
dnnl_verbose,info,cpu,runtime:OpenMP
dnnl_verbose,info,cpu,isa:Intel AVX2
dnnl_verbose,info,gpu,runtime:OpenCL
dnnl_verbose,exec,cpu,reorder,jit:uni,undef,src_f32::blocked:cdba:f0 dst_f32::blocked:abcd:f0,,,8x8x9x9,4.10303
dnnl_verbose,exec,cpu,reorder,jit:uni,undef,src_f32::blocked:acdb:f0 dst_f32::blocked:abcd:f0,,,2x8x222x222,0.868164
dnnl_verbose,exec,cpu,convolution,gemm:jit,forward_inference,src_f32::blocked:abcd:f0 wei_f32::blocked:abcd:f0 bia_f32::blocked:a:f0 dst_f32::blocked:abcd:f0,,alg:convolution_direct,mb2_ic8oc8_ih222oh222kh9sh1dh0ph4_iw222ow222kw9sw1dw0pw4,153.279
dnnl_verbose,exec,cpu,reorder,jit:uni,undef,src_f32::blocked:aBcd8b:f0 dst_f32::blocked:acdb:f0,,,2x8x222x222,0.448975

@Alok-Ranjan23
Copy link
Author

Thanks for the reproducer! I put it into a gist gh742.cpp.

For me it runs jit:gemm implementation (see below).

That is my point. I think that convolution parameters are in blocked format.
batch = 2,channel = 8,image = (222, 222),no_of_filter = 8,kernel = (9,9) pad = (4,4),stride = (1,1)
Why does it take 'jit:gemm' implementation to perform convolution. It has to take direct convolution implementation. Why is that? Is there anything related with kernal size?

@emfomenk
Copy link

emfomenk commented Jun 8, 2020

I see. The dispatching to gemm-based implementation, which is considered as a fallback, happens because jit:avx2 direct implementation currently cannot handle problems with big (greater than 3) padding over width dimension.

See this restriction:

bool ok = ... && jcp.l_pad <= jcp.ur_w // here l_pad = 4 -- width padding,
                                       // ur_w -- register unroll, which is 3
if (!ok) return status::unimplemented;
dnnl::impl::cpu::x64::jit_avx2_conv_fwd_kernel_f32::init_conf (jcp=..., cd=..., src_d=..., weights_d=..., dst_d=..., attr=...) at /nfs/pdx/disks/hal9000/emfomenk/ml/fw/ipl_mkl_dnn-master.git.tf/src/cpu/x64/jit_avx2_conv_kernel_f32.cpp:529
529         args_ok = true && jcp.oc % simd_w == 0 && jcp.l_pad <= jcp.ur_w
(gdb) p jcp.l_pad
$1 = 4
(gdb) p jcp.ur_w
$2 = 3

For the original problem (batch=37, ic=104, oc=40, ...) I made a rough comparison of gemm-based and jit-based implementations. To run jit-based implementation I made the padding equal 3, instead of 4.

The results are (raw data below):

Problem (benchdnn) Implementation Performance (GFlops)
mb37_ic104oc40_ih222kh9sh1dh0ph4 gemm:jit 582 GFlops
mb37_ic104oc40_ih222kh9sh1dh0ph3 jit:avx2 981 GFlops

So, even though gemm-based implementation is slower, at least it is not that horrible -- for the case with ic=oc=8 the difference is much bigger.


$ ( OMP_NUM_THREADS=16 ./tests/benchdnn/benchdnn --conv --mode=P -v5 mb37_ic104oc40_ih222kh9sh1dh0ph4 mb37_ic104oc40_ih222kh9sh1dh0ph3 )
run: --conv mb37ic104ih222oc40oh222kh9ph4
oneDNN implementation: gemm:jit
oneDNN implementation: gemm:jit
Output template: perf,%engine%,%name%,%prb%,%Gops%,%Gfreq%,%-time%,%-Gflops%,%0time%,%0Gflops%
perf,cpu,,--conv mb37ic104ih222oc40oh222kh9ph4,1204.42,0,2060.35,584.571,2069.67,581.939
tests:1 passed:0 skipped:0 mistrusted:0 unimplemented:0 failed:0 listed:0
total perf: min(ms):2060.35 avg(ms):2069.67
[emfomenk@jfldnn001 03:35:01 build] (0;0)
$ ( hidebrew; OMP_NUM_THREADS=16 ./tests/benchdnn/benchdnn --conv --mode=P -v5 mb37_ic104oc40_ih222kh9sh1dh0ph3 )
run: --conv mb37ic104ih222oc40oh220kh9ph3
oneDNN implementation: jit:avx2
oneDNN implementation: jit:avx2
Output template: perf,%engine%,%name%,%prb%,%Gops%,%Gfreq%,%-time%,%-Gflops%,%0time%,%0Gflops%
perf,cpu,,--conv mb37ic104ih222oc40oh220kh9ph3,1192.27,0,1192.66,999.679,1214.8,981.455
tests:1 passed:0 skipped:0 mistrusted:0 unimplemented:0 failed:0 listed:0
total perf: min(ms):1192.66 avg(ms):1214.8

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
Projects
None yet
Development

No branches or pull requests

2 participants