Skip to content

Commit

Permalink
feat: more versatile column matcher
Browse files Browse the repository at this point in the history
  • Loading branch information
licaonfee committed Aug 16, 2020
1 parent 109b081 commit 4f75134
Show file tree
Hide file tree
Showing 6 changed files with 153 additions and 69 deletions.
2 changes: 1 addition & 1 deletion Makefile
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@ deps:

utest:
$(GOTEST) -race -count 1 -timeout 30s -coverprofile coverage.out ./...
cover:
cover: utest
$(GOTOOL) cover -func=coverage.out

clean:
Expand Down
8 changes: 7 additions & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,13 @@ func main() {
m := magiccol.DefaultMapper()
//Use mysql native Time type see
//https://github.com/go-sql-driver/mysql#timetime-support
m.Type(reflect.TypeOf(mysql.NullTime{}), "DATE", "DATETIME", "TIMESTAMP")
custom := reflect.TypeOf(mysql.NullTime{})
match := []magiccol.Matcher{
magiccol.DatabaseTypeAs("DATE", custom),
magiccol.DatabaseTypeAs("DATETIME", custom),
magiccol.DatabaseTypeAs("TIMESTAMP", custom),
}
m.Match(match)

sc, err := magiccol.NewScanner(magiccol.Options{Rows:r, Mapper: m})
if err != nil {
Expand Down
4 changes: 2 additions & 2 deletions magiccol.go
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@ type Options struct {
//Rows must be a valid sql.Rows object
Rows Rows
//Mapper can be nil, if so DefaultMapper is used
Mapper Mapper
Mapper *Mapper
}

//Rows allow to mock sql.Rows object
Expand Down Expand Up @@ -56,7 +56,7 @@ func NewScanner(o Options) (*Scanner, error) {
values := make([]reflect.Value, len(cols))
for i := 0; i < len(cols); i++ {
t := tp[i]
refType := o.Mapper.Get(t.DatabaseTypeName(), t.ScanType())
refType := o.Mapper.Get(t)
v := reflect.New(refType)
pointers[i] = v.Interface()
values[i] = v
Expand Down
48 changes: 31 additions & 17 deletions magiccol_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,9 +2,9 @@ package magiccol_test

import (
"database/sql"
"database/sql/driver"
"errors"
"reflect"
"strings"
"testing"

"github.com/DATA-DOG/go-sqlmock"
Expand Down Expand Up @@ -81,41 +81,55 @@ func TestScan(t *testing.T) {
rowError := errors.New("row error")
tests := []struct {
name string
csvData string
rows [][]driver.Value
columns []*sqlmock.Column
want []map[string]interface{}
wantErr error
errorAt int
}{
{
name: "success",
columns: []*sqlmock.Column{
sqlmock.NewColumn("name").OfType("VARCHAR", ""),
sqlmock.NewColumn("age").OfType("INTEGER", int64(0)),
},
csvData: strings.Join([]string{`"jhon",35`, `"jeremy",29`}, "\n"),
rows: [][]driver.Value{
{"jhon", 35},
{"jeremy", 29},
},
want: []map[string]interface{}{
{"name": "jhon", "age": int64(35)},
{"name": "jeremy", "age": int64(29)},
},
wantErr: nil,
errorAt: -1,
},
{
name: "Rows error",
columns: []*sqlmock.Column{
sqlmock.NewColumn("name").OfType("VARCHAR", ""),
sqlmock.NewColumn("address").OfType("VARCHAR", ""),
},
rows: [][]driver.Value{
{"jeimy", "oak"},
{"jhon", "jhonson"},
},
csvData: `"jeimy"`,
want: nil,
wantErr: rowError,
errorAt: 1,
},
{
name: "Scan error",
columns: []*sqlmock.Column{
sqlmock.NewColumn("name").OfType("INTEGER", int64(0)),
sqlmock.NewColumn("id").OfType("INTEGER", int64(0)),
sqlmock.NewColumn("moto").OfType("VARCHAR", ""),
},
csvData: `"jeimy"`,
want: nil,
wantErr: errors.New(""),
rows: [][]driver.Value{
{11, "foo"},
{"invalidata", "bar"}},
want: []map[string]interface{}{},
wantErr: errors.New("sss"),
errorAt: -1,
},
}
for _, tt := range tests {
Expand All @@ -125,10 +139,12 @@ func TestScan(t *testing.T) {
t.Error(err)
}
r := mock.NewRowsWithColumnDefinition(tt.columns...)
r.FromCSVString(tt.csvData)
for i := 0; i < len(tt.rows); i++ {
r.AddRow(tt.rows[i]...)
}
mock.ExpectQuery("SELECT").WillReturnRows(r)
if tt.wantErr != nil {
r.RowError(0, tt.wantErr)
if tt.wantErr != nil && tt.errorAt >= 0 {
r.RowError(tt.errorAt, tt.wantErr)
}
rows, _ := db.Query("SELECT")
m, err := magiccol.NewScanner(magiccol.Options{Rows: rows})
Expand All @@ -140,15 +156,13 @@ func TestScan(t *testing.T) {
got = append(got, m.Value())
}
if m.Err() != tt.wantErr {
e := errors.New("")
if !(m.Err() != nil && tt.wantErr != nil) {
t.Errorf("Scan() err = %v , want = %v", m.Err(), tt.wantErr)
}
if !errors.As(m.Err(), &e) || m.Err().Error() != tt.wantErr.Error() {
var e error
if !errors.As(m.Err(), &e) {
t.Errorf("Scan() err = %v , want = %v", m.Err(), tt.wantErr)
}
}
if m.Err() != nil {

if m.Err() != nil && tt.wantErr != nil {
return
}
if !reflect.DeepEqual(got, tt.want) {
Expand Down
71 changes: 52 additions & 19 deletions mappers.go
Original file line number Diff line number Diff line change
Expand Up @@ -5,14 +5,20 @@ import (
"time"
)

//Mapper translate sql types to golang types
type Mapper interface {
//Get typeName should be sql type as is called in sql.ColumnType.DatabaseTypeName()
Get(typeName string, fallback reflect.Type) reflect.Type
//Type allow to set alias, extends or fix mapper behaviour
Type(t reflect.Type, asTypes ...string)
//ColumnType is identical as defined in sql.ColumnType struct
type ColumnType interface {
Name() string
DatabaseTypeName() string
ScanType() reflect.Type
Nullable() (nullable bool, ok bool)
DecimalSize() (precision int64, scale int64, ok bool)
Length() (length int64, ok bool)
}

//Matcher return a type and true if column definition match
//on a negative match reflect.Type should be null but is not mandatory
type Matcher func(ColumnType) (reflect.Type, bool)

var (
stringType = reflect.TypeOf("")
intType = reflect.TypeOf(int64(0))
Expand All @@ -24,32 +30,59 @@ var (
durationType = reflect.TypeOf(time.Duration(0))
)

//LookupMapper implements Mapper interface
type LookupMapper struct {
m map[string]reflect.Type
//Mapper translate sql types to golang types
type Mapper struct {
m map[string]reflect.Type
match []Matcher
}

//Get do a map lookup if type is not found return a ScanType itself
func (l LookupMapper) Get(typeName string, fallback reflect.Type) reflect.Type {
t, ok := l.m[typeName]
func (l *Mapper) Get(col ColumnType) reflect.Type {
for _, m := range l.match {
t, ok := m(col)
if ok {
return t
}
}
t, ok := l.m[col.DatabaseTypeName()]
if !ok {
return fallback
return col.ScanType()
}
return t
}

//Type method allow to set custom types as scanneable types
func (l *LookupMapper) Type(t reflect.Type, asType ...string) {
for _, x := range asType {
tp := x
l.m[tp] = t
//Match method allow to set custom types as scanneable types
//if m is nil then is a no-op
func (l *Mapper) Match(m ...Matcher) {
for i := 0; i < len(m); i++ {
if m[i] != nil {
l.match = append(l.match, m[i])
}
}
}

func DatabaseTypeAs(databaseTypeName string, t reflect.Type) Matcher {
return func(col ColumnType) (reflect.Type, bool) {
if col.DatabaseTypeName() == databaseTypeName {
return t, true
}
return nil, false
}
}

func ColumnNameAs(columnName string, t reflect.Type) Matcher {
return func(col ColumnType) (reflect.Type, bool) {
if col.Name() == columnName {
return t, true
}
return nil, false
}
}

//DefaultMapper provides a mapping for most common sql types
//type list reference used is:
//http://jakewheat.github.io/sql-overview/sql-2011-foundation-grammar.html#predefined-type
func DefaultMapper() Mapper {
func DefaultMapper() *Mapper {
m := map[string]reflect.Type{
//Character types
"CHARACTER": stringType,
Expand Down Expand Up @@ -96,5 +129,5 @@ func DefaultMapper() Mapper {
//Interval type
"INTERVAL": durationType,
}
return &LookupMapper{m: m}
return &Mapper{m: m}
}
89 changes: 60 additions & 29 deletions mappers_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -7,63 +7,94 @@ import (
"github.com/licaonfee/magiccol"
)

type col struct {
name string
databaseType string
scanType reflect.Type
}

func (c col) Name() string { return c.name }
func (c col) DatabaseTypeName() string { return c.databaseType }
func (c col) ScanType() reflect.Type { return c.scanType }
func (c col) Nullable() (nullable bool, ok bool) { return false, false }
func (c col) DecimalSize() (precision int64, scale int64, ok bool) { return 0, 0, false }
func (c col) Length() (length int64, ok bool) { return 0, false }

func TestMapperGet(t *testing.T) {
tests := []struct {
name string
typeName string
fallback reflect.Type
want reflect.Type
name string
col magiccol.ColumnType
match []magiccol.Matcher
want reflect.Type
}{
{
name: "char type",
typeName: "VARCHAR",
fallback: nil,
want: reflect.TypeOf(""),
name: "char type",
col: col{databaseType: "VARCHAR"},
want: reflect.TypeOf(""),
},
{
name: "fallback",
col: col{databaseType: "MISSING TYPE", scanType: reflect.TypeOf(uint8(0))},
want: reflect.TypeOf(uint8(0)),
},
{
name: "fallback",
typeName: "MISSING TYPE",
fallback: reflect.TypeOf(uint8(0)),
want: reflect.TypeOf(uint8(0)),
name: "match priority",
col: col{databaseType: "INTEGER", scanType: reflect.TypeOf(uint64(0))},
match: []magiccol.Matcher{
func(c magiccol.ColumnType) (reflect.Type, bool) {
return reflect.TypeOf(float64(0.0)), true
}},
want: reflect.TypeOf(float64(0.0)),
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
m := magiccol.DefaultMapper()
got := m.Get(tt.typeName, tt.fallback)
for i := 0; i < len(tt.match); i++ {
m.Match(tt.match[i])
}
got := m.Get(tt.col)
if got != tt.want {
t.Errorf("Get() got = %v , want = %v", got, tt.want)
}
})
}
}

func TestMapperType(t *testing.T) {
func TestMatchers(t *testing.T) {
tests := []struct {
name string
rType reflect.Type
as []string
name string
positive magiccol.ColumnType
negative magiccol.ColumnType
m magiccol.Matcher
want reflect.Type
}{
{
name: "new type",
rType: reflect.TypeOf(false),
as: []string{"MY_BOOL", "FLAGTYPE"},
name: "DatabaseTypeAs",
positive: col{databaseType: "EXOTIC"},
negative: col{databaseType: "DOUBLE"},
m: magiccol.DatabaseTypeAs("EXOTIC", reflect.TypeOf(float32(0))),
want: reflect.TypeOf(float32(0)),
},
{
name: "overwrite",
rType: reflect.TypeOf(int64(0)),
as: []string{"VARCHAR"},
name: "ColumnNameAs",
positive: col{name: "delimited_data"},
negative: col{name: "name"},
m: magiccol.ColumnNameAs("delimited_data", reflect.TypeOf([]string{})),
want: reflect.TypeOf([]string{}),
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
m := magiccol.DefaultMapper()
m.Type(tt.rType, tt.as...)
for _, n := range tt.as {
if tp := m.Get(n, nil); tp != tt.rType {
t.Errorf("Type(%s) not set %v", n, tt.rType)
}
got, ok := tt.m(tt.positive)
if !ok || got != tt.want {
t.Errorf("Matcher() got = (%v,%v) , want = (%v, %v)", got, ok, tt.want, true)
}
got, ok = tt.m(tt.negative)
if ok || got != nil {
t.Errorf("Matcher() got = (%v,%v) , want = (%v, %v)", got, ok, nil, false)
}

})
}
}

0 comments on commit 4f75134

Please sign in to comment.