diff --git a/include/tvm/relay/attrs/nn.h b/include/tvm/relay/attrs/nn.h index f8e5af98c0a0..de404f49c6aa 100644 --- a/include/tvm/relay/attrs/nn.h +++ b/include/tvm/relay/attrs/nn.h @@ -48,6 +48,7 @@ struct BiasAddAttrs : public tvm::AttrsNode { } }; + /*! \brief Attributes used in convolution operators */ struct Conv2DAttrs : public tvm::AttrsNode { Array strides; @@ -193,6 +194,61 @@ struct Conv2DWinogradNNPACKWeightTransformAttrs } }; +/*! \brief Attributes used in convolution operators */ +struct Conv3DAttrs : public tvm::AttrsNode { + Array strides; + Array padding; + Array dilation; + int groups; + IndexExpr channels; + Array kernel_size; + std::string data_layout; + std::string kernel_layout; + std::string out_layout; + DataType out_dtype; + + TVM_DECLARE_ATTRS(Conv3DAttrs, "relay.attrs.Conv3DAttrs") { + TVM_ATTR_FIELD(strides).set_default(Array({1, 1, 1})) + .describe("Specifies the strides of the convolution."); + TVM_ATTR_FIELD(padding).set_default(Array({0, 0, 0})) + .describe("If padding is non-zero, then the input is implicitly zero-padded" + "on both sides for padding number of points"); + TVM_ATTR_FIELD(dilation).set_default(Array({1, 1, 1})) + .describe("Specifies the dilation rate to use for dilated convolution."); + TVM_ATTR_FIELD(groups).set_default(1) + .describe("Controls the connections between inputs and outputs." + "At groups=1, all inputs are convolved to all outputs." + "At groups=2, the operation becomes equivalent to having two convolution" + "layers side by side, each seeing half the input channels, and producing" + "half the output channels, and both subsequently concatenated."); + TVM_ATTR_FIELD(channels) + .describe("The number of output channels in the convolution." + " If it is not set, inferred by shape of the weight.") + .set_default(NullValue()); + TVM_ATTR_FIELD(kernel_size) + .describe("Specifies the dimensions of the convolution window.") + .set_default(NullValue >()); + TVM_ATTR_FIELD(data_layout).set_default("NCDHW") + .describe("Dimension ordering of input data. Can be 'NCDHW', 'NDHWC', etc." + "'N', 'C', 'D', 'H', 'W' stands for batch, channel, depth, height, and width" + "dimensions respectively. Convolution is applied on the 'D', 'H' and" + "'W' dimensions."); + TVM_ATTR_FIELD(kernel_layout).set_default("OIDHW") + .describe("Dimension ordering of weight. Can be 'OIDHW', 'OIDHW16o16i', etc." + "'O', 'I', 'D', 'H', 'W' stands for num_filter, input_channel, depth, height," + "and width dimensions respectively."); + TVM_ATTR_FIELD(out_layout).set_default("") + .describe("Dimension ordering of output. Can be 'NCDHW', 'NDHWC', etc." + "'N', 'C', 'D', 'H', 'W' stands for batch, channel, depth, height, and width" + "dimensions respectively. Default to be same as input layout."); + + // use 0 bits to indicate none. + TVM_ATTR_FIELD(out_dtype) + .set_default(NullValue()) + .describe("Output data type, set to explicit type under mixed precision setting"); + } +}; + /*! \brief Attributes used in softmax operators */ struct SoftmaxAttrs : public tvm::AttrsNode { int axis; diff --git a/python/tvm/relay/op/nn/_nn.py b/python/tvm/relay/op/nn/_nn.py index 08ef1bf58a1e..cd8a1311eaba 100644 --- a/python/tvm/relay/op/nn/_nn.py +++ b/python/tvm/relay/op/nn/_nn.py @@ -142,7 +142,6 @@ def _find_conv2d_op(op): return op_ return None - @reg.register_compute("nn.conv2d") def compute_conv2d(attrs, inputs, out_type, target): """Compute definition of conv2d""" @@ -278,6 +277,48 @@ def compute_conv2d_transpose(attrs, inputs, out_dtype, target): return [out] +@reg.register_compute("nn.conv3d") +def compute_conv3d(attrs, inputs, out_type, target): + """Compute definition of conv3d""" + padding = get_const_tuple(attrs.padding) + strides = get_const_tuple(attrs.strides) + dilation = get_const_tuple(attrs.dilation) + groups = attrs.groups + layout = attrs.data_layout + out_dtype = attrs.out_dtype + out_dtype = (inputs[0].dtype if out_dtype in ("same", "") + else out_dtype) + + assert layout in ["NCDHW"] + (dilation_d, dilation_h, dilation_w) = dilation + if dilation_d < 1 or dilation_h < 1 or dilation_w < 1: + raise ValueError("dilation should be positive value") + + if groups == 1: + out = topi.nn.conv3d( + inputs[0], inputs[1], strides, padding, + dilation, layout, out_dtype) + else: + raise ValueError("not support arbitrary group number for now") + return [out] + + +@reg.register_schedule("nn.conv3d") +def schedule_conv3d(attrs, outs, target): + """Schedule definition of conv3d""" + groups = attrs.groups + layout = attrs.data_layout + + with target: + if groups == 1 and layout == "NCDHW": + return topi.generic.schedule_conv3d_ncdhw(outs) + + raise ValueError("No compatible schedule") + + +reg.register_pattern("nn.conv3d", OpPattern.OUT_ELEMWISE_FUSABLE) + + @reg.register_schedule("nn.conv2d_transpose") def schedule_conv2d_transpose(attrs, outs, target): """Schedule definition of conv2d_transpose""" diff --git a/python/tvm/relay/op/nn/nn.py b/python/tvm/relay/op/nn/nn.py index 5f3f80084787..5e1c6a8c2616 100644 --- a/python/tvm/relay/op/nn/nn.py +++ b/python/tvm/relay/op/nn/nn.py @@ -106,6 +106,91 @@ def conv2d(data, kernel_layout, out_layout, out_dtype) +def conv3d(data, + weight, + strides=(1, 1, 1), + padding=(0, 0, 0), + dilation=(1, 1, 1), + groups=1, + channels=None, + kernel_size=None, + data_layout="NCDHW", + kernel_layout="OIDHW", + out_layout="", + out_dtype=""): + r"""3D convolution. + + This operator takes the weight as the convolution kernel + and convolves it with data to produce an output. + + + In the default case, where the data_layout is `NCDHW` + and kernel_layout is `OIDHW`, conv3d takes in + a data Tensor with shape `(batch_size, in_channels, depth, height, width)`, + and a weight Tensor with shape `(channels, in_channels, kernel_size[0], kernel_size[1], + kernel_size[2])` to produce an output Tensor with the following rule: + + .. math:: + + \mbox{out}[b, c, z, y, x] = \sum_{dz, dy, dx, k} + \mbox{data}[b, k, \mbox{strides}[0] * z + dz, \mbox{strides}[1] * y + dy, + \mbox{strides}[2] * x + dx] * \mbox{weight}[c, k, dz, dy, dx] + + Padding and dilation are applied to data and weight respectively before the computation. + This operator accepts data layout specification. + Semantically, the operator will convert the layout to the canonical layout + (`NCDHW` for data and `OIDHW` for weight), perform the computation, + then convert to the out_layout. + + + Parameters + ---------- + data : tvm.relay.Expr + The input data to the operator. + + weight : tvm.relay.Expr + The weight expressions. + + strides : Optional[Tuple[int]] + The strides of convolution. + + padding : Optional[Tuple[int]] + The padding of convolution on both sides of inputs before convolution. + + dilation : Optional[Tuple[int]] + Specifies the dilation rate to be used for dilated convolution. + + groups : Optional[int] + Number of groups for grouped convolution. + + channels : Optional[int] + Number of output channels of this convolution. + + kernel_size : Optional[Tuple[int]] + The spatial of the convolution kernel. + + data_layout : Optional[str] + Layout of the input. + + kernel_layout : Optional[str] + Layout of the weight. + + out_layout : Optional[str] + Layout of the output, by default, out_layout is the same as data_layout + + out_dtype : Optional[str] + Specifies the output data type for mixed precision conv2d. + + Returns + ------- + result : tvm.relay.Expr + The computed result. + """ + return _make.conv3d(data, weight, strides, padding, dilation, + groups, channels, kernel_size, data_layout, + kernel_layout, out_layout, out_dtype) + + def conv2d_transpose(data, weight, strides=(1, 1), diff --git a/src/relay/op/nn/convolution.cc b/src/relay/op/nn/convolution.cc index 65e71bec4dea..3c9bebc1b0d0 100644 --- a/src/relay/op/nn/convolution.cc +++ b/src/relay/op/nn/convolution.cc @@ -106,6 +106,64 @@ with the layer input to produce a tensor of outputs. .add_type_rel("Conv2D", Conv2DRel) .set_attr("FInferCorrectLayout", Conv2DInferCorrectLayout); +// relay.nn.conv3d +TVM_REGISTER_NODE_TYPE(Conv3DAttrs); + +// Positional relay function to create conv3d operator +// used by frontend FFI. +Expr MakeConv3D(Expr data, + Expr weight, + Array strides, + Array padding, + Array dilation, + int groups, + IndexExpr channels, + Array kernel_size, + std::string data_layout, + std::string kernel_layout, + std::string out_layout, + DataType out_dtype) { + auto attrs = make_node(); + attrs->strides = std::move(strides); + attrs->padding = std::move(padding); + attrs->dilation = std::move(dilation); + attrs->groups = groups; + attrs->channels = std::move(channels); + attrs->kernel_size = std::move(kernel_size); + attrs->data_layout = std::move(data_layout); + attrs->kernel_layout = std::move(kernel_layout); + attrs->out_layout = std::move(out_layout); + attrs->out_dtype = std::move(out_dtype); + static const Op& op = Op::Get("nn.conv3d"); + return CallNode::make(op, {data, weight}, Attrs(attrs), {}); +} + + +TVM_REGISTER_API("relay.op.nn._make.conv3d") +.set_body_typed(MakeConv3D); + + +RELAY_REGISTER_OP("nn.conv3d") +.describe(R"code(3D convolution layer (e.g. convolution over 3D image data, +like Magnetic Resonance Imaging (MRI) data in medicine). + +This layer creates a convolution kernel that is convolved +with the layer input to produce a tensor of outputs. + +- **data**: This depends on the `layout` parameter. Input is 5D array of shape + (batch_size, in_channels, depth, height, width) if `layout` is `NCDHW`. +- **weight**: (channels, in_channels, kernel_size[0], kernel_size[1], kernel_size[2]) +- **out**: This depends on the `layout` parameter. Output is 5D array of shape + (batch_size, channels, out_depth, out_height, out_width) if `layout` is `NCDHW`. + +)code" TVM_ADD_FILELINE) +.set_attrs_type() +.set_num_inputs(2) +.add_argument("data", "Tensor", "The input tensor.") +.add_argument("weight", "Tensor", "The weight tensor.") +.set_support_level(2) +.add_type_rel("Conv3D", Conv3DRel); + // relay.nn.conv2d_transpose TVM_REGISTER_NODE_TYPE(Conv2DTransposeAttrs); diff --git a/src/relay/op/nn/convolution.h b/src/relay/op/nn/convolution.h index 01437302bc92..efcf7dfe6906 100644 --- a/src/relay/op/nn/convolution.h +++ b/src/relay/op/nn/convolution.h @@ -138,6 +138,123 @@ bool Conv2DRel(const Array& types, int num_inputs, const Attrs& attrs, return true; } +template +bool Conv3DRel(const Array& types, int num_inputs, const Attrs& attrs, + const TypeReporter& reporter) { + CHECK_EQ(types.size(), 3); + const auto* data = types[0].as(); + const auto* weight = types[1].as(); + if (data == nullptr) return false; + static const Layout kNCDHW("NCDHW"); + static const Layout kOIDHW("OIDHW"); + + const AttrType* param = attrs.as(); + CHECK(param != nullptr); + const Layout in_layout(param->data_layout); + const Layout kernel_layout(param->kernel_layout); + + const auto trans_in_layout = BijectiveLayoutNode::make(in_layout, kNCDHW); + CHECK(trans_in_layout.defined()) + << "Conv only support input layouts that are convertible from NCDHW." + << " But got " << in_layout; + + const auto trans_kernel_layout = BijectiveLayoutNode::make(kernel_layout, kOIDHW); + CHECK(trans_kernel_layout.defined()) + << "Conv only support kernel layouts that are convertible from OIDHW." + << " But got " << kernel_layout; + + Layout out_layout(param->out_layout == "" ? param->data_layout : param->out_layout); + const auto trans_out_layout = BijectiveLayoutNode::make(out_layout, kNCDHW); + CHECK(trans_out_layout.defined()) + << "Conv only support output layouts that are convertible from NCDHW." + << " But got " << out_layout; + + Array dshape_ncdhw = trans_in_layout.ForwardShape(data->shape); + + IndexExpr channels, dilated_ksize_z, dilated_ksize_y, dilated_ksize_x; + // infer weight if the kernel_size and channels are defined + if (param->kernel_size.defined() && param->channels.defined()) { + CHECK_EQ(param->kernel_size.size(), 3); + CHECK_EQ(param->dilation.size(), 3); + Array wshape; + + if (tvm::ir::Equal(param->channels, param->groups) && !tvm::ir::Equal(param->channels, 1)) { + // infer weight's shape for depthwise convolution + wshape = {{dshape_ncdhw[1], indexdiv(param->groups, dshape_ncdhw[1]), param->kernel_size[0], + param->kernel_size[1], param->kernel_size[2]}}; + } else { + wshape = {{param->channels, indexdiv(dshape_ncdhw[1], param->groups), param->kernel_size[0], + param->kernel_size[1], param->kernel_size[2]}}; + } + + /*wshape = trans_kernel_layout.BackwardShape(wshape); */ + channels = param->channels; + dilated_ksize_z = 1 + (param->kernel_size[0] - 1) * param->dilation[0]; + dilated_ksize_y = 1 + (param->kernel_size[1] - 1) * param->dilation[1]; + dilated_ksize_x = 1 + (param->kernel_size[2] - 1) * param->dilation[2]; + DataType weight_dtype = data->dtype; + if (weight != nullptr) { + weight_dtype = weight->dtype; + } + // assign result to reporter + reporter->Assign(types[1], TensorTypeNode::make(wshape, weight_dtype)); + } else { + // use weight to infer the conv shape. + if (weight == nullptr) return false; + auto wshape = trans_kernel_layout.ForwardShape(weight->shape); + if (param->kernel_size.defined()) { + CHECK_EQ(param->kernel_size.size(), 3); + // check the size + CHECK(reporter->AssertEQ(param->kernel_size[0], wshape[2]) && + reporter->AssertEQ(param->kernel_size[1], wshape[3]) && + reporter->AssertEQ(param->kernel_size[2], wshape[4])) + << "Conv3D: shape of weight is inconsistent with kernel_size, " + << " kernel_size=" << param->kernel_size << " wshape=" << wshape; + } + if (param->channels.defined()) { + CHECK(reporter->AssertEQ(param->channels, wshape[0])) + << "Conv3D: shape of weight is inconsistent with channels, " + << " channels=" << param->channels << " wshape=" << wshape; + } + CHECK(reporter->AssertEQ(indexdiv(dshape_ncdhw[1], param->groups), wshape[1])); + channels = wshape[0]; + dilated_ksize_z = 1 + (wshape[2] - 1) * param->dilation[0]; + dilated_ksize_y = 1 + (wshape[3] - 1) * param->dilation[1]; + dilated_ksize_x = 1 + (wshape[4] - 1) * param->dilation[2]; + } + // dilation + Array oshape({dshape_ncdhw[0], channels, 0, 0, 0}); + + if (!dshape_ncdhw[2].as()) { + oshape.Set(2, indexdiv(dshape_ncdhw[2] + param->padding[0] * 2 - dilated_ksize_z, + param->strides[0]) + 1); + } else { + oshape.Set(2, dshape_ncdhw[2]); + } + + if (!dshape_ncdhw[3].as()) { + oshape.Set(3, indexdiv(dshape_ncdhw[3] + param->padding[1] * 2 - dilated_ksize_y, + param->strides[1]) + 1); + } else { + oshape.Set(3, dshape_ncdhw[3]); + } + + if (!dshape_ncdhw[4].as()) { + oshape.Set(4, indexdiv(dshape_ncdhw[4] + param->padding[2] * 2 - dilated_ksize_x, + param->strides[2]) + 1); + } else { + oshape.Set(4, dshape_ncdhw[4]); + } + DataType out_dtype = param->out_dtype; + if (out_dtype.bits() == 0) { + out_dtype = data->dtype; + } + oshape = trans_out_layout.BackwardShape(oshape); + // assign output type + reporter->Assign(types[2], TensorTypeNode::make(oshape, out_dtype)); + return true; +} + } // namespace relay } // namespace tvm #endif // TVM_RELAY_OP_NN_CONVOLUTION_H_ diff --git a/tests/python/relay/test_op_level2.py b/tests/python/relay/test_op_level2.py index 4099d19e407e..bb16487d610b 100644 --- a/tests/python/relay/test_op_level2.py +++ b/tests/python/relay/test_op_level2.py @@ -294,6 +294,51 @@ def run_test_conv2d_cuda(dtype, out_dtype, scale, dshape, kshape, padding=(2, 2), channels=192, kernel_size=(7, 7)) +def test_conv3d_run(): + def run_test_conv3d(dtype, out_dtype, scale, dshape, kshape, + padding=(1, 1, 1), + fref=None, + groups=1, + dilation=(1, 1, 1), + except_targets=None, + **attrs): + if except_targets is None: + except_targets = [] + + x = relay.var("x", shape=dshape, dtype=dtype) + w = relay.var("w", dtype=dtype) + y = relay.nn.conv3d(x, w, + padding=padding, + dilation=dilation, + groups=groups, + **attrs) + func = relay.Function([x, w], y) + data = np.random.uniform(-scale, scale, size=dshape).astype(dtype) + kernel = np.random.uniform(-scale, scale, size=kshape).astype(dtype) + dkernel = topi.testing.dilate_python(kernel, (1, 1) + dilation) + if fref is None: + ref_res = topi.testing.conv3d_ncdhw_python( + data.astype(out_dtype), dkernel.astype(out_dtype), 1, padding, + groups=groups) + else: + ref_res = fref(data.astype(out_dtype), dkernel.astype(out_dtype)) + + + for target, ctx in ctx_list(): + if target in except_targets: + continue + + intrp1 = relay.create_executor("graph", ctx=ctx, target=target) + op_res1 = intrp1.evaluate(func)(data, kernel) + tvm.testing.assert_allclose(op_res1.asnumpy(), ref_res, rtol=1e-5, atol=1e-5) + + # normal conv3d + dshape = (1, 3, 5, 224, 224) + kshape = (10, 3, 3, 3, 3) + run_test_conv3d("float32", "float32", 1, dshape, kshape, + padding=(1, 1, 1), channels=10, kernel_size=(3, 3 ,3)) + + def test_conv2d_transpose_infer_type(): # symbolic in batch dimension n, c, h, w = tvm.var("n"), 10, 10, 12 @@ -850,6 +895,7 @@ def test_bitpack_infer_type(): test_conv2d_transpose_nhwc_run() test_conv2d_run() test_conv2d_winograd() + test_conv3d_run() test_bitserial_conv2d_infer_type() test_batch_flatten() test_upsampling() diff --git a/topi/python/topi/cuda/__init__.py b/topi/python/topi/cuda/__init__.py index 58bb3a5688c0..e6a342def32b 100644 --- a/topi/python/topi/cuda/__init__.py +++ b/topi/python/topi/cuda/__init__.py @@ -21,6 +21,7 @@ from . import conv2d, depthwise_conv2d, conv2d_transpose_nchw, deformable_conv2d, \ group_conv2d_nchw, dense +from . import conv3d from .conv2d_hwcn import schedule_conv2d_hwcn from .depthwise_conv2d import schedule_depthwise_conv2d_backward_input_nhwc from .depthwise_conv2d import schedule_depthwise_conv2d_backward_weight_nhwc diff --git a/topi/python/topi/cuda/conv3d.py b/topi/python/topi/cuda/conv3d.py new file mode 100644 index 000000000000..8d3c720b6a89 --- /dev/null +++ b/topi/python/topi/cuda/conv3d.py @@ -0,0 +1,136 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +# pylint: disable=invalid-name +"""Compute definition for conv3d with cuda backend""" +import tvm +from tvm import autotvm +from tvm.contrib import cudnn + +from .. import nn, generic +from ..util import get_const_tuple, traverse_inline + +from .conv3d_direct import schedule_direct_3d_cuda + + +@autotvm.register_topi_compute(nn.conv3d, ['cuda', 'gpu'], ['direct']) +def conv3d_cuda(cfg, data, kernel, strides, padding, dilation, layout='NCDHW', out_dtype='float32'): + """Conv3D operator for cuda backend. + + Parameters + ---------- + cfg: ConfigEntity + The config for this template + + data : tvm.Tensor + 5-D with shape [batch, in_channel, in_depth, in_height, in_width] + + kernel : tvm.Tensor + 5-D with shape [num_filter, in_channel, filter_depth, filter_height, filter_width] + + strides : int or a list/tuple of three ints + stride size, or [stride_depth, stride_height, stride_width] + + padding : int or a list/tuple of three ints + padding size, or [pad_depth, pad_height, pad_width] + + dilation: int or a list/tuple of three ints + dilation size, or [dilation_depth, dilation_height, dilation_width] + + layout : str + layout of data + + out_dtype: str + The output type. This is used for mixed precision. + + Returns + ------- + output : tvm.Tensor + 5-D with shape [batch, out_channel, out_depth, out_height, out_width] + """ + target = tvm.target.current_target() + + if "cudnn" in target.libs: + if layout == 'NCDHW': + tensor_format = 0 # CUDNN_TENSOR_NCHW + N, _, D, H, W = get_const_tuple(data.shape) + elif layout == 'NDHWC': + tensor_format = 1 # CUDNN_TENSOR_NHWC + N, D, H, W, _ = get_const_tuple(data.shape) + else: + raise ValueError("Unsupported layout %s in cudnn" % layout) + CO, CI, KD, KH, KW = get_const_tuple(kernel.shape) + + # handle dilation + stride_d, stride_h, stride_w = (strides, strides, strides) if isinstance(strides, int) \ + else strides + pad_d, pad_h, pad_w = (padding, padding, padding) if isinstance(padding, int) else padding + dilation_d, dilation_h, dilation_w = (dilation, dilation, dilation) if \ + isinstance(dilation, int) else dilation + + OD = (D + 2 * pad_d - KD) // stride_d + 1 + OH = (H + 2 * pad_h - KH) // stride_h + 1 + OW = (W + 2 * pad_w - KW) // stride_w + 1 + cfg.add_flop(2 * N * OD * OH * OW * CO * CI * ((DH - 1) * dilation_d + 1) *\ + ((KH - 1) * dilation_h + 1) * ((KW - 1) * dilation_w + 1)) + + return cudnn.conv_forward(data, + kernel, + [pad_d, pad_h, pad_w], + [stride_d, stride_h, stride_w], + [dilation_d, dilation_h, dilation_w], + conv_mode=1, + tensor_format=tensor_format, + algo=-1, # let CUDNN choose the best algo + conv_dtype=dtype) + + if layout == 'NCDHW': + return nn.conv3d_ncdhw(data, kernel, strides, padding, dilation, out_dtype) + raise ValueError("not support this layout {} yet".format(layout)) + + +@autotvm.register_topi_schedule(generic.schedule_conv3d_ncdhw, ["cuda", "gpu"], + ["direct"]) +def schedule_conv3d_ncdhw_cuda(cfg, outs): + """TOPI schedule callback of conv3d for cuda gpu + + Parameters + ---------- + cfg: ConfigEntity + The config for this template + + outs: Array of Tensor + The computation graph description of conv2d + in the format of an array of tensors. + + Returns + ------- + s: Schedule + The computation schedule for conv2d. + """ + target = tvm.target.current_target() + if 'cudnn' in target.libs: + return generic.schedule_extern(outs) + + outs = [outs] if isinstance(outs, tvm.tensor.Tensor) else outs + s = tvm.create_schedule([x.op for x in outs]) + + def _callback(op): + if op.tag == 'conv3d_ncdhw': + schedule_direct_3d_cuda(cfg, s, op.output(0)) + + traverse_inline(s, outs[0].op, _callback) + return s diff --git a/topi/python/topi/cuda/conv3d_direct.py b/topi/python/topi/cuda/conv3d_direct.py new file mode 100644 index 000000000000..e38dbcbfa002 --- /dev/null +++ b/topi/python/topi/cuda/conv3d_direct.py @@ -0,0 +1,124 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +# pylint: disable=invalid-name +"""The templates for cuda conv3d operators""" +import tvm +from tvm import autotvm +from ..util import get_const_tuple + +def schedule_direct_3d_cuda(cfg, s, conv): + """schedule optimized for batch size = 1""" + + ##### space definition begin ##### + n, f, d, y, x = s[conv].op.axis + rc, rd, ry, rx = s[conv].op.reduce_axis + cfg.define_split("tile_f", f, num_outputs=4) + cfg.define_split("tile_d", d, num_outputs=4) + cfg.define_split("tile_y", y, num_outputs=4) + cfg.define_split("tile_x", x, num_outputs=4) + cfg.define_split("tile_rc", rc, num_outputs=2) + cfg.define_split("tile_rd", ry, num_outputs=2) + cfg.define_split("tile_ry", ry, num_outputs=2) + cfg.define_split("tile_rx", rx, num_outputs=2) + cfg.define_knob("auto_unroll_max_step", [0, 512, 1500]) + + target = tvm.target.current_target() + if target.target_name in ['nvptx', 'rocm']: + cfg.define_knob("unroll_explicit", [1]) + else: + cfg.define_knob("unroll_explicit", [0, 1]) + + # fallback support + if cfg.is_fallback: + ref_log = autotvm.tophub.load_reference_log( + target.target_name, target.model, 'conv3d', 'direct') + cfg.fallback_with_reference_log(ref_log) + ##### space definition end ##### + + pad_data, kernel = s[conv].op.input_tensors + + s[pad_data].compute_inline() + if isinstance(kernel.op, tvm.tensor.ComputeOp) and 'dilate' in kernel.op.tag: + s[kernel].compute_inline() + + if conv.op in s.outputs: + output = conv + OL = s.cache_write(conv, 'local') + else: + output = s.outputs[0].output(0) + s[conv].set_scope('local') + OL = conv + + # create cache stage + AA = s.cache_read(pad_data, 'shared', [OL]) + WW = s.cache_read(kernel, 'shared', [OL]) + + # tile and bind spatial axes + n, f, d, y, x = s[output].op.axis + kernel_scope, n = s[output].split(n, nparts=1) + + bf, vf, tf, fi = cfg["tile_f"].apply(s, output, f) + bd, vd, td, di = cfg["tile_d"].apply(s, output, d) + by, vy, ty, yi = cfg["tile_y"].apply(s, output, y) + bx, vx, tx, xi = cfg["tile_x"].apply(s, output, x) + + bf = s[output].fuse(n, bf) + s[output].reorder(bf, bd, by, bx, vf, vd, vy, vx, tf, td, ty, tx, fi, di, yi, xi) + + s[output].bind(bf, tvm.thread_axis("blockIdx.z")) + s[output].bind(s[output].fuse(bd, by), tvm.thread_axis("blockIdx.y")) + s[output].bind(bx, tvm.thread_axis("blockIdx.x")) + s[output].bind(vf, tvm.thread_axis("vthread")) + s[output].bind(vd, tvm.thread_axis("vthread")) + s[output].bind(vy, tvm.thread_axis("vthread")) + s[output].bind(vx, tvm.thread_axis("vthread")) + s[output].bind(s[output].fuse(td, tf), tvm.thread_axis("threadIdx.z")) + s[output].bind(ty, tvm.thread_axis("threadIdx.y")) + s[output].bind(tx, tvm.thread_axis("threadIdx.x")) + s[OL].compute_at(s[output], tx) + + # tile reduction axes + n, f, d, y, x = s[OL].op.axis + rc, rd, ry, rx = s[OL].op.reduce_axis + rco, rci = cfg['tile_rc'].apply(s, OL, rc) + rdo, rdi = cfg['tile_rd'].apply(s, OL, rd) + ryo, ryi = cfg['tile_ry'].apply(s, OL, ry) + rxo, rxi = cfg['tile_rx'].apply(s, OL, rx) + s[OL].reorder(rco, rdo, ryo, rxo, rci, rdi, ryi, rxi, n, f, d, y, x) + + s[AA].compute_at(s[OL], rxo) + s[WW].compute_at(s[OL], rxo) + + # cooperative fetching + for load in [AA, WW]: + n, f, d, y, x = s[load].op.axis + fused = s[load].fuse(n, f, d, y, x) + tz, fused = s[load].split(fused, nparts=cfg["tile_f"].size[2]) + td, fused = s[load].split(fused, nparts=cfg["tile_d"].size[2]) + ty, fused = s[load].split(fused, nparts=cfg["tile_y"].size[2]) + tx, fused = s[load].split(fused, nparts=cfg["tile_x"].size[2]) + s[load].bind(tz, tvm.thread_axis("threadIdx.z")) + s[load].bind(s[load].fuse(td, ty), tvm.thread_axis("threadIdx.y")) + s[load].bind(tx, tvm.thread_axis("threadIdx.x")) + + # unroll + s[output].pragma(kernel_scope, 'auto_unroll_max_step', cfg['auto_unroll_max_step'].val) + s[output].pragma(kernel_scope, 'unroll_explicit', cfg['unroll_explicit'].val) + + N, CO, OD, OH, OW = get_const_tuple(output.shape) + _, KD, KH, KW, CI = get_const_tuple(kernel.shape) + cfg.add_flop(2 * N * OD * OH * OW * CO * CI * KD * KH * KW) diff --git a/topi/python/topi/generic/nn.py b/topi/python/topi/generic/nn.py index 4043cb7e4606..752cb5a63401 100644 --- a/topi/python/topi/generic/nn.py +++ b/topi/python/topi/generic/nn.py @@ -225,6 +225,24 @@ def schedule_conv2d_winograd_nnpack_without_weight_transform(outs): return _default_schedule(outs, False) +@tvm.target.generic_func +def schedule_conv3d_ncdhw(outs): + """Schedule for conv3d_ncdhw + + Parameters + ---------- + outs: Array of Tensor + The computation graph description of conv2d_nchw + in the format of an array of tensors. + + Returns + ------- + sch: Schedule + The computation schedule for the op. + """ + return _default_schedule(outs, False) + + @tvm.target.generic_func def schedule_conv2d_transpose_nchw(outs): """Schedule for conv2d_transpose_nchw diff --git a/topi/python/topi/nn/__init__.py b/topi/python/topi/nn/__init__.py index dc3f369be285..f42cde860ae9 100644 --- a/topi/python/topi/nn/__init__.py +++ b/topi/python/topi/nn/__init__.py @@ -20,6 +20,7 @@ from __future__ import absolute_import as _abs from .conv2d import * +from .conv3d import * from .deformable_conv2d import * from .depthwise_conv2d import * from .elemwise import * diff --git a/topi/python/topi/nn/conv3d.py b/topi/python/topi/nn/conv3d.py new file mode 100644 index 000000000000..928f32f51d75 --- /dev/null +++ b/topi/python/topi/nn/conv3d.py @@ -0,0 +1,130 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +# pylint: disable=invalid-name, unused-variable, too-many-locals +# pylint: disable=unused-argument, redefined-builtin +"""Conv3D operators""" +from __future__ import absolute_import as _abs +import tvm + +from .pad import pad +from .util import get_pad_tuple3d +from ..util import simplify + + +@tvm.target.generic_func +def conv3d(input, filter, strides, padding, dilation, layout='NCDHW', out_dtype=None): + """Conv3D operator. + + Parameters + ---------- + input : tvm.Tensor + 5-D with shape [batch, in_depth, in_channel, in_height, in_width] + + filter : tvm.Tensor + 5-D with shape [num_filter, in_channel, filter_depth, filter_height, filter_width] + + strides : int or a list/tuple of three ints + stride size, or [stride_depth, stride_height, stride_width] + + padding : int or a list/tuple of three ints + padding size, or [pad_depth, pad_height, pad_width] + + dilation: int or a list/tuple of three ints + dilation size, or [dilation_depth, dilation_height, dilation_width] + + layout : str + layout of data + + Returns + ------- + output : tvm.Tensor + 5-D with shape [batch, out_depth, out_channel, out_height, out_width] + """ + # search platform specific declaration first + # default declaration + if layout == 'NCDHW': + return conv3d_ncdhw(input, filter, strides, padding, dilation, out_dtype) + raise ValueError("not support this layout {} yet".format(layout)) + + +def conv3d_ncdhw(Input, Filter, stride, padding, dilation, out_dtype=None): + """Convolution operator in NCDHW layout. + + Parameters + ---------- + Input : tvm.Tensor + 5-D with shape [batch, in_channel, in_depth, in_height, in_width] + + Filter : tvm.Tensor + 5-D with shape [num_filter, in_channel, filter_depth, filter_height, filter_width] + + stride : int or a list/tuple of three ints + Stride size, or [strid_depth, stride_height, stride_width] + + padding : int or str + Padding size, or ['VALID', 'SAME'] + + dilation: int or a list/tuple of three ints + dilation size, or [dilation_depth, dilation_height, dilation_width] + + Returns + ------- + Output : tvm.Tensor + 5-D with shape [batch, out_channel, out_depth, out_height, out_width] + """ + if out_dtype is None: + out_dtype = Input.dtype + assert isinstance(stride, int) or len(stride) == 3 + assert isinstance(dilation, int) or len(dilation) == 3 + if isinstance(stride, int): + stride_d = stride_h = stride_w = stride + else: + stride_d, stride_h, stride_w = stride + + if isinstance(dilation, int): + dilation_d = dilation_h = dilation_w = dilation + else: + dilation_d, dilation_h, dilation_w = dilation + + batch, in_channel, in_depth, in_height, in_width = Input.shape + num_filter, channel, kernel_d, kernel_h, kernel_w = Filter.shape + # compute the output shape + dilated_kernel_d = (kernel_d - 1) * dilation_d + 1 + dilated_kernel_h = (kernel_h - 1) * dilation_h + 1 + dilated_kernel_w = (kernel_w - 1) * dilation_w + 1 + pad_front, pad_top, pad_left, pad_back, pad_down, pad_right = get_pad_tuple3d( + padding, (dilated_kernel_d, dilated_kernel_h, dilated_kernel_w)) + out_channel = num_filter + out_depth = simplify((in_depth - dilated_kernel_d + pad_front + pad_back) // stride_d + 1) + out_height = simplify((in_height - dilated_kernel_h + pad_top + pad_down) // stride_h + 1) + out_width = simplify((in_width - dilated_kernel_w + pad_left + pad_right) // stride_w + 1) + # compute graph + pad_before = [0, 0, pad_front, pad_top, pad_left] + pad_after = [0, 0, pad_back, pad_down, pad_right] + temp = pad(Input, pad_before, pad_after, name="pad_temp") + rc = tvm.reduce_axis((0, in_channel), name='rc') + rz = tvm.reduce_axis((0, kernel_d), name='rz') + ry = tvm.reduce_axis((0, kernel_h), name='ry') + rx = tvm.reduce_axis((0, kernel_w), name='rx') + + return tvm.compute( + (batch, out_channel, out_depth, out_height, out_width), + lambda nn, ff, zz, yy, xx: tvm.sum( + temp[nn, rc, zz * stride_d + rz * dilation_d, yy * stride_h + ry * dilation_h, + xx * stride_w + rx * dilation_w].astype(out_dtype) * + Filter[ff, rc, rz, ry, rx].astype(out_dtype), + axis=[rc, rz, ry, rx]), tag="conv3d_ncdhw") diff --git a/topi/python/topi/nn/util.py b/topi/python/topi/nn/util.py index 7c0957cc7fc4..463edaa463dc 100644 --- a/topi/python/topi/nn/util.py +++ b/topi/python/topi/nn/util.py @@ -118,3 +118,57 @@ def get_pad_tuple(padding, kernel): pad_top = (pad_h + 1) // 2 pad_left = (pad_w + 1) // 2 return pad_top, pad_left, pad_h - pad_top, pad_w - pad_left + + +def get_pad_tuple3d(padding, kernel): + """Common code to get the pad option + + Parameters + ---------- + padding : int or str + Padding size, or ['VALID', 'SAME'] + + kernel : tuple of int + Conv kernel size + + Returns + ------- + pad_front : int + Padding size on front. + + pad_top : int + Padding size on top + + pad_left : int + Padding size on left + + pad_back : int + Padding size on back. + + pad_down : int + Padding size on down. + + pad_right : int + Padding size on right. + """ + # compute the padding size + if isinstance(padding, (tuple, list)): + pad_h = padding[0] * 2 + pad_w = padding[1] * 2 + pad_d = padding[2] * 2 + elif isinstance(padding, int): + pad_d = pad_w = pad_h = padding * 2 + elif padding == "VALID": + pad_h = 0 + pad_w = 0 + pad_d = 0 + elif padding == "SAME": + pad_h = kernel[0] - 1 + pad_w = kernel[1] - 1 + pad_d = kernel[2] - 1 + else: + raise ValueError("Unknown padding option %s" % padding) + pad_top = (pad_h + 1) // 2 + pad_left = (pad_w + 1) // 2 + pad_front = (pad_d + 1) // 2 + return pad_front, pad_top, pad_left, pad_d - pad_front, pad_h - pad_top, pad_w - pad_left diff --git a/topi/tests/python/test_topi_conv3d_ncdhw.py b/topi/tests/python/test_topi_conv3d_ncdhw.py new file mode 100644 index 000000000000..78827e4ca9d1 --- /dev/null +++ b/topi/tests/python/test_topi_conv3d_ncdhw.py @@ -0,0 +1,115 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +"""Example code to do convolution.""" + +import numpy as np +import tvm +from tvm import autotvm +import topi +import topi.testing +from tvm.contrib.pickle_memoize import memoize +from topi.util import get_const_tuple + +from common import get_all_backend + +def verify_conv3d_ncdhw(batch, in_channel, in_size, num_filter, kernel, stride, padding, dilation=1, add_bias=False, add_relu=False): + print("Workload: (%d, %d, %d, %d, %d, %d, %d, %d)" % (batch, in_channel, in_size, num_filter, kernel, stride, padding, dilation)) + + in_depth = in_height = in_width = in_size + + A = tvm.placeholder((batch, in_channel, in_depth, in_height, in_width), name='A') + W = tvm.placeholder((num_filter, in_channel, kernel, kernel, kernel), name='W') + bias = tvm.placeholder((num_filter, 1, 1, 1), name='bias') + + a_shape = get_const_tuple(A.shape) + w_shape = get_const_tuple(W.shape) + bias_shape = get_const_tuple(bias.shape) + dtype = A.dtype + + @memoize("topi.tests.test_topi_conv3d_ncdhw.verify_conv3d_ncdhw") + def get_ref_data(): + a_np = np.random.uniform(size=a_shape).astype(dtype) + w_np = np.random.uniform(size=w_shape).astype(dtype) + b_np = np.random.uniform(size=bias_shape).astype(dtype) + dw_np = topi.testing.dilate_python(w_np, (1, 1, dilation, dilation, dilation)) + c_np = topi.testing.conv3d_ncdhw_python(a_np, dw_np, stride, padding) + if add_bias: + c_np += b_np + if add_relu: + c_np = np.maximum(c_np, 0) + return a_np, w_np, b_np, c_np + + a_np, w_np, b_np, c_np = get_ref_data() + + def check_device(device): + ctx = tvm.context(device, 0) + if not ctx.exist: + print("Skip because %s is not enabled" % device) + return + print("Running on target: %s" % device) + with tvm.target.create(device): + C = topi.nn.conv3d(A, W, (stride, stride, stride), (padding, padding, padding), + (dilation, dilation, dilation), layout='NCDHW', out_dtype=dtype) + if add_bias: + C = topi.add(C, bias) + if add_relu: + C = topi.nn.relu(C) + s = topi.generic.schedule_conv3d_ncdhw([C]) + + a = tvm.nd.array(a_np, ctx) + w = tvm.nd.array(w_np, ctx) + b = tvm.nd.array(b_np, ctx) + c = tvm.nd.array(np.zeros(get_const_tuple(C.shape), dtype=C.dtype), ctx) + if add_bias: + func = tvm.build(s, [A, W, bias, C], device, name="relu_%d_%d_%d_%d_%d_%d_%d_%d" % (batch, in_channel, in_size, num_filter, kernel, stride, padding, dilation)) + func(a, w, b, c) + else: + func = tvm.build(s, [A, W, C], device, name="relu_%d_%d_%d_%d_%d_%d_%d_%d" % (batch, in_channel, in_size, num_filter, kernel, stride, padding, dilation)) + func(a, w, c) + tvm.testing.assert_allclose(c.asnumpy(), c_np, rtol=1e-4) + + for device in get_all_backend(): + with autotvm.tophub.context(device): # load tophub pre-tuned parameters + check_device(device) + + +def test_conv3d_ncdhw(): + #3DCNN workloads + verify_conv3d_ncdhw(1, 32, 32, 5, 1, 1, 0) + verify_conv3d_ncdhw(1, 32, 32, 1, 1, 1, 0) + verify_conv3d_ncdhw(1, 32, 32, 5, 1, 1, 1) + verify_conv3d_ncdhw(1, 32, 32, 1, 1, 1, 1) + + # bias, relu + verify_conv3d_ncdhw(1, 64, 56, 3, 1, 1, 1, add_relu=True) + verify_conv3d_ncdhw(1, 64, 56, 3, 1, 1, 1, add_bias=True) + verify_conv3d_ncdhw(1, 64, 56, 3, 1, 1, 1, add_bias=True, add_relu=True) + + # dilation = 2 + verify_conv3d_ncdhw(1, 64, 56, 3, 3, 1, 1, dilation=2) + + # batch size + verify_conv3d_ncdhw(4, 64, 56, 5, 3, 1, 1) + + # weird workloads + verify_conv3d_ncdhw(2, 2, 2, 2, 2, 2, 2) + verify_conv3d_ncdhw(3, 3, 3, 3, 3, 3, 3) + + + +if __name__ == "__main__": + test_conv3d_ncdhw()