Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[CORE] Change get_input_const_data_as return type from std::unique_ptr to ov::optional #24443

Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -76,7 +76,7 @@ std::vector<TRShape> shape_infer(const util::FFTBase* op,

util::fft_common_validation::shape_validation(op,
input_shapes,
axes.get(),
axes,
util::fft_common_validation::FFTKind::ComplexInput);

output_shape = input_shape;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -101,7 +101,7 @@ void validate_signal_size(const ov::op::util::FFTBase* op,
template <class T>
void shape_validation(const ov::op::util::FFTBase* op,
const std::vector<T>& input_shapes,
std::vector<int64_t>* axes,
ov::optional<std::vector<int64_t>>& axes,
FFTKind fft_kind) {
const auto& input_shape = input_shapes[0];
const auto& axes_shape = input_shapes[1];
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -153,20 +153,16 @@ TRShape make_padded_shape(const TShape& input, TInputIter pads_begin, TInputIter
* @return Not null pointer with axes values or null pointer if can't get axes from input.
*/
template <class TShape, class TRes = std::vector<int64_t>>
std::unique_ptr<TRes> get_axes(const Node* const op,
size_t port,
bool has_axes,
size_t rank,
const ITensorAccessor& ta) {
std::unique_ptr<TRes> axes;
ov::optional<TRes> get_axes(const Node* const op, size_t port, bool has_axes, size_t rank, const ITensorAccessor& ta) {
ov::optional<TRes> axes;
if (has_axes) {
using TAxis = typename TRes::value_type;
axes = std::move(get_input_const_data_as<TShape, TAxis, TRes>(op, port, ta));
if (axes) {
validate::axes_values(op, *axes, rank);
}
} else {
axes.reset(new TRes(rank));
axes.emplace(rank);
std::iota(axes->begin(), axes->end(), 0);
}
return axes;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@ std::vector<TRShape> shape_infer(const IRDFT* op,

util::fft_common_validation::shape_validation(op,
input_shapes,
axes.get(),
axes,
util::fft_common_validation::FFTKind::ComplexInput);

if (input_shape.rank().is_dynamic()) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,7 @@ std::vector<TRShape> shape_infer(const RDFT* op,

util::fft_common_validation::shape_validation(op,
input_shapes,
axes.get(),
axes,
util::fft_common_validation::FFTKind::RealInput);

if (input_shape.rank().is_dynamic()) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -56,7 +56,7 @@ std::vector<TRShape> reduce_shape_infer(const util::ReductionBase* op,
"Axes input must be a scalar or 1D input. Got: ",
axes_shape);

const auto axes_val = ov::op::get_input_const_data_as<TRShape, int64_t>(op, 1, tensor_accessor);
auto axes_val = ov::op::get_input_const_data_as<TRShape, int64_t>(op, 1, tensor_accessor);

if (data_rank.is_static() && axes_val) {
ov::util::normalize_axes(op, data_rank.get_length(), *axes_val);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -54,7 +54,7 @@ std::vector<TRShape> shape_infer(const RMSNorm* op,
}

// Axes values validation
if (const auto axes_val = ov::op::get_input_const_data_as<TRShape, int64_t>(op, 1, tensor_accessor)) {
if (auto axes_val = ov::op::get_input_const_data_as<TRShape, int64_t>(op, 1, tensor_accessor)) {
ov::util::normalize_axes(op, data_rank.get_length(), *axes_val);
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,7 @@ std::vector<TRShape> shape_infer(const Roll* op,
"Axes must be a scalar or 1D tensor.");

if (data_pshape.rank().is_static()) {
if (const auto axes = get_input_const_data_as<TRShape, int64_t>(op, 2, ta)) {
if (auto axes = get_input_const_data_as<TRShape, int64_t>(op, 2, ta)) {
ov::util::normalize_axes(op, data_pshape.size(), *axes);
}
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -61,12 +61,12 @@ std::vector<TRShape> shape_infer(const StridedSlice* op,
const auto begin = get_input_bounds<TRShape, int64_t>(op, 1, ta);
const auto end = get_input_bounds<TRShape, int64_t>(op, 2, ta);

std::unique_ptr<std::vector<int64_t>> strides;
ov::optional<std::vector<int64_t>> strides;
if (input_shapes.size() > 3) {
strides = get_input_const_data_as<TRShape, int64_t>(op, 3, ta);
} else if (begin) {
// generate default strides
strides.reset(new std::vector<int64_t>(begin->size(), 1));
strides.emplace(begin->size(), 1);
}

// compute and check a number of axes for which begin, end, and strides are defined
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -74,7 +74,7 @@ std::vector<TRShape> shape_infer(const Transpose* op,
"Input order must have shape [n], where n is the rank of arg.");
}

const auto axes = get_input_const_data_as<TShape, int64_t>(op, Transpose::ORDER, tensor_accessor);
auto axes = get_input_const_data_as<TShape, int64_t>(op, Transpose::ORDER, tensor_accessor);

auto output_shapes = std::vector<TRShape>();
if (axes && input_rank.is_static()) {
Expand Down
36 changes: 18 additions & 18 deletions src/core/shape_inference/include/utils.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -204,21 +204,21 @@ template <class TShape,
class TRes = std::vector<TData>,
class UnaryOperation = ov::util::Cast<TData>,
typename std::enable_if<!std::is_same<TShape, ov::PartialShape>::value>::type* = nullptr>
std::unique_ptr<TRes> get_input_const_data_as(const ov::Node* op,
size_t idx,
const ITensorAccessor& tensor_accessor,
UnaryOperation&& func = ov::util::Cast<TData>()) {
ov::optional<TRes> get_input_const_data_as(const ov::Node* op,
size_t idx,
const ITensorAccessor& tensor_accessor,
UnaryOperation&& func = ov::util::Cast<TData>()) {
if (auto t = tensor_accessor(idx)) {
return std::unique_ptr<TRes>(new TRes(get_tensor_data_as<TData, TRes>(t, std::forward<UnaryOperation>(func))));
return {get_tensor_data_as<TData, TRes>(t, std::forward<UnaryOperation>(func))};
} else {
const auto& constant = ov::as_type_ptr<ov::opset1::Constant>(op->get_input_node_shared_ptr(idx));
NODE_VALIDATION_CHECK(op, constant != nullptr, "Static shape inference lacks constant data on port ", idx);
const auto& et = constant->get_element_type();
const auto& shape = constant->get_shape();
return std::unique_ptr<TRes>(new TRes(get_raw_data_as<TData, TRes>(et,
constant->get_data_ptr(),
shape_size(shape),
std::forward<UnaryOperation>(func))));
return {get_raw_data_as<TData, TRes>(et,
constant->get_data_ptr(),
shape_size(shape),
std::forward<UnaryOperation>(func))};
}
}

Expand All @@ -245,20 +245,20 @@ template <class TShape,
class TRes = std::vector<TData>,
class UnaryOperation = ov::util::Cast<TData>,
typename std::enable_if<std::is_same<TShape, ov::PartialShape>::value>::type* = nullptr>
std::unique_ptr<TRes> get_input_const_data_as(const ov::Node* op,
size_t idx,
const ITensorAccessor& tensor_accessor,
UnaryOperation&& func = ov::util::Cast<TData>()) {
ov::optional<TRes> get_input_const_data_as(const ov::Node* op,
size_t idx,
const ITensorAccessor& tensor_accessor,
UnaryOperation&& func = ov::util::Cast<TData>()) {
if (auto t = tensor_accessor(idx)) {
return std::unique_ptr<TRes>(new TRes(get_tensor_data_as<TData, TRes>(t, std::forward<UnaryOperation>(func))));
return {get_tensor_data_as<TData, TRes>(t, std::forward<UnaryOperation>(func))};
} else if (const auto& constant =
(idx < op->get_input_size()) ? ov::util::get_constant_from_source(op->input_value(idx)) : nullptr) {
const auto& et = constant->get_element_type();
const auto& shape = constant->get_shape();
return std::unique_ptr<TRes>(new TRes(get_raw_data_as<TData, TRes>(et,
constant->get_data_ptr(),
shape_size(shape),
std::forward<UnaryOperation>(func))));
return {get_raw_data_as<TData, TRes>(et,
constant->get_data_ptr(),
shape_size(shape),
std::forward<UnaryOperation>(func))};
} else {
return {};
}
Expand Down
Loading