Skip to content

Commit

Permalink
Fix computation of num_inputs for Python API create_operator_entry fo…
Browse files Browse the repository at this point in the history
…r custom operators with 0 arguments. (apache#7967)
  • Loading branch information
fhieber authored and crazy-cat committed Oct 26, 2017
1 parent 2db1ac4 commit 8cb6021
Showing 1 changed file with 3 additions and 4 deletions.
7 changes: 3 additions & 4 deletions src/operator/custom/custom.cc
Expand Up @@ -225,9 +225,8 @@ OpStatePtr CreateState(const NodeAttrs& attrs, Context ctx,
const std::vector<int>& in_type) {
const CustomParam& params = nnvm::get<CustomParam>(attrs.parsed);

size_t total = params.num_args + params.num_outs + params.num_auxs;
std::vector<uint32_t*> shapes(total);
std::vector<int> ndims(total);
std::vector<uint32_t*> shapes(params.num_args);
std::vector<int> ndims(params.num_args);
size_t buff_size = 0;
for (const auto& i : in_shape) buff_size += i.ndim();
std::vector<uint32_t> buff(buff_size);
Expand All @@ -246,7 +245,7 @@ OpStatePtr CreateState(const NodeAttrs& attrs, Context ctx,
MXCallbackList *op_info = new MXCallbackList;
CHECK(reinterpret_cast<CustomOpCreateFunc>(
params.info->callbacks[kCustomOpPropCreateOperator])(
os.str().c_str(), shapes.size(), shapes.data(), ndims.data(), in_type.data(),
os.str().c_str(), params.num_args, shapes.data(), ndims.data(), in_type.data(),
op_info, params.info->contexts[kCustomOpPropCreateOperator]));

CustomParam state = params;
Expand Down

0 comments on commit 8cb6021

Please sign in to comment.