forked from tensorflow/rust
-
Notifications
You must be signed in to change notification settings - Fork 0
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Merge pull request tensorflow#68 from Enet4/export-saved-model
Load from saved model support
- Loading branch information
Showing
3 changed files
with
199 additions
and
15 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,43 @@ | ||
import tensorflow as tf | ||
from tensorflow.python.saved_model.builder import SavedModelBuilder | ||
from tensorflow.python.saved_model.signature_def_utils import build_signature_def | ||
from tensorflow.python.saved_model.signature_constants import REGRESS_METHOD_NAME | ||
from tensorflow.python.saved_model.tag_constants import TRAINING, SERVING | ||
from tensorflow.python.saved_model.utils import build_tensor_info | ||
|
||
x = tf.placeholder(tf.float32, name='x') | ||
y = tf.placeholder(tf.float32, name='y') | ||
|
||
w = tf.Variable(tf.random_uniform([1], -1.0, 1.0), name='w') | ||
b = tf.Variable(tf.zeros([1]), name='b') | ||
y_hat = w * x + b | ||
|
||
loss = tf.reduce_mean(tf.square(y_hat - y)) | ||
optimizer = tf.train.GradientDescentOptimizer(0.5) | ||
train = optimizer.minimize(loss, name='train') | ||
|
||
init = tf.variables_initializer(tf.global_variables(), name='init') | ||
|
||
directory = 'examples/saved-regression-model' | ||
builder = SavedModelBuilder(directory) | ||
|
||
with tf.Session(graph=tf.get_default_graph()) as sess: | ||
sess.run(init) | ||
|
||
signature_inputs = { | ||
"x": build_tensor_info(x), | ||
"y": build_tensor_info(y) | ||
} | ||
signature_outputs = { | ||
"out": build_tensor_info(y_hat) | ||
} | ||
signature_def = build_signature_def( | ||
signature_inputs, signature_outputs, | ||
REGRESS_METHOD_NAME) | ||
builder.add_meta_graph_and_variables( | ||
sess, [TRAINING, SERVING], | ||
signature_def_map={ | ||
REGRESS_METHOD_NAME: signature_def | ||
}, | ||
assets_collection=tf.get_collection(tf.GraphKeys.ASSET_FILEPATHS)) | ||
builder.save(as_text=False) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,100 @@ | ||
extern crate random; | ||
extern crate tensorflow; | ||
|
||
use random::Source; | ||
use std::error::Error; | ||
use std::result::Result; | ||
use std::path::Path; | ||
use std::process::exit; | ||
use tensorflow::Code; | ||
use tensorflow::Graph; | ||
use tensorflow::Session; | ||
use tensorflow::SessionOptions; | ||
use tensorflow::Status; | ||
use tensorflow::StepWithGraph; | ||
use tensorflow::Tensor; | ||
|
||
fn main() { | ||
// Putting the main code in another function serves two purposes: | ||
// 1. We can use the try! macro. | ||
// 2. We can call exit safely, which does not run any destructors. | ||
exit(match run() { | ||
Ok(_) => 0, | ||
Err(e) => { | ||
println!("{}", e); | ||
1 | ||
} | ||
}) | ||
} | ||
|
||
fn run() -> Result<(), Box<Error>> { | ||
let export_dir = "examples/saved-regression-model"; // y = w * x + b | ||
if !Path::new(export_dir).exists() { | ||
return Err(Box::new(Status::new_set(Code::NotFound, | ||
&format!("Run 'python regression_savedmodel.py' to generate \ | ||
{} and try again.", | ||
export_dir)) | ||
.unwrap())); | ||
} | ||
|
||
// Generate some test data. | ||
let w = 0.1; | ||
let b = 0.3; | ||
let num_points = 100; | ||
let steps = 201; | ||
let mut rand = random::default(); | ||
let mut x = Tensor::new(&[num_points as u64]); | ||
let mut y = Tensor::new(&[num_points as u64]); | ||
for i in 0..num_points { | ||
x[i] = (2.0 * rand.read::<f64>() - 1.0) as f32; | ||
y[i] = w * x[i] + b; | ||
} | ||
|
||
// Load the saved model exported by regression_savedmodel.py. | ||
let mut graph = Graph::new(); | ||
let mut session = Session::from_saved_model(&SessionOptions::new(), | ||
&["train", "serve"], | ||
&mut graph, | ||
export_dir)?; | ||
let op_x = graph.operation_by_name_required("x")?; | ||
let op_y = graph.operation_by_name_required("y")?; | ||
let op_train = graph.operation_by_name_required("train")?; | ||
let op_w = graph.operation_by_name_required("w")?; | ||
let op_b = graph.operation_by_name_required("b")?; | ||
|
||
// Train the model (e.g. for fine tuning). | ||
let mut train_step = StepWithGraph::new(); | ||
train_step.add_input(&op_x, 0, &x); | ||
train_step.add_input(&op_y, 0, &y); | ||
train_step.add_target(&op_train); | ||
for _ in 0..steps { | ||
try!(session.run(&mut train_step)); | ||
} | ||
|
||
// Grab the data out of the session. | ||
let mut output_step = StepWithGraph::new(); | ||
let w_ix = output_step.request_output(&op_w, 0); | ||
let b_ix = output_step.request_output(&op_b, 0); | ||
try!(session.run(&mut output_step)); | ||
|
||
// Check our results. | ||
let w_hat: f32 = try!(output_step.take_output(w_ix)).data()[0]; | ||
let b_hat: f32 = try!(output_step.take_output(b_ix)).data()[0]; | ||
println!("Checking w: expected {}, got {}. {}", | ||
w, | ||
w_hat, | ||
if (w - w_hat).abs() < 1e-3 { | ||
"Success!" | ||
} else { | ||
"FAIL" | ||
}); | ||
println!("Checking b: expected {}, got {}. {}", | ||
b, | ||
b_hat, | ||
if (b - b_hat).abs() < 1e-3 { | ||
"Success!" | ||
} else { | ||
"FAIL" | ||
}); | ||
Ok(()) | ||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters