Skip to content

Commit

Permalink
finishers are Get, All, Execute only | logger removed
Browse files Browse the repository at this point in the history
  • Loading branch information
amirrezaask committed Mar 17, 2022
1 parent 8167266 commit 2b2768b
Show file tree
Hide file tree
Showing 9 changed files with 237 additions and 285 deletions.
59 changes: 33 additions & 26 deletions binder.go
Original file line number Diff line number Diff line change
Expand Up @@ -10,30 +10,32 @@ import (

// makeNewPointersOf creates a map of [field name] -> pointer to fill it
// recursively. it will go down until reaches a driver.Valuer implementation, it will stop there.
func (b *binder[T]) makeNewPointersOf(v reflect.Value) map[string]interface{} {
func (b *binder) makeNewPointersOf(v reflect.Value) interface{} {
m := map[string]interface{}{}
actualV := v
for actualV.Type().Kind() == reflect.Ptr {
actualV = actualV.Elem()
}
for i := 0; i < actualV.NumField(); i++ {
f := actualV.Field(i)
if (f.Type().Kind() == reflect.Struct || f.Type().Kind() == reflect.Ptr) && !f.Type().Implements(reflect.TypeOf((*driver.Valuer)(nil)).Elem()) {
f = reflect.NewAt(actualV.Type().Field(i).Type, unsafe.Pointer(actualV.Field(i).UnsafeAddr()))
fm := b.makeNewPointersOf(f)
for k, p := range fm {
m[k] = p
if actualV.Type().Kind() == reflect.Struct {
for i := 0; i < actualV.NumField(); i++ {
f := actualV.Field(i)
if (f.Type().Kind() == reflect.Struct || f.Type().Kind() == reflect.Ptr) && !f.Type().Implements(reflect.TypeOf((*driver.Valuer)(nil)).Elem()) {
f = reflect.NewAt(actualV.Type().Field(i).Type, unsafe.Pointer(actualV.Field(i).UnsafeAddr()))
fm := b.makeNewPointersOf(f).(map[string]interface{})
for k, p := range fm {
m[k] = p
}
} else {
var fm *field
fm = b.s.getField(actualV.Type().Field(i))
if fm == nil {
fm = fieldMetadata(actualV.Type().Field(i), b.s.columnConstraints)[0]
}
m[fm.Name] = reflect.NewAt(actualV.Field(i).Type(), unsafe.Pointer(actualV.Field(i).UnsafeAddr())).Interface()
}
} else {
var fm *field
fm = b.s.getField(actualV.Type().Field(i))
if fm == nil {
var ec EntityConfigurator
(*new(T)).ConfigureEntity(&ec)
fm = fieldMetadata(actualV.Type().Field(i), ec.columnConstraints)[0]
}
m[fm.Name] = reflect.NewAt(actualV.Field(i).Type(), unsafe.Pointer(actualV.Field(i).UnsafeAddr())).Interface()
}
} else {
return v.Addr().Interface()
}

return m
Expand All @@ -42,28 +44,33 @@ func (b *binder[T]) makeNewPointersOf(v reflect.Value) map[string]interface{} {
// ptrsFor first allocates for all struct fields recursively until reaches a driver.Value impl
// then it will put them in a map with their correct field name as key, then loops over cts
// and for each one gets appropriate one from the map and adds it to pointer list.
func (b *binder[T]) ptrsFor(v reflect.Value, cts []*sql.ColumnType) []interface{} {
nameToPtr := b.makeNewPointersOf(v)
func (b *binder) ptrsFor(v reflect.Value, cts []*sql.ColumnType) []interface{} {
ptrs := b.makeNewPointersOf(v)
var scanInto []interface{}
for _, ct := range cts {
if nameToPtr[ct.Name()] != nil {
scanInto = append(scanInto, nameToPtr[ct.Name()])
if reflect.TypeOf(ptrs).Kind() == reflect.Map {
nameToPtr := ptrs.(map[string]interface{})
for _, ct := range cts {
if nameToPtr[ct.Name()] != nil {
scanInto = append(scanInto, nameToPtr[ct.Name()])
}
}
} else {
scanInto = append(scanInto, ptrs)
}

return scanInto
}

type binder[T Entity] struct {
type binder struct {
s *schema
}

func newBinder[T Entity](s *schema) *binder[T] {
return &binder[T]{s: s}
func newBinder(s *schema) *binder {
return &binder{s: s}
}

// bind binds given rows to the given object at obj. obj should be a pointer
func (b *binder[T]) bind(rows *sql.Rows, obj interface{}) error {
func (b *binder) bind(rows *sql.Rows, obj interface{}) error {
cts, err := rows.ColumnTypes()
if err != nil {
return err
Expand Down
4 changes: 2 additions & 2 deletions binder_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,7 @@ func TestBind(t *testing.T) {

u := &User{}
md := schemaOfHeavyReflectionStuff(u)
err = newBinder[User](md).bind(rows, u)
err = newBinder(md).bind(rows, u)
assert.NoError(t, err)

assert.Equal(t, "amirreza", u.Name)
Expand All @@ -57,7 +57,7 @@ func TestBind(t *testing.T) {

md := schemaOfHeavyReflectionStuff(&User{})
var users []*User
err = newBinder[User](md).bind(rows, &users)
err = newBinder(md).bind(rows, &users)
assert.NoError(t, err)

assert.Equal(t, "amirreza", users[0].Name)
Expand Down
7 changes: 0 additions & 7 deletions connection.go
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,6 @@ type connection struct {
Dialect *Dialect
DB *sql.DB
Schemas map[string]*schema
Logger Logger
}

func (c *connection) Schematic() {
Expand Down Expand Up @@ -58,19 +57,13 @@ func GetConnection(name string) *connection {
}

func (c *connection) exec(q string, args ...any) (sql.Result, error) {
globalLogger.Debugf(q)
globalLogger.Debugf("%v", args)
return c.DB.Exec(q, args...)
}

func (c *connection) query(q string, args ...any) (*sql.Rows, error) {
globalLogger.Debugf(q)
globalLogger.Debugf("%v", args)
return c.DB.Query(q, args...)
}

func (c *connection) queryRow(q string, args ...any) *sql.Row {
globalLogger.Debugf(q)
globalLogger.Debugf("%v", args)
return c.DB.QueryRow(q, args...)
}
62 changes: 0 additions & 62 deletions logger.go

This file was deleted.

45 changes: 16 additions & 29 deletions orm.go
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,6 @@ import (
)

var globalConnections = map[string]*connection{}
var globalLogger Logger

// Schematic prints all information ORM inferred from your entities in startup, remember to pass
// your entities in Entities when you call SetupConnections if you want their data inferred
Expand All @@ -27,11 +26,6 @@ func Schematic() {
}
}

type Config struct {
// LogLevel
LogLevel LogLevel
}

type ConnectionConfig struct {
// Name of your database connection, it's up to you to name them anything
// just remember that having a connection name is mandatory if
Expand All @@ -54,7 +48,6 @@ type ConnectionConfig struct {
func SetupConnections(configs ...ConnectionConfig) error {
// configure logger
var err error
globalLogger, err = newZapLogger(LogLevelDev)
if err != nil {
return err
}
Expand All @@ -74,9 +67,6 @@ func setupConnection(config ConnectionConfig) error {
config.Name = "default"
}

globalLogger.Infof("Generating schema definitions for connection %s entities", config.Name)
globalLogger.Infof("Entities are: %v", entitiesAsList(config.Entities))

for _, entity := range config.Entities {
s := schemaOfHeavyReflectionStuff(entity)
var configurator EntityConfigurator
Expand All @@ -93,8 +83,6 @@ func setupConnection(config ConnectionConfig) error {

globalConnections[fmt.Sprintf("%s", config.Name)] = s

globalLogger.Infof("%s registered successfully.", config.Name)

return nil
}

Expand All @@ -113,7 +101,6 @@ func Insert(objs ...Entity) error {
if len(objs) == 0 {
return nil
}
globalLogger.Debugf("Going to insert %d objects", len(objs))
s := getSchemaFor(objs[0])
cols := s.Columns(false)
var values [][]interface{}
Expand Down Expand Up @@ -166,10 +153,8 @@ func isZero(val interface{}) bool {
// insert it.
func Save(obj Entity) error {
if isZero(getSchemaFor(obj).getPK(obj)) {
globalLogger.Debugf("Given object has no primary key set, going to insert it.")
return Insert(obj)
} else {
globalLogger.Debugf("Given object has primary key set, going for update.")
return Update(obj)
}
}
Expand All @@ -179,7 +164,7 @@ func Find[T Entity](id interface{}) (T, error) {
var q string
out := new(T)
md := getSchemaFor(*out)
q, args, err := NewQueryBuilder[T]().
q, args, err := NewQueryBuilder[T](md).
SetDialect(md.getDialect()).
Table(md.Table).
Select(md.Columns(true)...).
Expand Down Expand Up @@ -213,7 +198,7 @@ func toTuples(obj Entity, withPK bool) [][2]interface{} {
// Update given Entity in database.
func Update(obj Entity) error {
s := getSchemaFor(obj)
q, args, err := NewQueryBuilder[Entity]().SetDialect(s.getDialect()).Sets(toTuples(obj, false)...).Where(s.pkName(), genericGetPKValue(obj)).Table(s.Table).ToSql()
q, args, err := NewQueryBuilder[Entity](s).SetDialect(s.getDialect()).Sets(toTuples(obj, false)...).Where(s.pkName(), genericGetPKValue(obj)).Table(s.Table).ToSql()

if err != nil {
return err
Expand All @@ -226,7 +211,7 @@ func Update(obj Entity) error {
func Delete(obj Entity) error {
s := getSchemaFor(obj)
genericSet(obj, "deleted_at", sql.NullTime{Time: time.Now(), Valid: true})
query, args, err := NewQueryBuilder[Entity]().SetDialect(s.getDialect()).Table(s.Table).Where(s.pkName(), genericGetPKValue(obj)).SetDelete().ToSql()
query, args, err := NewQueryBuilder[Entity](s).SetDialect(s.getDialect()).Table(s.Table).Where(s.pkName(), genericGetPKValue(obj)).SetDelete().ToSql()
if err != nil {
return err
}
Expand All @@ -240,7 +225,7 @@ func bind[T Entity](output interface{}, q string, args []interface{}) error {
if err != nil {
return err
}
return newBinder[T](outputMD).bind(rows, output)
return newBinder(outputMD).bind(rows, output)
}

// HasManyConfig contains all information we need for querying HasMany relationships.
Expand All @@ -264,8 +249,9 @@ type HasManyConfig struct {
// HasMany[Comment](&Post{})
// is for Post HasMany Comment relationship.
func HasMany[PROPERTY Entity](owner Entity) *QueryBuilder[PROPERTY] {
q := NewQueryBuilder[PROPERTY]()
outSchema := getSchemaFor(*new(PROPERTY))

q := NewQueryBuilder[PROPERTY](outSchema)
// getting config from our cache
c, ok := getSchemaFor(owner).relations[outSchema.Table].(HasManyConfig)
if !ok {
Expand Down Expand Up @@ -300,8 +286,8 @@ type HasOneConfig struct {
// HasOne[HeaderPicture](&Post{})
// is for Post HasOne HeaderPicture relationship.
func HasOne[PROPERTY Entity](owner Entity) *QueryBuilder[PROPERTY] {
q := NewQueryBuilder[PROPERTY]()
property := getSchemaFor(*new(PROPERTY))
q := NewQueryBuilder[PROPERTY](property)
c, ok := getSchemaFor(owner).relations[property.Table].(HasOneConfig)
if !ok {
q.err = fmt.Errorf("wrong config passed for HasOne")
Expand Down Expand Up @@ -338,8 +324,8 @@ type BelongsToConfig struct {
// OWNER type parameter and property argument, so
// property BelongsTo OWNER.
func BelongsTo[OWNER Entity](property Entity) *QueryBuilder[OWNER] {
q := NewQueryBuilder[OWNER]()
owner := getSchemaFor(*new(OWNER))
q := NewQueryBuilder[OWNER](owner)
c, ok := getSchemaFor(property).relations[owner.Table].(BelongsToConfig)
if !ok {
q.err = fmt.Errorf("wrong config passed for BelongsTo")
Expand Down Expand Up @@ -396,15 +382,16 @@ type BelongsToManyConfig struct {

// BelongsToMany configures a QueryBuilder for a BelongsToMany relationship
func BelongsToMany[OWNER Entity](property Entity) *QueryBuilder[OWNER] {
q := NewQueryBuilder[OWNER]()
out := new(OWNER)
c, ok := getSchemaFor(property).relations[getSchemaFor(*out).Table].(BelongsToManyConfig)
out := *new(OWNER)
outSchema := getSchemaFor(out)
q := NewQueryBuilder[OWNER](outSchema)
c, ok := getSchemaFor(property).relations[outSchema.Table].(BelongsToManyConfig)
if !ok {
q.err = fmt.Errorf("wrong config passed for HasMany")
}
return q.
Select(getSchemaFor(*out).Columns(true)...).
Table(getSchemaFor(*out).Table).
Select(outSchema.Columns(true)...).
Table(outSchema.Table).
WhereIn(c.OwnerLookupColumn, Raw(fmt.Sprintf(`SELECT %s FROM %s WHERE %s = ?`,
c.IntermediateOwnerID,
c.IntermediateTable, c.IntermediatePropertyID), genericGetPKValue(property)))
Expand Down Expand Up @@ -501,8 +488,8 @@ func addProperty(to Entity, items ...Entity) error {

// Query creates a new QueryBuilder for given type parameter, sets dialect and table as well.
func Query[E Entity]() *QueryBuilder[E] {
q := NewQueryBuilder[E]()
s := getSchemaFor(*new(E))
q := NewQueryBuilder[E](s)
q.SetDialect(s.getDialect()).Table(s.Table)
return q
}
Expand Down Expand Up @@ -537,7 +524,7 @@ func QueryRaw[OUTPUT Entity](q string, args ...interface{}) ([]OUTPUT, error) {
return nil, err
}
var output []OUTPUT
err = newBinder[OUTPUT](getSchemaFor(*o)).bind(rows, &output)
err = newBinder(getSchemaFor(*o)).bind(rows, &output)
if err != nil {
return nil, err
}
Expand Down
Loading

0 comments on commit 2b2768b

Please sign in to comment.