Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Update ImageUtils to support resizing one, three, or four channel images #94

Merged
merged 13 commits into from Jan 30, 2018
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
3 changes: 2 additions & 1 deletion python/tests/transformers/image_utils.py
Expand Up @@ -41,7 +41,8 @@ def _getSampleJPEGDir():
return os.path.join(cur_dir, "../resources/images")

def getImageFiles():
return glob(os.path.join(_getSampleJPEGDir(), "*"))
return [path for path in glob(os.path.join(_getSampleJPEGDir(), "*"))
if not os.path.isdir(path)]

def getSampleImageDF():
return imageIO.readImagesWithCustomFn(path=_getSampleJPEGDir(), decode_f=imageIO.PIL_decode)
Expand Down
117 changes: 71 additions & 46 deletions src/main/scala/com/databricks/sparkdl/ImageUtils.scala
Expand Up @@ -16,6 +16,7 @@

package com.databricks.sparkdl

import java.awt.color.ColorSpace
import java.awt.image.BufferedImage
import java.awt.{Color, Image}

Expand All @@ -24,9 +25,20 @@ import org.apache.spark.sql.Row

private[sparkdl] object ImageUtils {

// Set of OpenCV modes supported by our image utilities
private val supportedModes: Set[String] = Set("CV_8UC1", "CV_8UC3", "CV_8UC4")

// Map from OpenCV mode to Java BufferedImage type
private val openCVModeToImageType: Map[Int, Int] = Map(
ImageSchema.ocvTypes("CV_8UC1") -> BufferedImage.TYPE_BYTE_GRAY,
ImageSchema.ocvTypes("CV_8UC3") -> BufferedImage.TYPE_3BYTE_BGR,
ImageSchema.ocvTypes("CV_8UC4") -> BufferedImage.TYPE_4BYTE_ABGR
)

/**
* Takes a Row image (spImage) and returns a Java BufferedImage. Currently supports 1 & 3
* channel images. If the image has 3 channels, we assume the channels are in BGR order.
* Takes a Row image (spImage) and returns a Java BufferedImage. Currently supports 1, 3, & 4
* channel images. If the Row image has 3 or 4 channels, we assume the channels are in BGR(A)
* order.
*
* @param rowImage Image in spark.ml.image format.
* @return Java BGR BufferedImage.
Expand All @@ -35,81 +47,89 @@ private[sparkdl] object ImageUtils {
val height = ImageSchema.getHeight(rowImage)
val width = ImageSchema.getWidth(rowImage)
val channels = ImageSchema.getNChannels(rowImage)
val mode = ImageSchema.getMode(rowImage)
val imageData = ImageSchema.getData(rowImage)
require(
imageData.length == height * width * channels,
s"""| Only one byte per channel is currently supported, got ${imageData.length} bytes for
| image of size ($height, $width, $channels).
""".stripMargin
)
val image = new BufferedImage(width, height, BufferedImage.TYPE_3BYTE_BGR)

var offset, h = 0
var r, g, b: Byte = 0
while (h < height) {
var w = 0
while (w < width) {
channels match {
case 1 =>
b = imageData(offset)
g = b
r = b
case 3 =>
b = imageData(offset)
g = imageData(offset + 1)
r = imageData(offset + 2)
case _ =>
require(false, s"`Channels` must be 1 or 3, got $channels.")
}
val imageType = openCVModeToImageType.getOrElse(mode,
throw new UnsupportedOperationException("Cannot convert row image with " +
Copy link
Collaborator Author

@smurching smurching Jan 26, 2018

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@tomasatdatabricks As of right now, this error can only occur if users try to pass an image row with an unsupported OpenCV mode to DeepImageFeaturizer. I wonder if the current error message is vague/unhelpful to the user in this case...maybe the image validation should be pushed into DeepImageFeaturizer (but it also seems useful to validate the row image here)?

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes I think it should definitely be here as well. I believe you should always try to throw informative exception. Deep ImageFeaturizer is just one of the possible users of this function. Also, it does not have to have the same restrictions.

s"unsupported OpenCV mode = ${mode} to BufferedImage. Supported OpenCV modes: " +
s"${supportedModes.map(ImageSchema.ocvTypes(_)).mkString(", ")}"))

val color = new Color(r & 0xff, g & 0xff, b & 0xff)
image.setRGB(w, h, color.getRGB)
offset += channels
w += 1
val image = new BufferedImage(width, height, imageType)
var offset = 0
val raster = image.getRaster
// NOTE: This code assumes the raw image data in rowImage directly corresponds to the
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Does this change the data compared to the previous way?

Copy link
Collaborator Author

@smurching smurching Jan 29, 2018

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Good Q, it doesn't change the data layout compared to the previous way. The BufferedImage types supported by our utilities (TYPE_BYTE_GRAY, TYPE_3BYTE_BGR, TYPE_4BYTE_AGBR) have a fixed BGR(A) channel ordering so this approach should also be consistent across Java versions (i.e. the internal raster representation of the BufferedImage data won't suddenly change on us, see BufferedImage docs).

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ah I was wondering more about the content, whether what we're doing manually produced the exact same result as Color & Color.getRGB. The tests look like they maintain the same resizing results as before the code change here but wanted to confirm, since doing anything different to the image can affect the results of applying models to it. (Part of it is, I'm not sure how reasonable this assumption stated here is.)

Copy link
Collaborator Author

@smurching smurching Jan 29, 2018

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ahh I see. Actually the resize results may change when feeding grayscale images to DeepImageFeaturizer, but not for three-channel images (our DIF tests only run against three-channel images so they don't catch the change).

The difference for grayscale images is that running

val b = rowImageData(offset)
val color = new Color(b, b, b)
image.setRGB(x, y, color.getRGB)

as the original code did doesn't necessarily set the grayscale pixel at (x, y) to b. This happens because Java takes a weighted combination of the (r, g, b) channels when setting the grayscale pixel value. However with the raster approach the pixel at (x, y) is actually directly set to b. Given that b was likely obtained by calling raster.getSample (how ImageSchema.readImages reads the data of a grayscale image), the raster approach seems ok to me.

That being said I could see changes in model behavior being sufficiently undesirable that we wouldn't want to make this change.

Another thing: just realized this PR no longer validates that the input images to DIF are either 1 or 3 channels, I'm assuming we should add this validation back? The practical implication being that as it stands, this PR allows users to feed 4-channel images to DIF without hitting an error.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think the direct raster access is the the right way, ImageSchema.decode does the same and using it for the BGR(A) images makes the code more consistent.

Why do you want to restrict the number of channels on DIF?

Copy link
Collaborator Author

@smurching smurching Jan 30, 2018

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@tomasatdatabricks mainly I just didn't want to inadvertently release support for four-channel inputs in DIF without adding unit tests :P. Although I suppose we already supported one-channel inputs without specifically testing DIF on one-channel images.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

ok the current logic looks okay to me. the most important thing is not to change the logic if possible so there is no breaking change for users. i doubt this is a problem at this point.

for the one- and four- channel input stuff -- if we want any testing, it seems enough to have unit tests to make sure this resizing function returns the correct number of channels given targetChannels. Just checking the number of channels seems good enough - we could also check the content of the image if we want to catch any breaking changes in the future.

Copy link
Collaborator Author

@smurching smurching Jan 30, 2018

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

SGTM. Should we add the unit test in this PR or in a follow-up? I'm happy to do either, it's a pretty small test. Thanks again for the reviews :)

// raster of our output Java BufferedImage.
for (h <- 0 until height) {
for (w <- 0 until width) {
for (c <- 1 to channels) {
val cIdx = channels - c
raster.setSample(w, h, cIdx, imageData(offset) & 0xff)
offset += 1
}
}
h += 1
}
image
}

/** Returns the OCV type (int) of the passed-in image */
private def getOCVType(img: BufferedImage): Int = {
val isGray = img.getColorModel.getColorSpace.getType == ColorSpace.TYPE_GRAY
val hasAlpha = img.getColorModel.hasAlpha
if (isGray) {
ImageSchema.ocvTypes("CV_8UC1")
} else if (hasAlpha) {
ImageSchema.ocvTypes("CV_8UC4")
} else {
ImageSchema.ocvTypes("CV_8UC3")
}
}

/**
* Takes a Java BufferedImage and returns a Row Image (spImage).
*
* @param image Java BufferedImage.
* @return Row image in spark.ml.image format with 3 channels in BGR order.
* @return Row image in spark.ml.image format with channels in BGR(A) order.
*/
private[sparkdl] def spImageFromBufferedImage(image: BufferedImage, origin: String = null): Row = {
val channels = 3
val channels = image.getColorModel.getNumComponents
val height = image.getHeight
val width = image.getWidth

val isGray = image.getColorModel.getColorSpace.getType == ColorSpace.TYPE_GRAY
val hasAlpha = image.getColorModel.hasAlpha
val decoded = new Array[Byte](height * width * channels)
var offset, h = 0
while (h < height) {
var w = 0
while (w < width) {
val color = new Color(image.getRGB(w, h))
decoded(offset) = color.getBlue.toByte
decoded(offset + 1) = color.getGreen.toByte
decoded(offset + 2) = color.getRed.toByte
offset += channels
w += 1
// NOTE: This code assumes the raster of our Java BufferedImage directly corresponds to
// raw image data
val raster = image.getRaster
var offset = 0
for (h <- 0 until height) {
for (w <- 0 until width) {
for (c <- 1 to channels) {
val cIdx = channels - c
decoded(offset) = raster.getSample(w, h, cIdx).toByte
offset += 1
}
}
h += 1
}
Row(origin, height, width, channels, ImageSchema.ocvTypes("CV_8UC3"), decoded)
Row(origin, height, width, channels, getOCVType(image), decoded)

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Nit: extra blank line

}


/**
* Resizes an image and returns it as an Array[Byte]. Only 1 and 3 channel inputs, where each
* Resizes an image and returns it as an Array[Byte]. Only 1, 3, and 4 channel inputs, where each
* channel is a single Byte, are currently supported. Only BGR channel order is supported but
* this might work for other channel orders.
*
* @param tgtHeight desired height of output image.
* @param tgtWidth desired width of output image.
* @param tgtChannels number of channels of output image (must be 3), may be used later to
* support more channels.
* @param tgtChannels number of channels (must be 1, 3, or 4) in output image.
* @param spImage image to resize.
* @param scaleHint hint which algorhitm to use, see java.awt.Image#SCALE_SCALE_AREA_AVERAGING
* @return resized image, if the input was BGR or 1 channel, the output will be BGR.
Expand All @@ -120,8 +140,6 @@ private[sparkdl] object ImageUtils {
tgtChannels: Int,
spImage: Row,
scaleHint: Int = Image.SCALE_AREA_AVERAGING): Row = {
require(tgtChannels == 3, s"`tgtChannels` was set to $tgtChannels, must be 3.")

val height = ImageSchema.getHeight(spImage)
val width = ImageSchema.getWidth(spImage)
val nChannels = ImageSchema.getNChannels(spImage)
Expand All @@ -130,7 +148,14 @@ private[sparkdl] object ImageUtils {
spImage
} else {
val srcImg = spImageToBufferedImage(spImage)
val tgtImg = new BufferedImage(tgtWidth, tgtHeight, BufferedImage.TYPE_3BYTE_BGR)
val tgtImgType = tgtChannels match {
case 1 => BufferedImage.TYPE_BYTE_GRAY
case 3 => BufferedImage.TYPE_3BYTE_BGR
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I would add
case _ => throw meaningful exception.

case 4 => BufferedImage.TYPE_4BYTE_ABGR
case _ => throw new UnsupportedOperationException("Image resize: number of output " +
s"channels must be 1, 3, or 4, got ${tgtChannels}.")
}
val tgtImg = new BufferedImage(tgtWidth, tgtHeight, tgtImgType)
// scaledImg is a java.awt.Image which supports drawing but not pixel lookup by index.
val scaledImg = srcImg.getScaledInstance(tgtWidth, tgtHeight, scaleHint)
// Draw scaledImage onto resized (usually smaller) tgtImg so extract individual pixel values.
Expand Down
71 changes: 45 additions & 26 deletions src/test/scala/com/databricks/sparkdl/ImageUtilsSuite.scala
Expand Up @@ -16,31 +16,29 @@

package com.databricks.sparkdl

import java.awt.Color
import java.awt.{Color, Image}
import java.awt.image.BufferedImage
import java.io.File
import javax.imageio.ImageIO

import scala.util.Random

import org.scalatest.FunSuite

import org.apache.spark.ml.image.ImageSchema
import org.apache.spark.sql.Row

import org.scalatest.FunSuite

object ImageUtilsSuite {
val biggerImage: Row = {
val biggerFile = getClass.getResource("/sparkdl/test-image-collection/00081101.jpg").getFile
val imageBuffer = ImageIO.read(new File(biggerFile))
ImageUtils.spImageFromBufferedImage(imageBuffer)
}

val smallerImage: Row = {
val smallerFile = getClass.getResource("/sparkdl/00081101-small-version.png").getFile
val imageBuffer = ImageIO.read(new File(smallerFile))
/** Read image data into a BufferedImage, then use our utility method to convert to a row image */
def getImageRow(resourcePath: String): Row = {
val resourceFilename = getClass.getResource(resourcePath).getFile
val imageBuffer = ImageIO.read(new File(resourceFilename))
ImageUtils.spImageFromBufferedImage(imageBuffer)
}


def smallerImage: Row = getImageRow("/sparkdl/00081101-small-version.png")
def biggerImage: Row = getImageRow("/sparkdl/test-image-collection/00081101.jpg")
}

class ImageUtilsSuite extends FunSuite {
Expand All @@ -50,26 +48,47 @@ class ImageUtilsSuite extends FunSuite {
import ImageUtilsSuite._

test("Test spImage resize.") {
val tgtHeight: Int = ImageSchema.getHeight(smallerImage)
val tgtWidth: Int = ImageSchema.getWidth(smallerImage)
val tgtChannels: Int = ImageSchema.getNChannels(smallerImage)
def javaResize(imagePath: String, tgtWidth: Int, tgtHeight: Int): Row = {
// Read BufferedImage directly from file
val resourceFilename = getClass.getResource(imagePath).getFile
val srcImg = ImageIO.read(new File(resourceFilename))
val tgtImg = new BufferedImage(tgtWidth, tgtHeight, srcImg.getType)
// scaledImg is a java.awt.Image which supports drawing but not pixel lookup by index.
val scaledImg = srcImg.getScaledInstance(tgtWidth, tgtHeight, Image.SCALE_AREA_AVERAGING)
// Draw scaledImage onto resized (usually smaller) tgtImg so extract individual pixel values.
val graphic = tgtImg.createGraphics()
graphic.drawImage(scaledImg, 0, 0, null)
graphic.dispose()
ImageUtils.spImageFromBufferedImage(tgtImg)
}

val testImage = ImageUtils.resizeImage(tgtHeight, tgtWidth, tgtChannels, biggerImage)
assert(testImage === smallerImage, "Resizing image did not produce expected smaller image.")
for (channels <- Seq(1, 3, 4)) {
val path = s"/sparkdl/test-image-collection/${channels}_channels/00074201.png"
val biggerImage = getImageRow(path)
val tgtHeight: Int = ImageSchema.getHeight(biggerImage) / 2
val tgtWidth: Int = ImageSchema.getWidth(biggerImage) / 2
val tgtChannels: Int = ImageSchema.getNChannels(biggerImage)

val expectedImage = javaResize(path, tgtWidth, tgtHeight)
val resizedImage = ImageUtils.resizeImage(tgtHeight, tgtWidth, tgtChannels, biggerImage)
assert(resizedImage === expectedImage, "Resizing image did not produce expected smaller " +
"image.")
}
}

test ("Test Row image -> BufferedImage -> Row image") {
val height = 200
val width = 100
val channels = 3

val rand = new Random(971)
val imageData = Array.ofDim[Byte](height * width * channels)
rand.nextBytes(imageData)
val spImage = Row(null, height, width, channels, ImageSchema.ocvTypes("CV_8UC3"), imageData)
val bufferedImage = ImageUtils.spImageToBufferedImage(spImage)
val testImage = ImageUtils.spImageFromBufferedImage(bufferedImage)
assert(spImage === testImage, "Image changed during conversion.")
for (channels <- Seq(1, 3, 4)) {
val rand = new Random(971)
val imageData = Array.ofDim[Byte](height * width * channels)
rand.nextBytes(imageData)
val ocvType = s"CV_8UC$channels"
val spImage = Row(null, height, width, channels, ImageSchema.ocvTypes(ocvType), imageData)
val bufferedImage = ImageUtils.spImageToBufferedImage(spImage)
val testImage = ImageUtils.spImageFromBufferedImage(bufferedImage)
assert(spImage === testImage, "Image changed during conversion")
}
}

test("Simple BufferedImage from Row Image") {
Expand Down