Skip to content

Commit

Permalink
Compile
Browse files Browse the repository at this point in the history
  • Loading branch information
justinchuby committed Sep 22, 2023
1 parent 38b29c6 commit 0969eea
Show file tree
Hide file tree
Showing 4 changed files with 13 additions and 15 deletions.
1 change: 0 additions & 1 deletion onnx/test/automatic_upgrade_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -1532,7 +1532,6 @@ def test_HannWindow(self) -> None:
def test_HammingWindow(self) -> None:
self._test_window_function("HammingWindow")

@pytest.mark.xfail(reason="FIXME(#5613): Implement version converters for DFT")
def test_DFT(self) -> None:
self._test_op_upgrade("DFT", 17, [[2, 16, 1], []], [[2, 16, 2]])
self._test_op_upgrade("DFT", 17, [[2, 16, 2], []], [[2, 16, 2]])
Expand Down
2 changes: 1 addition & 1 deletion onnx/version_converter/adapters/axis_attribute_to_input.h
Original file line number Diff line number Diff line change
Expand Up @@ -46,7 +46,7 @@ class AxisAttributeToInput : public Adapter {

// Add the optional inputs if they don't exist
while (inputs.size() < axis_index) {
const empty_input = graph->create(kUndefined);
Node* empty_input = graph->create(kUndefined);
empty_input->insertBefore(node);
node->addInput(empty_input->output());
}
Expand Down
21 changes: 10 additions & 11 deletions onnx/version_converter/adapters/axis_input_to_attribute.h
Original file line number Diff line number Diff line change
Expand Up @@ -34,33 +34,32 @@ class AxisInputToAttribute : public Adapter {
const ArrayRef<Value*>& inputs = node->inputs();

// 1. Handle when axis is not given
if !(inputs.size() > this->axis_index && inputs[this->axis_index]->node()->kind() != kUndefined) {
if (!(inputs.size() > this->axis_index && inputs[this->axis_index]->node()->kind() != kUndefined)) {
node->i_(kaxis, this->default_axis);
return EnsureAndReturnNode(node);
}

// 2. Get axis from constant operator
Value* index_val = inputs[this->axis_index];
Node* node = index_val->node();
Value* axis_val = inputs[this->axis_index];
Node* axis_node = axis_val->node();
// Identify whether we have a Constant Op or an Initializer
if (node->kind() == kConstant) {
if (axis_node->kind() == kConstant) {
// Get value attribute of kConstant
const std::vector<int64_t>& int64s = node->t(kvalue).int64s();
const std::vector<int64_t>& int64s = axis_node->t(kvalue).int64s();
if (int64s.empty()) {
// Also handle raw data
std::string raw_data = node->t(kvalue).raw();
std::string raw_data = axis_node->t(kvalue).raw();
ONNX_ASSERTM(
raw_data.size() != 0 && raw_data.size() % 8 == 0,
"Raw Data must be non-empty and size must be a multiple of 8");
// TODO(justinchuby): Why cast to char* first?
int64_t* raw = const_cast<int64_t*>(const_cast<char*>(raw_data.c_str()));
const int64_t* raw = reinterpret_cast<int64_t*>(const_cast<char*>(raw_data.c_str()));
node->i_(kaxis, static_cast<int64_t>(raw[0]));
} else {
node->i_(kaxis, int64s.at(0));
}
// If Constant node isn't used anywhere else, remove it
node->removeInput(this->axis_index);
if (index_val->uses().size() < 1) {
if (axis_val->uses().size() < 1) {
node->destroy();
}
return EnsureAndReturnNode(node);
Expand All @@ -73,8 +72,8 @@ class AxisInputToAttribute : public Adapter {
node->i_(kaxis, initializer.int64s().at(0));
node->removeInput(this->axis_index);
// Remove initializer
if (index_val->uses().size() < 1)
graph->eraseInitializerAndInput(index_val);
if (axis_val->uses().size() < 1)
graph->eraseInitializerAndInput(axis_val);
break;
}
}
Expand Down
4 changes: 2 additions & 2 deletions onnx/version_converter/convert.h
Original file line number Diff line number Diff line change
Expand Up @@ -566,14 +566,14 @@ class DefaultVersionConverter : public BaseVersionConverter {
registerAdapter(std::make_unique<CompatibleAdapter>("Size", OpSetID(18), OpSetID(19)));

/******** 19 -> 20 ********/
registerAdapter(std::make_unique<AxisAttributeToInput>("DFT", OpSetID(19), OpSetID(20)));
registerAdapter(std::make_unique<AxisAttributeToInput>("DFT", OpSetID(19), OpSetID(20), 2, 1));
registerAdapter(std::make_unique<CompatibleAdapter>("ConstantOfShape", OpSetID(19), OpSetID(20)));
registerAdapter(std::make_unique<CompatibleAdapter>("ReduceMax", OpSetID(19), OpSetID(20)));
registerAdapter(std::make_unique<CompatibleAdapter>("ReduceMin", OpSetID(19), OpSetID(20)));
registerAdapter(std::make_unique<GridSample_19_20>());

/******** 20 -> 19 ********/
registerAdapter(std::make_unique<AxisInputToAttribute>("DFT", OpSetID(20), OpSetID(19)));
registerAdapter(std::make_unique<AxisInputToAttribute>("DFT", OpSetID(20), OpSetID(19), 2, -2));
const std::vector<TensorProto_DataType> reduce_min_max_18_unallowed_types = {TensorProto_DataType_BOOL};
registerAdapter(
std::make_unique<TypeRestriction>("ReduceMax", OpSetID(20), OpSetID(19), reduce_min_max_18_unallowed_types));
Expand Down

0 comments on commit 0969eea

Please sign in to comment.