diff --git a/spark/core/src/main/scala/org/elasticsearch/spark/serialization/ScalaValueReader.scala b/spark/core/src/main/scala/org/elasticsearch/spark/serialization/ScalaValueReader.scala index 174250c56..7995959c1 100644 --- a/spark/core/src/main/scala/org/elasticsearch/spark/serialization/ScalaValueReader.scala +++ b/spark/core/src/main/scala/org/elasticsearch/spark/serialization/ScalaValueReader.scala @@ -126,7 +126,7 @@ class ScalaValueReader extends AbstractValueReader with SettingsAware { } } - def nullValue() = { None } + def nullValue() = { null } def textValue(value: String, parser: Parser) = { checkNull (parseText, value, parser) } protected def parseText(value:String, parser: Parser) = { value } diff --git a/spark/core/src/test/scala/org/elasticsearch/spark/ScalaExtendedBooleanValueReaderTest.scala b/spark/core/src/test/scala/org/elasticsearch/spark/ScalaExtendedBooleanValueReaderTest.scala index 7040a4b74..08c836bab 100644 --- a/spark/core/src/test/scala/org/elasticsearch/spark/ScalaExtendedBooleanValueReaderTest.scala +++ b/spark/core/src/test/scala/org/elasticsearch/spark/ScalaExtendedBooleanValueReaderTest.scala @@ -49,7 +49,7 @@ class ScalaExtendedBooleanValueReaderTest(jsonString: String, expected: Expected def isNull: Matcher[AnyRef] = { return new BaseMatcher[AnyRef] { - override def matches(item: scala.Any): Boolean = item == None + override def matches(item: scala.Any): Boolean = item == null override def describeTo(description: Description): Unit = description.appendText("null") } } diff --git a/spark/core/src/test/scala/org/elasticsearch/spark/ScalaValueReaderTest.scala b/spark/core/src/test/scala/org/elasticsearch/spark/ScalaValueReaderTest.scala index 6c7ea2286..505c0cc0d 100644 --- a/spark/core/src/test/scala/org/elasticsearch/spark/ScalaValueReaderTest.scala +++ b/spark/core/src/test/scala/org/elasticsearch/spark/ScalaValueReaderTest.scala @@ -26,8 +26,8 @@ class ScalaValueReaderTest extends BaseValueReaderTest { override def createValueReader() = new ScalaValueReader() - override def checkNull(result: Object): Unit = { assertEquals(None, result)} - override def checkEmptyString(result: Object): Unit = { assertEquals(None, result)} + override def checkNull(result: Object): Unit = { assertEquals(null, result)} + override def checkEmptyString(result: Object): Unit = { assertEquals(null, result)} override def checkInteger(result: Object): Unit = { assertEquals(Int.MaxValue, result)} override def checkLong(result: Object): Unit = { assertEquals(Long.MaxValue, result)} override def checkDouble(result: Object): Unit = { assertEquals(Double.MaxValue, result)} diff --git a/spark/sql-13/src/itest/scala/org/elasticsearch/spark/integration/AbstractScalaEsSparkSQL.scala b/spark/sql-13/src/itest/scala/org/elasticsearch/spark/integration/AbstractScalaEsSparkSQL.scala index feb395ffb..99158099d 100644 --- a/spark/sql-13/src/itest/scala/org/elasticsearch/spark/integration/AbstractScalaEsSparkSQL.scala +++ b/spark/sql-13/src/itest/scala/org/elasticsearch/spark/integration/AbstractScalaEsSparkSQL.scala @@ -23,7 +23,6 @@ import java.{lang => jl} import java.sql.Timestamp import java.{util => ju} import java.util.concurrent.TimeUnit - import scala.collection.JavaConversions.propertiesAsScalaMap import scala.collection.JavaConverters.asScalaBufferConverter import scala.collection.JavaConverters.mapAsJavaMapConverter @@ -68,6 +67,8 @@ import org.junit.runners.Parameterized.Parameters import org.junit.runners.MethodSorters import com.esotericsoftware.kryo.io.{Input => KryoInput} import com.esotericsoftware.kryo.io.{Output => KryoOutput} +import org.apache.spark.rdd.RDD + import javax.xml.bind.DatatypeConverter import org.elasticsearch.hadoop.{EsHadoopIllegalArgumentException, EsHadoopIllegalStateException} import org.apache.spark.sql.types.DoubleType @@ -419,6 +420,33 @@ class AbstractScalaEsScalaSparkSQL(prefix: String, readMetadata: jl.Boolean, pus val results = sqc.sql("SELECT name FROM datfile WHERE id >=1 AND id <=10") //results.take(5).foreach(println) } + + @Test + def testEmptyStrings(): Unit = { + val data = Seq(("Java", "20000"), ("Python", ""), ("Scala", "3000")) + val rdd: RDD[Row] = sc.parallelize(data).map(row => Row(row._1, row._2)) + val schema = StructType( Array( + StructField("language", StringType,true), + StructField("description", StringType,true) + )) + val inputDf = sqc.createDataFrame(rdd, schema) + inputDf.write + .format("org.elasticsearch.spark.sql") + .save("empty_strings_test") + val reader = sqc.read.format("org.elasticsearch.spark.sql") + val outputDf = reader.load("empty_strings_test") + assertEquals(data.size, outputDf.count) + val nullDescriptionsDf = outputDf.filter("language = 'Python'") + assertEquals(1, nullDescriptionsDf.count) + assertEquals(null, nullDescriptionsDf.first().getAs("description")) + + val reader2 = sqc.read.format("org.elasticsearch.spark.sql").option("es.field.read.empty.as.null", "no") + val outputDf2 = reader2.load("empty_strings_test") + assertEquals(data.size, outputDf2.count) + val emptyDescriptionsDf = outputDf2.filter("language = 'Python'") + assertEquals(1, emptyDescriptionsDf.count) + assertEquals("", emptyDescriptionsDf.first().getAs("description")) + } @Test def test0WriteFieldNameWithPercentage() { diff --git a/spark/sql-20/src/itest/scala/org/elasticsearch/spark/integration/AbstractScalaEsSparkSQL.scala b/spark/sql-20/src/itest/scala/org/elasticsearch/spark/integration/AbstractScalaEsSparkSQL.scala index 489321981..83ceaf670 100644 --- a/spark/sql-20/src/itest/scala/org/elasticsearch/spark/integration/AbstractScalaEsSparkSQL.scala +++ b/spark/sql-20/src/itest/scala/org/elasticsearch/spark/integration/AbstractScalaEsSparkSQL.scala @@ -27,7 +27,6 @@ import java.nio.file.Paths import java.sql.Timestamp import java.{util => ju} import java.util.concurrent.TimeUnit - import scala.collection.JavaConversions.propertiesAsScalaMap import scala.collection.JavaConverters.asScalaBufferConverter import scala.collection.JavaConverters.mapAsJavaMapConverter @@ -86,6 +85,8 @@ import org.junit.runners.Parameterized import org.junit.runners.Parameterized.Parameters import com.esotericsoftware.kryo.io.{Input => KryoInput} import com.esotericsoftware.kryo.io.{Output => KryoOutput} +import org.apache.spark.rdd.RDD + import javax.xml.bind.DatatypeConverter import org.apache.spark.sql.SparkSession import org.elasticsearch.hadoop.EsAssume @@ -438,6 +439,33 @@ class AbstractScalaEsScalaSparkSQL(prefix: String, readMetadata: jl.Boolean, pus val results = sqc.sql("SELECT name FROM datfile WHERE id >=1 AND id <=10") //results.take(5).foreach(println) } + + @Test + def testEmptyStrings(): Unit = { + val data = Seq(("Java", "20000"), ("Python", ""), ("Scala", "3000")) + val rdd: RDD[Row] = sc.parallelize(data).map(row => Row(row._1, row._2)) + val schema = StructType( Array( + StructField("language", StringType,true), + StructField("description", StringType,true) + )) + val inputDf = sqc.createDataFrame(rdd, schema) + inputDf.write + .format("org.elasticsearch.spark.sql") + .save("empty_strings_test") + val reader = sqc.read.format("org.elasticsearch.spark.sql") + val outputDf = reader.load("empty_strings_test") + assertEquals(data.size, outputDf.count) + val nullDescriptionsDf = outputDf.filter(row => row.getAs("description") == null) + assertEquals(1, nullDescriptionsDf.count) + + val reader2 = sqc.read.format("org.elasticsearch.spark.sql").option("es.field.read.empty.as.null", "no") + val outputDf2 = reader2.load("empty_strings_test") + assertEquals(data.size, outputDf2.count) + val nullDescriptionsDf2 = outputDf2.filter(row => row.getAs("description") == null) + assertEquals(0, nullDescriptionsDf2.count) + val emptyDescriptionsDf = outputDf2.filter(row => row.getAs("description") == "") + assertEquals(1, emptyDescriptionsDf.count) + } @Test def test0WriteFieldNameWithPercentage() { diff --git a/spark/sql-30/src/itest/scala/org/elasticsearch/spark/integration/AbstractScalaEsSparkSQL.scala b/spark/sql-30/src/itest/scala/org/elasticsearch/spark/integration/AbstractScalaEsSparkSQL.scala index 86acac5cb..1166c6ac2 100644 --- a/spark/sql-30/src/itest/scala/org/elasticsearch/spark/integration/AbstractScalaEsSparkSQL.scala +++ b/spark/sql-30/src/itest/scala/org/elasticsearch/spark/integration/AbstractScalaEsSparkSQL.scala @@ -27,7 +27,6 @@ import java.nio.file.Paths import java.sql.Timestamp import java.{util => ju} import java.util.concurrent.TimeUnit - import scala.collection.JavaConversions.propertiesAsScalaMap import scala.collection.JavaConverters.asScalaBufferConverter import scala.collection.JavaConverters.mapAsJavaMapConverter @@ -86,6 +85,8 @@ import org.junit.runners.Parameterized import org.junit.runners.Parameterized.Parameters import com.esotericsoftware.kryo.io.{Input => KryoInput} import com.esotericsoftware.kryo.io.{Output => KryoOutput} +import org.apache.spark.rdd.RDD + import javax.xml.bind.DatatypeConverter import org.apache.spark.sql.SparkSession import org.elasticsearch.hadoop.EsAssume @@ -98,6 +99,7 @@ import org.junit.Assert._ import org.junit.ClassRule object AbstractScalaEsScalaSparkSQL { + @transient val conf = new SparkConf() .setAll(propertiesAsScalaMap(TestSettings.TESTING_PROPS)) .setAppName("estest") @@ -438,7 +440,34 @@ class AbstractScalaEsScalaSparkSQL(prefix: String, readMetadata: jl.Boolean, pus val results = sqc.sql("SELECT name FROM datfile WHERE id >=1 AND id <=10") //results.take(5).foreach(println) } - + + @Test + def testEmptyStrings(): Unit = { + val data = Seq(("Java", "20000"), ("Python", ""), ("Scala", "3000")) + val rdd: RDD[Row] = sc.parallelize(data).map(row => Row(row._1, row._2)) + val schema = StructType( Array( + StructField("language", StringType,true), + StructField("description", StringType,true) + )) + val inputDf = sqc.createDataFrame(rdd, schema) + inputDf.write + .format("org.elasticsearch.spark.sql") + .save("empty_strings_test") + val reader = sqc.read.format("org.elasticsearch.spark.sql") + val outputDf = reader.load("empty_strings_test") + assertEquals(data.size, outputDf.count) + val nullDescriptionsDf = outputDf.filter(row => row.getAs("description") == null) + assertEquals(1, nullDescriptionsDf.count) + + val reader2 = sqc.read.format("org.elasticsearch.spark.sql").option("es.field.read.empty.as.null", "no") + val outputDf2 = reader2.load("empty_strings_test") + assertEquals(data.size, outputDf2.count) + val nullDescriptionsDf2 = outputDf2.filter(row => row.getAs("description") == null) + assertEquals(0, nullDescriptionsDf2.count) + val emptyDescriptionsDf = outputDf2.filter(row => row.getAs("description") == "") + assertEquals(1, emptyDescriptionsDf.count) + } + @Test def test0WriteFieldNameWithPercentage() { val index = wrapIndex("spark-test-scala-sql-field-with-percentage")