diff --git a/pkg/rain/model.go b/pkg/rain/model.go index 5b2d30b..2cee098 100644 --- a/pkg/rain/model.go +++ b/pkg/rain/model.go @@ -40,7 +40,25 @@ type scanColumnPlan struct { } type rowScanPlan struct { - columns []scanColumnPlan + columns []scanColumnPlan + clearIndices []int + + int64ValueCols []scanColumnPlan + int64PointerCols []scanColumnPlan + + stringValueCols []scanColumnPlan + stringPointerCols []scanColumnPlan + + boolValueCols []scanColumnPlan + boolPointerCols []scanColumnPlan + + float64ValueCols []scanColumnPlan + float64PointerCols []scanColumnPlan + + timeValueCols []scanColumnPlan + timePointerCols []scanColumnPlan + + otherCols []scanColumnPlan } type rowScanPlanKey struct { @@ -53,15 +71,6 @@ type rowScanPlanKey struct { var rowScanPlanCache sync.Map -type boundRowScanPlan struct { - columns []boundColumnScan - clearIndices []int -} - -type boundColumnScan struct { - scan func(reflect.Value) error -} - var modelMetaCache sync.Map func lookupModelMeta(model any) (*modelMeta, reflect.Value, error) { @@ -218,13 +227,12 @@ func scanRowsAgainstTableDirect(rows *sql.Rows, dest any, table *schema.TableDef } scanTargets, scanned := newScanTargets(cols, plan, nil, nil) - bound := plan.bind(scanned) if err := rows.Scan(scanTargets...); err != nil { return err } - return scanDirectRowWithPlan(target, bound) + return scanDirectRow(target, plan, scanned) case reflect.Slice: elemType := target.Type().Elem() structType, pointerElems, err := sliceElementStructType(elemType) @@ -237,7 +245,6 @@ func scanRowsAgainstTableDirect(rows *sql.Rows, dest any, table *schema.TableDef } scanTargets, scanned := newScanTargets(cols, plan, nil, nil) - bound := plan.bind(scanned) zeroElem := reflect.Zero(elemType) // Use a local slice header to grow the result set. If rows.Scan fails, @@ -250,7 +257,7 @@ func scanRowsAgainstTableDirect(rows *sql.Rows, dest any, table *schema.TableDef // Clear any previous generic scanned values to avoid carrying over data // for non-direct columns. Direct columns use pointers to scratch variables // that are overwritten by rows.Scan. - for _, idx := range bound.clearIndices { + for _, idx := range plan.clearIndices { scanned[idx] = nil } @@ -278,7 +285,7 @@ func scanRowsAgainstTableDirect(rows *sql.Rows, dest any, table *schema.TableDef scanTarget = item } - if err := scanDirectRowWithPlan(scanTarget, bound); err != nil { + if err := scanDirectRow(scanTarget, plan, scanned); err != nil { return err } } @@ -341,9 +348,315 @@ func newScanTargets(cols []string, plan *rowScanPlan, scanTargets, scanned []any return scanTargets, scanned } -func scanDirectRowWithPlan(target reflect.Value, bound *boundRowScanPlan) error { - for i := range bound.columns { - if err := bound.columns[i].scan(target); err != nil { +func scanDirectRow(target reflect.Value, plan *rowScanPlan, scanned []any) error { + for i := range plan.int64ValueCols { + col := &plan.int64ValueCols[i] + v := scanned[col.scanIndex].(*sql.NullInt64) + if !v.Valid { + return fmt.Errorf("rain: cannot assign NULL to non-pointer field %s", col.fieldType) + } + var field reflect.Value + if col.isComplex { + var err error + field, err = fieldByIndexAlloc(target, col.fieldIndex) + if err != nil { + return err + } + } else { + field = target.Field(col.index0) + } + switch field.Kind() { + case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64: + if field.OverflowInt(v.Int64) { + return fmt.Errorf("rain: value %d overflows field %s", v.Int64, field.Type()) + } + field.SetInt(v.Int64) + case reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64: + if v.Int64 < 0 || field.OverflowUint(uint64(v.Int64)) { + return fmt.Errorf("rain: value %d overflows field %s", v.Int64, field.Type()) + } + field.SetUint(uint64(v.Int64)) + default: + if err := assignRawValueToField(field, v.Int64); err != nil { + return err + } + } + } + for i := range plan.int64PointerCols { + col := &plan.int64PointerCols[i] + v := scanned[col.scanIndex].(*sql.NullInt64) + var field reflect.Value + if col.isComplex { + var err error + field, err = fieldByIndexAlloc(target, col.fieldIndex) + if err != nil { + return err + } + } else { + field = target.Field(col.index0) + } + if !v.Valid { + field.SetZero() + continue + } + if field.IsNil() { + field.Set(reflect.New(field.Type().Elem())) + } + field = field.Elem() + switch field.Kind() { + case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64: + if field.OverflowInt(v.Int64) { + return fmt.Errorf("rain: value %d overflows field %s", v.Int64, field.Type()) + } + field.SetInt(v.Int64) + case reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64: + if v.Int64 < 0 || field.OverflowUint(uint64(v.Int64)) { + return fmt.Errorf("rain: value %d overflows field %s", v.Int64, field.Type()) + } + field.SetUint(uint64(v.Int64)) + default: + if err := assignRawValueToField(field, v.Int64); err != nil { + return err + } + } + } + for i := range plan.stringValueCols { + col := &plan.stringValueCols[i] + v := scanned[col.scanIndex].(*sql.NullString) + if !v.Valid { + return fmt.Errorf("rain: cannot assign NULL to non-pointer field %s", col.fieldType) + } + var field reflect.Value + if col.isComplex { + var err error + field, err = fieldByIndexAlloc(target, col.fieldIndex) + if err != nil { + return err + } + } else { + field = target.Field(col.index0) + } + if field.Kind() == reflect.String { + field.SetString(v.String) + } else { + if err := assignRawValueToField(field, v.String); err != nil { + return err + } + } + } + for i := range plan.stringPointerCols { + col := &plan.stringPointerCols[i] + v := scanned[col.scanIndex].(*sql.NullString) + var field reflect.Value + if col.isComplex { + var err error + field, err = fieldByIndexAlloc(target, col.fieldIndex) + if err != nil { + return err + } + } else { + field = target.Field(col.index0) + } + if !v.Valid { + field.SetZero() + continue + } + if field.IsNil() { + field.Set(reflect.New(field.Type().Elem())) + } + field = field.Elem() + if field.Kind() == reflect.String { + field.SetString(v.String) + } else { + if err := assignRawValueToField(field, v.String); err != nil { + return err + } + } + } + for i := range plan.boolValueCols { + col := &plan.boolValueCols[i] + v := scanned[col.scanIndex].(*sql.NullBool) + if !v.Valid { + return fmt.Errorf("rain: cannot assign NULL to non-pointer field %s", col.fieldType) + } + var field reflect.Value + if col.isComplex { + var err error + field, err = fieldByIndexAlloc(target, col.fieldIndex) + if err != nil { + return err + } + } else { + field = target.Field(col.index0) + } + if field.Kind() == reflect.Bool { + field.SetBool(v.Bool) + } else { + if err := assignRawValueToField(field, v.Bool); err != nil { + return err + } + } + } + for i := range plan.boolPointerCols { + col := &plan.boolPointerCols[i] + v := scanned[col.scanIndex].(*sql.NullBool) + var field reflect.Value + if col.isComplex { + var err error + field, err = fieldByIndexAlloc(target, col.fieldIndex) + if err != nil { + return err + } + } else { + field = target.Field(col.index0) + } + if !v.Valid { + field.SetZero() + continue + } + if field.IsNil() { + field.Set(reflect.New(field.Type().Elem())) + } + field = field.Elem() + if field.Kind() == reflect.Bool { + field.SetBool(v.Bool) + } else { + if err := assignRawValueToField(field, v.Bool); err != nil { + return err + } + } + } + for i := range plan.float64ValueCols { + col := &plan.float64ValueCols[i] + v := scanned[col.scanIndex].(*sql.NullFloat64) + if !v.Valid { + return fmt.Errorf("rain: cannot assign NULL to non-pointer field %s", col.fieldType) + } + var field reflect.Value + if col.isComplex { + var err error + field, err = fieldByIndexAlloc(target, col.fieldIndex) + if err != nil { + return err + } + } else { + field = target.Field(col.index0) + } + if field.Kind() == reflect.Float32 || field.Kind() == reflect.Float64 { + if field.OverflowFloat(v.Float64) { + return fmt.Errorf("rain: value %f overflows field %s", v.Float64, field.Type()) + } + field.SetFloat(v.Float64) + } else { + if err := assignRawValueToField(field, v.Float64); err != nil { + return err + } + } + } + for i := range plan.float64PointerCols { + col := &plan.float64PointerCols[i] + v := scanned[col.scanIndex].(*sql.NullFloat64) + var field reflect.Value + if col.isComplex { + var err error + field, err = fieldByIndexAlloc(target, col.fieldIndex) + if err != nil { + return err + } + } else { + field = target.Field(col.index0) + } + if !v.Valid { + field.SetZero() + continue + } + if field.IsNil() { + field.Set(reflect.New(field.Type().Elem())) + } + field = field.Elem() + if field.Kind() == reflect.Float32 || field.Kind() == reflect.Float64 { + if field.OverflowFloat(v.Float64) { + return fmt.Errorf("rain: value %f overflows field %s", v.Float64, field.Type()) + } + field.SetFloat(v.Float64) + } else { + if err := assignRawValueToField(field, v.Float64); err != nil { + return err + } + } + } + for i := range plan.timeValueCols { + col := &plan.timeValueCols[i] + v := scanned[col.scanIndex].(*sql.NullTime) + if !v.Valid { + return fmt.Errorf("rain: cannot assign NULL to non-pointer field %s", col.fieldType) + } + var field reflect.Value + if col.isComplex { + var err error + field, err = fieldByIndexAlloc(target, col.fieldIndex) + if err != nil { + return err + } + } else { + field = target.Field(col.index0) + } + if field.Type() == reflect.TypeFor[time.Time]() { + *field.Addr().Interface().(*time.Time) = v.Time + } else { + if err := assignRawValueToField(field, v.Time); err != nil { + return err + } + } + } + for i := range plan.timePointerCols { + col := &plan.timePointerCols[i] + v := scanned[col.scanIndex].(*sql.NullTime) + var field reflect.Value + if col.isComplex { + var err error + field, err = fieldByIndexAlloc(target, col.fieldIndex) + if err != nil { + return err + } + } else { + field = target.Field(col.index0) + } + if !v.Valid { + field.SetZero() + continue + } + if field.IsNil() { + field.Set(reflect.New(field.Type().Elem())) + } + field = field.Elem() + if field.Type() == reflect.TypeFor[time.Time]() { + *field.Addr().Interface().(*time.Time) = v.Time + } else { + if err := assignRawValueToField(field, v.Time); err != nil { + return err + } + } + } + for i := range plan.otherCols { + col := &plan.otherCols[i] + var field reflect.Value + if col.isComplex { + var err error + field, err = fieldByIndexAlloc(target, col.fieldIndex) + if err != nil { + return err + } + } else { + field = target.Field(col.index0) + } + rowVal := scanned[col.scanIndex] + if !col.isDirect && col.isJSON { + if s, ok := rowVal.(string); ok { + rowVal = []byte(s) + } + } + if err := assignRawValueToField(field, rowVal); err != nil { return err } } @@ -443,10 +756,14 @@ func newRowScanPlanForColumns(cols []string, modelType reflect.Type, table *sche return nil, err } - plan := &rowScanPlan{columns: make([]scanColumnPlan, 0, len(cols))} + plan := &rowScanPlan{ + columns: make([]scanColumnPlan, 0, len(cols)), + clearIndices: make([]int, 0), + } for idx, name := range cols { fieldInfo, ok := meta.byColumn[name] if !ok { + plan.clearIndices = append(plan.clearIndices, idx) continue } @@ -479,8 +796,11 @@ func newRowScanPlanForColumns(cols []string, modelType reflect.Type, table *sche } isDirect := !isJSON && isSimpleDirectType(fieldType) + if !isDirect { + plan.clearIndices = append(plan.clearIndices, idx) + } - plan.columns = append(plan.columns, scanColumnPlan{ + colPlan := scanColumnPlan{ columnName: name, scanIndex: idx, fieldIndex: fieldInfo.index, @@ -490,203 +810,62 @@ func newRowScanPlanForColumns(cols []string, modelType reflect.Type, table *sche isDirect: isDirect, columnDef: columnDef, fieldType: fieldType, - }) - } - - actual, _ := rowScanPlanCache.LoadOrStore(key, plan) - return actual.(*rowScanPlan), nil -} - -func (p *rowScanPlan) bind(scanned []any) *boundRowScanPlan { - bound := &boundRowScanPlan{ - columns: make([]boundColumnScan, len(p.columns)), - } - - for i := range p.columns { - col := &p.columns[i] - idx := col.scanIndex - val := scanned[idx] - - // Pre-calculate which indices in the scanned slice should be cleared. - // These are the ones where we scan into the scanned slice itself, - // rather than into a specialized sql.Null* type. - if scanned[idx] == nil { - bound.clearIndices = append(bound.clearIndices, idx) } - - fieldIndex := col.fieldIndex - isComplex := col.isComplex - index0 := col.index0 - isDirect := col.isDirect - isJSON := col.isJSON - - var scanFn func(reflect.Value) error + plan.columns = append(plan.columns, colPlan) if isDirect { - switch v := val.(type) { - case *sql.NullInt64: - scanFn = func(target reflect.Value) error { - var field reflect.Value - if isComplex { - var err error - field, err = fieldByIndexAlloc(target, fieldIndex) - if err != nil { - return err - } - } else { - field = target.Field(index0) - } - if !v.Valid { - return assignRawValueToField(field, nil) - } - switch field.Kind() { - case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64: - if field.OverflowInt(v.Int64) { - return fmt.Errorf("rain: value %d overflows field %s", v.Int64, field.Type()) - } - field.SetInt(v.Int64) - case reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64: - if v.Int64 < 0 || field.OverflowUint(uint64(v.Int64)) { - return fmt.Errorf("rain: value %d overflows field %s", v.Int64, field.Type()) - } - field.SetUint(uint64(v.Int64)) - default: - return assignRawValueToField(field, v.Int64) - } - return nil + isPtr := fieldType.Kind() == reflect.Pointer + baseType := fieldType + if isPtr { + baseType = fieldType.Elem() + } + + switch baseType.Kind() { + case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64, + reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64: + if isPtr { + plan.int64PointerCols = append(plan.int64PointerCols, colPlan) + } else { + plan.int64ValueCols = append(plan.int64ValueCols, colPlan) } - case *sql.NullString: - scanFn = func(target reflect.Value) error { - var field reflect.Value - if isComplex { - var err error - field, err = fieldByIndexAlloc(target, fieldIndex) - if err != nil { - return err - } - } else { - field = target.Field(index0) - } - if !v.Valid { - return assignRawValueToField(field, nil) - } - if field.Kind() == reflect.String { - field.SetString(v.String) - } else { - return assignRawValueToField(field, v.String) - } - return nil + case reflect.String: + if isPtr { + plan.stringPointerCols = append(plan.stringPointerCols, colPlan) + } else { + plan.stringValueCols = append(plan.stringValueCols, colPlan) } - case *sql.NullBool: - scanFn = func(target reflect.Value) error { - var field reflect.Value - if isComplex { - var err error - field, err = fieldByIndexAlloc(target, fieldIndex) - if err != nil { - return err - } - } else { - field = target.Field(index0) - } - if !v.Valid { - return assignRawValueToField(field, nil) - } - if field.Kind() == reflect.Bool { - field.SetBool(v.Bool) - } else { - return assignRawValueToField(field, v.Bool) - } - return nil + case reflect.Bool: + if isPtr { + plan.boolPointerCols = append(plan.boolPointerCols, colPlan) + } else { + plan.boolValueCols = append(plan.boolValueCols, colPlan) } - case *sql.NullFloat64: - scanFn = func(target reflect.Value) error { - var field reflect.Value - if isComplex { - var err error - field, err = fieldByIndexAlloc(target, fieldIndex) - if err != nil { - return err - } - } else { - field = target.Field(index0) - } - if !v.Valid { - return assignRawValueToField(field, nil) - } - if field.Kind() == reflect.Float32 || field.Kind() == reflect.Float64 { - if field.OverflowFloat(v.Float64) { - return fmt.Errorf("rain: value %f overflows field %s", v.Float64, field.Type()) - } - field.SetFloat(v.Float64) - } else { - return assignRawValueToField(field, v.Float64) - } - return nil + case reflect.Float32, reflect.Float64: + if isPtr { + plan.float64PointerCols = append(plan.float64PointerCols, colPlan) + } else { + plan.float64ValueCols = append(plan.float64ValueCols, colPlan) } - case *sql.NullTime: - scanFn = func(target reflect.Value) error { - var field reflect.Value - if isComplex { - var err error - field, err = fieldByIndexAlloc(target, fieldIndex) - if err != nil { - return err - } + case reflect.Struct: + if baseType == reflect.TypeFor[time.Time]() { + if isPtr { + plan.timePointerCols = append(plan.timePointerCols, colPlan) } else { - field = target.Field(index0) - } - if !v.Valid { - return assignRawValueToField(field, nil) + plan.timeValueCols = append(plan.timeValueCols, colPlan) } - if field.Type() == reflect.TypeFor[time.Time]() { - *field.Addr().Interface().(*time.Time) = v.Time - } else { - return assignRawValueToField(field, v.Time) - } - return nil + } else { + plan.otherCols = append(plan.otherCols, colPlan) } default: - scanFn = func(target reflect.Value) error { - var field reflect.Value - if isComplex { - var err error - field, err = fieldByIndexAlloc(target, fieldIndex) - if err != nil { - return err - } - } else { - field = target.Field(index0) - } - return assignRawValueToField(field, scanned[idx]) - } + plan.otherCols = append(plan.otherCols, colPlan) } } else { - scanFn = func(target reflect.Value) error { - var field reflect.Value - if isComplex { - var err error - field, err = fieldByIndexAlloc(target, fieldIndex) - if err != nil { - return err - } - } else { - field = target.Field(index0) - } - - rowVal := scanned[idx] - if isJSON { - if s, ok := rowVal.(string); ok { - rowVal = []byte(s) - } - } - - return assignRawValueToField(field, rowVal) - } + plan.otherCols = append(plan.otherCols, colPlan) } - bound.columns[i] = boundColumnScan{scan: scanFn} } - return bound + + actual, _ := rowScanPlanCache.LoadOrStore(key, plan) + return actual.(*rowScanPlan), nil } func isSimpleDirectType(t reflect.Type) bool { diff --git a/pkg/rain/model_internal_test.go b/pkg/rain/model_internal_test.go index 48d3347..3249fa7 100644 --- a/pkg/rain/model_internal_test.go +++ b/pkg/rain/model_internal_test.go @@ -216,20 +216,23 @@ func TestBoundDirectFallbackReadsCurrentScannedValue(t *testing.T) { Name string `db:"name"` } - scanned := []any{"stale"} - plan := &rowScanPlan{columns: []scanColumnPlan{{ + colPlan := scanColumnPlan{ scanIndex: 0, fieldIndex: []int{0}, index0: 0, isDirect: true, fieldType: reflect.TypeFor[string](), - }}} - bound := plan.bind(scanned) + } + plan := &rowScanPlan{ + columns: []scanColumnPlan{colPlan}, + stringValueCols: []scanColumnPlan{colPlan}, + } - scanned[0] = "fresh" + scanned := []any{&sql.NullString{String: "stale", Valid: true}} + scanned[0].(*sql.NullString).String = "fresh" var got row - if err := scanDirectRowWithPlan(reflect.ValueOf(&got).Elem(), bound); err != nil { + if err := scanDirectRow(reflect.ValueOf(&got).Elem(), plan, scanned); err != nil { t.Fatalf("scan direct fallback: %v", err) } if got.Name != "fresh" {