Skip to content

Commit

Permalink
[SPARK-47385] Fix tuple encoders with Option inputs
Browse files Browse the repository at this point in the history
apache#40755  adds a null check on the input of the child deserializer in the tuple encoder. It breaks the deserializer for the `Option` type, because null should be deserialized into `None` rather than null. This PR adds a boolean parameter to `ExpressionEncoder.tuple` so that only the user that apache#40755 intended to fix has this null check.

Unit test.

Closes apache#45508 from chenhao-db/SPARK-47385.

Authored-by: Chenhao Li <chenhao.li@databricks.com>
Signed-off-by: Wenchen Fan <wenchen@databricks.com>
(cherry picked from commit 9986462)
Signed-off-by: Wenchen Fan <wenchen@databricks.com>
  • Loading branch information
chenhao-db authored and cloud-fan committed Mar 14, 2024
1 parent e98872f commit 45ba922
Show file tree
Hide file tree
Showing 3 changed files with 23 additions and 3 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -79,8 +79,14 @@ object ExpressionEncoder {
* Given a set of N encoders, constructs a new encoder that produce objects as items in an
* N-tuple. Note that these encoders should be unresolved so that information about
* name/positional binding is preserved.
* When `useNullSafeDeserializer` is true, the deserialization result for a child will be null if
* the input is null. It is false by default as most deserializers handle null input properly and
* don't require an extra null check. Some of them are null-tolerant, such as the deserializer for
* `Option[T]`, and we must not set it to true in this case.
*/
def tuple(encoders: Seq[ExpressionEncoder[_]]): ExpressionEncoder[_] = {
def tuple(
encoders: Seq[ExpressionEncoder[_]],
useNullSafeDeserializer: Boolean = false): ExpressionEncoder[_] = {
if (encoders.length > 22) {
throw QueryExecutionErrors.elementsOfTupleExceedLimitError()
}
Expand Down Expand Up @@ -125,7 +131,7 @@ object ExpressionEncoder {
case GetColumnByOrdinal(0, _) => input
}

if (enc.objSerializer.nullable) {
if (useNullSafeDeserializer && enc.objSerializer.nullable) {
nullSafe(input, childDeserializer)
} else {
childDeserializer
Expand Down
4 changes: 3 additions & 1 deletion sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala
Original file line number Diff line number Diff line change
Expand Up @@ -1173,7 +1173,9 @@ class Dataset[T] private[sql](
}

implicit val tuple2Encoder: Encoder[(T, U)] =
ExpressionEncoder.tuple(this.exprEnc, other.exprEnc)
ExpressionEncoder
.tuple(Seq(this.exprEnc, other.exprEnc), useNullSafeDeserializer = true)
.asInstanceOf[Encoder[(T, U)]]

val leftResultExpr = {
if (!this.exprEnc.isSerializedAsStructForTopLevel) {
Expand Down
12 changes: 12 additions & 0 deletions sql/core/src/test/scala/org/apache/spark/sql/DatasetSuite.scala
Original file line number Diff line number Diff line change
Expand Up @@ -2197,6 +2197,18 @@ class DatasetSuite extends QueryTest
)
assert(result == expected)
}

test("SPARK-47385: Tuple encoder with Option inputs") {
implicit val enc: Encoder[(SingleData, Option[SingleData])] =
Encoders.tuple(Encoders.product[SingleData], Encoders.product[Option[SingleData]])

val input = Seq(
(SingleData(1), Some(SingleData(1))),
(SingleData(2), None)
)
val ds = spark.createDataFrame(input).as[(SingleData, Option[SingleData])]
checkDataset(ds, input: _*)
}
}

case class Bar(a: Int)
Expand Down

0 comments on commit 45ba922

Please sign in to comment.