Skip to content
Permalink
Browse files
fix(spanner): invalid numeric should throw an error (#3926)
* fix(spanner): invalid numeric should throw an error

* Add round and error modes.

* Update comments.

* Fix comments.

* Fix comments.
  • Loading branch information
hengfengli committed Aug 23, 2021
1 parent 68b8eb8 commit cde8697be01f1ef57806275c0ddf54f87bb9a571
Show file tree
Hide file tree
Showing 2 changed files with 100 additions and 0 deletions.
@@ -24,6 +24,7 @@ import (
"math/big"
"reflect"
"strconv"
"strings"
"time"

"cloud.google.com/go/civil"
@@ -49,13 +50,56 @@ const (
NumericScaleDigits = 9
)

// LossOfPrecisionHandlingOption describes the option to deal with loss of
// precision on numeric values.
type LossOfPrecisionHandlingOption int

const (
// NumericRound automatically rounds a numeric value that has a higher
// precision than what is supported by Spanner, e.g., 0.1234567895 rounds
// to 0.123456790.
NumericRound LossOfPrecisionHandlingOption = iota
// NumericError returns an error for numeric values that have a higher
// precision than what is supported by Spanner. E.g. the client returns an
// error if the application tries to insert the value 0.1234567895.
NumericError
)

// LossOfPrecisionHandling configures how to deal with loss of precision on
// numeric values. The value of this configuration is global and will be used
// for all Spanner clients.
var LossOfPrecisionHandling LossOfPrecisionHandlingOption

// NumericString returns a string representing a *big.Rat in a format compatible
// with Spanner SQL. It returns a floating-point literal with 9 digits after the
// decimal point.
func NumericString(r *big.Rat) string {
return r.FloatString(NumericScaleDigits)
}

// validateNumeric returns nil if there are no errors. It will return an error
// when the numeric number is not valid.
func validateNumeric(r *big.Rat) error {
if r == nil {
return nil
}
// Add one more digit to the scale component to find out if there are more
// digits than required.
strRep := r.FloatString(NumericScaleDigits + 1)
strRep = strings.TrimRight(strRep, "0")
strRep = strings.TrimLeft(strRep, "-")
s := strings.Split(strRep, ".")
whole := s[0]
scale := s[1]
if len(scale) > NumericScaleDigits {
return fmt.Errorf("max scale for a numeric is %d. The requested numeric has more", NumericScaleDigits)
}
if len(whole) > NumericPrecisionDigits-NumericScaleDigits {
return fmt.Errorf("max precision for the whole component of a numeric is %d. The requested numeric has a whole component with precision %d", NumericPrecisionDigits-NumericScaleDigits, len(whole))
}
return nil
}

var (
// CommitTimestamp is a special value used to tell Cloud Spanner to insert
// the commit timestamp of the transaction into a column. It can be used in
@@ -2671,6 +2715,15 @@ func encodeValue(v interface{}) (*proto3.Value, *sppb.Type, error) {
}
pt = listType(floatType())
case big.Rat:
switch LossOfPrecisionHandling {
case NumericError:
err = validateNumeric(&v)
if err != nil {
return nil, nil, err
}
case NumericRound:
// pass
}
pb.Kind = stringKind(NumericString(&v))
pt = numericType()
case []big.Rat:
@@ -2695,6 +2748,15 @@ func encodeValue(v interface{}) (*proto3.Value, *sppb.Type, error) {
}
pt = listType(numericType())
case *big.Rat:
switch LossOfPrecisionHandling {
case NumericError:
err = validateNumeric(v)
if err != nil {
return nil, nil, err
}
case NumericRound:
// pass
}
if v != nil {
pb.Kind = stringKind(NumericString(v))
}
@@ -228,6 +228,8 @@ func TestEncodeValue(t *testing.T) {
numValuePtr := big.NewRat(12345, 1e3)
var numNilPtr *big.Rat
num2ValuePtr := big.NewRat(12345, 1e4)
maxNumValuePtr, _ := (&big.Rat{}).SetString("99999999999999999999999999999.999999999")
minNumValuePtr, _ := (&big.Rat{}).SetString("-99999999999999999999999999999.999999999")

var (
tString = stringType()
@@ -296,6 +298,8 @@ func TestEncodeValue(t *testing.T) {
// NUMERIC / NUMERIC ARRAY
{*numValuePtr, numericProto(numValuePtr), tNumeric, "big.Rat"},
{numValuePtr, numericProto(numValuePtr), tNumeric, "*big.Rat"},
{maxNumValuePtr, numericProto(maxNumValuePtr), tNumeric, "max numeric"},
{minNumValuePtr, numericProto(minNumValuePtr), tNumeric, "min numeric"},
{numNilPtr, nullProto(), tNumeric, "*big.Rat with null"},
{NullNumeric{*numValuePtr, true}, numericProto(numValuePtr), tNumeric, "NullNumeric with value"},
{NullNumeric{*numValuePtr, false}, nullProto(), tNumeric, "NullNumeric with null"},
@@ -429,6 +433,40 @@ func TestEncodeValue(t *testing.T) {
}
}

// Test encoding invalid values.
func TestEncodeInvalidValues(t *testing.T) {
type CustomNumeric big.Rat

invalidNumPtr1 := big.NewRat(11234567891, 1e10)
invalidNumPtr2, _ := (&big.Rat{}).SetString("199999999999999999999999999999.999999999")

// Enable error mode.
LossOfPrecisionHandling = NumericError

for i, test := range []struct {
desc string
in interface{}
errMsg string
}{
// NUMERIC
{desc: "numeric pointer with invalid scale component", in: invalidNumPtr1, errMsg: "max scale for a numeric is 9. The requested numeric has more"},
{desc: "numeric pointer with invalid whole component", in: invalidNumPtr2, errMsg: "max precision for the whole component of a numeric is 29. The requested numeric has a whole component with precision 30"},
{desc: "numeric with invalid scale component", in: *invalidNumPtr1, errMsg: "max scale for a numeric is 9. The requested numeric has more"},
{desc: "numeric with invalid whole component", in: *invalidNumPtr2, errMsg: "max precision for the whole component of a numeric is 29. The requested numeric has a whole component with precision 30"},
// CUSTOM NUMERIC
{desc: "custom numeric type with invalid scale component", in: CustomNumeric(*invalidNumPtr1), errMsg: "max scale for a numeric is 9. The requested numeric has more"},
{desc: "custom numeric type with invalid whole component", in: CustomNumeric(*invalidNumPtr2), errMsg: "max precision for the whole component of a numeric is 29. The requested numeric has a whole component with precision 30"},
} {
_, _, err := encodeValue(test.in)
if err == nil {
t.Fatalf("#%d (%s): want error during encoding, but got nil", i, test.desc)
}
if err.Error() != test.errMsg {
t.Errorf("#%d (%s): incorrect error message, got %v, want %v", i, test.desc, err, test.errMsg)
}
}
}

type encodeTest struct {
desc string
in interface{}

0 comments on commit cde8697

Please sign in to comment.