Skip to content

Commit

Permalink
fix: nested preload with join panic when find (#6877)
Browse files Browse the repository at this point in the history
  • Loading branch information
black-06 committed Mar 9, 2024
1 parent c4c9aa4 commit e4e23d2
Show file tree
Hide file tree
Showing 2 changed files with 27 additions and 4 deletions.
21 changes: 17 additions & 4 deletions callbacks/preload.go
Expand Up @@ -121,10 +121,23 @@ func preloadEntryPoint(db *gorm.DB, joins []string, relationships *schema.Relati
}
} else if rel := relationships.Relations[name]; rel != nil {
if joined, nestedJoins := isJoined(name); joined {
reflectValue := rel.Field.ReflectValueOf(db.Statement.Context, db.Statement.ReflectValue)
tx := preloadDB(db, reflectValue, reflectValue.Interface())
if err := preloadEntryPoint(tx, nestedJoins, &tx.Statement.Schema.Relationships, preloadMap[name], associationsConds); err != nil {
return err
switch rv := db.Statement.ReflectValue; rv.Kind() {
case reflect.Slice, reflect.Array:
for i := 0; i < rv.Len(); i++ {
reflectValue := rel.Field.ReflectValueOf(db.Statement.Context, rv.Index(i))
tx := preloadDB(db, reflectValue, reflectValue.Interface())
if err := preloadEntryPoint(tx, nestedJoins, &tx.Statement.Schema.Relationships, preloadMap[name], associationsConds); err != nil {
return err
}
}
case reflect.Struct:
reflectValue := rel.Field.ReflectValueOf(db.Statement.Context, rv)
tx := preloadDB(db, reflectValue, reflectValue.Interface())
if err := preloadEntryPoint(tx, nestedJoins, &tx.Statement.Schema.Relationships, preloadMap[name], associationsConds); err != nil {
return err
}
default:
return gorm.ErrInvalidData
}
} else {
tx := db.Table("").Session(&gorm.Session{Context: db.Statement.Context, SkipHooks: db.Statement.SkipHooks})
Expand Down
10 changes: 10 additions & 0 deletions tests/preload_test.go
Expand Up @@ -8,6 +8,8 @@ import (
"sync"
"testing"

"github.com/stretchr/testify/require"

"gorm.io/gorm"
"gorm.io/gorm/clause"
. "gorm.io/gorm/utils/tests"
Expand Down Expand Up @@ -362,6 +364,14 @@ func TestNestedPreloadWithNestedJoin(t *testing.T) {
t.Errorf("failed to find value, got err: %v", err)
}
AssertEqual(t, find2, value)

var finds []Value
err = DB.Joins("Nested.Join").Joins("Nested").Preload("Nested.Preloads").Find(&finds).Error
if err != nil {
t.Errorf("failed to find value, got err: %v", err)
}
require.Len(t, finds, 1)
AssertEqual(t, finds[0], value)
}

func TestEmbedPreload(t *testing.T) {
Expand Down

0 comments on commit e4e23d2

Please sign in to comment.