Skip to content

Commit

Permalink
apply review comments
Browse files Browse the repository at this point in the history
  • Loading branch information
eaidova committed Sep 21, 2023
1 parent ff102c3 commit 60f5d01
Showing 1 changed file with 6 additions and 10 deletions.
16 changes: 6 additions & 10 deletions src/frontends/pytorch/src/transforms/prim_list_unpack_replacer.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -180,21 +180,17 @@ PrimListUnpackReplacer::PrimListUnpackReplacer() {
"aten::broadcast_tensors: only prim::ListConstruct supported as input.");
return false;
}
auto zero = opset10::Constant::create(element::i32, Shape{}, {0});
Output<Node> final_shape_t = zero;
Output<Node> final_shape_t = opset10::Constant::create(element::f32, Shape{}, {0});
for (auto input : tensors->inputs()) {
auto tensor_shape = std::make_shared<opset10::ShapeOf>(input.get_source_output());
auto zero_broadcasted =
std::make_shared<opset10::Broadcast>(zero, tensor_shape, ov::op::BroadcastType::BIDIRECTIONAL);
final_shape_t = std::make_shared<opset10::Add>(final_shape_t, zero_broadcasted);
auto tensor = rg.make<opset10::Convert>(input.get_source_output(), element::f32);
final_shape_t = rg.make<opset10::Add>(final_shape_t, tensor);
}
auto final_shape = std::make_shared<opset10::ShapeOf>(final_shape_t, element::i32);
auto final_shape = rg.make<opset10::ShapeOf>(final_shape_t, element::i32);
OutputVector outputs;
for (auto input : tensors->inputs()) {
outputs.push_back(std::make_shared<opset10::Broadcast>(input.get_source_output(),
final_shape,
ov::op::BroadcastType::BIDIRECTIONAL));
outputs.push_back(rg.make<opset10::Broadcast>(input.get_source_output(), final_shape));
}
copy_runtime_info_and_name(list_unpack, rg.get(), {input_node});
replace_node(list_unpack, outputs);
return true;
}
Expand Down

0 comments on commit 60f5d01

Please sign in to comment.