Skip to content

Commit

Permalink
Added automatic inference of input and output varaibles
Browse files Browse the repository at this point in the history
  • Loading branch information
nadavbar committed Jun 13, 2017
1 parent b92b54c commit 4f51ccc
Show file tree
Hide file tree
Showing 4 changed files with 101 additions and 30 deletions.
97 changes: 71 additions & 26 deletions src/CNTKModelObjectWrap.cc
Original file line number Diff line number Diff line change
Expand Up @@ -70,6 +70,28 @@ NAN_METHOD(CNTKModelObjectWrap::New)
}
}

void CNTKModelObjectWrap::JsArrayToCntkInputData(Local<Object> dataObj, CNTKEvalInputDataHolder<float> &inputData)
{
// get number of rows:
Local<String> lengthSymb = Nan::New<String>("length").ToLocalChecked();
inputData.numberOfSamples = Nan::To<int32_t>(Nan::Get(dataObj, lengthSymb).ToLocalChecked()).FromMaybe(0);

// Insert object data to each row
// TODO: We might be able to optimize the initialiation by resizing according to the input data shape
// for now just leave this as is and let std do the resizing for us
for (int j = 0; j < inputData.numberOfSamples; j++)
{
Local<Object> entryObj = Nan::To<Object>(Nan::Get(dataObj, j).ToLocalChecked()).ToLocalChecked();
// TODO: We might be able to optimize this for networks with fixes length input by calling this per items
int itemsCount = Nan::To<int32_t>(Nan::Get(entryObj, lengthSymb).ToLocalChecked()).FromMaybe(0);
for (int k = 0; k < itemsCount; k++)
{
float value = static_cast<float>(Nan::To<double_t>(Nan::Get(entryObj, k).ToLocalChecked()).FromMaybe(0.0));
inputData.data.push_back(value);
}
}
}

void CNTKModelObjectWrap::JsInputToCntk(Handle<Object> inputsObj, Handle<Array> outputsArr, CNTKEvalInputDataFloat& inputDataOut, CNTKEvalOutputVariablesNames& outputVariablesNamesOut)
{
Nan::HandleScope scope;
Expand All @@ -85,38 +107,61 @@ void CNTKModelObjectWrap::JsInputToCntk(Handle<Object> inputsObj, Handle<Array>
}
}

// get the input value names & values
Local<Array> inputKeyNames = Nan::GetPropertyNames(inputsObj).ToLocalChecked();
for (unsigned int i=0; i < inputKeyNames->Length(); i++)
// get the input value names & value

if (inputsObj->IsArray())
{
CNTKEvalInputDataHolder<float> inputData;
Local<String> inputNode = Nan::To<String>(Nan::Get(inputKeyNames, i).ToLocalChecked()).ToLocalChecked();
String::Value inputNodeVal(inputNode);
inputData.inputVaraibleName = reinterpret_cast<wchar_t*>(*inputNodeVal);

Local<Object> dataObj = Nan::To<Object>(Nan::Get(inputsObj, inputNode).ToLocalChecked()).ToLocalChecked();

// get number of rows:
Local<String> lengthSymb = Nan::New<String>("length").ToLocalChecked();
inputData.numberOfSamples = Nan::To<int32_t>(Nan::Get(dataObj, lengthSymb).ToLocalChecked()).FromMaybe(0);

// Insert object data to each row
// TODO: We might be able to optimize the initialiation by resizing according to the input data shape
// for now just leave this as is and let std do the resizing for us
for (int j=0; j < inputData.numberOfSamples; j++)
Local<Array> inputsArr = inputsObj.As<Array>();

bool isArrayOfArrays = false;
if (inputsArr->Length() > 0) {
Local<Object> firstObj = Nan::To<Object>(Nan::Get(inputsObj, 0).ToLocalChecked()).ToLocalChecked();
Local<String> lengthSymb = Nan::New<String>("length").ToLocalChecked();
int firstArrLength = Nan::To<int32_t>(Nan::Get(firstObj, lengthSymb).ToLocalChecked()).FromMaybe(0);
if (firstArrLength > 0) {
Local<Object> firstNestedObj = Nan::To<Object>(Nan::Get(firstObj, 0).ToLocalChecked()).ToLocalChecked();
int nestedArrLength = Nan::To<int32_t>(Nan::Get(firstNestedObj, lengthSymb).ToLocalChecked()).FromMaybe(0);
isArrayOfArrays = nestedArrLength > 0;
}
}

// if this is an array of arrays of samples
if (isArrayOfArrays)
{
Local<Object> entryObj = Nan::To<Object>(Nan::Get(dataObj, j).ToLocalChecked()).ToLocalChecked();
// TODO: We might be able to optimize this for networks with fixes length input by calling this per items
int itemsCount = Nan::To<int32_t>(Nan::Get(entryObj, lengthSymb).ToLocalChecked()).FromMaybe(0);
for (int k=0; k < itemsCount; k++)
for (unsigned int i = 0; i < inputsArr->Length(); i++)
{
float value = static_cast<float>(Nan::To<double_t>(Nan::Get(entryObj, k).ToLocalChecked()).FromMaybe(0.0));
inputData.data.push_back(value);
CNTKEvalInputDataHolder<float> inputData;

Local<Object> dataObj = Nan::To<Object>(Nan::Get(inputsArr, i).ToLocalChecked()).ToLocalChecked();

JsArrayToCntkInputData(dataObj, inputData);
inputDataOut.push_back(inputData);
}
}
else // only one array which contain the samples
{
CNTKEvalInputDataHolder<float> inputData;
JsArrayToCntkInputData(inputsObj, inputData);
inputDataOut.push_back(inputData);
}
}
else // object with keys
{
Local<Array> inputKeyNames = Nan::GetPropertyNames(inputsObj).ToLocalChecked();

// TODO: we might need to return the number of samples as well
inputDataOut.push_back(inputData);
for (unsigned int i = 0; i < inputKeyNames->Length(); i++)
{
CNTKEvalInputDataHolder<float> inputData;
Local<String> inputNode = Nan::To<String>(Nan::Get(inputKeyNames, i).ToLocalChecked()).ToLocalChecked();
String::Value inputNodeVal(inputNode);
inputData.inputVaraibleName = reinterpret_cast<wchar_t*>(*inputNodeVal);

Local<Object> dataObj = Nan::To<Object>(Nan::Get(inputsObj, inputNode).ToLocalChecked()).ToLocalChecked();

JsArrayToCntkInputData(dataObj, inputData);

inputDataOut.push_back(inputData);
}
}

}
Expand Down
1 change: 1 addition & 0 deletions src/CNTKModelObjectWrap.h
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@ class CNTKModelObjectWrap : public Nan::ObjectWrap {
~CNTKModelObjectWrap();

static void JsInputToCntk(v8::Handle<v8::Object> inputsObj, v8::Handle<v8::Array> outputsArr, CNTKEvalInputDataFloat& inputDataOut, CNTKEvalOutputVariablesNames& outputVariablesNamesOut);
static void CNTKModelObjectWrap::JsArrayToCntkInputData(v8::Local<v8::Object> dataObj, CNTKEvalInputDataHolder<float> &inputData);

static NAN_METHOD(New);
static NAN_METHOD(Eval);
Expand Down
19 changes: 18 additions & 1 deletion src/EvalModelAsyncWorker.cc
Original file line number Diff line number Diff line change
Expand Up @@ -45,11 +45,27 @@ void EvalModelAsyncWorker::Execute()
try
{
unordered_map<CNTK::Variable, CNTK::ValuePtr> inputVars;
int index = 0;
auto modelInputs = _model->Arguments();
for (auto it = _inputData.begin(); it != _inputData.end(); it++)
{
CNTK::Variable inputVar;
// TODO: optimize this such that the model wrap will hold this instead
if (!CNTKUtils::GetInputVariableByName(_model, it->inputVaraibleName, inputVar))

if (it->inputVaraibleName.empty())
{
if (index > modelInputs.size())
{
stringstream errorMessageStream;
errorMessageStream << "Error: Provided number of input variables exceed the number that the model expects (" << modelInputs.size() << ")";
_errorMessage = errorMessageStream.str();
_errorOccured = true;
return;
}

inputVar = modelInputs[index];
}
else if (!CNTKUtils::GetInputVariableByName(_model, it->inputVaraibleName, inputVar))
{
stringstream errorMessageStream;
errorMessageStream << "Input variable: '" << it->inputVaraibleName.c_str() << "' was not found in model.";
Expand All @@ -70,6 +86,7 @@ void EvalModelAsyncWorker::Execute()
CNTK::ValuePtr inputValue = CNTK::MakeSharedObject<CNTK::Value>(CNTK::MakeSharedObject<CNTK::NDArrayView>(inputShape, it->data, true));

inputVars[inputVar] = inputValue;
index++;
}

if (_outputVariablesNames.size() > 0)
Expand Down
14 changes: 11 additions & 3 deletions test/basic.js
Original file line number Diff line number Diff line change
Expand Up @@ -45,13 +45,21 @@ cntk.loadModel(modelPath, (err, model) => {
img1 = images1[0];
pixel.parse(testImagePath2).then(function(images2) {
img2 = images2[0];
inputData = {
// inputs data can be an object with variable names
/*inputData = {
'input' : [rgbToOneChannel(img1), rgbToOneChannel(img2) ]
}
}*/

// this also works
//inputData = [[rgbToOneChannel(img1), rgbToOneChannel(img2)]]

// and this works as well
inputData = [rgbToOneChannel(img1), rgbToOneChannel(img2)]

// you can optionally specify output nodes that you are interested in
//outputNodes = ['output']
console.info('Calling eval')
model.eval(inputData, (err, res)=>{
model.eval(inputData, /*outputNodes,*/ (err, res)=>{
if (err) {
console.info(err);
return;
Expand Down

0 comments on commit 4f51ccc

Please sign in to comment.