Skip to content

Commit 2e4c4b0

Browse files
committed
Consider document routing when deleting and overwriting data in SparkSQL
fixes #1030
1 parent 81a0893 commit 2e4c4b0

File tree

5 files changed

+157
-6
lines changed

5 files changed

+157
-6
lines changed

mr/src/main/java/org/elasticsearch/hadoop/rest/RestRepository.java

Lines changed: 21 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -384,18 +384,35 @@ public void delete() {
384384
sb.append("&search_type=scan");
385385
}
386386
String scanQuery = sb.toString();
387-
ScrollReader scrollReader = new ScrollReader(new ScrollReaderConfig(new JdkValueReader()));
387+
ScrollReaderConfig readerConf = new ScrollReaderConfig(true, new JdkValueReader());
388+
ScrollReader scrollReader = new ScrollReader(readerConf);
388389

389390
// start iterating
390391
ScrollQuery sq = scanAll(scanQuery, null, scrollReader);
391392
try {
392393
BytesArray entry = new BytesArray(0);
393394

394-
// delete each retrieved batch
395-
String format = "{\"delete\":{\"_id\":\"%s\"}}\n";
395+
// delete each retrieved batch, keep routing in mind:
396+
String baseFormat = "{\"delete\":{\"_id\":\"%s\"}}\n";
397+
String routedFormat;
398+
if (client.internalVersion.onOrAfter(EsMajorVersion.V_7_X)) {
399+
routedFormat = "{\"delete\":{\"_id\":\"%s\", \"routing\":\"%s\"}}\n";
400+
} else {
401+
routedFormat = "{\"delete\":{\"_id\":\"%s\", \"_routing\":\"%s\"}}\n";
402+
}
396403
while (sq.hasNext()) {
397404
entry.reset();
398-
entry.add(StringUtils.toUTF(String.format(format, sq.next()[0])));
405+
Object[] kv = sq.next();
406+
@SuppressWarnings("unchecked")
407+
Map<String, Object> value = (Map<String, Object>) kv[1];
408+
@SuppressWarnings("unchecked")
409+
Map<String, Object> metadata = (Map<String, Object>) value.get("_metadata");
410+
String routing = (String) metadata.get("_routing");
411+
if (StringUtils.hasText(routing)) {
412+
entry.add(StringUtils.toUTF(String.format(routedFormat, kv[0], routing)));
413+
} else {
414+
entry.add(StringUtils.toUTF(String.format(baseFormat, kv[0])));
415+
}
399416
writeProcessedToIndex(entry);
400417
}
401418

mr/src/main/java/org/elasticsearch/hadoop/serialization/ScrollReader.java

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -190,7 +190,11 @@ public ScrollReaderConfig(ValueReader reader, Mapping resolvedMapping, boolean r
190190
}
191191

192192
public ScrollReaderConfig(ValueReader reader) {
193-
this(reader, null, false, "_metadata", false, false, Collections.<String> emptyList(), Collections.<String> emptyList(), Collections.<String> emptyList());
193+
this(false, reader);
194+
}
195+
196+
public ScrollReaderConfig(boolean readMetadata, ValueReader reader) {
197+
this(reader, null, readMetadata, "_metadata", false, false, Collections.<String> emptyList(), Collections.<String> emptyList(), Collections.<String> emptyList());
194198
}
195199

196200
public ScrollReaderConfig(ValueReader reader, Mapping resolvedMapping, Settings cfg) {

spark/sql-13/src/itest/scala/org/elasticsearch/spark/integration/AbstractScalaEsSparkSQL.scala

Lines changed: 65 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1204,6 +1204,71 @@ class AbstractScalaEsScalaSparkSQL(prefix: String, readMetadata: jl.Boolean, pus
12041204
assertEquals(1, df.count)
12051205
}
12061206

1207+
@Test
1208+
def testEsDataFrame52OverwriteExistingDataSourceWithJoinField() {
1209+
// Join added in 6.0.
1210+
EsAssume.versionOnOrAfter(EsMajorVersion.V_6_X, "Join added in 6.0.")
1211+
1212+
// using long-form joiner values
1213+
val schema = StructType(Seq(
1214+
StructField("id", StringType, nullable = false),
1215+
StructField("company", StringType, nullable = true),
1216+
StructField("name", StringType, nullable = true),
1217+
StructField("joiner", StructType(Seq(
1218+
StructField("name", StringType, nullable = false),
1219+
StructField("parent", StringType, nullable = true)
1220+
)))
1221+
))
1222+
1223+
val parents = Seq(
1224+
Row("1", "Elastic", null, Row("company", null)),
1225+
Row("2", "Fringe Cafe", null, Row("company", null)),
1226+
Row("3", "WATIcorp", null, Row("company", null))
1227+
)
1228+
1229+
val firstChildren = Seq(
1230+
Row("10", null, "kimchy", Row("employee", "1")),
1231+
Row("20", null, "April Ryan", Row("employee", "2")),
1232+
Row("21", null, "Charlie", Row("employee", "2")),
1233+
Row("30", null, "Alvin Peats", Row("employee", "3"))
1234+
)
1235+
1236+
val index = wrapIndex("sparksql-test-scala-overwrite-join")
1237+
val typename = "join"
1238+
val target = s"$index/$typename"
1239+
RestUtils.delete(index)
1240+
RestUtils.touch(index)
1241+
RestUtils.putMapping(index, typename, "data/join/mapping.json")
1242+
1243+
sqc.createDataFrame(sc.makeRDD(parents ++ firstChildren), schema)
1244+
.write
1245+
.format("es")
1246+
.options(Map(ES_MAPPING_ID -> "id", ES_MAPPING_JOIN -> "joiner"))
1247+
.save(target)
1248+
1249+
assertThat(RestUtils.get(target + "/10?routing=1"), containsString("kimchy"))
1250+
assertThat(RestUtils.get(target + "/10?routing=1"), containsString(""""_routing":"1""""))
1251+
1252+
// Overwrite the data using a new dataset:
1253+
val newChildren = Seq(
1254+
Row("110", null, "costinl", Row("employee", "1")),
1255+
Row("111", null, "jbaiera", Row("employee", "1")),
1256+
Row("121", null, "Charlie", Row("employee", "2")),
1257+
Row("130", null, "Damien", Row("employee", "3"))
1258+
)
1259+
1260+
sqc.createDataFrame(sc.makeRDD(parents ++ newChildren), schema)
1261+
.write
1262+
.format("es")
1263+
.options(cfg ++ Map(ES_MAPPING_ID -> "id", ES_MAPPING_JOIN -> "joiner"))
1264+
.mode(SaveMode.Overwrite)
1265+
.save(target)
1266+
1267+
assertFalse(RestUtils.exists(target + "/10?routing=1"))
1268+
assertThat(RestUtils.get(target + "/110?routing=1"), containsString("costinl"))
1269+
assertThat(RestUtils.get(target + "/110?routing=1"), containsString(""""_routing":"1""""))
1270+
}
1271+
12071272
@Test
12081273
def testEsDataFrame53OverwriteExistingDataSourceFromAnotherDataSource() {
12091274
// to keep the select static

spark/sql-20/src/itest/scala/org/elasticsearch/spark/integration/AbstractScalaEsSparkSQL.scala

Lines changed: 65 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1222,6 +1222,71 @@ class AbstractScalaEsScalaSparkSQL(prefix: String, readMetadata: jl.Boolean, pus
12221222
assertEquals(1, df.count)
12231223
}
12241224

1225+
@Test
1226+
def testEsDataFrame52OverwriteExistingDataSourceWithJoinField() {
1227+
// Join added in 6.0.
1228+
EsAssume.versionOnOrAfter(EsMajorVersion.V_6_X, "Join added in 6.0.")
1229+
1230+
// using long-form joiner values
1231+
val schema = StructType(Seq(
1232+
StructField("id", StringType, nullable = false),
1233+
StructField("company", StringType, nullable = true),
1234+
StructField("name", StringType, nullable = true),
1235+
StructField("joiner", StructType(Seq(
1236+
StructField("name", StringType, nullable = false),
1237+
StructField("parent", StringType, nullable = true)
1238+
)))
1239+
))
1240+
1241+
val parents = Seq(
1242+
Row("1", "Elastic", null, Row("company", null)),
1243+
Row("2", "Fringe Cafe", null, Row("company", null)),
1244+
Row("3", "WATIcorp", null, Row("company", null))
1245+
)
1246+
1247+
val firstChildren = Seq(
1248+
Row("10", null, "kimchy", Row("employee", "1")),
1249+
Row("20", null, "April Ryan", Row("employee", "2")),
1250+
Row("21", null, "Charlie", Row("employee", "2")),
1251+
Row("30", null, "Alvin Peats", Row("employee", "3"))
1252+
)
1253+
1254+
val index = wrapIndex("sparksql-test-scala-overwrite-join")
1255+
val typename = "join"
1256+
val target = s"$index/$typename"
1257+
RestUtils.delete(index)
1258+
RestUtils.touch(index)
1259+
RestUtils.putMapping(index, typename, "data/join/mapping.json")
1260+
1261+
sqc.createDataFrame(sc.makeRDD(parents ++ firstChildren), schema)
1262+
.write
1263+
.format("es")
1264+
.options(Map(ES_MAPPING_ID -> "id", ES_MAPPING_JOIN -> "joiner"))
1265+
.save(target)
1266+
1267+
assertThat(RestUtils.get(target + "/10?routing=1"), containsString("kimchy"))
1268+
assertThat(RestUtils.get(target + "/10?routing=1"), containsString(""""_routing":"1""""))
1269+
1270+
// Overwrite the data using a new dataset:
1271+
val newChildren = Seq(
1272+
Row("110", null, "costinl", Row("employee", "1")),
1273+
Row("111", null, "jbaiera", Row("employee", "1")),
1274+
Row("121", null, "Charlie", Row("employee", "2")),
1275+
Row("130", null, "Damien", Row("employee", "3"))
1276+
)
1277+
1278+
sqc.createDataFrame(sc.makeRDD(parents ++ newChildren), schema)
1279+
.write
1280+
.format("es")
1281+
.options(cfg ++ Map(ES_MAPPING_ID -> "id", ES_MAPPING_JOIN -> "joiner"))
1282+
.mode(SaveMode.Overwrite)
1283+
.save(target)
1284+
1285+
assertFalse(RestUtils.exists(target + "/10?routing=1"))
1286+
assertThat(RestUtils.get(target + "/110?routing=1"), containsString("costinl"))
1287+
assertThat(RestUtils.get(target + "/110?routing=1"), containsString(""""_routing":"1""""))
1288+
}
1289+
12251290
@Test
12261291
def testEsDataFrame53OverwriteExistingDataSourceFromAnotherDataSource() {
12271292
// to keep the select static

spark/sql-20/src/main/scala/org/elasticsearch/spark/sql/DataFrameFieldExtractor.scala

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -47,7 +47,7 @@ class DataFrameFieldExtractor extends ScalaMapFieldExtractor {
4747

4848
// Return the value or unpack the value if it's a row-schema tuple
4949
obj match {
50-
case (row: Row, struct: StructType) => row
50+
case (row: Row, _: StructType) => row
5151
case any => any
5252
}
5353
}

0 commit comments

Comments
 (0)