Skip to content

Commit

Permalink
[PT FE]: support aten::broadcast_tensors
Browse files Browse the repository at this point in the history
  • Loading branch information
eaidova committed Sep 21, 2023
1 parent f1823b8 commit 4c0999b
Show file tree
Hide file tree
Showing 4 changed files with 69 additions and 51 deletions.
49 changes: 0 additions & 49 deletions src/frontends/pytorch/src/op/broadcast_tensors.cpp

This file was deleted.

2 changes: 0 additions & 2 deletions src/frontends/pytorch/src/op_table.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,6 @@ OP_CONVERTER(translate_batch_norm);
OP_CONVERTER(translate_bitwise_and);
OP_CONVERTER(translate_bitwise_not);
OP_CONVERTER(translate_bitwise_or);
OP_CONVERTER(translate_broadcast_tensors);
OP_CONVERTER(translate_cat);
OP_CONVERTER(translate_cdist);
OP_CONVERTER(translate_clamp);
Expand Down Expand Up @@ -250,7 +249,6 @@ const std::map<std::string, CreatorFunction> get_supported_ops_ts() {
{"aten::batch_norm", op::translate_batch_norm},
{"aten::bitwise_not", op::translate_bitwise_not},
{"aten::bmm", op::translate_1to1_match_2_inputs<opset10::MatMul>},
{"aten::broadcast_tensors", op::translate_broadcast_tensors},
{"aten::Bool", op::translate_bool},
{"aten::cat", op::translate_cat},
{"aten::concat", op::translate_cat},
Expand Down
24 changes: 24 additions & 0 deletions src/frontends/pytorch/src/transforms/prim_list_unpack_replacer.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -173,6 +173,30 @@ PrimListUnpackReplacer::PrimListUnpackReplacer() {
}
}

if (auto broadcast_tensors = cast_fw_node(input_node, "aten::broadcast_tensors")) {
auto tensors = cast_fw_node(broadcast_tensors->input_value(0).get_node_shared_ptr(), "prim::ListConstruct");
if (!tensors) {
add_exception_to_fw_node(input_node,
"aten::broadcast_tensors: only prim::ListConstruct supported as input.");
return false;
}
auto zero = opset10::Constant::create(element::i32, Shape{}, {0});
Output<Node> final_shape_t = zero;
for (auto input : tensors->inputs()) {
auto tensor_shape = std::make_shared<opset10::ShapeOf>(input.get_source_output());
auto zero_broadcasted =
std::make_shared<opset10::Broadcast>(zero, tensor_shape, ov::op::BroadcastType::BIDIRECTIONAL);
final_shape_t = std::make_shared<opset10::Add>(final_shape_t, zero_broadcasted);
}
auto final_shape = std::make_shared<opset10::ShapeOf>(final_shape_t, element::i32);
OutputVector outputs;
for (auto input : tensors->inputs()) {
outputs.push_back(std::make_shared<opset10::Broadcast>(input.get_source_output(), final_shape, ov::op::BroadcastType::BIDIRECTIONAL));
}
replace_node(list_unpack, outputs);
return true;
}

if (auto unbind = cast_fw_node(input_node, "aten::unbind")) {
const auto input = unbind->get_input_source_output(0);
const auto axis = unbind->get_input_source_output(1);
Expand Down
45 changes: 45 additions & 0 deletions tests/layer_tests/pytorch_tests/test_broadcast_tensors.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,45 @@
# Copyright (C) 2018-2023 Intel Corporation
# SPDX-License-Identifier: Apache-2.0

import pytest

from pytorch_layer_test_class import PytorchLayerTest


class TestBroadcastTensors(PytorchLayerTest):
def _prepare_input(self, x_shape, y_shape, z_shape, x_dtype, y_dtype, z_dtype):
import numpy as np
return (
np.random.randn(*x_shape).astype(x_dtype),
np.random.randn(*y_shape).astype(y_dtype),
np.random.randn(*z_shape).astype(z_dtype))

def create_model(self):
import torch

class aten_broadcast_tensors(torch.nn.Module):
def __init__(self):
super(aten_broadcast_tensors, self).__init__()

def forward(self, x, y, z):
x1, y1, z1 = torch.broadcast_tensors(x, y, z)
return x1, y1, z1

ref_net = None

return aten_broadcast_tensors(), ref_net, ("prim::ListConstruct", "aten::broadcast_tensors", "prim::ListUnpack")

@pytest.mark.nightly
@pytest.mark.precommit
@pytest.mark.parametrize("x_shape", [[1, ], [2, 1], [2, 2, 1]])
@pytest.mark.parametrize("y_shape", [[2, ], [1, 2], [1, 2, 1]])
@pytest.mark.parametrize("z_shape", [[1, 2], [2, 2], [1, 2, 1, 1]])
@pytest.mark.parametrize("x_dtype", ["float32", "int32"])
@pytest.mark.parametrize("y_dtype", ["float32", "int32"])
@pytest.mark.parametrize("z_dtype", ["float32", "int32"])
def test_broadcast_tensors(self, x_shape, y_shape, z_shape, x_dtype, y_dtype, z_dtype, ie_device, precision, ir_version):
self._test(*self.create_model(), ie_device, precision, ir_version, kwargs_to_prepare_input={
"x_shape": x_shape, "x_dtype": x_dtype,
"y_shape": y_shape, "y_dtype": y_dtype,
"z_shape": z_shape, "z_dtype": z_dtype,
})

0 comments on commit 4c0999b

Please sign in to comment.