Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix #7149 Datavec: ImageLoader.scalingIfNeed may missing channel scal… #7159

Merged
merged 6 commits into from Feb 19, 2019
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
Expand Up @@ -36,8 +36,8 @@
/**
* Image loader for taking images
* and converting them to matrices
* @author Adam Gibson
*
* @author Adam Gibson
*/
public class ImageLoader extends BaseImageLoader {

Expand All @@ -50,8 +50,8 @@ public class ImageLoader extends BaseImageLoader {
registry.registerServiceProvider(new com.twelvemonkeys.imageio.plugins.jpeg.JPEGImageWriterSpi());
registry.registerServiceProvider(new com.twelvemonkeys.imageio.plugins.psd.PSDImageReaderSpi());
registry.registerServiceProvider(Arrays.asList(new com.twelvemonkeys.imageio.plugins.bmp.BMPImageReaderSpi(),
new com.twelvemonkeys.imageio.plugins.bmp.CURImageReaderSpi(),
new com.twelvemonkeys.imageio.plugins.bmp.ICOImageReaderSpi()));
new com.twelvemonkeys.imageio.plugins.bmp.CURImageReaderSpi(),
new com.twelvemonkeys.imageio.plugins.bmp.ICOImageReaderSpi()));
}

public ImageLoader() {
Expand All @@ -61,9 +61,9 @@ public ImageLoader() {
/**
* Instantiate an image with the given
* height and width
*
* @param height the height to load*
* @param width the width to load

*/
public ImageLoader(long height, long width) {
super();
Expand All @@ -75,8 +75,9 @@ public ImageLoader(long height, long width) {
/**
* Instantiate an image with the given
* height and width
* @param height the height to load
* @param width the width to load
*
* @param height the height to load
* @param width the width to load
* @param channels the number of channels for the image*
*/
public ImageLoader(long height, long width, long channels) {
Expand All @@ -89,9 +90,10 @@ public ImageLoader(long height, long width, long channels) {
/**
* Instantiate an image with the given
* height and width
* @param height the height to load
* @param width the width to load
* @param channels the number of channels for the image*
*
* @param height the height to load
* @param width the width to load
* @param channels the number of channels for the image*
* @param centerCropIfNeeded to crop before rescaling and converting
*/
public ImageLoader(long height, long width, long channels, boolean centerCropIfNeeded) {
Expand Down Expand Up @@ -121,6 +123,7 @@ public INDArray asRowVector(InputStream inputStream) throws IOException {

/**
* Convert an image in to a row vector
*
* @param image the image to convert
* @return the row vector based on a rastered
* representation of the image
Expand All @@ -140,8 +143,9 @@ public INDArray asRowVector(BufferedImage image) {
/**
* Changes the input stream in to an
* bgr based raveled(flattened) vector
*
* @param file the input stream to convert
* @return the raveled bgr values for this input stream
* @return the raveled bgr values for this input stream
*/
public INDArray toRaveledTensor(File file) {
try {
Expand All @@ -157,8 +161,9 @@ public INDArray toRaveledTensor(File file) {
/**
* Changes the input stream in to an
* bgr based raveled(flattened) vector
*
* @param is the input stream to convert
* @return the raveled bgr values for this input stream
* @return the raveled bgr values for this input stream
*/
public INDArray toRaveledTensor(InputStream is) {
return toBgr(is).ravel();
Expand All @@ -167,6 +172,7 @@ public INDArray toRaveledTensor(InputStream is) {
/**
* Convert an image in to a raveled tensor of
* the bgr values of the image
*
* @param image the image to parse
* @return the raveled tensor of bgr values
*/
Expand Down Expand Up @@ -211,7 +217,7 @@ public INDArray toBgr(InputStream inputStream) {
}
}

private org.datavec.image.data.Image toBgrImage(InputStream inputStream){
private org.datavec.image.data.Image toBgrImage(InputStream inputStream) {
try {
BufferedImage image = ImageIO.read(inputStream);
INDArray img = toBgr(image);
Expand All @@ -237,6 +243,7 @@ public INDArray toBgr(BufferedImage image) {
/**
* Convert an image file
* in to a matrix
*
* @param f the file to convert
* @return a 2d matrix of a rastered version of the image
* @throws IOException
Expand All @@ -247,6 +254,7 @@ public INDArray asMatrix(File f) throws IOException {

/**
* Convert an input stream to a matrix
*
* @param inputStream the input stream to convert
* @return the input stream to convert
*/
Expand Down Expand Up @@ -283,6 +291,7 @@ public org.datavec.image.data.Image asImageMatrix(InputStream inputStream) throw

/**
* Convert an BufferedImage to a matrix
*
* @param image the BufferedImage to convert
* @return the input stream to convert
*/
Expand All @@ -297,7 +306,7 @@ public INDArray asMatrix(BufferedImage image) {

for (int i = 0; i < h; i++) {
for (int j = 0; j < w; j++) {
ret.putScalar(new int[] {i, j}, image.getRGB(j, i));
ret.putScalar(new int[]{i, j}, image.getRGB(j, i));
}
}
return ret;
Expand All @@ -307,8 +316,8 @@ public INDArray asMatrix(BufferedImage image) {
/**
* Slices up an image in to a mini batch.
*
* @param f the file to load from
* @param numMiniBatches the number of images in a mini batch
* @param f the file to load from
* @param numMiniBatches the number of images in a mini batch
* @param numRowsPerSlice the number of rows for each image
* @return a tensor representing one image as a mini batch
*/
Expand All @@ -327,6 +336,7 @@ public int[] flattenedImageFromFile(File f) throws IOException {

/**
* Load a rastered image from file
*
* @param file the file to load
* @return the rastered image
* @throws IOException
Expand All @@ -339,6 +349,7 @@ public int[][] fromFile(File file) throws IOException {

/**
* Load a rastered image from file
*
* @param file the file to load
* @return the rastered image
* @throws IOException
Expand All @@ -349,17 +360,17 @@ public int[][][] fromFileMultipleChannels(File file) throws IOException {

int w = image.getWidth(), h = image.getHeight();
int bands = image.getSampleModel().getNumBands();
int[][][] ret = new int[(int)Math.min(channels, Integer.MAX_VALUE)]
[(int)Math.min(h, Integer.MAX_VALUE)]
[(int)Math.min(w, Integer.MAX_VALUE)];
int[][][] ret = new int[(int) Math.min(channels, Integer.MAX_VALUE)]
[(int) Math.min(h, Integer.MAX_VALUE)]
[(int) Math.min(w, Integer.MAX_VALUE)];
byte[] pixels = ((DataBufferByte) image.getRaster().getDataBuffer()).getData();

for (int i = 0; i < h; i++) {
for (int j = 0; j < w; j++) {
for (int k = 0; k < channels; k++) {
if (k >= bands)
break;
ret[k][i][j] = pixels[(int)Math.min(channels * w * i + channels * j + k, Integer.MAX_VALUE)];
ret[k][i][j] = pixels[(int) Math.min(channels * w * i + channels * j + k, Integer.MAX_VALUE)];
}
}
}
Expand All @@ -368,6 +379,7 @@ public int[][][] fromFileMultipleChannels(File file) throws IOException {

/**
* Convert a matrix in to a buffereed image
*
* @param matrix the
* @return {@link java.awt.image.BufferedImage}
*/
Expand All @@ -393,14 +405,15 @@ private static int[] rasterData(INDArray matrix) {

/**
* Convert the given image to an rgb image
* @param arr the array to use
*
* @param arr the array to use
* @param image the image to set
*/
public void toBufferedImageRGB(INDArray arr, BufferedImage image) {
if (arr.rank() < 3)
throw new IllegalArgumentException("Arr must be 3d");

image = scalingIfNeed(image, arr.size(-2), arr.size(-1), true);
image = scalingIfNeed(image, arr.size(-2), arr.size(-1), image.getType(), true);
for (int i = 0; i < image.getHeight(); i++) {
for (int j = 0; j < image.getWidth(); j++) {
int r = arr.slice(2).getInt(i, j);
Expand All @@ -416,12 +429,12 @@ public void toBufferedImageRGB(INDArray arr, BufferedImage image) {
/**
* Converts a given Image into a BufferedImage
*
* @param img The Image to be converted
* @param img The Image to be converted
* @param type The color model of BufferedImage
* @return The converted BufferedImage
*/
public static BufferedImage toBufferedImage(Image img, int type) {
if (img instanceof BufferedImage) {
if (img instanceof BufferedImage && ((BufferedImage) img).getType() == type) {
return (BufferedImage) img;
}

Expand Down Expand Up @@ -463,7 +476,7 @@ protected INDArray toINDArrayBGR(BufferedImage image) {
int bands = image.getSampleModel().getNumBands();

byte[] pixels = ((DataBufferByte) image.getRaster().getDataBuffer()).getData();
int[] shape = new int[] {height, width, bands};
int[] shape = new int[]{height, width, bands};

INDArray ret2 = Nd4j.create(1, pixels.length);
for (int i = 0; i < ret2.length(); i++) {
Expand Down Expand Up @@ -491,32 +504,29 @@ public BufferedImage centerCropIfNeeded(BufferedImage img) {
}

protected BufferedImage scalingIfNeed(BufferedImage image, boolean needAlpha) {
return scalingIfNeed(image, height, width, needAlpha);
return scalingIfNeed(image, height, width, channels, needAlpha);
}

protected BufferedImage scalingIfNeed(BufferedImage image, long dstHeight, long dstWidth, boolean needAlpha) {
protected BufferedImage scalingIfNeed(BufferedImage image, long dstHeight, long dstWidth, long dstImageType, boolean needAlpha) {
Image scaled;
// Scale width and height first if necessary
if (dstHeight > 0 && dstWidth > 0 && (image.getHeight() != dstHeight || image.getWidth() != dstWidth)) {
Image scaled = image.getScaledInstance((int) dstWidth, (int) dstHeight, Image.SCALE_SMOOTH);

if (needAlpha && image.getColorModel().hasAlpha() && channels == BufferedImage.TYPE_4BYTE_ABGR) {
return toBufferedImage(scaled, BufferedImage.TYPE_4BYTE_ABGR);
} else {
if (channels == BufferedImage.TYPE_BYTE_GRAY)
return toBufferedImage(scaled, BufferedImage.TYPE_BYTE_GRAY);
else
return toBufferedImage(scaled, BufferedImage.TYPE_3BYTE_BGR);
}
scaled = image.getScaledInstance((int) dstWidth, (int) dstHeight, Image.SCALE_SMOOTH);
} else {
if (image.getType() == BufferedImage.TYPE_4BYTE_ABGR || image.getType() == BufferedImage.TYPE_3BYTE_BGR) {
return image;
} else if (needAlpha && image.getColorModel().hasAlpha() && channels == BufferedImage.TYPE_4BYTE_ABGR) {
return toBufferedImage(image, BufferedImage.TYPE_4BYTE_ABGR);
} else {
if (channels == BufferedImage.TYPE_BYTE_GRAY)
return toBufferedImage(image, BufferedImage.TYPE_BYTE_GRAY);
else
return toBufferedImage(image, BufferedImage.TYPE_3BYTE_BGR);
}
scaled = image;
}

// Transfer imageType if necessary and transfer to BufferedImage.
if (scaled instanceof BufferedImage && ((BufferedImage) scaled).getType() == dstImageType) {
return (BufferedImage) scaled;
}
if (needAlpha && image.getColorModel().hasAlpha() && dstImageType == BufferedImage.TYPE_4BYTE_ABGR) {
return toBufferedImage(scaled, BufferedImage.TYPE_4BYTE_ABGR);
} else {
if (dstImageType == BufferedImage.TYPE_BYTE_GRAY)
return toBufferedImage(scaled, BufferedImage.TYPE_BYTE_GRAY);
else
return toBufferedImage(scaled, BufferedImage.TYPE_3BYTE_BGR);
}
}

Expand Down
Expand Up @@ -128,6 +128,31 @@ public void testScalingIfNeed() throws Exception {

}

@Test
public void testScalingIfNeedWhenSuitableSizeButDiffChannel() {
int width1 = 60;
int height1 = 110;
int channel1 = BufferedImage.TYPE_BYTE_GRAY;
BufferedImage img1 = makeRandomBufferedImage(true, width1, height1);
ImageLoader loader1 = new ImageLoader(height1, width1, channel1);
BufferedImage scaled1 = loader1.scalingIfNeed(img1, false);
assertEquals(width1, scaled1.getWidth());
assertEquals(height1, scaled1.getHeight());
assertEquals(channel1, scaled1.getType());
assertEquals(1, scaled1.getSampleModel().getNumBands());

int width2 = 70;
int height2 = 120;
int channel2 = BufferedImage.TYPE_BYTE_GRAY;
BufferedImage img2 = makeRandomBufferedImage(false, width2, height2);
ImageLoader loader2 = new ImageLoader(height2, width2, channel2);
BufferedImage scaled2 = loader2.scalingIfNeed(img2, false);
assertEquals(width2, scaled2.getWidth());
assertEquals(height2, scaled2.getHeight());
assertEquals(channel2, scaled2.getType());
assertEquals(1, scaled2.getSampleModel().getNumBands());
}

@Test
public void testToBufferedImageRGB() {
BufferedImage img = makeRandomBufferedImage(false);
Expand All @@ -150,13 +175,19 @@ public void testToBufferedImageRGB() {

}

private BufferedImage makeRandomBufferedImage(boolean alpha) {
int w = rng.nextInt() % 100 + 100;
int h = rng.nextInt() % 100 + 100;
/**
* Generate a Random BufferedImage with specified width and height
*
* @param alpha Is image alpha
* @param width Proposed width
* @param height Proposed height
* @return Generated BufferedImage
*/
private BufferedImage makeRandomBufferedImage(boolean alpha, int width, int height) {
int type = alpha ? BufferedImage.TYPE_4BYTE_ABGR : BufferedImage.TYPE_3BYTE_BGR;
BufferedImage img = new BufferedImage(w, h, type);
for (int i = 0; i < h; ++i) {
for (int j = 0; j < w; ++j) {
BufferedImage img = new BufferedImage(width, height, type);
for (int i = 0; i < height; ++i) {
for (int j = 0; j < width; ++j) {
int a = (alpha ? rng.nextInt() : 1) & 0xff;
int r = rng.nextInt() & 0xff;
int g = rng.nextInt() & 0xff;
Expand All @@ -167,4 +198,14 @@ private BufferedImage makeRandomBufferedImage(boolean alpha) {
}
return img;
}

/**
* Generate a Random BufferedImage with random width and height
*
* @param alpha Is image alpha
* @return Generated BufferedImage
*/
private BufferedImage makeRandomBufferedImage(boolean alpha) {
return makeRandomBufferedImage(alpha, rng.nextInt() % 100 + 100, rng.nextInt() % 100 + 100);
}
}