From 72a65551ac4579dd2035d2b6355f23e8cf8573f1 Mon Sep 17 00:00:00 2001 From: Minh Cung Date: Sun, 29 Mar 2026 14:49:50 +1100 Subject: [PATCH 1/2] feat(rain): expand model scanning type support --- pkg/rain/model.go | 273 ++++++++++++++++++++++++++------ pkg/rain/model_internal_test.go | 203 ++++++++++++++++++++++++ 2 files changed, 429 insertions(+), 47 deletions(-) create mode 100644 pkg/rain/model_internal_test.go diff --git a/pkg/rain/model.go b/pkg/rain/model.go index aa23ada..57f29a2 100644 --- a/pkg/rain/model.go +++ b/pkg/rain/model.go @@ -2,6 +2,7 @@ package rain import ( "database/sql" + "database/sql/driver" "fmt" "reflect" "strings" @@ -9,6 +10,14 @@ import ( "time" ) +type scannerInterface = interface { + Scan(src any) error +} + +type valuerInterface = interface { + Value() (driver.Value, error) +} + type modelField struct { index []int } @@ -60,8 +69,8 @@ func buildModelMeta(meta *modelMeta, typ reflect.Type, prefix []int) { } current := append(append([]int{}, prefix...), fieldIndex) - if field.Anonymous && field.Type.Kind() == reflect.Struct { - buildModelMeta(meta, field.Type, current) + if embedded := embeddedStructType(field); embedded != nil { + buildModelMeta(meta, embedded, current) continue } @@ -77,6 +86,22 @@ func buildModelMeta(meta *modelMeta, typ reflect.Type, prefix []int) { } } +func embeddedStructType(field reflect.StructField) reflect.Type { + if !field.Anonymous { + return nil + } + + typ := field.Type + if typ.Kind() == reflect.Pointer { + typ = typ.Elem() + } + if typ.Kind() != reflect.Struct { + return nil + } + + return typ +} + func relationTagName(tag string) string { trimmed := strings.TrimSpace(tag) if trimmed == "" || trimmed == "-" { @@ -144,7 +169,10 @@ func scanCurrentRow(rows *sql.Rows, target reflect.Value) error { continue } - field := target.FieldByIndex(fieldInfo.index) + field, err := fieldByIndexAlloc(target, fieldInfo.index) + if err != nil { + return err + } scanTarget, finalize, err := prepareScanTarget(field) if err != nil { return err @@ -169,7 +197,30 @@ func scanCurrentRow(rows *sql.Rows, target reflect.Value) error { return nil } +func fieldByIndexAlloc(value reflect.Value, index []int) (reflect.Value, error) { + current := value + for position, part := range index { + field := current.Field(part) + if position < len(index)-1 && field.Kind() == reflect.Pointer { + if field.IsNil() { + if !field.CanSet() { + return reflect.Value{}, fmt.Errorf("rain: embedded pointer field %s is not settable", field.Type()) + } + field.Set(reflect.New(field.Type().Elem())) + } + current = field.Elem() + continue + } + current = field + } + return current, nil +} + func prepareScanTarget(field reflect.Value) (any, func() error, error) { + if scanTarget, finalize, ok := scannerTarget(field); ok { + return scanTarget, finalize, nil + } + if field.Kind() != reflect.Pointer { if !field.CanAddr() { return nil, nil, fmt.Errorf("rain: field %s is not addressable", field.Type()) @@ -177,54 +228,182 @@ func prepareScanTarget(field reflect.Value) (any, func() error, error) { return field.Addr().Interface(), nil, nil } - elemType := field.Type().Elem() - switch { - case elemType.Kind() == reflect.String: - holder := sql.Null[string]{} - return &holder, func() error { - if !holder.Valid { - field.Set(reflect.Zero(field.Type())) + for _, handler := range nullablePrimitiveHandlers() { + if scanTarget, finalize, ok := handler(field); ok { + return scanTarget, finalize, nil + } + } + + return nil, nil, fmt.Errorf("rain: unsupported nullable field type %s", field.Type()) +} + +type nullableHandler func(field reflect.Value) (target any, finalize func() error, ok bool) + +func nullablePrimitiveHandlers() []nullableHandler { + return []nullableHandler{ + nullableStringTarget, + nullableSignedIntTarget, + nullableUnsignedIntTarget, + nullableFloatTarget, + nullableBoolTarget, + nullableTimeTarget, + } +} + +func scannerTarget(field reflect.Value) (any, func() error, bool) { + scannerType := reflect.TypeFor[scannerInterface]() + valuerType := reflect.TypeFor[valuerInterface]() + + if field.Kind() != reflect.Pointer { + if field.CanAddr() && field.Addr().Type().Implements(scannerType) { + return field.Addr().Interface(), nil, true + } + return nil, nil, false + } + + fieldType := field.Type() + if fieldType.Implements(scannerType) && fieldType.Implements(valuerType) { + receiver := reflect.New(fieldType.Elem()) + if receiver.Type().Implements(scannerType) { + return receiver.Interface(), func() error { + field.Set(receiver) return nil - } - value := holder.V - field.Set(reflect.ValueOf(&value)) + }, true + } + } + + if fieldType.Elem().Implements(scannerType) { + receiver := reflect.New(fieldType.Elem()) + return receiver.Interface(), func() error { + field.Set(receiver) return nil - }, nil - case elemType.Kind() == reflect.Int || elemType.Kind() == reflect.Int64: - holder := sql.Null[int64]{} - return &holder, func() error { - if !holder.Valid { - field.Set(reflect.Zero(field.Type())) - return nil - } - ptr := reflect.New(elemType) - ptr.Elem().SetInt(holder.V) - field.Set(ptr) + }, true + } + + return nil, nil, false +} + +func nullableStringTarget(field reflect.Value) (any, func() error, bool) { + if field.Type().Elem().Kind() != reflect.String { + return nil, nil, false + } + + holder := sql.Null[string]{} + return &holder, func() error { + if !holder.Valid { + field.Set(reflect.Zero(field.Type())) return nil - }, nil - case elemType.Kind() == reflect.Bool: - holder := sql.Null[bool]{} - return &holder, func() error { - if !holder.Valid { - field.Set(reflect.Zero(field.Type())) - return nil - } - value := holder.V - field.Set(reflect.ValueOf(&value)) + } + value := holder.V + field.Set(reflect.ValueOf(&value)) + return nil + }, true +} + +func nullableSignedIntTarget(field reflect.Value) (any, func() error, bool) { + elemType := field.Type().Elem() + signed := map[reflect.Kind]bool{ + reflect.Int: true, + reflect.Int8: true, + reflect.Int16: true, + reflect.Int32: true, + reflect.Int64: true, + } + if !signed[elemType.Kind()] { + return nil, nil, false + } + + holder := sql.Null[int64]{} + return &holder, func() error { + if !holder.Valid { + field.Set(reflect.Zero(field.Type())) return nil - }, nil - case elemType == reflect.TypeFor[time.Time](): - holder := sql.Null[time.Time]{} - return &holder, func() error { - if !holder.Valid { - field.Set(reflect.Zero(field.Type())) - return nil - } - value := holder.V - field.Set(reflect.ValueOf(&value)) + } + ptr := reflect.New(elemType) + ptr.Elem().SetInt(holder.V) + field.Set(ptr) + return nil + }, true +} + +func nullableUnsignedIntTarget(field reflect.Value) (any, func() error, bool) { + elemType := field.Type().Elem() + unsigned := map[reflect.Kind]bool{ + reflect.Uint: true, + reflect.Uint8: true, + reflect.Uint16: true, + reflect.Uint32: true, + reflect.Uint64: true, + } + if !unsigned[elemType.Kind()] { + return nil, nil, false + } + + holder := sql.Null[int64]{} + return &holder, func() error { + if !holder.Valid { + field.Set(reflect.Zero(field.Type())) return nil - }, nil - default: - return nil, nil, fmt.Errorf("rain: unsupported nullable field type %s", field.Type()) + } + if holder.V < 0 { + return fmt.Errorf("rain: cannot scan negative value %d into %s", holder.V, field.Type()) + } + ptr := reflect.New(elemType) + ptr.Elem().SetUint(uint64(holder.V)) + field.Set(ptr) + return nil + }, true +} + +func nullableFloatTarget(field reflect.Value) (any, func() error, bool) { + elemType := field.Type().Elem() + if elemType.Kind() != reflect.Float32 && elemType.Kind() != reflect.Float64 { + return nil, nil, false + } + + holder := sql.Null[float64]{} + return &holder, func() error { + if !holder.Valid { + field.Set(reflect.Zero(field.Type())) + return nil + } + ptr := reflect.New(elemType) + ptr.Elem().SetFloat(holder.V) + field.Set(ptr) + return nil + }, true +} + +func nullableBoolTarget(field reflect.Value) (any, func() error, bool) { + if field.Type().Elem().Kind() != reflect.Bool { + return nil, nil, false } + + holder := sql.Null[bool]{} + return &holder, func() error { + if !holder.Valid { + field.Set(reflect.Zero(field.Type())) + return nil + } + value := holder.V + field.Set(reflect.ValueOf(&value)) + return nil + }, true +} + +func nullableTimeTarget(field reflect.Value) (any, func() error, bool) { + if field.Type().Elem() != reflect.TypeFor[time.Time]() { + return nil, nil, false + } + + holder := sql.Null[time.Time]{} + return &holder, func() error { + if !holder.Valid { + field.Set(reflect.Zero(field.Type())) + return nil + } + value := holder.V + field.Set(reflect.ValueOf(&value)) + return nil + }, true } diff --git a/pkg/rain/model_internal_test.go b/pkg/rain/model_internal_test.go new file mode 100644 index 0000000..a14c0a2 --- /dev/null +++ b/pkg/rain/model_internal_test.go @@ -0,0 +1,203 @@ +package rain + +import ( + "database/sql" + "database/sql/driver" + "fmt" + "path/filepath" + "strings" + "testing" + + _ "modernc.org/sqlite" +) + +type modelScanStatus string + +func (s *modelScanStatus) Scan(src any) error { + switch value := src.(type) { + case string: + *s = modelScanStatus(strings.ToUpper(value)) + return nil + case []byte: + *s = modelScanStatus(strings.ToUpper(string(value))) + return nil + default: + return fmt.Errorf("unsupported status source %T", src) + } +} + +func (s modelScanStatus) Value() (driver.Value, error) { + return string(s), nil +} + +type ModelScanEmbedded struct { + ID int64 `db:"id"` +} + +type ModelScanProfile struct { + Name *string `db:"name"` +} + +type modelScanRow struct { + ModelScanEmbedded + *ModelScanProfile + Age *int32 `db:"age"` + Score *float32 `db:"score"` + Visits *uint16 `db:"visits"` + Status *modelScanStatus `db:"status"` + Disabled *bool `db:"disabled"` +} + +type modelUnsupportedRow struct { + Payload *struct{} `db:"payload"` +} + +func openModelInternalDB(t *testing.T) *sql.DB { + t.Helper() + + dbPath := filepath.Join(t.TempDir(), "model-internal.sqlite") + db, err := sql.Open("sqlite", dbPath) + if err != nil { + t.Fatalf("open sqlite db: %v", err) + } + t.Cleanup(func() { + _ = db.Close() + }) + + return db +} + +func TestScanRowsSupportsExpandedNullableTypesAndEmbeddedStructs(t *testing.T) { + t.Parallel() + + db := openModelInternalDB(t) + if _, err := db.Exec(` + CREATE TABLE scan_rows ( + id INTEGER NOT NULL, + name TEXT, + age INTEGER, + score REAL, + visits INTEGER, + status TEXT, + disabled INTEGER + ) + `); err != nil { + t.Fatalf("create table: %v", err) + } + + if _, err := db.Exec(` + INSERT INTO scan_rows(id, name, age, score, visits, status, disabled) + VALUES (1, 'alice', 33, 9.5, 12, 'active', NULL) + `); err != nil { + t.Fatalf("insert table row: %v", err) + } + + rows, err := db.Query(` + SELECT id, name, age, score, visits, status, disabled + FROM scan_rows + `) + if err != nil { + t.Fatalf("query rows: %v", err) + } + t.Cleanup(func() { + _ = rows.Close() + }) + + var scanned modelScanRow + if err := scanRows(rows, &scanned); err != nil { + t.Fatalf("scan rows: %v", err) + } + + if scanned.ID != 1 { + t.Fatalf("expected embedded id 1, got %d", scanned.ID) + } + if scanned.ModelScanProfile == nil || scanned.Name == nil || *scanned.Name != "alice" { + t.Fatalf("expected embedded profile name alice, got %#v", scanned.ModelScanProfile) + } + if scanned.Age == nil || *scanned.Age != 33 { + t.Fatalf("expected age 33, got %#v", scanned.Age) + } + if scanned.Score == nil || *scanned.Score != float32(9.5) { + t.Fatalf("expected score 9.5, got %#v", scanned.Score) + } + if scanned.Visits == nil || *scanned.Visits != uint16(12) { + t.Fatalf("expected visits 12, got %#v", scanned.Visits) + } + if scanned.Status == nil || *scanned.Status != modelScanStatus("ACTIVE") { + t.Fatalf("expected custom scanner status ACTIVE, got %#v", scanned.Status) + } + if scanned.Disabled != nil { + t.Fatalf("expected null bool pointer, got %#v", scanned.Disabled) + } +} + +func TestScanRowsDiscardsUnknownColumnsAndLeavesMissingFields(t *testing.T) { + t.Parallel() + + db := openModelInternalDB(t) + if _, err := db.Exec(` + CREATE TABLE unknown_columns ( + id INTEGER NOT NULL, + ghost TEXT + ) + `); err != nil { + t.Fatalf("create table: %v", err) + } + if _, err := db.Exec(`INSERT INTO unknown_columns(id, ghost) VALUES (7, 'ignored')`); err != nil { + t.Fatalf("insert row: %v", err) + } + + type row struct { + ID int64 `db:"id"` + Name *string `db:"name"` + } + + rows, err := db.Query(`SELECT id, ghost FROM unknown_columns`) + if err != nil { + t.Fatalf("query rows: %v", err) + } + t.Cleanup(func() { + _ = rows.Close() + }) + + var scanned row + if err := scanRows(rows, &scanned); err != nil { + t.Fatalf("scan rows: %v", err) + } + + if scanned.ID != 7 { + t.Fatalf("expected id 7, got %d", scanned.ID) + } + if scanned.Name != nil { + t.Fatalf("expected missing field to remain nil, got %#v", scanned.Name) + } +} + +func TestScanRowsUnsupportedNullableTypeReturnsClearError(t *testing.T) { + t.Parallel() + + db := openModelInternalDB(t) + if _, err := db.Exec(`CREATE TABLE unsupported_type (payload TEXT)`); err != nil { + t.Fatalf("create table: %v", err) + } + if _, err := db.Exec(`INSERT INTO unsupported_type(payload) VALUES ('bad')`); err != nil { + t.Fatalf("insert row: %v", err) + } + + rows, err := db.Query(`SELECT payload FROM unsupported_type`) + if err != nil { + t.Fatalf("query rows: %v", err) + } + t.Cleanup(func() { + _ = rows.Close() + }) + + var scanned modelUnsupportedRow + err = scanRows(rows, &scanned) + if err == nil { + t.Fatalf("expected unsupported nullable type error") + } + if !strings.Contains(err.Error(), "unsupported nullable field type") { + t.Fatalf("expected clear unsupported type message, got %v", err) + } +} From 862a716549051bdb05dbdf8b21279a69dfa265d3 Mon Sep 17 00:00:00 2001 From: Minh Cung Date: Sun, 29 Mar 2026 15:12:29 +1100 Subject: [PATCH 2/2] fix(rain): support scanner-only pointer fields --- pkg/rain/model.go | 40 +++++++++------------------------ pkg/rain/model_internal_test.go | 5 ----- 2 files changed, 11 insertions(+), 34 deletions(-) diff --git a/pkg/rain/model.go b/pkg/rain/model.go index 57f29a2..50493b9 100644 --- a/pkg/rain/model.go +++ b/pkg/rain/model.go @@ -2,7 +2,6 @@ package rain import ( "database/sql" - "database/sql/driver" "fmt" "reflect" "strings" @@ -14,10 +13,6 @@ type scannerInterface = interface { Scan(src any) error } -type valuerInterface = interface { - Value() (driver.Value, error) -} - type modelField struct { index []int } @@ -252,7 +247,6 @@ func nullablePrimitiveHandlers() []nullableHandler { func scannerTarget(field reflect.Value) (any, func() error, bool) { scannerType := reflect.TypeFor[scannerInterface]() - valuerType := reflect.TypeFor[valuerInterface]() if field.Kind() != reflect.Pointer { if field.CanAddr() && field.Addr().Type().Implements(scannerType) { @@ -262,14 +256,12 @@ func scannerTarget(field reflect.Value) (any, func() error, bool) { } fieldType := field.Type() - if fieldType.Implements(scannerType) && fieldType.Implements(valuerType) { + if fieldType.Implements(scannerType) { receiver := reflect.New(fieldType.Elem()) - if receiver.Type().Implements(scannerType) { - return receiver.Interface(), func() error { - field.Set(receiver) - return nil - }, true - } + return receiver.Interface(), func() error { + field.Set(receiver) + return nil + }, true } if fieldType.Elem().Implements(scannerType) { @@ -302,14 +294,9 @@ func nullableStringTarget(field reflect.Value) (any, func() error, bool) { func nullableSignedIntTarget(field reflect.Value) (any, func() error, bool) { elemType := field.Type().Elem() - signed := map[reflect.Kind]bool{ - reflect.Int: true, - reflect.Int8: true, - reflect.Int16: true, - reflect.Int32: true, - reflect.Int64: true, - } - if !signed[elemType.Kind()] { + switch elemType.Kind() { + case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64: + default: return nil, nil, false } @@ -328,14 +315,9 @@ func nullableSignedIntTarget(field reflect.Value) (any, func() error, bool) { func nullableUnsignedIntTarget(field reflect.Value) (any, func() error, bool) { elemType := field.Type().Elem() - unsigned := map[reflect.Kind]bool{ - reflect.Uint: true, - reflect.Uint8: true, - reflect.Uint16: true, - reflect.Uint32: true, - reflect.Uint64: true, - } - if !unsigned[elemType.Kind()] { + switch elemType.Kind() { + case reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64: + default: return nil, nil, false } diff --git a/pkg/rain/model_internal_test.go b/pkg/rain/model_internal_test.go index a14c0a2..5b7ac5c 100644 --- a/pkg/rain/model_internal_test.go +++ b/pkg/rain/model_internal_test.go @@ -2,7 +2,6 @@ package rain import ( "database/sql" - "database/sql/driver" "fmt" "path/filepath" "strings" @@ -26,10 +25,6 @@ func (s *modelScanStatus) Scan(src any) error { } } -func (s modelScanStatus) Value() (driver.Value, error) { - return string(s), nil -} - type ModelScanEmbedded struct { ID int64 `db:"id"` }