Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
255 changes: 208 additions & 47 deletions pkg/rain/model.go
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,10 @@ import (
"time"
)

type scannerInterface = interface {
Scan(src any) error
}

type modelField struct {
index []int
}
Expand Down Expand Up @@ -60,8 +64,8 @@ func buildModelMeta(meta *modelMeta, typ reflect.Type, prefix []int) {
}

current := append(append([]int{}, prefix...), fieldIndex)
if field.Anonymous && field.Type.Kind() == reflect.Struct {
buildModelMeta(meta, field.Type, current)
if embedded := embeddedStructType(field); embedded != nil {
buildModelMeta(meta, embedded, current)
continue
}

Expand All @@ -77,6 +81,22 @@ func buildModelMeta(meta *modelMeta, typ reflect.Type, prefix []int) {
}
}

func embeddedStructType(field reflect.StructField) reflect.Type {
if !field.Anonymous {
return nil
}

typ := field.Type
if typ.Kind() == reflect.Pointer {
typ = typ.Elem()
}
if typ.Kind() != reflect.Struct {
return nil
}

return typ
}

func relationTagName(tag string) string {
trimmed := strings.TrimSpace(tag)
if trimmed == "" || trimmed == "-" {
Expand Down Expand Up @@ -144,7 +164,10 @@ func scanCurrentRow(rows *sql.Rows, target reflect.Value) error {
continue
}

field := target.FieldByIndex(fieldInfo.index)
field, err := fieldByIndexAlloc(target, fieldInfo.index)
if err != nil {
return err
}
scanTarget, finalize, err := prepareScanTarget(field)
if err != nil {
return err
Expand All @@ -169,62 +192,200 @@ func scanCurrentRow(rows *sql.Rows, target reflect.Value) error {
return nil
}

func fieldByIndexAlloc(value reflect.Value, index []int) (reflect.Value, error) {
current := value
for position, part := range index {
field := current.Field(part)
if position < len(index)-1 && field.Kind() == reflect.Pointer {
if field.IsNil() {
if !field.CanSet() {
return reflect.Value{}, fmt.Errorf("rain: embedded pointer field %s is not settable", field.Type())
}
field.Set(reflect.New(field.Type().Elem()))
}
current = field.Elem()
continue
}
current = field
}
return current, nil
}

func prepareScanTarget(field reflect.Value) (any, func() error, error) {
if scanTarget, finalize, ok := scannerTarget(field); ok {
return scanTarget, finalize, nil
}

if field.Kind() != reflect.Pointer {
if !field.CanAddr() {
return nil, nil, fmt.Errorf("rain: field %s is not addressable", field.Type())
}
return field.Addr().Interface(), nil, nil
}

elemType := field.Type().Elem()
switch {
case elemType.Kind() == reflect.String:
holder := sql.Null[string]{}
return &holder, func() error {
if !holder.Valid {
field.Set(reflect.Zero(field.Type()))
return nil
}
value := holder.V
field.Set(reflect.ValueOf(&value))
for _, handler := range nullablePrimitiveHandlers() {
if scanTarget, finalize, ok := handler(field); ok {
return scanTarget, finalize, nil
}
}

return nil, nil, fmt.Errorf("rain: unsupported nullable field type %s", field.Type())
}

type nullableHandler func(field reflect.Value) (target any, finalize func() error, ok bool)

func nullablePrimitiveHandlers() []nullableHandler {
return []nullableHandler{
nullableStringTarget,
nullableSignedIntTarget,
nullableUnsignedIntTarget,
nullableFloatTarget,
nullableBoolTarget,
nullableTimeTarget,
}
}

func scannerTarget(field reflect.Value) (any, func() error, bool) {
scannerType := reflect.TypeFor[scannerInterface]()

if field.Kind() != reflect.Pointer {
if field.CanAddr() && field.Addr().Type().Implements(scannerType) {
return field.Addr().Interface(), nil, true
}
return nil, nil, false
}

fieldType := field.Type()
if fieldType.Implements(scannerType) {
receiver := reflect.New(fieldType.Elem())
return receiver.Interface(), func() error {
field.Set(receiver)
return nil
}, nil
case elemType.Kind() == reflect.Int || elemType.Kind() == reflect.Int64:
holder := sql.Null[int64]{}
return &holder, func() error {
if !holder.Valid {
field.Set(reflect.Zero(field.Type()))
return nil
}
ptr := reflect.New(elemType)
ptr.Elem().SetInt(holder.V)
field.Set(ptr)
}, true
}
Comment on lines +258 to +265
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

P1 Scanner-only pointer types silently unsupported

The first branch inside the pointer case requires that fieldType implements both scannerInterface and valuerInterface before handling it. Any *T type that implements sql.Scanner (pointer receiver) but does not implement driver.Valuer will fall through this block.

The second branch (fieldType.Elem().Implements(scannerType)) only rescues types where T — not *T — has a value-receiver Scan method, which is uncommon. So the typical custom scanner type:

type Status string
func (s *Status) Scan(src any) error { ... }  // pointer receiver, no Value()

…used as Status *Status on a struct will reach nullablePrimitiveHandlers(), find no match, and return "rain: unsupported nullable field type *Status" — silently failing at runtime despite the type correctly implementing the standard sql.Scanner interface.

The driver.Valuer requirement is not necessary for scan-path correctness and should be removed from the guard. The same receiver.Interface()/field.Set(receiver) pattern works regardless of whether the type also satisfies Valuer:

	if fieldType.Implements(scannerType) {
		receiver := reflect.New(fieldType.Elem())
		if receiver.Type().Implements(scannerType) {
			return receiver.Interface(), func() error {
				field.Set(receiver)
				return nil
			}, true
		}
	}
Prompt To Fix With AI
This is a comment left during a code review.
Path: pkg/rain/model.go
Line: 264-273

Comment:
**Scanner-only pointer types silently unsupported**

The first branch inside the pointer case requires that `fieldType` implements *both* `scannerInterface` and `valuerInterface` before handling it. Any `*T` type that implements `sql.Scanner` (pointer receiver) but does **not** implement `driver.Valuer` will fall through this block.

The second branch (`fieldType.Elem().Implements(scannerType)`) only rescues types where `T` — not `*T` — has a value-receiver `Scan` method, which is uncommon. So the typical custom scanner type:

```go
type Status string
func (s *Status) Scan(src any) error { ... }  // pointer receiver, no Value()
```

…used as `Status *Status` on a struct will reach `nullablePrimitiveHandlers()`, find no match, and return `"rain: unsupported nullable field type *Status"` — silently failing at runtime despite the type correctly implementing the standard `sql.Scanner` interface.

The `driver.Valuer` requirement is not necessary for scan-path correctness and should be removed from the guard. The same `receiver.Interface()`/`field.Set(receiver)` pattern works regardless of whether the type also satisfies `Valuer`:

```go
	if fieldType.Implements(scannerType) {
		receiver := reflect.New(fieldType.Elem())
		if receiver.Type().Implements(scannerType) {
			return receiver.Interface(), func() error {
				field.Set(receiver)
				return nil
			}, true
		}
	}
```

How can I resolve this? If you propose a fix, please make it concise.


if fieldType.Elem().Implements(scannerType) {
receiver := reflect.New(fieldType.Elem())
return receiver.Interface(), func() error {
field.Set(receiver)
return nil
}, nil
case elemType.Kind() == reflect.Bool:
holder := sql.Null[bool]{}
return &holder, func() error {
if !holder.Valid {
field.Set(reflect.Zero(field.Type()))
return nil
}
value := holder.V
field.Set(reflect.ValueOf(&value))
}, true
}

return nil, nil, false
}

func nullableStringTarget(field reflect.Value) (any, func() error, bool) {
if field.Type().Elem().Kind() != reflect.String {
return nil, nil, false
}

holder := sql.Null[string]{}
return &holder, func() error {
if !holder.Valid {
field.Set(reflect.Zero(field.Type()))
return nil
}, nil
case elemType == reflect.TypeFor[time.Time]():
holder := sql.Null[time.Time]{}
return &holder, func() error {
if !holder.Valid {
field.Set(reflect.Zero(field.Type()))
return nil
}
value := holder.V
field.Set(reflect.ValueOf(&value))
}
value := holder.V
field.Set(reflect.ValueOf(&value))
return nil
}, true
}

func nullableSignedIntTarget(field reflect.Value) (any, func() error, bool) {
elemType := field.Type().Elem()
switch elemType.Kind() {
case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64:
default:
return nil, nil, false
}

holder := sql.Null[int64]{}
return &holder, func() error {
if !holder.Valid {
field.Set(reflect.Zero(field.Type()))
return nil
}, nil
}
ptr := reflect.New(elemType)
ptr.Elem().SetInt(holder.V)
field.Set(ptr)
return nil
}, true
}

func nullableUnsignedIntTarget(field reflect.Value) (any, func() error, bool) {
elemType := field.Type().Elem()
switch elemType.Kind() {
case reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64:
default:
return nil, nil, fmt.Errorf("rain: unsupported nullable field type %s", field.Type())
return nil, nil, false
}

holder := sql.Null[int64]{}
return &holder, func() error {
if !holder.Valid {
field.Set(reflect.Zero(field.Type()))
return nil
}
if holder.V < 0 {
return fmt.Errorf("rain: cannot scan negative value %d into %s", holder.V, field.Type())
}
ptr := reflect.New(elemType)
ptr.Elem().SetUint(uint64(holder.V))
field.Set(ptr)
return nil
}, true
}

func nullableFloatTarget(field reflect.Value) (any, func() error, bool) {
elemType := field.Type().Elem()
if elemType.Kind() != reflect.Float32 && elemType.Kind() != reflect.Float64 {
return nil, nil, false
}

holder := sql.Null[float64]{}
return &holder, func() error {
if !holder.Valid {
field.Set(reflect.Zero(field.Type()))
return nil
}
ptr := reflect.New(elemType)
ptr.Elem().SetFloat(holder.V)
field.Set(ptr)
return nil
}, true
}

func nullableBoolTarget(field reflect.Value) (any, func() error, bool) {
if field.Type().Elem().Kind() != reflect.Bool {
return nil, nil, false
}

holder := sql.Null[bool]{}
return &holder, func() error {
if !holder.Valid {
field.Set(reflect.Zero(field.Type()))
return nil
}
value := holder.V
field.Set(reflect.ValueOf(&value))
return nil
}, true
}

func nullableTimeTarget(field reflect.Value) (any, func() error, bool) {
if field.Type().Elem() != reflect.TypeFor[time.Time]() {
return nil, nil, false
}

holder := sql.Null[time.Time]{}
return &holder, func() error {
if !holder.Valid {
field.Set(reflect.Zero(field.Type()))
return nil
}
value := holder.V
field.Set(reflect.ValueOf(&value))
return nil
}, true
}
Loading
Loading