Skip to content

Commit

Permalink
added JSONLiteral, fixed enginetests
Browse files Browse the repository at this point in the history
  • Loading branch information
andy-wm-arthur committed Mar 11, 2021
1 parent 2503feb commit a675dc5
Show file tree
Hide file tree
Showing 12 changed files with 138 additions and 49 deletions.
8 changes: 8 additions & 0 deletions enginetest/enginetests.go
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@ package enginetest

import (
"context"
"encoding/json"
"strings"
"sync/atomic"
"testing"
Expand Down Expand Up @@ -3308,6 +3309,13 @@ func MustConvert(val interface{}, err error) interface{} {
return val
}

func MustJSON(s string) (doc interface{}) {
if err := json.Unmarshal([]byte(s), &doc); err != nil {
panic(err)
}
return doc
}

var pid uint64

func NewContext(harness Harness) *sql.Context {
Expand Down
8 changes: 4 additions & 4 deletions enginetest/insert_queries.go
Original file line number Diff line number Diff line change
Expand Up @@ -88,7 +88,7 @@ var InsertQueries = []WriteQueryTest{
uint8(math.MaxUint8), uint16(math.MaxUint16), uint32(math.MaxUint32), uint64(math.MaxUint64),
float32(math.MaxFloat32), float64(math.MaxFloat64),
MustConvert(sql.Timestamp.Convert("2037-04-05 12:51:36")), MustConvert(sql.Date.Convert("2231-11-07")),
"random text", sql.True, ([]byte)(`{"key":"value"}`), "blobdata",
"random text", sql.True, MustJSON(`{"key":"value"}`), "blobdata",
}},
},
{
Expand All @@ -106,7 +106,7 @@ var InsertQueries = []WriteQueryTest{
uint8(math.MaxUint8), uint16(math.MaxUint16), uint32(math.MaxUint32), uint64(math.MaxUint64),
float32(math.MaxFloat32), float64(math.MaxFloat64),
MustConvert(sql.Timestamp.Convert("2037-04-05 12:51:36")), MustConvert(sql.Date.Convert("2231-11-07")),
"random text", sql.True, ([]byte)(`{"key":"value"}`), "blobdata",
"random text", sql.True, MustJSON(`{"key":"value"}`), "blobdata",
}},
},
{
Expand All @@ -124,7 +124,7 @@ var InsertQueries = []WriteQueryTest{
uint8(0), uint16(0), uint32(0), uint64(0),
float32(math.SmallestNonzeroFloat32), float64(math.SmallestNonzeroFloat64),
sql.Timestamp.Zero(), sql.Date.Zero(),
"", sql.False, ([]byte)(`""`), "",
"", sql.False, MustJSON(`""`), "",
}},
},
{
Expand All @@ -142,7 +142,7 @@ var InsertQueries = []WriteQueryTest{
uint8(0), uint16(0), uint32(0), uint64(0),
float32(math.SmallestNonzeroFloat32), float64(math.SmallestNonzeroFloat64),
sql.Timestamp.Zero(), sql.Date.Zero(),
"", sql.False, ([]byte)(`""`), "",
"", sql.False, MustJSON(`""`), "",
}},
},
{
Expand Down
2 changes: 1 addition & 1 deletion enginetest/queries.go
Original file line number Diff line number Diff line change
Expand Up @@ -1993,7 +1993,7 @@ var QueryTests = []QueryTest{
},
},
{
Query: `SELECT JSON_EXTRACT("foo", "$")`,
Query: `SELECT JSON_EXTRACT('"foo"', "$")`,
Expected: []sql.Row{{"foo"}},
},
{
Expand Down
8 changes: 4 additions & 4 deletions enginetest/replace_queries.go
Original file line number Diff line number Diff line change
Expand Up @@ -86,7 +86,7 @@ var ReplaceQueries = []WriteQueryTest{
uint8(math.MaxUint8), uint16(math.MaxUint16), uint32(math.MaxUint32), uint64(math.MaxUint64),
float32(math.MaxFloat32), float64(math.MaxFloat64),
MustConvert(sql.Timestamp.Convert("2037-04-05 12:51:36")), MustConvert(sql.Date.Convert("2231-11-07")),
"random text", sql.True, ([]byte)(`{"key":"value"}`), "blobdata",
"random text", sql.True, MustJSON(`{"key":"value"}`), "blobdata",
}},
},
{
Expand All @@ -104,7 +104,7 @@ var ReplaceQueries = []WriteQueryTest{
uint8(math.MaxUint8), uint16(math.MaxUint16), uint32(math.MaxUint32), uint64(math.MaxUint64),
float32(math.MaxFloat32), float64(math.MaxFloat64),
MustConvert(sql.Timestamp.Convert("2037-04-05 12:51:36")), MustConvert(sql.Date.Convert("2231-11-07")),
"random text", sql.True, ([]byte)(`{"key":"value"}`), "blobdata",
"random text", sql.True, MustJSON(`{"key":"value"}`), "blobdata",
}},
},
{
Expand All @@ -122,7 +122,7 @@ var ReplaceQueries = []WriteQueryTest{
uint8(0), uint16(0), uint32(0), uint64(0),
float32(math.SmallestNonzeroFloat32), float64(math.SmallestNonzeroFloat64),
sql.Timestamp.Zero(), sql.Date.Zero(),
"", sql.False, ([]byte)(`""`), "",
"", sql.False, MustJSON(`""`), "",
}},
},
{
Expand All @@ -140,7 +140,7 @@ var ReplaceQueries = []WriteQueryTest{
uint8(0), uint16(0), uint32(0), uint64(0),
float32(math.SmallestNonzeroFloat32), float64(math.SmallestNonzeroFloat64),
sql.Timestamp.Zero(), sql.Date.Zero(),
"", sql.False, ([]byte)(`""`), "",
"", sql.False, MustJSON(`""`), "",
}},
},
{
Expand Down
3 changes: 3 additions & 0 deletions sql/errors.go
Original file line number Diff line number Diff line change
Expand Up @@ -57,6 +57,9 @@ var (
// node or expression is called with an invalid child type. This error is indicative of a bug.
ErrInvalidChildType = errors.NewKind("%T: invalid child type, got %T, expected %T")

// ErrInvalidJSONText is returned when a JSON string cannot be parsed or unmarshalled
ErrInvalidJSONText = errors.NewKind("Invalid JSON text: %s")

// ErrDeleteRowNotFound
ErrDeleteRowNotFound = errors.NewKind("row was not found when attempting to delete")

Expand Down
40 changes: 19 additions & 21 deletions sql/expression/function/json_extract.go
Original file line number Diff line number Diff line change
Expand Up @@ -15,8 +15,8 @@
package function

import (
"encoding/json"
"fmt"
"github.com/dolthub/go-mysql-server/sql/expression"
"strings"

"github.com/oliveagle/jsonpath"
Expand All @@ -41,9 +41,26 @@ func NewJSONExtract(args ...sql.Expression) (sql.Expression, error) {
return nil, sql.ErrInvalidArgumentNumber.New("JSON_EXTRACT", 2, len(args))
}

// TODO(andy) make this an analysis step
args, err := maybeConvertLiteral(args...)
if err != nil {
return nil, err
}

return &JSONExtract{args[0], args[1:]}, nil
}

func maybeConvertLiteral(args ...sql.Expression) ([]sql.Expression, error) {
if lit, ok := args[0].(*expression.Literal); ok {
json, err := expression.JSONLiteralFromLiteral(lit)
if err != nil {
return nil, err
}
args[0] = json
}
return args, nil
}

// FunctionName implements sql.FunctionExpression
func (j *JSONExtract) FunctionName() string {
return "json_extract"
Expand All @@ -67,12 +84,7 @@ func (j *JSONExtract) Eval(ctx *sql.Context, row sql.Row) (interface{}, error) {
span, ctx := ctx.Span("function.JSONExtract")
defer span.Finish()

js, err := j.JSON.Eval(ctx, row)
if err != nil {
return nil, err
}

doc, err := unmarshalVal(js)
doc, err := j.JSON.Eval(ctx, row)
if err != nil {
return nil, err
}
Expand Down Expand Up @@ -104,20 +116,6 @@ func (j *JSONExtract) Eval(ctx *sql.Context, row sql.Row) (interface{}, error) {
return result, nil
}

func unmarshalVal(v interface{}) (interface{}, error) {
v, err := sql.JSON.Convert(v)
if err != nil {
return nil, err
}

var doc interface{}
if err := json.Unmarshal(v.([]byte), &doc); err != nil {
return nil, err
}

return doc, nil
}

// IsNullable implements the sql.Expression interface.
func (j *JSONExtract) IsNullable() bool {
for _, p := range j.Paths {
Expand Down
6 changes: 3 additions & 3 deletions sql/expression/function/json_extract_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -47,14 +47,14 @@ func TestJSONExtract(t *testing.T) {
require.NoError(t, err)

json := map[string]interface{}{
"a": []interface{}{1, 2, 3, 4},
"a": []interface{}{float64(1), float64(2), float64(3), float64(4)},
"b": map[string]interface{}{
"c": "foo",
"d": true,
},
"e": []interface{}{
[]interface{}{1, 2},
[]interface{}{3, 4},
[]interface{}{float64(1), float64(2)},
[]interface{}{float64(3), float64(4)},
},
}

Expand Down
84 changes: 84 additions & 0 deletions sql/expression/json_literal.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,84 @@
// Copyright 2020-2021 Dolthub, Inc.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

package expression

import (
"encoding/json"
"fmt"

"github.com/dolthub/go-mysql-server/sql"
)

// JSONLiteral represents a JSON string literal.
type JSONLiteral struct {
json []byte
}

// NewJSONLiteral creates a new JSONLiteral expression.
func NewJSONLiteral(json string) JSONLiteral {
return JSONLiteral{ []byte(json)}
}

// Resolved implements the Expression interface.
func (l JSONLiteral) Resolved() bool {
return true
}

// IsNullable implements the Expression interface.
func (l JSONLiteral) IsNullable() bool {
return false
}

// Type implements the Expression interface.
func (l JSONLiteral) Type() sql.Type {
return sql.JSON
}

// Eval implements the Expression interface.
func (l JSONLiteral) Eval(ctx *sql.Context, row sql.Row) (doc interface{}, err error) {
if err = json.Unmarshal(l.json, &doc); err != nil {
return nil, sql.ErrInvalidJSONText.New(err.Error())
}
return doc, nil
}

func (l JSONLiteral) String() string {
return string(l.json)
}

func (l JSONLiteral) DebugString() string {
return fmt.Sprintf("JSON(%s)", string(l.json))
}

// WithChildren implements the Expression interface.
func (l JSONLiteral) WithChildren(children ...sql.Expression) (sql.Expression, error) {
if len(children) != 0 {
return nil, sql.ErrInvalidChildrenNumber.New(l, len(children), 0)
}
return l, nil
}

// Children implements the Expression interface.
func (JSONLiteral) Children() []sql.Expression {
return nil
}

func JSONLiteralFromLiteral(literal *Literal) (JSONLiteral, error) {
if _, ok := literal.Type().(sql.StringType); !ok {
return JSONLiteral{}, sql.ErrInvalidJSONText.New(literal.String())
}
s := literal.Value().(string)
return JSONLiteral{json: []byte(s)}, nil
}
2 changes: 2 additions & 0 deletions sql/expression/literal.go
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,8 @@ type Literal struct {
fieldType sql.Type
}

var _ sql.Expression = &Literal{}

// NewLiteral creates a new Literal expression.
func NewLiteral(value interface{}, fieldType sql.Type) *Literal {
// TODO(juanjux): we should probably check here if the type is sql.VarChar and the
Expand Down
6 changes: 6 additions & 0 deletions sql/json.go
Original file line number Diff line number Diff line change
Expand Up @@ -48,6 +48,12 @@ func (t jsonType) Convert(v interface{}) (doc interface{}, err error) {
switch v := v.(type) {
case []byte:
err = json.Unmarshal(v, &doc)
case string:
if err = json.Unmarshal([]byte(v), &doc); err != nil {
// if |v| does not encode a valid JSON document
// return it as a naked string value
return v, nil
}
default:
// validate that |v| can be marshalled
if _, err = json.Marshal(v); err == nil {
Expand Down
6 changes: 3 additions & 3 deletions sql/json_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -100,9 +100,9 @@ func TestJsonConvert(t *testing.T) {
expectedVal interface{}
expectedErr bool
}{
{"", []byte(`""`), false},
{[]int{1, 2}, []byte("[1,2]"), false},
{`{"a": true, "b": 3}`, []byte(`{"a":true,"b":3}`), false},
{"", mustJSON(`""`), false},
{[]int{1, 2}, []int{1, 2}, false},
{`{"a": true, "b": 3}`, mustJSON(`{"a":true,"b":3}`), false},
}

for _, test := range tests {
Expand Down
14 changes: 1 addition & 13 deletions sql/type.go
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,6 @@
package sql

import (
"encoding/json"
"fmt"
"io"
"strconv"
Expand Down Expand Up @@ -518,18 +517,7 @@ func UnderlyingType(t Type) Type {
func convertForJSON(t Type, v interface{}) (interface{}, error) {
switch t := t.(type) {
case jsonType:
val, err := t.Convert(v)
if err != nil {
return nil, err
}

var doc interface{}
err = json.Unmarshal(val.([]byte), &doc)
if err != nil {
return nil, err
}

return doc, nil
return t.Convert(v)
case arrayType:
return convertArrayForJSON(t, v)
default:
Expand Down

0 comments on commit a675dc5

Please sign in to comment.