Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Adds Tribuo provenance as a metadata field to exported ONNX models #182

Merged
merged 8 commits into from
Oct 18, 2021
Original file line number Diff line number Diff line change
Expand Up @@ -46,6 +46,7 @@
import org.tribuo.interop.onnx.LabelTransformer;
import org.tribuo.interop.onnx.ONNXExternalModel;
import org.tribuo.math.optimisers.AdaGrad;
import org.tribuo.provenance.ModelProvenance;
import org.tribuo.test.Helpers;

import java.io.IOException;
Expand All @@ -54,12 +55,15 @@
import java.util.HashMap;
import java.util.List;
import java.util.Map;
import java.util.Optional;
import java.util.logging.Level;
import java.util.logging.Logger;

import static org.junit.jupiter.api.Assertions.assertEquals;
import static org.junit.jupiter.api.Assertions.assertNotNull;
import static org.junit.jupiter.api.Assertions.assertNotSame;
import static org.junit.jupiter.api.Assertions.assertThrows;
import static org.junit.jupiter.api.Assertions.assertTrue;
import static org.junit.jupiter.api.Assertions.fail;

public class TestFMClassification {
Expand Down Expand Up @@ -176,6 +180,14 @@ public void testOnnxSerialization() throws IOException, OrtException {
}
}

// Check that the provenance can be extracted and is the same
ModelProvenance modelProv = model.getProvenance();
Optional<ModelProvenance> optProv = onnxModel.getTribuoProvenance();
assertTrue(optProv.isPresent());
ModelProvenance onnxProv = optProv.get();
assertNotSame(onnxProv, modelProv);
assertEquals(modelProv,onnxProv);

onnxModel.close();
} else {
logger.warning("ORT based tests only supported on x86_64, found " + arch);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -46,6 +46,7 @@
import org.junit.jupiter.api.Assertions;
import org.junit.jupiter.api.BeforeAll;
import org.junit.jupiter.api.Test;
import org.tribuo.provenance.ModelProvenance;
import org.tribuo.test.Helpers;

import java.io.IOException;
Expand All @@ -55,11 +56,15 @@
import java.util.HashMap;
import java.util.List;
import java.util.Map;
import java.util.Optional;
import java.util.logging.Level;
import java.util.logging.Logger;

import static org.junit.jupiter.api.Assertions.assertEquals;
import static org.junit.jupiter.api.Assertions.assertFalse;
import static org.junit.jupiter.api.Assertions.assertNotSame;
import static org.junit.jupiter.api.Assertions.assertThrows;
import static org.junit.jupiter.api.Assertions.assertTrue;
import static org.junit.jupiter.api.Assertions.fail;

public class TestSGDLinear {
Expand All @@ -82,10 +87,10 @@ public static LinearSGDModel testSGDLinear(Pair<Dataset<Label>,Dataset<Label>> p
LabelEvaluation evaluation = e.evaluate(m,p.getB());
Map<String, List<Pair<String,Double>>> features = m.getTopFeatures(3);
Assertions.assertNotNull(features);
Assertions.assertFalse(features.isEmpty());
assertFalse(features.isEmpty());
features = m.getTopFeatures(-1);
Assertions.assertNotNull(features);
Assertions.assertFalse(features.isEmpty());
assertFalse(features.isEmpty());
return m;
}

Expand Down Expand Up @@ -152,6 +157,14 @@ public void testOnnxSerialization() throws IOException, OrtException {
}
}

// Check that the provenance can be extracted and is the same
ModelProvenance modelProv = model.getProvenance();
Optional<ModelProvenance> optProv = onnxModel.getTribuoProvenance();
assertTrue(optProv.isPresent());
ModelProvenance onnxProv = optProv.get();
assertNotSame(onnxProv, modelProv);
assertEquals(modelProv,onnxProv);

onnxModel.close();
} else {
logger.warning("ORT based tests only supported on x86_64, found " + arch);
Expand Down Expand Up @@ -209,10 +222,10 @@ public void testSerializedModel(String resourceName) throws IOException, ClassNo
LabelEvaluation evaluation = e.evaluate(m,LabelledDataGenerator.denseTrainTest().getB());
Map<String, List<Pair<String,Double>>> features = m.getTopFeatures(3);
Assertions.assertNotNull(features);
Assertions.assertFalse(features.isEmpty());
assertFalse(features.isEmpty());
features = m.getTopFeatures(-1);
Assertions.assertNotNull(features);
Assertions.assertFalse(features.isEmpty());
assertFalse(features.isEmpty());
} else {
fail("Invalid model type found, expected Label");
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@
import org.tribuo.math.la.DenseVector;
import org.tribuo.math.la.Tensor;
import org.tribuo.onnx.ONNXContext;
import org.tribuo.onnx.ONNXExportable;
import org.tribuo.onnx.ONNXOperators;
import org.tribuo.provenance.ModelProvenance;

Expand Down Expand Up @@ -204,6 +205,14 @@ protected OnnxMl.ModelProto innerExportONNXModel(OnnxMl.GraphProto graph, String
builder.setDocString(toString());
builder.addOpsetImport(ONNXOperators.getOpsetProto());
builder.setIrVersion(6);

// Extract provenance and store in metadata
OnnxMl.StringStringEntryProto.Builder metaBuilder = OnnxMl.StringStringEntryProto.newBuilder();
metaBuilder.setKey(ONNXExportable.PROVENANCE_METADATA_FIELD);
String serializedProvenance = ONNXExportable.SERIALIZER.marshalAndSerialize(getProvenance());
metaBuilder.setValue(serializedProvenance);
builder.addMetadataProps(metaBuilder.build());

return builder.build();
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,7 @@
import org.tribuo.math.LinearParameters;
import org.tribuo.math.la.DenseMatrix;
import org.tribuo.onnx.ONNXContext;
import org.tribuo.onnx.ONNXExportable;
import org.tribuo.onnx.ONNXOperators;
import org.tribuo.provenance.ModelProvenance;

Expand Down Expand Up @@ -179,6 +180,14 @@ protected OnnxMl.ModelProto innerExportONNXModel(OnnxMl.GraphProto graph, String
builder.setDocString(toString());
builder.addOpsetImport(ONNXOperators.getOpsetProto());
builder.setIrVersion(6);

// Extract provenance and store in metadata
OnnxMl.StringStringEntryProto.Builder metaBuilder = OnnxMl.StringStringEntryProto.newBuilder();
metaBuilder.setKey(ONNXExportable.PROVENANCE_METADATA_FIELD);
String serializedProvenance = ONNXExportable.SERIALIZER.marshalAndSerialize(getProvenance());
metaBuilder.setValue(serializedProvenance);
builder.addMetadataProps(metaBuilder.build());

return builder.build();
}

Expand Down
4 changes: 4 additions & 0 deletions Core/pom.xml
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,10 @@
<groupId>com.oracle.labs.olcut</groupId>
<artifactId>olcut-core</artifactId>
</dependency>
<dependency>
<groupId>com.oracle.labs.olcut</groupId>
<artifactId>olcut-config-protobuf</artifactId>
</dependency>
<dependency>
<groupId>com.google.protobuf</groupId>
<artifactId>protobuf-java</artifactId>
Expand Down
21 changes: 20 additions & 1 deletion Core/src/main/java/org/tribuo/Model.java
Original file line number Diff line number Diff line change
Expand Up @@ -280,7 +280,7 @@ public Model<T> copy() {
}

/**
* Copies a model, replacing it's provenance and name with the supplied values.
* Copies a model, replacing its provenance and name with the supplied values.
* <p>
* Used to provide the provenance removal functionality.
* @param newName The new name.
Expand All @@ -297,5 +297,24 @@ public String toString() {
return provenanceOutput;
}
}

/**
* Casts the model to the specified output type, assuming it is valid.
* <p>
* If it's not valid, throws {@link ClassCastException}.
* @param inputModel The model to cast.
* @param outputType The output type to cast to.
* @param <T> The output type.
* @return The model cast to the correct value.
*/
public static <T extends Output<T>> Model<T> castModel(Model<?> inputModel, Class<T> outputType) {
if (inputModel.validate(outputType)) {
@SuppressWarnings("unchecked") // guarded by validate
Model<T> castedModel = (Model<T>) inputModel;
return castedModel;
} else {
throw new ClassCastException("Attempted to cast model to " + outputType.getName() + " which is not valid for model " + inputModel.toString());
}
}

}
23 changes: 23 additions & 0 deletions Core/src/main/java/org/tribuo/onnx/ONNXExportable.java
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,9 @@
package org.tribuo.onnx;

import ai.onnx.proto.OnnxMl;
import com.oracle.labs.mlrg.olcut.config.protobuf.ProtoProvenanceSerialization;
import com.oracle.labs.mlrg.olcut.provenance.io.ProvenanceSerialization;
import org.tribuo.provenance.ModelProvenance;

import java.io.BufferedOutputStream;
import java.io.FileOutputStream;
Expand All @@ -29,6 +32,17 @@
*/
public interface ONNXExportable {

/**
* The provenance serializer.
*/
public static final ProvenanceSerialization SERIALIZER = new ProtoProvenanceSerialization(true);

/**
* The name of the ONNX metadata field where the provenance information is stored
* in exported models.
*/
public static final String PROVENANCE_METADATA_FIELD = "TRIBUO_PROVENANCE";

/**
* Exports this {@link org.tribuo.Model} as an ONNX protobuf.
* @param domain A reverse-DNS name to namespace the model (e.g., org.tribuo.classification.sgd.linear).
Expand Down Expand Up @@ -74,4 +88,13 @@ default public void saveONNXModel(String domain, long modelVersion, Path outputP
}
}

/**
* Serializes the model provenance to a String.
* @param provenance The provenance to serialize.
* @return The serialized form of the ModelProvenance.
*/
default public String serializeProvenance(ModelProvenance provenance) {
return SERIALIZER.marshalAndSerialize(provenance);
}

}
5 changes: 5 additions & 0 deletions Interop/ONNX/pom.xml
Original file line number Diff line number Diff line change
Expand Up @@ -67,6 +67,11 @@
<artifactId>tribuo-util-tokenization</artifactId>
<version>${project.version}</version>
</dependency>
<dependency>
<groupId>com.oracle.labs.olcut</groupId>
<artifactId>olcut-config-protobuf</artifactId>
<version>${olcut.version}</version>
</dependency>
<dependency>
<groupId>com.microsoft.onnxruntime</groupId>
<artifactId>onnxruntime</artifactId>
Expand Down
Loading