Skip to content

Commit

Permalink
[opencv] Add opencv extension
Browse files Browse the repository at this point in the history
Change-Id: I738397d316ae128dd66f6ff5de8d52bfc56927e2
  • Loading branch information
frankfliu committed Nov 2, 2021
1 parent 11d4013 commit 56f626a
Show file tree
Hide file tree
Showing 24 changed files with 584 additions and 93 deletions.
58 changes: 21 additions & 37 deletions api/src/main/java/ai/djl/modality/cv/BufferedImageFactory.java
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,6 @@
import java.io.IOException;
import java.io.InputStream;
import java.io.OutputStream;
import java.net.URL;
import java.nio.ByteBuffer;
import java.nio.file.Path;
import java.util.List;
Expand All @@ -61,16 +60,6 @@ public Image fromFile(Path path) throws IOException {
return new BufferedImageWrapper(image);
}

/** {@inheritDoc} */
@Override
public Image fromUrl(URL url) throws IOException {
BufferedImage image = ImageIO.read(url);
if (image == null) {
throw new IOException("Failed to read image from: " + url);
}
return new BufferedImageWrapper(image);
}

/** {@inheritDoc} */
@Override
public Image fromInputStream(InputStream is) throws IOException {
Expand Down Expand Up @@ -148,7 +137,7 @@ protected void save(BufferedImage image, OutputStream os, String type) throws IO

private class BufferedImageWrapper implements Image {

private final BufferedImage image;
private BufferedImage image;

BufferedImageWrapper(BufferedImage image) {
this.image = image;
Expand All @@ -174,19 +163,22 @@ public Object getWrappedImage() {

/** {@inheritDoc} */
@Override
public Image getSubimage(int x, int y, int w, int h) {
public Image getSubImage(int x, int y, int w, int h) {
return new BufferedImageWrapper(image.getSubimage(x, y, w, h));
}

/** {@inheritDoc} */
@Override
public Image duplicate(Type type) {
private void convertIdNeeded() {
if (image.getType() == BufferedImage.TYPE_INT_ARGB) {
return;
}

BufferedImage newImage =
new BufferedImage(image.getWidth(), image.getHeight(), getType(type));
new BufferedImage(
image.getWidth(), image.getHeight(), BufferedImage.TYPE_INT_ARGB);
Graphics2D g = newImage.createGraphics();
g.drawImage(image, 0, 0, null);
g.dispose();
return new BufferedImageWrapper(newImage);
image = newImage;
}

/** {@inheritDoc} */
Expand Down Expand Up @@ -244,6 +236,9 @@ public void save(OutputStream os, String type) throws IOException {
/** {@inheritDoc} */
@Override
public void drawBoundingBoxes(DetectedObjects detections) {
// Make image copy with alpha channel because original image was jpg
convertIdNeeded();

Graphics2D g = (Graphics2D) image.getGraphics();
int stroke = 2;
g.setStroke(new BasicStroke(stroke));
Expand All @@ -269,10 +264,9 @@ public void drawBoundingBoxes(DetectedObjects detections) {
drawText(g, className, x, y, stroke, 4);
// If we have a mask instead of a plain rectangle, draw tha mask
if (box instanceof Mask) {
Mask mask = (Mask) box;
drawMask(image, mask);
drawMask((Mask) box);
} else if (box instanceof Landmark) {
drawLandmarks(image, box);
drawLandmarks(box);
}
}
g.dispose();
Expand All @@ -281,6 +275,9 @@ public void drawBoundingBoxes(DetectedObjects detections) {
/** {@inheritDoc} */
@Override
public void drawJoints(Joints joints) {
// Make image copy with alpha channel because original image was jpg
convertIdNeeded();

Graphics2D g = (Graphics2D) image.getGraphics();
int stroke = 2;
g.setStroke(new BasicStroke(stroke));
Expand All @@ -297,13 +294,6 @@ public void drawJoints(Joints joints) {
g.dispose();
}

private int getType(Type type) {
if (type == Type.TYPE_INT_ARGB) {
return BufferedImage.TYPE_INT_ARGB;
}
throw new IllegalArgumentException("the type is not supported!");
}

private Color randomColor() {
return new Color(RandomUtils.nextInt(255));
}
Expand All @@ -321,7 +311,7 @@ private void drawText(Graphics2D g, String text, int x, int y, int stroke, int p
g.drawString(text, x + padding, y + ascent);
}

private void drawMask(BufferedImage image, Mask mask) {
private void drawMask(Mask mask) {
float r = RandomUtils.nextFloat();
float g = RandomUtils.nextFloat();
float b = RandomUtils.nextFloat();
Expand All @@ -343,13 +333,7 @@ private void drawMask(BufferedImage image, Mask mask) {
probDist.length, probDist[0].length, BufferedImage.TYPE_INT_ARGB);
for (int xCor = 0; xCor < probDist.length; xCor++) {
for (int yCor = 0; yCor < probDist[xCor].length; yCor++) {
float opacity = probDist[xCor][yCor];
if (opacity < 0.1) {
opacity = 0f;
}
if (opacity > 0.8) {
opacity = 0.8f;
}
float opacity = probDist[xCor][yCor] * 0.8f;
maskImage.setRGB(xCor, yCor, new Color(r, g, b, opacity).darker().getRGB());
}
}
Expand All @@ -358,7 +342,7 @@ private void drawMask(BufferedImage image, Mask mask) {
gR.dispose();
}

private void drawLandmarks(BufferedImage image, BoundingBox box) {
private void drawLandmarks(BoundingBox box) {
Graphics2D g = (Graphics2D) image.getGraphics();
g.setColor(new Color(246, 96, 0));
BasicStroke bStroke = new BasicStroke(4, BasicStroke.CAP_BUTT, BasicStroke.JOIN_MITER);
Expand Down
15 changes: 1 addition & 14 deletions api/src/main/java/ai/djl/modality/cv/Image.java
Original file line number Diff line number Diff line change
Expand Up @@ -55,15 +55,7 @@ public interface Image {
* @param h the height of the specified rectangular region
* @return subimage of this image
*/
Image getSubimage(int x, int y, int w, int h);

/**
* Gets a deep copy of the original image given the type.
*
* @param type the type of the copied image
* @return the copy of the original image.
*/
Image duplicate(Type type);
Image getSubImage(int x, int y, int w, int h);

/**
* Converts image to a RGB {@link NDArray}.
Expand Down Expand Up @@ -129,11 +121,6 @@ public int numChannels() {
}
}

/** Type indicates the type options for images. */
enum Type {
TYPE_INT_ARGB
}

/** Interpolation indicates the Interpolation options for resizinig an image. */
enum Interpolation {
NEAREST,
Expand Down
35 changes: 26 additions & 9 deletions api/src/main/java/ai/djl/modality/cv/ImageFactory.java
Original file line number Diff line number Diff line change
Expand Up @@ -19,27 +19,40 @@
import java.net.URL;
import java.nio.file.Path;
import java.nio.file.Paths;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

/**
* {@code ImageFactory} contains image creation mechanism on top of different platforms like PC and
* Android. System will choose appropriate Factory based on the supported image type.
*/
public abstract class ImageFactory {

private static final Logger logger = LoggerFactory.getLogger(ImageFactory.class);

private static final String[] FACTORIES = {
"ai.djl.opencv.OpenCVImageFactory",
"ai.djl.modality.cv.BufferedImageFactory",
"ai.djl.android.core.BitmapImageFactory"
};

private static ImageFactory factory = newInstance();

private static ImageFactory newInstance() {
String className = "ai.djl.modality.cv.BufferedImageFactory";
int index = 0;
if ("http://www.android.com/".equals(System.getProperty("java.vendor.url"))) {
className = "ai.djl.android.core.BitmapImageFactory";
index = 2;
}
try {
Class<? extends ImageFactory> clazz =
Class.forName(className).asSubclass(ImageFactory.class);
return clazz.getConstructor().newInstance();
} catch (ReflectiveOperationException e) {
throw new IllegalStateException("Create new ImageFactory failed!", e);
for (int i = index; i < FACTORIES.length; ++i) {
try {
Class<? extends ImageFactory> clazz =
Class.forName(FACTORIES[i]).asSubclass(ImageFactory.class);
return clazz.getConstructor().newInstance();
} catch (ReflectiveOperationException e) {
logger.trace("", e);
}
}
throw new IllegalStateException("Create new ImageFactory failed!");
}

/**
Expand Down Expand Up @@ -76,7 +89,11 @@ public static void setImageFactory(ImageFactory factory) {
* @return {@link Image}
* @throws IOException URL is not valid.
*/
public abstract Image fromUrl(URL url) throws IOException;
public Image fromUrl(URL url) throws IOException {
try (InputStream is = url.openStream()) {
return fromInputStream(is);
}
}

/**
* Gets {@link Image} from URL.
Expand Down
1 change: 1 addition & 0 deletions bom/build.gradle
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@ dependencies {
api "ai.djl.mxnet:mxnet-engine:${version}"
api "ai.djl.mxnet:mxnet-model-zoo:${version}"
api "ai.djl.mxnet:mxnet-native-auto:${mxnet_version}"
api "ai.djl.opencv:opencv:${version}"
api "ai.djl.onnxruntime:onnxruntime-engine:${version}"
api "ai.djl.paddlepaddle:paddlepaddle-engine:${version}"
api "ai.djl.paddlepaddle:paddlepaddle-model-zoo:${version}"
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -72,7 +72,7 @@ private static Image getSubImage(Image img, BoundingBox box) {
rect.getWidth() * width,
rect.getHeight() * height,
0.18);
return img.getSubimage(squareBox[0], squareBox[1], squareBox[2], squareBox[2]);
return img.getSubImage(squareBox[0], squareBox[1], squareBox[2], squareBox[2]);
}

private static int[] extendSquare(
Expand All @@ -90,13 +90,11 @@ private static void saveBoundingBoxImage(Image img, DetectedObjects detection)
Path outputDir = Paths.get("build/output");
Files.createDirectories(outputDir);

// Make image copy with alpha channel because original image was jpg
Image newImage = img.duplicate(Image.Type.TYPE_INT_ARGB);
newImage.drawBoundingBoxes(detection);
img.drawBoundingBoxes(detection);

Path imagePath = outputDir.resolve("test.png");
// OpenJDK can't save jpg with alpha channel
newImage.save(Files.newOutputStream(imagePath), "png");
img.save(Files.newOutputStream(imagePath), "png");
}

private static Predictor<Image, Classifications> getClassifier()
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -115,7 +115,7 @@ private static Image getSubImage(Image img, BoundingBox box) {
(int) (extended[2] * width),
(int) (extended[3] * height)
};
return img.getSubimage(recovered[0], recovered[1], recovered[2], recovered[3]);
return img.getSubImage(recovered[0], recovered[1], recovered[2], recovered[3]);
}

private static double[] extendRect(double xmin, double ymin, double width, double height) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -83,12 +83,10 @@ private static void saveBoundingBoxImage(Image img, DetectedObjects detection)
Path outputDir = Paths.get("build/output");
Files.createDirectories(outputDir);

// Make image copy with alpha channel because original image was jpg
Image newImage = img.duplicate(Image.Type.TYPE_INT_ARGB);
newImage.drawBoundingBoxes(detection);
img.drawBoundingBoxes(detection);

Path imagePath = outputDir.resolve("detected-dog_bike_car.png");
// OpenJDK can't save jpg with alpha channel
newImage.save(Files.newOutputStream(imagePath), "png");
img.save(Files.newOutputStream(imagePath), "png");
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -75,12 +75,10 @@ private static void saveBoundingBoxImage(Image img, DetectedObjects detection)
Path outputDir = Paths.get("build/output");
Files.createDirectories(outputDir);

// Make image copy with alpha channel because original image was jpg
Image newImage = img.duplicate(Image.Type.TYPE_INT_ARGB);
newImage.drawBoundingBoxes(detection);
img.drawBoundingBoxes(detection);

Path imagePath = outputDir.resolve("instances.png");
newImage.save(Files.newOutputStream(imagePath), "png");
img.save(Files.newOutputStream(imagePath), "png");
logger.info("Segmentation result image has been saved in: {}", imagePath);
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -81,13 +81,11 @@ private static void saveBoundingBoxImage(Image img, DetectedObjects detection)
Path outputDir = Paths.get("build/output");
Files.createDirectories(outputDir);

// Make image copy with alpha channel because original image was jpg
Image newImage = img.duplicate(Image.Type.TYPE_INT_ARGB);
newImage.drawBoundingBoxes(detection);
img.drawBoundingBoxes(detection);

Path imagePath = outputDir.resolve("detected-dog_bike_car.png");
// OpenJDK can't save jpg with alpha channel
newImage.save(Files.newOutputStream(imagePath), "png");
img.save(Files.newOutputStream(imagePath), "png");
logger.info("Detected objects image has been saved in: {}", imagePath);
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -110,13 +110,11 @@ private static void saveBoundingBoxImage(Image img, DetectedObjects detection)
Path outputDir = Paths.get("build/output");
Files.createDirectories(outputDir);

// Make image copy with alpha channel because original image was jpg
Image newImage = img.duplicate(Image.Type.TYPE_INT_ARGB);
newImage.drawBoundingBoxes(detection);
img.drawBoundingBoxes(detection);

Path imagePath = outputDir.resolve("detected-tensorflow-model-dog_bike_car.png");
// OpenJDK can't save jpg with alpha channel
newImage.save(Files.newOutputStream(imagePath), "png");
img.save(Files.newOutputStream(imagePath), "png");
logger.info("Detected objects image has been saved in: {}", imagePath);
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -94,7 +94,7 @@ private static Image predictPersonInImage(Image img)
Rectangle rect = item.getBoundingBox().getBounds();
int width = img.getWidth();
int height = img.getHeight();
return img.getSubimage(
return img.getSubImage(
(int) (rect.getX() * width),
(int) (rect.getY() * height),
(int) (rect.getWidth() * width),
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -89,12 +89,10 @@ private static void saveBoundingBoxImage(Image img, DetectedObjects detection)
Path outputDir = Paths.get("build/output");
Files.createDirectories(outputDir);

// Make image copy with alpha channel because original image was jpg
Image newImage = img.duplicate(Image.Type.TYPE_INT_ARGB);
newImage.drawBoundingBoxes(detection);
img.drawBoundingBoxes(detection);

Path imagePath = outputDir.resolve("ultranet_detected.png");
newImage.save(Files.newOutputStream(imagePath), "png");
img.save(Files.newOutputStream(imagePath), "png");
logger.info("Face detection result image has been saved in: {}", imagePath);
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -89,12 +89,10 @@ private static void saveBoundingBoxImage(Image img, DetectedObjects detection)
Path outputDir = Paths.get("build/output");
Files.createDirectories(outputDir);

// Make image copy with alpha channel because original image was jpg
Image newImage = img.duplicate(Image.Type.TYPE_INT_ARGB);
newImage.drawBoundingBoxes(detection);
img.drawBoundingBoxes(detection);

Path imagePath = outputDir.resolve("retinaface_detected.png");
newImage.save(Files.newOutputStream(imagePath), "png");
img.save(Files.newOutputStream(imagePath), "png");
logger.info("Face detection result image has been saved in: {}", imagePath);
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,7 @@
import java.nio.file.Files;
import java.nio.file.Path;
import java.nio.file.Paths;
import java.util.Arrays;
import java.util.Collections;
import java.util.List;
import java.util.stream.Collectors;
import org.slf4j.Logger;
Expand All @@ -53,7 +53,7 @@ public static void main(String[] args) throws ModelException, TranslateException
ImageFactory imageFactory = ImageFactory.getInstance();

List<Image> inputImages =
Arrays.asList(imageFactory.fromFile(Paths.get(imagePath + "fox.png")));
Collections.singletonList(imageFactory.fromFile(Paths.get(imagePath + "fox.png")));

List<Image> enhancedImages = enhance(inputImages);

Expand Down
Loading

0 comments on commit 56f626a

Please sign in to comment.