Skip to content

Commit

Permalink
ignore nullable in DeepImageFeaturizer.validateSchema (#143)
Browse files Browse the repository at this point in the history
In Spark SQL, nullability is a hint used during optimization and codegen to skip nullchecks, but not intended as an enforcement mechanism or as an implication that null values do exist. It might get dropped through the pipeline.

This PR switches to DataType.equalsIgnoreNullability for the check

* ignore nullable in DeepImageFeaturizer.validateSchema
  • Loading branch information
mengxr committed Jul 2, 2018
1 parent a44fcbb commit 973e9da
Show file tree
Hide file tree
Showing 3 changed files with 42 additions and 3 deletions.
Expand Up @@ -25,14 +25,13 @@ import org.apache.spark.ml.linalg.Vectors
import org.apache.spark.ml.param.{Param, ParamMap}
import org.apache.spark.ml.util.{DefaultParamsReadable, DefaultParamsWritable, Identifiable}
import org.apache.spark.sql.functions.{col, udf}
import org.apache.spark.sql.types.StructType
import org.apache.spark.sql.types.{DataTypeShim, StructType}
import org.apache.spark.sql.{DataFrame, Dataset, Row}
import org.tensorflow.framework.GraphDef
import org.tensorframes.impl.DebugRowOps
import org.tensorframes.{Shape, ShapeDescription}



class DeepImageFeaturizer(override val uid: String) extends Transformer with DefaultParamsWritable {

def this() = this(Identifiable.randomUID("deepImageFeaturizer"))
Expand Down Expand Up @@ -65,7 +64,7 @@ class DeepImageFeaturizer(override val uid: String) extends Transformer with Def
val fieldIndex = schema.fieldIndex(inputColumnName)
val colType = schema.fields(fieldIndex).dataType
require(
colType == ImageSchema.columnSchema,
DataTypeShim.equalsIgnoreNullability(colType, ImageSchema.columnSchema),
s"inputCol must be an image column with schema ImageSchema.columnSchema, got ${colType}"
)
}
Expand Down
26 changes: 26 additions & 0 deletions src/main/scala/org/apache/spark/sql/types/DataTypeShim.scala
@@ -0,0 +1,26 @@
/*
* Copyright 2017 Databricks, Inc.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

package org.apache.spark.sql.types

object DataTypeShim {
/**
* Compares two types, ignoring nullability of ArrayType, MapType, StructType.
*/
def equalsIgnoreNullability(left: DataType, right: DataType): Boolean = {
DataType.equalsIgnoreNullability(left, right)
}
}
Expand Up @@ -128,4 +128,18 @@ class DeepImageFeaturizerSuite extends FunSuite with TestSparkContext with Defau
.setOutputCol("myOutput")
testDefaultReadWrite(featurizer)
}

test("DeepImageFeaturizer accepts nullable") {
val nullableImageSchema = StructType(
data.schema("image").dataType.asInstanceOf[StructType]
.fields.map(_.copy(nullable = true)))
val nullableSchema = StructType(StructField("image", nullableImageSchema, true) :: Nil)
val featurizer = new DeepImageFeaturizer()
.setModelName("_test")
.setInputCol("image")
.setOutputCol("features")
withClue("featurizer should accept nullable schemas") {
featurizer.transformSchema(nullableSchema)
}
}
}

0 comments on commit 973e9da

Please sign in to comment.