Skip to content

Commit

Permalink
add complex type support for azure search sink
Browse files Browse the repository at this point in the history
  • Loading branch information
mhamilton723 committed Oct 8, 2019
1 parent de8b542 commit aec1672
Show file tree
Hide file tree
Showing 8 changed files with 141 additions and 158 deletions.
2 changes: 1 addition & 1 deletion notebooks/samples/AzureSearchIndex - Met Artworks.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -189,7 +189,7 @@
},
"outputs": [],
"source": [
"from mmlspark.io.http import *\n",
"from mmlspark.cognitive import *\n",
"data_processed.writeToAzureSearch(options)"
]
},
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -12,15 +12,14 @@
from pyspark.ml.param.shared import *
from pyspark.sql import DataFrame


def streamToAzureSearch(df, options=dict()):
def streamToAzureSearch(df, **options):
jvm = SparkContext.getOrCreate()._jvm
writer = jvm.com.microsoft.ml.spark.cognitive.AzureSearchWriter
return writer.stream(df._jdf, options)

setattr(pyspark.sql.DataFrame, 'streamToAzureSearch', streamToAzureSearch)

def writeToAzureSearch(df, options=dict()):
def writeToAzureSearch(df, **options):
jvm = SparkContext.getOrCreate()._jvm
writer = jvm.com.microsoft.ml.spark.cognitive.AzureSearchWriter
writer.write(df._jdf, options)
Expand Down
23 changes: 0 additions & 23 deletions src/main/python/mmlspark/io/http/BingImageReader.py

This file was deleted.

139 changes: 73 additions & 66 deletions src/main/scala/com/microsoft/ml/spark/cognitive/AzureSearch.scala
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@ import org.apache.spark.sql.functions.{col, struct, to_json, udf, expr}
import org.apache.spark.sql.streaming.DataStreamWriter
import org.apache.spark.sql.types._
import org.apache.spark.sql.{DataFrame, Dataset, Row}
import com.microsoft.ml.spark.cognitive.IndexJsonProtocol._
import com.microsoft.ml.spark.cognitive.AzureSearchProtocol._
import spray.json._
import DefaultJsonProtocol._

Expand Down Expand Up @@ -122,7 +122,7 @@ class AddDocuments(override val uid: String) extends CognitiveServicesBase(uid)
override def transform(dataset: Dataset[_]): DataFrame = {
if (get(url).isEmpty) {
setUrl(s"https://$getServiceName.search.windows.net" +
s"/indexes/$getIndexName/docs/index?api-version=2017-11-11")
s"/indexes/$getIndexName/docs/index?api-version=${AzureSearchAPIConstants.DefaultAPIVersion}")
}
super.transform(dataset)
}
Expand Down Expand Up @@ -153,40 +153,53 @@ object AzureSearchWriter extends IndexParser with SLogging {
df.withColumn(collectionColName, expr(s"filter($collectionColName, x -> x is not null)"))
}

private def dfToIndexJson(schema: StructType,
indexName: String,
private def convertFields(fields: Seq[StructField],
keyCol: String,
searchActionCol: String,
searchableCols: List[String],
filterableCols: List[String],
sortableCols: List[String],
facetableCols: List[String],
retrievableCols: List[String]): String = {
val is = IndexSchema(indexName, schema.fields.filterNot(_.name == searchActionCol).map(sf =>
Field(
prefix: Option[String]): Seq[IndexField] = {
fields.filterNot(_.name == searchActionCol).map { sf =>
val fullName = prefix.map(_ + sf.name).getOrElse(sf.name)
val (innerType, innerFields) = sparkTypeToEdmType(sf.dataType)
IndexField(
sf.name,
sparkTypeToEdmType(sf.dataType),
searchableCols.contains(sf.name),
filterableCols.contains(sf.name),
sortableCols.contains(sf.name),
facetableCols.contains(sf.name),
keyCol == sf.name,
retrievableCols.contains(sf.name),
None,
None,
None,
None
innerType,
None, None, None, None, None,
if (keyCol == fullName) Some(true) else None,
None, None, None, None,
structFieldToSearchFields(sf.dataType,
keyCol, searchActionCol, prefix=Some(prefix.getOrElse("") + sf.name + "."))
)
))
}
}

private def structFieldToSearchFields(schema: DataType,
keyCol: String,
searchActionCol: String,
prefix: Option[String] = None
): Option[Seq[IndexField]] = {
schema match {
case StructType(fields) => Some(convertFields(fields, keyCol, searchActionCol, prefix))
case ArrayType(StructType(fields), _) => Some(convertFields(fields, keyCol, searchActionCol, prefix))
case _ => None
}
}

private def dfToIndexJson(schema: StructType,
indexName: String,
keyCol: String,
searchActionCol: String): String = {
val is = IndexInfo(
Some(indexName),
structFieldToSearchFields(schema, keyCol, searchActionCol).get,
None, None, None, None, None, None, None, None
)
is.toJson.compactPrint
}

private def prepareDF(df: DataFrame, options: Map[String, String] = Map()): DataFrame = {
val applicableOptions = Set(
"subscriptionKey", "actionCol", "serviceName", "indexName", "indexJson",
"apiVersion", "batchSize", "fatalErrors", "filterNulls", "keyCol", "searchableCols", "filterableCols",
"facetableCols", "retrievableCols", "sortableCols"
"apiVersion", "batchSize", "fatalErrors", "filterNulls", "keyCol"
)

options.keys.foreach(k =>
Expand All @@ -196,39 +209,21 @@ object AzureSearchWriter extends IndexParser with SLogging {
val actionCol = options.getOrElse("actionCol", "@search.action")
val serviceName = options("serviceName")
val indexJsonOpt = options.get("indexJson")
val apiVersion = options.getOrElse("apiVersion", "2017-11-11")
val apiVersion = options.getOrElse("apiVersion", AzureSearchAPIConstants.DefaultAPIVersion)
val batchSize = options.getOrElse("batchSize", "100").toInt
val fatalErrors = options.getOrElse("fatalErrors", "true").toBoolean
val filterNulls = options.getOrElse("filterNulls", "false").toBoolean

val keyCol = options.get("keyCol")
val indexName = options.getOrElse("indexName", parseIndexJson(indexJsonOpt.get).name.get)
if (indexJsonOpt.isDefined) {
List("keyCol", "searchableCols", "filterableCols", "facetableCols", "indexName").foreach(opt =>
List("keyCol", "indexName").foreach(opt =>
assert(options.get(opt).isEmpty, s"Cannot set both indexJson options and $opt")
)
}

val keyCol = options.get("keyCol")

val defaultParseEmpty = { field: String =>
options.get(field).map(s => s.split(",").toList)
.getOrElse(List())
}
val defaultParseFull = { field: String =>
options.get(field).map(s => s.split(",").toList)
.getOrElse(df.schema.fieldNames.toList)
}

val searchableCols = defaultParseFull("searchableCols")
val filterableCols = defaultParseEmpty("filterableCols")
val facetableCols = defaultParseEmpty("facetableCols")
val retrievableCols = defaultParseFull("retrievableCols")
val sortableCols = defaultParseEmpty("sortableCols")

val indexName = options.getOrElse("indexName", parseIndexJson(indexJsonOpt.get).name.get)

val indexJson = indexJsonOpt.getOrElse {
dfToIndexJson(df.schema, indexName, keyCol.get, actionCol,
searchableCols, filterableCols, sortableCols, facetableCols, retrievableCols)
dfToIndexJson(df.schema, indexName, keyCol.get, actionCol)
}

SearchIndex.createIfNoneExists(subscriptionKey, serviceName, indexJson, apiVersion)
Expand Down Expand Up @@ -265,43 +260,55 @@ object AzureSearchWriter extends IndexParser with SLogging {
t.substring("Collection(".length).dropRight(1)
}

private def edmTypeToSparkType(dt: String, allowCollections: Boolean = true): DataType = dt match {
case t if allowCollections && isEdmCollection(t) =>
ArrayType(edmTypeToSparkType(getEdmCollectionElement(t), false), containsNull = false)
private[ml] def edmTypeToSparkType(dt: String,
fields: Option[Seq[IndexField]]): DataType = dt match {
case t if isEdmCollection(t) =>
throw new IllegalArgumentException("Azure search does not allow nested collections," +
" consider using Edm.ComplexType")
ArrayType(edmTypeToSparkType(getEdmCollectionElement(t), fields), containsNull = false)
case "Edm.String" => StringType
case "Edm.Boolean" => BooleanType
case "Edm.Int64" => LongType
case "Edm.Int32" => IntegerType
case "Edm.Double" => DoubleType
case "Edm.DateTimeOffset" => StringType //See if there's a way to use spark datetimes
case "Edm.GeographyPoint" => StringType
case "Edm.ComplexType" => StringType
case "Edm.ComplexType" => StructType(fields.get.map(f =>
StructField(f.name, edmTypeToSparkType(f.`type`, f.fields))))
}

private def sparkTypeToEdmType(dt: DataType, allowCollections: Boolean = true): String = dt match {
case ArrayType(it, _) if allowCollections =>
"Collection(" + sparkTypeToEdmType(it, false) + ")"
case StringType => "Edm.String"
case BooleanType => "Edm.Boolean"
case IntegerType => "Edm.Int32"
case LongType => "Edm.Int64"
case DoubleType => "Edm.Double"
case DateType => "Edm.DateTimeOffset"
case _ => "Edm.ComplexType"
private def sparkTypeToEdmType(dt: DataType,
allowCollections: Boolean = true): (String, Option[Seq[IndexField]]) = {
dt match {
case ArrayType(it, _) if allowCollections =>
val (innerType, innerFields) = sparkTypeToEdmType(it, allowCollections = false)
(s"Collection($innerType)", innerFields)
case ArrayType(it, _) if !allowCollections =>
val (innerType, innerFields) = sparkTypeToEdmType(it, allowCollections)
("Edm.ComplexType", innerFields)
case StringType => ("Edm.String", None)
case BooleanType => ("Edm.Boolean", None)
case IntegerType => ("Edm.Int32", None)
case LongType => ("Edm.Int64", None)
case DoubleType => ("Edm.Double", None)
case DateType => ("Edm.DateTimeOffset", None)
case StructType(fields) => ("Edm.ComplexType", Some(fields.map{f=>
val (innerType, innerFields) = sparkTypeToEdmType(f.dataType)
IndexField(f.name, innerType,None, None, None, None, None, None, None, None, None, None, innerFields)
}))
}
}

@scala.annotation.tailrec
private def dtEqualityModuloNullability(dt1: DataType, dt2: DataType): Boolean = (dt1, dt2) match {
case (ArrayType(it1, _), ArrayType(it2, _)) => dtEqualityModuloNullability(it1, it2)
case (StructType(fields1), StructType(fields2)) =>
fields1.zip(fields2).forall {
case (sf1, sf2) => sf1.name==sf2.name && dtEqualityModuloNullability(sf1.dataType, sf2.dataType)
}
case _ => dt1 == dt2
}

private def checkSchemaParity(schema: StructType, indexJson: String, searchActionCol: String): Unit = {
val indexInfo = parseIndexJson(indexJson)
val indexFields = indexInfo.fields.map(f => (f.name, edmTypeToSparkType(f.`type`))).toMap
val indexFields = indexInfo.fields.map(f => (f.name, edmTypeToSparkType(f.`type`, f.fields))).toMap

assert(schema(searchActionCol).dataType == StringType)
schema.toList.filter(_.name != searchActionCol).foreach { field =>
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -3,18 +3,20 @@

package com.microsoft.ml.spark.cognitive

import com.microsoft.ml.spark.cognitive._
import com.microsoft.ml.spark.cognitive.AzureSearchProtocol._
import com.microsoft.ml.spark.cognitive.RESTHelpers._
import org.apache.commons.io.IOUtils
import org.apache.http.client.methods.{HttpGet, HttpPost}
import org.apache.http.entity.StringEntity
import org.apache.log4j.{LogManager, Logger}
import org.apache.spark.sql.types._
import spray.json._

import scala.util.{Failure, Success, Try}
import AzureSearchProtocol._

import RESTHelpers._
object AzureSearchAPIConstants {
val DefaultAPIVersion = "2019-05-06"
}
import AzureSearchAPIConstants._

trait IndexParser {
def parseIndexJson(str: String): IndexInfo = {
Expand All @@ -25,7 +27,7 @@ trait IndexParser {
trait IndexLister {
def getExisting(key: String,
serviceName: String,
apiVersion: String = "2017-11-11"): Seq[String] = {
apiVersion: String = DefaultAPIVersion): Seq[String] = {
val indexListRequest = new HttpGet(
s"https://$serviceName.search.windows.net/indexes?api-version=$apiVersion&$$select=name"
)
Expand All @@ -46,7 +48,7 @@ object SearchIndex extends IndexParser with IndexLister {
def createIfNoneExists(key: String,
serviceName: String,
indexJson: String,
apiVersion: String = "2017-11-11"): Unit = {
apiVersion: String = DefaultAPIVersion): Unit = {
val indexName = parseIndexJson(indexJson).name.get

val existingIndexNames = getExisting(key, serviceName, apiVersion)
Expand Down Expand Up @@ -84,7 +86,7 @@ object SearchIndex extends IndexParser with IndexLister {
private def validIndexField(field: IndexField): Try[IndexField] = {
for {
_ <- validName(field.name)
_ <- validType(field.`type`)
_ <- validType(field.`type`, field.fields)
_ <- validSearchable(field.`type`, field.searchable)
_ <- validSortable(field.`type`, field.sortable)
_ <- validFacetable(field.`type`, field.facetable)
Expand All @@ -100,29 +102,19 @@ object SearchIndex extends IndexParser with IndexLister {
Try(fields.map(f => validIndexField(f).get))
}

private val ValidFieldTypes = Seq("Edm.String",
"Collection(Edm.String)",
"Edm.Int32",
"Edm.Int64",
"Edm.Double",
"Edm.Boolean",
"Edm.DateTimeOffset",
"Edm.GeographyPoint")

private def validName(n: String): Try[String] = {
if (n.isEmpty) {
Failure(new IllegalArgumentException("Empty name"))
} else Success(n)
}

private def validType(t: String): Try[String] = {
if (ValidFieldTypes.contains(t)) {
Success(t)
} else Failure(new IllegalArgumentException("Invalid field type"))
private def validType(t: String, fields: Option[Seq[IndexField]]): Try[String] = {
val tdt = Try(AzureSearchWriter.edmTypeToSparkType(t,fields))
tdt.map(_ => t)
}

private def validSearchable(t: String, s: Option[Boolean]): Try[Option[Boolean]] = {
if (Seq("Edm.String", "Collection(Edm.String)").contains(t)) {
if (Set("Edm.String", "Collection(Edm.String)")(t)) {
Success(s)
} else if (s.contains(true)) {
Failure(new IllegalArgumentException("Only Edm.String and Collection(Edm.String) fields can be searchable"))
Expand Down Expand Up @@ -193,7 +185,7 @@ object SearchIndex extends IndexParser with IndexLister {
def getStatistics(indexName: String,
key: String,
serviceName: String,
apiVersion: String = "2017-11-11"): (Int, Int) = {
apiVersion: String = DefaultAPIVersion): (Int, Int) = {
val getStatsRequest = new HttpGet(
s"https://$serviceName.search.windows.net/indexes/$indexName/stats?api-version=$apiVersion")
getStatsRequest.setHeader("api-key", key)
Expand Down
Loading

0 comments on commit aec1672

Please sign in to comment.