Skip to content

Commit

Permalink
Tensorflow-Java 0.4.0 update (#195)
Browse files Browse the repository at this point in the history
* Updating to TF-Java 0.4.0-SNAPSHOT.

* Removing unnecessary init code.

* Updates for the latest TF-Java snapshot.

* Bumping to released TF-Java 0.4.0.
  • Loading branch information
Craigacp committed Dec 9, 2021
1 parent 831e1a2 commit ac9802e
Show file tree
Hide file tree
Showing 22 changed files with 159 additions and 211 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,6 @@
import org.tribuo.Example;
import org.tribuo.Feature;
import org.tribuo.ImmutableFeatureMap;
import org.tribuo.Output;
import org.tribuo.math.la.DenseVector;
import org.tribuo.math.la.SGDVector;
import org.tribuo.math.la.VectorTuple;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,6 @@
import com.oracle.labs.mlrg.olcut.provenance.Provenancable;
import org.tribuo.Example;
import org.tribuo.ImmutableFeatureMap;
import org.tribuo.Output;
import org.tribuo.math.la.SGDVector;

import java.io.Serializable;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,6 @@
import org.tribuo.Example;
import org.tribuo.Feature;
import org.tribuo.ImmutableFeatureMap;
import org.tribuo.Output;
import org.tribuo.math.la.SGDVector;
import org.tribuo.math.la.VectorTuple;

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -65,7 +65,7 @@ public BiFunction<Ops, Pair<Placeholder<? extends TNumber>,Operand<TNumber>>,Ope
return (ops,pair) -> {
@SuppressWarnings("unchecked") // cast off the wildcard to the superclass
Placeholder<TNumber> placeholder = (Placeholder<TNumber>) pair.getA();
return ops.math.mean(ops.nn.raw.softmaxCrossEntropyWithLogits(pair.getB(),placeholder).loss(),ops.constant(0));
return ops.math.mean(ops.nn.softmaxCrossEntropyWithLogits(pair.getB(),placeholder).loss(),ops.constant(0));
};
// TODO - migrate over to TF-Java's CategoricalCrossEntropy when we've fixed the issue we had applying this.
// It should be roughly the block below.
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@
import com.oracle.labs.mlrg.olcut.util.Pair;
import org.tensorflow.Operand;
import org.tensorflow.Tensor;
import org.tensorflow.framework.op.FrameworkOps;
import org.tensorflow.ndarray.FloatNdArray;
import org.tensorflow.ndarray.Shape;
import org.tensorflow.ndarray.index.Indices;
Expand Down Expand Up @@ -73,9 +74,10 @@ public MultiLabelConverter() {}
@Override
public BiFunction<Ops, Pair<Placeholder<? extends TNumber>,Operand<TNumber>>,Operand<TNumber>> loss() {
return (ops,pair) -> {
FrameworkOps frameworkOps = FrameworkOps.create(ops);
@SuppressWarnings("unchecked") // cast off the wildcard to the superclass
Placeholder<TNumber> placeholder = (Placeholder<TNumber>) pair.getA();
return ops.math.mean(ops.nn.sigmoidCrossEntropyWithLogits(placeholder,pair.getB()),ops.constant(0));
return ops.math.mean(frameworkOps.nn.sigmoidCrossEntropyWithLogits(placeholder,pair.getB()),ops.constant(0));
};
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -60,7 +60,7 @@ public RegressorConverter() {}
*/
@Override
public BiFunction<Ops, Pair<Placeholder<? extends TNumber>,Operand<TNumber>>,Operand<TNumber>> loss() {
return (ops, pair) -> new MeanSquaredError(ops, "tribuo-mse", Reduction.SUM_OVER_BATCH_SIZE).call(pair.getA(),pair.getB());
return (ops, pair) -> new MeanSquaredError("tribuo-mse", Reduction.SUM_OVER_BATCH_SIZE).call(ops,pair.getA(),pair.getB());
}

/**
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -54,7 +54,7 @@
* <p>
* N.B. TensorFlow support is experimental and may change without a major version bump.
*/
public class TensorFlowCheckpointModel<T extends Output<T>> extends TensorFlowModel<T> implements Closeable {
public final class TensorFlowCheckpointModel<T extends Output<T>> extends TensorFlowModel<T> implements Closeable {
private static final Logger logger = Logger.getLogger(TensorFlowCheckpointModel.class.getName());

private static final long serialVersionUID = 200L;
Expand All @@ -65,8 +65,8 @@ public class TensorFlowCheckpointModel<T extends Output<T>> extends TensorFlowMo

private boolean initialized;

TensorFlowCheckpointModel(String name, ModelProvenance description, ImmutableFeatureMap featureIDMap, ImmutableOutputInfo<T> outputIDMap, GraphDef graphDef, String checkpointDirectory, String checkpointName, int batchSize, String initName, String outputName, FeatureConverter featureConverter, OutputConverter<T> outputConverter) {
super(name, description, featureIDMap, outputIDMap, graphDef, batchSize, initName, outputName, featureConverter, outputConverter);
TensorFlowCheckpointModel(String name, ModelProvenance description, ImmutableFeatureMap featureIDMap, ImmutableOutputInfo<T> outputIDMap, GraphDef graphDef, String checkpointDirectory, String checkpointName, int batchSize, String outputName, FeatureConverter featureConverter, OutputConverter<T> outputConverter) {
super(name, description, featureIDMap, outputIDMap, graphDef, batchSize, outputName, featureConverter, outputConverter);
this.checkpointDirectory = checkpointDirectory;
this.checkpointName = checkpointName;
try {
Expand Down Expand Up @@ -155,12 +155,12 @@ public String getCheckpointName() {
public TensorFlowNativeModel<T> convertToNativeModel() {
Map<String, TensorFlowUtil.TensorTuple> tensorMap = TensorFlowUtil.extractMarshalledVariables(modelGraph,session);
return new TensorFlowNativeModel<>(name, provenance, featureIDMap,
outputIDInfo, modelGraph.toGraphDef(), tensorMap, batchSize, initName, outputName, featureConverter, outputConverter);
outputIDInfo, modelGraph.toGraphDef(), tensorMap, batchSize, outputName, featureConverter, outputConverter);
}

@Override
protected TensorFlowCheckpointModel<T> copy(String newName, ModelProvenance newProvenance) {
return new TensorFlowCheckpointModel<>(newName,newProvenance,featureIDMap,outputIDInfo,modelGraph.toGraphDef(),checkpointDirectory,checkpointName,batchSize,initName,outputName, featureConverter, outputConverter);
return new TensorFlowCheckpointModel<>(newName,newProvenance,featureIDMap,outputIDInfo,modelGraph.toGraphDef(),checkpointDirectory,checkpointName,batchSize,outputName, featureConverter, outputConverter);
}

private void writeObject(java.io.ObjectOutputStream out) throws IOException {
Expand Down
Original file line number Diff line number Diff line change
@@ -1,11 +1,27 @@
/*
* Copyright (c) 2015-2021, Oracle and/or its affiliates. All rights reserved.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

package org.tribuo.interop.tensorflow;

import com.oracle.labs.mlrg.olcut.util.Pair;
import org.tensorflow.ConcreteFunction;
import org.tensorflow.Graph;
import org.tensorflow.Operation;
import org.tensorflow.SavedModelBundle;
import org.tensorflow.Session;
import org.tensorflow.SessionFunction;
import org.tensorflow.Signature;
import org.tensorflow.Tensor;
import org.tensorflow.proto.framework.GraphDef;
Expand Down Expand Up @@ -42,7 +58,6 @@ public abstract class TensorFlowModel<T extends Output<T>> extends Model<T> impl
private static final long serialVersionUID = 200L;

protected int batchSize;
protected final String initName;
protected final String outputName;
protected final FeatureConverter featureConverter;
protected final OutputConverter<T> outputConverter;
Expand All @@ -58,18 +73,16 @@ public abstract class TensorFlowModel<T extends Output<T>> extends Model<T> impl
* @param outputIDInfo The output domain.
* @param trainedGraphDef The graph definition.
* @param batchSize The test time batch size.
* @param initName The name of the initialization operation.
* @param outputName The name of the output operation.
* @param featureConverter The feature converter.
* @param outputConverter The output converter.
*/
protected TensorFlowModel(String name, ModelProvenance provenance, ImmutableFeatureMap featureIDMap, ImmutableOutputInfo<T> outputIDInfo, GraphDef trainedGraphDef, int batchSize, String initName, String outputName, FeatureConverter featureConverter, OutputConverter<T> outputConverter) {
protected TensorFlowModel(String name, ModelProvenance provenance, ImmutableFeatureMap featureIDMap, ImmutableOutputInfo<T> outputIDInfo, GraphDef trainedGraphDef, int batchSize, String outputName, FeatureConverter featureConverter, OutputConverter<T> outputConverter) {
super(name, provenance, featureIDMap, outputIDInfo, outputConverter.generatesProbabilities());
this.modelGraph = new Graph();
this.modelGraph.importGraphDef(trainedGraphDef);
this.session = new Session(modelGraph);
this.batchSize = batchSize;
this.initName = initName;
this.outputName = outputName;
this.featureConverter = featureConverter;
this.outputConverter = outputConverter;
Expand Down Expand Up @@ -209,7 +222,7 @@ public void exportModel(String path) throws IOException {
}
Operation outputOp = modelGraph.operation(outputName);
Signature modelSig = sigBuilder.output(outputName, outputOp.output(0)).build();
ConcreteFunction concFunc = ConcreteFunction.create(modelSig, session);
SessionFunction concFunc = SessionFunction.create(modelSig, session);
SavedModelBundle.exporter(path).withFunction(concFunc).export();
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -43,19 +43,17 @@
* <p>
* N.B. TensorFlow support is experimental and may change without a major version bump.
*/
public class TensorFlowNativeModel<T extends Output<T>> extends TensorFlowModel<T> {
public final class TensorFlowNativeModel<T extends Output<T>> extends TensorFlowModel<T> {
private static final long serialVersionUID = 200L;

TensorFlowNativeModel(String name, ModelProvenance description, ImmutableFeatureMap featureIDMap, ImmutableOutputInfo<T> outputIDMap, GraphDef trainedGraphDef, Map<String, TensorFlowUtil.TensorTuple> tensorMap, int batchSize, String initName, String outputName, FeatureConverter featureConverter, OutputConverter<T> outputConverter) {
super(name, description, featureIDMap, outputIDMap, trainedGraphDef, batchSize, initName, outputName, featureConverter, outputConverter);
// Initialises the parameters.
session.run(initName);
TensorFlowNativeModel(String name, ModelProvenance description, ImmutableFeatureMap featureIDMap, ImmutableOutputInfo<T> outputIDMap, GraphDef trainedGraphDef, Map<String, TensorFlowUtil.TensorTuple> tensorMap, int batchSize, String outputName, FeatureConverter featureConverter, OutputConverter<T> outputConverter) {
super(name, description, featureIDMap, outputIDMap, trainedGraphDef, batchSize, outputName, featureConverter, outputConverter);
TensorFlowUtil.restoreMarshalledVariables(session,tensorMap);
}

@Override
protected TensorFlowNativeModel<T> copy(String newName, ModelProvenance newProvenance) {
return new TensorFlowNativeModel<>(newName,newProvenance,featureIDMap,outputIDInfo,modelGraph.toGraphDef(), TensorFlowUtil.extractMarshalledVariables(modelGraph,session),batchSize,initName,outputName, featureConverter, outputConverter);
return new TensorFlowNativeModel<>(newName,newProvenance,featureIDMap,outputIDInfo,modelGraph.toGraphDef(), TensorFlowUtil.extractMarshalledVariables(modelGraph,session),batchSize,outputName, featureConverter, outputConverter);
}

/**
Expand All @@ -67,7 +65,7 @@ protected TensorFlowNativeModel<T> copy(String newName, ModelProvenance newProve
public TensorFlowCheckpointModel<T> convertToCheckpointModel(String checkpointDirectory, String checkpointName) {
session.save(Paths.get(checkpointDirectory,checkpointName).toString());
return new TensorFlowCheckpointModel<>(name, provenance, featureIDMap,
outputIDInfo, modelGraph.toGraphDef(), checkpointDirectory, checkpointName, batchSize, initName, outputName, featureConverter, outputConverter);
outputIDInfo, modelGraph.toGraphDef(), checkpointDirectory, checkpointName, batchSize, outputName, featureConverter, outputConverter);
}

private void writeObject(java.io.ObjectOutputStream out) throws IOException {
Expand All @@ -89,8 +87,6 @@ private void readObject(java.io.ObjectInputStream in) throws IOException, ClassN
modelGraph = new Graph();
modelGraph.importGraphDef(GraphDef.parseFrom(modelBytes));
session = new Session(modelGraph);
// Initialises the parameters.
session.run(initName);
TensorFlowUtil.restoreMarshalledVariables(session,tensorMap);
}
}
Loading

0 comments on commit ac9802e

Please sign in to comment.