diff --git a/scan.go b/scan.go index 415b9f0d74..6e1177cfb1 100644 --- a/scan.go +++ b/scan.go @@ -66,13 +66,21 @@ func (db *DB) scanIntoStruct(rows Rows, reflectValue reflect.Value, values []int db.RowsAffected++ db.AddError(rows.Scan(values...)) joinedNestedSchemaMap := make(map[string]interface{}) + fieldsWithValueMap := make(map[string]bool) for idx, field := range fields { if field == nil { continue } if len(joinFields) == 0 || len(joinFields[idx]) == 0 { - db.AddError(field.Set(db.Statement.Context, reflectValue, values[idx])) + fieldIsEmbeddedPointerTypeStruct := len(field.BindNames) > 1 && len(field.StructField.Index) > 0 && field.StructField.Index[0] < 0 + fieldValue := reflect.ValueOf(values[idx]).Elem() + if !fieldIsEmbeddedPointerTypeStruct && fieldValue.Kind() == reflect.Ptr && fieldValue.IsNil() { + db.AddError(field.Set(db.Statement.Context, reflectValue, field.DefaultValueInterface)) + } else { + fieldsWithValueMap[field.BindName()] = fieldValue.Kind() == reflect.Ptr && !fieldValue.IsNil() + db.AddError(field.Set(db.Statement.Context, reflectValue, values[idx])) + } } else { // joinFields count is larger than 2 when using join var isNilPtrValue bool var relValue reflect.Value @@ -109,6 +117,37 @@ func (db *DB) scanIntoStruct(rows Rows, reflectValue reflect.Value, values []int // release data to pool field.NewValuePool.Put(values[idx]) } + + if dest := reflect.Indirect(db.Statement.ReflectValue); len(db.Statement.Clauses) == 0 && dest.Kind() == reflect.Struct { + resetEmbeddedPointerTypeStruct(dest, db.Statement.Schema, fieldsWithValueMap) + } +} + +func resetEmbeddedPointerTypeStruct(dest reflect.Value, schema *schema.Schema, fieldsWithValueMap map[string]bool) { + for i := 0; i < dest.NumField(); i++ { + field := schema.ParseField(dest.Type().Field(i)) + if field.EmbeddedSchema != nil && field.FieldType.Kind() == reflect.Ptr { + if !wasValueScannedIntoEmbeddedStruct(field, fieldsWithValueMap) && dest.Field(i).Kind() == reflect.Ptr && !dest.Field(i).IsNil() { + dest.Field(i).Set(reflect.Zero(dest.Field(i).Type())) + } + } + } +} + +func wasValueScannedIntoEmbeddedStruct(field *schema.Field, fieldsWithValueMap map[string]bool) bool { + if fieldsWithValueMap[field.BindName()] { + return true + } + + if field.EmbeddedSchema != nil { + for _, embeddedField := range field.EmbeddedSchema.Fields { + if wasValueScannedIntoEmbeddedStruct(embeddedField, fieldsWithValueMap) { + return true + } + } + } + + return false } // ScanMode scan data mode diff --git a/tests/go.mod b/tests/go.mod index 3d3901d931..827f2a4d02 100644 --- a/tests/go.mod +++ b/tests/go.mod @@ -17,7 +17,7 @@ require ( require ( filippo.io/edwards25519 v1.1.0 // indirect github.com/davecgh/go-spew v1.1.1 // indirect - github.com/go-sql-driver/mysql v1.8.0 // indirect + github.com/go-sql-driver/mysql v1.8.1 // indirect github.com/golang-sql/civil v0.0.0-20220223132316-b832511892a9 // indirect github.com/golang-sql/sqlexp v0.1.0 // indirect github.com/jackc/pgpassfile v1.0.0 // indirect diff --git a/tests/scan_test.go b/tests/scan_test.go index 6f2e9f54dd..0c756a0c95 100644 --- a/tests/scan_test.go +++ b/tests/scan_test.go @@ -5,6 +5,7 @@ import ( "sort" "strings" "testing" + "time" "gorm.io/gorm" . "gorm.io/gorm/utils/tests" @@ -126,7 +127,7 @@ func TestScanRows(t *testing.T) { rows, err := DB.Table("users").Where("name = ? or name = ?", user2.Name, user3.Name).Select("name, age").Rows() if err != nil { - t.Errorf("Not error should happen, got %v", err) + t.Errorf("No error should happen, got %v", err) } type Result struct { @@ -148,7 +149,7 @@ func TestScanRows(t *testing.T) { }) if !reflect.DeepEqual(results, []Result{{Name: "ScanRowsUser2", Age: 10}, {Name: "ScanRowsUser3", Age: 20}}) { - t.Errorf("Should find expected results") + t.Errorf("Should find expected results, got %+v", results) } var ages int @@ -158,7 +159,104 @@ func TestScanRows(t *testing.T) { var name string if err := DB.Table("users").Where("name = ?", user2.Name).Select("name").Scan(&name).Error; err != nil || name != user2.Name { - t.Fatalf("failed to scan ages, got error %v, ages: %v", err, name) + t.Fatalf("failed to scan name, got error %v, name: %v", err, name) + } +} + +func TestScanRowsNullValuesScanToFieldDefault(t *testing.T) { + DB.Save(&User{}) + + rows, err := DB.Table("users"). + Select(` + NULL AS bool_field, + NULL AS int_field, + NULL AS int8_field, + NULL AS int16_field, + NULL AS int32_field, + NULL AS int64_field, + NULL AS uint_field, + NULL AS uint8_field, + NULL AS uint16_field, + NULL AS uint32_field, + NULL AS uint64_field, + NULL AS float32_field, + NULL AS float64_field, + NULL AS string_field, + NULL AS time_field, + NULL AS time_ptr_field, + NULL AS embedded_int_field, + NULL AS nested_embedded_int_field, + NULL AS embedded_ptr_int_field + `).Rows() + if err != nil { + t.Errorf("No error should happen, got %v", err) + } + + type NestedEmbeddedStruct struct { + NestedEmbeddedIntField int + } + + type EmbeddedStruct struct { + EmbeddedIntField int + NestedEmbeddedStruct `gorm:"embedded"` + } + + type EmbeddedPtrStruct struct { + EmbeddedPtrIntField int + *NestedEmbeddedStruct `gorm:"embedded"` + } + + type Result struct { + BoolField bool + IntField int + Int8Field int8 + Int16Field int16 + Int32Field int32 + Int64Field int64 + UIntField uint + UInt8Field uint8 + UInt16Field uint16 + UInt32Field uint32 + UInt64Field uint64 + Float32Field float32 + Float64Field float64 + StringField string + TimeField time.Time + TimePtrField *time.Time + EmbeddedStruct `gorm:"embedded"` + *EmbeddedPtrStruct `gorm:"embedded"` + } + + currTime := time.Now() + result := Result{ + BoolField: true, + IntField: 1, + Int8Field: 1, + Int16Field: 1, + Int32Field: 1, + Int64Field: 1, + UIntField: 1, + UInt8Field: 1, + UInt16Field: 1, + UInt32Field: 1, + UInt64Field: 1, + Float32Field: 1.1, + Float64Field: 1.1, + StringField: "hello", + TimeField: currTime, + TimePtrField: &currTime, + EmbeddedStruct: EmbeddedStruct{EmbeddedIntField: 1, NestedEmbeddedStruct: NestedEmbeddedStruct{NestedEmbeddedIntField: 1}}, + EmbeddedPtrStruct: &EmbeddedPtrStruct{EmbeddedPtrIntField: 1, NestedEmbeddedStruct: &NestedEmbeddedStruct{NestedEmbeddedIntField: 1}}, + } + + for rows.Next() { + if err := DB.ScanRows(rows, &result); err != nil { + t.Errorf("should get no error, but got %v", err) + } + } + + if !reflect.DeepEqual(result, Result{}) { + t.Errorf("Should find zero values in struct fields, got %+v", result) } }