Skip to content

Commit

Permalink
Update axis_input_to_attribute.h
Browse files Browse the repository at this point in the history
  • Loading branch information
justinchuby committed Sep 22, 2023
1 parent 74bb39c commit 482144b
Showing 1 changed file with 26 additions and 24 deletions.
50 changes: 26 additions & 24 deletions onnx/version_converter/adapters/axis_input_to_attribute.h
Original file line number Diff line number Diff line change
Expand Up @@ -27,60 +27,62 @@ class AxisInputToAttribute : public Adapter {
const OpSetID& target,
int64_t axis_index,
int64_t default_axis)
: Adapter(op_name, initial, target) {
this->axis_index = axis_index;
this->default_axis = default_axis;
}
: Adapter(op_name, initial, target), axis_index(axis_index), default_axis(default_axis) {}

Node* adapt(std::shared_ptr<Graph> graph, Node* node) const override {
// Identify if axis is statically determined; if so, feed as attribute
const ArrayRef<Value*>& inputs = node->inputs();

// Handle when axis is not given
if !(inputs.size() > this->axis_index && inputs[this->axis_index]->node()->kind() != kUndefined) {
node->i_(kaxis, this->default_axis);
return node;
}

// Get axis from initializer or constant operator
// Identify whether we have a Constant Op or an Initializer
// TODO(justinchuby): Avoid segfault
Value* const_val = inputs[this->axis_index];
// TODO(justinchuby): How do I check if axis is empty? In which cases it should take default
Node* node_ptr = const_val->node();
if (node_ptr->kind() == kConstant) {
Value* index_val = inputs[this->axis_index];
Node* node = index_val->node();
if (node->kind() == kConstant) {
// Get value attribute of kConstant
const std::vector<int64_t>& int64s = node_ptr->t(kvalue).int64s();
const std::vector<int64_t>& int64s = node->t(kvalue).int64s();
if (int64s.empty()) {
// Also handle raw data
std::string raw_data = node_ptr->t(kvalue).raw();
std::string raw_data = node->t(kvalue).raw();
ONNX_ASSERTM(
raw_data.size() != 0 && raw_data.size() % 8 == 0,
"Raw Data must be non-empty and size must be a multiple of 8");
int64_t* raw = (int64_t*)const_cast<char*>(raw_data.c_str());
// FIXME(justinchuby): Make sure this logic is correct
// TODO(justinchuby): Why cast to char* first?
int64_t* raw = const_cast<int64_t*>(const_cast<char*>(raw_data.c_str()));
node->i_(kaxis, static_cast<int64_t>(raw[0]));
} else {
node->i_(kaxis, static_cast<int64_t>(int64s.at(0)));
node->i_(kaxis, int64s.at(0));
}
// If Constant node isn't used anywhere else, remove it
node->removeInput(this->axis_index);
if (const_val->uses().size() < 1) {
node_ptr->destroy();
if (index_val->uses().size() < 1) {
node->destroy();
}
} else {
// Get Value name, find Initializer with same name
for (const auto& initializer : graph->initializers()) {
if (initializer.name() == inputs[1]->uniqueName()) {
node->i_(kaxis, static_cast<int64_t>(initializer.int64s().at(0)));
if (initializer.name() == inputs[this->axis_index]->uniqueName()) {
node->i_(kaxis, initializer.int64s().at(0));
node->removeInput(this->axis_index);
// Remove initializer
if (const_val->uses().size() < 1)
graph->eraseInitializerAndInput(const_val);
if (index_val->uses().size() < 1)
graph->eraseInitializerAndInput(index_val);
break;
}
}
}
ONNX_ASSERTM(node->hasAttribute(kaxis), "No initializer or constant input to node found");
ONNX_ASSERTM(node->hasAttribute(kaxis), "Axis attribute not created. This may be a bug.");
return node;
}

private:
int64_t axis_index;
int64_t default_axis;
private:
int64_t axis_index;
int64_t default_axis;
};

} // namespace version_conversion
Expand Down

0 comments on commit 482144b

Please sign in to comment.