Skip to content

Commit

Permalink
Adds protobuf serialization for the various baseline predictors, clas…
Browse files Browse the repository at this point in the history
…sifier chains, and viterbi (#277)

* Adding protobuf serialization to ViterbiModel and DummyClassifierModel.

* Adding protobuf serialization to DummyRegressionModel.

* Adding classifier chain and independent multi-label protobuf serialization.
  • Loading branch information
Craigacp committed Sep 23, 2022
1 parent 65678fb commit 04c6fc3
Show file tree
Hide file tree
Showing 33 changed files with 9,356 additions and 66 deletions.
@@ -1,5 +1,5 @@
/*
* Copyright (c) 2015-2020, Oracle and/or its affiliates. All rights reserved.
* Copyright (c) 2015, 2022, Oracle and/or its affiliates. All rights reserved.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
Expand All @@ -16,26 +16,37 @@

package org.tribuo.classification.baseline;

import com.google.protobuf.Any;
import com.google.protobuf.InvalidProtocolBufferException;
import com.oracle.labs.mlrg.olcut.util.Pair;
import com.sun.org.apache.xpath.internal.operations.Mod;
import org.tribuo.Example;
import org.tribuo.Excuse;
import org.tribuo.ImmutableFeatureMap;
import org.tribuo.ImmutableOutputInfo;
import org.tribuo.Model;
import org.tribuo.Output;
import org.tribuo.Prediction;
import org.tribuo.classification.ImmutableLabelInfo;
import org.tribuo.classification.Label;
import org.tribuo.classification.LabelFactory;
import org.tribuo.classification.baseline.DummyClassifierTrainer.DummyType;
import org.tribuo.classification.protos.DummyClassifierModelProto;
import org.tribuo.impl.ModelDataCarrier;
import org.tribuo.math.la.DenseSparseMatrix;
import org.tribuo.math.la.Tensor;
import org.tribuo.protos.core.ModelProto;
import org.tribuo.provenance.ModelProvenance;
import org.tribuo.util.Util;

import java.util.Arrays;
import java.util.Collections;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
import java.util.Optional;
import java.util.Random;
import java.util.stream.Collectors;

import static org.tribuo.Trainer.DEFAULT_SEED;

Expand All @@ -45,6 +56,11 @@
public class DummyClassifierModel extends Model<Label> {
private static final long serialVersionUID = 1L;

/**
* Protobuf serialization version.
*/
public static final int CURRENT_VERSION = 0;

private final DummyType dummyType;

private final Label constantLabel;
Expand Down Expand Up @@ -82,6 +98,55 @@ public class DummyClassifierModel extends Model<Label> {
this.rng = null;
}

private DummyClassifierModel(String name, ModelProvenance provenance, ImmutableFeatureMap featureMap,
ImmutableOutputInfo<Label> outputInfo, DummyType type, Label constantLabel,
double[] cdf, long seed) {
super(name, provenance, featureMap, outputInfo, false);
this.dummyType = type;
this.constantLabel = constantLabel;
this.cdf = cdf;
this.seed = seed;
this.rng = new Random(seed);
}

/**
* Deserialization factory.
* @param version The serialized object version.
* @param className The class name.
* @param message The serialized data.
*/
public static DummyClassifierModel deserializeFromProto(int version, String className, Any message) throws InvalidProtocolBufferException {
if (version < 0 || version > CURRENT_VERSION) {
throw new IllegalArgumentException("Unknown version " + version + ", this class supports at most version " + CURRENT_VERSION);
}
DummyClassifierModelProto proto = message.unpack(DummyClassifierModelProto.class);

ModelDataCarrier<?> carrier = ModelDataCarrier.deserialize(proto.getMetadata());
if (!carrier.outputDomain().getOutput(0).getClass().equals(Label.class)) {
throw new IllegalStateException("Invalid protobuf, output domain is not a label domain, found " + carrier.outputDomain().getClass());
}
@SuppressWarnings("unchecked") // guarded by getClass
ImmutableOutputInfo<Label> outputDomain = (ImmutableOutputInfo<Label>) carrier.outputDomain();

DummyType dummyType = DummyType.valueOf(proto.getDummyType());

Output<?> output = Output.deserialize(proto.getConstantLabel());
if (!(output instanceof Label)) {
throw new IllegalStateException("Invalid protobuf, expected a label, found " + output.getClass());
}
Label constantLabel = (Label) output;

double[] cdf = null;
if (proto.getCdfCount() > 0) {
cdf = Util.toPrimitiveDouble(proto.getCdfList());
}

long seed = proto.getSeed();

return new DummyClassifierModel(carrier.name(), carrier.provenance(), carrier.featureDomain(), outputDomain,
dummyType, constantLabel, cdf, seed);
}

@Override
public Prediction<Label> predict(Example<Label> example) {
switch (dummyType) {
Expand Down Expand Up @@ -110,6 +175,27 @@ public Optional<Excuse<Label>> getExcuse(Example<Label> example) {
return Optional.of(new Excuse<>(example,predict(example),getTopFeatures(1)));
}

@Override
public ModelProto serialize() {
ModelDataCarrier<Label> carrier = createDataCarrier();

DummyClassifierModelProto.Builder modelBuilder = DummyClassifierModelProto.newBuilder();
modelBuilder.setMetadata(carrier.serialize());
modelBuilder.setDummyType(dummyType.name());
modelBuilder.setConstantLabel(constantLabel.serialize());
if (cdf != null) {
modelBuilder.addAllCdf(Arrays.stream(cdf).boxed().collect(Collectors.toList()));
}
modelBuilder.setSeed(seed);

ModelProto.Builder builder = ModelProto.newBuilder();
builder.setSerializedData(Any.pack(modelBuilder.build()));
builder.setClassName(DummyClassifierModel.class.getName());
builder.setVersion(CURRENT_VERSION);

return builder.build();
}

@Override
protected DummyClassifierModel copy(String newName, ModelProvenance newProvenance) {
switch (dummyType) {
Expand Down

0 comments on commit 04c6fc3

Please sign in to comment.