Skip to content

Commit

Permalink
Added support for implict output variable names
Browse files Browse the repository at this point in the history
  • Loading branch information
nadavbar committed Jun 13, 2017
1 parent e98efa3 commit b92b54c
Show file tree
Hide file tree
Showing 3 changed files with 49 additions and 24 deletions.
28 changes: 18 additions & 10 deletions src/CNTKModelObjectWrap.cc
Original file line number Diff line number Diff line change
Expand Up @@ -75,12 +75,14 @@ void CNTKModelObjectWrap::JsInputToCntk(Handle<Object> inputsObj, Handle<Array>
Nan::HandleScope scope;

// Get the name of the output nodes
for (unsigned int i=0; i < outputsArr->Length(); i++)
{
Local<String> outputNode = Nan::To<String>(Nan::Get(outputsArr, i).ToLocalChecked()).ToLocalChecked();
String::Value outputNodeVal(outputNode);
wstring outputNodeName(reinterpret_cast<wchar_t*>(*outputNodeVal));
outputVariablesNamesOut.push_back(outputNodeName);
if (!outputsArr.IsEmpty()) {
for (unsigned int i = 0; i < outputsArr->Length(); i++)
{
Local<String> outputNode = Nan::To<String>(Nan::Get(outputsArr, i).ToLocalChecked()).ToLocalChecked();
String::Value outputNodeVal(outputNode);
wstring outputNodeName(reinterpret_cast<wchar_t*>(*outputNodeVal));
outputVariablesNamesOut.push_back(outputNodeName);
}
}

// get the input value names & values
Expand Down Expand Up @@ -121,21 +123,27 @@ void CNTKModelObjectWrap::JsInputToCntk(Handle<Object> inputsObj, Handle<Array>

NAN_METHOD(CNTKModelObjectWrap::Eval) {
Nan::HandleScope scope;
if (info.Length() < 3 || !info[0]->IsObject() || !info[1]->IsArray() || !info[2]->IsFunction())
if (info.Length() < 2 || !info[info.Length() - 1]->IsFunction() || !info[0]->IsObject() || (info.Length() > 2 && !info[1]->IsArray()))
{
Nan::ThrowTypeError("Bad usage, expected arguments are: input args[key: input node name (string), value: input data (array of arrays)], output node names[array of strings], completion callback [function]");
Nan::ThrowTypeError("Bad usage, expected arguments are: input args[key: input node name (string), value: input data (array of arrays)], optional: output node names[array of strings], completion callback [function]");
return;
}

Local<Object> inputDataObj = Nan::To<Object>(info[0]).ToLocalChecked();
Local<Array> outputNodesArr = info[1].As<Array>();

Local<Array> outputNodesArr;
if (info.Length() > 2)
{
outputNodesArr = info[1].As<Array>();
}

CNTKEvalInputDataFloat inputData;
CNTKEvalOutputVariablesNames outputVariables;
JsInputToCntk(inputDataObj, outputNodesArr, inputData, outputVariables);

CNTKModelObjectWrap* objectWrap = Nan::ObjectWrap::Unwrap<CNTKModelObjectWrap>(info.This());
Callback *callback = new Callback(info[2].As<Function>());

Callback *callback = new Callback(info[info.Length() -1].As<Function>());

AsyncQueueWorker(new EvalModelAsyncWorker(callback, objectWrap->_model, inputData, outputVariables, CNTK::DeviceDescriptor::UseDefaultDevice()));
}
41 changes: 29 additions & 12 deletions src/EvalModelAsyncWorker.cc
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@ using std::string;
using std::map;
using std::stringstream;
using std::unordered_map;
using std::vector;


EvalModelAsyncWorker::EvalModelAsyncWorker(Nan::Callback *callback, CNTK::FunctionPtr model,
Expand Down Expand Up @@ -71,22 +72,38 @@ void EvalModelAsyncWorker::Execute()
inputVars[inputVar] = inputValue;
}

for (auto it = _outputVariablesNames.begin(); it != _outputVariablesNames.end(); it++)
if (_outputVariablesNames.size() > 0)
{
CNTK::Variable outputVar;
if (!CNTKUtils::GetOutputVaraiableByName(_model, *it, outputVar))
for (auto it = _outputVariablesNames.begin(); it != _outputVariablesNames.end(); it++)
{
stringstream errorMessageStream;
errorMessageStream << "Output variable: '" << it->c_str() << "' was not found in model.";
_errorMessage = errorMessageStream.str();
_errorOccured = true;
return;
}
CNTK::Variable outputVar;
if (!CNTKUtils::GetOutputVaraiableByName(_model, *it, outputVar))
{
stringstream errorMessageStream;
errorMessageStream << "Output variable: '" << it->c_str() << "' was not found in model.";
_errorMessage = errorMessageStream.str();
_errorOccured = true;
return;
}

CNTK::ValuePtr outputValue;
CNTK::ValuePtr outputValue;

_outputVars[outputVar] = outputValue;
_outputVarsByName[*it] = outputVar;
_outputVars[outputVar] = outputValue;
_outputVarsByName[*it] = outputVar;
}
}
else
{
// Output vars weren't specified explicitly, so we'll just get them from the model
vector<CNTK::Variable> outputVars = _model->Outputs();
for (auto outputVar : outputVars)
{
wstring varName = outputVar.Name();
_outputVariablesNames.push_back(varName);
CNTK::ValuePtr outputValue;
_outputVars[outputVar] = outputValue;
_outputVarsByName[varName] = outputVar;
}
}

_model->Forward(inputVars, _outputVars, _device);
Expand Down
4 changes: 2 additions & 2 deletions test/basic.js
Original file line number Diff line number Diff line change
Expand Up @@ -49,9 +49,9 @@ cntk.loadModel(modelPath, (err, model) => {
'input' : [rgbToOneChannel(img1), rgbToOneChannel(img2) ]
}

outputNodes = ['output']
//outputNodes = ['output']
console.info('Calling eval')
model.eval(inputData, outputNodes, (err, res)=>{
model.eval(inputData, (err, res)=>{
if (err) {
console.info(err);
return;
Expand Down

0 comments on commit b92b54c

Please sign in to comment.