Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Update ImageUtils to support resizing one, three, or four channel images #94

Merged
merged 13 commits into from Jan 30, 2018
Binary file added python/tests/resources/images/1_channels/big.png
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file added python/tests/resources/images/3_channels/big.png
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file added python/tests/resources/images/4_channels/big.png
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Expand Up @@ -115,8 +115,9 @@ class DeepImageFeaturizer(override val uid: String) extends Transformer with Def
val height = model.height
val width = model.width

val resizeUdf = udf((image: Row) => ImageUtils.resizeImage(height, width, 3, image,
DeepImageFeaturizer.scaleHints(getScaleHint)), imSchema)
val resizeUdf = udf((image: Row) => { ImageUtils.resizeImage(height, width, 3, image,
DeepImageFeaturizer.scaleHints(getScaleHint))
}, imSchema)

val imageDF = dataFrame
.withColumn(RESIZED_IMAGE_COL, resizeUdf(col(getInputCol)))
Expand Down
93 changes: 73 additions & 20 deletions src/main/scala/com/databricks/sparkdl/ImageUtils.scala
Expand Up @@ -16,6 +16,7 @@

package com.databricks.sparkdl

import java.awt.color.ColorSpace
import java.awt.image.BufferedImage
import java.awt.{Color, Image}

Expand All @@ -25,8 +26,8 @@ import org.apache.spark.sql.Row
private[sparkdl] object ImageUtils {

/**
* Takes a Row image (spImage) and returns a Java BufferedImage. Currently supports 1 & 3
* channel images. If the image has 3 channels, we assume the channels are in BGR order.
* Takes a Row image (spImage) and returns a Java BufferedImage. Currently supports 1, 3, & 4
* channel images. If the image has 3 or 4 channels, we assume the channels are in BGR(A) order.
*
* @param rowImage Image in spark.ml.image format.
* @return Java BGR BufferedImage.
Expand All @@ -42,10 +43,15 @@ private[sparkdl] object ImageUtils {
| image of size ($height, $width, $channels).
""".stripMargin
)
val image = new BufferedImage(width, height, BufferedImage.TYPE_3BYTE_BGR)

val image = channels match {
case 1 => new BufferedImage(width, height, BufferedImage.TYPE_BYTE_GRAY)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Technically there can be images with 2 channels. We should throw UnsupportedOperationException or something like that and inform user that resize does not work for 2 channel images.

Actually, thinking about it, we should probably check that the open cv type is one of the supported ones - e.g. one of CV_8UC{1,3,4}, so how about something something like this :

val mode = ImageSchema.getMode(rowImage)

val image = mode match {
case imageSchema.ocvTypes("CV_8UC1") => new BufferedImage(width, height, BufferedImage.TYPE_BYTE_GRAY)
....
case _ => throw new UnsupportedOperationException("Can not resize images with mode = " + mode)
}

case 3 => new BufferedImage(width, height, BufferedImage.TYPE_3BYTE_BGR)
case 4 => new BufferedImage(width, height, BufferedImage.TYPE_4BYTE_ABGR)
}

var offset, h = 0
var r, g, b: Byte = 0
var r, g, b, a: Byte = 0
while (h < height) {
var w = 0
while (w < width) {
Expand All @@ -58,11 +64,20 @@ private[sparkdl] object ImageUtils {
b = imageData(offset)
g = imageData(offset + 1)
r = imageData(offset + 2)
case 4 =>
b = imageData(offset)
g = imageData(offset + 1)
r = imageData(offset + 2)
a = imageData(offset + 3)
case _ =>
require(false, s"`Channels` must be 1 or 3, got $channels.")
require(false, s"`Channels` must be 1, 3, or 4, got $channels.")
}

val color = new Color(r & 0xff, g & 0xff, b & 0xff)
val color = if (channels < 4) {
new Color(r & 0xff, g & 0xff, b & 0xff)
} else {
new Color(r & 0xff, g & 0xff, b & 0xff, a & 0xff)
}
image.setRGB(w, h, color.getRGB)
offset += channels
w += 1
Expand All @@ -72,15 +87,40 @@ private[sparkdl] object ImageUtils {
image
}

/** Returns the number of channels in the passed-in buffered image. */
Copy link
Collaborator Author

@smurching smurching Jan 24, 2018

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This method and getOCVType invert the mapping done in ImageSchema.scala

private def getNumChannels(img: BufferedImage): Int = {
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why not call image.getColorModel.getNumComponents?

val isGray = img.getColorModel.getColorSpace.getType == ColorSpace.TYPE_GRAY
val hasAlpha = img.getColorModel.hasAlpha
if (isGray) {
1
} else if (hasAlpha) {
4
} else {
3
}
}

/** Returns the OCV type (int) of the passed-in image */
private def getOCVType(img: BufferedImage): Int = {
val isGray = img.getColorModel.getColorSpace.getType == ColorSpace.TYPE_GRAY
val hasAlpha = img.getColorModel.hasAlpha
if (isGray) {
ImageSchema.ocvTypes("CV_8UC1")
} else if (hasAlpha) {
ImageSchema.ocvTypes("CV_8UC4")
} else {
ImageSchema.ocvTypes("CV_8UC3")
}
}

/**
* Takes a Java BufferedImage and returns a Row Image (spImage).
*
* @param image Java BufferedImage.
* @return Row image in spark.ml.image format with 3 channels in BGR order.
* @return Row image in spark.ml.image format with channels in BGR(A) order.
*/
private[sparkdl] def spImageFromBufferedImage(image: BufferedImage, origin: String = null): Row = {
val channels = 3
val channels = getNumChannels(image)
val height = image.getHeight
val width = image.getWidth

Expand All @@ -89,27 +129,37 @@ private[sparkdl] object ImageUtils {
while (h < height) {
var w = 0
while (w < width) {
val color = new Color(image.getRGB(w, h))
decoded(offset) = color.getBlue.toByte
decoded(offset + 1) = color.getGreen.toByte
decoded(offset + 2) = color.getRed.toByte
val color = new Color(image.getRGB(w, h), image.getColorModel.hasAlpha)
channels match {
case 1 =>
decoded(offset) = color.getBlue.toByte
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Note: Unfortunately this doesn't always yield the original blue byte written in spImageToBufferedImage. In general the setRGB and getRGB methods don't yield consistent results when working with grayscale BufferedImages (BufferedImages with image type TYPE_BYTE_GRAY), see this code snippet:

https://gist.github.com/smurching/d6e918a84c79155b616f9339b30fdfca

It makes sense to me that this happens, since it doesn't really make sense to tell a grayscale image "set the pixel at 0, 0 to RGB (x, y, z)" and expect the RGB channel values to actually be (x, y, z) - if this were the case, the image might not be gray. I'm not sure how to work around it...

Copy link
Collaborator Author

@smurching smurching Jan 25, 2018

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actually I think this can be fixed by stealing some code from ImageSchema that directly manipulates the Image raster, will push an update...

case 3 =>
decoded(offset) = color.getBlue.toByte
decoded(offset + 1) = color.getGreen.toByte
decoded(offset + 2) = color.getRed.toByte
case 4 =>
decoded(offset) = color.getBlue.toByte
decoded(offset + 1) = color.getGreen.toByte
decoded(offset + 2) = color.getRed.toByte
decoded(offset + 3) = color.getAlpha.toByte
}
offset += channels
w += 1
}
h += 1
}
Row(origin, height, width, channels, ImageSchema.ocvTypes("CV_8UC3"), decoded)
Row(origin, height, width, channels, getOCVType(image), decoded)

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Nit: extra blank line

}

/**
* Resizes an image and returns it as an Array[Byte]. Only 1 and 3 channel inputs, where each
* Resizes an image and returns it as an Array[Byte]. Only 1, 3, and 4 channel inputs, where each
* channel is a single Byte, are currently supported. Only BGR channel order is supported but
* this might work for other channel orders.
*
* @param tgtHeight desired height of output image.
* @param tgtWidth desired width of output image.
* @param tgtChannels number of channels of output image (must be 3), may be used later to
* support more channels.
* @param tgtChannels number of channels in output image.
* @param spImage image to resize.
* @param scaleHint hint which algorhitm to use, see java.awt.Image#SCALE_SCALE_AREA_AVERAGING
* @return resized image, if the input was BGR or 1 channel, the output will be BGR.
Expand All @@ -120,8 +170,6 @@ private[sparkdl] object ImageUtils {
tgtChannels: Int,
spImage: Row,
scaleHint: Int = Image.SCALE_AREA_AVERAGING): Row = {
require(tgtChannels == 3, s"`tgtChannels` was set to $tgtChannels, must be 3.")

val height = ImageSchema.getHeight(spImage)
val width = ImageSchema.getWidth(spImage)
val nChannels = ImageSchema.getNChannels(spImage)
Expand All @@ -130,14 +178,19 @@ private[sparkdl] object ImageUtils {
spImage
} else {
val srcImg = spImageToBufferedImage(spImage)
val tgtImg = new BufferedImage(tgtWidth, tgtHeight, BufferedImage.TYPE_3BYTE_BGR)
val tgtImgType = tgtChannels match {
case 1 => BufferedImage.TYPE_BYTE_GRAY
case 3 => BufferedImage.TYPE_3BYTE_BGR
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I would add
case _ => throw meaningful exception.

case 4 => BufferedImage.TYPE_4BYTE_ABGR
}
val tgtImg = new BufferedImage(tgtWidth, tgtHeight, tgtImgType)
// scaledImg is a java.awt.Image which supports drawing but not pixel lookup by index.
val scaledImg = srcImg.getScaledInstance(tgtWidth, tgtHeight, scaleHint)
// Draw scaledImage onto resized (usually smaller) tgtImg so extract individual pixel values.
val graphic = tgtImg.createGraphics()
graphic.drawImage(scaledImg, 0, 0, null)
graphic.dispose()
spImageFromBufferedImage(tgtImg, origin=ImageSchema.getOrigin(spImage))
spImageFromBufferedImage(tgtImg, origin = ImageSchema.getOrigin(spImage))
}
}
}
59 changes: 31 additions & 28 deletions src/test/scala/com/databricks/sparkdl/ImageUtilsSuite.scala
Expand Up @@ -22,54 +22,57 @@ import javax.imageio.ImageIO

import scala.util.Random

import org.scalatest.FunSuite

import org.apache.spark.ml.image.ImageSchema
import org.apache.spark.sql.Row

import org.scalatest.FunSuite

object ImageUtilsSuite {
val biggerImage: Row = {
val biggerFile = getClass.getResource("/sparkdl/test-image-collection/00081101.jpg").getFile
val imageBuffer = ImageIO.read(new File(biggerFile))
ImageUtils.spImageFromBufferedImage(imageBuffer)
}

val smallerImage: Row = {
val smallerFile = getClass.getResource("/sparkdl/00081101-small-version.png").getFile
val imageBuffer = ImageIO.read(new File(smallerFile))
/** Read image data into a BufferedImage, then use our utility method to convert to a row image */
def getImageRow(resourcePath: String): Row = {
val resourceUrl = getClass.getResource(resourcePath).getFile
val imageBuffer = ImageIO.read(new File(resourceUrl))
ImageUtils.spImageFromBufferedImage(imageBuffer)
}


def smallerImage: Row = getImageRow("/sparkdl/00081101-small-version.png")
def biggerImage: Row = getImageRow("/sparkdl/test-image-collection/00081101.jpg")
}

class ImageUtilsSuite extends FunSuite {
// We want to make sure to test ImageUtils in headless mode to ensure it'll work on all systems.
assert(System.getProperty("java.awt.headless") === "true")

import ImageUtilsSuite._

test("Test spImage resize.") {
val tgtHeight: Int = ImageSchema.getHeight(smallerImage)
val tgtWidth: Int = ImageSchema.getWidth(smallerImage)
val tgtChannels: Int = ImageSchema.getNChannels(smallerImage)
def getImagePath(imageSize: String, numChannels: Int): String = {
s"/sparkdl/test-image-collection/${numChannels}_channels/$imageSize.png"
}
for (channels <- Seq(1, 3, 4)) {
val smallerImage = ImageUtilsSuite.getImageRow(getImagePath("small", channels))
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Note: I cheated here & had the test compare the original image to a smaller image generated via the resizeImage code path (smallerImage is an image file generated by saving the result of calling resizeImage on biggerImage). @MrBago do you remember how you generated the original image pair for this test?

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Do you mean you compare against image resized using BufferedImage or ImageUtils.resizeImage?

Just a suggestion, why don't we store only the original image, and compare against the image resized in java using BufferedImage on the fly?

val biggerImage = ImageUtilsSuite.getImageRow(getImagePath("big", channels))

val testImage = ImageUtils.resizeImage(tgtHeight, tgtWidth, tgtChannels, biggerImage)
assert(testImage === smallerImage, "Resizing image did not produce expected smaller image.")
val tgtHeight: Int = ImageSchema.getHeight(smallerImage)
val tgtWidth: Int = ImageSchema.getWidth(smallerImage)
val tgtChannels: Int = ImageSchema.getNChannels(smallerImage)

val testImage = ImageUtils.resizeImage(tgtHeight, tgtWidth, tgtChannels, biggerImage)
assert(testImage === smallerImage, "Resizing image did not produce expected smaller image.")
}
}

test ("Test Row image -> BufferedImage -> Row image") {
val height = 200
val width = 100
val channels = 3

val rand = new Random(971)
val imageData = Array.ofDim[Byte](height * width * channels)
rand.nextBytes(imageData)
val spImage = Row(null, height, width, channels, ImageSchema.ocvTypes("CV_8UC3"), imageData)
val bufferedImage = ImageUtils.spImageToBufferedImage(spImage)
val testImage = ImageUtils.spImageFromBufferedImage(bufferedImage)
assert(spImage === testImage, "Image changed during conversion.")
for (channels <- Seq(3, 4)) {
Copy link
Collaborator Author

@smurching smurching Jan 24, 2018

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Note: This test will fail for channels = 1, see my comment in spImageFromBufferedImage

val rand = new Random(971)
val imageData = Array.ofDim[Byte](height * width * channels)
rand.nextBytes(imageData)
val ocvType = s"CV_8UC$channels"
val spImage = Row(null, height, width, channels, ImageSchema.ocvTypes(ocvType), imageData)
val bufferedImage = ImageUtils.spImageToBufferedImage(spImage)
val testImage = ImageUtils.spImageFromBufferedImage(bufferedImage)
assert(spImage === testImage, s"Image changed during conversion")
}
}

test("Simple BufferedImage from Row Image") {
Expand Down