Skip to content

Commit

Permalink
Fix panics when array ptr is used for InsertAll and DeleteAll (#250)
Browse files Browse the repository at this point in the history
* Fix panics when array ptr is used for InsertAll and DeleteAll

* update test function name

* rename function name
  • Loading branch information
Fs02 committed Nov 5, 2021
1 parent e534415 commit 6aa81d4
Show file tree
Hide file tree
Showing 5 changed files with 72 additions and 25 deletions.
4 changes: 2 additions & 2 deletions association.go
Expand Up @@ -131,7 +131,7 @@ func (a Association) ReferenceField() string {

// ReferenceValue of the association.
func (a Association) ReferenceValue() interface{} {
return indirect(a.rv.Field(a.data.referenceIndex))
return indirectInterface(a.rv.Field(a.data.referenceIndex))
}

// ForeignField of the association.
Expand All @@ -154,7 +154,7 @@ func (a Association) ForeignValue() interface{} {
rv = rv.Elem()
}

return indirect(rv.Field(a.data.foreignIndex))
return indirectInterface(rv.Field(a.data.foreignIndex))
}

// Through return intermediary association.
Expand Down
16 changes: 10 additions & 6 deletions collection.go
Expand Up @@ -32,7 +32,7 @@ func (c *Collection) Table() string {
return tn.Table()
}

return tableName(c.rt.Elem())
return tableName(indirectReflectType(c.rt.Elem()))
}

// PrimaryFields column name of this collection.
Expand Down Expand Up @@ -77,7 +77,7 @@ func (c Collection) PrimaryValues() []interface{} {
)

for j := range values {
values[j] = c.rv.Index(j).Field(index[i]).Interface()
values[j] = c.rvIndex(j).Field(index[i]).Interface()
}

pValues[i] = values
Expand All @@ -89,7 +89,7 @@ func (c Collection) PrimaryValues() []interface{} {
)

for i := 0; i < c.rv.Len(); i++ {
for j, id := range c.rv.Index(i).Interface().(primary).PrimaryValues() {
for j, id := range c.rvIndex(i).Interface().(primary).PrimaryValues() {
tmp[j] = append(tmp[j], id)
}
}
Expand All @@ -112,9 +112,13 @@ func (c Collection) PrimaryValue() interface{} {
panic("rel: composite primary key is not supported")
}

func (c Collection) rvIndex(index int) reflect.Value {
return reflect.Indirect(c.rv.Index(index))
}

// Get an element from the underlying slice as a document.
func (c Collection) Get(index int) *Document {
return NewDocument(c.rv.Index(index).Addr())
return NewDocument(c.rvIndex(index).Addr())
}

// Len of the underlying slice.
Expand All @@ -141,7 +145,7 @@ func (c Collection) Add() *Document {

c.rv.Set(reflect.Append(c.rv, drv))

return NewDocument(c.rv.Index(index).Addr())
return NewDocument(c.rvIndex(index).Addr())
}

// Truncate collection.
Expand Down Expand Up @@ -202,6 +206,6 @@ func newCollection(v interface{}, rv reflect.Value, readonly bool) *Collection {
v: v,
rv: rv,
rt: rt,
data: extractDocumentData(rt.Elem(), false),
data: extractDocumentData(indirectReflectType(rt.Elem()), false),
}
}
16 changes: 1 addition & 15 deletions document.go
Expand Up @@ -276,15 +276,10 @@ func (d Document) Scanners(fields []string) []interface{} {
result = make([]interface{}, len(fields))
)

val := d.rv
if val.Kind() == reflect.Ptr {
val = reflect.Indirect(val)
}

for index, field := range fields {
if structIndex, ok := d.data.index[field]; ok {
var (
fv = val.Field(structIndex)
fv = d.rv.Field(structIndex)
ft = fv.Type()
)

Expand Down Expand Up @@ -393,9 +388,6 @@ func newDocument(v interface{}, rv reflect.Value, readonly bool) *Document {
}
rt = rt.Elem()
}
if rt.Kind() == reflect.Ptr {
rt = rt.Elem()
}

if rt.Kind() != reflect.Struct {
panic("rel: must be a struct or pointer to a struct")
Expand All @@ -410,9 +402,6 @@ func newDocument(v interface{}, rv reflect.Value, readonly bool) *Document {
}

func extractDocumentData(rt reflect.Type, skipAssoc bool) documentData {
if rt.Kind() == reflect.Ptr {
rt = rt.Elem()
}
if data, cached := documentDataCache.Load(rt); cached {
return data.(documentData)
}
Expand Down Expand Up @@ -583,9 +572,6 @@ func searchPrimary(rt reflect.Type) ([]string, []int) {
}

func tableName(rt reflect.Type) string {
if rt.Kind() == reflect.Ptr {
rt = rt.Elem()
}
// check for cache
if name, cached := tablesCache.Load(rt); cached {
return name.(string)
Expand Down
51 changes: 50 additions & 1 deletion repository_test.go
Expand Up @@ -415,7 +415,7 @@ func TestRepository_FindAll(t *testing.T) {
cur.AssertExpectations(t)
}

func TestRepository_FindAllPointer(t *testing.T) {
func TestRepository_FindAll_ptrElem(t *testing.T) {
var (
users []*User
adapter = &testAdapter{}
Expand Down Expand Up @@ -1215,6 +1215,41 @@ func TestRepository_InsertAll_compositePrimaryFields(t *testing.T) {
adapter.AssertExpectations(t)
}

func TestRepository_InsertAll_ptrElem(t *testing.T) {
var (
users = []*User{
{Name: "name1"},
{Name: "name2", Age: 12},
}
adapter = &testAdapter{}
repo = New(adapter)
mutates = []map[string]Mutate{
{
"name": Set("name", "name1"),
"age": Set("age", 0),
"created_at": Set("created_at", Now()),
"updated_at": Set("updated_at", Now()),
},
{
"name": Set("name", "name2"),
"age": Set("age", 12),
"created_at": Set("created_at", Now()),
"updated_at": Set("updated_at", Now()),
},
}
)

adapter.On("InsertAll", From("users"), mock.Anything, mutates).Return([]interface{}{1, 2}, nil).Once()

assert.Nil(t, repo.InsertAll(context.TODO(), &users))
assert.Equal(t, []*User{
{ID: 1, Name: "name1", Age: 0, CreatedAt: Now(), UpdatedAt: Now()},
{ID: 2, Name: "name2", Age: 12, CreatedAt: Now(), UpdatedAt: Now()},
}, users)

adapter.AssertExpectations(t)
}

func TestRepository_InsertAll_empty(t *testing.T) {
var (
users []User
Expand Down Expand Up @@ -2914,6 +2949,20 @@ func TestRepository_DeleteAll_emptySlice(t *testing.T) {
adapter.AssertExpectations(t)
}

func TestRepository_DeleteAll_ptrElem(t *testing.T) {
var (
adapter = &testAdapter{}
repo = New(adapter)
users = []*User{{ID: 1}}
)

adapter.On("Delete", From("users").Where(In("id", users[0].ID))).Return(1, nil).Once()

assert.Nil(t, repo.DeleteAll(context.TODO(), &users))

adapter.AssertExpectations(t)
}

func TestRepository_MustDeleteAll(t *testing.T) {
var (
adapter = &testAdapter{}
Expand Down
10 changes: 9 additions & 1 deletion util.go
Expand Up @@ -7,7 +7,7 @@ import (
"strings"
)

func indirect(rv reflect.Value) interface{} {
func indirectInterface(rv reflect.Value) interface{} {
if rv.Kind() == reflect.Ptr {
if rv.IsNil() {
return nil
Expand All @@ -19,6 +19,14 @@ func indirect(rv reflect.Value) interface{} {
return rv.Interface()
}

func indirectReflectType(rt reflect.Type) reflect.Type {
if rt.Kind() == reflect.Ptr {
return rt.Elem()
}

return rt
}

func must(err error) {
if err != nil {
panic(err)
Expand Down

0 comments on commit 6aa81d4

Please sign in to comment.