Skip to content

Commit

Permalink
Integrate protobuf in matlab.
Browse files Browse the repository at this point in the history
The separate protobuf (e.g.,
https://gist.github.com/jiayuzhou/b5029bb1ba7bd7f1d911) is likely to
crash Matlab due to the conflict below:

[libprotobuf ERROR google/protobuf/descriptor_database.cc:57] File
already exists in database: caffe.proto
[libprotobuf FATAL google/protobuf/descriptor.cc:1018] CHECK failed:
generated_database_->Add(encoded_file_descriptor, size)
  • Loading branch information
jiayuzhou committed Jul 2, 2015
1 parent 805a995 commit a7397da
Show file tree
Hide file tree
Showing 5 changed files with 152 additions and 2 deletions.
3 changes: 2 additions & 1 deletion Makefile
@@ -1,5 +1,6 @@
PROJECT := caffe


CONFIG_FILE := Makefile.config
# Explicitly check for the config file, otherwise make -k will proceed anyway.
ifeq ($(wildcard $(CONFIG_FILE)),)
Expand Down Expand Up @@ -456,7 +457,7 @@ $(MAT$(PROJECT)_SO): $(MAT$(PROJECT)_SRC) $(STATIC_NAME)
exit 1; \
fi
@ echo MEX $<
$(Q)$(MATLAB_DIR)/bin/mex $(MAT$(PROJECT)_SRC) \
$(Q)$(MATLAB_DIR)/bin/mex -I$(MEXPLUS_DIR) $(MAT$(PROJECT)_SRC) \
CXX="$(CXX)" \
CXXFLAGS="\$$CXXFLAGS $(MATLAB_CXXFLAGS)" \
CXXLIBS="\$$CXXLIBS $(STATIC_LINK_COMMAND) $(LDFLAGS)" -output $@
Expand Down
1 change: 1 addition & 0 deletions Makefile.config.example
Expand Up @@ -45,6 +45,7 @@ BLAS := atlas
# MATLAB directory should contain the mex binary in /bin.
# MATLAB_DIR := /usr/local
# MATLAB_DIR := /Applications/MATLAB_R2012b.app
# MEXPLUS_DIR := /path/to/your/mexplus/include

# NOTE: this is required only if you will compile the python interface.
# We need to be able to find Python.h and numpy/arrayobject.h.
Expand Down
11 changes: 11 additions & 0 deletions matlab/+caffe/fromDatum.m
@@ -0,0 +1,11 @@
function [ label, image ] = fromDatum( varargin )
%FROMDATUM decode image and label from caffe protobuf.

CHECK(nargin > 0, ['usage: '...
'[ label, image ] = fromDatum( datum )']);
datum = varargin{1};

[label, image] = caffe_('from_datum', datum);

end

67 changes: 66 additions & 1 deletion matlab/+caffe/private/caffe_.cpp
Expand Up @@ -14,12 +14,14 @@
#include <vector>

#include "mex.h"

#include "mexplus.h"
#include "caffe/caffe.hpp"

#define MEX_ARGS int nlhs, mxArray **plhs, int nrhs, const mxArray **prhs

using namespace caffe; // NOLINT(build/namespaces)
//using namespace std;
//using namespace mexplus;

// Do CHECK and throw a Mex error if check fails
inline void mxCHECK(bool expr, const char* msg) {
Expand Down Expand Up @@ -478,6 +480,68 @@ static void read_mean(MEX_ARGS) {
mxFree(mean_proto_file);
}

// Usage: caffe_('from_datum', datum)
static void from_datum(MEX_ARGS) {
mexplus::OutputArguments output(nlhs, plhs, 2);
mxCHECK(nrhs == 1 && mxIsChar(prhs[0]),
"Usage: caffe_('from_datum', datum)");

caffe::Datum datum;
std::basic_string<char> datum_content = mexplus::MxArray::to<string>(prhs[0]); //datum received from matlab

mxCHECK(datum.ParseFromString(datum_content),
"Failed to parse datum.");
output.set(0, datum.data());

if (datum.has_encoded() && datum.encoded()) {
output.set(0, datum.data());
}
else {
vector<mwIndex> dimensions(3);
dimensions[0] = (datum.has_height()) ? datum.height() : 0;
dimensions[1] = (datum.has_width()) ? datum.width() : 0;
dimensions[2] = (datum.has_channels()) ? datum.channels() : 0;
mexplus::MxArray array;
vector<mwIndex> subscripts(3);
int index = 0;
if (datum.has_data()) {
array.reset(mxCreateNumericArray(dimensions.size(),
&dimensions[0],
mxUINT8_CLASS,
mxREAL));
const string& data = datum.data();
for (int k = dimensions[2] - 1; k >= 0; --k) { // BGR to RGB order.
subscripts[2] = k;
for (int i = 0; i < dimensions[0]; ++i) {
subscripts[0] = i;
for (int j = 0; j < dimensions[1]; ++j) {
subscripts[1] = j;
array.set(subscripts, data[index++]);
}
}
}
}
else if (datum.float_data_size() > 0) {
array.reset(mxCreateNumericArray(dimensions.size(),
&dimensions[0],
mxSINGLE_CLASS,
mxREAL));
for (int k = dimensions[2] - 1; k >= 0; --k) { // BGR to RGB order.
subscripts[2] = k;
for (int i = 0; i < dimensions[0]; ++i) {
subscripts[0] = i;
for (int j = 0; j < dimensions[1]; ++j) {
subscripts[1] = j;
array.set(subscripts, datum.float_data(index++));
}
}
}
}
output.set(0, array.release());
}
output.set(1, (datum.has_label()) ? datum.label() : 0);
}

/** -----------------------------------------------------------------
** Available commands.
**/
Expand Down Expand Up @@ -515,6 +579,7 @@ static handler_registry handlers[] = {
{ "get_init_key", get_init_key },
{ "reset", reset },
{ "read_mean", read_mean },
{ "from_datum", from_datum },
// The end.
{ "END", NULL },
};
Expand Down
72 changes: 72 additions & 0 deletions matlab/demo/lmdb_datum_demo.m
@@ -0,0 +1,72 @@
% Example of using lmdb and caffe protobuf (datum) in matlab.
%
% by Jiayu, July 1, 2015.
%
% NOTE 1. start matlab with a specified libtiff.5.dylib.
% DYLD_INSERT_LIBRARIES=/usr/local/lib/libtiff.5.dylib /Applications/MATLAB_R2012b.app/bin/matlab &
%
% 2. install matlab-lmdb
% https://github.com/illidanlab/matlab-lmdb
%
% 3. the image num (the first input_num) in the model file should set to 1.
% will fix later.

if exist('../+caffe', 'dir')
addpath('..');
else
error('Please run this demo from caffe/matlab/demo');
end

addpath ../../../matlab-lmdb/ % change to your matlab-lmdb path

cur_director = pwd;
net_model = strcat(cur_director, '/../../examples/mnist/lenet.prototxt');
net_weights = strcat(cur_director, '/../../examples/mnist/lenet_iter_10000.caffemodel');
db_path = strcat(cur_director, '/../../examples/mnist/mnist_test_lmdb');
use_gpu = 0;
phase = 'test';


% create caffe net instance
caffe.set_mode_cpu();
net = caffe.Net(net_model, net_weights, phase);

% load an existing lmdb database (crated using the shell in example).
database = lmdb.DB(db_path, 'RDONLY', true, 'NOLOCK', true);
cursor = database.cursor('RDONLY', true);

max_count = 10; % maximum test cases

count = 0;
correctNum = 0;
while cursor.next()
key = cursor.key;
value = cursor.value;

% transform datum.
[image, label] = caffe.fromDatum(value);

% prepare image
data = single(image);
data = permute(data, [2,1,3]);

% generate prediction
scores = net.forward({data});
predict_class = find(scores{1}==1) - 1; % shift 1


fprintf('[%u] Class %u predicted as %u \n', count+1, label, predict_class)

if(predict_class == label)
correctNum = correctNum + 1;
end

count = count + 1;
if (count >= max_count)
break;
end
end

fprintf('Correctly classified %d images out of %d ( %d percent)\n', correctNum, count, correctNum/count * 100)

clear cursor;

0 comments on commit a7397da

Please sign in to comment.