/
LightFaceDetection.java
94 lines (79 loc) · 3.6 KB
/
LightFaceDetection.java
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
/*
* Copyright 2021 Amazon.com, Inc. or its affiliates. All Rights Reserved.
*
* Licensed under the Apache License, Version 2.0 (the "License"). You may not use this file except in compliance
* with the License. A copy of the License is located at
*
* http://aws.amazon.com/apache2.0/
*
* or in the "license" file accompanying this file. This file is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES
* OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions
* and limitations under the License.
*/
package ai.djl.examples.inference.face;
import ai.djl.ModelException;
import ai.djl.inference.Predictor;
import ai.djl.modality.cv.Image;
import ai.djl.modality.cv.ImageFactory;
import ai.djl.modality.cv.output.DetectedObjects;
import ai.djl.repository.zoo.Criteria;
import ai.djl.repository.zoo.ZooModel;
import ai.djl.training.util.ProgressBar;
import ai.djl.translate.TranslateException;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import java.io.IOException;
import java.nio.file.Files;
import java.nio.file.Path;
import java.nio.file.Paths;
/**
* An example of inference using a face detection model.
*
* <p>See this <a
* href="https://github.com/deepjavalibrary/djl/blob/master/examples/docs/face_detection.md">doc</a>
* for information about this example.
*/
public final class LightFaceDetection {
private static final Logger logger = LoggerFactory.getLogger(LightFaceDetection.class);
private LightFaceDetection() {}
public static void main(String[] args) throws IOException, ModelException, TranslateException {
DetectedObjects detection = LightFaceDetection.predict();
logger.info("{}", detection);
}
public static DetectedObjects predict() throws IOException, ModelException, TranslateException {
Path facePath = Paths.get("src/test/resources/largest_selfie.jpg");
Image img = ImageFactory.getInstance().fromFile(facePath);
double confThresh = 0.85f;
double nmsThresh = 0.45f;
double[] variance = {0.1f, 0.2f};
int topK = 5000;
int[][] scales = {{10, 16, 24}, {32, 48}, {64, 96}, {128, 192, 256}};
int[] steps = {8, 16, 32, 64};
FaceDetectionTranslator translator =
new FaceDetectionTranslator(confThresh, nmsThresh, variance, topK, scales, steps);
Criteria<Image, DetectedObjects> criteria =
Criteria.builder()
.setTypes(Image.class, DetectedObjects.class)
.optModelUrls("https://resources.djl.ai/test-models/pytorch/ultranet.zip")
.optTranslator(translator)
.optProgress(new ProgressBar())
.optEngine("PyTorch") // Use PyTorch engine
.build();
try (ZooModel<Image, DetectedObjects> model = criteria.loadModel()) {
try (Predictor<Image, DetectedObjects> predictor = model.newPredictor()) {
DetectedObjects detection = predictor.predict(img);
saveBoundingBoxImage(img, detection);
return detection;
}
}
}
private static void saveBoundingBoxImage(Image img, DetectedObjects detection)
throws IOException {
Path outputDir = Paths.get("build/output");
Files.createDirectories(outputDir);
img.drawBoundingBoxes(detection);
Path imagePath = outputDir.resolve("ultranet_detected.png");
img.save(Files.newOutputStream(imagePath), "png");
logger.info("Face detection result image has been saved in: {}", imagePath);
}
}