Skip to content

Commit

Permalink
Simplify OpInferInputListAttrs to only pass dtypes
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 288713492
Change-Id: I16b78a0cdb8919e45450f1c1d9c4d9e09e56b97f
  • Loading branch information
jaingaurav authored and tensorflower-gardener committed Jan 8, 2020
1 parent 2886154 commit 37e3630
Showing 1 changed file with 14 additions and 15 deletions.
29 changes: 14 additions & 15 deletions tensorflow/c/eager/c_api.cc
Original file line number Diff line number Diff line change
Expand Up @@ -638,34 +638,28 @@ tensorflow::Status OpInferSingleInputAttrs(TFE_Op* op,

void OpInferSingleTypeInputListAttrs(TFE_Op* op,
const tensorflow::OpDef::ArgDef& input_def,
TFE_TensorHandle** inputs,
const tensorflow::DataType dtype,
int num_inputs) {
TFE_OpInferenceContext* ictx = op->inference_ctx.get();
if (ictx->attrs.find(input_def.number_attr()) == ictx->attrs.end()) {
op->operation.MutableAttrs()->Set(input_def.number_attr(), num_inputs);
ictx->attrs.insert(input_def.number_attr());
}
if (ictx->attrs.find(input_def.type_attr()) == ictx->attrs.end()) {
op->operation.MutableAttrs()->Set(input_def.type_attr(),
inputs[0]->handle->dtype);
op->operation.MutableAttrs()->Set(input_def.type_attr(), dtype);
ictx->attrs.insert(input_def.type_attr());
}
}

void OpInferMixedTypeInputListAttrs(TFE_Op* op,
const tensorflow::OpDef::ArgDef& input_def,
TFE_TensorHandle** inputs, int num_inputs) {
void OpInferMixedTypeInputListAttrs(
TFE_Op* op, const tensorflow::OpDef::ArgDef& input_def,
const std::vector<tensorflow::DataType>& dtypes) {
TFE_OpInferenceContext* ictx = op->inference_ctx.get();
if (ictx->attrs.find(input_def.type_list_attr()) == ictx->attrs.end()) {
std::unique_ptr<tensorflow::DataType[]> dtypes(
new tensorflow::DataType[num_inputs]);
for (int i = 0; i < num_inputs; ++i) {
dtypes[i] = inputs[i]->handle->dtype;
}
op->operation.MutableAttrs()->Set(
input_def.type_list_attr(),
tensorflow::gtl::ArraySlice<const tensorflow::DataType>(dtypes.get(),
num_inputs));
tensorflow::gtl::ArraySlice<const tensorflow::DataType>(dtypes.data(),
dtypes.size()));
ictx->attrs.insert(input_def.type_list_attr());
}
}
Expand All @@ -675,10 +669,15 @@ tensorflow::Status OpInferInputListAttrs(TFE_Op* op, TFE_TensorHandle** inputs,
TFE_OpInferenceContext* ictx = op->inference_ctx.get();
const auto& input_def = ictx->op_def->input_arg(ictx->input_arg_idx++);
if (!input_def.type_list_attr().empty()) {
OpInferMixedTypeInputListAttrs(op, input_def, inputs, num_inputs);
std::vector<tensorflow::DataType> dtypes(num_inputs);
for (int i = 0; i < num_inputs; ++i) {
dtypes[i] = inputs[i]->handle->dtype;
}
OpInferMixedTypeInputListAttrs(op, input_def, dtypes);
} else if (!input_def.type_attr().empty() &&
!input_def.number_attr().empty()) {
OpInferSingleTypeInputListAttrs(op, input_def, inputs, num_inputs);
OpInferSingleTypeInputListAttrs(op, input_def, inputs[0]->handle->dtype,
num_inputs);
} else {
return tensorflow::errors::InvalidArgument("Invalid input list definition");
}
Expand Down

0 comments on commit 37e3630

Please sign in to comment.