Skip to content

Commit

Permalink
Merge pull request tensorflow#68 from Enet4/export-saved-model
Browse files Browse the repository at this point in the history
Load from saved model support
  • Loading branch information
adamcrume committed Mar 11, 2017
2 parents 022f4d0 + e0f4612 commit 8bd62df
Show file tree
Hide file tree
Showing 3 changed files with 199 additions and 15 deletions.
43 changes: 43 additions & 0 deletions examples/regression_savedmodel.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,43 @@
import tensorflow as tf
from tensorflow.python.saved_model.builder import SavedModelBuilder
from tensorflow.python.saved_model.signature_def_utils import build_signature_def
from tensorflow.python.saved_model.signature_constants import REGRESS_METHOD_NAME
from tensorflow.python.saved_model.tag_constants import TRAINING, SERVING
from tensorflow.python.saved_model.utils import build_tensor_info

x = tf.placeholder(tf.float32, name='x')
y = tf.placeholder(tf.float32, name='y')

w = tf.Variable(tf.random_uniform([1], -1.0, 1.0), name='w')
b = tf.Variable(tf.zeros([1]), name='b')
y_hat = w * x + b

loss = tf.reduce_mean(tf.square(y_hat - y))
optimizer = tf.train.GradientDescentOptimizer(0.5)
train = optimizer.minimize(loss, name='train')

init = tf.variables_initializer(tf.global_variables(), name='init')

directory = 'examples/saved-regression-model'
builder = SavedModelBuilder(directory)

with tf.Session(graph=tf.get_default_graph()) as sess:
sess.run(init)

signature_inputs = {
"x": build_tensor_info(x),
"y": build_tensor_info(y)
}
signature_outputs = {
"out": build_tensor_info(y_hat)
}
signature_def = build_signature_def(
signature_inputs, signature_outputs,
REGRESS_METHOD_NAME)
builder.add_meta_graph_and_variables(
sess, [TRAINING, SERVING],
signature_def_map={
REGRESS_METHOD_NAME: signature_def
},
assets_collection=tf.get_collection(tf.GraphKeys.ASSET_FILEPATHS))
builder.save(as_text=False)
100 changes: 100 additions & 0 deletions examples/regression_savedmodel.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,100 @@
extern crate random;
extern crate tensorflow;

use random::Source;
use std::error::Error;
use std::result::Result;
use std::path::Path;
use std::process::exit;
use tensorflow::Code;
use tensorflow::Graph;
use tensorflow::Session;
use tensorflow::SessionOptions;
use tensorflow::Status;
use tensorflow::StepWithGraph;
use tensorflow::Tensor;

fn main() {
// Putting the main code in another function serves two purposes:
// 1. We can use the try! macro.
// 2. We can call exit safely, which does not run any destructors.
exit(match run() {
Ok(_) => 0,
Err(e) => {
println!("{}", e);
1
}
})
}

fn run() -> Result<(), Box<Error>> {
let export_dir = "examples/saved-regression-model"; // y = w * x + b
if !Path::new(export_dir).exists() {
return Err(Box::new(Status::new_set(Code::NotFound,
&format!("Run 'python regression_savedmodel.py' to generate \
{} and try again.",
export_dir))
.unwrap()));
}

// Generate some test data.
let w = 0.1;
let b = 0.3;
let num_points = 100;
let steps = 201;
let mut rand = random::default();
let mut x = Tensor::new(&[num_points as u64]);
let mut y = Tensor::new(&[num_points as u64]);
for i in 0..num_points {
x[i] = (2.0 * rand.read::<f64>() - 1.0) as f32;
y[i] = w * x[i] + b;
}

// Load the saved model exported by regression_savedmodel.py.
let mut graph = Graph::new();
let mut session = Session::from_saved_model(&SessionOptions::new(),
&["train", "serve"],
&mut graph,
export_dir)?;
let op_x = graph.operation_by_name_required("x")?;
let op_y = graph.operation_by_name_required("y")?;
let op_train = graph.operation_by_name_required("train")?;
let op_w = graph.operation_by_name_required("w")?;
let op_b = graph.operation_by_name_required("b")?;

// Train the model (e.g. for fine tuning).
let mut train_step = StepWithGraph::new();
train_step.add_input(&op_x, 0, &x);
train_step.add_input(&op_y, 0, &y);
train_step.add_target(&op_train);
for _ in 0..steps {
try!(session.run(&mut train_step));
}

// Grab the data out of the session.
let mut output_step = StepWithGraph::new();
let w_ix = output_step.request_output(&op_w, 0);
let b_ix = output_step.request_output(&op_b, 0);
try!(session.run(&mut output_step));

// Check our results.
let w_hat: f32 = try!(output_step.take_output(w_ix)).data()[0];
let b_hat: f32 = try!(output_step.take_output(b_ix)).data()[0];
println!("Checking w: expected {}, got {}. {}",
w,
w_hat,
if (w - w_hat).abs() < 1e-3 {
"Success!"
} else {
"FAIL"
});
println!("Checking b: expected {}, got {}. {}",
b,
b_hat,
if (b - b_hat).abs() < 1e-3 {
"Success!"
} else {
"FAIL"
});
Ok(())
}
71 changes: 56 additions & 15 deletions src/session.rs
Original file line number Diff line number Diff line change
@@ -1,6 +1,8 @@
use tf;
use libc::c_int;
use libc::{c_char, c_int};
use std::ffi::CString;
use std::marker;
use std::path::Path;
use std::ptr;
use super::Code;
use super::DataType;
Expand Down Expand Up @@ -32,6 +34,45 @@ impl Session {
}
}

/// Loads a session from an exported model.
pub fn from_saved_model<P: AsRef<Path>, Tag: AsRef<str>, Tags: IntoIterator<Item = Tag>>
(options: &SessionOptions,
tags: Tags,
graph: &mut Graph,
export_dir: P)
-> Result<Self> {
let mut status = Status::new();

let export_dir_cstr =
try!(export_dir.as_ref()
.to_str()
.and_then(|s| CString::new(s.as_bytes()).ok())
.ok_or_else(|| invalid_arg!("Invalid export directory path")));

let tags_cstr: Vec<_> = try!(tags.into_iter()
.map(|t| CString::new(t.as_ref()))
.collect::<::std::result::Result<_, _>>()
.map_err(|_| invalid_arg!("Invalid tag name")));
// keeping tags_cstr to retain strings in memory
let tags_ptr: Vec<*const c_char> = tags_cstr.iter().map(|t| t.as_ptr()).collect();

let inner = unsafe {
tf::TF_LoadSessionFromSavedModel(options.inner,
ptr::null(),
export_dir_cstr.as_ptr(),
tags_ptr.as_ptr(),
tags_ptr.len() as c_int,
graph.inner(),
ptr::null_mut(),
status.inner())
};
if inner.is_null() {
Err(status)
} else {
Ok(Session { inner: inner })
}
}

/// Closes the session.
pub fn close(&mut self) -> Result<()> {
let mut status = Status::new();
Expand Down Expand Up @@ -143,19 +184,19 @@ impl<'l> StepWithGraph<'l> {
index: c_int,
tensor: &'l Tensor<T>) {
self.input_ports.push(tf::TF_Output {
oper: operation.inner(),
index: index,
});
oper: operation.inner(),
index: index,
});
self.input_tensors.push(tensor.inner);
}

/// Requests that an output is fetched from the graph after running this step.
/// Returns an index that you can then use to fetch this output from the step after running it.
pub fn request_output(&mut self, operation: &Operation, index: c_int) -> OutputToken {
self.output_ports.push(tf::TF_Output {
oper: operation.inner(),
index: index,
});
oper: operation.inner(),
index: index,
});
self.output_tensors.push(ptr::null_mut());
OutputToken { index: self.output_tensors.len() - 1 }
}
Expand All @@ -172,13 +213,13 @@ impl<'l> StepWithGraph<'l> {
{}",
output_idx,
self.output_tensors.len()))
.unwrap());
.unwrap());
}
if self.output_tensors[output_idx].is_null() {
return Err(Status::new_set(Code::Unavailable,
"Output not available. Either it was already taken, or \
this step has not been sucessfully run yet.")
.unwrap());
.unwrap());
}
let actual_data_type = self.output_data_type(output_idx).unwrap();
if actual_data_type != T::data_type() {
Expand Down Expand Up @@ -260,13 +301,13 @@ mod tests {
let y = {
let mut nd = g.new_operation("Mul", "y").unwrap();
nd.add_input(Output {
operation: &two,
index: 0,
});
operation: &two,
index: 0,
});
nd.add_input(Output {
operation: &x,
index: 0,
});
operation: &x,
index: 0,
});
nd.finish().unwrap()
};
let options = SessionOptions::new();
Expand Down

0 comments on commit 8bd62df

Please sign in to comment.