Skip to content

Commit

Permalink
[Op] Col2Im-15 reference implementation (#24548)
Browse files Browse the repository at this point in the history
### Details:
- Similar in functionality to
https://pytorch.org/docs/stable/generated/torch.nn.Fold.html, Col2Im is
`torch.nn.Fold` restricted to two output spatial dimensions

### Tickets:
 - CVS-138919

### Related PRs:
- #24197
- #23947
- #24569

---------

Co-authored-by: Michal Lukaszewski <michal.lukaszewski@intel.com>
  • Loading branch information
p-wysocki and mlukasze authored May 21, 2024
1 parent 1aa2e17 commit fba3ec0
Show file tree
Hide file tree
Showing 10 changed files with 462 additions and 16 deletions.
1 change: 1 addition & 0 deletions src/core/include/openvino/opsets/opset15_tbl.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -16,3 +16,4 @@ _OPENVINO_OP_REG(ShapeOf, ov::op::v3)
_OPENVINO_OP_REG(ScatterNDUpdate, ov::op::v15)
_OPENVINO_OP_REG(EmbeddingBagPacked, ov::op::v15)
_OPENVINO_OP_REG(EmbeddingBagOffsets, ov::op::v15)
_OPENVINO_OP_REG(Col2Im, ov::op::v15)
75 changes: 75 additions & 0 deletions src/core/reference/include/openvino/reference/col2im.hpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,75 @@
// Copyright (C) 2018-2024 Intel Corporation
// SPDX-License-Identifier: Apache-2.0
//

#pragma once

#include <algorithm>

#include "openvino/core/shape.hpp"

namespace ov {
namespace reference {
template <typename T, typename T_idx>
void col2im(const T* data,
const Shape& data_shape,
const T_idx* output_size,
const T_idx* kernel_size,
T* out,
const Strides& strides,
const Strides& dilations,
const Shape& pads_begin,
const Shape& pads_end) {
// fill output with zeros to account for values missing due to dilation and stride
const auto kernel_product = kernel_size[0] * kernel_size[1];
const bool is_batched = data_shape.size() == 3;
const int64_t C_idx = is_batched ? 1 : 0;
const int64_t channels_per_column = data_shape[C_idx];
const int64_t channel_count = channels_per_column / kernel_product;
const int64_t batch_count = is_batched ? data_shape[0] : 1;
std::fill_n(out, batch_count * output_size[0] * output_size[1] * channel_count, T(0));

// calculate the original height and width
auto get_original_dimension = [&](const int64_t idx) {
return (output_size[idx] + pads_begin[idx] + pads_end[idx] - (dilations[idx] * (kernel_size[idx] - 1) + 1)) /
strides[idx] +
1;
};
const int64_t original_height = get_original_dimension(0);
const int64_t original_width = get_original_dimension(1);

auto get_image_dimension_index = [&](const int64_t column_dim_idx, const int64_t dim_offset, const int64_t idx) {
return column_dim_idx * strides[idx] - pads_begin[idx] + dim_offset * dilations[idx];
};
for (int64_t batch = 0; batch < batch_count; ++batch) {
for (int64_t column = 0; column < channels_per_column; ++column) {
const auto width_offset = column % kernel_size[1];
const auto height_offset = (column / kernel_size[1]) % kernel_size[0];
const auto channel_idx = column / kernel_product;

for (int64_t column_height_idx = 0; column_height_idx < original_height; ++column_height_idx) {
const int64_t image_height_idx = get_image_dimension_index(column_height_idx, height_offset, 0);
if (image_height_idx >= 0 && image_height_idx < output_size[0]) {
for (int64_t column_width_idx = 0; column_width_idx < original_width; ++column_width_idx) {
const int64_t image_width_idx = get_image_dimension_index(column_width_idx, width_offset, 1);
if (image_width_idx >= 0 && image_width_idx < output_size[1]) {
const int64_t img_idx =
((batch * channel_count + channel_idx) * output_size[0] + image_height_idx) *
output_size[1] +
image_width_idx;
const int64_t data_idx =
((batch * channels_per_column + column) * original_height + column_height_idx) *
original_width +
column_width_idx;

// sum the overlapping values
out[img_idx] += data[data_idx];
}
}
}
}
}
}
}
} // namespace reference
} // namespace ov
24 changes: 11 additions & 13 deletions src/core/src/op/col2im.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
#include "itt.hpp"
#include "openvino/core/validation_util.hpp"
#include "openvino/op/op.hpp"
#include "openvino/reference/col2im.hpp"

namespace ov {
namespace op {
Expand Down Expand Up @@ -42,20 +43,17 @@ void Col2Im::validate_and_infer_types() {

const auto& data_element_type = get_input_element_type(0);
const auto& output_size_element_type = get_input_element_type(1);
const bool is_valid_output_size_type =
output_size_element_type == element::i32 || output_size_element_type == element::i64;
NODE_VALIDATION_CHECK(this,
is_valid_output_size_type,
"The element type of the output_size tensor must be i32 or i64 type. Got: ",
output_size_element_type);

const auto& kernel_size_element_type = get_input_element_type(2);
const bool is_valid_kernel_size_type =
kernel_size_element_type == element::i32 || kernel_size_element_type == element::i64;
NODE_VALIDATION_CHECK(this,
is_valid_kernel_size_type,
"The element type of the kernel_size tensor must be i32 or i64 type. Got: ",
kernel_size_element_type);
const bool is_valid_index_type =
(output_size_element_type == element::i32 || output_size_element_type == element::i64) &&
output_size_element_type == kernel_size_element_type;
NODE_VALIDATION_CHECK(
this,
is_valid_index_type,
"The element types of the output_size and kernel_size tensors must match and be of i32 or i64 type. Got: ",
output_size_element_type,
" and ",
kernel_size_element_type);

const auto output_shapes = shape_infer(this, ov::util::get_node_input_partial_shapes(*this));
set_output_type(0, data_element_type, output_shapes[0]);
Expand Down
2 changes: 1 addition & 1 deletion src/core/tests/opset.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -75,7 +75,7 @@ INSTANTIATE_TEST_SUITE_P(opset,
OpsetTestParams{ov::get_opset12, 178},
OpsetTestParams{ov::get_opset13, 186},
OpsetTestParams{ov::get_opset14, 189},
OpsetTestParams{ov::get_opset15, 6}),
OpsetTestParams{ov::get_opset15, 7}),
OpsetTestNameGenerator{});

class MyOpOld : public ov::op::Op {
Expand Down
13 changes: 11 additions & 2 deletions src/core/tests/type_prop/col2im.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -68,17 +68,26 @@ TEST_F(TypePropCol2ImTest, incorrect_types) {
const auto data = std::make_shared<Parameter>(element::i32, PartialShape{3, 12, 225});
const auto output_size = std::make_shared<Parameter>(element::i64, PartialShape{2});
const auto kernel_size = std::make_shared<Parameter>(element::i64, PartialShape{2});
constexpr auto error_substring =
"The element types of the output_size and kernel_size tensors must match and be of i32 or i64 type";
{
const auto output_size_i4 = std::make_shared<Parameter>(element::i4, PartialShape{16, 16});
OV_EXPECT_THROW(std::ignore = make_op(data, output_size_i4, kernel_size),
ov::NodeValidationFailure,
HasSubstr("The element type of the output_size tensor must be i32 or i64 type"));
HasSubstr(error_substring));
}
{
const auto kernel_size_u8 = std::make_shared<Parameter>(element::u8, PartialShape{2, 2});
OV_EXPECT_THROW(std::ignore = make_op(data, output_size, kernel_size_u8),
ov::NodeValidationFailure,
HasSubstr("The element type of the kernel_size tensor must be i32 or i64 type"));
HasSubstr(error_substring));
}
{
const auto output_size_i32 = std::make_shared<Parameter>(element::i32, PartialShape{16, 16});
const auto kernel_size_i64 = std::make_shared<Parameter>(element::i64, PartialShape{2, 2});
OV_EXPECT_THROW(std::ignore = make_op(data, output_size_i32, kernel_size_i64),
ov::NodeValidationFailure,
HasSubstr(error_substring));
}
}

Expand Down
80 changes: 80 additions & 0 deletions src/plugins/template/backend/ops/col2im.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,80 @@
// Copyright (C) 2018-2024 Intel Corporation
// SPDX-License-Identifier: Apache-2.0
//

#include "openvino/reference/col2im.hpp"

#include "col2im_shape_inference.hpp"
#include "element_visitor.hpp"
#include "evaluate_node.hpp"

template <ov::element::Type_t ET_data, ov::element::Type_t ET_idx>
bool evaluate_index_type(const std::shared_ptr<ov::op::v15::Col2Im>& op,
ov::TensorVector& outputs,
const ov::TensorVector& inputs) {
using T_data = typename ov::element_type_traits<ET_data>::value_type;
using T_idx = typename ov::element_type_traits<ET_idx>::value_type;
const std::vector<ov::PartialShape> input_shapes{op->get_input_shape(0),
op->get_input_shape(1),
op->get_input_shape(2)};
const auto output_shape =
ov::op::v15::shape_infer(op.get(), input_shapes, make_tensor_accessor(inputs)).front().to_shape();
outputs.front().set_shape(output_shape);
ov::reference::col2im(inputs[0].data<const T_data>(),
inputs[0].get_shape(),
inputs[1].data<const T_idx>(),
inputs[2].data<const T_idx>(),
outputs[0].data<T_data>(),
op->get_strides(),
op->get_dilations(),
op->get_pads_begin(),
op->get_pads_end());
return true;
}

template <ov::element::Type_t ET_data>
bool evaluate_data_type(const std::shared_ptr<ov::op::v15::Col2Im>& op,
ov::TensorVector& outputs,
const ov::TensorVector& inputs) {
const auto& index_type = op->get_input_element_type(1);
using ov::op::v15::Col2Im;
using namespace ov::element;
switch (index_type) {
case i32:
return evaluate_index_type<ET_data, i32>(ov::as_type_ptr<Col2Im>(op), outputs, inputs);
case i64:
return evaluate_index_type<ET_data, i64>(ov::as_type_ptr<Col2Im>(op), outputs, inputs);
default:
OPENVINO_THROW("Unhandled index type ", index_type, " in evaluate_node()");
}
}

template <>
bool evaluate_node<ov::op::v15::Col2Im>(std::shared_ptr<ov::Node> node,
ov::TensorVector& outputs,
const ov::TensorVector& inputs) {
const auto& element_type = node->get_output_element_type(0);

using ov::op::v15::Col2Im;
using namespace ov::element;
switch (element_type) {
case i8:
return evaluate_data_type<i8>(ov::as_type_ptr<Col2Im>(node), outputs, inputs);
case i32:
return evaluate_data_type<i32>(ov::as_type_ptr<Col2Im>(node), outputs, inputs);
case i64:
return evaluate_data_type<i64>(ov::as_type_ptr<Col2Im>(node), outputs, inputs);
case u8:
return evaluate_data_type<u8>(ov::as_type_ptr<Col2Im>(node), outputs, inputs);
case u32:
return evaluate_data_type<u32>(ov::as_type_ptr<Col2Im>(node), outputs, inputs);
case u64:
return evaluate_data_type<u64>(ov::as_type_ptr<Col2Im>(node), outputs, inputs);
case f16:
return evaluate_data_type<f16>(ov::as_type_ptr<Col2Im>(node), outputs, inputs);
case f32:
return evaluate_data_type<f32>(ov::as_type_ptr<Col2Im>(node), outputs, inputs);
default:
OPENVINO_THROW("Unhandled data type ", element_type, " in evaluate_node()");
}
}
4 changes: 4 additions & 0 deletions src/plugins/template/backend/ops/ops_evaluates.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -494,6 +494,10 @@ extern template bool evaluate_node<ov::op::v14::Inverse>(std::shared_ptr<ov::Nod
ov::TensorVector& outputs,
const ov::TensorVector& inputs);

extern template bool evaluate_node<ov::op::v15::Col2Im>(std::shared_ptr<ov::Node> node,
ov::TensorVector& outputs,
const ov::TensorVector& inputs);

extern template bool evaluate_node<ov::op::v14::ROIAlignRotated>(std::shared_ptr<ov::Node> node,
ov::TensorVector& outputs,
const ov::TensorVector& inputs);
Expand Down
1 change: 1 addition & 0 deletions src/plugins/template/backend/opset_int_tbl.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -166,6 +166,7 @@ _OPENVINO_OP_REG(ROIAlignRotated, ov::op::v14)

_OPENVINO_OP_REG(EmbeddingBagOffsets, op::v15)
_OPENVINO_OP_REG(EmbeddingBagPacked, op::v15)
_OPENVINO_OP_REG(Col2Im, ov::op::v15)

_OPENVINO_OP_REG(AUGRUCell, ov::op::internal)
_OPENVINO_OP_REG(AUGRUSequence, ov::op::internal)
Expand Down
Loading

0 comments on commit fba3ec0

Please sign in to comment.