diff --git a/CHANGELOG.md b/CHANGELOG.md index f1868ac..b524115 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -6,6 +6,13 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0 ## [Unreleased] +## [5.20.0] - 2023-06-17 +### Added +- Expanded Option type SQL Value support to handle value custom types and honour the `driver.Valuer` interface. + +### Changed +- Option sql.Scanner to support custom types. + ## [5.19.0] - 2023-06-14 ### Added - strconvext.ParseBool(...) which is a drop-in replacement for the std lin strconv.ParseBool(..) with a few more supported values. @@ -55,7 +62,8 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0 ### Added - Added `timext.NanoTime` for fast low level monotonic time with nanosecond precision. -[Unreleased]: https://github.com/go-playground/pkg/compare/v5.19.0...HEAD +[Unreleased]: https://github.com/go-playground/pkg/compare/v5.20.0...HEAD +[5.20.0]: https://github.com/go-playground/pkg/compare/v5.19.0..v5.20.0 [5.19.0]: https://github.com/go-playground/pkg/compare/v5.18.0..v5.19.0 [5.18.0]: https://github.com/go-playground/pkg/compare/v5.17.2..v5.18.0 [5.17.2]: https://github.com/go-playground/pkg/compare/v5.17.1..v5.17.2 diff --git a/README.md b/README.md index d6a0445..b0e92f1 100644 --- a/README.md +++ b/README.md @@ -1,6 +1,6 @@ # pkg -![Project status](https://img.shields.io/badge/version-5.19.0-green.svg) +![Project status](https://img.shields.io/badge/version-5.20.0-green.svg) [![Lint & Test](https://github.com/go-playground/pkg/actions/workflows/go.yml/badge.svg)](https://github.com/go-playground/pkg/actions/workflows/go.yml) [![Coverage Status](https://coveralls.io/repos/github/go-playground/pkg/badge.svg?branch=master)](https://coveralls.io/github/go-playground/pkg?branch=master) [![GoDoc](https://godoc.org/github.com/go-playground/pkg?status.svg)](https://pkg.go.dev/mod/github.com/go-playground/pkg/v5) diff --git a/values/option/option.go b/values/option/option.go index ddafb6a..68f740a 100644 --- a/values/option/option.go +++ b/values/option/option.go @@ -15,6 +15,12 @@ import ( var ( scanType = reflect.TypeOf((*sql.Scanner)(nil)).Elem() byteSliceType = reflect.TypeOf(([]byte)(nil)) + valuerType = reflect.TypeOf((*driver.Valuer)(nil)).Elem() + timeType = reflect.TypeOf((*time.Time)(nil)).Elem() + stringType = reflect.TypeOf((*string)(nil)).Elem() + int64Type = reflect.TypeOf((*int64)(nil)).Elem() + float64Type = reflect.TypeOf((*float64)(nil)).Elem() + boolType = reflect.TypeOf((*bool)(nil)).Elem() ) // Option represents a values that represents a values existence. @@ -97,11 +103,43 @@ func (o *Option[T]) UnmarshalJSON(data []byte) error { } // Value implements the driver.Valuer interface. +// +// This honours the `driver.Valuer` interface if the value implements it. +// It also supports custom types of the std types and treats all else as []byte/ func (o Option[T]) Value() (driver.Value, error) { - if o.isSome { - return o.Unwrap(), nil + if o.IsNone() { + return nil, nil + } + value := o.Unwrap() + val := reflect.ValueOf(value) + + if val.Type().Implements(valuerType) { + return val.Interface().(driver.Valuer).Value() + } + switch val.Kind() { + case reflect.String: + return val.Convert(stringType).Interface(), nil + case reflect.Bool: + return val.Convert(boolType).Interface(), nil + case reflect.Int64: + return val.Convert(int64Type).Interface(), nil + case reflect.Float64: + return val.Convert(float64Type).Interface(), nil + case reflect.Slice, reflect.Array: + if val.Type().ConvertibleTo(byteSliceType) { + return val.Convert(byteSliceType).Interface(), nil + } + return json.Marshal(val.Interface()) + case reflect.Struct: + if val.CanConvert(timeType) { + return val.Convert(timeType).Interface(), nil + } + return json.Marshal(val.Interface()) + case reflect.Map: + return json.Marshal(val.Interface()) + default: + return val.Interface(), nil } - return nil, nil } // Scan implements the sql.Scanner interface. @@ -130,68 +168,68 @@ func (o *Option[T]) Scan(value any) error { if err := v.Scan(value); err != nil { return err } - *o = Some(reflect.ValueOf(v.String).Interface().(T)) + *o = Some(reflect.ValueOf(v.String).Convert(val.Type()).Interface().(T)) case reflect.Bool: var v sql.NullBool if err := v.Scan(value); err != nil { return err } - *o = Some(reflect.ValueOf(v.Bool).Interface().(T)) + *o = Some(reflect.ValueOf(v.Bool).Convert(val.Type()).Interface().(T)) case reflect.Uint8: var v sql.NullByte if err := v.Scan(value); err != nil { return err } - *o = Some(reflect.ValueOf(v.Byte).Interface().(T)) + *o = Some(reflect.ValueOf(v.Byte).Convert(val.Type()).Interface().(T)) case reflect.Float64: var v sql.NullFloat64 if err := v.Scan(value); err != nil { return err } - *o = Some(reflect.ValueOf(v.Float64).Interface().(T)) + *o = Some(reflect.ValueOf(v.Float64).Convert(val.Type()).Interface().(T)) case reflect.Int16: var v sql.NullInt16 if err := v.Scan(value); err != nil { return err } - *o = Some(reflect.ValueOf(v.Int16).Interface().(T)) + *o = Some(reflect.ValueOf(v.Int16).Convert(val.Type()).Interface().(T)) case reflect.Int32: var v sql.NullInt32 if err := v.Scan(value); err != nil { return err } - *o = Some(reflect.ValueOf(v.Int32).Interface().(T)) + *o = Some(reflect.ValueOf(v.Int32).Convert(val.Type()).Interface().(T)) case reflect.Int64: var v sql.NullInt64 if err := v.Scan(value); err != nil { return err } - *o = Some(reflect.ValueOf(v.Int64).Interface().(T)) + *o = Some(reflect.ValueOf(v.Int64).Convert(val.Type()).Interface().(T)) case reflect.Interface: - *o = Some(reflect.ValueOf(value).Interface().(T)) + *o = Some(reflect.ValueOf(value).Convert(val.Type()).Interface().(T)) case reflect.Struct: - if val.Type() == reflect.TypeOf(time.Time{}) { + if val.CanConvert(timeType) { switch t := value.(type) { case string: tm, err := time.Parse(time.RFC3339Nano, t) if err != nil { return err } - *o = Some(reflect.ValueOf(tm).Interface().(T)) + *o = Some(reflect.ValueOf(tm).Convert(val.Type()).Interface().(T)) case []byte: tm, err := time.Parse(time.RFC3339Nano, string(t)) if err != nil { return err } - *o = Some(reflect.ValueOf(tm).Interface().(T)) + *o = Some(reflect.ValueOf(tm).Convert(val.Type()).Interface().(T)) default: var v sql.NullTime if err := v.Scan(value); err != nil { return err } - *o = Some(reflect.ValueOf(v.Time).Interface().(T)) + *o = Some(reflect.ValueOf(v.Time).Convert(val.Type()).Interface().(T)) } return nil } diff --git a/values/option/option_test.go b/values/option/option_test.go index 0d213f6..3b6cdf3 100644 --- a/values/option/option_test.go +++ b/values/option/option_test.go @@ -4,13 +4,163 @@ package optionext import ( + "database/sql/driver" "encoding/json" + "reflect" "testing" "time" . "github.com/go-playground/assert/v2" ) +type valueTest struct { +} + +func (valueTest) Value() (driver.Value, error) { + return "value", nil +} + +type customStringType string + +type testStructType struct { + Name string +} + +func TestSQLDriverValue(t *testing.T) { + + var v valueTest + Equal(t, reflect.TypeOf(v).Implements(valuerType), true) + + // none + nOpt := None[string]() + nVal, err := nOpt.Value() + Equal(t, err, nil) + Equal(t, nVal, nil) + + // string + convert custom string type + sOpt := Some("myString") + sVal, err := sOpt.Value() + Equal(t, err, nil) + + _, ok := sVal.(string) + Equal(t, ok, true) + Equal(t, sVal, "myString") + + sCustOpt := Some(customStringType("string")) + sCustVal, err := sCustOpt.Value() + Equal(t, err, nil) + Equal(t, sCustVal, "string") + + _, ok = sCustVal.(string) + Equal(t, ok, true) + + // bool + bOpt := Some(true) + bVal, err := bOpt.Value() + Equal(t, err, nil) + + _, ok = bVal.(bool) + Equal(t, ok, true) + Equal(t, bVal, true) + + // int64 + iOpt := Some(int64(2)) + iVal, err := iOpt.Value() + Equal(t, err, nil) + + _, ok = iVal.(int64) + Equal(t, ok, true) + Equal(t, iVal, int64(2)) + + // float64 + fOpt := Some(1.1) + fVal, err := fOpt.Value() + Equal(t, err, nil) + + _, ok = fVal.(float64) + Equal(t, ok, true) + Equal(t, fVal, 1.1) + + // time.Time + dt := time.Now().UTC() + dtOpt := Some(dt) + dtVal, err := dtOpt.Value() + Equal(t, err, nil) + + _, ok = dtVal.(time.Time) + Equal(t, ok, true) + Equal(t, dtVal, dt) + + // Slice []byte + b := []byte("myBytes") + bytesOpt := Some(b) + bytesVal, err := bytesOpt.Value() + Equal(t, err, nil) + + _, ok = bytesVal.([]byte) + Equal(t, ok, true) + Equal(t, bytesVal, b) + + // Slice []uint8 + b2 := []uint8("myBytes") + bytes2Opt := Some(b2) + bytes2Val, err := bytes2Opt.Value() + Equal(t, err, nil) + + _, ok = bytes2Val.([]byte) + Equal(t, ok, true) + Equal(t, bytes2Val, b2) + + // Array []byte + a := []byte{'1', '2', '3'} + arrayOpt := Some(a) + arrayVal, err := arrayOpt.Value() + Equal(t, err, nil) + + _, ok = arrayVal.([]byte) + Equal(t, ok, true) + Equal(t, arrayVal, a) + + // Slice []byte + data := []testStructType{{Name: "test"}} + b, err = json.Marshal(data) + Equal(t, err, nil) + + dataOpt := Some(data) + dataVal, err := dataOpt.Value() + Equal(t, err, nil) + + _, ok = dataVal.([]byte) + Equal(t, ok, true) + Equal(t, dataVal, b) + + // Map + data2 := map[string]int{"test": 1} + b, err = json.Marshal(data2) + Equal(t, err, nil) + + data2Opt := Some(data2) + data2Val, err := data2Opt.Value() + Equal(t, err, nil) + + _, ok = data2Val.([]byte) + Equal(t, ok, true) + Equal(t, data2Val, b) + + // Struct + data3 := testStructType{Name: "test"} + b, err = json.Marshal(data3) + Equal(t, err, nil) + + data3Opt := Some(data3) + data3Val, err := data3Opt.Value() + Equal(t, err, nil) + + _, ok = data3Val.([]byte) + Equal(t, ok, true) + Equal(t, data3Val, b) +} + type customScanner struct { S string } @@ -20,7 +170,7 @@ func (c *customScanner) Scan(src interface{}) error { return nil } -func TestSQL(t *testing.T) { +func TestSQLScanner(t *testing.T) { value := int64(123) var optionI64 Option[int64] var optionI32 Option[int32] @@ -115,6 +265,12 @@ func TestSQL(t *testing.T) { err = optionMap.Scan([]byte(`{"name":"test"}`)) Equal(t, err, nil) Equal(t, optionMap, Some(map[string]any{"name": "test"})) + + // test custom types + var ct Option[customStringType] + err = ct.Scan("test") + Equal(t, err, nil) + Equal(t, ct, Some(customStringType("test"))) } func TestNilOption(t *testing.T) {