Skip to content

Commit

Permalink
Merge pull request #272 from Mistobaan/features/tensorflow/add_graph_dot
Browse files Browse the repository at this point in the history
tensorflow: bind DotGraph
  • Loading branch information
saudet authored Aug 31, 2016
2 parents 576a262 + 7843af0 commit 2ecc400
Showing 1 changed file with 71 additions and 1 deletion.
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,8 @@
import java.lang.annotation.Retention;
import java.lang.annotation.RetentionPolicy;
import java.lang.annotation.Target;

import org.bytedeco.javacpp.BytePointer;
import org.bytedeco.javacpp.FunctionPointer;
import org.bytedeco.javacpp.Loader;
import org.bytedeco.javacpp.Pointer;
Expand All @@ -35,6 +37,7 @@
import org.bytedeco.javacpp.annotation.Cast;
import org.bytedeco.javacpp.annotation.Platform;
import org.bytedeco.javacpp.annotation.Properties;
import org.bytedeco.javacpp.annotation.StdString;
import org.bytedeco.javacpp.tools.Info;
import org.bytedeco.javacpp.tools.InfoMap;
import org.bytedeco.javacpp.tools.InfoMapper;
Expand All @@ -49,7 +52,10 @@
"tensorflow/core/lib/core/stringpiece.h", */ "tensorflow/core/platform/types.h", "tensorflow/core/platform/mutex.h",
"tensorflow/core/platform/macros.h", "tensorflow/core/util/port.h", "tensorflow/core/lib/core/error_codes.pb.h",
"tensorflow/core/platform/logging.h", "tensorflow/core/lib/core/status.h", "tensorflow/core/platform/protobuf.h",
"tensorflow/core/platform/file_system.h", "tensorflow/core/platform/env.h", "tensorflow/core/protobuf/config.pb.h", "tensorflow/core/framework/cost_graph.pb.h",
"tensorflow/core/platform/file_system.h", "tensorflow/core/platform/env.h",
"tensorflow/core/graph/dot.h",
"tensorflow/core/graph/graph.h",
"tensorflow/core/protobuf/config.pb.h", "tensorflow/core/framework/cost_graph.pb.h",
"tensorflow/core/framework/step_stats.pb.h", "tensorflow/core/framework/versions.pb.h", "tensorflow/core/public/session_options.h",
"tensorflow/core/lib/core/threadpool.h", "tensorflow/core/framework/allocation_description.pb.h", "tensorflow/core/framework/allocator.h",
"tensorflow/core/framework/tensor_shape.pb.h", "tensorflow/core/framework/types.pb.h", "tensorflow/core/framework/tensor.pb.h",
Expand Down Expand Up @@ -157,6 +163,23 @@ public void map(InfoMap infoMap) {
}
}

infoMap.put(new Info("tensorflow::DotOptions::edge_label")
.javaText("@MemberSetter public native DotOptions edge_label(EdgeLabelFunction edge_label_function);"))
.put(new Info("tensorflow::DotOptions::node_label")
.javaText("@MemberSetter public native DotOptions node_label(NodeLabelFunction node_label_function);"))
.put(new Info("tensorflow::DotOptions::edge_cost")
.javaText("@MemberSetter public native DotOptions edge_cost(EdgeCostFunction edge_cost_function);"))
.put(new Info("tensorflow::DotOptions::node_cost")
.javaText("@MemberSetter public native DotOptions node_cost(NodeCostFunction node_cost_function);"))
.put(new Info("tensorflow::DotOptions::node_color")
.javaText("@MemberSetter public native DotOptions node_color(NodeColorFunction node_color_function);"))

.put(new Info("std::function<double(const *tensorflow::Edge)>").pointerTypes("EdgeCostFunction"))
.put(new Info("std::function<double(const *tensorflow::Node)>").pointerTypes("NodeCostFunction"))
.put(new Info("std::function<std::string(const *tensorflow::Node)>").pointerTypes("NodeLabelFunction"))
.put(new Info("std::function<std::string(const *tensorflow::Edge)>").pointerTypes("EdgeLabelFunction"))
.put(new Info("std::function<int(const *tensorflow::Node)>").pointerTypes("NodeColorFunction"));

infoMap.put(new Info("tensorflow::gtl::ArraySlice").annotations("@ArraySlice"))
.put(new Info("tensorflow::StringPiece").annotations("@StringPiece").valueTypes("BytePointer", "String").pointerTypes("BytePointer"))
.put(new Info("tensorflow::ops::Const(tensorflow::StringPiece, tensorflow::GraphDefBuilder::Options&)")
Expand Down Expand Up @@ -191,6 +214,53 @@ public static class ConsiderFunction extends FunctionPointer {
public native @Cast("bool") boolean call(@Cast("const tensorflow::Node*") Pointer node);
}

public static class NodeColorFunction extends FunctionPointer {
static { Loader.load(); }
/** Pointer cast constructor. Invokes {@link Pointer#Pointer(Pointer)}. */
public NodeColorFunction(Pointer p) { super(p); }
protected NodeColorFunction() { allocate(); }
private native void allocate();
public native @Cast("int") int call(@Cast("const tensorflow::Node*") Pointer node);
}

public static class NodeCostFunction extends FunctionPointer {
static { Loader.load(); }
/** Pointer cast constructor. Invokes {@link Pointer#Pointer(Pointer)}. */
public NodeCostFunction(Pointer p) { super(p); }
protected NodeCostFunction() { allocate(); }
private native void allocate();
public native @Cast("double") double call(@Cast("const tensorflow::Node*") Pointer node);
}

public static class EdgeCostFunction extends FunctionPointer {
static { Loader.load(); }
/** Pointer cast constructor. Invokes {@link Pointer#Pointer(Pointer)}. */
public EdgeCostFunction(Pointer p) { super(p); }
protected EdgeCostFunction() { allocate(); }
private native void allocate();
public native @Cast("double") double call(@Cast("const tensorflow::Edge*") Pointer node);
}

public static class NodeLabelFunction extends FunctionPointer {
static { Loader.load(); }
/** Pointer cast constructor. Invokes {@link Pointer#Pointer(Pointer)}. */
public NodeLabelFunction(Pointer p) { super(p); }
protected NodeLabelFunction() { allocate(); }
private native void allocate();
public native @StdString BytePointer call(@Cast("const tensorflow::Node*") Pointer node);
}

public static class EdgeLabelFunction extends FunctionPointer {
static { Loader.load(); }
/** Pointer cast constructor. Invokes {@link Pointer#Pointer(Pointer)}. */
public EdgeLabelFunction(Pointer p) { super(p); }
protected EdgeLabelFunction() { allocate(); }
private native void allocate();
public native @StdString BytePointer call(@Cast("const tensorflow::Edge*") Pointer node);
}



@Documented @Retention(RetentionPolicy.RUNTIME)
@Target({ElementType.METHOD, ElementType.PARAMETER})
@Cast({"tensorflow::gtl::ArraySlice", "&"}) @Adapter("ArraySliceAdapter")
Expand Down

0 comments on commit 2ecc400

Please sign in to comment.