New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Update ImageUtils to support resizing one, three, or four channel images #94
Changes from 12 commits
9b104c6
811c086
882d9ee
7bfdfb2
ba5d332
e4fed25
c12efd8
887b306
e82e302
8c4a320
ae46bf2
65fe425
3c757ca
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -16,6 +16,7 @@ | |
|
||
package com.databricks.sparkdl | ||
|
||
import java.awt.color.ColorSpace | ||
import java.awt.image.BufferedImage | ||
import java.awt.{Color, Image} | ||
|
||
|
@@ -24,9 +25,20 @@ import org.apache.spark.sql.Row | |
|
||
private[sparkdl] object ImageUtils { | ||
|
||
// Set of OpenCV modes supported by our image utilities | ||
private val supportedModes: Set[String] = Set("CV_8UC1", "CV_8UC3", "CV_8UC4") | ||
|
||
// Map from OpenCV mode to Java BufferedImage type | ||
private val openCVModeToImageType: Map[Int, Int] = Map( | ||
ImageSchema.ocvTypes("CV_8UC1") -> BufferedImage.TYPE_BYTE_GRAY, | ||
ImageSchema.ocvTypes("CV_8UC3") -> BufferedImage.TYPE_3BYTE_BGR, | ||
ImageSchema.ocvTypes("CV_8UC4") -> BufferedImage.TYPE_4BYTE_ABGR | ||
) | ||
|
||
/** | ||
* Takes a Row image (spImage) and returns a Java BufferedImage. Currently supports 1 & 3 | ||
* channel images. If the image has 3 channels, we assume the channels are in BGR order. | ||
* Takes a Row image (spImage) and returns a Java BufferedImage. Currently supports 1, 3, & 4 | ||
* channel images. If the Row image has 3 or 4 channels, we assume the channels are in BGR(A) | ||
* order. | ||
* | ||
* @param rowImage Image in spark.ml.image format. | ||
* @return Java BGR BufferedImage. | ||
|
@@ -35,81 +47,89 @@ private[sparkdl] object ImageUtils { | |
val height = ImageSchema.getHeight(rowImage) | ||
val width = ImageSchema.getWidth(rowImage) | ||
val channels = ImageSchema.getNChannels(rowImage) | ||
val mode = ImageSchema.getMode(rowImage) | ||
val imageData = ImageSchema.getData(rowImage) | ||
require( | ||
imageData.length == height * width * channels, | ||
s"""| Only one byte per channel is currently supported, got ${imageData.length} bytes for | ||
| image of size ($height, $width, $channels). | ||
""".stripMargin | ||
) | ||
val image = new BufferedImage(width, height, BufferedImage.TYPE_3BYTE_BGR) | ||
|
||
var offset, h = 0 | ||
var r, g, b: Byte = 0 | ||
while (h < height) { | ||
var w = 0 | ||
while (w < width) { | ||
channels match { | ||
case 1 => | ||
b = imageData(offset) | ||
g = b | ||
r = b | ||
case 3 => | ||
b = imageData(offset) | ||
g = imageData(offset + 1) | ||
r = imageData(offset + 2) | ||
case _ => | ||
require(false, s"`Channels` must be 1 or 3, got $channels.") | ||
} | ||
val imageType = openCVModeToImageType.getOrElse(mode, | ||
throw new UnsupportedOperationException("Cannot convert row image with " + | ||
s"unsupported OpenCV mode = ${mode} to BufferedImage. Supported OpenCV modes: " + | ||
s"${supportedModes.map(ImageSchema.ocvTypes(_)).mkString(", ")}")) | ||
|
||
val color = new Color(r & 0xff, g & 0xff, b & 0xff) | ||
image.setRGB(w, h, color.getRGB) | ||
offset += channels | ||
w += 1 | ||
val image = new BufferedImage(width, height, imageType) | ||
var offset = 0 | ||
val raster = image.getRaster | ||
// NOTE: This code assumes the raw image data in rowImage directly corresponds to the | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Does this change the data compared to the previous way? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Good Q, it doesn't change the data layout compared to the previous way. The BufferedImage types supported by our utilities ( There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Ah I was wondering more about the content, whether what we're doing manually produced the exact same result as There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Ahh I see. Actually the resize results may change when feeding grayscale images to DeepImageFeaturizer, but not for three-channel images (our DIF tests only run against three-channel images so they don't catch the change). The difference for grayscale images is that running
as the original code did doesn't necessarily set the grayscale pixel at That being said I could see changes in model behavior being sufficiently undesirable that we wouldn't want to make this change. Another thing: just realized this PR no longer validates that the input images to DIF are either 1 or 3 channels, I'm assuming we should add this validation back? The practical implication being that as it stands, this PR allows users to feed 4-channel images to DIF without hitting an error. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I think the direct raster access is the the right way, ImageSchema.decode does the same and using it for the BGR(A) images makes the code more consistent. Why do you want to restrict the number of channels on DIF? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. @tomasatdatabricks mainly I just didn't want to inadvertently release support for four-channel inputs in DIF without adding unit tests :P. Although I suppose we already supported one-channel inputs without specifically testing DIF on one-channel images. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. ok the current logic looks okay to me. the most important thing is not to change the logic if possible so there is no breaking change for users. i doubt this is a problem at this point. for the one- and four- channel input stuff -- if we want any testing, it seems enough to have unit tests to make sure this resizing function returns the correct number of channels given There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. SGTM. Should we add the unit test in this PR or in a follow-up? I'm happy to do either, it's a pretty small test. Thanks again for the reviews :) |
||
// raster of our output Java BufferedImage. | ||
for (h <- 0 until height) { | ||
for (w <- 0 until width) { | ||
for (c <- 1 to channels) { | ||
val cIdx = channels - c | ||
raster.setSample(w, h, cIdx, imageData(offset) & 0xff) | ||
offset += 1 | ||
} | ||
} | ||
h += 1 | ||
} | ||
image | ||
} | ||
|
||
/** Returns the OCV type (int) of the passed-in image */ | ||
private def getOCVType(img: BufferedImage): Int = { | ||
val isGray = img.getColorModel.getColorSpace.getType == ColorSpace.TYPE_GRAY | ||
val hasAlpha = img.getColorModel.hasAlpha | ||
if (isGray) { | ||
ImageSchema.ocvTypes("CV_8UC1") | ||
} else if (hasAlpha) { | ||
ImageSchema.ocvTypes("CV_8UC4") | ||
} else { | ||
ImageSchema.ocvTypes("CV_8UC3") | ||
} | ||
} | ||
|
||
/** | ||
* Takes a Java BufferedImage and returns a Row Image (spImage). | ||
* | ||
* @param image Java BufferedImage. | ||
* @return Row image in spark.ml.image format with 3 channels in BGR order. | ||
* @return Row image in spark.ml.image format with channels in BGR(A) order. | ||
*/ | ||
private[sparkdl] def spImageFromBufferedImage(image: BufferedImage, origin: String = null): Row = { | ||
val channels = 3 | ||
val channels = image.getColorModel.getNumComponents | ||
val height = image.getHeight | ||
val width = image.getWidth | ||
|
||
val isGray = image.getColorModel.getColorSpace.getType == ColorSpace.TYPE_GRAY | ||
val hasAlpha = image.getColorModel.hasAlpha | ||
val decoded = new Array[Byte](height * width * channels) | ||
var offset, h = 0 | ||
while (h < height) { | ||
var w = 0 | ||
while (w < width) { | ||
val color = new Color(image.getRGB(w, h)) | ||
decoded(offset) = color.getBlue.toByte | ||
decoded(offset + 1) = color.getGreen.toByte | ||
decoded(offset + 2) = color.getRed.toByte | ||
offset += channels | ||
w += 1 | ||
// NOTE: This code assumes the raster of our Java BufferedImage directly corresponds to | ||
// raw image data | ||
val raster = image.getRaster | ||
var offset = 0 | ||
for (h <- 0 until height) { | ||
for (w <- 0 until width) { | ||
for (c <- 1 to channels) { | ||
val cIdx = channels - c | ||
decoded(offset) = raster.getSample(w, h, cIdx).toByte | ||
offset += 1 | ||
} | ||
} | ||
h += 1 | ||
} | ||
Row(origin, height, width, channels, ImageSchema.ocvTypes("CV_8UC3"), decoded) | ||
Row(origin, height, width, channels, getOCVType(image), decoded) | ||
|
||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Nit: extra blank line |
||
} | ||
|
||
|
||
/** | ||
* Resizes an image and returns it as an Array[Byte]. Only 1 and 3 channel inputs, where each | ||
* Resizes an image and returns it as an Array[Byte]. Only 1, 3, and 4 channel inputs, where each | ||
* channel is a single Byte, are currently supported. Only BGR channel order is supported but | ||
* this might work for other channel orders. | ||
* | ||
* @param tgtHeight desired height of output image. | ||
* @param tgtWidth desired width of output image. | ||
* @param tgtChannels number of channels of output image (must be 3), may be used later to | ||
* support more channels. | ||
* @param tgtChannels number of channels (must be 1, 3, or 4) in output image. | ||
* @param spImage image to resize. | ||
* @param scaleHint hint which algorhitm to use, see java.awt.Image#SCALE_SCALE_AREA_AVERAGING | ||
* @return resized image, if the input was BGR or 1 channel, the output will be BGR. | ||
|
@@ -120,8 +140,6 @@ private[sparkdl] object ImageUtils { | |
tgtChannels: Int, | ||
spImage: Row, | ||
scaleHint: Int = Image.SCALE_AREA_AVERAGING): Row = { | ||
require(tgtChannels == 3, s"`tgtChannels` was set to $tgtChannels, must be 3.") | ||
|
||
val height = ImageSchema.getHeight(spImage) | ||
val width = ImageSchema.getWidth(spImage) | ||
val nChannels = ImageSchema.getNChannels(spImage) | ||
|
@@ -130,7 +148,14 @@ private[sparkdl] object ImageUtils { | |
spImage | ||
} else { | ||
val srcImg = spImageToBufferedImage(spImage) | ||
val tgtImg = new BufferedImage(tgtWidth, tgtHeight, BufferedImage.TYPE_3BYTE_BGR) | ||
val tgtImgType = tgtChannels match { | ||
case 1 => BufferedImage.TYPE_BYTE_GRAY | ||
case 3 => BufferedImage.TYPE_3BYTE_BGR | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I would add |
||
case 4 => BufferedImage.TYPE_4BYTE_ABGR | ||
case _ => throw new UnsupportedOperationException("Image resize: number of output " + | ||
s"channels must be 1, 3, or 4, got ${tgtChannels}.") | ||
} | ||
val tgtImg = new BufferedImage(tgtWidth, tgtHeight, tgtImgType) | ||
// scaledImg is a java.awt.Image which supports drawing but not pixel lookup by index. | ||
val scaledImg = srcImg.getScaledInstance(tgtWidth, tgtHeight, scaleHint) | ||
// Draw scaledImage onto resized (usually smaller) tgtImg so extract individual pixel values. | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@tomasatdatabricks As of right now, this error can only occur if users try to pass an image row with an unsupported OpenCV mode to DeepImageFeaturizer. I wonder if the current error message is vague/unhelpful to the user in this case...maybe the image validation should be pushed into DeepImageFeaturizer (but it also seems useful to validate the row image here)?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yes I think it should definitely be here as well. I believe you should always try to throw informative exception. Deep ImageFeaturizer is just one of the possible users of this function. Also, it does not have to have the same restrictions.