Skip to content
Permalink
Browse files
feat(spanner/spansql): support CAST and SAFE_CAST (#5057)
* feat(spanner/spansql): support CAST and SAFE_CAST

Adds support for the CAST and SAFE_CAST functions. This change also includes
a small refactor of the function evaluation to allow the function evaluation to
receive both parsing errors and type information for the arguments. This makes
it easier to implement the SAFE versions of the functions.

* fix: only ignore conversion errors in SAFE_CAST

* fix: outdent else statement

Co-authored-by: rahul2393 <rahulyadavsep92@gmail.com>
Co-authored-by: Hengfeng Li <hengfeng@google.com>
  • Loading branch information
3 people committed Nov 3, 2021
1 parent cbd5c8c commit 54cbf4c0a0305e680b213f84487110dfeaf8e7e1
Show file tree
Hide file tree
Showing 9 changed files with 331 additions and 15 deletions.
@@ -385,16 +385,19 @@ func (ec evalContext) evalArithOp(e spansql.ArithOp) (interface{}, error) {
}

func (ec evalContext) evalFunc(e spansql.Func) (interface{}, spansql.Type, error) {
var err error
if f, ok := functions[e.Name]; ok {
args := make([]interface{}, len(e.Args))
types := make([]spansql.Type, len(e.Args))
for i, arg := range e.Args {
val, err := ec.evalExpr(arg)
if err != nil {
if args[i], err = ec.evalExpr(arg); err != nil {
return nil, spansql.Type{}, err
}
args[i] = val
if te, ok := arg.(spansql.TypedExpr); ok {
types[i] = te.Type
}
}
return f.Eval(args)
return f.Eval(args, types)
}
return nil, spansql.Type{}, status.Errorf(codes.Unimplemented, "function %q is not implemented", e.Name)
}
@@ -464,6 +467,8 @@ func (ec evalContext) evalExpr(e spansql.Expr) (interface{}, error) {
return bool(e), nil
case spansql.Paren:
return ec.evalExpr(e.Expr)
case spansql.TypedExpr:
return ec.evalTypedExpr(e)
case spansql.Func:
v, _, err := ec.evalFunc(e)
if err != nil {
@@ -662,6 +667,14 @@ func (ec evalContext) coerceString(target spansql.Expr, slit spansql.StringLiter
return nil, fmt.Errorf("unable to coerce string literal %q to match %v", slit, ci.Type)
}

func (ec evalContext) evalTypedExpr(expr spansql.TypedExpr) (result interface{}, err error) {
val, err := ec.evalExpr(expr.Expr)
if err != nil {
return nil, err
}
return convert(val, expr.Type)
}

func evalLiteralOrParam(lop spansql.LiteralOrParam, params queryParams) (int64, error) {
switch v := lop.(type) {
case spansql.IntegerLiteral:
@@ -19,8 +19,11 @@ package spannertest
import (
"fmt"
"math"
"strconv"
"strings"
"time"

"cloud.google.com/go/civil"
"cloud.google.com/go/spanner/spansql"
"google.golang.org/grpc/codes"
"google.golang.org/grpc/status"
@@ -30,12 +33,21 @@ import (

type function struct {
// Eval evaluates the result of the function using the given input.
Eval func(values []interface{}) (interface{}, spansql.Type, error)
Eval func(values []interface{}, types []spansql.Type) (interface{}, spansql.Type, error)
}

func firstErr(errors []error) error {
for _, err := range errors {
if err != nil {
return err
}
}
return nil
}

var functions = map[string]function{
"STARTS_WITH": {
Eval: func(values []interface{}) (interface{}, spansql.Type, error) {
Eval: func(values []interface{}, types []spansql.Type) (interface{}, spansql.Type, error) {
// TODO: Refine error messages to exactly match Spanner.
// Check input values first.
if len(values) != 2 {
@@ -53,7 +65,7 @@ var functions = map[string]function{
},
},
"LOWER": {
Eval: func(values []interface{}) (interface{}, spansql.Type, error) {
Eval: func(values []interface{}, types []spansql.Type) (interface{}, spansql.Type, error) {
if len(values) != 1 {
return nil, spansql.Type{}, status.Error(codes.InvalidArgument, "No matching signature for function LOWER for the given argument types")
}
@@ -66,6 +78,159 @@ var functions = map[string]function{
return strings.ToLower(values[0].(string)), spansql.Type{Base: spansql.String}, nil
},
},
"CAST": {
Eval: func(values []interface{}, types []spansql.Type) (interface{}, spansql.Type, error) {
return cast(values, types, false)
},
},
"SAFE_CAST": {
Eval: func(values []interface{}, types []spansql.Type) (interface{}, spansql.Type, error) {
return cast(values, types, true)
},
},
}

func cast(values []interface{}, types []spansql.Type, safe bool) (interface{}, spansql.Type, error) {
name := "CAST"
if safe {
name = "SAFE_CAST"
}
if len(types) != 1 {
return nil, spansql.Type{}, status.Errorf(codes.InvalidArgument, "No type information for function %s for the given arguments", name)
}
if len(values) != 1 {
return nil, spansql.Type{}, status.Errorf(codes.InvalidArgument, "No matching signature for function %s for the given arguments", name)
}
// If the input type is an error, then the conversion itself failed.
if err, ok := values[0].(error); ok {
if safe {
return nil, types[0], nil
}
return nil, types[0], err
}
return values[0], types[0], nil
}

func convert(val interface{}, tp spansql.Type) (interface{}, error) {
// TODO: Implement more conversions.
if tp.Array {
return nil, status.Errorf(codes.Unimplemented, "conversion to ARRAY types is not implemented")
}
var res interface{}
var convertErr, err error
switch tp.Base {
case spansql.Int64:
res, convertErr, err = convertToInt64(val)
case spansql.Float64:
res, convertErr, err = convertToFloat64(val)
case spansql.String:
res, convertErr, err = convertToString(val)
case spansql.Bool:
res, convertErr, err = convertToBool(val)
case spansql.Date:
res, convertErr, err = convertToDate(val)
case spansql.Timestamp:
res, convertErr, err = convertToTimestamp(val)
case spansql.Numeric:
case spansql.JSON:
}
if err != nil {
return nil, err
}
if convertErr != nil {
res = convertErr
}
if res != nil {
return res, nil
}

return nil, status.Errorf(codes.Unimplemented, "unsupported conversion for %v to %v", val, tp.Base.SQL())
}

func convertToInt64(val interface{}) (res int64, convertErr error, err error) {
switch v := val.(type) {
case int64:
return v, nil, nil
case string:
res, err := strconv.ParseInt(v, 10, 64)
if err != nil {
return 0, status.Errorf(codes.InvalidArgument, "invalid value for INT64: %q", v), nil
}
return res, nil, nil
}
return 0, nil, status.Errorf(codes.Unimplemented, "unsupported conversion for %v to INT64", val)
}

func convertToFloat64(val interface{}) (res float64, convertErr error, err error) {
switch v := val.(type) {
case int64:
return float64(v), nil, nil
case float64:
return v, nil, nil
case string:
res, err := strconv.ParseFloat(v, 64)
if err != nil {
return 0, status.Errorf(codes.InvalidArgument, "invalid value for FLOAT64: %q", v), nil
}
return res, nil, nil
}
return 0, nil, status.Errorf(codes.Unimplemented, "unsupported conversion for %v to FLOAT64", val)
}

func convertToString(val interface{}) (res string, convertErr error, err error) {
switch v := val.(type) {
case string:
return v, nil, nil
case bool, int64, float64:
return fmt.Sprintf("%v", v), nil, nil
case civil.Date:
return v.String(), nil, nil
case time.Time:
return v.UTC().Format(time.RFC3339Nano), nil, nil
}
return "", nil, status.Errorf(codes.Unimplemented, "unsupported conversion for %v to STRING", val)
}

func convertToBool(val interface{}) (res bool, convertErr error, err error) {
switch v := val.(type) {
case bool:
return v, nil, nil
case string:
res, err := strconv.ParseBool(v)
if err != nil {
return false, status.Errorf(codes.InvalidArgument, "invalid value for BOOL: %q", v), nil
}
return res, nil, nil
}
return false, nil, status.Errorf(codes.Unimplemented, "unsupported conversion for %v to BOOL", val)
}

func convertToDate(val interface{}) (res civil.Date, convertErr error, err error) {
switch v := val.(type) {
case civil.Date:
return v, nil, nil
case string:
res, err := civil.ParseDate(v)
if err != nil {
return civil.Date{}, status.Errorf(codes.InvalidArgument, "invalid value for DATE: %q", v), nil
}
return res, nil, nil
}
return civil.Date{}, nil, status.Errorf(codes.Unimplemented, "unsupported conversion for %v to DATE", val)
}

func convertToTimestamp(val interface{}) (res time.Time, convertErr error, err error) {
switch v := val.(type) {
case time.Time:
return v, nil, nil
case string:
res, err := time.Parse(time.RFC3339Nano, v)
if err != nil {
return time.Time{}, status.Errorf(codes.InvalidArgument, "invalid value for TIMESTAMP: %q", v), nil
}
return res, nil, nil
}
return time.Time{}, nil, status.Errorf(codes.Unimplemented, "unsupported conversion for %v to TIMESTAMP", val)
}

type aggregateFunc struct {
@@ -741,16 +741,23 @@ func TestIntegration_ReadsAndQueries(t *testing.T) {
t.Errorf("Updating with DML affected %d rows, want 3", n)
}

rows := client.Single().Query(ctx, spanner.NewStatement("SELECT CAST('Foo' AS INT64)"))
_, err = rows.Next()
if g, w := spanner.ErrCode(err), codes.InvalidArgument; g != w {
t.Errorf("error code mismatch for invalid CAST\n Got: %v\nWant: %v", g, w)
}
rows.Stop()

// Do some complex queries.
tests := []struct {
q string
params map[string]interface{}
want [][]interface{}
}{
{
`SELECT 17, "sweet", TRUE AND FALSE, NULL, B"hello", STARTS_WITH('Foo', 'B'), STARTS_WITH('Bar', 'B')`,
`SELECT 17, "sweet", TRUE AND FALSE, NULL, B"hello", STARTS_WITH('Foo', 'B'), STARTS_WITH('Bar', 'B'), CAST(17 AS STRING), SAFE_CAST(TRUE AS STRING), SAFE_CAST('Foo' AS INT64)`,
nil,
[][]interface{}{{int64(17), "sweet", false, nil, []byte("hello"), false, true}},
[][]interface{}{{int64(17), "sweet", false, nil, []byte("hello"), false, true, "17", "true", nil}},
},
// Check handling of NULL values for the IS operator.
// There was a bug that returned errors for some of these cases.
@@ -128,11 +128,15 @@ var keywords = map[string]bool{
// funcs is the set of reserved keywords that are functions.
// https://cloud.google.com/spanner/docs/functions-and-operators
var funcs = make(map[string]bool)
var funcArgParsers = make(map[string]func(*parser) (Expr, *parseError))

func init() {
for _, f := range allFuncs {
funcs[f] = true
}
// Special case for CAST and SAFE_CAST
funcArgParsers["CAST"] = typedArgParser
funcArgParsers["SAFE_CAST"] = typedArgParser
}

var allFuncs = []string{
@@ -148,6 +152,10 @@ var allFuncs = []string{
"MIN",
"SUM",

// Cast functions.
"CAST",
"SAFE_CAST",

// Mathematical functions.
"ABS",

@@ -1893,8 +1893,16 @@ var baseTypes = map[string]TypeBase{
"JSON": JSON,
}

func (p *parser) parseBaseType() (Type, *parseError) {
return p.parseBaseOrParameterizedType(false)
}

func (p *parser) parseType() (Type, *parseError) {
debugf("parseType: %v", p)
return p.parseBaseOrParameterizedType(true)
}

func (p *parser) parseBaseOrParameterizedType(withParam bool) (Type, *parseError) {
debugf("parseBaseOrParameterizedType: %v", p)

/*
array_type:
@@ -1928,7 +1936,7 @@ func (p *parser) parseType() (Type, *parseError) {
}
t.Base = base

if t.Base == String || t.Base == Bytes {
if withParam && (t.Base == String || t.Base == Bytes) {
if err := p.expect("("); err != nil {
return Type{}, err
}
@@ -2436,9 +2444,15 @@ func (p *parser) parseExprList() ([]Expr, *parseError) {
}

func (p *parser) parseParenExprList() ([]Expr, *parseError) {
return p.parseParenExprListWithParseFunc(func(p *parser) (Expr, *parseError) {
return p.parseExpr()
})
}

func (p *parser) parseParenExprListWithParseFunc(f func(*parser) (Expr, *parseError)) ([]Expr, *parseError) {
var list []Expr
err := p.parseCommaList("(", ")", func(p *parser) *parseError {
e, err := p.parseExpr()
e, err := f(p)
if err != nil {
return err
}
@@ -2448,6 +2462,26 @@ func (p *parser) parseParenExprList() ([]Expr, *parseError) {
return list, err
}

// Special argument parser for CAST and SAFE_CAST
var typedArgParser = func(p *parser) (Expr, *parseError) {
e, err := p.parseExpr()
if err != nil {
return nil, err
}
if err := p.expect("AS"); err != nil {
return nil, err
}
// typename in cast function must not be parameterized types
toType, err := p.parseBaseType()
if err != nil {
return nil, err
}
return TypedExpr{
Expr: e,
Type: toType,
}, nil
}

/*
Expressions
@@ -2800,7 +2834,13 @@ func (p *parser) parseLit() (Expr, *parseError) {
// this is a function invocation.
// The `funcs` map is keyed by upper case strings.
if name := strings.ToUpper(tok.value); funcs[name] && p.sniff("(") {
list, err := p.parseParenExprList()
var list []Expr
var err *parseError
if f, ok := funcArgParsers[name]; ok {
list, err = p.parseParenExprListWithParseFunc(f)
} else {
list, err = p.parseParenExprList()
}
if err != nil {
return nil, err
}

0 comments on commit 54cbf4c

Please sign in to comment.