From 76b2ed9ecfa881c0d467aa7231aacb6b012c47e1 Mon Sep 17 00:00:00 2001 From: paolobolettieri Date: Fri, 26 Feb 2016 16:09:07 +0100 Subject: [PATCH] JavaCV version of OpenCV caffe_googlenet.cpp --- samples/CaffeGooglenet.java | 122 ++++++++++++++++++++++++++++++++++++ 1 file changed, 122 insertions(+) create mode 100644 samples/CaffeGooglenet.java diff --git a/samples/CaffeGooglenet.java b/samples/CaffeGooglenet.java new file mode 100644 index 00000000..8fe794e1 --- /dev/null +++ b/samples/CaffeGooglenet.java @@ -0,0 +1,122 @@ +/* + * JavaCV version of OpenCV caffe_googlenet.cpp + * https://github.com/ludv1x/opencv_contrib/blob/master/modules/dnn/samples/caffe_googlenet.cpp + * + * Paolo Bolettieri + */ + +import static org.bytedeco.javacpp.opencv_core.minMaxLoc; +import static org.bytedeco.javacpp.opencv_dnn.createCaffeImporter; +import static org.bytedeco.javacpp.opencv_imgcodecs.imread; +import static org.bytedeco.javacpp.opencv_imgproc.resize; + +import java.io.BufferedReader; +import java.io.File; +import java.io.FileReader; +import java.io.IOException; +import java.util.ArrayList; +import java.util.List; + +import org.bytedeco.javacpp.opencv_core.Mat; +import org.bytedeco.javacpp.opencv_core.Point; +import org.bytedeco.javacpp.opencv_core.Size; +import org.bytedeco.javacpp.opencv_dnn.Blob; +import org.bytedeco.javacpp.opencv_dnn.Importer; +import org.bytedeco.javacpp.opencv_dnn.Net; + +public class CaffeGooglenet { + + /* Find best class for the blob (i. e. class with maximal probability) */ + public static void getMaxClass(Blob probBlob, Point classId, double[] classProb) { + Mat probMat =probBlob.matRefConst().reshape(1, 1); //reshape the blob to 1x1000 matrix + minMaxLoc(probMat, null, classProb, null, classId, null); + } + + public static List readClassNames() { + String filename = "synset_words.txt"; + List classNames = null; + + try (BufferedReader br = new BufferedReader(new FileReader(new File(filename)))) { + classNames = new ArrayList(); + String name = null; + while ((name = br.readLine()) != null) { + classNames.add(name.substring(name.indexOf(' ')+1)); + } + } catch (IOException ex) { + System.err.println("File with classes labels not found " + filename); + System.exit(-1); + } + return classNames; + } + + public static void main(String[] args) throws Exception { + String modelTxt = "bvlc_googlenet.prototxt"; + String modelBin = "bvlc_googlenet.caffemodel"; + String imageFile = (args.length > 0) ? args[0] : "space_shuttle.jpg"; + + //! [Create the importer of Caffe model] + Importer importer = null; + //Try to import Caffe GoogleNet model + try { + importer = createCaffeImporter(modelTxt, modelBin); + //Importer can throw errors, we will catch them + } catch (Exception e) { + e.printStackTrace(); + } + //! [Create the importer of Caffe model] + + if (importer == null) { + System.err.println("Can't load network by using the following files: "); + System.err.println("prototxt: " + modelTxt); + System.err.println("caffemodel: " + modelBin); + System.err.println("bvlc_googlenet.caffemodel can be downloaded here:"); + System.err.println("http://dl.caffe.berkeleyvision.org/bvlc_googlenet.caffemodel"); + System.exit(-1); + } + + //! [Initialize network] + Net net = new Net(); + importer.populateNet(net); + importer.close(); + //We don't need importer anymore + //! [Initialize network] + + //! [Prepare blob] + Mat img = imread(imageFile); + + if (img.empty()) { + System.err.println("Can't read image from the file: " + imageFile); + System.exit(-1); + } + + resize(img, img, new Size(224, 224)); //GoogLeNet accepts only 224x224 RGB-images + Blob inputBlob = new Blob(img); //Convert Mat to dnn::Blob image batch + //! [Prepare blob] + + //! [Set input blob] + net.setBlob(".data", inputBlob); //set the network input + //! [Set input blob] + + //! [Make forward pass] + net.forward(); //compute output + //! [Make forward pass] + + //! [Gather output] + Blob prob = net.getBlob("prob"); //gather output of "prob" layer + + Point classId = new Point(); + double[] classProb = new double[1]; + getMaxClass(prob, classId, classProb);//find the best class + //! [Gather output] + + //! [Print results] + List classNames = readClassNames(); + + System.out.println("Best class: #" + classId.x() + " '" + classNames.get(classId.x()) + "'"); + System.out.println("Best class: #" + classId.x()); + System.out.println("Probability: " + classProb[0] * 100 + "%"); + + //! [Print results] + } //main + +}