Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,8 @@ import org.apache.spark.ml.util._
import org.apache.spark.ml.{ComplexParamsReadable, NamespaceInjections, PipelineModel}
import org.apache.spark.ml.linalg.SQLDataTypes.VectorType
import org.apache.spark.ml.functions.vector_to_array
import org.apache.spark.sql.functions.{col, expr, struct, to_json, to_utc_timestamp, date_format, when}
import org.apache.spark.sql.functions.{col, expr, from_json, struct, to_json, to_utc_timestamp,
date_format, when}
import org.apache.spark.sql.streaming.DataStreamWriter
import org.apache.spark.sql.types._
import org.apache.spark.sql.{DataFrame, Dataset, Row}
Expand Down Expand Up @@ -249,6 +250,57 @@ object AzureSearchWriter extends IndexParser with IndexJsonGetter with SLogging
}
}

/**
* Converts string columns containing GeoJSON to the proper struct shape required for
* Azure Search `Edm.GeographyPoint` fields.
*
* Azure AI Search expects spatial values to be sent as a GeoJSON object
* (e.g. `{"type":"Point","coordinates":[lon, lat]}`), not as a JSON-encoded string.
* Users frequently have their GeoJSON readily available as a string column, and
* passing it as a `StringType` previously caused a `400 Bad Request`
* (see [[https://github.com/microsoft/SynapseML/issues/2420]]) because the writer
* JSON-escaped the entire string.
*
* For each '''top-level''' field declared as `Edm.GeographyPoint` in the index, if the
* corresponding DataFrame column is a `StringType`, parse it into the canonical
* `StructType(type: StringType, coordinates: ArrayType(DoubleType))` so that downstream
* `to_json` emits a proper GeoJSON object. Columns that are already structured are
* left as-is. GeographyPoint fields nested inside complex types are not auto-converted
* (mirrors the existing top-level-only handling in `convertDateTimeToISO8601`).
*
* Parsing uses Spark's `FAILFAST` mode so malformed GeoJSON surfaces an explicit
* exception instead of being silently coerced to `null` and shipped to Azure Search.
*
* @param df DataFrame with potential GeographyPoint columns
* @param indexJson JSON string containing the index schema
* @return DataFrame with string GeographyPoint columns converted to GeoJSON structs
*/
private[ml] def convertGeographyPointToStruct(df: DataFrame, indexJson: String): DataFrame = {
val geoStructType = StructType(Seq(
StructField("type", StringType),
StructField("coordinates", ArrayType(DoubleType))
))
val parseOptions = Map("mode" -> "FAILFAST")
val geoFields = parseIndexJson(indexJson).fields
.filter(_.`type` == "Edm.GeographyPoint")
.map(_.name)
geoFields.foldLeft(df) { (currentDF, fieldName) =>
if (currentDF.columns.contains(fieldName)) {
currentDF.schema(fieldName).dataType match {
case StringType =>
currentDF.withColumn(fieldName,
when(col(fieldName).isNotNull, from_json(col(fieldName), geoStructType, parseOptions))
)
case _ =>
// Already a struct (or otherwise compatible); checkSchemaParity will validate.
currentDF
}
} else {
currentDF
}
}
}

private def dfToIndexJson(schema: StructType,
indexName: String,
keyCol: String,
Expand Down Expand Up @@ -328,17 +380,18 @@ object AzureSearchWriter extends IndexParser with IndexJsonGetter with SLogging

SearchIndex.createIfNoneExists(subscriptionKey, serviceName, indexJson, apiVersion)
val dateConvertedDF = convertDateTimeToISO8601(preppedDF, indexJson)
val geoConvertedDF = convertGeographyPointToStruct(dateConvertedDF, indexJson)

logInfo("checking schema parity")
checkSchemaParity(dateConvertedDF.schema, indexJson, actionCol)
checkSchemaParity(geoConvertedDF.schema, indexJson, actionCol)

val df1 = if (filterNulls) {
val collectionColumns = parseIndexJson(indexJson).fields
.filter(_.`type`.startsWith("Collection"))
.map(_.name)
collectionColumns.foldLeft(dateConvertedDF) { (ndf, c) => filterOutNulls(ndf, c) }
collectionColumns.foldLeft(geoConvertedDF) { (ndf, c) => filterOutNulls(ndf, c) }
} else {
dateConvertedDF
geoConvertedDF
}

// Convert date/timestamp columns to ISO8601 strings for Azure Search
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -171,4 +171,123 @@ class SearchWriterSuite extends SearchWriterSuiteUtilities {

}

test("Handle GeoJSON GeographyPoint fields supplied as strings") {

val in = generateIndexName()
val df = spark.createDataFrame(Seq(
("upload", "0", """{"type":"Point","coordinates":[-122.3493, 47.6205]}"""),
("upload", "1", """{"type":"Point","coordinates":[-122.3351, 47.6080]}""")
)).toDF("searchAction", "id", "location")

val indexJson =
s"""
|{
| "name": "$in",
| "fields": [
| { "name": "id", "type": "Edm.String", "key": true, "searchable": true, "retrievable": true },
| { "name": "location", "type": "Edm.GeographyPoint", "searchable": false,
| "filterable": true, "retrievable": true, "sortable": true }
| ]
|}
|""".stripMargin

AzureSearchWriter.write(df,
Map(
"subscriptionKey" -> azureSearchKey,
"actionCol" -> "searchAction",
"serviceName" -> testServiceName,
"indexJson" -> indexJson
)
)

// With fatalErrors=true (default) any 400 from Azure Search becomes a thrown
// RuntimeException, so reaching this `assertSize` proves the documents were
// accepted as valid spatial objects -- a count of 2 is only achievable if the
// GeoJSON strings were correctly parsed and serialized as GeoJSON objects.
retryWithBackoff(assertSize(in, 2))

Comment on lines +207 to +208
Copy link

Copilot AI Apr 25, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This test only asserts document count after the write. If GeoJSON parsing were to fail and produce null (or otherwise lose the spatial payload), Azure Search could still ingest the documents and this test would pass. To make the test validate the intended behavior, consider fetching the stored documents (or querying/selecting location) and asserting the location field is present and shaped as a GeoJSON object (e.g., has type and coordinates).

Copilot uses AI. Check for mistakes.
Copy link
Copy Markdown
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Addressed in 95a7ae1. Two complementary guarantees now make the count check meaningful:

  1. The new FAILFAST parsing mode raises a SparkException on malformed GeoJSON before the request is even built (covered by convertGeographyPointToStruct fails fast on malformed GeoJSON ...).
  2. AzureSearchWriter.write runs with fatalErrors=true by default, so any 400 from the service throws a RuntimeException and fails the test before assertSize is reached.

Combined with the new unit tests that assert the converted struct shape and parsed coordinates directly, a count of 2 is only achievable if the documents were accepted as valid spatial objects. The repo's existing search tests follow the same assertSize-only pattern and there's no helper for fetching individual documents, so I kept the e2e style consistent.

}

test("convertGeographyPointToStruct parses GeoJSON strings into structs") {
val df = spark.createDataFrame(Seq(
("0", """{"type":"Point","coordinates":[-122.3493, 47.6205]}"""),
("1", null)
)).toDF("id", "location")

val indexJson =
"""
|{
| "name": "unit-test-geo",
| "fields": [
| { "name": "id", "type": "Edm.String", "key": true },
| { "name": "location", "type": "Edm.GeographyPoint" }
| ]
|}
|""".stripMargin

val converted = AzureSearchWriter.convertGeographyPointToStruct(df, indexJson)
val expected = StructType(Seq(
StructField("type", StringType),
StructField("coordinates", ArrayType(DoubleType))
))
assert(converted.schema("location").dataType == expected)

val rows = converted.orderBy("id").collect()
val parsed = rows.head.getStruct(rows.head.fieldIndex("location"))
assert(parsed.getString(0) == "Point")
assert(parsed.getSeq[Double](1) == Seq(-122.3493, 47.6205))
assert(rows(1).isNullAt(rows(1).fieldIndex("location")))
}

test("convertGeographyPointToStruct leaves struct columns untouched") {
val schema = StructType(Seq(
StructField("id", StringType),
StructField("location", StructType(Seq(
StructField("type", StringType, nullable = false),
StructField("coordinates", ArrayType(DoubleType, containsNull = false), nullable = false)
)))
))
val df = spark.createDataFrame(
spark.sparkContext.parallelize(Seq(Row("0", Row("Point", Seq(-122.3493, 47.6205))))),
schema
)

val indexJson =
"""
|{
| "name": "unit-test-geo",
| "fields": [
| { "name": "id", "type": "Edm.String", "key": true },
| { "name": "location", "type": "Edm.GeographyPoint" }
| ]
|}
|""".stripMargin

val converted = AzureSearchWriter.convertGeographyPointToStruct(df, indexJson)
assert(converted.schema("location").dataType == schema("location").dataType)
}

test("convertGeographyPointToStruct fails fast on malformed GeoJSON instead of silently nulling") {
val df = spark.createDataFrame(Seq(
("0", "{not valid json")
)).toDF("id", "location")

val indexJson =
"""
|{
| "name": "unit-test-geo",
| "fields": [
| { "name": "id", "type": "Edm.String", "key": true },
| { "name": "location", "type": "Edm.GeographyPoint" }
| ]
|}
|""".stripMargin

val converted = AzureSearchWriter.convertGeographyPointToStruct(df, indexJson)
// FAILFAST surfaces parse errors when the row is materialized, not at plan time.
intercept[org.apache.spark.SparkException] {
converted.collect()
}
}

}
Loading