From 0eaa6358c36253a31c73bd95b52e2395fad02242 Mon Sep 17 00:00:00 2001 From: Fs02 Date: Fri, 1 Mar 2024 20:34:39 +0900 Subject: [PATCH] Fix nested preload with duplicate ptr belongs to --- cursor.go | 8 ++++- document.go | 7 ++++ repository_test.go | 88 ++++++++++++++++++++++++++++++++++++++++++++++ 3 files changed, 102 insertions(+), 1 deletion(-) diff --git a/cursor.go b/cursor.go index 78efd4c8..d9f233da 100644 --- a/cursor.go +++ b/cursor.go @@ -92,8 +92,14 @@ func scanMulti(cur Cursor, keyField string, keyType reflect.Type, cols map[any][ key, found := doc.Value(keyField) mustTrue(found, "rel: key field not found") + needCopy := false for _, col := range cols[key] { - col.Append(doc) + if needCopy { + col.Append(doc.Copy()) + } else { + col.Append(doc) + needCopy = true + } } // create new doc for next scan diff --git a/document.go b/document.go index 5e4fa28a..34d92dfe 100644 --- a/document.go +++ b/document.go @@ -266,6 +266,13 @@ func (d Document) NewDocument() *Document { return newZeroDocument(d.rt) } +// Copy returns copy of this document +func (d Document) Copy() *Document { + rv := reflect.New(d.rt) + rv.Elem().Set(d.rv) + return NewDocument(rv) +} + // Append is alias for Assign for compatibility with internal slice interface func (d *Document) Append(o *Document) { d.Assign(o) diff --git a/repository_test.go b/repository_test.go index d6a3ee04..e90b69fb 100644 --- a/repository_test.go +++ b/repository_test.go @@ -4025,6 +4025,94 @@ func TestRepository_Preload_scanErrors(t *testing.T) { cur.AssertExpectations(t) } +type ScheduledQuestion struct { + ID int + QuestionID int + Question *Question +} + +type Question struct { + ID int + Answers []Answers +} + +type Answers struct { + ID int + QuestionID int +} + +func TestRepository_Preload_nestedWithDuplicatePtrBelongsTo(t *testing.T) { + var ( + adapter = &testAdapter{} + repo = New(adapter) + scheduledQuestions = []ScheduledQuestion{} + cur = &testCursor{} + ) + + { + adapter.On("Query", From("scheduled_questions")).Return(cur, nil).Once() + cur.On("Close").Return(nil).Once() + cur.On("Fields").Return([]string{"id", "question_id"}, nil).Once() + cur.On("Next").Return(true).Times(2) + cur.MockScan(1, 1).Once() + cur.MockScan(2, 1).Once() + cur.On("Next").Return(false).Once() + + assert.Nil(t, repo.FindAll(context.TODO(), &scheduledQuestions)) + } + + { + adapter.On("Query", From("questions").Where(In("id", 1))).Return(cur, nil).Once() + cur.On("Close").Return(nil).Once() + cur.On("Fields").Return([]string{"id"}, nil).Once() + cur.On("Next").Return(true).Times(1) + cur.MockScan(1).Once() + cur.On("Next").Return(false).Once() + + assert.Nil(t, repo.Preload(context.TODO(), &scheduledQuestions, "question")) + } + + { + adapter.On("Query", From("answers").Where(In("question_id", 1))).Return(cur, nil).Once() + cur.On("Close").Return(nil).Once() + cur.On("Fields").Return([]string{"id", "question_id"}, nil).Once() + cur.On("Next").Return(true).Times(1) + cur.MockScan(1, 1).Once() + cur.On("Next").Return(false).Once() + + assert.Nil(t, repo.Preload(context.TODO(), &scheduledQuestions, "question.answers")) + } + + assert.Equal(t, []ScheduledQuestion{ + { + ID: 1, + QuestionID: 1, + Question: &Question{ + ID: 1, + Answers: []Answers{ + { + ID: 1, + QuestionID: 1, + }, + }, + }, + }, + { + ID: 2, + QuestionID: 1, + Question: &Question{ + ID: 1, + Answers: []Answers{ + { + ID: 1, + QuestionID: 1, + }, + }, + }, + }, + }, scheduledQuestions) +} + func TestRepository_MustPreload(t *testing.T) { var ( adapter = &testAdapter{}