Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

JavaCV version of OpenCV caffe_googlenet.cpp #341

Merged
merged 1 commit into from Feb 27, 2016
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
122 changes: 122 additions & 0 deletions samples/CaffeGooglenet.java
@@ -0,0 +1,122 @@
/*
* JavaCV version of OpenCV caffe_googlenet.cpp
* https://github.com/ludv1x/opencv_contrib/blob/master/modules/dnn/samples/caffe_googlenet.cpp
*
* Paolo Bolettieri <paolo.bolettieri@gmail.com>
*/

import static org.bytedeco.javacpp.opencv_core.minMaxLoc;
import static org.bytedeco.javacpp.opencv_dnn.createCaffeImporter;
import static org.bytedeco.javacpp.opencv_imgcodecs.imread;
import static org.bytedeco.javacpp.opencv_imgproc.resize;

import java.io.BufferedReader;
import java.io.File;
import java.io.FileReader;
import java.io.IOException;
import java.util.ArrayList;
import java.util.List;

import org.bytedeco.javacpp.opencv_core.Mat;
import org.bytedeco.javacpp.opencv_core.Point;
import org.bytedeco.javacpp.opencv_core.Size;
import org.bytedeco.javacpp.opencv_dnn.Blob;
import org.bytedeco.javacpp.opencv_dnn.Importer;
import org.bytedeco.javacpp.opencv_dnn.Net;

public class CaffeGooglenet {

/* Find best class for the blob (i. e. class with maximal probability) */
public static void getMaxClass(Blob probBlob, Point classId, double[] classProb) {
Mat probMat =probBlob.matRefConst().reshape(1, 1); //reshape the blob to 1x1000 matrix
minMaxLoc(probMat, null, classProb, null, classId, null);
}

public static List<String> readClassNames() {
String filename = "synset_words.txt";
List<String> classNames = null;

try (BufferedReader br = new BufferedReader(new FileReader(new File(filename)))) {
classNames = new ArrayList<String>();
String name = null;
while ((name = br.readLine()) != null) {
classNames.add(name.substring(name.indexOf(' ')+1));
}
} catch (IOException ex) {
System.err.println("File with classes labels not found " + filename);
System.exit(-1);
}
return classNames;
}

public static void main(String[] args) throws Exception {
String modelTxt = "bvlc_googlenet.prototxt";
String modelBin = "bvlc_googlenet.caffemodel";
String imageFile = (args.length > 0) ? args[0] : "space_shuttle.jpg";

//! [Create the importer of Caffe model]
Importer importer = null;
//Try to import Caffe GoogleNet model
try {
importer = createCaffeImporter(modelTxt, modelBin);
//Importer can throw errors, we will catch them
} catch (Exception e) {
e.printStackTrace();
}
//! [Create the importer of Caffe model]

if (importer == null) {
System.err.println("Can't load network by using the following files: ");
System.err.println("prototxt: " + modelTxt);
System.err.println("caffemodel: " + modelBin);
System.err.println("bvlc_googlenet.caffemodel can be downloaded here:");
System.err.println("http://dl.caffe.berkeleyvision.org/bvlc_googlenet.caffemodel");
System.exit(-1);
}

//! [Initialize network]
Net net = new Net();
importer.populateNet(net);
importer.close();
//We don't need importer anymore
//! [Initialize network]

//! [Prepare blob]
Mat img = imread(imageFile);

if (img.empty()) {
System.err.println("Can't read image from the file: " + imageFile);
System.exit(-1);
}

resize(img, img, new Size(224, 224)); //GoogLeNet accepts only 224x224 RGB-images
Blob inputBlob = new Blob(img); //Convert Mat to dnn::Blob image batch
//! [Prepare blob]

//! [Set input blob]
net.setBlob(".data", inputBlob); //set the network input
//! [Set input blob]

//! [Make forward pass]
net.forward(); //compute output
//! [Make forward pass]

//! [Gather output]
Blob prob = net.getBlob("prob"); //gather output of "prob" layer

Point classId = new Point();
double[] classProb = new double[1];
getMaxClass(prob, classId, classProb);//find the best class
//! [Gather output]

//! [Print results]
List<String> classNames = readClassNames();

System.out.println("Best class: #" + classId.x() + " '" + classNames.get(classId.x()) + "'");
System.out.println("Best class: #" + classId.x());
System.out.println("Probability: " + classProb[0] * 100 + "%");

//! [Print results]
} //main

}