Skip to content

Commit

Permalink
nice additions to both C API and SimpleModel C++ API
Browse files Browse the repository at this point in the history
  • Loading branch information
memo committed Jun 1, 2017
1 parent ca7932c commit 95083c0
Show file tree
Hide file tree
Showing 17 changed files with 434 additions and 344 deletions.
9 changes: 9 additions & 0 deletions example-basic/src/example-basic.cpp
Expand Up @@ -6,6 +6,8 @@
* Python script constructs tensor flow graph which simply multiplies two numbers and exports binary model (see bin/py)
*
* openFrameworks code loads and processes pre-trained model (i.e. makes calculations/predictions)
* This is using the lower level C API
* For the higer level C++ API (using msa::tf::SimpleModel) see example-pix2pix-simple
*
*/

Expand All @@ -16,6 +18,8 @@ class ofApp : public ofBaseApp {
public:

// shared pointer to tensorflow::Session
// This is using the lower level C API
// For the higer level C++ API (using msa::tf::SimpleModel) see example-pix2pix-simple
msa::tf::Session_ptr session;


Expand All @@ -27,6 +31,11 @@ class ofApp : public ofBaseApp {

// Load graph (i.e. trained model) we exported from python, and initialize session
session = msa::tf::create_session_with_graph("models/model.pb");

if(!session) {
ofLogError() << "Model not found. " << msa::tf::missing_data_error();
ofExit(1);
}
}


Expand Down
35 changes: 15 additions & 20 deletions example-char-rnn/src/example-char-rnn.cpp
Expand Up @@ -30,6 +30,8 @@ class ofApp : public ofBaseApp {
public:

// shared pointer to tensorflow::Session
// This is using the lower level C API
// For the higer level C++ API (using msa::tf::SimpleModel) see example-pix2pix-simple
msa::tf::Session_ptr session;


Expand All @@ -54,9 +56,8 @@ class ofApp : public ofBaseApp {


// model file management
string model_root_dir = "models";
vector<string> model_names;
int cur_model_index = 0;
ofDirectory models_dir; // data/models folder which contains subfolders for each model
int cur_model_index = 0; // which model (i.e. folder) we're currently using


// random generator for sampling
Expand All @@ -81,29 +82,25 @@ class ofApp : public ofBaseApp {


// scan models dir
ofDirectory dir;
dir.listDir(model_root_dir);
if(dir.size()==0) {
ofLogError() << "Could not find models folder. Did you download the data files and place them in the data folder? ";
ofLogError() << "Download from https://github.com/memo/ofxMSATensorFlow/releases";
ofLogError() << "More info at https://github.com/memo/ofxMSATensorFlow/wiki";
models_dir.listDir("models");
if(models_dir.size()==0) {
ofLogError() << "Couldn't find models folder." << msa::tf::missing_data_error();
assert(false);
ofExit(1);
}
for(int i=0; i<dir.getFiles().size(); i++) model_names.push_back(dir.getName(i));
sort(model_names.begin(), model_names.end());
load_model_index(0);
models_dir.sort();
load_model_index(0); // load first model

// seed rng
rng.seed(ofGetSystemTimeMicros());
}


//--------------------------------------------------------------
// Load graph (model trained in and exported from python) by folder INDEX, and initialize session
// Load model by folder INDEX
void load_model_index(int index) {
cur_model_index = ofClamp(index, 0, model_names.size()-1);
load_model(model_root_dir + "/" + model_names[cur_model_index]);
cur_model_index = ofClamp(index, 0, models_dir.size()-1);
load_model(models_dir.getPath(cur_model_index));
}


Expand All @@ -114,9 +111,7 @@ class ofApp : public ofBaseApp {
session = msa::tf::create_session_with_graph(dir + "/graph_frz.pb");

if(!session) {
ofLogError() << "Could not initialize session. Did you download the data files and place them in the data folder? ";
ofLogError() << "Download from https://github.com/memo/ofxMSATensorFlow/releases";
ofLogError() << "More info at https://github.com/memo/ofxMSATensorFlow/wiki";
ofLogError() << "Session init error." << msa::tf::missing_data_error();
assert(false);
ofExit(1);
}
Expand Down Expand Up @@ -247,9 +242,9 @@ class ofApp : public ofBaseApp {
str << endl;

str << "Press number key to load model: " << endl;
for(int i=0; i<model_names.size(); i++) {
for(int i=0; i<models_dir.size(); i++) {
auto marker = (i==cur_model_index) ? ">" : " ";
str << " " << (i+1) << " : " << marker << " " << model_names[i] << endl;
str << " " << (i+1) << " : " << marker << " " << models_dir.getName(i) << endl;
}

str << endl;
Expand Down
35 changes: 15 additions & 20 deletions example-handwriting-rnn/src/example-handwriting-rnn.cpp
Expand Up @@ -30,6 +30,8 @@ class ofApp : public ofBaseApp {
public:

// shared pointer to tensorflow::Session
// This is using the lower level C API
// For the higer level C++ API (using msa::tf::SimpleModel) see example-pix2pix-simple
msa::tf::Session_ptr session;

// tensors in and out of model
Expand Down Expand Up @@ -63,9 +65,8 @@ class ofApp : public ofBaseApp {


// model file management
string model_root_dir = "models";
vector<string> model_names;
int cur_model_index = 0;
ofDirectory models_dir; // data/models folder which contains subfolders for each model
int cur_model_index = 0; // which model (i.e. folder) we're currently using


// random generator for sampling
Expand All @@ -91,18 +92,14 @@ class ofApp : public ofBaseApp {
ofBackground(220);

// scan models dir
ofDirectory dir;
dir.listDir(model_root_dir);
if(dir.size()==0) {
ofLogError() << "Could not find models folder. Did you download the data files and place them in the data folder? ";
ofLogError() << "Download from https://github.com/memo/ofxMSATensorFlow/releases";
ofLogError() << "More info at https://github.com/memo/ofxMSATensorFlow/wiki";
models_dir.listDir("models");
if(models_dir.size()==0) {
ofLogError() << "Couldn't find models folder." << msa::tf::missing_data_error();
assert(false);
ofExit(1);
}
for(int i=0; i<dir.getFiles().size(); i++) model_names.push_back(dir.getName(i));
sort(model_names.begin(), model_names.end());
load_model_index(0);
models_dir.sort();
load_model_index(0); // load first model

// seed rng
rng.seed(ofGetSystemTimeMicros());
Expand All @@ -119,10 +116,10 @@ class ofApp : public ofBaseApp {


//--------------------------------------------------------------
// Load graph (model trained in and exported from python) by folder INDEX, and initialize session
// Load model by folder INDEX
void load_model_index(int index) {
cur_model_index = ofClamp(index, 0, model_names.size()-1);
load_model(model_root_dir + "/" + model_names[cur_model_index]);
cur_model_index = ofClamp(index, 0, models_dir.size()-1);
load_model(models_dir.getPath(cur_model_index));
}


Expand All @@ -133,9 +130,7 @@ class ofApp : public ofBaseApp {
session = msa::tf::create_session_with_graph(dir + "/graph_frz.pb");

if(!session) {
ofLogError() << "Could not initialize session. Did you download the data files and place them in the data folder? ";
ofLogError() << "Download from https://github.com/memo/ofxMSATensorFlow/releases";
ofLogError() << "More info at https://github.com/memo/ofxMSATensorFlow/wiki";
ofLogError() << "Session init error." << msa::tf::missing_data_error();
assert(false);
ofExit(1);
}
Expand Down Expand Up @@ -222,9 +217,9 @@ class ofApp : public ofBaseApp {
str << endl;

str << "Press number key to load model: " << endl;
for(int i=0; i<model_names.size(); i++) {
for(int i=0; i<models_dir.size(); i++) {
auto marker = (i==cur_model_index) ? ">" : " ";
str << " " << (i+1) << " : " << marker << " " << model_names[i] << endl;
str << " " << (i+1) << " : " << marker << " " << models_dir.getName(i) << endl;
}


Expand Down
1 change: 1 addition & 0 deletions example-inception3/src/example-inception3.cpp
Expand Up @@ -19,6 +19,7 @@ class ofApp : public ofBaseApp {

// classifies pixels
// check the src of this class (ofxMSATFImageClassifier) to see how to do more generic stuff with ofxMSATensorFlow
// UPDATE: Actually the msa::tf::SimpleModel supercedes this. Need to update it.
msa::tf::ImageClassifier classifier;

// for webcam input
Expand Down
5 changes: 2 additions & 3 deletions example-mnist/src/example-mnist.cpp
Expand Up @@ -20,6 +20,7 @@ class ofApp : public ofBaseApp {

// classifies pixels
// check the src of this class (ofxMSATFImageClassifier) to see how to do more generic stuff with ofxMSATensorFlow
// UPDATE: Actually the msa::tf::SimpleModel supercedes this. Need to update it.
msa::tf::ImageClassifier classifier;

// simple visualization of weights layer,
Expand Down Expand Up @@ -65,9 +66,7 @@ class ofApp : public ofBaseApp {
// initialize classifier with these settings
classifier.setup(settings);
if(!classifier.getGraphDef()) {
ofLogError() << "Could not initialize session. Did you download the data files and place them in the data folder? ";
ofLogError() << "Download from https://github.com/memo/ofxMSATensorFlow/releases";
ofLogError() << "More info at https://github.com/memo/ofxMSATensorFlow/wiki";
ofLogError() << "Session init error." << msa::tf::missing_data_error();
assert(false);
ofExit(1);
}
Expand Down
49 changes: 20 additions & 29 deletions example-pix2pix-webcam/src/example-pix2pix-webcam.cpp
Expand Up @@ -14,16 +14,16 @@ I'm using a very simple and ghetto method of transforming the webcam input into
//--------------------------------------------------------------
class ofApp : public ofBaseApp {
public:
// a simple wrapper for a simple predictor model with one (n-dim) input and one (n-dim) output
// a simple wrapper for a simple predictor model with variable number of inputs and outputs
msa::tf::SimpleModel model;

// a bunch of properties of the models
// ideally should read from disk and vary with the model
// but trying to keep the code minimal so hardcoding them since they're the same for all models
const int input_shape[2] = {256, 256}; // dimensions {height, width} for input image
const int output_shape[2] = {256, 256}; // dimensions {height, width} for output image
ofVec2f input_range = {-1, 1}; // range of values {min, max} that model expects for input
ofVec2f output_range = {-1, 1}; // range of values {min, max} that model outputs
const ofVec2f input_range = {-1, 1}; // range of values {min, max} that model expects for input
const ofVec2f output_range = {-1, 1}; // range of values {min, max} that model outputs
const string input_op_name = "generator/generator_inputs"; // name of op to feed input to
const string output_op_name = "generator/generator_outputs"; // name of op to fetch output from

Expand All @@ -48,9 +48,8 @@ class ofApp : public ofBaseApp {
ofFloatImage img_out; // output from the model

// model file management
string model_root_dir = "models";
vector<string> model_names;
int cur_model_index = 0;
ofDirectory models_dir; // data/models folder which contains subfolders for each model
int cur_model_index = 0; // which model (i.e. folder) we're currently using


// color management for drawing
Expand Down Expand Up @@ -78,18 +77,15 @@ class ofApp : public ofBaseApp {


// scan models dir
ofDirectory dir;
dir.listDir(model_root_dir);
dir.sort();
if(dir.size()==0) {
ofLogError() << "Could not find models folder. Did you download the data files and place them in the data folder? ";
ofLogError() << "Download from https://github.com/memo/ofxMSATensorFlow/releases";
ofLogError() << "More info at https://github.com/memo/ofxMSATensorFlow/wiki";
models_dir.listDir("models");
if(models_dir.size()==0) {
ofLogError() << "Couldn't find models folder." << msa::tf::missing_data_error();
assert(false);
ofExit(1);
}
for(int i=0; i<dir.getFiles().size(); i++) model_names.push_back(dir.getName(i));
load_model_index(0);
models_dir.sort();
load_model_index(0); // load first model


// init video grabber
video_grabber.setDeviceID(0);
Expand All @@ -107,9 +103,8 @@ class ofApp : public ofBaseApp {
// note that it expects arrays for input op names and output op names, so just use {}
model.setup(ofFilePath::join(model_dir, "graph_frz.pb"), {input_op_name}, {output_op_name});
if(! model.is_loaded()) {
ofLogError() << "Model init error. Did you download the data files and place them in the data folder? ";
ofLogError() << "Download from https://github.com/memo/ofxMSATensorFlow/releases";
ofLogError() << "More info at https://github.com/memo/ofxMSATensorFlow/wiki";
ofLogError() << "Model init error.";
ofLogError() << msa::tf::missing_data_error();
assert(false);
ofExit(1);
}
Expand Down Expand Up @@ -179,10 +174,10 @@ class ofApp : public ofBaseApp {


//--------------------------------------------------------------
// Load model by folder INDEX, and initialise session
// Load model by folder INDEX
void load_model_index(int index) {
cur_model_index = ofClamp(index, 0, model_names.size()-1);
load_model(model_root_dir + "/" + model_names[cur_model_index]);
cur_model_index = ofClamp(index, 0, models_dir.size()-1);
load_model(models_dir.getPath(cur_model_index));
}


Expand Down Expand Up @@ -276,7 +271,7 @@ class ofApp : public ofBaseApp {

// run model on it
if(do_auto_run)
model.run(img_in, img_out, input_range, output_range);
model.run_image_to_image(img_in, img_out, input_range, output_range);

// DISPLAY STUFF
stringstream str;
Expand All @@ -293,17 +288,13 @@ class ofApp : public ofBaseApp {
str << endl;
str << "draw in the box on the left" << endl;
str << "or drag an image (PNG) into it" << endl;

str << endl;
str << "Press number key to load model: " << endl;
str << endl;

for(int i=0; i<model_names.size(); i++) {
str << "Press number key to load model: " << endl;
for(int i=0; i<models_dir.size(); i++) {
auto marker = (i==cur_model_index) ? ">" : " ";
str << " " << (i+1) << " : " << marker << " " << model_names[i] << endl;
str << " " << (i+1) << " : " << marker << " " << models_dir.getName(i) << endl;
}
str << endl;



ofPushMatrix();
Expand Down

0 comments on commit 95083c0

Please sign in to comment.