Skip to content

Commit

Permalink
get by index
Browse files Browse the repository at this point in the history
  • Loading branch information
latolukasz committed Feb 19, 2024
1 parent 3f981b7 commit 3814a8c
Show file tree
Hide file tree
Showing 12 changed files with 137 additions and 24 deletions.
6 changes: 5 additions & 1 deletion db.go
Original file line number Diff line number Diff line change
Expand Up @@ -204,7 +204,11 @@ type rowsStruct struct {
}

func (r *rowsStruct) Next() bool {
return r.sqlRows.Next()
has := r.sqlRows.Next()
if !has {
_ = r.sqlRows.Close()
}
return has
}

func (r *rowsStruct) Columns() []string {
Expand Down
4 changes: 2 additions & 2 deletions editable_entity.go
Original file line number Diff line number Diff line change
Expand Up @@ -293,8 +293,8 @@ func initNewEntity(elem reflect.Value, fields *tableFields) {
}
}

func IsDirty[E any](orm ORM, id uint64) (oldValues, newValues Bind, hasChanges bool) {
return isDirty(orm, getEntitySchema[E](orm), id)
func IsDirty[E any, I ID](orm ORM, id I) (oldValues, newValues Bind, hasChanges bool) {
return isDirty(orm, getEntitySchema[E](orm), uint64(id))
}

func isDirty(orm ORM, schema *entitySchema, id uint64) (oldValues, newValues Bind, hasChanges bool) {
Expand Down
64 changes: 64 additions & 0 deletions entity_iterator.go
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,8 @@ import (

type EntityIterator[E any] interface {
Next() bool
ID() uint64
Index() int
Len() int
Entity() *E
All() []*E
Expand All @@ -17,6 +19,8 @@ type EntityIterator[E any] interface {

type EntityAnonymousIterator interface {
Next() bool
ID() uint64
Index() int
Len() int
Entity() any
Reset()
Expand All @@ -40,6 +44,17 @@ func (lc *localCacheIDsIterator[E]) Next() bool {
return true
}

func (lc *localCacheIDsIterator[E]) Index() int {
return lc.index
}

func (lc *localCacheIDsIterator[E]) ID() uint64 {
if lc.index == -1 {
return 0
}
return lc.ids[lc.index]
}

func (lc *localCacheIDsIterator[E]) Len() int {
return len(lc.ids)
}
Expand Down Expand Up @@ -140,6 +155,14 @@ func (el *emptyResultsIterator[E]) Next() bool {
return false
}

func (el *emptyResultsIterator[E]) Index() int {
return -1
}

func (el *emptyResultsIterator[E]) ID() uint64 {
return 0
}

func (el *emptyResultsIterator[E]) Len() int {
return 0
}
Expand Down Expand Up @@ -172,6 +195,17 @@ func (ei *entityIterator[E]) Next() bool {
return true
}

func (ei *entityIterator[E]) ID() uint64 {
if ei.index == -1 {
return 0
}
return reflect.ValueOf(ei.rows[ei.index]).Elem().FieldByName("ID").Uint()
}

func (ei *entityIterator[E]) Index() int {
return ei.index
}

func (ei *entityIterator[E]) Len() int {
return len(ei.rows)
}
Expand Down Expand Up @@ -209,6 +243,17 @@ func (ea *entityAnonymousIterator) Next() bool {
return true
}

func (ea *entityAnonymousIterator) ID() uint64 {
if ea.index == -1 {
return 0
}
return ea.rows.Index(ea.index).Elem().FieldByName("ID").Uint()
}

func (ea *entityAnonymousIterator) Index() int {
return ea.index
}

func (ea *entityAnonymousIterator) Len() int {
return ea.rows.Len()
}
Expand All @@ -230,6 +275,14 @@ func (el *emptyResultsAnonymousIterator) Next() bool {
return false
}

func (el *emptyResultsAnonymousIterator) ID() uint64 {
return 0
}

func (el *emptyResultsAnonymousIterator) Index() int {
return -1
}

func (el *emptyResultsAnonymousIterator) Len() int {
return 0
}
Expand Down Expand Up @@ -258,6 +311,17 @@ func (lc *localCacheIDsAnonymousIterator) Next() bool {
return true
}

func (lc *localCacheIDsAnonymousIterator) ID() uint64 {
if lc.index == -1 {
return 0
}
return lc.ids[lc.index]
}

func (lc *localCacheIDsAnonymousIterator) Index() int {
return lc.index
}

func (lc *localCacheIDsAnonymousIterator) Len() int {
return len(lc.ids)
}
Expand Down
20 changes: 16 additions & 4 deletions entity_schema.go
Original file line number Diff line number Diff line change
Expand Up @@ -399,10 +399,6 @@ func (e *entitySchema) init(registry *registry, entityType reflect.Type) error {
e.uniqueIndices[name][i-1] = index[i]
}
}
err := e.validateIndexes(uniqueIndices, indices)
if err != nil {
return err
}
for indexName, indexColumns := range indices {
where := ""
for i := 0; i < len(indexColumns); i++ {
Expand All @@ -427,6 +423,10 @@ func (e *entitySchema) init(registry *registry, entityType reflect.Type) error {
e.cachedIndexes[indexName] = definition
}
}
err := e.validateIndexes(uniqueIndices, indices)
if err != nil {
return err
}
for _, plugin := range registry.plugins {
pluginInterfaceValidateEntitySchema, isInterface := plugin.(PluginInterfaceValidateEntitySchema)
if isInterface {
Expand Down Expand Up @@ -462,6 +462,18 @@ func (e *entitySchema) validateIndexes(uniqueIndices map[string]map[int]string,
break
}
if same == len(v) {
def, found := e.indexes[k]
if found {
def.Duplicated = true
e.indexes[k] = def
break
}
def, found = e.indexes[k2]
if found {
def.Duplicated = true
e.indexes[k2] = def
break
}
return fmt.Errorf("duplicated index %s with %s in %s", k, k2, e.t.String())
}
}
Expand Down
1 change: 1 addition & 0 deletions flush_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -219,6 +219,7 @@ func testFlushInsert(t *testing.T, async, local, redis bool) {

// Adding empty entity
newEntity := schema.NewEntity(orm).(*flushEntity)
assert.NotNil(t, newEntity.BoolArray)
newEntity.ReferenceRequired = Reference[flushEntityReference](reference.ID)
newEntity.Name = "Name"
assert.NotEmpty(t, newEntity.ID)
Expand Down
6 changes: 3 additions & 3 deletions get_by_id.go
Original file line number Diff line number Diff line change
Expand Up @@ -6,22 +6,22 @@ import (
"strconv"
)

func MustByID[E any](orm ORM, id uint64) *E {
func MustByID[E any, I ID](orm ORM, id I) *E {
entity, found := GetByID[E](orm, id)
if !found {
panic(fmt.Errorf("entity withd ID %d not found", id))
}
return entity
}

func GetByID[E any](orm ORM, id uint64) (entity *E, found bool) {
func GetByID[E any, I ID](orm ORM, id I) (entity *E, found bool) {
var e E
cE := orm.(*ormImplementation)
schema := cE.engine.registry.entitySchemas[reflect.TypeOf(e)]
if schema == nil {
panic(fmt.Errorf("entity '%T' is not registered", e))
}
value, found := getByID(cE, id, schema)
value, found := getByID(cE, uint64(id), schema)
if value == nil {
return nil, false
}
Expand Down
7 changes: 4 additions & 3 deletions get_by_index.go
Original file line number Diff line number Diff line change
Expand Up @@ -11,9 +11,10 @@ import (
)

type indexDefinition struct {
Cached bool
Columns []string
Where string
Cached bool
Columns []string
Where string
Duplicated bool
}

func (d indexDefinition) CreteWhere(hasNil bool, attributes []any) Where {
Expand Down
10 changes: 6 additions & 4 deletions get_by_index_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -8,10 +8,11 @@ import (
)

type getByIndexEntity struct {
ID uint64 `orm:"localCache;redisCache"`
Name string `orm:"index=Name"`
Age uint32 `orm:"index=Age;cached"`
Born *time.Time `orm:"index=Age:2;cached"`
ID uint64 `orm:"localCache;redisCache"`
Name string `orm:"index=Name"`
Age uint32 `orm:"index=Age;unique=Fake;cached"`
Born *time.Time `orm:"index=Age:2;unique=Fake:2;cached"`
Other int `orm:"unique=Fake:3"`
}

func TestGetByIndexNoCache(t *testing.T) {
Expand Down Expand Up @@ -53,6 +54,7 @@ func testGetByIndex(t *testing.T, local, redis bool) {
for i := 0; i < 10; i++ {
entity = NewEntity[getByIndexEntity](orm)
entity.Age = 10
entity.Other = i
if i >= 5 {
entity.Name = "Test Name"
entity.Age = 18
Expand Down
4 changes: 2 additions & 2 deletions get_by_reference.go
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@ import (

const redisValidSetValue = "Y"

func GetByReference[E any](orm ORM, referenceName string, id uint64) EntityIterator[E] {
func GetByReference[E any, I ID](orm ORM, referenceName string, id I) EntityIterator[E] {
if id == 0 {
return nil
}
Expand All @@ -25,7 +25,7 @@ func GetByReference[E any](orm ORM, referenceName string, id uint64) EntityItera
if !def.Cached {
return Search[E](orm, NewWhere("`"+referenceName+"` = ?", id), nil)
}
return getCachedByReference[E](orm, referenceName, id, schema)
return getCachedByReference[E](orm, referenceName, uint64(id), schema)
}

func getCachedByReference[E any](orm ORM, key string, id uint64, schema *entitySchema) EntityIterator[E] {
Expand Down
4 changes: 4 additions & 0 deletions orm.go
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,10 @@ import (
"github.com/puzpuzpuz/xsync/v2"
)

type ID interface {
int | uint | uint8 | uint16 | uint32 | uint64 | int8 | int16 | int32 | int64
}

type Meta map[string]string

func (m Meta) Get(key string) string {
Expand Down
3 changes: 3 additions & 0 deletions registry.go
Original file line number Diff line number Diff line change
Expand Up @@ -144,6 +144,9 @@ func (r *registry) Validate() (Engine, error) {
e.registry.entitySchemasQuickMap[reflect.PointerTo(entityType)] = schema
e.registry.entities[name] = entityType
if schema.hasLocalCache {
if r.localCaches == nil {
r.localCaches = make(map[string]LocalCache)
}
r.localCaches[schema.getCacheKey()] = newLocalCache(schema.getCacheKey(), schema.localCacheLimit, schema)
}
extractEnums(schema.fields, e.registry)
Expand Down
32 changes: 27 additions & 5 deletions schema.go
Original file line number Diff line number Diff line change
Expand Up @@ -52,7 +52,9 @@ func (td *TableSQLSchemaDefinition) CreateTableSQL() string {
}
var indexDefinitions []string
for _, indexEntity := range td.EntityIndexes {
indexDefinitions = append(indexDefinitions, buildCreateIndexSQL(indexEntity))
if !indexEntity.Duplicated {
indexDefinitions = append(indexDefinitions, buildCreateIndexSQL(indexEntity))
}
}
sort.Strings(indexDefinitions)
for _, value := range indexDefinitions {
Expand All @@ -77,6 +79,7 @@ func (td *TableSQLSchemaDefinition) CreateTableSQL() string {
type IndexSchemaDefinition struct {
Name string
Unique bool
Duplicated bool
columnsMap map[int]string
}

Expand Down Expand Up @@ -194,6 +197,18 @@ func getSchemaChanges(orm ORM, entitySchema *entitySchema) (preAlters, alters, p
pool.QueryRow(orm, NewWhere(fmt.Sprintf("SHOW CREATE TABLE `%s`", entitySchema.GetTableName())), &skip, &sqlSchema.DBCreateSchema)
lines := strings.Split(sqlSchema.DBCreateSchema, "\n")
for x := 1; x < len(lines); x++ {
l := strings.Trim(lines[x], " ")
if strings.HasPrefix(l, "CONSTRAINT ") {
alter := fmt.Sprintf("ALTER TABLE `%s`.`%s`\n", pool.GetConfig().GetDatabaseName(), entitySchema.GetTableName())
parts := strings.Split(l, " ")
alter += " DROP FOREIGN KEY " + parts[1] + ";"
preAlters = append(preAlters, Alter{
SQL: alter,
Safe: true,
Pool: pool.GetConfig().GetCode(),
})
continue
}
if lines[x][2] != 96 {
for _, field := range strings.Split(lines[x], " ") {
if strings.HasPrefix(field, "CHARSET=") {
Expand Down Expand Up @@ -327,7 +342,7 @@ OUTER:
break
}
}
if !hasIndex {
if !hasIndex && !indexEntity.Duplicated {
newIndexes = append(newIndexes, buildCreateIndexSQL(indexEntity))
hasAlters = true
}
Expand All @@ -339,7 +354,7 @@ OUTER:
}
hasIndex := false
for _, index := range sqlSchema.EntityIndexes {
if index.Name == key.Name {
if index.Name == key.Name && !index.Duplicated {
hasIndex = true
break
}
Expand Down Expand Up @@ -491,6 +506,10 @@ func checkColumn(engine Engine, schema *entitySchema, field *reflect.StructField
current, has := indexes[indexColumn[0]]
if !has {
current = &IndexSchemaDefinition{Name: indexColumn[0], Unique: unique, columnsMap: map[int]string{location: prefix + field.Name}}
schemaDef, hasDef := schema.indexes[indexColumn[0]]
if hasDef && schemaDef.Duplicated {
current.Duplicated = true
}
indexes[indexColumn[0]] = current
} else {
current.columnsMap[location] = prefix + field.Name
Expand Down Expand Up @@ -565,7 +584,8 @@ func checkColumn(engine Engine, schema *entitySchema, field *reflect.StructField
columns = append(columns, structFields...)
continue
} else if fieldType.Implements(reflect.TypeOf((*referenceInterface)(nil)).Elem()) {
definition, addNotNullIfNotSet, defaultValue = handleInt("uint64", attributes, !isRequired)
refIDType := reflect.New(reflect.New(fieldType).Interface().(referenceInterface).getType()).Elem().FieldByName("ID").Type().String()
definition, addNotNullIfNotSet, defaultValue = handleInt(refIDType, attributes, !isRequired)
} else if fieldType.Implements(reflect.TypeOf((*EnumValues)(nil)).Elem()) {
def := reflect.New(fieldType).Interface().(EnumValues)
definition, addNotNullIfNotSet, addDefaultNullIfNullable, defaultValue, err = handleSetEnum("enum", fieldType.String(), schema, def, !isRequired)
Expand Down Expand Up @@ -599,7 +619,9 @@ func checkColumn(engine Engine, schema *entitySchema, field *reflect.StructField

func handleInt(typeAsString string, attributes map[string]string, nullable bool) (string, bool, string) {
if nullable {
typeAsString = typeAsString[1:]
if strings.HasPrefix(typeAsString, "*") {
typeAsString = typeAsString[1:]
}
return convertIntToSchema(typeAsString, attributes), false, "nil"
}
return convertIntToSchema(typeAsString, attributes), true, "'0'"
Expand Down

0 comments on commit 3814a8c

Please sign in to comment.