Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Adds protobuf serialization for the various baseline predictors, classifier chains, and viterbi #277

Merged
merged 3 commits into from Sep 23, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
@@ -1,5 +1,5 @@
/*
* Copyright (c) 2015-2020, Oracle and/or its affiliates. All rights reserved.
* Copyright (c) 2015, 2022, Oracle and/or its affiliates. All rights reserved.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
Expand All @@ -16,26 +16,37 @@

package org.tribuo.classification.baseline;

import com.google.protobuf.Any;
import com.google.protobuf.InvalidProtocolBufferException;
import com.oracle.labs.mlrg.olcut.util.Pair;
import com.sun.org.apache.xpath.internal.operations.Mod;
import org.tribuo.Example;
import org.tribuo.Excuse;
import org.tribuo.ImmutableFeatureMap;
import org.tribuo.ImmutableOutputInfo;
import org.tribuo.Model;
import org.tribuo.Output;
import org.tribuo.Prediction;
import org.tribuo.classification.ImmutableLabelInfo;
import org.tribuo.classification.Label;
import org.tribuo.classification.LabelFactory;
import org.tribuo.classification.baseline.DummyClassifierTrainer.DummyType;
import org.tribuo.classification.protos.DummyClassifierModelProto;
import org.tribuo.impl.ModelDataCarrier;
import org.tribuo.math.la.DenseSparseMatrix;
import org.tribuo.math.la.Tensor;
import org.tribuo.protos.core.ModelProto;
import org.tribuo.provenance.ModelProvenance;
import org.tribuo.util.Util;

import java.util.Arrays;
import java.util.Collections;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
import java.util.Optional;
import java.util.Random;
import java.util.stream.Collectors;

import static org.tribuo.Trainer.DEFAULT_SEED;

Expand All @@ -45,6 +56,11 @@
public class DummyClassifierModel extends Model<Label> {
private static final long serialVersionUID = 1L;

/**
* Protobuf serialization version.
*/
public static final int CURRENT_VERSION = 0;

private final DummyType dummyType;

private final Label constantLabel;
Expand Down Expand Up @@ -82,6 +98,55 @@ public class DummyClassifierModel extends Model<Label> {
this.rng = null;
}

private DummyClassifierModel(String name, ModelProvenance provenance, ImmutableFeatureMap featureMap,
ImmutableOutputInfo<Label> outputInfo, DummyType type, Label constantLabel,
double[] cdf, long seed) {
super(name, provenance, featureMap, outputInfo, false);
this.dummyType = type;
this.constantLabel = constantLabel;
this.cdf = cdf;
this.seed = seed;
this.rng = new Random(seed);
}

/**
* Deserialization factory.
* @param version The serialized object version.
* @param className The class name.
* @param message The serialized data.
*/
public static DummyClassifierModel deserializeFromProto(int version, String className, Any message) throws InvalidProtocolBufferException {
if (version < 0 || version > CURRENT_VERSION) {
throw new IllegalArgumentException("Unknown version " + version + ", this class supports at most version " + CURRENT_VERSION);
}
DummyClassifierModelProto proto = message.unpack(DummyClassifierModelProto.class);

ModelDataCarrier<?> carrier = ModelDataCarrier.deserialize(proto.getMetadata());
if (!carrier.outputDomain().getOutput(0).getClass().equals(Label.class)) {
throw new IllegalStateException("Invalid protobuf, output domain is not a label domain, found " + carrier.outputDomain().getClass());
}
@SuppressWarnings("unchecked") // guarded by getClass
ImmutableOutputInfo<Label> outputDomain = (ImmutableOutputInfo<Label>) carrier.outputDomain();

DummyType dummyType = DummyType.valueOf(proto.getDummyType());

Output<?> output = Output.deserialize(proto.getConstantLabel());
if (!(output instanceof Label)) {
throw new IllegalStateException("Invalid protobuf, expected a label, found " + output.getClass());
}
Label constantLabel = (Label) output;

double[] cdf = null;
if (proto.getCdfCount() > 0) {
cdf = Util.toPrimitiveDouble(proto.getCdfList());
}

long seed = proto.getSeed();

return new DummyClassifierModel(carrier.name(), carrier.provenance(), carrier.featureDomain(), outputDomain,
dummyType, constantLabel, cdf, seed);
}

@Override
public Prediction<Label> predict(Example<Label> example) {
switch (dummyType) {
Expand Down Expand Up @@ -110,6 +175,27 @@ public Optional<Excuse<Label>> getExcuse(Example<Label> example) {
return Optional.of(new Excuse<>(example,predict(example),getTopFeatures(1)));
}

@Override
public ModelProto serialize() {
ModelDataCarrier<Label> carrier = createDataCarrier();

DummyClassifierModelProto.Builder modelBuilder = DummyClassifierModelProto.newBuilder();
modelBuilder.setMetadata(carrier.serialize());
modelBuilder.setDummyType(dummyType.name());
modelBuilder.setConstantLabel(constantLabel.serialize());
if (cdf != null) {
modelBuilder.addAllCdf(Arrays.stream(cdf).boxed().collect(Collectors.toList()));
}
modelBuilder.setSeed(seed);

ModelProto.Builder builder = ModelProto.newBuilder();
builder.setSerializedData(Any.pack(modelBuilder.build()));
builder.setClassName(DummyClassifierModel.class.getName());
builder.setVersion(CURRENT_VERSION);

return builder.build();
}

@Override
protected DummyClassifierModel copy(String newName, ModelProvenance newProvenance) {
switch (dummyType) {
Expand Down