Skip to content

Commit

Permalink
[SPARK-43427][PROTOBUF] spark protobuf: allow upcasting unsigned inte…
Browse files Browse the repository at this point in the history
…ger types

### What changes were proposed in this pull request?

JIRA: https://issues.apache.org/jira/browse/SPARK-43427

Protobuf supports unsigned integer types, including uint32 and uint64. When deserializing protobuf values with fields of these types, `from_protobuf` currently transforms them to the spark types of:

```
uint32 => IntegerType
uint64 => LongType
```

IntegerType and LongType are [signed](https://spark.apache.org/docs/latest/sql-ref-datatypes.html) integer types, so this can lead to confusing results. Namely, if a uint32 value in a stored proto is above 2^31 or a uint64 value is above 2^63, their representation in binary will contain a 1 in the highest bit, which when interpreted as a signed integer will be negative (I.e. overflow). No information is lost, as `IntegerType` and `LongType` contain 32 and 64 bits respectively, however their representation can be confusing.

In this PR, we add an option (`upcast.unsigned.ints`) to allow upcasting unsigned integer types into a larger integer type that can represent them natively, i.e.

```
uint32 => LongType
uint64 => Decimal(20, 0)
```

I added an option so that it doesn't break any existing clients.

**Example of current behavior**

Consider a protobuf message like:

```
syntax = "proto3";

message Test {
  uint64 val = 1;
}
```

If we compile the above and then generate a message with a value for `val` above 2^63:

```
import test_pb2

s = test_pb2.Test()
s.val = 9223372036854775809 # 2**63 + 1
serialized = s.SerializeToString()
print(serialized)
```

This generates the binary representation:

b'\x08\x81\x80\x80\x80\x80\x80\x80\x80\x80\x01'

Then, deserializing this using `from_protobuf`, we can see that it is represented as a negative number. I did this in a notebook so its easier to see, but could reproduce in a scala test as well:

![image](https://github.com/apache/spark/assets/1002986/7144e6a9-3f43-455e-94c3-9065ae88206e)

**Precedent**
I believe that unsigned integer types in parquet are deserialized in a similar manner, i.e. put into a larger type so that the unsigned representation natively fits. https://issues.apache.org/jira/browse/SPARK-34817 and apache#31921. So an option to get similar behavior would be useful.

### Why are the changes needed?
Improve unsigned integer deserialization behavior.

### Does this PR introduce any user-facing change?
Yes, adds a new option.

### How was this patch tested?
Unit Testing

### Was this patch authored or co-authored using generative AI tooling?
No

Closes apache#43773 from justaparth/parth/43427-add-option-to-expand-unsigned-integers.

Authored-by: Parth Upadhyay <parth.upadhyay@gmail.com>
Signed-off-by: Hyukjin Kwon <gurwls223@apache.org>
  • Loading branch information
justaparth authored and HyukjinKwon committed Dec 12, 2023
1 parent e434c9f commit 2a49fee
Show file tree
Hide file tree
Showing 5 changed files with 96 additions and 3 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -193,6 +193,11 @@ private[sql] class ProtobufDeserializer(
case (INT, ShortType) =>
(updater, ordinal, value) => updater.setShort(ordinal, value.asInstanceOf[Short])

case (INT, LongType) =>
(updater, ordinal, value) =>
updater.setLong(
ordinal,
Integer.toUnsignedLong(value.asInstanceOf[Int]))
case (
MESSAGE | BOOLEAN | INT | FLOAT | DOUBLE | LONG | STRING | ENUM | BYTE_STRING,
ArrayType(dataType: DataType, containsNull)) if protoType.isRepeated =>
Expand All @@ -201,6 +206,13 @@ private[sql] class ProtobufDeserializer(
case (LONG, LongType) =>
(updater, ordinal, value) => updater.setLong(ordinal, value.asInstanceOf[Long])

case (LONG, DecimalType.LongDecimal) =>
(updater, ordinal, value) =>
updater.setDecimal(
ordinal,
Decimal.fromString(
UTF8String.fromString(java.lang.Long.toUnsignedString(value.asInstanceOf[Long]))))

case (FLOAT, FloatType) =>
(updater, ordinal, value) => updater.setFloat(ordinal, value.asInstanceOf[Float])

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@ package org.apache.spark.sql.protobuf

import scala.jdk.CollectionConverters._

import com.google.protobuf.{Duration, DynamicMessage, Timestamp}
import com.google.protobuf.{Duration, DynamicMessage, Timestamp, WireFormat}
import com.google.protobuf.Descriptors.{Descriptor, FieldDescriptor}
import com.google.protobuf.Descriptors.FieldDescriptor.JavaType._

Expand Down Expand Up @@ -91,8 +91,17 @@ private[sql] class ProtobufSerializer(
(getter, ordinal) => {
getter.getInt(ordinal)
}
case (LongType, INT) if fieldDescriptor.getLiteType == WireFormat.FieldType.UINT32 =>
(getter, ordinal) => {
getter.getLong(ordinal).toInt
}
case (LongType, LONG) =>
(getter, ordinal) => getter.getLong(ordinal)
case (DecimalType(), LONG)
if fieldDescriptor.getLiteType == WireFormat.FieldType.UINT64 =>
(getter, ordinal) => {
getter.getDecimal(ordinal, 20, 0).toUnscaledLong
}
case (FloatType, FLOAT) =>
(getter, ordinal) => getter.getFloat(ordinal)
case (DoubleType, DOUBLE) =>
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -168,6 +168,18 @@ private[sql] class ProtobufOptions(
// instead of string, so use caution if changing existing parsing logic.
val enumsAsInts: Boolean =
parameters.getOrElse("enums.as.ints", false.toString).toBoolean

// Protobuf supports unsigned integer types uint32 and uint64. By default this library
// will serialize them as the signed IntegerType and LongType respectively. For very
// large unsigned values this can cause overflow, causing these numbers
// to be represented as negative (above 2^31 for uint32
// and above 2^63 for uint64).
//
// Enabling this option will upcast unsigned integers into a larger type,
// i.e. LongType for uint32 and Decimal(20, 0) for uint64 so their representation
// can contain large unsigned values without overflow.
val upcastUnsignedInts: Boolean =
parameters.getOrElse("upcast.unsigned.ints", false.toString).toBoolean
}

private[sql] object ProtobufOptions {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@ package org.apache.spark.sql.protobuf.utils
import scala.jdk.CollectionConverters._

import com.google.protobuf.Descriptors.{Descriptor, FieldDescriptor}
import com.google.protobuf.WireFormat

import org.apache.spark.annotation.DeveloperApi
import org.apache.spark.internal.Logging
Expand Down Expand Up @@ -67,9 +68,22 @@ object SchemaConverters extends Logging {
existingRecordNames: Map[String, Int],
protobufOptions: ProtobufOptions): Option[StructField] = {
import com.google.protobuf.Descriptors.FieldDescriptor.JavaType._

val dataType = fd.getJavaType match {
case INT => Some(IntegerType)
case LONG => Some(LongType)
// When the protobuf type is unsigned and upcastUnsignedIntegers has been set,
// use a larger type (LongType and Decimal(20,0) for uint32 and uint64).
case INT =>
if (fd.getLiteType == WireFormat.FieldType.UINT32 && protobufOptions.upcastUnsignedInts) {
Some(LongType)
} else {
Some(IntegerType)
}
case LONG => if (fd.getLiteType == WireFormat.FieldType.UINT64
&& protobufOptions.upcastUnsignedInts) {
Some(DecimalType.LongDecimal)
} else {
Some(LongType)
}
case FLOAT => Some(FloatType)
case DOUBLE => Some(DoubleType)
case BOOLEAN => Some(BooleanType)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -1600,6 +1600,52 @@ class ProtobufFunctionsSuite extends QueryTest with SharedSparkSession with Prot
}
}

test("test unsigned integer types") {
// Test that we correctly handle unsigned integer parsing.
// We're using Integer/Long's `MIN_VALUE` as it has a 1 in the sign bit.
val sample = spark.range(1).select(
lit(
SimpleMessage
.newBuilder()
.setUint32Value(Integer.MIN_VALUE)
.setUint64Value(Long.MinValue)
.build()
.toByteArray
).as("raw_proto"))

val expectedWithoutFlag = spark.range(1).select(
lit(Integer.MIN_VALUE).as("uint32_value"),
lit(Long.MinValue).as("uint64_value")
)

val expectedWithFlag = spark.range(1).select(
lit(Integer.toUnsignedLong(Integer.MIN_VALUE).longValue).as("uint32_value"),
lit(BigDecimal(java.lang.Long.toUnsignedString(Long.MinValue))).as("uint64_value")
)

checkWithFileAndClassName("SimpleMessage") { case (name, descFilePathOpt) =>
List(
Map.empty[String, String],
Map("upcast.unsigned.ints" -> "false")).foreach(opts => {
checkAnswer(
sample.select(
from_protobuf_wrapper($"raw_proto", name, descFilePathOpt, opts).as("proto"))
.select("proto.uint32_value", "proto.uint64_value"),
expectedWithoutFlag)
})

checkAnswer(
sample.select(
from_protobuf_wrapper(
$"raw_proto",
name,
descFilePathOpt,
Map("upcast.unsigned.ints" -> "true")).as("proto"))
.select("proto.uint32_value", "proto.uint64_value"),
expectedWithFlag)
}
}


def testFromProtobufWithOptions(
df: DataFrame,
Expand Down

0 comments on commit 2a49fee

Please sign in to comment.