Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix JSON string comparison. #475

Merged
merged 2 commits into from
Jun 30, 2021
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
16 changes: 16 additions & 0 deletions enginetest/queries.go
Original file line number Diff line number Diff line change
Expand Up @@ -2688,6 +2688,22 @@ var QueryTests = []QueryTest{
Query: `SELECT JSON_UNQUOTE('"\t\\u0032"')`,
Expected: []sql.Row{{"\t2"}},
},
{
Query: `SELECT JSON_UNQUOTE(JSON_EXTRACT('{"xid":"hello"}', '$.xid')) = "hello"`,
Expected: []sql.Row{{true}},
},
{
Query: `SELECT JSON_EXTRACT('{"xid":"hello"}', '$.xid') = "hello"`,
Expected: []sql.Row{{true}},
},
{
Query: `SELECT JSON_EXTRACT('{"xid":"hello"}', '$.xid') = '"hello"'`,
Expected: []sql.Row{{false}},
},
{
Query: `SELECT JSON_UNQUOTE(JSON_EXTRACT('{"xid":null}', '$.xid'))`,
Expected: []sql.Row{{"null"}},
},
{
Query: `SELECT CONNECTION_ID()`,
Expected: []sql.Row{{uint32(1)}},
Expand Down
86 changes: 86 additions & 0 deletions internal/strings/unquote.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,86 @@
package strings

import (
"bytes"
"encoding/binary"
"encoding/hex"
"fmt"
"unicode/utf8"
)

// The implementation is taken from TiDB
// https://github.com/pingcap/tidb/blob/a594287e9f402037b06930026906547000006bb6/types/json/binary_functions.go#L89
func Unquote(s string) (string, error) {
ret := new(bytes.Buffer)
for i := 0; i < len(s); i++ {
if s[i] == '\\' {
i++
if i == len(s) {
return "", fmt.Errorf("Missing a closing quotation mark in string")
}
switch s[i] {
case '"':
ret.WriteByte('"')
case 'b':
ret.WriteByte('\b')
case 'f':
ret.WriteByte('\f')
case 'n':
ret.WriteByte('\n')
case 'r':
ret.WriteByte('\r')
case 't':
ret.WriteByte('\t')
case '\\':
ret.WriteByte('\\')
case 'u':
if i+4 > len(s) {
return "", fmt.Errorf("Invalid unicode: %s", s[i+1:])
}
char, size, err := decodeEscapedUnicode([]byte(s[i+1 : i+5]))
if err != nil {
return "", err
}
ret.Write(char[0:size])
i += 4
default:
// For all other escape sequences, backslash is ignored.
ret.WriteByte(s[i])
}
} else {
ret.WriteByte(s[i])
}
}

str := ret.String()
strlen := len(str)
// Remove prefix and suffix '"'.
if strlen > 1 {
head, tail := str[0], str[strlen-1]
if head == '"' && tail == '"' {
return str[1 : strlen-1], nil
}
}
return str, nil
}

// decodeEscapedUnicode decodes unicode into utf8 bytes specified in RFC 3629.
// According RFC 3629, the max length of utf8 characters is 4 bytes.
// And MySQL use 4 bytes to represent the unicode which must be in [0, 65536).
// The implementation is taken from TiDB:
// https://github.com/pingcap/tidb/blob/a594287e9f402037b06930026906547000006bb6/types/json/binary_functions.go#L136
func decodeEscapedUnicode(s []byte) (char [4]byte, size int, err error) {
size, err = hex.Decode(char[0:2], s)
if err != nil || size != 2 {
// The unicode must can be represented in 2 bytes.
return char, 0, err
}
var unicode uint16
err = binary.Read(bytes.NewReader(char[0:2]), binary.BigEndian, &unicode)
if err != nil {
return char, 0, err
}
size = utf8.RuneLen(rune(unicode))
utf8.EncodeRune(char[0:size], rune(unicode))
return
}
84 changes: 2 additions & 82 deletions sql/expression/function/json_unquote.go
Original file line number Diff line number Diff line change
Expand Up @@ -15,13 +15,10 @@
package function

import (
"bytes"
"encoding/binary"
"encoding/hex"
"fmt"
"reflect"
"unicode/utf8"

"github.com/dolthub/go-mysql-server/internal/strings"
"github.com/dolthub/go-mysql-server/sql"
"github.com/dolthub/go-mysql-server/sql/expression"
)
Expand Down Expand Up @@ -78,82 +75,5 @@ func (js *JSONUnquote) Eval(ctx *sql.Context, row sql.Row) (interface{}, error)
return nil, sql.ErrInvalidType.New(reflect.TypeOf(ex).String())
}

return unquote(str)
}

// The implementation is taken from TiDB
// https://github.com/pingcap/tidb/blob/a594287e9f402037b06930026906547000006bb6/types/json/binary_functions.go#L89
func unquote(s string) (string, error) {
ret := new(bytes.Buffer)
for i := 0; i < len(s); i++ {
if s[i] == '\\' {
i++
if i == len(s) {
return "", fmt.Errorf("Missing a closing quotation mark in string")
}
switch s[i] {
case '"':
ret.WriteByte('"')
case 'b':
ret.WriteByte('\b')
case 'f':
ret.WriteByte('\f')
case 'n':
ret.WriteByte('\n')
case 'r':
ret.WriteByte('\r')
case 't':
ret.WriteByte('\t')
case '\\':
ret.WriteByte('\\')
case 'u':
if i+4 > len(s) {
return "", fmt.Errorf("Invalid unicode: %s", s[i+1:])
}
char, size, err := decodeEscapedUnicode([]byte(s[i+1 : i+5]))
if err != nil {
return "", err
}
ret.Write(char[0:size])
i += 4
default:
// For all other escape sequences, backslash is ignored.
ret.WriteByte(s[i])
}
} else {
ret.WriteByte(s[i])
}
}

str := ret.String()
strlen := len(str)
// Remove prefix and suffix '"'.
if strlen > 1 {
head, tail := str[0], str[strlen-1]
if head == '"' && tail == '"' {
return str[1 : strlen-1], nil
}
}
return str, nil
}

// decodeEscapedUnicode decodes unicode into utf8 bytes specified in RFC 3629.
// According RFC 3629, the max length of utf8 characters is 4 bytes.
// And MySQL use 4 bytes to represent the unicode which must be in [0, 65536).
// The implementation is taken from TiDB:
// https://github.com/pingcap/tidb/blob/a594287e9f402037b06930026906547000006bb6/types/json/binary_functions.go#L136
func decodeEscapedUnicode(s []byte) (char [4]byte, size int, err error) {
size, err = hex.Decode(char[0:2], s)
if err != nil || size != 2 {
// The unicode must can be represented in 2 bytes.
return char, 0, err
}
var unicode uint16
err = binary.Read(bytes.NewReader(char[0:2]), binary.BigEndian, &unicode)
if err != nil {
return char, 0, err
}
size = utf8.RuneLen(rune(unicode))
utf8.EncodeRune(char[0:size], rune(unicode))
return
return strings.Unquote(str)
}
13 changes: 9 additions & 4 deletions sql/stringtype.go
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@ import (
"gopkg.in/src-d/go-errors.v1"

"github.com/dolthub/go-mysql-server/internal/regex"
istrings "github.com/dolthub/go-mysql-server/internal/strings"
)

const (
Expand Down Expand Up @@ -267,12 +268,16 @@ func (t stringType) Convert(v interface{}) (interface{}, error) {
return nil, nil
}
val = s.Decimal.String()
case JSONDocument:
if s.Val == nil {
return "", nil
case JSONValue:
str, err := s.ToString(nil)

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Should this case be handling all JSONValues, not just JSONDocuments?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I just leave as it was, but makes sense. If we use JSONValue then we need to remove the following piece of code:

if s.Val == nil {
	return "", nil
}

tests still pass without it.

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If we don't handle arbitrary JSONValues somehow in this function then JSONValues that aren't JSONDocuments won't be supported, right? Are users supposed to be able to use their own implementations of JSONValue?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Changed! Also, it fixes another bug related to JSON null values. They were translated to "" instead of "null"

if err != nil {
return nil, err
}

return s.ToString(nil)
val, err = istrings.Unquote(str)
if err != nil {
return nil, err
}
default:
return nil, ErrConvertToSQL.New(t)
}
Expand Down
1 change: 1 addition & 0 deletions sql/stringtype_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -316,6 +316,7 @@ func TestStringConvert(t *testing.T) {
{MustCreateStringWithDefaults(sqltypes.Text, 3), strings.Repeat("𒁏", int(tinyTextBlobMax/Collation_Default.CharacterSet().MaxLength())+1), nil, true},
{MustCreateBinary(sqltypes.VarBinary, 3), []byte{01, 02, 03, 04}, nil, true},
{MustCreateStringWithDefaults(sqltypes.VarChar, 3), []byte("abcd"), nil, true},
{MustCreateStringWithDefaults(sqltypes.Char, 20), JSONDocument{Val: nil}, "null", false},
}

for _, test := range tests {
Expand Down