From c1afb54027590da6245d918c42de7dbdc2d23eb1 Mon Sep 17 00:00:00 2001 From: "waleed.masoom" Date: Wed, 3 Apr 2024 14:18:52 -0400 Subject: [PATCH] fix: update setupValuerAndSetter to use default values when pointer types are nil --- scan.go | 6 +++++- tests/go.mod | 2 +- tests/scan_test.go | 48 ++++++++++++++++++++++++++++++++++++++++++++-- 3 files changed, 52 insertions(+), 4 deletions(-) diff --git a/scan.go b/scan.go index 415b9f0d74..8ab044cab4 100644 --- a/scan.go +++ b/scan.go @@ -72,7 +72,11 @@ func (db *DB) scanIntoStruct(rows Rows, reflectValue reflect.Value, values []int } if len(joinFields) == 0 || len(joinFields[idx]) == 0 { - db.AddError(field.Set(db.Statement.Context, reflectValue, values[idx])) + if value := reflect.ValueOf(values[idx]).Elem(); field.OwnerSchema == nil && value.Kind() == reflect.Ptr && value.IsNil() { + db.AddError(field.Set(db.Statement.Context, reflectValue, field.DefaultValueInterface)) + } else { + db.AddError(field.Set(db.Statement.Context, reflectValue, values[idx])) + } } else { // joinFields count is larger than 2 when using join var isNilPtrValue bool var relValue reflect.Value diff --git a/tests/go.mod b/tests/go.mod index 3d3901d931..827f2a4d02 100644 --- a/tests/go.mod +++ b/tests/go.mod @@ -17,7 +17,7 @@ require ( require ( filippo.io/edwards25519 v1.1.0 // indirect github.com/davecgh/go-spew v1.1.1 // indirect - github.com/go-sql-driver/mysql v1.8.0 // indirect + github.com/go-sql-driver/mysql v1.8.1 // indirect github.com/golang-sql/civil v0.0.0-20220223132316-b832511892a9 // indirect github.com/golang-sql/sqlexp v0.1.0 // indirect github.com/jackc/pgpassfile v1.0.0 // indirect diff --git a/tests/scan_test.go b/tests/scan_test.go index 6f2e9f54dd..1008a44388 100644 --- a/tests/scan_test.go +++ b/tests/scan_test.go @@ -148,7 +148,7 @@ func TestScanRows(t *testing.T) { }) if !reflect.DeepEqual(results, []Result{{Name: "ScanRowsUser2", Age: 10}, {Name: "ScanRowsUser3", Age: 20}}) { - t.Errorf("Should find expected results") + t.Errorf("Should find expected results, got %+v", results) } var ages int @@ -158,7 +158,51 @@ func TestScanRows(t *testing.T) { var name string if err := DB.Table("users").Where("name = ?", user2.Name).Select("name").Scan(&name).Error; err != nil || name != user2.Name { - t.Fatalf("failed to scan ages, got error %v, ages: %v", err, name) + t.Fatalf("failed to scan name, got error %v, name: %v", err, name) + } +} + +func TestScanRowsNullValuesScanToZeroValues(t *testing.T) { + user1 := User{Name: "ScanRowsUser1", Age: 1} + user2 := User{Name: "ScanRowsUser2", Age: 10} + user3 := User{Name: "ScanRowsUser3", Age: 0} + DB.Save(&user1).Save(&user2).Save(&user3) + + rows, err := DB.Table("users").Where("name = ? or name = ?", user2.Name, user3.Name).Select("name, CASE WHEN age = 0 THEN NULL ELSE age END AS age").Rows() + if err != nil { + t.Errorf("Not error should happen, got %v", err) + } + + type Result struct { + Name string + Age int + } + + var results []Result + var result Result + for rows.Next() { + if err := DB.ScanRows(rows, &result); err != nil { + t.Errorf("should get no error, but got %v", err) + } + results = append(results, result) + } + + sort.Slice(results, func(i, j int) bool { + return strings.Compare(results[i].Name, results[j].Name) <= -1 + }) + + if !reflect.DeepEqual(results, []Result{{Name: "ScanRowsUser2", Age: 10}, {Name: "ScanRowsUser3", Age: 0}}) { + t.Errorf("Should find expected results, got %+v", results) + } + + var ages int + if err := DB.Table("users").Where("name = ? or name = ?", user2.Name, user3.Name).Select("SUM(age)").Scan(&ages).Error; err != nil || ages != 10 { + t.Fatalf("failed to scan ages, got error %v, ages: %v", err, ages) + } + + var name string + if err := DB.Table("users").Where("name = ?", user2.Name).Select("name").Scan(&name).Error; err != nil || name != user2.Name { + t.Fatalf("failed to scan name, got error %v, name: %v", err, name) } }