Skip to content

Commit

Permalink
added blazeface test
Browse files Browse the repository at this point in the history
  • Loading branch information
cansik committed Aug 3, 2021
1 parent e9037bb commit 2eae01f
Show file tree
Hide file tree
Showing 2 changed files with 160 additions and 0 deletions.
90 changes: 90 additions & 0 deletions src/main/java/ch/bildspur/vision/MediaPipeBlazeFaceNetwork.java
Original file line number Diff line number Diff line change
@@ -0,0 +1,90 @@
package ch.bildspur.vision;

import ch.bildspur.vision.network.ObjectDetectionNetwork;
import ch.bildspur.vision.result.ObjectDetectionResult;
import ch.bildspur.vision.result.ResultList;
import ch.bildspur.vision.util.MathUtils;
import org.bytedeco.javacpp.FloatPointer;
import org.bytedeco.javacpp.IntPointer;
import org.bytedeco.javacpp.indexer.FloatIndexer;
import org.bytedeco.opencv.global.opencv_dnn;
import org.bytedeco.opencv.opencv_core.*;
import org.bytedeco.opencv.opencv_dnn.Net;
import org.bytedeco.opencv.opencv_text.FloatVector;

import java.nio.file.Path;
import java.util.ArrayList;
import java.util.List;

import static org.bytedeco.opencv.global.opencv_core.CV_32F;
import static org.bytedeco.opencv.global.opencv_dnn.*;

/**
* Based on https://github.com/Linzaer/Ultra-Light-Fast-Generic-Face-Detector-1MB/blob/master/caffe/ultra_face_opencvdnn_inference.py
* Adapted and improved a lot.
*/
public class MediaPipeBlazeFaceNetwork extends ObjectDetectionNetwork {
private Path modelPath;
protected Net net;

private int width;
private int height;

private Scalar imageMean = Scalar.all(127);
private float imageStd = 128.0f;

public MediaPipeBlazeFaceNetwork(Path modelPath, int width, int height) {
this.modelPath = modelPath;
this.width = width;
this.height = height;
}

@Override
public boolean setup() {
net = readNetFromONNX(modelPath.toAbsolutePath().toString());

if (DeepVision.ENABLE_CUDA_BACKEND) {
net.setPreferableBackend(opencv_dnn.DNN_BACKEND_CUDA);
net.setPreferableTarget(opencv_dnn.DNN_TARGET_CUDA);
}

if (net.empty()) {
System.out.println("Can't load network!");
return false;
}

return true;
}

@Override
public ResultList<ObjectDetectionResult> run(Mat frame) {
// convert image into batch of images
Mat inputBlob = blobFromImage(frame,
1 / imageStd,
new Size(width, height),
imageMean,
false, false, CV_32F);

// set input
net.setInput(inputBlob);

// create output layers
StringVector outNames = net.getUnconnectedOutLayersNames();
MatVector outs = new MatVector(outNames.size());

// run detection
net.forward(outs, outNames);

// extract boxes and scores
Mat boxesOut = outs.get(0);
Mat confidencesOut = outs.get(1);

// boxes
Mat boxes = boxesOut.reshape(0, boxesOut.size(1));

// class confidences (BACKGROUND, face)
Mat confidences = confidencesOut.reshape(0, confidencesOut.size(1));

return new ResultList<>();
}
}
70 changes: 70 additions & 0 deletions src/test/java/ch/bildspur/vision/test/BlazeFaceTest.java
Original file line number Diff line number Diff line change
@@ -0,0 +1,70 @@
package ch.bildspur.vision.test;


import ch.bildspur.vision.DeepVisionPreview;
import ch.bildspur.vision.MediaPipeBlazeFaceNetwork;
import ch.bildspur.vision.TextBoxesNetwork;
import ch.bildspur.vision.result.ObjectDetectionResult;
import processing.core.PApplet;
import processing.core.PImage;

import java.nio.file.Paths;
import java.util.List;

public class BlazeFaceTest extends PApplet {

public static void main(String... args) {
BlazeFaceTest sketch = new BlazeFaceTest();
sketch.runSketch();
}

public void settings() {
size(640, 480, FX2D);
}

PImage testImage;

DeepVisionPreview vision = new DeepVisionPreview(this);
MediaPipeBlazeFaceNetwork network;
List<ObjectDetectionResult> detections;

public void setup() {
colorMode(HSB, 360, 100, 100);

testImage = loadImage(sketchPath("data/faces.png"));

println("creating network...");
network = new MediaPipeBlazeFaceNetwork(Paths.get("networks/face_detection_back_256x256_barracuda.onnx"), 256, 256);

println("loading model...");
network.setup();

//network.setConfidenceThreshold(0.2f);

println("inferencing...");
detections = network.run(testImage);
println("done!");

for (ObjectDetectionResult detection : detections) {
System.out.println(detection.getClassName() + "\t[" + detection.getConfidence() + "]");
}

println("found " + detections.size() + " texts!");
}

public void draw() {
background(55);

image(testImage, 0, 0);

noFill();
strokeWeight(2f);

stroke(200, 80, 100);
for (ObjectDetectionResult detection : detections) {
rect(detection.getX(), detection.getY(), detection.getWidth(), detection.getHeight());
}

surface.setTitle("BlazeFace Test - FPS: " + Math.round(frameRate));
}
}

0 comments on commit 2eae01f

Please sign in to comment.