Skip to content

Commit

Permalink
fix(torch): throw error if multiple models are provided
Browse files Browse the repository at this point in the history
  • Loading branch information
Bycob authored and mergify[bot] committed Mar 24, 2021
1 parent fd3e476 commit efbd1f9
Showing 1 changed file with 18 additions and 4 deletions.
22 changes: 18 additions & 4 deletions src/backends/torch/torchlib.cc
Expand Up @@ -303,14 +303,28 @@ namespace dd
this->_mlmodel._proto = dest_net;
}

bool unsupported_model_configuration
= this->_mlmodel._traced.empty() && this->_mlmodel._proto.empty()
&& !NativeFactory::valid_template_def(_template);
bool model_not_found = this->_mlmodel._traced.empty()
&& this->_mlmodel._proto.empty()
&& !NativeFactory::valid_template_def(_template);

if (unsupported_model_configuration)
if (model_not_found)
throw MLLibInternalException("Use of libtorch backend needs either: "
"traced net, protofile or native template");

bool multiple_models_found
= ((!this->_mlmodel._traced.empty()) + (!this->_mlmodel._proto.empty())
+ NativeFactory::valid_template_def(_template))
> 1;
if (multiple_models_found)
{
this->_logger->error("traced: {}, proto: {}, template: {}",
this->_mlmodel._traced, this->_mlmodel._proto,
_template);
throw MLLibInternalException(
"Only one of these must be provided: traced net, protofile or "
"native template");
}

// FIXME(louis): out of if(bert) because we allow not to specify template
// at predict. Should we change this?
this->_inputc._input_format = "bert";
Expand Down

0 comments on commit efbd1f9

Please sign in to comment.