# Model export and deployment tutorial

Tribuo works best as a library which provides training and deployment inside the JVM where the application is running, however sometimes you need to deploy models elsewhere, either in another programming environment like Python, or in a cloud service. To support these use cases many of Tribuo's models can be exported as [ONNX](https://onnx.ai) models, a cross-platform model exchange format. ONNX is widely supported across industry, for edge devices, hardware accelerators, and cloud services. Tribuo also supports loading in ONNX models and scoring them as native Tribuo models, for more information on that see the external models tutorial.

This tutorial will show how to export models in ONNX format, how to recover the provenance information from Tribuo-exported ONNX models, and how to deploy an ONNX model in [OCI Data Science](https://www.oracle.com/data-science/cloud-infrastructure-data-science.html) though of course other cloud providers support ONNX models too. We'll show how to export a factorization machine, create an ensemble of a factorization machine along with some other models, export the ensemble, then we'll discuss how to interact with the provenance of an exported model, before concluding with deploying that model to OCI.

## Setup

This tutorial requires ONNX Runtime to score the exported models, so by default will only run on x86\_64 platforms. ONNX Runtime can be compiled on ARM64 platforms, but that binary is not in the Maven Central jar Tribuo depends on, so will need to be compiled from scratch to run the tutorial on ARM.

We're going to use MNIST as the example dataset for this tutorial, so you'll need to download it if you haven't already.

First the training set:
`wget http://yann.lecun.com/exdb/mnist/train-images-idx3-ubyte.gz http://yann.lecun.com/exdb/mnist/train-labels-idx1-ubyte.gz`

Then the test set:
`wget http://yann.lecun.com/exdb/mnist/t10k-images-idx3-ubyte.gz http://yann.lecun.com/exdb/mnist/t10k-labels-idx1-ubyte.gz`

As usual we'll load in some jars for classification problems, along with Tribuo's ONNX Runtime and OCI interfaces.

In [1]:
%jars ./tribuo-classification-experiments-4.2.0-SNAPSHOT-jar-with-dependencies.jar
%jars ./tribuo-onnx-4.2.0-SNAPSHOT-jar-with-dependencies.jar
%jars ./tribuo-json-4.2.0-SNAPSHOT-jar-with-dependencies.jar

In [2]:
import java.nio.file.Files;
import java.nio.file.Paths;

import org.tribuo.*;
import org.tribuo.classification.*;
import org.tribuo.classification.ensemble.*;
import org.tribuo.classification.evaluation.*;
import org.tribuo.classification.sgd.fm.FMClassificationTrainer;
import org.tribuo.classification.sgd.linear.*;
import org.tribuo.classification.sgd.objectives.LogMulticlass;
import org.tribuo.ensemble.*;
import org.tribuo.data.csv.CSVLoader;
import org.tribuo.datasource.IDXDataSource;
import org.tribuo.evaluation.TrainTestSplitter;
import org.tribuo.interop.onnx.*;
import org.tribuo.math.optimisers.*;
import org.tribuo.onnx.*;
import org.tribuo.util.Util;
import com.oracle.labs.mlrg.olcut.provenance.ProvenanceUtil;
import com.oracle.labs.mlrg.olcut.util.Pair;

import ai.onnxruntime.*;

Then we'll load in MNIST and Wine Quality.

In [3]:
var labelFactory = new LabelFactory();
var labelEvaluator = new LabelEvaluator();
var mnistTrainSource = new IDXDataSource<>(Paths.get("train-images-idx3-ubyte.gz"),Paths.get("train-labels-idx1-ubyte.gz"),labelFactory);
var mnistTestSource = new IDXDataSource<>(Paths.get("t10k-images-idx3-ubyte.gz"),Paths.get("t10k-labels-idx1-ubyte.gz"),labelFactory);
var mnistTrain = new MutableDataset<>(mnistTrainSource);
var mnistTest = new MutableDataset<>(mnistTestSource);
System.out.println(String.format("MNIST train size = %d, number of features = %d, number of classes = %d",mnistTrain.size(),mnistTrain.getFeatureMap().size(),mnistTrain.getOutputInfo().size()));
System.out.println(String.format("MNIST test size = %d, number of features = %d, number of classes = %d",mnistTest.size(),mnistTest.getFeatureMap().size(),mnistTest.getOutputInfo().size()));

MNIST train size = 60000, number of features = 717, number of classes = 10
MNIST test size = 10000, number of features = 668, number of classes = 10


## Exporting a single classification model

We're going to train a multi-class [Factorization Machine](https://ieeexplore.ieee.org/document/5694074), which is a non-linear model that approximates all the non-linear feature interactions with a small per-feature embedding vector. It's similar to a logistic regression with an additional feature-feature interaction term, one per output label. In Tribuo Factorization Machines can be trained using stochastic gradient descent, using the standard SGD algorithms Tribuo uses for other models. We're going to use AdaGrad as it's usually a good baseline.

In [4]:
var fmLabelTrainer = new FMClassificationTrainer(new LogMulticlass(),  // Loss function
                                                 new AdaGrad(0.1,0.1), // Gradient optimiser
                                                 5,                    // Number of training epochs
                                                 30000,                // Logging interval
                                                 Trainer.DEFAULT_SEED, // RNG seed
                                                 6,                    // Factor size
                                                 0.1                   // Factor initialisation variance
                                                 );

After defining the model we train it as usual. Factorization machines take a little longer to train than logistic regression does, but not excessively so.

In [5]:
var fmStartTime = System.currentTimeMillis();
var fmMNIST = fmLabelTrainer.train(mnistTrain);
var fmEndTime = System.currentTimeMillis();
System.out.println("Training factorization machine took " + Util.formatDuration(fmStartTime,fmEndTime));

Training factorization machine took (00:00:11:687)


And then evaluate it using Tribuo's built in evaluation system.

In [6]:
fmStartTime = System.currentTimeMillis();
var mnistFMEval = labelEvaluator.evaluate(fmMNIST,mnistTest);
fmEndTime = System.currentTimeMillis();
System.out.println("Scoring factorization machine took " + Util.formatDuration(fmStartTime,fmEndTime));
System.out.println(mnistFMEval.toString());
System.out.println(mnistFMEval.getConfusionMatrix().toString());

Scoring factorization machine took (00:00:00:431)
Class                           n          tp          fn          fp      recall        prec          f1
0                             980         959          21          31       0.979       0.969       0.974
1                           1,135       1,120          15          22       0.987       0.981       0.984
2                           1,032         976          56          57       0.946       0.945       0.945
3                           1,010         952          58          39       0.943       0.961       0.952
4                             982         952          30          49       0.969       0.951       0.960
5                             892         857          35          63       0.961       0.932       0.946
6                             958         920          38          30       0.960       0.968       0.964
7                           1,028         969          59          36       0.943       0.964       0.

We get about 95% accuracy on MNIST, which is pretty good for a fairly simple model. Now let's export it to ONNX, then  we'll load it back in via Tribuo's ONNX Runtime interface and compare the performance. We'll use this model in the reproducibility tutorial so we'll save it to disk in the tutorials folder.

Tribuo `Model`s which support ONNX export implement the `ONNXExportable` interface which defines methods for constructing an ONNX protobuf and saving it to disk.

In [7]:
var fmMNISTPath = Paths.get(".","fm-mnist.onnx");
fmMNIST.saveONNXModel("org.tribuo.tutorials.onnxexport.fm", // namespace for the model
                      0,                                    // model version number
                      fmMNISTPath                           // path to save the model
                      );

To load an ONNX model we need to define the mapping between Tribuo's feature names and the indices that the ONNX model understands. Fortunately for models exported from Tribuo we already have that information, as it is stored in the feature and output maps. We'll extract it into the general form that `ONNXExternalModel` expects.

In [8]:
Map<String, Integer> mnistFeatureMap = new HashMap<>();
for (VariableInfo f : fmMNIST.getFeatureIDMap()){
    VariableIDInfo id = (VariableIDInfo) f;
    mnistFeatureMap.put(id.getName(),id.getID());
}
Map<Label, Integer> mnistOutputMap = new HashMap<>();
for (Pair<Integer,Label> l : fmMNIST.getOutputIDInfo()) {
    mnistOutputMap.put(l.getB(), l.getA());
}

Now we'll define a test function that compares two sets of predictions, as ONNX Runtime uses single precision for computations, and Tribuo uses double precision so the prediction scores are never bitwise equal.

In [9]:
public boolean checkPredictions(List<Prediction<Label>> nativePredictions, List<Prediction<Label>> onnxPredictions, double delta) {
    for (int i = 0; i < nativePredictions.size(); i++) {
        Prediction<Label> tribuo = nativePredictions.get(i);
        Prediction<Label> external = onnxPredictions.get(i);
        // Check the predicted label
        if (!tribuo.getOutput().getLabel().equals(external.getOutput().getLabel())) {
            System.out.println("At index " + i + " predictions are not equal - "
                    + tribuo.getOutput().getLabel() + " and "
                    + external.getOutput().getLabel());
            return false;
        }
        // Check the maximum score
        if (Math.abs(tribuo.getOutput().getScore() - external.getOutput().getScore()) > delta) {
            System.out.println("At index " + i + " predictions are not equal - "
                    + tribuo.getOutput() + " and "
                    + external.getOutput());
            return false;
        }
        // Check the score distribution
        for (Map.Entry<String, Label> l : tribuo.getOutputScores().entrySet()) {
            Label other = external.getOutputScores().get(l.getKey());
            if (other == null) {
                System.out.println("At index " + i + " failed to find label " + l.getKey() + " in ORT prediction.");
                return false;
            } else {
                if (Math.abs(l.getValue().getScore() - other.getScore()) > delta) {
                    System.out.println("At index " + i + " predictions are not equal - "
                            + tribuo.getOutputScores() + " and "
                            + external.getOutputScores());
                    return false;
                }
            }
        }
    }
    return true;
}

Then we'll construct the `ONNXExternalModel` loading our freshly created ONNX model using the feature and output mappings we built earlier. First we create a `SessionOptions` which controls the model inference. By default it uses a single thread on one CPU, but by setting values in the options object before building the external model we can make it run on multiple threads, use GPUs or other accelerator hardware supported by ONNX Runtime.

In [10]:
var ortEnv = OrtEnvironment.getEnvironment();
var sessionOpts = new OrtSession.SessionOptions();
var denseTransformer = new DenseTransformer();
var labelTransformer = new LabelTransformer();
ONNXExternalModel<Label> onnxFM = ONNXExternalModel.createOnnxModel(labelFactory, mnistFeatureMap, mnistOutputMap,
                    denseTransformer, labelTransformer, sessionOpts, fmMNISTPath, "input");

An `ONNXExternalModel` is a Tribuo model so we can use the same evaluation infrastructure.

In [11]:
var onnxStartTime = System.currentTimeMillis();
var mnistONNXEval = labelEvaluator.evaluate(onnxFM,mnistTest);
var onnxEndTime = System.currentTimeMillis();
System.out.println("Scoring ONNX factorization machine took " + Util.formatDuration(onnxStartTime,onnxEndTime));
System.out.println(mnistONNXEval.toString());
System.out.println(mnistONNXEval.getConfusionMatrix().toString());

Scoring ONNX factorization machine took (00:00:00:822)
Class                           n          tp          fn          fp      recall        prec          f1
0                             980         959          21          31       0.979       0.969       0.974
1                           1,135       1,120          15          22       0.987       0.981       0.984
2                           1,032         976          56          57       0.946       0.945       0.945
3                           1,010         952          58          39       0.943       0.961       0.952
4                             982         952          30          49       0.969       0.951       0.960
5                             892         857          35          63       0.961       0.932       0.946
6                             958         920          38          30       0.960       0.968       0.964
7                           1,028         969          59          36       0.943       0.964    

The two models evaluate the same, but they could be producing slightly different probability values, so let's check it using our more precise comparsion function. `checkPrediction` will log any divergence it finds, as well as returning true or false if the predictions differ. We're going to use a delta of 1e-5, and consider differences below that threshold to be irrelevant.

In [12]:
System.out.println("Predictions are equal - " + 
                    checkPredictions(mnistFMEval.getPredictions(), mnistONNXEval.getPredictions(), 1e-5));

Predictions are equal - true


An important part of a Tribuo model is the provenance. We don't want to lose that information when exporting models to ONNX format, so we encode the provenance in the ONNX protobuf. It uses the marshalled provenance format from OLCUT, and the protos are available in OLCUT so they could be parsed in other systems. As a result when loading in a Tribuo-exported ONNX model the `ONNXExternalModel` class has two provenance objects, one for the `ONNXExternalModel` itself, and one for the original Model object.

Let's examine both of these provenances. First the one for the `ONNXExternalModel`:

In [13]:
System.out.println("ONNXExternalModel provenance:\n" + ProvenanceUtil.formattedProvenanceString(onnxFM.getProvenance()));

ONNXExternalModel provenance:
ONNXExternalModel(
	class-name = org.tribuo.interop.onnx.ONNXExternalModel
	dataset = Dataset(
			class-name = org.tribuo.Dataset
			datasource = DataSource(
					description = unknown-external-data
					outputFactory = LabelFactory(
							class-name = org.tribuo.classification.LabelFactory
						)
					datasource-creation-time = 2021-10-26T17:51:37.741785-04:00
				)
			transformations = List[]
			is-sequence = false
			is-dense = false
			num-examples = -1
			num-features = 717
			num-outputs = 10
			tribuo-version = 4.2.0-SNAPSHOT
		)
	trainer = Trainer(
			class-name = org.tribuo.Trainer
			fileModifiedTime = 2021-10-26T17:51:36.243-04:00
			modelHash = 8DD82B31BD7CFC1C520942590E173AED07AF33C97C32021EE94738FA9FF4CC89
			location = file:/Users/apocock/Development/Tribuo/tutorials/./fm-mnist.onnx
		)
	trained-at = 2021-10-26T17:51:37.739663-04:00
	instance-values = Map{
		model-domain=org.tribuo.tutorials.onnxexport.fm
		model-metadata-TRIBUO_PROVENANCE

This has the location the ONNX file was loaded from, a hash of the file, and timestamps for both the ONNX file and the model object wrapping it.

Now let's look at the original Model provenance:

In [14]:
System.out.println("ONNX file provenance:\n" + ProvenanceUtil.formattedProvenanceString(onnxFM.getTribuoProvenance().get()));

ONNX file provenance:
FMClassificationModel(
	class-name = org.tribuo.classification.sgd.fm.FMClassificationModel
	dataset = MutableDataset(
			class-name = org.tribuo.MutableDataset
			datasource = IDXDataSource(
					class-name = org.tribuo.datasource.IDXDataSource
					outputFactory = LabelFactory(
							class-name = org.tribuo.classification.LabelFactory
						)
					outputPath = /Users/apocock/Development/Tribuo/tutorials/train-labels-idx1-ubyte.gz
					featuresPath = /Users/apocock/Development/Tribuo/tutorials/train-images-idx3-ubyte.gz
					features-file-modified-time = 2000-07-21T14:20:24-04:00
					output-resource-hash = 3552534A0A558BBED6AED32B30C495CCA23D567EC52CAC8BE1A0730E8010255C
					datasource-creation-time = 2021-10-26T17:51:22.314557-04:00
					output-file-modified-time = 2000-07-21T14:20:27-04:00
					idx-feature-type = UBYTE
					features-resource-hash = 440FCABF73CC546FA21475E81EA370265605F56BE210A4024D2CA8F203523609
					host-short-name = DataSource
				)
			tran

We can also check that the provenance extracted from the ONNX file is the same as the provenance in the original model object.

In [15]:
var equality = fmMNIST.getProvenance().equals(onnxFM.getTribuoProvenance().get()) ? "equal" : "not equal";
System.out.println("Provenances are " + equality);

Provenances are equal


## Exporting an ensemble

Tribuo allows the creation of arbitrary ensembles, and these are usually powerful models which are useful to deploy. So we're going to make a 3 element voting ensemble out of our factorization machine along with two other models and export that to ONNX as well. The other models are a logistic regression and a smaller factorization machine, but we could use any classification model supported by Tribuo, including another ensemble. As this is a small ensemble of similar models our goal is to demonstrate the functionality rather than improve performance on MNIST too much.

In [22]:
var lrTrainer = new LogisticRegressionTrainer();
var smallFMTrainer = new FMClassificationTrainer(new LogMulticlass(),  // Loss function
                                                 new AdaGrad(0.1,0.1), // Gradient optimiser
                                                 2,                    // Number of training epochs
                                                 30000,                // Logging interval
                                                 42L,                  // RNG seed
                                                 3,                    // Factor size
                                                 0.1                   // Factor initialisation variance
                                                 );
var lrModel = lrTrainer.train(mnistTrain);
var smallFMModel = smallFMTrainer.train(mnistTrain);

Tribuo's `WeightedEnsembleModel` class allows the creation of arbitrary ensembles with or without voting weights. We're going to create an unweighted ensemble of our three models using the standard `VotingCombiner` which takes a majority vote between the three classes, with ties broken by the first label.

In [23]:
var ensemble = WeightedEnsembleModel.createEnsembleFromExistingModels("ensemble", // Model name
                                           List.of(fmMNIST,lrModel,smallFMModel), // Ensemble members
                                           new VotingCombiner());                 // Combination operator

In [24]:
var ensembleStartTime = System.currentTimeMillis();
var ensembleEval = labelEvaluator.evaluate(ensemble,mnistTest);
var ensembleEndTime = System.currentTimeMillis();
System.out.println("Scoring ensemble took " + Util.formatDuration(ensembleStartTime,ensembleEndTime));
System.out.println(ensembleEval.toString());
System.out.println(ensembleEval.getConfusionMatrix().toString());

Scoring ensemble took (00:00:00:603)
Class                           n          tp          fn          fp      recall        prec          f1
0                             980         965          15          43       0.985       0.957       0.971
1                           1,135       1,119          16          34       0.986       0.971       0.978
2                           1,032         979          53          86       0.949       0.919       0.934
3                           1,010         926          84          38       0.917       0.961       0.938
4                             982         937          45          49       0.954       0.950       0.952
5                             892         837          55          49       0.938       0.945       0.942
6                             958         922          36          32       0.962       0.966       0.964
7                           1,028         978          50          52       0.951       0.950       0.950
8        

As before, we use the `saveONNXModel` method on the `ONNXExportable` interface to write out the model. Note if one of the ensemble members isn't `ONNXExportable` then you'll get a runtime exception out of this call.

In [25]:
var ensemblePath = Paths.get(".","ensemble-mnist.onnx");
ensemble.saveONNXModel("org.tribuo.tutorials.onnxexport.ensemble", // namespace for the model
                      0,                                           // model version number
                      ensemblePath                                 // path to save the model
                      );

We can load this model into `ONNXExternalModel` as well:

In [27]:
var onnxEnsemble = ONNXExternalModel.createOnnxModel(labelFactory, mnistFeatureMap, mnistOutputMap,
                    denseTransformer, labelTransformer, sessionOpts, ensemblePath, "input");
onnxStartTime = System.currentTimeMillis();
var mnistONNXEnsembleEval = labelEvaluator.evaluate(onnxEnsemble,mnistTest);
onnxEndTime = System.currentTimeMillis();
System.out.println("Scoring ONNX ensemble took " + Util.formatDuration(onnxStartTime,onnxEndTime));
System.out.println("Predictions are equal - " + 
                    checkPredictions(ensembleEval.getPredictions(), mnistONNXEnsembleEval.getPredictions(), 1e-5));

Scoring ONNX ensemble took (00:00:00:994)
Predictions are equal - true


## Deploying the model

This portion of the tutorial describes how to deploy the ONNX model on OCI Data Science, using their model deployment service. ONNX models can also be deployed in many other machine learning cloud services, or via a functions-as-a-service offering using something like ONNX Runtime.