Skip to content

Commit

Permalink
Snapshot
Browse files Browse the repository at this point in the history
Signed-off-by: Justin Chu <justinchu@microsoft.com>
  • Loading branch information
justinchuby committed Sep 19, 2023
1 parent 4d63e63 commit b2a07da
Showing 1 changed file with 22 additions and 5 deletions.
27 changes: 22 additions & 5 deletions onnx/version_converter/adapters/axis_input_to_attribute.h
Original file line number Diff line number Diff line change
Expand Up @@ -18,15 +18,28 @@ namespace version_conversion {

class AxisInputToAttribute : public Adapter {
public:
explicit AxisInputToAttribute(const std::string& op_name, const OpSetID& initial, const OpSetID& target)
: Adapter(op_name, initial, target) {}
// Convert axis from input to attribute.
// axis_index: index of the axis input
// default_axis: default value of axis
explicit AxisInputToAttribute(
const std::string& op_name,
const OpSetID& initial,
const OpSetID& target,
int64_t axis_index,
int64_t default_axis)
: Adapter(op_name, initial, target) {
this->axis_index = axis_index;
this->default_axis = default_axis;
}

Node* adapt(std::shared_ptr<Graph> graph, Node* node) const override {
// Identify if axis is statically determined; if so, feed as attribute
const ArrayRef<Value*>& inputs = node->inputs();
// Get axis from initializer or constant operator
// Identify whether we have a Constant Op or an Initializer
Value* const_val = inputs[1];
// TODO(justinchuby): Avoid segfault
Value* const_val = inputs[this->axis_index];
// TODO(justinchuby): How do I check if axis is empty? In which cases it should take default
Node* node_ptr = const_val->node();
if (node_ptr->kind() == kConstant) {
// Get value attribute of kConstant
Expand All @@ -44,7 +57,7 @@ class AxisInputToAttribute : public Adapter {
node->i_(kaxis, static_cast<int64_t>(int64s.at(0)));
}
// If Constant node isn't used anywhere else, remove it
node->removeInput(1);
node->removeInput(this->axis_index);
if (const_val->uses().size() < 1) {
node_ptr->destroy();
}
Expand All @@ -53,7 +66,7 @@ class AxisInputToAttribute : public Adapter {
for (const auto& initializer : graph->initializers()) {
if (initializer.name() == inputs[1]->uniqueName()) {
node->i_(kaxis, static_cast<int64_t>(initializer.int64s().at(0)));
node->removeInput(1);
node->removeInput(this->axis_index);
// Remove initializer
if (const_val->uses().size() < 1)
graph->eraseInitializerAndInput(const_val);
Expand All @@ -64,6 +77,10 @@ class AxisInputToAttribute : public Adapter {
ONNX_ASSERTM(node->hasAttribute(kaxis), "No initializer or constant input to node found");
return node;
}

private:
int64_t axis_index;
int64_t default_axis;
};

} // namespace version_conversion
Expand Down

0 comments on commit b2a07da

Please sign in to comment.