Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Adds ONNX export support to the sparse linear models #163

Merged
merged 9 commits into from
Sep 7, 2021
Original file line number Diff line number Diff line change
Expand Up @@ -63,6 +63,7 @@
import static org.junit.jupiter.api.Assertions.fail;

public class TestSGDLinear {
private static final Logger logger = Logger.getLogger(TestSGDLinear.class.getName());

private static final LinearSGDTrainer t = new LinearSGDTrainer(new LogMulticlass(),new AdaGrad(0.1,0.1),5,1000, Trainer.DEFAULT_SEED);

Expand Down Expand Up @@ -122,32 +123,39 @@ public void testOnnxSerialization() throws IOException, OrtException {
outputMapping.put(l.getB(), l.getA());
}

// Load in via ORT
OrtEnvironment env = OrtEnvironment.getEnvironment();
env.close();
ONNXExternalModel<Label> onnxModel = ONNXExternalModel.createOnnxModel(new LabelFactory(),featureMapping,outputMapping,new DenseTransformer(),new LabelTransformer(),new OrtSession.SessionOptions(),onnxFile,"input");

// Generate predictions
List<Prediction<Label>> nativePredictions = model.predict(p.getB());
List<Prediction<Label>> onnxPredictions = onnxModel.predict(p.getB());

// Assert the predictions are identical
for (int i = 0; i < nativePredictions.size(); i++) {
Prediction<Label> tribuo = nativePredictions.get(i);
Prediction<Label> external = onnxPredictions.get(i);
assertEquals(tribuo.getOutput().getLabel(),external.getOutput().getLabel());
assertEquals(tribuo.getOutput().getScore(),external.getOutput().getScore(),1e-6);
for (Map.Entry<String,Label> l : tribuo.getOutputScores().entrySet()) {
Label other = external.getOutputScores().get(l.getKey());
if (other == null) {
fail("Failed to find label " + l.getKey() + " in ORT prediction.");
} else {
assertEquals(l.getValue().getScore(),other.getScore(),1e-6);
String arch = System.getProperty("os.arch");
if (arch.equalsIgnoreCase("amd64") || arch.equalsIgnoreCase("x86_64")) {
// Initialise the OrtEnvironment to load the native library
// (as OrtSession.SessionOptions doesn't trigger the static initializer).
OrtEnvironment env = OrtEnvironment.getEnvironment();
env.close();
// Load in via ORT
ONNXExternalModel<Label> onnxModel = ONNXExternalModel.createOnnxModel(new LabelFactory(), featureMapping, outputMapping, new DenseTransformer(), new LabelTransformer(), new OrtSession.SessionOptions(), onnxFile, "input");

// Generate predictions
List<Prediction<Label>> nativePredictions = model.predict(p.getB());
List<Prediction<Label>> onnxPredictions = onnxModel.predict(p.getB());

// Assert the predictions are identical
for (int i = 0; i < nativePredictions.size(); i++) {
Prediction<Label> tribuo = nativePredictions.get(i);
Prediction<Label> external = onnxPredictions.get(i);
assertEquals(tribuo.getOutput().getLabel(), external.getOutput().getLabel());
assertEquals(tribuo.getOutput().getScore(), external.getOutput().getScore(), 1e-6);
for (Map.Entry<String, Label> l : tribuo.getOutputScores().entrySet()) {
Label other = external.getOutputScores().get(l.getKey());
if (other == null) {
fail("Failed to find label " + l.getKey() + " in ORT prediction.");
} else {
assertEquals(l.getValue().getScore(), other.getScore(), 1e-6);
}
}
}
}

onnxModel.close();
onnxModel.close();
} else {
logger.warning("ORT based tests only supported on x86_64, found " + arch);
}

onnxFile.toFile().delete();
}
Expand Down
30 changes: 30 additions & 0 deletions Core/src/main/java/org/tribuo/onnx/ONNXUtils.java
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,11 @@
package org.tribuo.onnx;

import ai.onnx.proto.OnnxMl;
import com.google.protobuf.ByteString;

import java.nio.ByteBuffer;
import java.nio.ByteOrder;
import java.nio.FloatBuffer;

/**
* Helper functions for building ONNX protos.
Expand Down Expand Up @@ -44,4 +49,29 @@ public static OnnxMl.TypeProto buildTensorTypeNode(ONNXShape shape, OnnxMl.Tenso

return builder.build();
}

/**
* Builds a TensorProto containing the array.
* <p>
* Downcasts the doubles into floats as ONNX's fp64 support is poor compared to fp32.
Craigacp marked this conversation as resolved.
Show resolved Hide resolved
* @param context The naming context.
* @param name The base name for the proto.
* @param parameters The array to store in the proto.
* @return A TensorProto containing the array as floats.
*/
public static OnnxMl.TensorProto arrayBuilder(ONNXContext context, String name, double[] parameters) {
OnnxMl.TensorProto.Builder arrBuilder = OnnxMl.TensorProto.newBuilder();
arrBuilder.setName(context.generateUniqueName(name));
arrBuilder.addDims(parameters.length);
arrBuilder.setDataType(OnnxMl.TensorProto.DataType.FLOAT.getNumber());
ByteBuffer buffer = ByteBuffer.allocate(parameters.length*4).order(ByteOrder.LITTLE_ENDIAN);
FloatBuffer floatBuffer = buffer.asFloatBuffer();
for (int i = 0; i < parameters.length; i++) {
floatBuffer.put((float)parameters[i]);
}
floatBuffer.rewind();
arrBuilder.setRawData(ByteString.copyFrom(buffer));
return arrBuilder.build();
}

}
6 changes: 6 additions & 0 deletions Core/src/main/java/org/tribuo/onnx/package-info.java
Original file line number Diff line number Diff line change
Expand Up @@ -17,5 +17,11 @@
/**
* Interfaces and utilities for exporting Tribuo {@link org.tribuo.Model}s in
* <a href="https://onnx.ai">ONNX</a> format.
* <p>
* ONNX exported models use floats where Tribuo uses doubles, this is due
* to comparatively poor support for fp64 in ONNX deployment environments
* as compared to fp32. In addition fp32 executes better on the various
* accelerator backends available in
* <a href="https://onnxruntime.ai">ONNX Runtime</a>.
*/
package org.tribuo.onnx;
Original file line number Diff line number Diff line change
Expand Up @@ -56,6 +56,7 @@
import static org.junit.jupiter.api.Assertions.fail;

public class TestSGDLinear {
private static final Logger logger = Logger.getLogger(TestSGDLinear.class.getName());

private static final LinearSGDTrainer hinge = new LinearSGDTrainer(new Hinge(),new AdaGrad(0.1,0.1),5,1000, Trainer.DEFAULT_SEED);
private static final LinearSGDTrainer sigmoid = new LinearSGDTrainer(new BinaryCrossEntropy(),new AdaGrad(0.1,0.1),5,1000, Trainer.DEFAULT_SEED);
Expand Down Expand Up @@ -120,31 +121,38 @@ public void testOnnxSerialization() throws IOException, OrtException {
outputMapping.put(l.getB(), l.getA());
}

// Load in via ORT
OrtEnvironment env = OrtEnvironment.getEnvironment();
env.close();
ONNXExternalModel<MultiLabel> onnxModel = ONNXExternalModel.createOnnxModel(new MultiLabelFactory(),featureMapping,outputMapping,new DenseTransformer(),new MultiLabelTransformer(),new OrtSession.SessionOptions(),onnxFile,"input");

// Generate predictions
List<Prediction<MultiLabel>> nativePredictions = model.predict(test);
List<Prediction<MultiLabel>> onnxPredictions = onnxModel.predict(test);

// Assert the predictions are identical
for (int i = 0; i < nativePredictions.size(); i++) {
Prediction<MultiLabel> tribuo = nativePredictions.get(i);
Prediction<MultiLabel> external = onnxPredictions.get(i);
assertEquals(tribuo.getOutput().getLabelSet(), external.getOutput().getLabelSet());
for (Map.Entry<String,MultiLabel> l : tribuo.getOutputScores().entrySet()) {
MultiLabel other = external.getOutputScores().get(l.getKey());
if (other == null) {
fail("Failed to find label " + l.getKey() + " in ORT prediction.");
} else {
assertEquals(l.getValue().getScore(),other.getScore(),1e-6);
String arch = System.getProperty("os.arch");
if (arch.equalsIgnoreCase("amd64") || arch.equalsIgnoreCase("x86_64")) {
// Initialise the OrtEnvironment to load the native library
// (as OrtSession.SessionOptions doesn't trigger the static initializer).
OrtEnvironment env = OrtEnvironment.getEnvironment();
env.close();
// Load in via ORT
ONNXExternalModel<MultiLabel> onnxModel = ONNXExternalModel.createOnnxModel(new MultiLabelFactory(),featureMapping,outputMapping,new DenseTransformer(),new MultiLabelTransformer(),new OrtSession.SessionOptions(),onnxFile,"input");

// Generate predictions
List<Prediction<MultiLabel>> nativePredictions = model.predict(test);
List<Prediction<MultiLabel>> onnxPredictions = onnxModel.predict(test);

// Assert the predictions are identical
for (int i = 0; i < nativePredictions.size(); i++) {
Prediction<MultiLabel> tribuo = nativePredictions.get(i);
Prediction<MultiLabel> external = onnxPredictions.get(i);
assertEquals(tribuo.getOutput().getLabelSet(), external.getOutput().getLabelSet());
for (Map.Entry<String,MultiLabel> l : tribuo.getOutputScores().entrySet()) {
MultiLabel other = external.getOutputScores().get(l.getKey());
if (other == null) {
fail("Failed to find label " + l.getKey() + " in ORT prediction.");
} else {
assertEquals(l.getValue().getScore(),other.getScore(),1e-6);
}
}
}
}

onnxModel.close();
onnxModel.close();
} else {
logger.warning("ORT based tests only supported on x86_64, found " + arch);
}

onnxFile.toFile().delete();
}
Expand Down
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
/*
* Copyright (c) 2015-2020, Oracle and/or its affiliates. All rights reserved.
* Copyright (c) 2015-2021, Oracle and/or its affiliates. All rights reserved.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
Expand Down Expand Up @@ -134,11 +134,26 @@ public ImmutableOutputInfo<Regressor> constructInfoForExternalModel(Map<Regresso
// Validate inputs are dense
OutputFactory.validateMapping(mapping);

// Coalesce all the mappings into a single Regressor.
String[] names = new String[mapping.size()];
double[] values = new double[mapping.size()];
double[] variances = new double[mapping.size()];
int i = 0;
for (Map.Entry<Regressor,Integer> m : mapping.entrySet()) {
Regressor r = m.getKey();
if (r.size() != 1) {
throw new IllegalArgumentException("Expected to find a DimensionTuple, found multiple dimensions for a single integer. Found = " + r);
}
names[i] = r.getNames()[0];
values[i] = r.getValues()[0];
variances[i] = r.getVariances()[0];
i++;
}
Regressor newRegressor = new Regressor(names,values,variances);

MutableRegressionInfo info = new MutableRegressionInfo();

for (Map.Entry<Regressor,Integer> e : mapping.entrySet()) {
info.observe(e.getKey());
}
info.observe(newRegressor);

return new ImmutableRegressionInfo(info,mapping);
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -64,6 +64,7 @@
import static org.junit.jupiter.api.Assertions.fail;

public class TestSGDLinear {
private static final Logger logger = Logger.getLogger(TestSGDLinear.class.getName());

private static final LinearSGDTrainer t = new LinearSGDTrainer(new SquaredLoss(), new AdaGrad(0.1,0.1),5,1000, Trainer.DEFAULT_SEED);

Expand Down Expand Up @@ -110,24 +111,31 @@ public void testOnnxSerialization() throws IOException, OrtException {
outputMapping.put(l.getB(), l.getA());
}

// Load in via ORT
OrtEnvironment env = OrtEnvironment.getEnvironment();
env.close();
ONNXExternalModel<Regressor> onnxModel = ONNXExternalModel.createOnnxModel(new RegressionFactory(),featureMapping,outputMapping,new DenseTransformer(),new RegressorTransformer(),new OrtSession.SessionOptions(),onnxFile,"input");

// Generate predictions
List<Prediction<Regressor>> nativePredictions = model.predict(p.getB());
List<Prediction<Regressor>> onnxPredictions = onnxModel.predict(p.getB());

// Assert the predictions are identical
for (int i = 0; i < nativePredictions.size(); i++) {
Prediction<Regressor> tribuo = nativePredictions.get(i);
Prediction<Regressor> external = onnxPredictions.get(i);
assertArrayEquals(tribuo.getOutput().getNames(),external.getOutput().getNames());
assertArrayEquals(tribuo.getOutput().getValues(),external.getOutput().getValues(),1e-6);
}
String arch = System.getProperty("os.arch");
if (arch.equalsIgnoreCase("amd64") || arch.equalsIgnoreCase("x86_64")) {
// Initialise the OrtEnvironment to load the native library
// (as OrtSession.SessionOptions doesn't trigger the static initializer).
OrtEnvironment env = OrtEnvironment.getEnvironment();
env.close();
// Load in via ORT
ONNXExternalModel<Regressor> onnxModel = ONNXExternalModel.createOnnxModel(new RegressionFactory(),featureMapping,outputMapping,new DenseTransformer(),new RegressorTransformer(),new OrtSession.SessionOptions(),onnxFile,"input");

// Generate predictions
List<Prediction<Regressor>> nativePredictions = model.predict(p.getB());
List<Prediction<Regressor>> onnxPredictions = onnxModel.predict(p.getB());

// Assert the predictions are identical
for (int i = 0; i < nativePredictions.size(); i++) {
Prediction<Regressor> tribuo = nativePredictions.get(i);
Prediction<Regressor> external = onnxPredictions.get(i);
assertArrayEquals(tribuo.getOutput().getNames(),external.getOutput().getNames());
assertArrayEquals(tribuo.getOutput().getValues(),external.getOutput().getValues(),1e-6);
}

onnxModel.close();
onnxModel.close();
} else {
logger.warning("ORT based tests only supported on x86_64, found " + arch);
}

onnxFile.toFile().delete();
}
Expand Down
16 changes: 11 additions & 5 deletions Regression/SLM/pom.xml
Original file line number Diff line number Diff line change
Expand Up @@ -47,6 +47,11 @@
<artifactId>tribuo-math</artifactId>
<version>${project.version}</version>
</dependency>
<dependency>
<groupId>org.apache.commons</groupId>
<artifactId>commons-math3</artifactId>
<version>${commonsmath.version}</version>
</dependency>
<!-- test time dependencies -->
<dependency>
<groupId>${project.groupId}</groupId>
Expand All @@ -55,17 +60,18 @@
<type>test-jar</type>
<scope>test</scope>
</dependency>
<dependency>
<groupId>${project.groupId}</groupId>
<artifactId>tribuo-onnx</artifactId>
<version>${project.version}</version>
<scope>test</scope>
</dependency>
<dependency>
<groupId>org.junit.jupiter</groupId>
<artifactId>junit-jupiter</artifactId>
<version>${junit.version}</version>
<scope>test</scope>
</dependency>
<dependency>
<groupId>org.apache.commons</groupId>
<artifactId>commons-math3</artifactId>
<version>${commonsmath.version}</version>
</dependency>
</dependencies>

<build>
Expand Down
Loading