diff --git a/spark/sql-13/src/itest/scala/org/elasticsearch/spark/integration/AbstractScalaEsSparkSQL.scala b/spark/sql-13/src/itest/scala/org/elasticsearch/spark/integration/AbstractScalaEsSparkSQL.scala index a46f58367..919830745 100644 --- a/spark/sql-13/src/itest/scala/org/elasticsearch/spark/integration/AbstractScalaEsSparkSQL.scala +++ b/spark/sql-13/src/itest/scala/org/elasticsearch/spark/integration/AbstractScalaEsSparkSQL.scala @@ -263,7 +263,7 @@ class AbstractScalaEsScalaSparkSQL(prefix: String, readMetadata: jl.Boolean, pus val target = wrapIndex("sparksql-test/scala-basic-write") val newCfg = collection.mutable.Map(cfg.toSeq: _*) += ("es.read.field.include" -> "id, name, url") - + val dataFrame = sqc.esDF(target, newCfg) assertTrue(dataFrame.count > 300) val schema = dataFrame.schema.treeString @@ -300,7 +300,7 @@ class AbstractScalaEsScalaSparkSQL(prefix: String, readMetadata: jl.Boolean, pus val target = wrapIndex("sparksql-test/scala-basic-write") val dfNoQuery = JavaEsSparkSQL.esDF(sqc, target, cfg.asJava) - val query = s"""{ "query" : { "query_string" : { "query" : "name:me*" } } //, "fields" : ["name"] + val query = s"""{ "query" : { "query_string" : { "query" : "name:me*" } } //, "fields" : ["name"] }""" val dfWQuery = JavaEsSparkSQL.esDF(sqc, target, query, cfg.asJava) @@ -453,7 +453,7 @@ class AbstractScalaEsScalaSparkSQL(prefix: String, readMetadata: jl.Boolean, pus val filter = df.filter(df("airport").equalTo("OTP")) if (strictPushDown) { assertEquals(0, filter.count()) - // however if we change the arguments to be lower cased, it will Spark who's going to filter out the data + // however if we change the arguments to be lower cased, it will be Spark who's going to filter out the data return } @@ -517,7 +517,7 @@ class AbstractScalaEsScalaSparkSQL(prefix: String, readMetadata: jl.Boolean, pus if (strictPushDown) { assertEquals(0, filter.count()) - // however if we change the arguments to be lower cased, it will Spark who's going to filter out the data + // however if we change the arguments to be lower cased, it will be Spark who's going to filter out the data return } @@ -525,6 +525,23 @@ class AbstractScalaEsScalaSparkSQL(prefix: String, readMetadata: jl.Boolean, pus assertEquals("jan", filter.select("tag").sort("tag").take(2)(1)(0)) } + @Test + def testDataSourcePushDown08InWithNumber() { + val df = esDataSource("pd_in_number") + var filter = df.filter("participants IN (1, 2, 3)") + + assertEquals(1, filter.count()) + assertEquals("long", filter.select("tag").sort("tag").take(1)(0)(0)) + } + + @Test + def testDataSourcePushDown08InWithNumberAndStrings() { + val df = esDataSource("pd_in_number") + var filter = df.filter("participants IN (2, 'bar', 1, 'foo')") + + assertEquals(0, filter.count()) + } + @Test def testDataSourcePushDown09StartsWith() { val df = esDataSource("pd_starts_with") diff --git a/spark/sql-13/src/main/scala/org/elasticsearch/spark/sql/DataFrameValueWriter.scala b/spark/sql-13/src/main/scala/org/elasticsearch/spark/sql/DataFrameValueWriter.scala index d7906a327..b1024b3ca 100644 --- a/spark/sql-13/src/main/scala/org/elasticsearch/spark/sql/DataFrameValueWriter.scala +++ b/spark/sql-13/src/main/scala/org/elasticsearch/spark/sql/DataFrameValueWriter.scala @@ -10,7 +10,6 @@ import scala.collection.Seq import org.apache.spark.annotation.DeveloperApi import org.apache.spark.sql.Row -import org.apache.spark.sql.catalyst.expressions.GenericRowWithSchema import org.apache.spark.sql.types.ArrayType import org.apache.spark.sql.types.DataType import org.apache.spark.sql.types.DataTypes.BinaryType @@ -51,7 +50,7 @@ class DataFrameValueWriter(writeUnknownTypes: Boolean = false) extends Filtering private[spark] def writeStruct(schema: StructType, value: Any, generator: Generator): Result = { value match { - case r: GenericRowWithSchema => + case r: Row => generator.writeBeginObject() schema.fields.view.zipWithIndex foreach { diff --git a/spark/sql-13/src/main/scala/org/elasticsearch/spark/sql/DefaultSource.scala b/spark/sql-13/src/main/scala/org/elasticsearch/spark/sql/DefaultSource.scala index 0f81bbe79..35c94a07a 100644 --- a/spark/sql-13/src/main/scala/org/elasticsearch/spark/sql/DefaultSource.scala +++ b/spark/sql-13/src/main/scala/org/elasticsearch/spark/sql/DefaultSource.scala @@ -1,10 +1,10 @@ package org.elasticsearch.spark.sql import java.util.Locale - +import scala.None +import scala.Null import scala.collection.JavaConverters.mapAsJavaMapConverter import scala.collection.mutable.LinkedHashMap - import org.apache.spark.rdd.RDD import org.apache.spark.sql.DataFrame import org.apache.spark.sql.Row @@ -44,7 +44,9 @@ import org.elasticsearch.hadoop.util.IOUtils import org.elasticsearch.hadoop.util.StringUtils import org.elasticsearch.spark.cfg.SparkSettingsManager import org.elasticsearch.spark.serialization.ScalaValueWriter -import org.elasticsearch.spark.sql.Utils._ +import scala.collection.mutable.ListBuffer +import scala.collection.mutable.LinkedHashSet +import scala.collection.mutable.ArrayOps private[sql] class DefaultSource extends RelationProvider with SchemaRelationProvider with CreatableRelationProvider { @@ -152,8 +154,14 @@ private[sql] case class ElasticsearchRelation(parameters: Map[String, String], @ case LessThan(attribute, value) => s"""{"range":{"$attribute":{"lt" :${extract(value)}}}}""" case LessThanOrEqual(attribute, value) => s"""{"range":{"$attribute":{"lte":${extract(value)}}}}""" case In(attribute, values) => { - if (strictPushDown) s"""{"terms":{"$attribute":${extractAsJsonArray(values)}}}""" - else s"""{"query":{"match":{"$attribute":${extract(values)}}}}""" + // when dealing with mixed types (strings and numbers) Spark converts the Strings to null (gets confused by the type field) + // this leads to incorrect query DSL hence why nulls are filtered + val filtered = values filter (_ != null) + if (filtered.isEmpty) { + return "" + } + if (strictPushDown) s"""{"terms":{"$attribute":${extractAsJsonArray(filtered)}}}""" + else s"""{"or":{"filters":[${extractMatchArray(attribute, filtered)}]}}""" } case IsNull(attribute) => s"""{"missing":{"field":"$attribute"}}""" case IsNotNull(attribute) => s"""{"exists":{"field":"$attribute"}}""" @@ -207,9 +215,39 @@ private[sql] case class ElasticsearchRelation(parameters: Map[String, String], @ extract(value, true, true) } + private def extractMatchArray(attribute: String, ar: Array[Any]):String = { + // use a set to avoid duplicate values + // especially since Spark conversion might turn each user param into null + val numbers = LinkedHashSet.empty[AnyRef] + val strings = LinkedHashSet.empty[AnyRef] + + // move numbers into a separate list for a terms query combined with a bool + for (i <- ar) i.asInstanceOf[AnyRef] match { + case null => // ignore + case n:Number => numbers += extract(i, false, false) + case _ => strings += extract(i, false, false) + } + + if (numbers.isEmpty) { + if (strings.isEmpty) { + return StringUtils.EMPTY + } + return s"""{"query":{"match":{"$attribute":${strings.mkString("\"", " ", "\"")}}}}""" + //s"""{"query":{"$attribute":${strings.mkString("\"", " ", "\"")}}}""" + } + else { + // translate the numbers into a terms query + val str = s"""{"terms":{"$attribute":${numbers.mkString("[", ",", "]")}}}""" + if (strings.isEmpty) return str + // if needed, add the strings as a match query + else return str + s""",{"query":{"match":{"$attribute":${strings.mkString("\"", " ", "\"")}}}}""" + } + } + private def extract(value: Any, inJsonFormat: Boolean, asJsonArray: Boolean):String = { // common-case implies primitives and String so try these before using the full-blown ValueWriter value match { + case null => "null" case u: Unit => "null" case b: Boolean => b.toString case c: Char => if (inJsonFormat) StringUtils.toJsonString(c) else c.toString() @@ -220,14 +258,14 @@ private[sql] case class ElasticsearchRelation(parameters: Map[String, String], @ case f: Float => f.toString case d: Double => d.toString case s: String => if (inJsonFormat) StringUtils.toJsonString(s) else s + case ar: Array[Any] => + if (asJsonArray) (for (i <- ar) yield extract(i, true, false)).distinct.mkString("[", ",", "]") + else (for (i <- ar) yield extract(i, false, false)).distinct.mkString("\"", " ", "\"") // new in Spark 1.4 case utf if (isClass(utf, "org.apache.spark.sql.types.UTF8String") // new in Spark 1.5 || isClass(utf, "org.apache.spark.unsafe.types.UTF8String")) => if (inJsonFormat) StringUtils.toJsonString(utf.toString()) else utf.toString() - case ar: Array[Any] => - if (asJsonArray) (for (i <- ar) yield extract(i, true, false)).mkString("[", ",", "]") - else (for (i <- ar) yield extract(i, false, false)).mkString("\"", " ", "\"") case a: AnyRef => { val storage = new FastByteArrayOutputStream() val generator = new JacksonJsonGenerator(storage) diff --git a/spark/sql-13/src/main/scala/org/elasticsearch/spark/sql/SchemaUtils.scala b/spark/sql-13/src/main/scala/org/elasticsearch/spark/sql/SchemaUtils.scala index 6660cd9d3..f9ae2f6ba 100644 --- a/spark/sql-13/src/main/scala/org/elasticsearch/spark/sql/SchemaUtils.scala +++ b/spark/sql-13/src/main/scala/org/elasticsearch/spark/sql/SchemaUtils.scala @@ -1,11 +1,9 @@ package org.elasticsearch.spark.sql import java.util.Properties - import scala.collection.JavaConverters.asScalaBufferConverter import scala.collection.JavaConverters.propertiesAsScalaMapConverter import scala.collection.mutable.ArrayBuffer - import org.apache.spark.sql.types.BinaryType import org.apache.spark.sql.types.BooleanType import org.apache.spark.sql.types.ByteType @@ -41,8 +39,7 @@ import org.elasticsearch.hadoop.serialization.dto.mapping.MappingUtils import org.elasticsearch.hadoop.util.Assert import org.elasticsearch.hadoop.util.IOUtils import org.elasticsearch.hadoop.util.StringUtils -import org.elasticsearch.spark.sql.Utils.ROOT_LEVEL_NAME -import org.elasticsearch.spark.sql.Utils.ROW_ORDER_PROPERTY +import org.elasticsearch.spark.sql.Utils._ private[sql] object SchemaUtils { case class Schema(field: Field, struct: StructType) @@ -60,11 +57,11 @@ private[sql] object SchemaUtils { val repo = new RestRepository(cfg) try { if (repo.indexExists(true)) { - + var field = repo.getMapping.skipHeaders() val readIncludeCfg = cfg.getProperty(readInclude) val readExcludeCfg = cfg.getProperty(readExclude) - + // apply mapping filtering only when present to minimize configuration settings (big when dealing with large mappings) if (StringUtils.hasText(readIncludeCfg) || StringUtils.hasText(readExcludeCfg)) { // apply any possible include/exclude that can define restrict the DataFrame to just a number of fields