Skip to content
Permalink
Browse files
feat(spanner/spansql): support EXTRACT (#5218)
* feat(spanner/spansql): support EXTRACT

* added separate Expr for Extract func and added unit and integration tests

* add test for year

* repleace atTimeZone func with atTimeZone expression

* fixing failing tests

* added negative test, reduced the valid extract part values.

* remove extra space

Co-authored-by: Rahul Yadav <irahul@google.com>
  • Loading branch information
rahul2393 and rahul2393 committed Dec 16, 2021
1 parent 2c664a6 commit 81b7c85a8993a36557ea4eb4ec0c47d1f93c4960
Show file tree
Hide file tree
Showing 8 changed files with 208 additions and 16 deletions.
@@ -469,6 +469,10 @@ func (ec evalContext) evalExpr(e spansql.Expr) (interface{}, error) {
return ec.evalExpr(e.Expr)
case spansql.TypedExpr:
return ec.evalTypedExpr(e)
case spansql.ExtractExpr:
return ec.evalExtractExpr(e)
case spansql.AtTimeZoneExpr:
return ec.evalAtTimeZoneExpr(e)
case spansql.Func:
v, _, err := ec.evalFunc(e)
if err != nil {
@@ -675,6 +679,61 @@ func (ec evalContext) evalTypedExpr(expr spansql.TypedExpr) (result interface{},
return convert(val, expr.Type)
}

func (ec evalContext) evalExtractExpr(expr spansql.ExtractExpr) (result interface{}, err error) {
val, err := ec.evalExpr(expr.Expr)
if err != nil {
return nil, err
}
switch expr.Part {
case "DATE":
switch v := val.(type) {
case time.Time:
return civil.DateOf(v), nil
case civil.Date:
return v, nil
}
case "DAY":
switch v := val.(type) {
case time.Time:
return int64(v.Day()), nil
case civil.Date:
return int64(v.Day), nil
}
case "MONTH":
switch v := val.(type) {
case time.Time:
return int64(v.Month()), nil
case civil.Date:
return int64(v.Month), nil
}
case "YEAR":
switch v := val.(type) {
case time.Time:
return int64(v.Year()), nil
case civil.Date:
return int64(v.Year), nil
}
}
return nil, fmt.Errorf("Extract with part %v not supported", expr.Part)
}

func (ec evalContext) evalAtTimeZoneExpr(expr spansql.AtTimeZoneExpr) (result interface{}, err error) {
val, err := ec.evalExpr(expr.Expr)
if err != nil {
return nil, err
}
switch v := val.(type) {
case time.Time:
loc, err := time.LoadLocation(expr.Zone)
if err != nil {
return nil, fmt.Errorf("AtTimeZone with %T not supported", v)
}
return v.In(loc), nil
default:
return nil, fmt.Errorf("AtTimeZone with %T not supported", val)
}
}

func evalLiteralOrParam(lop spansql.LiteralOrParam, params queryParams) (int64, error) {
switch v := lop.(type) {
case spansql.IntegerLiteral:
@@ -107,6 +107,32 @@ var functions = map[string]function{
return "", spansql.Type{Base: spansql.String}, nil
},
},
"EXTRACT": {
Eval: func(values []interface{}, types []spansql.Type) (interface{}, spansql.Type, error) {
date, okArg1 := values[0].(civil.Date)
part, okArg2 := values[0].(int64)
if !(okArg1 || okArg2) {
return nil, spansql.Type{}, status.Error(codes.InvalidArgument, "No matching signature for function EXTRACT for the given argument types")
}
if okArg1 {
return date, spansql.Type{Base: spansql.Date}, nil
}
return part, spansql.Type{Base: spansql.Int64}, nil
},
},
"TIMESTAMP": {
Eval: func(values []interface{}, types []spansql.Type) (interface{}, spansql.Type, error) {
t, okArg1 := values[0].(string)
if !(okArg1) {
return nil, spansql.Type{}, status.Error(codes.InvalidArgument, "No matching signature for function TIMESTAMP for the given argument types")
}
timestamp, err := time.Parse(time.RFC3339, t)
if err != nil {
return nil, spansql.Type{}, status.Error(codes.InvalidArgument, "No matching signature for function TIMESTAMP for the given argument types")
}
return timestamp, spansql.Type{Base: spansql.Timestamp}, nil
},
},
}

func cast(values []interface{}, types []spansql.Type, safe bool) (interface{}, spansql.Type, error) {
@@ -748,16 +748,24 @@ func TestIntegration_ReadsAndQueries(t *testing.T) {
}
rows.Stop()

rows = client.Single().Query(ctx, spanner.NewStatement("SELECT EXTRACT(INVALID_PART FROM TIMESTAMP('2008-12-25T05:30:00Z')"))
_, err = rows.Next()
if g, w := spanner.ErrCode(err), codes.InvalidArgument; g != w {
t.Errorf("error code mismatch for invalid part from EXTRACT\n Got: %v\nWant: %v", g, w)
}
rows.Stop()

// Do some complex queries.
tests := []struct {
q string
params map[string]interface{}
want [][]interface{}
}{
{
`SELECT 17, "sweet", TRUE AND FALSE, NULL, B"hello", STARTS_WITH('Foo', 'B'), STARTS_WITH('Bar', 'B'), CAST(17 AS STRING), SAFE_CAST(TRUE AS STRING), SAFE_CAST('Foo' AS INT64)`,

`SELECT 17, "sweet", TRUE AND FALSE, NULL, B"hello", STARTS_WITH('Foo', 'B'), STARTS_WITH('Bar', 'B'), CAST(17 AS STRING), SAFE_CAST(TRUE AS STRING), SAFE_CAST('Foo' AS INT64), EXTRACT(DATE FROM TIMESTAMP('2008-12-25T05:30:00Z') AT TIME ZONE 'Europe/Amsterdam'), EXTRACT(YEAR FROM TIMESTAMP('2008-12-25T05:30:00Z'))`,
nil,
[][]interface{}{{int64(17), "sweet", false, nil, []byte("hello"), false, true, "17", "true", nil}},
[][]interface{}{{int64(17), "sweet", false, nil, []byte("hello"), false, true, "17", "true", nil, civil.Date{Year: 2008, Month: 12, Day: 25}, int64(2008)}},
},
// Check handling of NULL values for the IS operator.
// There was a bug that returned errors for some of these cases.
@@ -1277,13 +1285,16 @@ func TestIntegration_GeneratedColumns(t *testing.T) {
defer cancel()

tableName := "SongWriters"

err := updateDDL(t, adminClient,
`CREATE TABLE `+tableName+` (
Name STRING(50) NOT NULL,
NumSongs INT64,
CreatedAT TIMESTAMP,
CreatedDate DATE,
EstimatedSales INT64 NOT NULL,
CanonicalName STRING(50) AS (LOWER(Name)) STORED,
GeneratedCreatedDate DATE AS (EXTRACT(DATE FROM CreatedAT AT TIME ZONE "CET")) STORED,
GeneratedCreatedDay INT64 AS (EXTRACT(DAY FROM CreatedDate)) STORED,
) PRIMARY KEY (Name)`)
if err != nil {
t.Fatalf("Setting up fresh table: %v", err)
@@ -1295,16 +1306,18 @@ func TestIntegration_GeneratedColumns(t *testing.T) {
}

// Insert some data.
d1, _ := civil.ParseDate("2016-11-15")
t1, _ := time.Parse(time.RFC3339Nano, "2016-11-15T15:04:05.999999999Z")
_, err = client.Apply(ctx, []*spanner.Mutation{
spanner.Insert(tableName,
[]string{"Name", "EstimatedSales", "NumSongs"},
[]interface{}{"Average Writer", 10, 10}),
[]string{"Name", "EstimatedSales", "NumSongs", "CreatedAT", "CreatedDate"},
[]interface{}{"Average Writer", 10, 10, t1, d1}),
spanner.Insert(tableName,
[]string{"Name", "EstimatedSales"},
[]interface{}{"Great Writer", 100}),
[]string{"Name", "EstimatedSales", "CreatedAT", "CreatedDate"},
[]interface{}{"Great Writer", 100, t1, d1}),
spanner.Insert(tableName,
[]string{"Name", "EstimatedSales", "NumSongs"},
[]interface{}{"Poor Writer", 1, 50}),
[]string{"Name", "EstimatedSales", "NumSongs", "CreatedAT", "CreatedDate"},
[]interface{}{"Poor Writer", 1, 50, t1, d1}),
})
if err != nil {
t.Fatalf("Applying mutations: %v", err)
@@ -1317,7 +1330,7 @@ func TestIntegration_GeneratedColumns(t *testing.T) {
}

ri := client.Single().Query(ctx, spanner.NewStatement(
`SELECT CanonicalName, TotalSales FROM `+tableName+` ORDER BY Name`,
`SELECT CanonicalName, TotalSales, GeneratedCreatedDate, GeneratedCreatedDay FROM `+tableName+` ORDER BY Name`,
))
all, err := slurpRows(t, ri)
if err != nil {
@@ -1326,9 +1339,9 @@ func TestIntegration_GeneratedColumns(t *testing.T) {

// Great writer has nil because NumSongs is nil
want := [][]interface{}{
{"average writer", int64(100)},
{"great writer", nil},
{"poor writer", int64(50)},
{"average writer", int64(100), civil.Date{Year: 2016, Month: 11, Day: 15}, int64(15)},
{"great writer", nil, civil.Date{Year: 2016, Month: 11, Day: 15}, int64(15)},
{"poor writer", int64(50), civil.Date{Year: 2016, Month: 11, Day: 15}, int64(15)},
}
if !reflect.DeepEqual(all, want) {
t.Errorf("Expected values are wrong.\n got %v\nwant %v", all, want)
@@ -134,9 +134,10 @@ func init() {
for _, f := range allFuncs {
funcs[f] = true
}
// Special case for CAST and SAFE_CAST
// Special case for CAST, SAFE_CAST and EXTRACT
funcArgParsers["CAST"] = typedArgParser
funcArgParsers["SAFE_CAST"] = typedArgParser
funcArgParsers["EXTRACT"] = extractArgParser
}

var allFuncs = []string{
@@ -1901,6 +1901,27 @@ func (p *parser) parseType() (Type, *parseError) {
return p.parseBaseOrParameterizedType(true)
}

var extractPartTypes = map[string]TypeBase{
"DAY": Int64,
"MONTH": Int64,
"YEAR": Int64,
"DATE": Date,
}

func (p *parser) parseExtractType() (Type, string, *parseError) {
var t Type
tok := p.next()
if tok.err != nil {
return Type{}, "", tok.err
}
base, ok := extractPartTypes[strings.ToUpper(tok.value)] // valid part types for EXTRACT is keyed by upper case strings.
if !ok {
return Type{}, "", p.errorf("got %q, want valid EXTRACT types", tok.value)
}
t.Base = base
return t, strings.ToUpper(tok.value), nil
}

func (p *parser) parseBaseOrParameterizedType(withParam bool) (Type, *parseError) {
debugf("parseBaseOrParameterizedType: %v", p)

@@ -2482,6 +2503,34 @@ var typedArgParser = func(p *parser) (Expr, *parseError) {
}, nil
}

// Special argument parser for EXTRACT
var extractArgParser = func(p *parser) (Expr, *parseError) {
partType, part, err := p.parseExtractType()
if err != nil {
return nil, err
}
if err := p.expect("FROM"); err != nil {
return nil, err
}
e, err := p.parseExpr()
if err != nil {
return nil, err
}
// AT TIME ZONE is optional
if p.eat("AT", "TIME", "ZONE") {
tok := p.next()
if tok.err != nil {
return nil, err
}
return ExtractExpr{Part: part, Type: partType, Expr: AtTimeZoneExpr{Expr: e, Zone: tok.string, Type: Type{Base: Timestamp}}}, nil
}
return ExtractExpr{
Part: part,
Expr: e,
Type: partType,
}, nil
}

/*
Expressions
@@ -340,6 +340,8 @@ func TestParseExpr(t *testing.T) {
{`STARTS_WITH(Bar, 'B')`, Func{Name: "STARTS_WITH", Args: []Expr{ID("Bar"), StringLiteral("B")}}},
{`CAST(Bar AS STRING)`, Func{Name: "CAST", Args: []Expr{TypedExpr{Expr: ID("Bar"), Type: Type{Base: String}}}}},
{`SAFE_CAST(Bar AS INT64)`, Func{Name: "SAFE_CAST", Args: []Expr{TypedExpr{Expr: ID("Bar"), Type: Type{Base: Int64}}}}},
{`EXTRACT(DATE FROM TIMESTAMP AT TIME ZONE "America/Los_Angeles")`, Func{Name: "EXTRACT", Args: []Expr{ExtractExpr{Part: "DATE", Type: Type{Base: Date}, Expr: AtTimeZoneExpr{Expr: ID("TIMESTAMP"), Zone: "America/Los_Angeles", Type: Type{Base: Timestamp}}}}}},
{`EXTRACT(DAY FROM DATE)`, Func{Name: "EXTRACT", Args: []Expr{ExtractExpr{Part: "DAY", Expr: ID("DATE"), Type: Type{Base: Int64}}}}},

// String literal:
// Accept double quote and single quote.
@@ -524,7 +526,9 @@ func TestParseDDL(t *testing.T) {
CREATE TABLE users (
user_id STRING(36) NOT NULL,
some_string STRING(16) NOT NULL,
some_time TIMESTAMP NOT NULL,
number_key INT64 AS (SAFE_CAST(SUBSTR(some_string, 2) AS INT64)) STORED,
generated_date DATE AS (EXTRACT(DATE FROM some_time AT TIME ZONE "CET")) STORED,
) PRIMARY KEY(user_id);
-- Trailing comment at end of file.
@@ -744,12 +748,20 @@ func TestParseDDL(t *testing.T) {
Columns: []ColumnDef{
{Name: "user_id", Type: Type{Base: String, Len: 36}, NotNull: true, Position: line(67)},
{Name: "some_string", Type: Type{Base: String, Len: 16}, NotNull: true, Position: line(68)},
{Name: "some_time", Type: Type{Base: Timestamp}, NotNull: true, Position: line(69)},
{
Name: "number_key", Type: Type{Base: Int64},
Generated: Func{Name: "SAFE_CAST", Args: []Expr{
TypedExpr{Expr: Func{Name: "SUBSTR", Args: []Expr{ID("some_string"), IntegerLiteral(2)}}, Type: Type{Base: Int64}},
}},
Position: line(69),
Position: line(70),
},
{
Name: "generated_date", Type: Type{Base: Date},
Generated: Func{Name: "EXTRACT", Args: []Expr{
ExtractExpr{Part: "DATE", Type: Type{Base: Date}, Expr: AtTimeZoneExpr{Expr: ID("some_time"), Zone: "CET", Type: Type{Base: Timestamp}}},
}},
Position: line(71),
},
},
PrimaryKey: []KeyPart{{Column: "user_id"}},
@@ -777,7 +789,7 @@ func TestParseDDL(t *testing.T) {
{Marker: "--", Isolated: true, Start: line(49), End: line(49), Text: []string{"Table with row deletion policy."}},

// Comment after everything else.
{Marker: "--", Isolated: true, Start: line(72), End: line(72), Text: []string{"Trailing comment at end of file."}},
{Marker: "--", Isolated: true, Start: line(74), End: line(74), Text: []string{"Trailing comment at end of file."}},
}}},
// No trailing comma:
{`ALTER TABLE T ADD COLUMN C2 INT64`, &DDL{Filename: "filename", List: []DDLStmt{
@@ -589,6 +589,20 @@ func (te TypedExpr) addSQL(sb *strings.Builder) {
sb.WriteString(te.Type.SQL())
}

func (ee ExtractExpr) SQL() string { return buildSQL(ee) }
func (ee ExtractExpr) addSQL(sb *strings.Builder) {
sb.WriteString(ee.Part)
sb.WriteString(" FROM ")
ee.Expr.addSQL(sb)
}

func (aze AtTimeZoneExpr) SQL() string { return buildSQL(aze) }
func (aze AtTimeZoneExpr) addSQL(sb *strings.Builder) {
aze.Expr.addSQL(sb)
sb.WriteString(" AT TIME ZONE ")
sb.WriteString(aze.Zone)
}

func idList(l []ID, join string) string {
var ss []string
for _, s := range l {
@@ -647,6 +647,24 @@ type TypedExpr struct {
func (TypedExpr) isBoolExpr() {} // possibly bool
func (TypedExpr) isExpr() {}

type ExtractExpr struct {
Part string
Type Type
Expr Expr
}

func (ExtractExpr) isBoolExpr() {} // possibly bool
func (ExtractExpr) isExpr() {}

type AtTimeZoneExpr struct {
Expr Expr
Type Type
Zone string
}

func (AtTimeZoneExpr) isBoolExpr() {} // possibly bool
func (AtTimeZoneExpr) isExpr() {}

// Paren represents a parenthesised expression.
type Paren struct {
Expr Expr

0 comments on commit 81b7c85

Please sign in to comment.