Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
76 changes: 58 additions & 18 deletions bin/pytorch_inference/CCommandParser.cc
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,7 @@ namespace torch {

const std::string CCommandParser::REQUEST_ID{"request_id"};
const std::string CCommandParser::TOKENS{"tokens"};
const std::string CCommandParser::INPUTS{"inputs"};
const std::string CCommandParser::VAR_ARG_PREFIX{"arg_"};
const std::string CCommandParser::UNKNOWN_ID;

Expand Down Expand Up @@ -90,23 +91,37 @@ bool CCommandParser::validateJson(const rapidjson::Document& doc,
return false;
}

if (doc.HasMember(TOKENS) == false) {
errorHandler(doc[REQUEST_ID].GetString(),
"Invalid command: missing field [" + TOKENS + "]");
return false;
}
if (doc.HasMember(TOKENS)) {
const rapidjson::Value& tokens = doc[TOKENS];
if (tokens.IsArray() == false) {
errorHandler(doc[REQUEST_ID].GetString(),
"Invalid command: expected an array [" + TOKENS + "]");
return false;
}

const rapidjson::Value& tokens = doc[TOKENS];
if (tokens.IsArray() == false) {
errorHandler(doc[REQUEST_ID].GetString(),
"Invalid command: expected an array [" + TOKENS + "]");
return false;
}
if (checkArrayContainsUInts(tokens) == false) {
errorHandler(doc[REQUEST_ID].GetString(),
"Invalid command: array [" + TOKENS +
"] contains values that are not unsigned integers");
return false;
}
} else if (doc.HasMember(INPUTS)) {
const rapidjson::Value& inputs = doc[INPUTS];
if (inputs.IsArray() == false) {
errorHandler(doc[REQUEST_ID].GetString(),
"Invalid command: expected an array [" + INPUTS + "]");
return false;
}

if (checkArrayContainsUInts(tokens) == false) {
if (checkArrayContainsDoubles(inputs) == false) {
errorHandler(doc[REQUEST_ID].GetString(),
"Invalid command: array [" + INPUTS +
"] contains values that are not doubles");
return false;
}
} else {
errorHandler(doc[REQUEST_ID].GetString(),
"Invalid command: array [" + TOKENS +
"] contains values that are not unsigned integers");
"Invalid command: missing field [" + TOKENS + "|" + INPUTS + "]");
return false;
}

Expand Down Expand Up @@ -145,16 +160,37 @@ bool CCommandParser::checkArrayContainsUInts(const rapidjson::Value& arr) const
return allInts;
}

bool CCommandParser::checkArrayContainsDoubles(const rapidjson::Value& arr) const {
bool allDoubles{true};

for (auto itr = arr.Begin(); itr != arr.End(); ++itr) {
allDoubles = allDoubles && itr->IsDouble();
}

return allDoubles;
}

void CCommandParser::jsonToRequest(const rapidjson::Document& doc) {

m_Request.s_RequestId = doc[REQUEST_ID].GetString();
const rapidjson::Value& arr = doc[TOKENS];

// wipe any previous
m_Request.s_Tokens.clear();
m_Request.s_Tokens.reserve(arr.Size());
if (doc.HasMember(TOKENS)) {
const rapidjson::Value& arr = doc[TOKENS];
m_Request.s_Tokens.reserve(arr.Size());
for (auto itr = arr.Begin(); itr != arr.End(); ++itr) {
m_Request.s_Tokens.push_back(itr->GetUint64());
}
}

for (auto itr = arr.Begin(); itr != arr.End(); ++itr) {
m_Request.s_Tokens.push_back(itr->GetUint64());
m_Request.s_Inputs.clear();
if (doc.HasMember(INPUTS)) {
const rapidjson::Value& arr = doc[INPUTS];
m_Request.s_Inputs.reserve(arr.Size());
for (auto itr = arr.Begin(); itr != arr.End(); ++itr) {
m_Request.s_Inputs.push_back(itr->GetDouble());
}
}

std::uint64_t varCount{1};
Expand All @@ -175,5 +211,9 @@ void CCommandParser::jsonToRequest(const rapidjson::Document& doc) {
varArgName = VAR_ARG_PREFIX + std::to_string(varCount);
}
}

bool CCommandParser::SRequest::hasTokens() {
return s_Tokens.empty() == false;
}
}
}
6 changes: 5 additions & 1 deletion bin/pytorch_inference/CCommandParser.h
Original file line number Diff line number Diff line change
Expand Up @@ -45,18 +45,21 @@ class CCommandParser {
public:
static const std::string REQUEST_ID;
static const std::string TOKENS;
static const std::string INPUTS;
static const std::string VAR_ARG_PREFIX;
static const std::string UNKNOWN_ID;

using TUint64Vec = std::vector<std::uint64_t>;
using TUint64VecVec = std::vector<TUint64Vec>;
using TDoubleVec = std::vector<double>;

struct SRequest {
std::string s_RequestId;
TUint64Vec s_Tokens;
TUint64VecVec s_SecondaryArguments;
TDoubleVec s_Inputs;

void clear();
bool hasTokens();
};

using TRequestHandlerFunc = std::function<bool(SRequest&)>;
Expand All @@ -78,6 +81,7 @@ class CCommandParser {
bool validateJson(const rapidjson::Document& doc,
const TErrorHandlerFunc& errorHandler) const;
bool checkArrayContainsUInts(const rapidjson::Value& arr) const;
bool checkArrayContainsDoubles(const rapidjson::Value& arr) const;
void jsonToRequest(const rapidjson::Document& doc);

private:
Expand Down
104 changes: 69 additions & 35 deletions bin/pytorch_inference/Main.cc
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@
#include <torch/script.h>

#include <memory>
#include <sstream>
#include <string>

namespace {
Expand All @@ -33,45 +34,71 @@ const std::string ERROR{"error"};
torch::Tensor infer(torch::jit::script::Module& module,
ml::torch::CCommandParser::SRequest& request) {

torch::Tensor tokensTensor =
torch::from_blob(static_cast<void*>(request.s_Tokens.data()),
{1, static_cast<std::int64_t>(request.s_Tokens.size())},
at::dtype(torch::kInt64));

std::vector<torch::jit::IValue> inputs;
inputs.reserve(1 + request.s_SecondaryArguments.size());
inputs.push_back(tokensTensor);

for (auto& args : request.s_SecondaryArguments) {
inputs.emplace_back(torch::from_blob(
static_cast<void*>(args.data()),
{1, static_cast<std::int64_t>(args.size())}, at::dtype(torch::kInt64)));
if (request.hasTokens()) {
inputs.reserve(1 + request.s_SecondaryArguments.size());

// BERT UInt tokens
inputs.emplace_back(
torch::from_blob(static_cast<void*>(request.s_Tokens.data()),
{1, static_cast<std::int64_t>(request.s_Tokens.size())},
at::dtype(torch::kInt64)));

for (auto& args : request.s_SecondaryArguments) {
inputs.emplace_back(torch::from_blob(
static_cast<void*>(args.data()),
{1, static_cast<std::int64_t>(args.size())}, at::dtype(torch::kInt64)));
}
} else {
// floating point inputs
inputs.emplace_back(
torch::from_blob(static_cast<void*>(request.s_Inputs.data()),
{1, static_cast<std::int64_t>(request.s_Inputs.size())},
at::dtype(torch::kFloat64))
.to(torch::kFloat32));
}

torch::NoGradGuard noGrad;
auto tuple = module.forward(inputs).toTuple();
return tuple->elements()[0].toTensor();
auto result = module.forward(inputs);
if (result.isTuple()) {
// For BERT models the result tensor is the first element in a tuple
return result.toTuple()->elements()[0].toTensor();
} else {
return result.toTensor();
}
}

void writeTensor(torch::TensorAccessor<float, 1> accessor,
ml::core::CRapidJsonLineWriter<rapidjson::OStreamWrapper>& jsonWriter) {
jsonWriter.StartArray();
for (int i = 0; i < accessor.size(0); ++i) {
jsonWriter.Double(static_cast<double>(accessor[i]));
}
jsonWriter.EndArray();
}

void writeTensor(torch::TensorAccessor<float, 2> accessor,
ml::core::CRapidJsonLineWriter<rapidjson::OStreamWrapper>& jsonWriter) {
for (int i = 0; i < accessor.size(0); ++i) {
jsonWriter.StartArray();
for (int j = 0; j < accessor.size(1); ++j) {
jsonWriter.Double(static_cast<double>(accessor[i][j]));
}
jsonWriter.EndArray();
}
}

template<std::size_t N>
void writePrediction(const torch::Tensor& prediction,
const std::string& requestId,
std::ostream& outputStream) {

torch::Tensor view;
auto sizes = prediction.sizes();
// Some models return a 3D tensor in which case
// the first dimension must have size == 1
if (sizes.size() == 3 && sizes[0] == 1) {
view = prediction[0];
} else {
view = prediction;
}

// creating the accessor will throw if view does not
// have exactly 2 dimensions. Do this before writing
// creating the accessor will throw if the tensor does
// not have exactly N dimensions. Do this before writing
// any output so the error message isn't mingled with
// a partial result
auto accessor = view.accessor<float, 2>();
auto accessor = prediction.accessor<float, N>();

rapidjson::OStreamWrapper writeStream(outputStream);
ml::core::CRapidJsonLineWriter<rapidjson::OStreamWrapper> jsonWriter(writeStream);
Expand All @@ -81,13 +108,7 @@ void writePrediction(const torch::Tensor& prediction,
jsonWriter.Key(INFERENCE);
jsonWriter.StartArray();

for (int i = 0; i < accessor.size(0); ++i) {
jsonWriter.StartArray();
for (int j = 0; j < accessor.size(1); ++j) {
jsonWriter.Double(static_cast<double>(accessor[i][j]));
}
jsonWriter.EndArray();
}
writeTensor(accessor, jsonWriter);

jsonWriter.EndArray();
jsonWriter.EndObject();
Expand All @@ -107,10 +128,23 @@ void writeError(const std::string& requestId, const std::string& message, std::o
bool handleRequest(ml::torch::CCommandParser::SRequest& request,
torch::jit::script::Module& module,
std::ostream& outputStream) {

try {
torch::Tensor results = infer(module, request);
writePrediction(results, request.s_RequestId, outputStream);
auto sizes = results.sizes();
// Some models return a 3D tensor in which case
// the first dimension must have size == 1
if (sizes.size() == 3 && sizes[0] == 1) {
writePrediction<2>(results[0], request.s_RequestId, outputStream);
} else if (sizes.size() == 2) {
writePrediction<2>(results, request.s_RequestId, outputStream);
} else if (sizes.size() == 1) {
writePrediction<1>(results, request.s_RequestId, outputStream);
} else {
std::ostringstream ss;
ss << "Cannot convert results tensor of size [" << sizes << "]";
writeError(request.s_RequestId, ss.str(), outputStream);
}

} catch (std::runtime_error& e) {
writeError(request.s_RequestId, e.what(), outputStream);
}
Expand Down
16 changes: 13 additions & 3 deletions bin/pytorch_inference/evaluate.py
Original file line number Diff line number Diff line change
Expand Up @@ -83,7 +83,7 @@ def write_request(request, destination):
json.dump(request, destination)


def compare_results(expected, actual):
def compare_results(expected, actual, tolerance):
try:
if expected['request_id'] != actual['request_id']:
print("request_ids do not match [{}], [{}]".format(expected['request_id'], actual['request_id']), flush=True)
Expand All @@ -103,7 +103,7 @@ def compare_results(expected, actual):

are_close = True
for j in range(len(expected_row)):
are_close = are_close and math.isclose(expected_row[j], actual_row[j], rel_tol=1e-04)
are_close = are_close and math.isclose(expected_row[j], actual_row[j], abs_tol=tolerance)

if are_close == False:
print("row [{}] values are not close {}, {}".format(i, expected_row, actual_row), flush=True)
Expand Down Expand Up @@ -159,16 +159,26 @@ def main():
return

expected = test_evaluation[doc_count]['expected_output']

tolerance = 1e-04
if 'how_close' in test_evaluation[doc_count]:
tolerance = test_evaluation[doc_count]['how_close']

# compare to expected
if compare_results(expected, result) == False:
if compare_results(expected, result, tolerance) == False:
print()
print('ERROR: inference result [{}] does not match expected results'.format(doc_count))
print()
results_match = False

doc_count = doc_count +1

if doc_count != len(test_evaluation):
print()
print('ERROR: The number of inference results [{}] does not match expected count [{}]'.format(doc_count, len(test_evaluation)))
print()
results_match = False

if results_match:
print()
print('SUCCESS: inference results match expected', flush=True)
Expand Down
20 changes: 20 additions & 0 deletions bin/pytorch_inference/examples/simplest/test_run.json
Original file line number Diff line number Diff line change
@@ -0,0 +1,20 @@
[
{
"description": "Approximate sin(x) for x = pi",
"how_close": 0.2,
"input": {
"request_id": "one",
"inputs": [3.1416, 9.8696, 31.0063]
},
"expected_output": {"request_id": "one", "inference": [[0.0]]}
},
{
"description": "Approximate sin(x) for x= pi/2",
"how_close": 0.2,
"input": {
"request_id": "two",
"inputs": [1.5708, 2.4674, 3.8758]
},
"expected_output": {"request_id": "two", "inference": [[1.0]]}
}
]
Loading