diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala index 67f26c1a93968..f2e41845908af 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala @@ -1683,6 +1683,19 @@ object SQLConf { .booleanConf .createWithDefault(false) + /** + * SPARK-38809 - Config option to allow skipping null values for hash based stream-stream joins. + * Its possible for us to see nulls if state was written with an older version of Spark, + * the state was corrupted on disk or if we had an issue with the state iterators. + */ + val STATE_STORE_SKIP_NULLS_FOR_STREAM_STREAM_JOINS = + buildConf("spark.sql.streaming.stateStore.skipNullsForStreamStreamJoins.enabled") + .internal() + .doc("When true, this config will skip null values in hash based stream-stream joins.") + .version("3.3.0") + .booleanConf + .createWithDefault(false) + val VARIABLE_SUBSTITUTE_ENABLED = buildConf("spark.sql.variable.substitute") .doc("This enables substitution using syntax like `${var}`, `${system:var}`, " + @@ -3551,6 +3564,9 @@ class SQLConf extends Serializable with Logging { def stateStoreFormatValidationEnabled: Boolean = getConf(STATE_STORE_FORMAT_VALIDATION_ENABLED) + def stateStoreSkipNullsForStreamStreamJoins: Boolean = + getConf(STATE_STORE_SKIP_NULLS_FOR_STREAM_STREAM_JOINS) + def checkpointLocation: Option[String] = getConf(CHECKPOINT_LOCATION) def isUnsupportedOperationCheckEnabled: Boolean = getConf(UNSUPPORTED_OPERATION_CHECK_ENABLED) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/StateStoreConf.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/StateStoreConf.scala index 58af8272d1c09..529db2609cd45 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/StateStoreConf.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/StateStoreConf.scala @@ -52,6 +52,9 @@ class StateStoreConf( val formatValidationCheckValue: Boolean = extraOptions.getOrElse(StateStoreConf.FORMAT_VALIDATION_CHECK_VALUE_CONFIG, "true") == "true" + /** Whether to skip null values for hash based stream-stream joins. */ + val skipNullsForStreamStreamJoins: Boolean = sqlConf.stateStoreSkipNullsForStreamStreamJoins + /** The compression codec used to compress delta and snapshot files. */ val compressionCodec: String = sqlConf.stateStoreCompressionCodec diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/SymmetricHashJoinStateManager.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/SymmetricHashJoinStateManager.scala index 56c47d564a3b3..d17c6e8e862ca 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/SymmetricHashJoinStateManager.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/SymmetricHashJoinStateManager.scala @@ -222,8 +222,12 @@ class SymmetricHashJoinStateManager( valueRemoved = false } - // Find the next value satisfying the condition, updating `currentKey` and `numValues` if - // needed. Returns null when no value can be found. + /** + * Find the next value satisfying the condition, updating `currentKey` and `numValues` if + * needed. Returns null when no value can be found. + * Note that we will skip nulls explicitly if config setting for the same is + * set to true via STATE_STORE_SKIP_NULLS_FOR_STREAM_STREAM_JOINS. + */ private def findNextValueForIndex(): ValueAndMatchPair = { // Loop across all values for the current key, and then all other keys, until we find a // value satisfying the removal condition. @@ -233,7 +237,9 @@ class SymmetricHashJoinStateManager( if (hasMoreValuesForCurrentKey) { // First search the values for the current key. val valuePair = keyWithIndexToValue.get(currentKey, index) - if (removalCondition(valuePair.value)) { + if (valuePair == null && storeConf.skipNullsForStreamStreamJoins) { + index += 1 + } else if (removalCondition(valuePair.value)) { return valuePair } else { index += 1 @@ -597,22 +603,30 @@ class SymmetricHashJoinStateManager( /** * Get all values and indices for the provided key. * Should not return null. + * Note that we will skip nulls explicitly if config setting for the same is + * set to true via STATE_STORE_SKIP_NULLS_FOR_STREAM_STREAM_JOINS. */ def getAll(key: UnsafeRow, numValues: Long): Iterator[KeyWithIndexAndValue] = { - val keyWithIndexAndValue = new KeyWithIndexAndValue() - var index = 0 new NextIterator[KeyWithIndexAndValue] { + private val keyWithIndexAndValue = new KeyWithIndexAndValue() + private var index: Long = 0L + + private def hasMoreValues = index < numValues override protected def getNext(): KeyWithIndexAndValue = { - if (index >= numValues) { - finished = true - null - } else { + while (hasMoreValues) { val keyWithIndex = keyWithIndexRow(key, index) val valuePair = valueRowConverter.convertValue(stateStore.get(keyWithIndex)) - keyWithIndexAndValue.withNew(key, index, valuePair) - index += 1 - keyWithIndexAndValue + if (valuePair == null && storeConf.skipNullsForStreamStreamJoins) { + index += 1 + } else { + keyWithIndexAndValue.withNew(key, index, valuePair) + index += 1 + return keyWithIndexAndValue + } } + + finished = true + return null } override protected def close(): Unit = {} diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/state/SymmetricHashJoinStateManagerSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/state/SymmetricHashJoinStateManagerSuite.scala index deeebe1fc42bf..30d39ebcc4a91 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/state/SymmetricHashJoinStateManagerSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/state/SymmetricHashJoinStateManagerSuite.scala @@ -29,6 +29,7 @@ import org.apache.spark.sql.catalyst.expressions.codegen.GeneratePredicate import org.apache.spark.sql.catalyst.plans.logical.EventTimeWatermark import org.apache.spark.sql.execution.streaming.StatefulOperatorStateInfo import org.apache.spark.sql.execution.streaming.StreamingSymmetricHashJoinHelper.LeftSide +import org.apache.spark.sql.internal.SQLConf import org.apache.spark.sql.streaming.StreamTest import org.apache.spark.sql.types._ import org.apache.spark.unsafe.types.UTF8String @@ -52,6 +53,12 @@ class SymmetricHashJoinStateManagerSuite extends StreamTest with BeforeAndAfter } } + SymmetricHashJoinStateManager.supportedVersions.foreach { version => + test(s"StreamingJoinStateManager V${version} - all operations with nulls in middle") { + testAllOperationsWithNullsInMiddle(version) + } + } + SymmetricHashJoinStateManager.supportedVersions.foreach { version => test(s"SPARK-35689: StreamingJoinStateManager V${version} - " + "printable key of keyWithIndexToValue") { @@ -167,6 +174,55 @@ class SymmetricHashJoinStateManagerSuite extends StreamTest with BeforeAndAfter } } + /* Test removeByValue with nulls in middle simulated by updating numValues on the state manager */ + private def testAllOperationsWithNullsInMiddle(stateFormatVersion: Int): Unit = { + // Test with skipNullsForStreamStreamJoins set to false which would throw a + // NullPointerException while iterating and also return null values as part of get + withJoinStateManager(inputValueAttribs, joinKeyExprs, stateFormatVersion) { manager => + implicit val mgr = manager + + val ex = intercept[Exception] { + appendAndTest(40, 50, 200, 300) + assert(numRows === 3) + updateNumValues(40, 4) // create a null at the end + append(40, 400) + updateNumValues(40, 7) // create nulls in between and end + removeByValue(50) + } + assert(ex.isInstanceOf[NullPointerException]) + assert(getNumValues(40) === 7) // we should get 7 with no nulls skipped + + removeByValue(300) + assert(getNumValues(40) === 1) // only 400 should remain + assert(get(40) === Seq(400)) + removeByValue(400) + assert(get(40) === Seq.empty) + assert(numRows === 0) // ensure all elements removed + } + + // Test with skipNullsForStreamStreamJoins set to true which would skip nulls + // and continue iterating as part of removeByValue as well as get + withJoinStateManager(inputValueAttribs, joinKeyExprs, stateFormatVersion, true) { manager => + implicit val mgr = manager + + appendAndTest(40, 50, 200, 300) + assert(numRows === 3) + updateNumValues(40, 4) // create a null at the end + append(40, 400) + updateNumValues(40, 7) // create nulls in between and end + + removeByValue(50) + assert(getNumValues(40) === 3) // we should now get (400, 200, 300) with nulls skipped + + removeByValue(300) + assert(getNumValues(40) === 1) // only 400 should remain + assert(get(40) === Seq(400)) + removeByValue(400) + assert(get(40) === Seq.empty) + assert(numRows === 0) // ensure all elements removed + } + } + val watermarkMetadata = new MetadataBuilder().putLong(EventTimeWatermark.delayKey, 10).build() val inputValueSchema = new StructType() .add(StructField("time", IntegerType, metadata = watermarkMetadata)) @@ -205,6 +261,11 @@ class SymmetricHashJoinStateManagerSuite extends StreamTest with BeforeAndAfter manager.updateNumValuesTestOnly(toJoinKeyRow(key), numValues) } + def getNumValues(key: Int) + (implicit manager: SymmetricHashJoinStateManager): Int = { + manager.get(toJoinKeyRow(key)).size + } + def get(key: Int)(implicit manager: SymmetricHashJoinStateManager): Seq[Int] = { manager.get(toJoinKeyRow(key)).map(toValueInt).toSeq.sorted } @@ -232,22 +293,26 @@ class SymmetricHashJoinStateManagerSuite extends StreamTest with BeforeAndAfter manager.metrics.numKeys } - def withJoinStateManager( - inputValueAttribs: Seq[Attribute], - joinKeyExprs: Seq[Expression], - stateFormatVersion: Int)(f: SymmetricHashJoinStateManager => Unit): Unit = { + inputValueAttribs: Seq[Attribute], + joinKeyExprs: Seq[Expression], + stateFormatVersion: Int, + skipNullsForStreamStreamJoins: Boolean = false) + (f: SymmetricHashJoinStateManager => Unit): Unit = { withTempDir { file => - val storeConf = new StateStoreConf() - val stateInfo = StatefulOperatorStateInfo(file.getAbsolutePath, UUID.randomUUID, 0, 0, 5) - val manager = new SymmetricHashJoinStateManager( - LeftSide, inputValueAttribs, joinKeyExprs, Some(stateInfo), storeConf, new Configuration, - partitionId = 0, stateFormatVersion) - try { - f(manager) - } finally { - manager.abortIfNeeded() + withSQLConf(SQLConf.STATE_STORE_SKIP_NULLS_FOR_STREAM_STREAM_JOINS.key -> + skipNullsForStreamStreamJoins.toString) { + val storeConf = new StateStoreConf(spark.sqlContext.conf) + val stateInfo = StatefulOperatorStateInfo(file.getAbsolutePath, UUID.randomUUID, 0, 0, 5) + val manager = new SymmetricHashJoinStateManager( + LeftSide, inputValueAttribs, joinKeyExprs, Some(stateInfo), storeConf, new Configuration, + partitionId = 0, stateFormatVersion) + try { + f(manager) + } finally { + manager.abortIfNeeded() + } } } StateStore.stop()