New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Update ImageUtils to support resizing one, three, or four channel images #94
Changes from 4 commits
9b104c6
811c086
882d9ee
7bfdfb2
ba5d332
e4fed25
c12efd8
887b306
e82e302
8c4a320
ae46bf2
65fe425
3c757ca
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -16,6 +16,7 @@ | |
|
||
package com.databricks.sparkdl | ||
|
||
import java.awt.color.ColorSpace | ||
import java.awt.image.BufferedImage | ||
import java.awt.{Color, Image} | ||
|
||
|
@@ -25,8 +26,8 @@ import org.apache.spark.sql.Row | |
private[sparkdl] object ImageUtils { | ||
|
||
/** | ||
* Takes a Row image (spImage) and returns a Java BufferedImage. Currently supports 1 & 3 | ||
* channel images. If the image has 3 channels, we assume the channels are in BGR order. | ||
* Takes a Row image (spImage) and returns a Java BufferedImage. Currently supports 1, 3, & 4 | ||
* channel images. If the image has 3 or 4 channels, we assume the channels are in BGR(A) order. | ||
* | ||
* @param rowImage Image in spark.ml.image format. | ||
* @return Java BGR BufferedImage. | ||
|
@@ -42,10 +43,15 @@ private[sparkdl] object ImageUtils { | |
| image of size ($height, $width, $channels). | ||
""".stripMargin | ||
) | ||
val image = new BufferedImage(width, height, BufferedImage.TYPE_3BYTE_BGR) | ||
|
||
val image = channels match { | ||
case 1 => new BufferedImage(width, height, BufferedImage.TYPE_BYTE_GRAY) | ||
case 3 => new BufferedImage(width, height, BufferedImage.TYPE_3BYTE_BGR) | ||
case 4 => new BufferedImage(width, height, BufferedImage.TYPE_4BYTE_ABGR) | ||
} | ||
|
||
var offset, h = 0 | ||
var r, g, b: Byte = 0 | ||
var r, g, b, a: Byte = 0 | ||
while (h < height) { | ||
var w = 0 | ||
while (w < width) { | ||
|
@@ -58,11 +64,20 @@ private[sparkdl] object ImageUtils { | |
b = imageData(offset) | ||
g = imageData(offset + 1) | ||
r = imageData(offset + 2) | ||
case 4 => | ||
b = imageData(offset) | ||
g = imageData(offset + 1) | ||
r = imageData(offset + 2) | ||
a = imageData(offset + 3) | ||
case _ => | ||
require(false, s"`Channels` must be 1 or 3, got $channels.") | ||
require(false, s"`Channels` must be 1, 3, or 4, got $channels.") | ||
} | ||
|
||
val color = new Color(r & 0xff, g & 0xff, b & 0xff) | ||
val color = if (channels < 4) { | ||
new Color(r & 0xff, g & 0xff, b & 0xff) | ||
} else { | ||
new Color(r & 0xff, g & 0xff, b & 0xff, a & 0xff) | ||
} | ||
image.setRGB(w, h, color.getRGB) | ||
offset += channels | ||
w += 1 | ||
|
@@ -72,15 +87,40 @@ private[sparkdl] object ImageUtils { | |
image | ||
} | ||
|
||
/** Returns the number of channels in the passed-in buffered image. */ | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. This method and |
||
private def getNumChannels(img: BufferedImage): Int = { | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Why not call image.getColorModel.getNumComponents? |
||
val isGray = img.getColorModel.getColorSpace.getType == ColorSpace.TYPE_GRAY | ||
val hasAlpha = img.getColorModel.hasAlpha | ||
if (isGray) { | ||
1 | ||
} else if (hasAlpha) { | ||
4 | ||
} else { | ||
3 | ||
} | ||
} | ||
|
||
/** Returns the OCV type (int) of the passed-in image */ | ||
private def getOCVType(img: BufferedImage): Int = { | ||
val isGray = img.getColorModel.getColorSpace.getType == ColorSpace.TYPE_GRAY | ||
val hasAlpha = img.getColorModel.hasAlpha | ||
if (isGray) { | ||
ImageSchema.ocvTypes("CV_8UC1") | ||
} else if (hasAlpha) { | ||
ImageSchema.ocvTypes("CV_8UC4") | ||
} else { | ||
ImageSchema.ocvTypes("CV_8UC3") | ||
} | ||
} | ||
|
||
/** | ||
* Takes a Java BufferedImage and returns a Row Image (spImage). | ||
* | ||
* @param image Java BufferedImage. | ||
* @return Row image in spark.ml.image format with 3 channels in BGR order. | ||
* @return Row image in spark.ml.image format with channels in BGR(A) order. | ||
*/ | ||
private[sparkdl] def spImageFromBufferedImage(image: BufferedImage, origin: String = null): Row = { | ||
val channels = 3 | ||
val channels = getNumChannels(image) | ||
val height = image.getHeight | ||
val width = image.getWidth | ||
|
||
|
@@ -89,27 +129,37 @@ private[sparkdl] object ImageUtils { | |
while (h < height) { | ||
var w = 0 | ||
while (w < width) { | ||
val color = new Color(image.getRGB(w, h)) | ||
decoded(offset) = color.getBlue.toByte | ||
decoded(offset + 1) = color.getGreen.toByte | ||
decoded(offset + 2) = color.getRed.toByte | ||
val color = new Color(image.getRGB(w, h), image.getColorModel.hasAlpha) | ||
channels match { | ||
case 1 => | ||
decoded(offset) = color.getBlue.toByte | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Note: Unfortunately this doesn't always yield the original blue byte written in https://gist.github.com/smurching/d6e918a84c79155b616f9339b30fdfca It makes sense to me that this happens, since it doesn't really make sense to tell a grayscale image "set the pixel at 0, 0 to RGB (x, y, z)" and expect the RGB channel values to actually be (x, y, z) - if this were the case, the image might not be gray. I'm not sure how to work around it... There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Actually I think this can be fixed by stealing some code from ImageSchema that directly manipulates the Image raster, will push an update... |
||
case 3 => | ||
decoded(offset) = color.getBlue.toByte | ||
decoded(offset + 1) = color.getGreen.toByte | ||
decoded(offset + 2) = color.getRed.toByte | ||
case 4 => | ||
decoded(offset) = color.getBlue.toByte | ||
decoded(offset + 1) = color.getGreen.toByte | ||
decoded(offset + 2) = color.getRed.toByte | ||
decoded(offset + 3) = color.getAlpha.toByte | ||
} | ||
offset += channels | ||
w += 1 | ||
} | ||
h += 1 | ||
} | ||
Row(origin, height, width, channels, ImageSchema.ocvTypes("CV_8UC3"), decoded) | ||
Row(origin, height, width, channels, getOCVType(image), decoded) | ||
|
||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Nit: extra blank line |
||
} | ||
|
||
/** | ||
* Resizes an image and returns it as an Array[Byte]. Only 1 and 3 channel inputs, where each | ||
* Resizes an image and returns it as an Array[Byte]. Only 1, 3, and 4 channel inputs, where each | ||
* channel is a single Byte, are currently supported. Only BGR channel order is supported but | ||
* this might work for other channel orders. | ||
* | ||
* @param tgtHeight desired height of output image. | ||
* @param tgtWidth desired width of output image. | ||
* @param tgtChannels number of channels of output image (must be 3), may be used later to | ||
* support more channels. | ||
* @param tgtChannels number of channels in output image. | ||
* @param spImage image to resize. | ||
* @param scaleHint hint which algorhitm to use, see java.awt.Image#SCALE_SCALE_AREA_AVERAGING | ||
* @return resized image, if the input was BGR or 1 channel, the output will be BGR. | ||
|
@@ -120,8 +170,6 @@ private[sparkdl] object ImageUtils { | |
tgtChannels: Int, | ||
spImage: Row, | ||
scaleHint: Int = Image.SCALE_AREA_AVERAGING): Row = { | ||
require(tgtChannels == 3, s"`tgtChannels` was set to $tgtChannels, must be 3.") | ||
|
||
val height = ImageSchema.getHeight(spImage) | ||
val width = ImageSchema.getWidth(spImage) | ||
val nChannels = ImageSchema.getNChannels(spImage) | ||
|
@@ -130,14 +178,19 @@ private[sparkdl] object ImageUtils { | |
spImage | ||
} else { | ||
val srcImg = spImageToBufferedImage(spImage) | ||
val tgtImg = new BufferedImage(tgtWidth, tgtHeight, BufferedImage.TYPE_3BYTE_BGR) | ||
val tgtImgType = tgtChannels match { | ||
case 1 => BufferedImage.TYPE_BYTE_GRAY | ||
case 3 => BufferedImage.TYPE_3BYTE_BGR | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I would add |
||
case 4 => BufferedImage.TYPE_4BYTE_ABGR | ||
} | ||
val tgtImg = new BufferedImage(tgtWidth, tgtHeight, tgtImgType) | ||
// scaledImg is a java.awt.Image which supports drawing but not pixel lookup by index. | ||
val scaledImg = srcImg.getScaledInstance(tgtWidth, tgtHeight, scaleHint) | ||
// Draw scaledImage onto resized (usually smaller) tgtImg so extract individual pixel values. | ||
val graphic = tgtImg.createGraphics() | ||
graphic.drawImage(scaledImg, 0, 0, null) | ||
graphic.dispose() | ||
spImageFromBufferedImage(tgtImg, origin=ImageSchema.getOrigin(spImage)) | ||
spImageFromBufferedImage(tgtImg, origin = ImageSchema.getOrigin(spImage)) | ||
} | ||
} | ||
} |
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -22,54 +22,57 @@ import javax.imageio.ImageIO | |
|
||
import scala.util.Random | ||
|
||
import org.scalatest.FunSuite | ||
|
||
import org.apache.spark.ml.image.ImageSchema | ||
import org.apache.spark.sql.Row | ||
|
||
import org.scalatest.FunSuite | ||
|
||
object ImageUtilsSuite { | ||
val biggerImage: Row = { | ||
val biggerFile = getClass.getResource("/sparkdl/test-image-collection/00081101.jpg").getFile | ||
val imageBuffer = ImageIO.read(new File(biggerFile)) | ||
ImageUtils.spImageFromBufferedImage(imageBuffer) | ||
} | ||
|
||
val smallerImage: Row = { | ||
val smallerFile = getClass.getResource("/sparkdl/00081101-small-version.png").getFile | ||
val imageBuffer = ImageIO.read(new File(smallerFile)) | ||
/** Read image data into a BufferedImage, then use our utility method to convert to a row image */ | ||
def getImageRow(resourcePath: String): Row = { | ||
val resourceUrl = getClass.getResource(resourcePath).getFile | ||
val imageBuffer = ImageIO.read(new File(resourceUrl)) | ||
ImageUtils.spImageFromBufferedImage(imageBuffer) | ||
} | ||
|
||
|
||
def smallerImage: Row = getImageRow("/sparkdl/00081101-small-version.png") | ||
def biggerImage: Row = getImageRow("/sparkdl/test-image-collection/00081101.jpg") | ||
} | ||
|
||
class ImageUtilsSuite extends FunSuite { | ||
// We want to make sure to test ImageUtils in headless mode to ensure it'll work on all systems. | ||
assert(System.getProperty("java.awt.headless") === "true") | ||
|
||
import ImageUtilsSuite._ | ||
|
||
test("Test spImage resize.") { | ||
val tgtHeight: Int = ImageSchema.getHeight(smallerImage) | ||
val tgtWidth: Int = ImageSchema.getWidth(smallerImage) | ||
val tgtChannels: Int = ImageSchema.getNChannels(smallerImage) | ||
def getImagePath(imageSize: String, numChannels: Int): String = { | ||
s"/sparkdl/test-image-collection/${numChannels}_channels/$imageSize.png" | ||
} | ||
for (channels <- Seq(1, 3, 4)) { | ||
val smallerImage = ImageUtilsSuite.getImageRow(getImagePath("small", channels)) | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Note: I cheated here & had the test compare the original image to a smaller image generated via the resizeImage code path ( There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Do you mean you compare against image resized using BufferedImage or ImageUtils.resizeImage? Just a suggestion, why don't we store only the original image, and compare against the image resized in java using BufferedImage on the fly? |
||
val biggerImage = ImageUtilsSuite.getImageRow(getImagePath("big", channels)) | ||
|
||
val testImage = ImageUtils.resizeImage(tgtHeight, tgtWidth, tgtChannels, biggerImage) | ||
assert(testImage === smallerImage, "Resizing image did not produce expected smaller image.") | ||
val tgtHeight: Int = ImageSchema.getHeight(smallerImage) | ||
val tgtWidth: Int = ImageSchema.getWidth(smallerImage) | ||
val tgtChannels: Int = ImageSchema.getNChannels(smallerImage) | ||
|
||
val testImage = ImageUtils.resizeImage(tgtHeight, tgtWidth, tgtChannels, biggerImage) | ||
assert(testImage === smallerImage, "Resizing image did not produce expected smaller image.") | ||
} | ||
} | ||
|
||
test ("Test Row image -> BufferedImage -> Row image") { | ||
val height = 200 | ||
val width = 100 | ||
val channels = 3 | ||
|
||
val rand = new Random(971) | ||
val imageData = Array.ofDim[Byte](height * width * channels) | ||
rand.nextBytes(imageData) | ||
val spImage = Row(null, height, width, channels, ImageSchema.ocvTypes("CV_8UC3"), imageData) | ||
val bufferedImage = ImageUtils.spImageToBufferedImage(spImage) | ||
val testImage = ImageUtils.spImageFromBufferedImage(bufferedImage) | ||
assert(spImage === testImage, "Image changed during conversion.") | ||
for (channels <- Seq(3, 4)) { | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Note: This test will fail for channels = 1, see my comment in |
||
val rand = new Random(971) | ||
val imageData = Array.ofDim[Byte](height * width * channels) | ||
rand.nextBytes(imageData) | ||
val ocvType = s"CV_8UC$channels" | ||
val spImage = Row(null, height, width, channels, ImageSchema.ocvTypes(ocvType), imageData) | ||
val bufferedImage = ImageUtils.spImageToBufferedImage(spImage) | ||
val testImage = ImageUtils.spImageFromBufferedImage(bufferedImage) | ||
assert(spImage === testImage, s"Image changed during conversion") | ||
} | ||
} | ||
|
||
test("Simple BufferedImage from Row Image") { | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Technically there can be images with 2 channels. We should throw UnsupportedOperationException or something like that and inform user that resize does not work for 2 channel images.
Actually, thinking about it, we should probably check that the open cv type is one of the supported ones - e.g. one of CV_8UC{1,3,4}, so how about something something like this :
val mode = ImageSchema.getMode(rowImage)
val image = mode match {
case imageSchema.ocvTypes("CV_8UC1") => new BufferedImage(width, height, BufferedImage.TYPE_BYTE_GRAY)
....
case _ => throw new UnsupportedOperationException("Can not resize images with mode = " + mode)
}