Skip to content

Commit

Permalink
tired
Browse files Browse the repository at this point in the history
  • Loading branch information
amirrezaask committed Nov 20, 2021
1 parent d1378b4 commit 7107956
Show file tree
Hide file tree
Showing 7 changed files with 186 additions and 63 deletions.
1 change: 1 addition & 0 deletions .gitignore
Original file line number Diff line number Diff line change
@@ -1,2 +1,3 @@
**/.idea/*
cover.out
**db
61 changes: 32 additions & 29 deletions binder.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@ package orm

import (
"database/sql"
"fmt"
"reflect"
"unsafe"
)
Expand Down Expand Up @@ -43,46 +44,48 @@ func (o *ObjectMetadata) ptrsFor(v reflect.Value, cts []*sql.ColumnType) []inter
return scanInto
}

// Bind binds given rows to the given object at v.
func (o *ObjectMetadata) Bind(rows *sql.Rows, v interface{}) error {
// Bind binds given rows to the given object at obj. obj should be a pointer
func (o *ObjectMetadata) Bind(rows *sql.Rows, obj interface{}) error {
cts, err := rows.ColumnTypes()
if err != nil {
return err
}

t := reflect.TypeOf(v)
vt := reflect.ValueOf(v)

if t.Kind() == reflect.Ptr {
vt = vt.Elem()
t = t.Elem()
t := reflect.TypeOf(obj)
v := reflect.ValueOf(obj)
if t.Kind() != reflect.Ptr {
return fmt.Errorf("obj should be a ptr")
}

var inputs [][]interface{}
if t.Kind() != reflect.Slice {
inputs = append(inputs, o.ptrsFor(reflect.ValueOf(v), cts))
} else {
for i := 0; i < vt.Len(); i++ {
p := vt.Index(i).Elem()
if p.Type().Kind() == reflect.Ptr {
p = p.Elem()
t = t.Elem()
v = v.Elem()
if t.Kind() == reflect.Slice {
t = t.Elem()
for rows.Next() {
var rowValue reflect.Value
if t.Kind() == reflect.Ptr {
rowValue = reflect.New(t.Elem())
} else {
rowValue = reflect.New(t)
}
newCts := make([]*sql.ColumnType, len(cts))
copy(newCts, cts)
ptrs := o.ptrsFor(p, newCts)
inputs = append(inputs, ptrs)
ptrs := o.ptrsFor(rowValue, newCts)
err = rows.Scan(ptrs...)
if err != nil {
return err
}
v = reflect.Append(v, rowValue)
}
}

i := 0
for rows.Next() && i < len(inputs) {
err = rows.Scan(inputs[i]...)
if err != nil {
return err
} else {
for rows.Next() {
ptrs := o.ptrsFor(v, cts)
err = rows.Scan(ptrs...)
if err != nil {
return err
}
}
i++
}

// v is either struct or slice
reflect.ValueOf(obj).Elem().Set(v)
return nil

}
32 changes: 14 additions & 18 deletions binder_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -55,15 +55,13 @@ func TestBind(t *testing.T) {
rows, err := db.Query(`SELECT * FROM users`)
assert.NoError(t, err)

amirreza := &User{}
milad := &User{}
md := ObjectMetadataFrom(amirreza, Sqlite3SQLDialect)

err = md.Bind(rows, []interface{}{amirreza, milad})
md := ObjectMetadataFrom(&User{}, Sqlite3SQLDialect)
var users []*User
err = md.Bind(rows, &users)
assert.NoError(t, err)

assert.Equal(t, "amirreza", amirreza.Name)
assert.Equal(t, "milad", milad.Name)
assert.Equal(t, "amirreza", users[0].Name)
assert.Equal(t, "milad", users[1].Name)
})
}

Expand Down Expand Up @@ -100,18 +98,16 @@ func TestBindNested(t *testing.T) {
rows, err := db.Query(`SELECT users.id, users.name, addresses.id, addresses.path FROM users INNER JOIN addresses ON addresses.user_id = users.id`)
assert.NoError(t, err)

amirreza := &ComplexUser{}
milad := &ComplexUser{}
md := ObjectMetadataFrom(amirreza, Sqlite3SQLDialect)

err = md.Bind(rows, []*ComplexUser{amirreza, milad})
md := ObjectMetadataFrom(&ComplexUser{}, Sqlite3SQLDialect)
var users []*ComplexUser
err = md.Bind(rows, &users)
assert.NoError(t, err)

assert.Equal(t, "amirreza", amirreza.Name)
assert.Equal(t, "milad", milad.Name)
assert.Equal(t, "kianpars", amirreza.Address.Path)
assert.Equal(t, "delfan", milad.Address.Path)
assert.Equal(t, 2, milad.Address.ID)
assert.Equal(t, 1, amirreza.Address.ID)
assert.Equal(t, "amirreza", users[0].Name)
assert.Equal(t, "milad", users[1].Name)
assert.Equal(t, "kianpars", users[0].Address.Path)
assert.Equal(t, "delfan", users[1].Address.Path)
assert.Equal(t, 2, users[1].Address.ID)
assert.Equal(t, 1, users[0].Address.ID)
})
}
32 changes: 32 additions & 0 deletions example_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -92,3 +92,35 @@ func TestExampleRepositoriesWithRelationHasOne(t *testing.T) {
assert.Equal(t, "amirreza", firstUser.Name)
assert.Equal(t, "ahvaz", firstUser.Address.AddressContent.Content)
}

func TestEntity_HasMany(t *testing.T) {
type Address struct {
ID int64
Content string
}
type User struct {
ID int64
Name string
Age int
Address Address `orm:"in_rel=true has=one left=id right=user_id"`
}
db, mockDB, err := sqlmock.New()
assert.NoError(t, err)
// create the repository using database connection and an instance of the type representing the table in database.
userRepository := orm.NewRepository(db, orm.PostgreSQLDialect, &User{})
firstUser := &User{
ID: 1,
}
var addresses []*Address
mockDB.ExpectQuery(`SELECT addresses.id, addresses.content FROM addresses`).
WithArgs(1).
WillReturnRows(sqlmock.NewRows([]string{"addresses.id", "addresses.content"}).
AddRow(1, "ahvaz"))

err = userRepository.Entity(firstUser).HasMany(&addresses)
assert.NoError(t, err)
assert.Len(t, addresses, 1)
}
func TestEntity_HasOne(t *testing.T) {

}
81 changes: 81 additions & 0 deletions model.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,81 @@
package orm

import (
"fmt"
"reflect"

"github.com/golobby/orm/qb"
)

type Entity struct {
repo *Repository
obj interface{}
}

func (r *Repository) Entity(obj interface{}) *Entity {
return &Entity{r, obj}
}

func (e *Entity) HasMany(out interface{}) error {
t := reflect.TypeOf(out)
v := reflect.ValueOf(out)
if t.Kind() == reflect.Ptr {
t = t.Elem()
v = v.Elem()
}
if t.Kind() == reflect.Slice {
t = t.Elem()
}
target := reflect.New(t).Interface()
repo := NewRepository(e.repo.conn, e.repo.dialect, target)

var q string
var args []interface{}
for _, field := range e.repo.metadata.Fields {
if !field.IsRel {
continue
}
if field.RelationMetadata.Table == repo.metadata.Table {
q, args = qb.NewSelect().
From(field.RelationMetadata.Table).
Select(field.RelationMetadata.objectMetadata.Columns(true)...).
Where(qb.WhereHelpers.Equal(field.RelationMetadata.LeftColumn, field.RelationMetadata.RightColumn)).
WithArgs(e.repo.getPkValue(e.obj)).
Build()
}
}
if q == "" {
return fmt.Errorf("cannot build the query")
}
return repo.Bind(out, q, args...)
}
func (e *Entity) HasOne(out interface{}) error {
t := reflect.TypeOf(out)
v := reflect.ValueOf(out)
if t.Kind() == reflect.Ptr {
t = t.Elem()
v = v.Elem()
}
target := reflect.New(t).Interface()
repo := NewRepository(e.repo.conn, e.repo.dialect, target)

var q string
var args []interface{}
for _, field := range e.repo.metadata.Fields {
if !field.IsRel {
continue
}
if field.RelationMetadata.Table == repo.metadata.Table {
q, args = qb.NewSelect().
From(field.RelationMetadata.Table).
Select(field.RelationMetadata.objectMetadata.Columns(true)...).
Where(qb.WhereHelpers.Equal(field.RelationMetadata.LeftColumn, field.RelationMetadata.RightColumn)).
WithArgs(e.repo.getPkValue(e.obj)).
Build()
}
}
if q == "" {
return fmt.Errorf("cannot build the query")
}
return repo.Bind(out, q, args...)
}
24 changes: 13 additions & 11 deletions obj.go
Original file line number Diff line number Diff line change
Expand Up @@ -12,11 +12,11 @@ import (
"github.com/iancoleman/strcase"
)

// Entity interface is for sake of documentation, if you want to change orm behaviour for:
// IEntity interface is for sake of documentation, if you want to change orm behaviour for:
// Table name generation -> implement Table for your model
// GetPKValue -> returns value of primary key of model, implementing this helps with performance.
// SetPKValue -> sets the value of primary key of mode, implementing this helps with performance.
type Entity interface {
type IEntity interface {
Table
GetPKValue
SetPKValue
Expand Down Expand Up @@ -68,7 +68,7 @@ func tableName(v interface{}) string {
return hv.Table()
}
t := reflect.TypeOf(v)
if t.Kind() == reflect.Ptr {
for t.Kind() == reflect.Ptr {
t = t.Elem()
}

Expand All @@ -77,8 +77,8 @@ func tableName(v interface{}) string {
return strcase.ToSnake(pluralize.NewClient().Plural(name))
}

func (r *Repository) pkName(v interface{}) string {
for _, field := range r.metadata.Fields {
func (o *ObjectMetadata) pkName() string {
for _, field := range o.Fields {
if field.IsPK {
return field.Name
}
Expand Down Expand Up @@ -243,19 +243,19 @@ func fieldMetadataFromTag(t string) FieldTag {
}
return tag
}

func tableFromTypeName(name string) string {
return strcase.ToSnake(pluralize.NewClient().Plural(name))
}
func fieldsOf(obj interface{}, dialect *Dialect) []*FieldMetadata {
hasFields, is := obj.(HasFields)
if is {
return hasFields.Fields()
}
t := reflect.TypeOf(obj)
if t.Kind() == reflect.Ptr {
t = t.Elem()
}
if t.Kind() == reflect.Slice {
for t.Kind() == reflect.Ptr {
t = t.Elem()
}

var fms []*FieldMetadata
for i := 0; i < t.NumField(); i++ {
ft := t.Field(i)
Expand All @@ -273,11 +273,13 @@ func fieldsOf(obj interface{}, dialect *Dialect) []*FieldMetadata {
if tagParsed.InRel == true {
fm.IsRel = true

fm.RelationMetadata = &RelationMetadata{}
fm.RelationMetadata = &RelationMetadata{Type: RelationTypeHasOne}
fm.RelationMetadata.objectMetadata = ObjectMetadataFrom(reflect.New(ft.Type).Interface(), dialect)

if tagParsed.With != "" {
fm.RelationMetadata.Table = tagParsed.With
} else {
fm.RelationMetadata.Table = tableFromTypeName(ft.Name)
}
if tagParsed.Left != "" {
fm.RelationMetadata.LeftColumn = tagParsed.Left
Expand Down
Loading

0 comments on commit 7107956

Please sign in to comment.