Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Adding SLM protobuf serialization #269

Merged
merged 1 commit into from
Sep 13, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,8 @@
package org.tribuo.regression.slm;

import ai.onnx.proto.OnnxMl;
import com.google.protobuf.Any;
import com.google.protobuf.InvalidProtocolBufferException;
import com.oracle.labs.mlrg.olcut.util.Pair;
import org.tribuo.Example;
import org.tribuo.Excuse;
Expand All @@ -26,15 +28,22 @@
import org.tribuo.ONNXExportable;
import org.tribuo.Prediction;
import org.tribuo.VariableInfo;
import org.tribuo.impl.ModelDataCarrier;
import org.tribuo.math.la.DenseSparseMatrix;
import org.tribuo.math.la.DenseVector;
import org.tribuo.math.la.SparseVector;
import org.tribuo.math.la.Tensor;
import org.tribuo.math.la.VectorTuple;
import org.tribuo.math.protos.TensorProto;
import org.tribuo.protos.core.ModelProto;
import org.tribuo.provenance.ModelProvenance;
import org.tribuo.provenance.TrainerProvenance;
import org.tribuo.regression.ImmutableRegressionInfo;
import org.tribuo.regression.Regressor;
import org.tribuo.regression.Regressor.DimensionTuple;
import org.tribuo.regression.impl.SkeletalIndependentRegressionSparseModel;
import org.tribuo.regression.slm.protos.SparseLinearModelProto;
import org.tribuo.util.Util;
import org.tribuo.util.onnx.ONNXContext;
import org.tribuo.util.onnx.ONNXNode;
import org.tribuo.util.onnx.ONNXOperators;
Expand All @@ -54,6 +63,7 @@
import java.util.Optional;
import java.util.PriorityQueue;
import java.util.logging.Logger;
import java.util.stream.Collectors;

/**
* The inference time version of a sparse linear regression model.
Expand All @@ -64,6 +74,11 @@ public class SparseLinearModel extends SkeletalIndependentRegressionSparseModel
private static final long serialVersionUID = 3L;
private static final Logger logger = Logger.getLogger(SparseLinearModel.class.getName());

/**
* Protobuf serialization version.
*/
public static final int CURRENT_VERSION = 0;

private SparseVector[] weights;
private final DenseVector featureMeans;
/**
Expand Down Expand Up @@ -93,6 +108,81 @@ public class SparseLinearModel extends SkeletalIndependentRegressionSparseModel
this.enet41MappingFix = true;
}

/**
* Deserialization factory.
* @param version The serialized object version.
* @param className The class name.
* @param message The serialized data.
*/
public static SparseLinearModel deserializeFromProto(int version, String className, Any message) throws InvalidProtocolBufferException {
if (version < 0 || version > CURRENT_VERSION) {
throw new IllegalArgumentException("Unknown version " + version + ", this class supports at most version " + CURRENT_VERSION);
}
SparseLinearModelProto proto = message.unpack(SparseLinearModelProto.class);

ModelDataCarrier<?> carrier = ModelDataCarrier.deserialize(proto.getMetadata());
if (!carrier.outputDomain().getOutput(0).getClass().equals(Regressor.class)) {
throw new IllegalStateException("Invalid protobuf, output domain is not a regression domain, found " + carrier.outputDomain().getClass());
}
@SuppressWarnings("unchecked") // guarded by getClass
ImmutableOutputInfo<Regressor> outputDomain = (ImmutableOutputInfo<Regressor>) carrier.outputDomain();

String[] dimensions = new String[proto.getDimensionsCount()];
if (dimensions.length != outputDomain.size()) {
throw new IllegalStateException("Invalid protobuf, found insufficient dimension names, expected " + outputDomain.size() + ", found " + dimensions.length);
}
for (int i = 0; i < dimensions.length; i++) {
dimensions[i] = proto.getDimensions(i);
}

SparseVector[] weights = new SparseVector[outputDomain.size()];
if (weights.length != proto.getWeightsCount()) {
throw new IllegalStateException("Invalid protobuf, expected same weight dimension as output domain size, found " + proto.getWeightsCount() + " weights and " + outputDomain.size() + " output dimensions");
}
int featureSize = proto.getBias() ? carrier.featureDomain().size() + 1 : carrier.featureDomain().size();
for (int i = 0; i < weights.length; i++) {
Tensor deser = Tensor.deserialize(proto.getWeights(i));
if (deser instanceof SparseVector) {
SparseVector v = (SparseVector) deser;
if (v.size() == featureSize) {
weights[i] = v;
} else {
throw new IllegalStateException("Invalid protobuf, weights size and feature domain do not match, expected " + featureSize + ", found " + v.size());
}
} else {
throw new IllegalStateException("Invalid protobuf, expected a SparseVector, found " + deser.getClass());
}
}

Tensor featureMeansTensor = Tensor.deserialize(proto.getFeatureMeans());
if (!(featureMeansTensor instanceof DenseVector)) {
throw new IllegalStateException("Invalid protobuf, feature means must be a dense vector, found " + featureMeansTensor.getClass());
}
DenseVector featureMeans = (DenseVector) featureMeansTensor;
if (featureMeans.size() != featureSize) {
throw new IllegalStateException("Invalid protobuf, feature means not the right size, expected " + featureSize + ", found " + featureMeans.size());
}
Tensor featureNormsTensor = Tensor.deserialize(proto.getFeatureNorms());
if (!(featureNormsTensor instanceof DenseVector)) {
throw new IllegalStateException("Invalid protobuf, feature means must be a dense vector, found " + featureNormsTensor.getClass());
}
DenseVector featureNorms = (DenseVector) featureNormsTensor;
if (featureNorms.size() != featureSize) {
throw new IllegalStateException("Invalid protobuf, feature means not the right size, expected " + featureSize + ", found " + featureNorms.size());
}
double[] yMean = Util.toPrimitiveDouble(proto.getYMeanList());
if (yMean.length != outputDomain.size()) {
throw new IllegalStateException("Invalid protobuf, y means not the right size, expected " + carrier.outputDomain().size() + " found " + yMean.length);
}
double[] yNorm = Util.toPrimitiveDouble(proto.getYNormList());
if (yNorm.length != outputDomain.size()) {
throw new IllegalStateException("Invalid protobuf, y norms not the right size, expected " + carrier.outputDomain().size() + " found " + yNorm.length);
}

return new SparseLinearModel(carrier.name(),dimensions, carrier.provenance(),carrier.featureDomain(),outputDomain,
weights, featureMeans, featureNorms, yMean, yNorm, proto.getBias());
}

private static Map<String, List<String>> generateActiveFeatures(String[] dimensionNames, ImmutableFeatureMap featureMap, SparseVector[] weightsArray) {
Map<String, List<String>> map = new HashMap<>();

Expand Down Expand Up @@ -224,6 +314,30 @@ public Map<String, SparseVector> getWeights() {
return output;
}

@Override
public ModelProto serialize() {
ModelDataCarrier<Regressor> carrier = createDataCarrier();

SparseLinearModelProto.Builder modelBuilder = SparseLinearModelProto.newBuilder();
modelBuilder.setMetadata(carrier.serialize());
modelBuilder.addAllDimensions(Arrays.asList(dimensions));
for (SparseVector v : weights) {
modelBuilder.addWeights(v.serialize());
}
modelBuilder.setFeatureMeans(featureMeans.serialize());
modelBuilder.setFeatureNorms(featureVariance.serialize());
modelBuilder.setBias(bias);
modelBuilder.addAllYMean(Arrays.stream(yMean).boxed().collect(Collectors.toList()));
modelBuilder.addAllYNorm(Arrays.stream(yVariance).boxed().collect(Collectors.toList()));

ModelProto.Builder builder = ModelProto.newBuilder();
builder.setSerializedData(Any.pack(modelBuilder.build()));
builder.setClassName(SparseLinearModel.class.getName());
builder.setVersion(CURRENT_VERSION);

return builder.build();
}

@Override
public OnnxMl.ModelProto exportONNXModel(String domain, long modelVersion) {
ONNXContext onnx = new ONNXContext();
Expand Down
Loading