Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions bin/pytorch_inference/CResultWriter.cc
Original file line number Diff line number Diff line change
Expand Up @@ -136,6 +136,9 @@ std::string CResultWriter::createInnerResult(const ::torch::Tensor& results) {
case 2:
this->writePrediction<2>(results, jsonWriter);
break;
case 1:
this->writePrediction<1>(results, jsonWriter);
break;
default: {
std::ostringstream ss;
ss << "Cannot convert results tensor of size [" << sizes << ']';
Expand Down
18 changes: 18 additions & 0 deletions bin/pytorch_inference/CResultWriter.h
Original file line number Diff line number Diff line change
Expand Up @@ -191,6 +191,24 @@ class CResultWriter : public TStringBufWriter {
jsonWriter.onObjectEnd();
}

//! Write a 1D inference result
template<typename T>
void writeInferenceResults(const ::torch::TensorAccessor<T, 1UL>& accessor,
TStringBufWriter& jsonWriter) {

jsonWriter.onKey(RESULT);
jsonWriter.onObjectBegin();
jsonWriter.onKey(INFERENCE);
// The Java side requires a 3D array, so wrap the 1D result in an
// extra outer array twice.
jsonWriter.onArrayBegin();
jsonWriter.onArrayBegin();
this->writeTensor(accessor, jsonWriter);
jsonWriter.onArrayEnd();
jsonWriter.onArrayEnd();
jsonWriter.onObjectEnd();
}

private:
core::CJsonOutputStreamWrapper m_WrappedOutputStream;
};
Expand Down
7 changes: 6 additions & 1 deletion bin/pytorch_inference/Main.cc
Original file line number Diff line number Diff line change
Expand Up @@ -73,7 +73,12 @@ torch::Tensor infer(torch::jit::script::Module& module_,
// For transformers the result tensor is the first element in a tuple.
all.push_back(output.toTuple()->elements()[0].toTensor());
} else {
all.push_back(output.toTensor());
auto outputTensor = output.toTensor();
if (outputTensor.dim() == 0) { // If the output is a scaler, we need to reshape it into a 1D tensor
all.push_back(outputTensor.reshape({1, 1}));
} else {
all.push_back(std::move(outputTensor));
}
}

inputs.clear();
Expand Down
12 changes: 12 additions & 0 deletions bin/pytorch_inference/unittest/CResultWriterTest.cc
Original file line number Diff line number Diff line change
Expand Up @@ -80,6 +80,18 @@ BOOST_AUTO_TEST_CASE(testCreateInnerInferenceResult) {
BOOST_REQUIRE_EQUAL(expected, innerPortion);
}

BOOST_AUTO_TEST_CASE(testCreateInnerInferenceResultFor1DimensionalResult) {
std::ostringstream output;
ml::torch::CResultWriter resultWriter{output};
::torch::Tensor tensor{::torch::ones({1})};
std::string innerPortion{resultWriter.createInnerResult(tensor)};
std::string expected = "\"result\":{\"inference\":"
"[[[1]]]}";
LOG_INFO(<< "expected: " << expected);
LOG_INFO(<< "actual: " << innerPortion);
BOOST_REQUIRE_EQUAL(expected, innerPortion);
}

BOOST_AUTO_TEST_CASE(testWrapAndWriteInferenceResult) {
std::string innerPortion{
"\"result\":{\"inference\":"
Expand Down
4 changes: 4 additions & 0 deletions docs/CHANGELOG.asciidoc
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,10 @@
* Allow the user to force a detector to shift time series state by a specific amount.
(See {ml-pull}2695[#2695].)

=== Bug Fixes

* Allow for pytorch_inference results to include zero-dimensional tensors.

== {es} version 8.15.2

=== Enhancements
Expand Down