diff --git a/enginetest/memory_engine_test.go b/enginetest/memory_engine_test.go index 83ec8bc44a..e89400bde1 100644 --- a/enginetest/memory_engine_test.go +++ b/enginetest/memory_engine_test.go @@ -17,7 +17,6 @@ package enginetest_test import ( "fmt" "testing" - "time" "github.com/dolthub/go-mysql-server/memory" "github.com/dolthub/go-mysql-server/sql/expression" @@ -112,45 +111,14 @@ func TestSingleScript(t *testing.T) { var scripts = []enginetest.ScriptTest{ { - Name: "show create triggers", + Name: "UUIDs used in the wild.", SetUpScript: []string{ - "create table a (x int primary key)", - "create trigger a1 before insert on a for each row set new.x = new.x + 1", - "create table b (y int primary key)", - "create trigger b1 before insert on b for each row set new.x = new.x + 2", + "SET @uuid = '6ccd780c-baba-1026-9564-5b8c656024db'", }, Assertions: []enginetest.ScriptTestAssertion{ { - Query: "show create trigger a1", - Expected: []sql.Row{ - { - "a1", // Trigger - "", // sql_mode - "create trigger a1 before insert on a for each row set new.x = new.x + 1", // SQL Original Statement - sql.Collation_Default.CharacterSet().String(), // character_set_client - sql.Collation_Default.String(), // collation_connection - sql.Collation_Default.String(), // Database Collation - time.Unix(0, 0).UTC(), // Created - }, - }, - }, - { - Query: "show create trigger b1", - Expected: []sql.Row{ - { - "b1", // Trigger - "", // sql_mode - "create trigger b1 before insert on b for each row set new.x = new.x + 2", // SQL Original Statement - sql.Collation_Default.CharacterSet().String(), // character_set_client - sql.Collation_Default.String(), // collation_connection - sql.Collation_Default.String(), // Database Collation - time.Unix(0, 0).UTC(), // Created - }, - }, - }, - { - Query: "show create trigger b2", - ExpectedErr: sql.ErrTriggerDoesNotExist, + Query: `SELECT IS_UUID(@uuid)`, + Expected: []sql.Row{{int8(1)}}, }, }, }, diff --git a/enginetest/queries.go b/enginetest/queries.go index 0e1e8163d2..27202ae1d5 100755 --- a/enginetest/queries.go +++ b/enginetest/queries.go @@ -90,6 +90,10 @@ var QueryTests = []QueryTest{ {1, 50.0}, }, }, + { + Query: "select max(pk),c2 from one_pk group by c1 order by 1", + Expected: []sql.Row{{0, 1}, {1, 11}, {2, 21}, {3, 31}}, + }, { Query: "SELECT pk1, SUM(c1) FROM two_pk WHERE pk1 = 0", Expected: []sql.Row{{0, 10.0}}, diff --git a/enginetest/script_queries.go b/enginetest/script_queries.go index b7f32f1966..d72c70622b 100755 --- a/enginetest/script_queries.go +++ b/enginetest/script_queries.go @@ -320,6 +320,59 @@ var ScriptTests = []ScriptTest{ }, }, }, + { + Name: "UUIDs used in the wild.", + SetUpScript: []string{ + "SET @uuid = '6ccd780c-baba-1026-9564-5b8c656024db'", + "SET @binuuid = '0011223344556677'", + }, + Assertions: []ScriptTestAssertion{ + { + Query: `SELECT IS_UUID(UUID())`, + Expected: []sql.Row{{int8(1)}}, + }, + { + Query: `SELECT IS_UUID(@uuid)`, + Expected: []sql.Row{{int8(1)}}, + }, + { + Query: `SELECT BIN_TO_UUID(UUID_TO_BIN(@uuid))`, + Expected: []sql.Row{{"6ccd780c-baba-1026-9564-5b8c656024db"}}, + }, + { + Query: `SELECT BIN_TO_UUID(UUID_TO_BIN(@uuid, 1), 1)`, + Expected: []sql.Row{{"6ccd780c-baba-1026-9564-5b8c656024db"}}, + }, + { + Query: `SELECT UUID_TO_BIN(NULL)`, + Expected: []sql.Row{{nil}}, + }, + { + Query: `SELECT HEX(UUID_TO_BIN(@uuid))`, + Expected: []sql.Row{{"6CCD780CBABA102695645B8C656024DB"}}, + }, + { + Query: `SELECT UUID_TO_BIN(123)`, + RequiredErr: true, + }, + { + Query: `SELECT BIN_TO_UUID(123)`, + RequiredErr: true, + }, + { + Query: `SELECT BIN_TO_UUID(X'00112233445566778899aabbccddeeff')`, + Expected: []sql.Row{{"00112233-4455-6677-8899-aabbccddeeff"}}, + }, + { + Query: `SELECT BIN_TO_UUID('0011223344556677')`, + Expected: []sql.Row{{"30303131-3232-3333-3434-353536363737"}}, + }, + { + Query: `SELECT BIN_TO_UUID(@binuuid)`, + Expected: []sql.Row{{"30303131-3232-3333-3434-353536363737"}}, + }, + }, + }, { Name: "CrossDB Queries", SetUpScript: []string{ diff --git a/go.mod b/go.mod index 67373fced7..c9f5b8e913 100644 --- a/go.mod +++ b/go.mod @@ -10,6 +10,7 @@ require ( github.com/go-sql-driver/mysql v1.4.1 github.com/golang/glog v0.0.0-20160126235308-23def4e6c14b github.com/google/go-cmp v0.3.0 // indirect + github.com/google/uuid v1.2.0 github.com/hashicorp/golang-lru v0.5.3 github.com/jehiah/go-strftime v0.0.0-20171201141054-1d33003b3869 // indirect github.com/konsorten/go-windows-terminal-sequences v1.0.2 // indirect diff --git a/go.sum b/go.sum index a002fd85c2..669d8ce329 100755 --- a/go.sum +++ b/go.sum @@ -36,6 +36,8 @@ github.com/golang/protobuf v1.3.2/go.mod h1:6lQm79b+lXiMfvg/cZm0SGofjICqVBUtrP5y github.com/google/go-cmp v0.2.0/go.mod h1:oXzfMopK8JAjlY9xF4vHSVASa0yLyX7SntLO5aqRK0M= github.com/google/go-cmp v0.3.0 h1:crn/baboCvb5fXaQ0IJ1SGTsTVrWpDsCWC8EGETZijY= github.com/google/go-cmp v0.3.0/go.mod h1:8QqcDgzrUqlUb/G2PQTWiueGozuR1884gddMywk6iLU= +github.com/google/uuid v1.2.0 h1:qJYtXnJRWmpe7m/3XlyhrsLrEURqHRM2kxzoxXqyUDs= +github.com/google/uuid v1.2.0/go.mod h1:TIyPZe4MgqvfeYDBFedMoGGpEw/LqOeaOT+nhxU+yHo= github.com/hashicorp/golang-lru v0.5.3 h1:YPkqC67at8FYaadspW/6uE0COsBxS2656RLEr8Bppgk= github.com/hashicorp/golang-lru v0.5.3/go.mod h1:iADmTwqILo4mZ8BN3D2Q6+9jd8WM5uGBxy+E8yxSoD4= github.com/jehiah/go-strftime v0.0.0-20171201141054-1d33003b3869 h1:IPJ3dvxmJ4uczJe5YQdrYB16oTJlGSC/OyZDqUk9xX4= diff --git a/sql/expression/function/registry.go b/sql/expression/function/registry.go index bd1eafe5d8..590e87cf6a 100644 --- a/sql/expression/function/registry.go +++ b/sql/expression/function/registry.go @@ -33,6 +33,7 @@ var Defaults = []sql.Function{ sql.Function1{Name: "atan", Fn: NewAtan}, sql.Function1{Name: "avg", Fn: func(e sql.Expression) sql.Expression { return aggregation.NewAvg(e) }}, sql.Function1{Name: "bin", Fn: NewBin}, + sql.FunctionN{Name: "bin_to_uuid", Fn: NewBinToUUID}, sql.Function1{Name: "bit_length", Fn: NewBitlength}, sql.Function1{Name: "ceil", Fn: NewCeil}, sql.Function1{Name: "ceiling", Fn: NewCeil}, @@ -74,6 +75,7 @@ var Defaults = []sql.Function{ sql.Function2{Name: "ifnull", Fn: NewIfNull}, sql.Function2{Name: "instr", Fn: NewInstr}, sql.Function1{Name: "is_binary", Fn: NewIsBinary}, + sql.Function1{Name: "is_uuid", Fn: NewIsUUID}, sql.FunctionN{Name: "json_array", Fn: NewJSONArray}, sql.Function1{Name: "json_arrayagg", Fn: func(e sql.Expression) sql.Expression { return aggregation.NewJSONArrayAgg(e) }}, sql.FunctionN{Name: "json_array_append", Fn: NewJSONArrayAppend}, @@ -154,16 +156,18 @@ var Defaults = []sql.Function{ sql.Function1{Name: "sum", Fn: func(e sql.Expression) sql.Expression { return aggregation.NewSum(e) }}, sql.Function1{Name: "tan", Fn: NewTan}, sql.Function1{Name: "time_to_sec", Fn: NewTimeToSec}, + sql.Function2{Name: "timediff", Fn: NewTimeDiff}, sql.FunctionN{Name: "timestamp", Fn: NewTimestamp}, sql.Function1{Name: "to_base64", Fn: NewToBase64}, sql.Function1{Name: "trim", Fn: NewTrimFunc(bTrimType)}, sql.Function1{Name: "ucase", Fn: NewUpper}, sql.Function1{Name: "unhex", Fn: NewUnhex}, sql.FunctionN{Name: "unix_timestamp", Fn: NewUnixTimestamp}, - sql.FunctionN{Name: "utc_timestamp", Fn: NewUTCTimestamp}, - sql.Function2{Name: "timediff", Fn: NewTimeDiff}, sql.Function1{Name: "upper", Fn: NewUpper}, sql.NewFunction0("user", NewUser), + sql.FunctionN{Name: "utc_timestamp", Fn: NewUTCTimestamp}, + sql.Function0{Name: "uuid", Fn: NewUUIDFunc}, + sql.FunctionN{Name: "uuid_to_bin", Fn: NewUUIDToBin}, sql.FunctionN{Name: "week", Fn: NewWeek}, sql.Function1{Name: "values", Fn: NewValues}, sql.Function1{Name: "weekday", Fn: NewWeekday}, diff --git a/sql/expression/function/uuid.go b/sql/expression/function/uuid.go new file mode 100644 index 0000000000..6c845cc3ae --- /dev/null +++ b/sql/expression/function/uuid.go @@ -0,0 +1,485 @@ +// Copyright 2021 Dolthub, Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package function + +import ( + "fmt" + + "github.com/dolthub/vitess/go/sqltypes" + "github.com/dolthub/vitess/go/vt/proto/query" + "github.com/google/uuid" + + "github.com/dolthub/go-mysql-server/sql" +) + +// UUID() +// +// Returns a Universal Unique Identifier (UUID) generated according to RFC 4122, “A Universally Unique IDentifier (UUID) +// URN Namespace” (http://www.ietf.org/rfc/rfc4122.txt). A UUID is designed as a number that is globally unique in space +// and time. Two calls to UUID() are expected to generate two different values, even if these calls are performed on two +// separate devices not connected to each other. +// +// Warning Although UUID() values are intended to be unique, they are not necessarily unguessable or unpredictable. +// If unpredictability is required, UUID values should be generated some other way. UUID() returns a value that conforms +// to UUID version 1 as described in RFC 4122. The value is a 128-bit number represented as a utf8 string of five +// hexadecimal numbers in aaaaaaaa-bbbb-cccc-dddd-eeeeeeeeeeee format: + +// The first three numbers are generated from the low, middle, and high parts of a timestamp. The high part also includes +// the UUID version number. +// +// The fourth number preserves temporal uniqueness in case the timestamp value loses monotonicity +// (for example, due to daylight saving time). +// +// The fifth number is an IEEE 802 node number that provides spatial uniqueness. A random number is substituted if the +// latter is not available (for example, because the host device has no Ethernet card, or it is unknown how to find the +// hardware address of an interface on the host operating system). In this case, spatial uniqueness cannot be guaranteed. +// Nevertheless, a collision should have very low probability. +// +// The MAC address of an interface is taken into account only on FreeBSD, Linux, and Windows. On other operating systems, +// MySQL uses a randomly generated 48-bit number. +// https://dev.mysql.com/doc/refman/8.0/en/miscellaneous-functions.html#function_uuid + +type UUIDFunc struct{} + +var _ sql.FunctionExpression = &UUIDFunc{} + +func NewUUIDFunc() sql.Expression { + return UUIDFunc{} +} + +func (u UUIDFunc) String() string { + return "UUID()" +} + +func (u UUIDFunc) Type() sql.Type { + return sql.MustCreateStringWithDefaults(sqltypes.VarChar, 36) +} + +func (u UUIDFunc) Eval(ctx *sql.Context, row sql.Row) (interface{}, error) { + nUUID, err := uuid.NewUUID() + if err != nil { + return nil, err + } + + return nUUID.String(), nil +} + +func (u UUIDFunc) WithChildren(children ...sql.Expression) (sql.Expression, error) { + if len(children) != 0 { + return nil, sql.ErrInvalidChildrenNumber.New(u, len(children), 0) + } + + return UUIDFunc{}, nil +} + +func (u UUIDFunc) FunctionName() string { + return "uuid" +} + +func (u UUIDFunc) Resolved() bool { + return true +} + +// Children returns the children expressions of this expression. +func (u UUIDFunc) Children() []sql.Expression { + return nil +} + +// IsNullable returns whether the expression can be null. +func (u UUIDFunc) IsNullable() bool { + return false +} + +// IS_UUID(string_uuid) +// +// Returns 1 if the argument is a valid string-format UUID, 0 if the argument is not a valid UUID, and NULL if the +// argument is NULL. +// +// “Valid” means that the value is in a format that can be parsed. That is, it has the correct length and contains only +// the permitted characters (hexadecimal digits in any lettercase and, optionally, dashes and curly braces). + +type IsUUID struct { + child sql.Expression +} + +var _ sql.FunctionExpression = &IsUUID{} + +func NewIsUUID(arg sql.Expression) sql.Expression { + return IsUUID{child: arg} +} + +func (u IsUUID) String() string { + return fmt.Sprintf("IS_UUID(%s)", u.child) +} + +func (u IsUUID) Type() sql.Type { + return sql.Int8 +} + +func (u IsUUID) Eval(ctx *sql.Context, row sql.Row) (interface{}, error) { + str, err := u.child.Eval(ctx, row) + if err != nil { + return 0, err + } + + if str == nil { + return nil, nil + } + + switch str := str.(type) { + case string: + _, err := uuid.Parse(str) + if err != nil { + return int8(0), nil + } + + return int8(1), nil + case []byte: + _, err := uuid.ParseBytes(str) + if err != nil { + return int8(0), nil + } + + return int8(1), nil + default: + return int8(0), nil + } +} + +func (u IsUUID) WithChildren(children ...sql.Expression) (sql.Expression, error) { + if len(children) != 1 { + return nil, sql.ErrInvalidChildrenNumber.New(u, len(children), 1) + } + + return IsUUID{child: children[0]}, nil +} + +func (u IsUUID) FunctionName() string { + return "is_uuid" +} + +func (u IsUUID) Resolved() bool { + return u.child.Resolved() +} + +// Children returns the children expressions of this expression. +func (u IsUUID) Children() []sql.Expression { + return []sql.Expression{u.child} +} + +// IsNullable returns whether the expression can be null. +func (u IsUUID) IsNullable() bool { + return false +} + +// UUID_TO_BIN(string_uuid), UUID_TO_BIN(string_uuid, swap_flag) +// +// Converts a string UUID to a binary UUID and returns the result. (The IS_UUID() function description lists the +// permitted string UUID formats.) The return binary UUID is a VARBINARY(16) value. If the UUID argument is NULL, +// the return value is NULL. If any argument is invalid, an error occurs. +// +// UUID_TO_BIN() takes one or two arguments: +// +// The one-argument form takes a string UUID value. The binary result is in the same order as the string argument. +// +// The two-argument form takes a string UUID value and a flag value: +// +// If swap_flag is 0, the two-argument form is equivalent to the one-argument form. The binary result is in the same +// order as the string argument. +// +// If swap_flag is 1, the format of the return value differs: The time-low and time-high parts (the first and third +// groups of hexadecimal digits, respectively) are swapped. This moves the more rapidly varying part to the right and +// can improve indexing efficiency if the result is stored in an indexed column. +// +// Time-part swapping assumes the use of UUID version 1 values, such as are generated by the UUID() function. For UUID +// values produced by other means that do not follow version 1 format, time-part swapping provides no benefit. For +// details about version 1 format, see the UUID() function description. + +type UUIDToBin struct { + inputUUID sql.Expression + swapFlag sql.Expression +} + +var _ sql.FunctionExpression = (*UUIDToBin)(nil) + +func NewUUIDToBin(args ...sql.Expression) (sql.Expression, error) { + switch len(args) { + case 1: + return UUIDToBin{inputUUID: args[0]}, nil + case 2: + return UUIDToBin{inputUUID: args[0], swapFlag: args[1]}, nil + default: + return nil, sql.ErrInvalidArgumentNumber.New("UUID_TO_BIN", "1 or 2", len(args)) + } +} + +func (ub UUIDToBin) String() string { + if ub.swapFlag != nil { + return fmt.Sprintf("UUID_TO_BIN(%s, %s)", ub.inputUUID, ub.swapFlag) + } else { + return fmt.Sprintf("UUID_TO_BIN(%s)", ub.inputUUID) + } +} + +func (ub UUIDToBin) Type() sql.Type { + return sql.MustCreateBinary(query.Type_VARBINARY, int64(16)) +} + +func (ub UUIDToBin) Eval(ctx *sql.Context, row sql.Row) (interface{}, error) { + str, err := ub.inputUUID.Eval(ctx, row) + if err != nil { + return 0, err + } + + // Get the inputted uuid as a string. + converted, err := sql.LongText.Convert(str) + if err != nil { + return nil, err + } + + // If the UUID argument is NULL, the return value is NULL. + if converted == nil { + return nil, nil + } + + uuidAsStr, ok := converted.(string) + if !ok { + return nil, fmt.Errorf("invalid data format passed to UUID_TO_BIN") + } + + parsed, err := uuid.Parse(uuidAsStr) + if err != nil { + return nil, err + } + + // If no swap flag is passed we can return uuid's byte format as is. + if ub.swapFlag == nil { + bt, err := parsed.MarshalBinary() + if err != nil { + return nil, err + } + return string(bt), nil + } + + sf, err := ub.swapFlag.Eval(ctx, row) + if err != nil { + return nil, err + } + + sf, err = sql.Int8.Convert(sf) + if err != nil { + return nil, err + } + + // If the swap flag is 0 we can return uuid's byte format as is. + if sf == nil || sf.(int8) == 0 { + bt, err := parsed.MarshalBinary() + if err != nil { + return nil, err + } + + return string(bt), nil + } else if sf.(int8) == 1 { + encoding := swapUUIDBytes(parsed) + return string(encoding), nil + } else { + return nil, fmt.Errorf("UUID_TO_BIN received invalid swap flag") + } +} + +// swapUUIDBytes swaps the time-low and time-high parts (the first and third groups of hexadecimal digits, respectively) +func swapUUIDBytes(cur uuid.UUID) []byte { + ret := make([]byte, 16) + + copy(ret[0:2], cur[6:8]) + copy(ret[2:4], cur[4:6]) + copy(ret[4:8], cur[0:4]) + copy(ret[8:], cur[8:]) + + return ret +} + +func (ub UUIDToBin) WithChildren(children ...sql.Expression) (sql.Expression, error) { + return NewUUIDToBin(children...) +} + +func (ub UUIDToBin) FunctionName() string { + return "uuid_to_bin" +} + +func (ub UUIDToBin) Resolved() bool { + return ub.inputUUID.Resolved() +} + +// Children returns the children expressions of this expression. +func (ub UUIDToBin) Children() []sql.Expression { + if ub.swapFlag == nil { + return []sql.Expression{ub.inputUUID} + } + + return []sql.Expression{ub.inputUUID, ub.swapFlag} +} + +// IsNullable returns whether the expression can be null. +func (ub UUIDToBin) IsNullable() bool { + return false +} + +// BIN_TO_UUID(binary_uuid), BIN_TO_UUID(binary_uuid, swap_flag) + +// BIN_TO_UUID() is the inverse of UUID_TO_BIN(). It converts a binary UUID to a string UUID and returns the result. +// The binary value should be a UUID as a VARBINARY(16) value. The return value is a utf8 string of five hexadecimal +// numbers separated by dashes. (For details about this format, see the UUID() function description.) If the UUID +// argument is NULL, the return value is NULL. If any argument is invalid, an error occurs. + +// BIN_TO_UUID() takes one or two arguments: + +// The one-argument form takes a binary UUID value. The UUID value is assumed not to have its time-low and time-high +// parts swapped. The string result is in the same order as the binary argument. + +// The two-argument form takes a binary UUID value and a swap-flag value: + +// If swap_flag is 0, the two-argument form is equivalent to the one-argument form. The string result is in the same +// order as the binary argument. + +// If swap_flag is 1, the UUID value is assumed to have its time-low and time-high parts swapped. These parts are +// swapped back to their original position in the result value. + +type BinToUUID struct { + inputBinary sql.Expression + swapFlag sql.Expression +} + +var _ sql.FunctionExpression = (*BinToUUID)(nil) + +func NewBinToUUID(args ...sql.Expression) (sql.Expression, error) { + switch len(args) { + case 1: + return BinToUUID{inputBinary: args[0]}, nil + case 2: + return BinToUUID{inputBinary: args[0], swapFlag: args[1]}, nil + default: + return nil, sql.ErrInvalidArgumentNumber.New("BIN_TO_UUID", "1 or 2", len(args)) + } +} + +func (bu BinToUUID) String() string { + if bu.swapFlag != nil { + return fmt.Sprintf("BIN_TO_UUID(%s, %s)", bu.inputBinary, bu.swapFlag) + } else { + return fmt.Sprintf("BIN_TO_UUID(%s)", bu.inputBinary) + } +} + +func (bu BinToUUID) Type() sql.Type { + return sql.MustCreateStringWithDefaults(sqltypes.VarChar, 36) +} + +func (bu BinToUUID) Eval(ctx *sql.Context, row sql.Row) (interface{}, error) { + str, err := bu.inputBinary.Eval(ctx, row) + if err != nil { + return 0, err + } + + if str == nil { + return nil, nil + } + + // Get the inputted uuid as a string. + converted, err := sql.MustCreateBinary(query.Type_VARBINARY, int64(16)).Convert(str) + if err != nil { + return nil, err + } + + uuidAsByteString, ok := converted.(string) + if !ok { + return nil, fmt.Errorf("invalid data format passed to BIN_TO_UUID") + } + + asBytes := []byte(uuidAsByteString) + parsed, err := uuid.FromBytes(asBytes) + if err != nil { + return nil, err + } + + // If no swap flag is passed we can return uuid's string format as is. + if bu.swapFlag == nil { + return parsed.String(), nil + } + + sf, err := bu.swapFlag.Eval(ctx, row) + if err != nil { + return nil, err + } + + sf, err = sql.Int8.Convert(sf) + if err != nil { + return nil, err + } + + // If the swap flag is 0 we can return uuid's string format as is. + if sf.(int8) == 0 { + return parsed.String(), nil + } else if sf.(int8) == 1 { + encoding := unswapUUIDBytes(parsed) + parsed, err = uuid.FromBytes(encoding) + + if err != nil { + return nil, err + } + + return parsed.String(), nil + } else { + return nil, fmt.Errorf("UUID_TO_BIN received invalid swap flag") + } +} + +// unswapUUIDBytes unswaps the time-low and time-high parts (the third and first groups of hexadecimal digits, respectively) +func unswapUUIDBytes(cur uuid.UUID) []byte { + ret := make([]byte, 16) + + copy(ret[0:4], cur[4:8]) + copy(ret[4:6], cur[2:4]) + copy(ret[6:8], cur[0:2]) + copy(ret[8:], cur[8:]) + + return ret +} + +func (bu BinToUUID) WithChildren(children ...sql.Expression) (sql.Expression, error) { + return NewBinToUUID(children...) +} + +func (bu BinToUUID) FunctionName() string { + return "bin_to_uuid" +} + +func (bu BinToUUID) Resolved() bool { + return bu.inputBinary.Resolved() +} + +// Children returns the children expressions of this expression. +func (bu BinToUUID) Children() []sql.Expression { + if bu.swapFlag == nil { + return []sql.Expression{bu.inputBinary} + } + + return []sql.Expression{bu.inputBinary, bu.swapFlag} +} + +// IsNullable returns whether the expression can be null. +func (bu BinToUUID) IsNullable() bool { + return false +} diff --git a/sql/expression/function/uuid_test.go b/sql/expression/function/uuid_test.go new file mode 100644 index 0000000000..c9a0970ed2 --- /dev/null +++ b/sql/expression/function/uuid_test.go @@ -0,0 +1,216 @@ +// Copyright 2021 Dolthub, Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package function + +import ( + "regexp" + "testing" + + "github.com/dolthub/vitess/go/vt/proto/query" + "github.com/google/uuid" + "github.com/stretchr/testify/require" + + "github.com/dolthub/go-mysql-server/sql" + "github.com/dolthub/go-mysql-server/sql/expression" +) + +func TestUUID(t *testing.T) { + // Generate a UUID and validate that is a legitimate uuid + uuidE := NewUUIDFunc() + ctx := sql.NewEmptyContext() + + result, err := uuidE.Eval(ctx, sql.Row{nil}) + require.NoError(t, err) + + myUUID := result.(string) + _, err = uuid.Parse(myUUID) + require.NoError(t, err) + + // validate that generated uuid is legitimate for IsUUID + val := NewIsUUID(uuidE) + require.Equal(t, int8(1), eval(t, val, sql.Row{nil})) + + // Use a UUID regex as a sanity check + re2 := regexp.MustCompile(`\b[0-9a-f]{8}\b-[0-9a-f]{4}-[0-9a-f]{4}-[0-9a-f]{4}-\b[0-9a-f]{12}\b`) + require.True(t, re2.MatchString(myUUID)) +} + +func TestIsUUID(t *testing.T) { + testCases := []struct { + name string + rowType sql.Type + value interface{} + expected interface{} + }{ + {"uuid form 1", sql.LongText, "{12345678-1234-5678-1234-567812345678}", int8(1)}, + {"uuid form 2", sql.LongText, "12345678123456781234567812345678", int8(1)}, + {"uuid form 3", sql.LongText, "12345678-1234-5678-1234-567812345678", int8(1)}, + {"NULL", sql.Null, nil, nil}, + {"random int", sql.Int8, 1, int8(0)}, + {"random bool", sql.Boolean, false, int8(0)}, + {"random string", sql.LongText, "12345678-dasd-fasdf8", int8(0)}, + {"swapped uuid", sql.LongText, "5678-1234-12345678-1234-567812345678", int8(0)}, + } + + for _, tt := range testCases { + f := NewIsUUID(expression.NewLiteral(tt.value, tt.rowType)) + + t.Run(tt.name, func(t *testing.T) { + require.Equal(t, tt.expected, eval(t, f, sql.Row{nil})) + }) + + req := require.New(t) + req.False(f.IsNullable()) + } +} + +func TestUUIDToBinValid(t *testing.T) { + validTestCases := []struct { + name string + uuidType sql.Type + uuid interface{} + hasSwap bool + swapType sql.Type + swapValue interface{} + expected interface{} + }{ + {"valid uuid; swap=0", sql.LongText, "6ccd780c-baba-1026-9564-5b8c656024db", true, sql.Int8, int8(0), "6CCD780CBABA102695645B8C656024DB"}, + {"valid uuid; swap=nil", sql.LongText, "6ccd780c-baba-1026-9564-5b8c656024db", true, sql.Null, nil, "6CCD780CBABA102695645B8C656024DB"}, + {"valid uuid; swap=1", sql.LongText, "6ccd780c-baba-1026-9564-5b8c656024db", true, sql.Int8, int8(1), "1026BABA6CCD780C95645B8C656024DB"}, + {"valid uuid; no swap", sql.LongText, "6ccd780c-baba-1026-9564-5b8c656024db", false, nil, nil, "6CCD780CBABA102695645B8C656024DB"}, + {"null uuid; no swap", sql.Null, nil, false, nil, nil, nil}, + } + + for _, tt := range validTestCases { + var f sql.Expression + var err error + + if tt.hasSwap { + f, err = NewUUIDToBin(expression.NewLiteral(tt.uuid, tt.uuidType), expression.NewLiteral(tt.swapValue, tt.swapType)) + } else { + f, err = NewUUIDToBin(expression.NewLiteral(tt.uuid, tt.uuidType)) + } + + require.NoError(t, err) + + // Convert to hex to make testing easier + h := NewHex(f) + + t.Run(tt.name, func(t *testing.T) { + require.Equal(t, tt.expected, eval(t, h, sql.Row{nil})) + }) + + req := require.New(t) + req.False(f.IsNullable()) + } +} + +func TestUUIDToBinFailing(t *testing.T) { + failingTestCases := []struct { + name string + uuidType sql.Type + uuid interface{} + swapType sql.Type + swapValue interface{} + }{ + {"bad swap value", sql.LongText, "6ccd780c-baba-1026-9564-5b8c656024db", sql.Int8, int8(2)}, + {"bad uuid value", sql.LongText, "sdasdsad", sql.Int8, int8(0)}, + {"bad uuid value2", sql.Int8, int8(0), sql.Int8, int8(0)}, + } + + for _, tt := range failingTestCases { + f, err := NewUUIDToBin(expression.NewLiteral(tt.uuid, tt.uuidType), expression.NewLiteral(tt.swapValue, tt.swapType)) + require.NoError(t, err) + + t.Run(tt.name, func(t *testing.T) { + ctx := sql.NewEmptyContext() + _, err := f.Eval(ctx, sql.Row{nil}) + require.Error(t, err) + }) + } +} + +func TestBinToUUID(t *testing.T) { + // Test that UUID_TO_BIN to BIN_TO_UUID is reflexive + uuidE := eval(t, NewUUIDFunc(), sql.Row{nil}) + + f, err := NewUUIDToBin(expression.NewLiteral(uuidE, sql.LongText)) + require.NoError(t, err) + + retUUID, err := NewBinToUUID(f) + require.NoError(t, err) + + require.Equal(t, uuidE, eval(t, retUUID, sql.Row{nil})) + + // Run UUID_TO_BIN through a series of test cases. + validTestCases := []struct { + name string + uuidType sql.Type + binary interface{} + hasSwap bool + swapType sql.Type + swapValue interface{} + expected interface{} + }{ + {"valid uuid; swap=0", sql.MustCreateBinary(query.Type_VARBINARY, int64(16)), []byte("lxºº & d[e`$Û"), true, sql.Int8, int8(0), "6c78c2ba-c2ba-2026-2064-5b656024c39b"}, + {"valid uuid; swap=1", sql.MustCreateBinary(query.Type_VARBINARY, int64(16)), []byte("&ººlÍxd[e`$Û"), true, sql.Int8, int8(1), "ba6cc38d-bac2-26c2-7864-5b656024c39b"}, + {"valid uuid; no swap", sql.MustCreateBinary(query.Type_VARBINARY, int64(16)), []byte("lxºº & d[e`$Û"), false, nil, nil, "6c78c2ba-c2ba-2026-2064-5b656024c39b"}, + {"null input", sql.Null, nil, false, nil, nil, nil}, + } + + for _, tt := range validTestCases { + var f sql.Expression + var err error + + if tt.hasSwap { + f, err = NewBinToUUID(expression.NewLiteral(tt.binary, tt.uuidType), expression.NewLiteral(tt.swapValue, tt.swapType)) + } else { + f, err = NewBinToUUID(expression.NewLiteral(tt.binary, tt.uuidType)) + } + require.NoError(t, err) + + t.Run(tt.name, func(t *testing.T) { + require.Equal(t, tt.expected, eval(t, f, sql.Row{nil})) + }) + + req := require.New(t) + req.False(f.IsNullable()) + } +} + +func TestBinToUUIDFailing(t *testing.T) { + failingTestCases := []struct { + name string + uuidType sql.Type + uuid interface{} + swapType sql.Type + swapValue interface{} + }{ + {"bad swap value", sql.MustCreateBinary(query.Type_VARBINARY, int64(16)), "helo", sql.Int8, int8(2)}, + {"bad binary value", sql.MustCreateBinary(query.Type_VARBINARY, int64(16)), "sdasdsad", sql.Int8, int8(0)}, + {"bad input value", sql.Int8, int8(0), sql.Int8, int8(0)}, + } + + for _, tt := range failingTestCases { + f, err := NewBinToUUID(expression.NewLiteral(tt.uuid, tt.uuidType), expression.NewLiteral(tt.swapValue, tt.swapType)) + require.NoError(t, err) + + t.Run(tt.name, func(t *testing.T) { + ctx := sql.NewEmptyContext() + _, err := f.Eval(ctx, sql.Row{nil}) + require.Error(t, err) + }) + } +}