Skip to content

Commit

Permalink
[SPARK] Properly translate In filter w/ numbers
Browse files Browse the repository at this point in the history
Use generic trait instead of concrete impl
Relates #556

(cherry picked from commit fd80b13)
  • Loading branch information
costin committed Oct 7, 2015
1 parent 1129391 commit bbe6154
Show file tree
Hide file tree
Showing 4 changed files with 71 additions and 20 deletions.
Expand Up @@ -263,7 +263,7 @@ class AbstractScalaEsScalaSparkSQL(prefix: String, readMetadata: jl.Boolean, pus
val target = wrapIndex("sparksql-test/scala-basic-write")

val newCfg = collection.mutable.Map(cfg.toSeq: _*) += ("es.read.field.include" -> "id, name, url")

val dataFrame = sqc.esDF(target, newCfg)
assertTrue(dataFrame.count > 300)
val schema = dataFrame.schema.treeString
Expand Down Expand Up @@ -300,7 +300,7 @@ class AbstractScalaEsScalaSparkSQL(prefix: String, readMetadata: jl.Boolean, pus
val target = wrapIndex("sparksql-test/scala-basic-write")

val dfNoQuery = JavaEsSparkSQL.esDF(sqc, target, cfg.asJava)
val query = s"""{ "query" : { "query_string" : { "query" : "name:me*" } } //, "fields" : ["name"]
val query = s"""{ "query" : { "query_string" : { "query" : "name:me*" } } //, "fields" : ["name"]
}"""
val dfWQuery = JavaEsSparkSQL.esDF(sqc, target, query, cfg.asJava)

Expand Down Expand Up @@ -453,7 +453,7 @@ class AbstractScalaEsScalaSparkSQL(prefix: String, readMetadata: jl.Boolean, pus
val filter = df.filter(df("airport").equalTo("OTP"))
if (strictPushDown) {
assertEquals(0, filter.count())
// however if we change the arguments to be lower cased, it will Spark who's going to filter out the data
// however if we change the arguments to be lower cased, it will be Spark who's going to filter out the data
return
}

Expand Down Expand Up @@ -517,14 +517,31 @@ class AbstractScalaEsScalaSparkSQL(prefix: String, readMetadata: jl.Boolean, pus

if (strictPushDown) {
assertEquals(0, filter.count())
// however if we change the arguments to be lower cased, it will Spark who's going to filter out the data
// however if we change the arguments to be lower cased, it will be Spark who's going to filter out the data
return
}

assertEquals(2, filter.count())
assertEquals("jan", filter.select("tag").sort("tag").take(2)(1)(0))
}

@Test
def testDataSourcePushDown08InWithNumber() {
val df = esDataSource("pd_in_number")
var filter = df.filter("participants IN (1, 2, 3)")

assertEquals(1, filter.count())
assertEquals("long", filter.select("tag").sort("tag").take(1)(0)(0))
}

@Test
def testDataSourcePushDown08InWithNumberAndStrings() {
val df = esDataSource("pd_in_number")
var filter = df.filter("participants IN (2, 'bar', 1, 'foo')")

assertEquals(0, filter.count())
}

@Test
def testDataSourcePushDown09StartsWith() {
val df = esDataSource("pd_starts_with")
Expand Down
Expand Up @@ -10,7 +10,6 @@ import scala.collection.Seq

import org.apache.spark.annotation.DeveloperApi
import org.apache.spark.sql.Row
import org.apache.spark.sql.catalyst.expressions.GenericRowWithSchema
import org.apache.spark.sql.types.ArrayType
import org.apache.spark.sql.types.DataType
import org.apache.spark.sql.types.DataTypes.BinaryType
Expand Down Expand Up @@ -51,7 +50,7 @@ class DataFrameValueWriter(writeUnknownTypes: Boolean = false) extends Filtering

private[spark] def writeStruct(schema: StructType, value: Any, generator: Generator): Result = {
value match {
case r: GenericRowWithSchema =>
case r: Row =>
generator.writeBeginObject()

schema.fields.view.zipWithIndex foreach {
Expand Down
@@ -1,10 +1,10 @@
package org.elasticsearch.spark.sql

import java.util.Locale

import scala.None
import scala.Null
import scala.collection.JavaConverters.mapAsJavaMapConverter
import scala.collection.mutable.LinkedHashMap

import org.apache.spark.rdd.RDD
import org.apache.spark.sql.DataFrame
import org.apache.spark.sql.Row
Expand Down Expand Up @@ -44,7 +44,9 @@ import org.elasticsearch.hadoop.util.IOUtils
import org.elasticsearch.hadoop.util.StringUtils
import org.elasticsearch.spark.cfg.SparkSettingsManager
import org.elasticsearch.spark.serialization.ScalaValueWriter
import org.elasticsearch.spark.sql.Utils._
import scala.collection.mutable.ListBuffer
import scala.collection.mutable.LinkedHashSet
import scala.collection.mutable.ArrayOps

private[sql] class DefaultSource extends RelationProvider with SchemaRelationProvider with CreatableRelationProvider {

Expand Down Expand Up @@ -152,8 +154,14 @@ private[sql] case class ElasticsearchRelation(parameters: Map[String, String], @
case LessThan(attribute, value) => s"""{"range":{"$attribute":{"lt" :${extract(value)}}}}"""
case LessThanOrEqual(attribute, value) => s"""{"range":{"$attribute":{"lte":${extract(value)}}}}"""
case In(attribute, values) => {
if (strictPushDown) s"""{"terms":{"$attribute":${extractAsJsonArray(values)}}}"""
else s"""{"query":{"match":{"$attribute":${extract(values)}}}}"""
// when dealing with mixed types (strings and numbers) Spark converts the Strings to null (gets confused by the type field)
// this leads to incorrect query DSL hence why nulls are filtered
val filtered = values filter (_ != null)
if (filtered.isEmpty) {
return ""
}
if (strictPushDown) s"""{"terms":{"$attribute":${extractAsJsonArray(filtered)}}}"""
else s"""{"or":{"filters":[${extractMatchArray(attribute, filtered)}]}}"""
}
case IsNull(attribute) => s"""{"missing":{"field":"$attribute"}}"""
case IsNotNull(attribute) => s"""{"exists":{"field":"$attribute"}}"""
Expand Down Expand Up @@ -207,9 +215,39 @@ private[sql] case class ElasticsearchRelation(parameters: Map[String, String], @
extract(value, true, true)
}

private def extractMatchArray(attribute: String, ar: Array[Any]):String = {
// use a set to avoid duplicate values
// especially since Spark conversion might turn each user param into null
val numbers = LinkedHashSet.empty[AnyRef]
val strings = LinkedHashSet.empty[AnyRef]

// move numbers into a separate list for a terms query combined with a bool
for (i <- ar) i.asInstanceOf[AnyRef] match {
case null => // ignore
case n:Number => numbers += extract(i, false, false)
case _ => strings += extract(i, false, false)
}

if (numbers.isEmpty) {
if (strings.isEmpty) {
return StringUtils.EMPTY
}
return s"""{"query":{"match":{"$attribute":${strings.mkString("\"", " ", "\"")}}}}"""
//s"""{"query":{"$attribute":${strings.mkString("\"", " ", "\"")}}}"""
}
else {
// translate the numbers into a terms query
val str = s"""{"terms":{"$attribute":${numbers.mkString("[", ",", "]")}}}"""
if (strings.isEmpty) return str
// if needed, add the strings as a match query
else return str + s""",{"query":{"match":{"$attribute":${strings.mkString("\"", " ", "\"")}}}}"""
}
}

private def extract(value: Any, inJsonFormat: Boolean, asJsonArray: Boolean):String = {
// common-case implies primitives and String so try these before using the full-blown ValueWriter
value match {
case null => "null"
case u: Unit => "null"
case b: Boolean => b.toString
case c: Char => if (inJsonFormat) StringUtils.toJsonString(c) else c.toString()
Expand All @@ -220,14 +258,14 @@ private[sql] case class ElasticsearchRelation(parameters: Map[String, String], @
case f: Float => f.toString
case d: Double => d.toString
case s: String => if (inJsonFormat) StringUtils.toJsonString(s) else s
case ar: Array[Any] =>
if (asJsonArray) (for (i <- ar) yield extract(i, true, false)).distinct.mkString("[", ",", "]")
else (for (i <- ar) yield extract(i, false, false)).distinct.mkString("\"", " ", "\"")
// new in Spark 1.4
case utf if (isClass(utf, "org.apache.spark.sql.types.UTF8String")
// new in Spark 1.5
|| isClass(utf, "org.apache.spark.unsafe.types.UTF8String"))
=> if (inJsonFormat) StringUtils.toJsonString(utf.toString()) else utf.toString()
case ar: Array[Any] =>
if (asJsonArray) (for (i <- ar) yield extract(i, true, false)).mkString("[", ",", "]")
else (for (i <- ar) yield extract(i, false, false)).mkString("\"", " ", "\"")
case a: AnyRef => {
val storage = new FastByteArrayOutputStream()
val generator = new JacksonJsonGenerator(storage)
Expand Down
@@ -1,11 +1,9 @@
package org.elasticsearch.spark.sql

import java.util.Properties

import scala.collection.JavaConverters.asScalaBufferConverter
import scala.collection.JavaConverters.propertiesAsScalaMapConverter
import scala.collection.mutable.ArrayBuffer

import org.apache.spark.sql.types.BinaryType
import org.apache.spark.sql.types.BooleanType
import org.apache.spark.sql.types.ByteType
Expand Down Expand Up @@ -41,8 +39,7 @@ import org.elasticsearch.hadoop.serialization.dto.mapping.MappingUtils
import org.elasticsearch.hadoop.util.Assert
import org.elasticsearch.hadoop.util.IOUtils
import org.elasticsearch.hadoop.util.StringUtils
import org.elasticsearch.spark.sql.Utils.ROOT_LEVEL_NAME
import org.elasticsearch.spark.sql.Utils.ROW_ORDER_PROPERTY
import org.elasticsearch.spark.sql.Utils._

private[sql] object SchemaUtils {
case class Schema(field: Field, struct: StructType)
Expand All @@ -60,11 +57,11 @@ private[sql] object SchemaUtils {
val repo = new RestRepository(cfg)
try {
if (repo.indexExists(true)) {

var field = repo.getMapping.skipHeaders()
val readIncludeCfg = cfg.getProperty(readInclude)
val readExcludeCfg = cfg.getProperty(readExclude)

// apply mapping filtering only when present to minimize configuration settings (big when dealing with large mappings)
if (StringUtils.hasText(readIncludeCfg) || StringUtils.hasText(readExcludeCfg)) {
// apply any possible include/exclude that can define restrict the DataFrame to just a number of fields
Expand Down

0 comments on commit bbe6154

Please sign in to comment.