diff --git a/go.mod b/go.mod index 96a98c7407..8b8dcc58a4 100644 --- a/go.mod +++ b/go.mod @@ -79,7 +79,7 @@ require ( gopkg.in/square/go-jose.v2 v2.6.0 gopkg.in/yaml.v2 v2.4.0 gorm.io/driver/postgres v1.3.5 - gorm.io/gorm v1.24.5 + gorm.io/gorm v1.25.5 k8s.io/api v0.28.4 k8s.io/apiextensions-apiserver v0.28.2 k8s.io/apimachinery v0.28.4 @@ -182,7 +182,7 @@ require ( github.com/jackc/pgtype v1.14.0 // indirect github.com/jackc/pgx/v4 v4.16.0 // indirect github.com/jinzhu/inflection v1.0.0 // indirect - github.com/jinzhu/now v1.1.4 // indirect + github.com/jinzhu/now v1.1.5 // indirect github.com/jmespath/go-jmespath v0.4.0 // indirect github.com/josharian/intern v1.0.0 // indirect github.com/json-iterator/go v1.1.12 // indirect diff --git a/go.sum b/go.sum index 3e8a4e1989..ad1c9ecb9f 100644 --- a/go.sum +++ b/go.sum @@ -1063,8 +1063,9 @@ github.com/jinzhu/gorm v1.9.12/go.mod h1:vhTjlKSJUTWNtcbQtrMBFCxy7eXTzeCAzfL5fBZ github.com/jinzhu/inflection v1.0.0 h1:K317FqzuhWc8YvSVlFMCCUb36O/S9MCKRDI7QkRKD/E= github.com/jinzhu/inflection v1.0.0/go.mod h1:h+uFLlag+Qp1Va5pdKtLDYj+kHp5pxUVkryuEj+Srlc= github.com/jinzhu/now v1.0.1/go.mod h1:d3SSVoowX0Lcu0IBviAWJpolVfI5UJVZZ7cO71lE/z8= -github.com/jinzhu/now v1.1.4 h1:tHnRBy1i5F2Dh8BAFxqFzxKqqvezXrL2OW1TnX+Mlas= github.com/jinzhu/now v1.1.4/go.mod h1:d3SSVoowX0Lcu0IBviAWJpolVfI5UJVZZ7cO71lE/z8= +github.com/jinzhu/now v1.1.5 h1:/o9tlHleP7gOFmsnYNz3RGnqzefHA47wQpKrrdTIwXQ= +github.com/jinzhu/now v1.1.5/go.mod h1:d3SSVoowX0Lcu0IBviAWJpolVfI5UJVZZ7cO71lE/z8= github.com/jmespath/go-jmespath v0.0.0-20160202185014-0b12d6b521d8/go.mod h1:Nht3zPeWKUH0NzdCt2Blrr5ys8VGpn0CEB0cQHVjt7k= github.com/jmespath/go-jmespath v0.0.0-20160803190731-bd40a432e4c7/go.mod h1:Nht3zPeWKUH0NzdCt2Blrr5ys8VGpn0CEB0cQHVjt7k= github.com/jmespath/go-jmespath v0.0.0-20180206201540-c2b33e8439af/go.mod h1:Nht3zPeWKUH0NzdCt2Blrr5ys8VGpn0CEB0cQHVjt7k= @@ -2516,8 +2517,8 @@ gorm.io/driver/postgres v1.3.5/go.mod h1:EGCWefLFQSVFrHGy4J8EtiHCWX5Q8t0yz2Jt9aK gorm.io/driver/sqlite v1.3.1 h1:bwfE+zTEWklBYoEodIOIBwuWHpnx52Z9zJFW5F33WLk= gorm.io/driver/sqlserver v1.3.1 h1:F5t6ScMzOgy1zukRTIZgLZwKahgt3q1woAILVolKpOI= gorm.io/gorm v1.23.4/go.mod h1:l2lP/RyAtc1ynaTjFksBde/O8v9oOGIApu2/xRitmZk= -gorm.io/gorm v1.24.5 h1:g6OPREKqqlWq4kh/3MCQbZKImeB9e6Xgc4zD+JgNZGE= -gorm.io/gorm v1.24.5/go.mod h1:DVrVomtaYTbqs7gB/x2uVvqnXzv0nqjB396B8cG4dBA= +gorm.io/gorm v1.25.5 h1:zR9lOiiYf09VNh5Q1gphfyia1JpiClIWG9hQaxB/mls= +gorm.io/gorm v1.25.5/go.mod h1:hbnx/Oo0ChWMn1BIhpy1oYozzpM15i4YPuHDmfYtwg8= gotest.tools v2.2.0+incompatible h1:VsBPFP1AI068pPrMxtb/S8Zkgf9xEmTLJjfM+P5UIEo= gotest.tools v2.2.0+incompatible/go.mod h1:DsYFclhRJ6vuDpmuTbkuFWG+y2sxOXAzmJt81HFBacw= gotest.tools/v3 v3.0.2/go.mod h1:3SzNCllyD9/Y+b5r9JIKQ474KzkZyqLqEfYqMsX94Bk= diff --git a/vendor/github.com/jinzhu/now/README.md b/vendor/github.com/jinzhu/now/README.md index a1fc4daa56..e81d31de6a 100644 --- a/vendor/github.com/jinzhu/now/README.md +++ b/vendor/github.com/jinzhu/now/README.md @@ -72,13 +72,17 @@ Don't be bothered with the `WeekStartDay` setting, you can use `Monday`, `Sunday ```go now.Monday() // 2013-11-18 00:00:00 Mon +now.Monday("17:44") // 2013-11-18 17:44:00 Mon now.Sunday() // 2013-11-24 00:00:00 Sun (Next Sunday) +now.Sunday("18:19:24") // 2013-11-24 18:19:24 Sun (Next Sunday) now.EndOfSunday() // 2013-11-24 23:59:59.999999999 Sun (End of next Sunday) t := time.Date(2013, 11, 24, 17, 51, 49, 123456789, time.Now().Location()) // 2013-11-24 17:51:49.123456789 Sun -now.With(t).Monday() // 2013-11-18 00:00:00 Sun (Last Monday if today is Sunday) -now.With(t).Sunday() // 2013-11-24 00:00:00 Sun (Beginning Of Today if today is Sunday) -now.With(t).EndOfSunday() // 2013-11-24 23:59:59.999999999 Sun (End of Today if today is Sunday) +now.With(t).Monday() // 2013-11-18 00:00:00 Mon (Last Monday if today is Sunday) +now.With(t).Monday("17:44") // 2013-11-18 17:44:00 Mon (Last Monday if today is Sunday) +now.With(t).Sunday() // 2013-11-24 00:00:00 Sun (Beginning Of Today if today is Sunday) +now.With(t).Sunday("18:19:24") // 2013-11-24 18:19:24 Sun (Beginning Of Today if today is Sunday) +now.With(t).EndOfSunday() // 2013-11-24 23:59:59.999999999 Sun (End of Today if today is Sunday) ``` ### Parse String to Time diff --git a/vendor/github.com/jinzhu/now/main.go b/vendor/github.com/jinzhu/now/main.go index e972cd8d55..8f78bc7521 100644 --- a/vendor/github.com/jinzhu/now/main.go +++ b/vendor/github.com/jinzhu/now/main.go @@ -154,13 +154,14 @@ func EndOfYear() time.Time { } // Monday monday -func Monday() time.Time { - return With(time.Now()).Monday() + +func Monday(strs ...string) time.Time { + return With(time.Now()).Monday(strs...) } // Sunday sunday -func Sunday() time.Time { - return With(time.Now()).Sunday() +func Sunday(strs ...string) time.Time { + return With(time.Now()).Sunday(strs...) } // EndOfSunday end of sunday diff --git a/vendor/github.com/jinzhu/now/now.go b/vendor/github.com/jinzhu/now/now.go index d1fb8ae1d7..2f524cc8d4 100644 --- a/vendor/github.com/jinzhu/now/now.go +++ b/vendor/github.com/jinzhu/now/now.go @@ -108,6 +108,7 @@ func (now *Now) EndOfYear() time.Time { } // Monday monday +/* func (now *Now) Monday() time.Time { t := now.BeginningOfDay() weekday := int(t.Weekday()) @@ -116,15 +117,42 @@ func (now *Now) Monday() time.Time { } return t.AddDate(0, 0, -weekday+1) } +*/ -// Sunday sunday -func (now *Now) Sunday() time.Time { - t := now.BeginningOfDay() - weekday := int(t.Weekday()) +func (now *Now) Monday(strs ...string) time.Time { + var parseTime time.Time + var err error + if len(strs) > 0 { + parseTime, err = now.Parse(strs...) + if err != nil { + panic(err) + } + } else { + parseTime = now.BeginningOfDay() + } + weekday := int(parseTime.Weekday()) + if weekday == 0 { + weekday = 7 + } + return parseTime.AddDate(0, 0, -weekday+1) +} + +func (now *Now) Sunday(strs ...string) time.Time { + var parseTime time.Time + var err error + if len(strs) > 0 { + parseTime, err = now.Parse(strs...) + if err != nil { + panic(err) + } + } else { + parseTime = now.BeginningOfDay() + } + weekday := int(parseTime.Weekday()) if weekday == 0 { - return t + weekday = 7 } - return t.AddDate(0, 0, (7 - weekday)) + return parseTime.AddDate(0, 0, (7 - weekday)) } // EndOfSunday end of sunday diff --git a/vendor/gorm.io/gorm/License b/vendor/gorm.io/gorm/LICENSE similarity index 100% rename from vendor/gorm.io/gorm/License rename to vendor/gorm.io/gorm/LICENSE diff --git a/vendor/gorm.io/gorm/README.md b/vendor/gorm.io/gorm/README.md index 0c9ab74e94..85ad3050c3 100644 --- a/vendor/gorm.io/gorm/README.md +++ b/vendor/gorm.io/gorm/README.md @@ -35,9 +35,7 @@ The fantastic ORM library for Golang, aims to be developer friendly. ## Contributors -Thank you for contributing to the GORM framework! - -[![Contributors](https://contrib.rocks/image?repo=go-gorm/gorm)](https://github.com/go-gorm/gorm/graphs/contributors) +[Thank you](https://github.com/go-gorm/gorm/graphs/contributors) for contributing to the GORM framework! ## License diff --git a/vendor/gorm.io/gorm/association.go b/vendor/gorm.io/gorm/association.go index 06229caa78..7c93ebea0d 100644 --- a/vendor/gorm.io/gorm/association.go +++ b/vendor/gorm.io/gorm/association.go @@ -14,6 +14,7 @@ import ( type Association struct { DB *DB Relationship *schema.Relationship + Unscope bool Error error } @@ -40,6 +41,15 @@ func (db *DB) Association(column string) *Association { return association } +func (association *Association) Unscoped() *Association { + return &Association{ + DB: association.DB, + Relationship: association.Relationship, + Error: association.Error, + Unscope: true, + } +} + func (association *Association) Find(out interface{}, conds ...interface{}) error { if association.Error == nil { association.Error = association.buildCondition().Find(out, conds...).Error @@ -64,14 +74,30 @@ func (association *Association) Append(values ...interface{}) error { func (association *Association) Replace(values ...interface{}) error { if association.Error == nil { + reflectValue := association.DB.Statement.ReflectValue + rel := association.Relationship + + var oldBelongsToExpr clause.Expression + // we have to record the old BelongsTo value + if association.Unscope && rel.Type == schema.BelongsTo { + var foreignFields []*schema.Field + for _, ref := range rel.References { + if !ref.OwnPrimaryKey { + foreignFields = append(foreignFields, ref.ForeignKey) + } + } + if _, fvs := schema.GetIdentityFieldValuesMap(association.DB.Statement.Context, reflectValue, foreignFields); len(fvs) > 0 { + column, values := schema.ToQueryValues(rel.FieldSchema.Table, rel.FieldSchema.PrimaryFieldDBNames, fvs) + oldBelongsToExpr = clause.IN{Column: column, Values: values} + } + } + // save associations if association.saveAssociation( /*clear*/ true, values...); association.Error != nil { return association.Error } // set old associations's foreign key to null - reflectValue := association.DB.Statement.ReflectValue - rel := association.Relationship switch rel.Type { case schema.BelongsTo: if len(values) == 0 { @@ -91,6 +117,9 @@ func (association *Association) Replace(values ...interface{}) error { association.Error = association.DB.UpdateColumns(updateMap).Error } + if association.Unscope && oldBelongsToExpr != nil { + association.Error = association.DB.Model(nil).Where(oldBelongsToExpr).Delete(reflect.New(rel.FieldSchema.ModelType).Interface()).Error + } case schema.HasOne, schema.HasMany: var ( primaryFields []*schema.Field @@ -119,7 +148,11 @@ func (association *Association) Replace(values ...interface{}) error { if _, pvs := schema.GetIdentityFieldValuesMap(association.DB.Statement.Context, reflectValue, primaryFields); len(pvs) > 0 { column, values := schema.ToQueryValues(rel.FieldSchema.Table, foreignKeys, pvs) - association.Error = tx.Where(clause.IN{Column: column, Values: values}).UpdateColumns(updateMap).Error + if association.Unscope { + association.Error = tx.Where(clause.IN{Column: column, Values: values}).Delete(modelValue).Error + } else { + association.Error = tx.Where(clause.IN{Column: column, Values: values}).UpdateColumns(updateMap).Error + } } case schema.Many2Many: var ( @@ -184,7 +217,8 @@ func (association *Association) Delete(values ...interface{}) error { switch rel.Type { case schema.BelongsTo: - tx := association.DB.Model(reflect.New(rel.Schema.ModelType).Interface()) + associationDB := association.DB.Session(&Session{}) + tx := associationDB.Model(reflect.New(rel.Schema.ModelType).Interface()) _, pvs := schema.GetIdentityFieldValuesMap(association.DB.Statement.Context, reflectValue, rel.Schema.PrimaryFields) if pcolumn, pvalues := schema.ToQueryValues(rel.Schema.Table, rel.Schema.PrimaryFieldDBNames, pvs); len(pvalues) > 0 { @@ -198,8 +232,21 @@ func (association *Association) Delete(values ...interface{}) error { conds = append(conds, clause.IN{Column: relColumn, Values: relValues}) association.Error = tx.Clauses(conds...).UpdateColumns(updateAttrs).Error + if association.Unscope { + var foreignFields []*schema.Field + for _, ref := range rel.References { + if !ref.OwnPrimaryKey { + foreignFields = append(foreignFields, ref.ForeignKey) + } + } + if _, fvs := schema.GetIdentityFieldValuesMap(association.DB.Statement.Context, reflectValue, foreignFields); len(fvs) > 0 { + column, values := schema.ToQueryValues(rel.FieldSchema.Table, rel.FieldSchema.PrimaryFieldDBNames, fvs) + association.Error = associationDB.Model(nil).Where(clause.IN{Column: column, Values: values}).Delete(reflect.New(rel.FieldSchema.ModelType).Interface()).Error + } + } case schema.HasOne, schema.HasMany: - tx := association.DB.Model(reflect.New(rel.FieldSchema.ModelType).Interface()) + model := reflect.New(rel.FieldSchema.ModelType).Interface() + tx := association.DB.Model(model) _, pvs := schema.GetIdentityFieldValuesMap(association.DB.Statement.Context, reflectValue, primaryFields) if pcolumn, pvalues := schema.ToQueryValues(rel.FieldSchema.Table, foreignKeys, pvs); len(pvalues) > 0 { @@ -212,7 +259,11 @@ func (association *Association) Delete(values ...interface{}) error { relColumn, relValues := schema.ToQueryValues(rel.FieldSchema.Table, rel.FieldSchema.PrimaryFieldDBNames, rvs) conds = append(conds, clause.IN{Column: relColumn, Values: relValues}) - association.Error = tx.Clauses(conds...).UpdateColumns(updateAttrs).Error + if association.Unscope { + association.Error = tx.Clauses(conds...).Delete(model).Error + } else { + association.Error = tx.Clauses(conds...).UpdateColumns(updateAttrs).Error + } case schema.Many2Many: var ( primaryFields, relPrimaryFields []*schema.Field @@ -353,9 +404,13 @@ func (association *Association) saveAssociation(clear bool, values ...interface{ } case schema.HasMany, schema.Many2Many: elemType := association.Relationship.Field.IndirectFieldType.Elem() - fieldValue := reflect.Indirect(association.Relationship.Field.ReflectValueOf(association.DB.Statement.Context, source)) + oldFieldValue := reflect.Indirect(association.Relationship.Field.ReflectValueOf(association.DB.Statement.Context, source)) + var fieldValue reflect.Value if clear { - fieldValue = reflect.New(association.Relationship.Field.IndirectFieldType).Elem() + fieldValue = reflect.MakeSlice(oldFieldValue.Type(), 0, oldFieldValue.Cap()) + } else { + fieldValue = reflect.MakeSlice(oldFieldValue.Type(), oldFieldValue.Len(), oldFieldValue.Cap()) + reflect.Copy(fieldValue, oldFieldValue) } appendToFieldValues := func(ev reflect.Value) { diff --git a/vendor/gorm.io/gorm/callbacks.go b/vendor/gorm.io/gorm/callbacks.go index de979e4596..195d17203d 100644 --- a/vendor/gorm.io/gorm/callbacks.go +++ b/vendor/gorm.io/gorm/callbacks.go @@ -75,11 +75,7 @@ func (cs *callbacks) Raw() *processor { func (p *processor) Execute(db *DB) *DB { // call scopes for len(db.Statement.scopes) > 0 { - scopes := db.Statement.scopes - db.Statement.scopes = nil - for _, scope := range scopes { - db = scope(db) - } + db = db.executeScopes() } var ( @@ -253,7 +249,7 @@ func sortCallbacks(cs []*callback) (fns []func(*DB), err error) { names, sorted []string sortCallback func(*callback) error ) - sort.Slice(cs, func(i, j int) bool { + sort.SliceStable(cs, func(i, j int) bool { if cs[j].before == "*" && cs[i].before != "*" { return true } diff --git a/vendor/gorm.io/gorm/callbacks/associations.go b/vendor/gorm.io/gorm/callbacks/associations.go index 9d7c1412ec..f3cd464ae6 100644 --- a/vendor/gorm.io/gorm/callbacks/associations.go +++ b/vendor/gorm.io/gorm/callbacks/associations.go @@ -51,25 +51,40 @@ func SaveBeforeAssociations(create bool) func(db *gorm.DB) { } elems := reflect.MakeSlice(reflect.SliceOf(fieldType), 0, 10) + distinctElems := reflect.MakeSlice(reflect.SliceOf(fieldType), 0, 10) + identityMap := map[string]bool{} for i := 0; i < rValLen; i++ { obj := db.Statement.ReflectValue.Index(i) if reflect.Indirect(obj).Kind() != reflect.Struct { break } - if _, zero := rel.Field.ValueOf(db.Statement.Context, obj); !zero { // check belongs to relation value rv := rel.Field.ReflectValueOf(db.Statement.Context, obj) // relation reflect value + if !isPtr { + rv = rv.Addr() + } objs = append(objs, obj) - if isPtr { - elems = reflect.Append(elems, rv) - } else { - elems = reflect.Append(elems, rv.Addr()) + elems = reflect.Append(elems, rv) + + relPrimaryValues := make([]interface{}, 0, len(rel.FieldSchema.PrimaryFields)) + for _, pf := range rel.FieldSchema.PrimaryFields { + if pfv, ok := pf.ValueOf(db.Statement.Context, rv); !ok { + relPrimaryValues = append(relPrimaryValues, pfv) + } + } + cacheKey := utils.ToStringKey(relPrimaryValues...) + if len(relPrimaryValues) != len(rel.FieldSchema.PrimaryFields) || !identityMap[cacheKey] { + if cacheKey != "" { // has primary fields + identityMap[cacheKey] = true + } + + distinctElems = reflect.Append(distinctElems, rv) } } } if elems.Len() > 0 { - if saveAssociations(db, rel, elems, selectColumns, restricted, nil) == nil { + if saveAssociations(db, rel, distinctElems, selectColumns, restricted, nil) == nil { for i := 0; i < elems.Len(); i++ { setupReferences(objs[i], elems.Index(i)) } diff --git a/vendor/gorm.io/gorm/callbacks/callmethod.go b/vendor/gorm.io/gorm/callbacks/callmethod.go index bcaa03f3d9..fb90003790 100644 --- a/vendor/gorm.io/gorm/callbacks/callmethod.go +++ b/vendor/gorm.io/gorm/callbacks/callmethod.go @@ -13,11 +13,20 @@ func callMethod(db *gorm.DB, fc func(value interface{}, tx *gorm.DB) bool) { case reflect.Slice, reflect.Array: db.Statement.CurDestIndex = 0 for i := 0; i < db.Statement.ReflectValue.Len(); i++ { - fc(reflect.Indirect(db.Statement.ReflectValue.Index(i)).Addr().Interface(), tx) + if value := reflect.Indirect(db.Statement.ReflectValue.Index(i)); value.CanAddr() { + fc(value.Addr().Interface(), tx) + } else { + db.AddError(gorm.ErrInvalidValue) + return + } db.Statement.CurDestIndex++ } case reflect.Struct: - fc(db.Statement.ReflectValue.Addr().Interface(), tx) + if db.Statement.ReflectValue.CanAddr() { + fc(db.Statement.ReflectValue.Addr().Interface(), tx) + } else { + db.AddError(gorm.ErrInvalidValue) + } } } } diff --git a/vendor/gorm.io/gorm/callbacks/create.go b/vendor/gorm.io/gorm/callbacks/create.go index 0fe1dc93a0..f0b7813985 100644 --- a/vendor/gorm.io/gorm/callbacks/create.go +++ b/vendor/gorm.io/gorm/callbacks/create.go @@ -3,6 +3,7 @@ package callbacks import ( "fmt" "reflect" + "strings" "gorm.io/gorm" "gorm.io/gorm/clause" @@ -302,7 +303,8 @@ func ConvertToCreateValues(stmt *gorm.Statement) (values clause.Values) { for _, column := range values.Columns { if field := stmt.Schema.LookUpField(column.Name); field != nil { if v, ok := selectColumns[field.DBName]; (ok && v) || (!ok && !restricted) { - if !field.PrimaryKey && (!field.HasDefaultValue || field.DefaultValueInterface != nil) && field.AutoCreateTime == 0 { + if !field.PrimaryKey && (!field.HasDefaultValue || field.DefaultValueInterface != nil || + strings.EqualFold(field.DefaultValue, "NULL")) && field.AutoCreateTime == 0 { if field.AutoUpdateTime > 0 { assignment := clause.Assignment{Column: clause.Column{Name: field.DBName}, Value: curTime} switch field.AutoUpdateTime { diff --git a/vendor/gorm.io/gorm/callbacks/preload.go b/vendor/gorm.io/gorm/callbacks/preload.go index ea2570ba3e..15669c847b 100644 --- a/vendor/gorm.io/gorm/callbacks/preload.go +++ b/vendor/gorm.io/gorm/callbacks/preload.go @@ -3,6 +3,7 @@ package callbacks import ( "fmt" "reflect" + "strings" "gorm.io/gorm" "gorm.io/gorm/clause" @@ -10,6 +11,98 @@ import ( "gorm.io/gorm/utils" ) +// parsePreloadMap extracts nested preloads. e.g. +// +// // schema has a "k0" relation and a "k7.k8" embedded relation +// parsePreloadMap(schema, map[string][]interface{}{ +// clause.Associations: {"arg1"}, +// "k1": {"arg2"}, +// "k2.k3": {"arg3"}, +// "k4.k5.k6": {"arg4"}, +// }) +// // preloadMap is +// map[string]map[string][]interface{}{ +// "k0": {}, +// "k7": { +// "k8": {}, +// }, +// "k1": {}, +// "k2": { +// "k3": {"arg3"}, +// }, +// "k4": { +// "k5.k6": {"arg4"}, +// }, +// } +func parsePreloadMap(s *schema.Schema, preloads map[string][]interface{}) map[string]map[string][]interface{} { + preloadMap := map[string]map[string][]interface{}{} + setPreloadMap := func(name, value string, args []interface{}) { + if _, ok := preloadMap[name]; !ok { + preloadMap[name] = map[string][]interface{}{} + } + if value != "" { + preloadMap[name][value] = args + } + } + + for name, args := range preloads { + preloadFields := strings.Split(name, ".") + value := strings.TrimPrefix(strings.TrimPrefix(name, preloadFields[0]), ".") + if preloadFields[0] == clause.Associations { + for _, relation := range s.Relationships.Relations { + if relation.Schema == s { + setPreloadMap(relation.Name, value, args) + } + } + + for embedded, embeddedRelations := range s.Relationships.EmbeddedRelations { + for _, value := range embeddedValues(embeddedRelations) { + setPreloadMap(embedded, value, args) + } + } + } else { + setPreloadMap(preloadFields[0], value, args) + } + } + return preloadMap +} + +func embeddedValues(embeddedRelations *schema.Relationships) []string { + if embeddedRelations == nil { + return nil + } + names := make([]string, 0, len(embeddedRelations.Relations)+len(embeddedRelations.EmbeddedRelations)) + for _, relation := range embeddedRelations.Relations { + // skip first struct name + names = append(names, strings.Join(relation.Field.BindNames[1:], ".")) + } + for _, relations := range embeddedRelations.EmbeddedRelations { + names = append(names, embeddedValues(relations)...) + } + return names +} + +func preloadEmbedded(tx *gorm.DB, relationships *schema.Relationships, s *schema.Schema, preloads map[string][]interface{}, as []interface{}) error { + if relationships == nil { + return nil + } + preloadMap := parsePreloadMap(s, preloads) + for name := range preloadMap { + if embeddedRelations := relationships.EmbeddedRelations[name]; embeddedRelations != nil { + if err := preloadEmbedded(tx, embeddedRelations, s, preloadMap[name], as); err != nil { + return err + } + } else if rel := relationships.Relations[name]; rel != nil { + if err := preload(tx, rel, append(preloads[name], as), preloadMap[name]); err != nil { + return err + } + } else { + return fmt.Errorf("%s: %w (embedded) for schema %s", name, gorm.ErrUnsupportedRelation, s.Name) + } + } + return nil +} + func preload(tx *gorm.DB, rel *schema.Relationship, conds []interface{}, preloads map[string][]interface{}) error { var ( reflectValue = tx.Statement.ReflectValue diff --git a/vendor/gorm.io/gorm/callbacks/query.go b/vendor/gorm.io/gorm/callbacks/query.go index 97fe8a49c4..e89dd19960 100644 --- a/vendor/gorm.io/gorm/callbacks/query.go +++ b/vendor/gorm.io/gorm/callbacks/query.go @@ -8,6 +8,8 @@ import ( "gorm.io/gorm" "gorm.io/gorm/clause" + "gorm.io/gorm/schema" + "gorm.io/gorm/utils" ) func Query(db *gorm.DB) { @@ -109,86 +111,141 @@ func BuildQuerySQL(db *gorm.DB) { } } + specifiedRelationsName := make(map[string]interface{}) for _, join := range db.Statement.Joins { - if db.Statement.Schema == nil { - fromClause.Joins = append(fromClause.Joins, clause.Join{ - Expression: clause.NamedExpr{SQL: join.Name, Vars: join.Conds}, - }) - } else if relation, ok := db.Statement.Schema.Relationships.Relations[join.Name]; ok { - tableAliasName := relation.Name - - columnStmt := gorm.Statement{ - Table: tableAliasName, DB: db, Schema: relation.FieldSchema, - Selects: join.Selects, Omits: join.Omits, - } + if db.Statement.Schema != nil { + var isRelations bool // is relations or raw sql + var relations []*schema.Relationship + relation, ok := db.Statement.Schema.Relationships.Relations[join.Name] + if ok { + isRelations = true + relations = append(relations, relation) + } else { + // handle nested join like "Manager.Company" + nestedJoinNames := strings.Split(join.Name, ".") + if len(nestedJoinNames) > 1 { + isNestedJoin := true + gussNestedRelations := make([]*schema.Relationship, 0, len(nestedJoinNames)) + currentRelations := db.Statement.Schema.Relationships.Relations + for _, relname := range nestedJoinNames { + // incomplete match, only treated as raw sql + if relation, ok = currentRelations[relname]; ok { + gussNestedRelations = append(gussNestedRelations, relation) + currentRelations = relation.FieldSchema.Relationships.Relations + } else { + isNestedJoin = false + break + } + } - selectColumns, restricted := columnStmt.SelectAndOmitColumns(false, false) - for _, s := range relation.FieldSchema.DBNames { - if v, ok := selectColumns[s]; (ok && v) || (!ok && !restricted) { - clauseSelect.Columns = append(clauseSelect.Columns, clause.Column{ - Table: tableAliasName, - Name: s, - Alias: tableAliasName + "__" + s, - }) + if isNestedJoin { + isRelations = true + relations = gussNestedRelations + } } } - exprs := make([]clause.Expression, len(relation.References)) - for idx, ref := range relation.References { - if ref.OwnPrimaryKey { - exprs[idx] = clause.Eq{ - Column: clause.Column{Table: clause.CurrentTable, Name: ref.PrimaryKey.DBName}, - Value: clause.Column{Table: tableAliasName, Name: ref.ForeignKey.DBName}, + if isRelations { + genJoinClause := func(joinType clause.JoinType, parentTableName string, relation *schema.Relationship) clause.Join { + tableAliasName := relation.Name + if parentTableName != clause.CurrentTable { + tableAliasName = utils.NestedRelationName(parentTableName, tableAliasName) } - } else { - if ref.PrimaryValue == "" { - exprs[idx] = clause.Eq{ - Column: clause.Column{Table: clause.CurrentTable, Name: ref.ForeignKey.DBName}, - Value: clause.Column{Table: tableAliasName, Name: ref.PrimaryKey.DBName}, + + columnStmt := gorm.Statement{ + Table: tableAliasName, DB: db, Schema: relation.FieldSchema, + Selects: join.Selects, Omits: join.Omits, + } + + selectColumns, restricted := columnStmt.SelectAndOmitColumns(false, false) + for _, s := range relation.FieldSchema.DBNames { + if v, ok := selectColumns[s]; (ok && v) || (!ok && !restricted) { + clauseSelect.Columns = append(clauseSelect.Columns, clause.Column{ + Table: tableAliasName, + Name: s, + Alias: utils.NestedRelationName(tableAliasName, s), + }) } - } else { - exprs[idx] = clause.Eq{ - Column: clause.Column{Table: tableAliasName, Name: ref.ForeignKey.DBName}, - Value: ref.PrimaryValue, + } + + exprs := make([]clause.Expression, len(relation.References)) + for idx, ref := range relation.References { + if ref.OwnPrimaryKey { + exprs[idx] = clause.Eq{ + Column: clause.Column{Table: parentTableName, Name: ref.PrimaryKey.DBName}, + Value: clause.Column{Table: tableAliasName, Name: ref.ForeignKey.DBName}, + } + } else { + if ref.PrimaryValue == "" { + exprs[idx] = clause.Eq{ + Column: clause.Column{Table: parentTableName, Name: ref.ForeignKey.DBName}, + Value: clause.Column{Table: tableAliasName, Name: ref.PrimaryKey.DBName}, + } + } else { + exprs[idx] = clause.Eq{ + Column: clause.Column{Table: tableAliasName, Name: ref.ForeignKey.DBName}, + Value: ref.PrimaryValue, + } + } } } - } - } - { - onStmt := gorm.Statement{Table: tableAliasName, DB: db, Clauses: map[string]clause.Clause{}} - for _, c := range relation.FieldSchema.QueryClauses { - onStmt.AddClause(c) - } + { + onStmt := gorm.Statement{Table: tableAliasName, DB: db, Clauses: map[string]clause.Clause{}} + for _, c := range relation.FieldSchema.QueryClauses { + onStmt.AddClause(c) + } - if join.On != nil { - onStmt.AddClause(join.On) - } + if join.On != nil { + onStmt.AddClause(join.On) + } - if cs, ok := onStmt.Clauses["WHERE"]; ok { - if where, ok := cs.Expression.(clause.Where); ok { - where.Build(&onStmt) - - if onSQL := onStmt.SQL.String(); onSQL != "" { - vars := onStmt.Vars - for idx, v := range vars { - bindvar := strings.Builder{} - onStmt.Vars = vars[0 : idx+1] - db.Dialector.BindVarTo(&bindvar, &onStmt, v) - onSQL = strings.Replace(onSQL, bindvar.String(), "?", 1) + if cs, ok := onStmt.Clauses["WHERE"]; ok { + if where, ok := cs.Expression.(clause.Where); ok { + where.Build(&onStmt) + + if onSQL := onStmt.SQL.String(); onSQL != "" { + vars := onStmt.Vars + for idx, v := range vars { + bindvar := strings.Builder{} + onStmt.Vars = vars[0 : idx+1] + db.Dialector.BindVarTo(&bindvar, &onStmt, v) + onSQL = strings.Replace(onSQL, bindvar.String(), "?", 1) + } + + exprs = append(exprs, clause.Expr{SQL: onSQL, Vars: vars}) + } } - - exprs = append(exprs, clause.Expr{SQL: onSQL, Vars: vars}) } } + + return clause.Join{ + Type: joinType, + Table: clause.Table{Name: relation.FieldSchema.Table, Alias: tableAliasName}, + ON: clause.Where{Exprs: exprs}, + } } - } - fromClause.Joins = append(fromClause.Joins, clause.Join{ - Type: join.JoinType, - Table: clause.Table{Name: relation.FieldSchema.Table, Alias: tableAliasName}, - ON: clause.Where{Exprs: exprs}, - }) + parentTableName := clause.CurrentTable + for _, rel := range relations { + // joins table alias like "Manager, Company, Manager__Company" + nestedAlias := utils.NestedRelationName(parentTableName, rel.Name) + if _, ok := specifiedRelationsName[nestedAlias]; !ok { + fromClause.Joins = append(fromClause.Joins, genJoinClause(join.JoinType, parentTableName, rel)) + specifiedRelationsName[nestedAlias] = nil + } + + if parentTableName != clause.CurrentTable { + parentTableName = utils.NestedRelationName(parentTableName, rel.Name) + } else { + parentTableName = rel.Name + } + } + } else { + fromClause.Joins = append(fromClause.Joins, clause.Join{ + Expression: clause.NamedExpr{SQL: join.Name, Vars: join.Conds}, + }) + } } else { fromClause.Joins = append(fromClause.Joins, clause.Join{ Expression: clause.NamedExpr{SQL: join.Name, Vars: join.Conds}, @@ -215,32 +272,7 @@ func Preload(db *gorm.DB) { return } - preloadMap := map[string]map[string][]interface{}{} - for name := range db.Statement.Preloads { - preloadFields := strings.Split(name, ".") - if preloadFields[0] == clause.Associations { - for _, rel := range db.Statement.Schema.Relationships.Relations { - if rel.Schema == db.Statement.Schema { - if _, ok := preloadMap[rel.Name]; !ok { - preloadMap[rel.Name] = map[string][]interface{}{} - } - - if value := strings.TrimPrefix(strings.TrimPrefix(name, preloadFields[0]), "."); value != "" { - preloadMap[rel.Name][value] = db.Statement.Preloads[name] - } - } - } - } else { - if _, ok := preloadMap[preloadFields[0]]; !ok { - preloadMap[preloadFields[0]] = map[string][]interface{}{} - } - - if value := strings.TrimPrefix(strings.TrimPrefix(name, preloadFields[0]), "."); value != "" { - preloadMap[preloadFields[0]][value] = db.Statement.Preloads[name] - } - } - } - + preloadMap := parsePreloadMap(db.Statement.Schema, db.Statement.Preloads) preloadNames := make([]string, 0, len(preloadMap)) for key := range preloadMap { preloadNames = append(preloadNames, key) @@ -257,9 +289,12 @@ func Preload(db *gorm.DB) { return } preloadDB.Statement.ReflectValue = db.Statement.ReflectValue + preloadDB.Statement.Unscoped = db.Statement.Unscoped for _, name := range preloadNames { - if rel := preloadDB.Statement.Schema.Relationships.Relations[name]; rel != nil { + if relations := preloadDB.Statement.Schema.Relationships.EmbeddedRelations[name]; relations != nil { + db.AddError(preloadEmbedded(preloadDB.Table("").Session(&gorm.Session{Context: db.Statement.Context, SkipHooks: db.Statement.SkipHooks}), relations, db.Statement.Schema, preloadMap[name], db.Statement.Preloads[clause.Associations])) + } else if rel := preloadDB.Statement.Schema.Relationships.Relations[name]; rel != nil { db.AddError(preload(preloadDB.Table("").Session(&gorm.Session{Context: db.Statement.Context, SkipHooks: db.Statement.SkipHooks}), rel, append(db.Statement.Preloads[name], db.Statement.Preloads[clause.Associations]...), preloadMap[name])) } else { db.AddError(fmt.Errorf("%s: %w for schema %s", name, gorm.ErrUnsupportedRelation, db.Statement.Schema.Name)) diff --git a/vendor/gorm.io/gorm/callbacks/row.go b/vendor/gorm.io/gorm/callbacks/row.go index 56be742e81..beaa189e1c 100644 --- a/vendor/gorm.io/gorm/callbacks/row.go +++ b/vendor/gorm.io/gorm/callbacks/row.go @@ -7,7 +7,7 @@ import ( func RowQuery(db *gorm.DB) { if db.Error == nil { BuildQuerySQL(db) - if db.DryRun { + if db.DryRun || db.Error != nil { return } diff --git a/vendor/gorm.io/gorm/callbacks/update.go b/vendor/gorm.io/gorm/callbacks/update.go index b596df9a06..ff075dcf28 100644 --- a/vendor/gorm.io/gorm/callbacks/update.go +++ b/vendor/gorm.io/gorm/callbacks/update.go @@ -72,6 +72,7 @@ func Update(config *Config) func(db *gorm.DB) { db.Statement.AddClauseIfNotExists(clause.Update{}) if _, ok := db.Statement.Clauses["SET"]; !ok { if set := ConvertToAssignments(db.Statement); len(set) != 0 { + defer delete(db.Statement.Clauses, "SET") db.Statement.AddClause(set) } else { return @@ -137,7 +138,9 @@ func ConvertToAssignments(stmt *gorm.Statement) (set clause.Set) { case reflect.Slice, reflect.Array: assignValue = func(field *schema.Field, value interface{}) { for i := 0; i < stmt.ReflectValue.Len(); i++ { - field.Set(stmt.Context, stmt.ReflectValue.Index(i), value) + if stmt.ReflectValue.CanAddr() { + field.Set(stmt.Context, stmt.ReflectValue.Index(i), value) + } } } case reflect.Struct: @@ -243,11 +246,13 @@ func ConvertToAssignments(stmt *gorm.Statement) (set clause.Set) { } default: updatingSchema := stmt.Schema + var isDiffSchema bool if !updatingValue.CanAddr() || stmt.Dest != stmt.Model { // different schema updatingStmt := &gorm.Statement{DB: stmt.DB} if err := updatingStmt.Parse(stmt.Dest); err == nil { updatingSchema = updatingStmt.Schema + isDiffSchema = true } } @@ -274,7 +279,13 @@ func ConvertToAssignments(stmt *gorm.Statement) (set clause.Set) { if (ok || !isZero) && field.Updatable { set = append(set, clause.Assignment{Column: clause.Column{Name: field.DBName}, Value: value}) - assignValue(field, value) + assignField := field + if isDiffSchema { + if originField := stmt.Schema.LookUpField(dbName); originField != nil { + assignField = originField + } + } + assignValue(assignField, value) } } } else { diff --git a/vendor/gorm.io/gorm/chainable_api.go b/vendor/gorm.io/gorm/chainable_api.go index 676fe914c5..3dc7256e6c 100644 --- a/vendor/gorm.io/gorm/chainable_api.go +++ b/vendor/gorm.io/gorm/chainable_api.go @@ -60,7 +60,7 @@ var tableRegexp = regexp.MustCompile(`(?i)(?:.+? AS (\w+)\s*(?:$|,)|^\w+\s+(\w+) // Table specify the table you would like to run db operations // // // Get a user -// db.Table("users").take(&result) +// db.Table("users").Take(&result) func (db *DB) Table(name string, args ...interface{}) (tx *DB) { tx = db.getInstance() if strings.Contains(name, " ") || strings.Contains(name, "`") || len(args) > 0 { @@ -253,7 +253,10 @@ func joins(db *DB, joinType clause.JoinType, query string, args ...interface{}) if len(args) == 1 { if db, ok := args[0].(*DB); ok { - j := join{Name: query, Conds: args, Selects: db.Statement.Selects, Omits: db.Statement.Omits} + j := join{ + Name: query, Conds: args, Selects: db.Statement.Selects, + Omits: db.Statement.Omits, JoinType: joinType, + } if where, ok := db.Statement.Clauses["WHERE"].Expression.(clause.Where); ok { j.On = &where } @@ -363,6 +366,36 @@ func (db *DB) Scopes(funcs ...func(*DB) *DB) (tx *DB) { return tx } +func (db *DB) executeScopes() (tx *DB) { + tx = db.getInstance() + scopes := db.Statement.scopes + if len(scopes) == 0 { + return tx + } + tx.Statement.scopes = nil + + conditions := make([]clause.Interface, 0, 4) + if cs, ok := tx.Statement.Clauses["WHERE"]; ok && cs.Expression != nil { + conditions = append(conditions, cs.Expression.(clause.Interface)) + cs.Expression = nil + tx.Statement.Clauses["WHERE"] = cs + } + + for _, scope := range scopes { + tx = scope(tx) + if cs, ok := tx.Statement.Clauses["WHERE"]; ok && cs.Expression != nil { + conditions = append(conditions, cs.Expression.(clause.Interface)) + cs.Expression = nil + tx.Statement.Clauses["WHERE"] = cs + } + } + + for _, condition := range conditions { + tx.Statement.AddClause(condition) + } + return tx +} + // Preload preload associations with given conditions // // // get all users, and preload all non-cancelled orders diff --git a/vendor/gorm.io/gorm/clause/expression.go b/vendor/gorm.io/gorm/clause/expression.go index 92ac7f2238..3140846ef9 100644 --- a/vendor/gorm.io/gorm/clause/expression.go +++ b/vendor/gorm.io/gorm/clause/expression.go @@ -126,7 +126,7 @@ func (expr NamedExpr) Build(builder Builder) { for _, v := range []byte(expr.SQL) { if v == '@' && !inName { inName = true - name = []byte{} + name = name[:0] } else if v == ' ' || v == ',' || v == ')' || v == '"' || v == '\'' || v == '`' || v == '\r' || v == '\n' || v == ';' { if inName { if nv, ok := namedMap[string(name)]; ok { @@ -246,15 +246,19 @@ func (eq Eq) Build(builder Builder) { switch eq.Value.(type) { case []string, []int, []int32, []int64, []uint, []uint32, []uint64, []interface{}: - builder.WriteString(" IN (") rv := reflect.ValueOf(eq.Value) - for i := 0; i < rv.Len(); i++ { - if i > 0 { - builder.WriteByte(',') + if rv.Len() == 0 { + builder.WriteString(" IN (NULL)") + } else { + builder.WriteString(" IN (") + for i := 0; i < rv.Len(); i++ { + if i > 0 { + builder.WriteByte(',') + } + builder.AddVar(builder, rv.Index(i).Interface()) } - builder.AddVar(builder, rv.Index(i).Interface()) + builder.WriteByte(')') } - builder.WriteByte(')') default: if eqNil(eq.Value) { builder.WriteString(" IS NULL") diff --git a/vendor/gorm.io/gorm/clause/limit.go b/vendor/gorm.io/gorm/clause/limit.go index 3ede7385ca..abda00551a 100644 --- a/vendor/gorm.io/gorm/clause/limit.go +++ b/vendor/gorm.io/gorm/clause/limit.go @@ -33,7 +33,7 @@ func (limit Limit) MergeClause(clause *Clause) { clause.Name = "" if v, ok := clause.Expression.(Limit); ok { - if (limit.Limit == nil || *limit.Limit == 0) && (v.Limit != nil && *v.Limit != 0) { + if (limit.Limit == nil || *limit.Limit == 0) && v.Limit != nil { limit.Limit = v.Limit } diff --git a/vendor/gorm.io/gorm/errors.go b/vendor/gorm.io/gorm/errors.go index 49cbfe64ac..cd76f1f522 100644 --- a/vendor/gorm.io/gorm/errors.go +++ b/vendor/gorm.io/gorm/errors.go @@ -21,6 +21,10 @@ var ( ErrPrimaryKeyRequired = errors.New("primary key required") // ErrModelValueRequired model value required ErrModelValueRequired = errors.New("model value required") + // ErrModelAccessibleFieldsRequired model accessible fields required + ErrModelAccessibleFieldsRequired = errors.New("model accessible fields required") + // ErrSubQueryRequired sub query required + ErrSubQueryRequired = errors.New("sub query required") // ErrInvalidData unsupported data ErrInvalidData = errors.New("unsupported data") // ErrUnsupportedDriver unsupported driver @@ -41,4 +45,8 @@ var ( ErrInvalidValueOfLength = errors.New("invalid association values, length doesn't match") // ErrPreloadNotAllowed preload is not allowed when count is used ErrPreloadNotAllowed = errors.New("preload is not allowed when count is used") + // ErrDuplicatedKey occurs when there is a unique key constraint violation + ErrDuplicatedKey = errors.New("duplicated key not allowed") + // ErrForeignKeyViolated occurs when there is a foreign key constraint violation + ErrForeignKeyViolated = errors.New("violates foreign key constraint") ) diff --git a/vendor/gorm.io/gorm/finisher_api.go b/vendor/gorm.io/gorm/finisher_api.go index 39d9fca3ec..f80aa6c042 100644 --- a/vendor/gorm.io/gorm/finisher_api.go +++ b/vendor/gorm.io/gorm/finisher_api.go @@ -33,9 +33,10 @@ func (db *DB) CreateInBatches(value interface{}, batchSize int) (tx *DB) { var rowsAffected int64 tx = db.getInstance() + // the reflection length judgment of the optimized value + reflectLen := reflectValue.Len() + callFc := func(tx *DB) error { - // the reflection length judgment of the optimized value - reflectLen := reflectValue.Len() for i := 0; i < reflectLen; i += batchSize { ends := i + batchSize if ends > reflectLen { @@ -53,7 +54,7 @@ func (db *DB) CreateInBatches(value interface{}, batchSize int) (tx *DB) { return nil } - if tx.SkipDefaultTransaction { + if tx.SkipDefaultTransaction || reflectLen <= batchSize { tx.AddError(callFc(tx.Session(&Session{}))) } else { tx.AddError(tx.Transaction(callFc)) @@ -101,14 +102,13 @@ func (db *DB) Save(value interface{}) (tx *DB) { tx.Statement.Selects = append(tx.Statement.Selects, "*") } - tx = tx.callbacks.Update().Execute(tx) + updateTx := tx.callbacks.Update().Execute(tx.Session(&Session{Initialized: true})) - if tx.Error == nil && tx.RowsAffected == 0 && !tx.DryRun && !selectedUpdate { - result := reflect.New(tx.Statement.Schema.ModelType).Interface() - if result := tx.Session(&Session{}).Limit(1).Find(result); result.RowsAffected == 0 { - return tx.Create(value) - } + if updateTx.Error == nil && updateTx.RowsAffected == 0 && !updateTx.DryRun && !selectedUpdate { + return tx.Session(&Session{SkipHooks: true}).Clauses(clause.OnConflict{UpdateAll: true}).Create(value) } + + return updateTx } return @@ -490,7 +490,7 @@ func (db *DB) Count(count *int64) (tx *DB) { tx.Statement.Dest = count tx = tx.callbacks.Query().Execute(tx) - if tx.RowsAffected != 1 { + if _, ok := db.Statement.Clauses["GROUP BY"]; ok || tx.RowsAffected != 1 { *count = tx.RowsAffected } @@ -531,6 +531,7 @@ func (db *DB) Scan(dest interface{}) (tx *DB) { tx.ScanRows(rows, dest) } else { tx.RowsAffected = 0 + tx.AddError(rows.Err()) } tx.AddError(rows.Close()) } @@ -622,7 +623,6 @@ func (db *DB) Transaction(fc func(tx *DB) error, opts ...*sql.TxOptions) (err er if err != nil { return } - defer func() { // Make sure to rollback when panic, Block error or Commit error if panicked || err != nil { @@ -707,7 +707,21 @@ func (db *DB) Rollback() *DB { func (db *DB) SavePoint(name string) *DB { if savePointer, ok := db.Dialector.(SavePointerDialectorInterface); ok { + // close prepared statement, because SavePoint not support prepared statement. + // e.g. mysql8.0 doc: https://dev.mysql.com/doc/refman/8.0/en/sql-prepared-statements.html + var ( + preparedStmtTx *PreparedStmtTX + isPreparedStmtTx bool + ) + // close prepared statement, because SavePoint not support prepared statement. + if preparedStmtTx, isPreparedStmtTx = db.Statement.ConnPool.(*PreparedStmtTX); isPreparedStmtTx { + db.Statement.ConnPool = preparedStmtTx.Tx + } db.AddError(savePointer.SavePoint(db, name)) + // restore prepared statement + if isPreparedStmtTx { + db.Statement.ConnPool = preparedStmtTx + } } else { db.AddError(ErrUnsupportedDriver) } @@ -716,7 +730,21 @@ func (db *DB) SavePoint(name string) *DB { func (db *DB) RollbackTo(name string) *DB { if savePointer, ok := db.Dialector.(SavePointerDialectorInterface); ok { + // close prepared statement, because RollbackTo not support prepared statement. + // e.g. mysql8.0 doc: https://dev.mysql.com/doc/refman/8.0/en/sql-prepared-statements.html + var ( + preparedStmtTx *PreparedStmtTX + isPreparedStmtTx bool + ) + // close prepared statement, because SavePoint not support prepared statement. + if preparedStmtTx, isPreparedStmtTx = db.Statement.ConnPool.(*PreparedStmtTX); isPreparedStmtTx { + db.Statement.ConnPool = preparedStmtTx.Tx + } db.AddError(savePointer.RollbackTo(db, name)) + // restore prepared statement + if isPreparedStmtTx { + db.Statement.ConnPool = preparedStmtTx + } } else { db.AddError(ErrUnsupportedDriver) } diff --git a/vendor/gorm.io/gorm/gorm.go b/vendor/gorm.io/gorm/gorm.go index 37595ddd78..775cd3de3b 100644 --- a/vendor/gorm.io/gorm/gorm.go +++ b/vendor/gorm.io/gorm/gorm.go @@ -4,6 +4,7 @@ import ( "context" "database/sql" "fmt" + "reflect" "sort" "sync" "time" @@ -47,6 +48,8 @@ type Config struct { QueryFields bool // CreateBatchSize default create batch size CreateBatchSize int + // TranslateError enabling error translation + TranslateError bool // ClauseBuilders clause builder ClauseBuilders map[string]clause.ClauseBuilder @@ -144,7 +147,7 @@ func Open(dialector Dialector, opts ...Option) (db *DB, err error) { } if config.NamingStrategy == nil { - config.NamingStrategy = schema.NamingStrategy{} + config.NamingStrategy = schema.NamingStrategy{IdentifierMaxLength: 64} // Default Identifier length is 64 } if config.Logger == nil { @@ -177,17 +180,17 @@ func Open(dialector Dialector, opts ...Option) (db *DB, err error) { if config.Dialector != nil { err = config.Dialector.Initialize(db) - } - preparedStmt := &PreparedStmtDB{ - ConnPool: db.ConnPool, - Stmts: make(map[string]*Stmt), - Mux: &sync.RWMutex{}, - PreparedSQL: make([]string, 0, 100), + if err != nil { + if db, _ := db.DB(); db != nil { + _ = db.Close() + } + } } - db.cacheStore.Store(preparedStmtDBKey, preparedStmt) if config.PrepareStmt { + preparedStmt := NewPreparedStmtDB(db.ConnPool) + db.cacheStore.Store(preparedStmtDBKey, preparedStmt) db.ConnPool = preparedStmt } @@ -248,24 +251,30 @@ func (db *DB) Session(config *Session) *DB { } if config.PrepareStmt { + var preparedStmt *PreparedStmtDB + if v, ok := db.cacheStore.Load(preparedStmtDBKey); ok { - preparedStmt := v.(*PreparedStmtDB) - switch t := tx.Statement.ConnPool.(type) { - case Tx: - tx.Statement.ConnPool = &PreparedStmtTX{ - Tx: t, - PreparedStmtDB: preparedStmt, - } - default: - tx.Statement.ConnPool = &PreparedStmtDB{ - ConnPool: db.Config.ConnPool, - Mux: preparedStmt.Mux, - Stmts: preparedStmt.Stmts, - } + preparedStmt = v.(*PreparedStmtDB) + } else { + preparedStmt = NewPreparedStmtDB(db.ConnPool) + db.cacheStore.Store(preparedStmtDBKey, preparedStmt) + } + + switch t := tx.Statement.ConnPool.(type) { + case Tx: + tx.Statement.ConnPool = &PreparedStmtTX{ + Tx: t, + PreparedStmtDB: preparedStmt, + } + default: + tx.Statement.ConnPool = &PreparedStmtDB{ + ConnPool: db.Config.ConnPool, + Mux: preparedStmt.Mux, + Stmts: preparedStmt.Stmts, } - txConfig.ConnPool = tx.Statement.ConnPool - txConfig.PrepareStmt = true } + txConfig.ConnPool = tx.Statement.ConnPool + txConfig.PrepareStmt = true } if config.SkipHooks { @@ -347,10 +356,18 @@ func (db *DB) Callback() *callbacks { // AddError add error to db func (db *DB) AddError(err error) error { - if db.Error == nil { - db.Error = err - } else if err != nil { - db.Error = fmt.Errorf("%v; %w", db.Error, err) + if err != nil { + if db.Config.TranslateError { + if errTranslator, ok := db.Dialector.(ErrorTranslator); ok { + err = errTranslator.Translate(err) + } + } + + if db.Error == nil { + db.Error = err + } else { + db.Error = fmt.Errorf("%v; %w", db.Error, err) + } } return db.Error } @@ -358,12 +375,20 @@ func (db *DB) AddError(err error) error { // DB returns `*sql.DB` func (db *DB) DB() (*sql.DB, error) { connPool := db.ConnPool + if db.Statement != nil && db.Statement.ConnPool != nil { + connPool = db.Statement.ConnPool + } + if tx, ok := connPool.(*sql.Tx); ok && tx != nil { + return (*sql.DB)(reflect.ValueOf(tx).Elem().FieldByName("db").UnsafePointer()), nil + } if dbConnector, ok := connPool.(GetDBConnector); ok && dbConnector != nil { - return dbConnector.GetDBConn() + if sqldb, err := dbConnector.GetDBConn(); sqldb != nil || err != nil { + return sqldb, err + } } - if sqldb, ok := connPool.(*sql.DB); ok { + if sqldb, ok := connPool.(*sql.DB); ok && sqldb != nil { return sqldb, nil } @@ -377,11 +402,12 @@ func (db *DB) getInstance() *DB { if db.clone == 1 { // clone with new statement tx.Statement = &Statement{ - DB: tx, - ConnPool: db.Statement.ConnPool, - Context: db.Statement.Context, - Clauses: map[string]clause.Clause{}, - Vars: make([]interface{}, 0, 8), + DB: tx, + ConnPool: db.Statement.ConnPool, + Context: db.Statement.Context, + Clauses: map[string]clause.Clause{}, + Vars: make([]interface{}, 0, 8), + SkipHooks: db.Statement.SkipHooks, } } else { // with clone statement diff --git a/vendor/gorm.io/gorm/interfaces.go b/vendor/gorm.io/gorm/interfaces.go index cf9e07b97a..3bcc3d570b 100644 --- a/vendor/gorm.io/gorm/interfaces.go +++ b/vendor/gorm.io/gorm/interfaces.go @@ -86,3 +86,7 @@ type Rows interface { Err() error Close() error } + +type ErrorTranslator interface { + Translate(err error) error +} diff --git a/vendor/gorm.io/gorm/logger/sql.go b/vendor/gorm.io/gorm/logger/sql.go index bcacc7cf52..13e5d957dc 100644 --- a/vendor/gorm.io/gorm/logger/sql.go +++ b/vendor/gorm.io/gorm/logger/sql.go @@ -28,8 +28,10 @@ func isPrintable(s string) bool { return true } +// A list of Go types that should be converted to SQL primitives var convertibleTypes = []reflect.Type{reflect.TypeOf(time.Time{}), reflect.TypeOf(false), reflect.TypeOf([]byte{})} +// RegEx matches only numeric values var numericPlaceholderRe = regexp.MustCompile(`\$\d+\$`) // ExplainSQL generate SQL string with given parameters, the generated SQL is expected to be used in logger, execute it might introduce a SQL injection vulnerability @@ -93,8 +95,10 @@ func ExplainSQL(sql string, numericPlaceholder *regexp.Regexp, escaper string, a } case int, int8, int16, int32, int64, uint, uint8, uint16, uint32, uint64: vars[idx] = utils.ToString(v) - case float64, float32: - vars[idx] = fmt.Sprintf("%.6f", v) + case float32: + vars[idx] = strconv.FormatFloat(float64(v), 'f', -1, 32) + case float64: + vars[idx] = strconv.FormatFloat(v, 'f', -1, 64) case string: vars[idx] = escaper + strings.ReplaceAll(v, escaper, "\\"+escaper) + escaper default: diff --git a/vendor/gorm.io/gorm/migrator.go b/vendor/gorm.io/gorm/migrator.go index 882fc4cc50..0e01f567d6 100644 --- a/vendor/gorm.io/gorm/migrator.go +++ b/vendor/gorm.io/gorm/migrator.go @@ -13,11 +13,7 @@ func (db *DB) Migrator() Migrator { // apply scopes to migrator for len(tx.Statement.scopes) > 0 { - scopes := tx.Statement.scopes - tx.Statement.scopes = nil - for _, scope := range scopes { - tx = scope(tx) - } + tx = tx.executeScopes() } return tx.Dialector.Migrator(tx.Session(&Session{})) @@ -30,9 +26,9 @@ func (db *DB) AutoMigrate(dst ...interface{}) error { // ViewOption view option type ViewOption struct { - Replace bool - CheckOption string - Query *DB + Replace bool // If true, exec `CREATE`. If false, exec `CREATE OR REPLACE` + CheckOption string // optional. e.g. `WITH [ CASCADED | LOCAL ] CHECK OPTION` + Query *DB // required subquery. } // ColumnType column type interface @@ -60,6 +56,14 @@ type Index interface { Option() string } +// TableType table type interface +type TableType interface { + Schema() string + Name() string + Type() string + Comment() (comment string, ok bool) +} + // Migrator migrator interface type Migrator interface { // AutoMigrate @@ -76,6 +80,7 @@ type Migrator interface { HasTable(dst interface{}) bool RenameTable(oldName, newName interface{}) error GetTables() (tableList []string, err error) + TableType(dst interface{}) (TableType, error) // Columns AddColumn(dst interface{}, field string) error diff --git a/vendor/gorm.io/gorm/migrator/index.go b/vendor/gorm.io/gorm/migrator/index.go index fe686e5afe..8845da95b3 100644 --- a/vendor/gorm.io/gorm/migrator/index.go +++ b/vendor/gorm.io/gorm/migrator/index.go @@ -17,12 +17,12 @@ func (idx Index) Table() string { return idx.TableName } -// Name return the name of the index. +// Name return the name of the index. func (idx Index) Name() string { return idx.NameValue } -// Columns return the columns fo the index +// Columns return the columns of the index func (idx Index) Columns() []string { return idx.ColumnList } @@ -37,7 +37,7 @@ func (idx Index) Unique() (unique bool, ok bool) { return idx.UniqueValue.Bool, idx.UniqueValue.Valid } -// Option return the optional attribute fo the index +// Option return the optional attribute of the index func (idx Index) Option() string { return idx.OptionValue } diff --git a/vendor/gorm.io/gorm/migrator/migrator.go b/vendor/gorm.io/gorm/migrator/migrator.go index b8aaef2b30..49bc937117 100644 --- a/vendor/gorm.io/gorm/migrator/migrator.go +++ b/vendor/gorm.io/gorm/migrator/migrator.go @@ -16,9 +16,16 @@ import ( "gorm.io/gorm/schema" ) -var ( - regFullDataType = regexp.MustCompile(`\D*(\d+)\D?`) -) +// This regular expression seeks to find a sequence of digits (\d+) among zero or more non-digit characters (\D*), +// with a possible trailing non-digit character (\D?). + +// For example, values that can pass this regular expression are: +// - "123" +// - "abc456" +// -"%$#@789" +var regFullDataType = regexp.MustCompile(`\D*(\d+)\D?`) + +// TODO:? Create const vars for raw sql queries ? // Migrator m struct type Migrator struct { @@ -115,7 +122,7 @@ func (m Migrator) AutoMigrate(values ...interface{}) error { return err } } else { - if err := m.RunWithValue(value, func(stmt *gorm.Statement) (errr error) { + if err := m.RunWithValue(value, func(stmt *gorm.Statement) error { columnTypes, err := queryTx.Migrator().ColumnTypes(value) if err != nil { return err @@ -125,7 +132,6 @@ func (m Migrator) AutoMigrate(values ...interface{}) error { parseCheckConstraints = stmt.Schema.ParseCheckConstraints() ) for _, dbName := range stmt.Schema.DBNames { - field := stmt.Schema.FieldsByDBName[dbName] var foundColumn gorm.ColumnType for _, columnType := range columnTypes { @@ -137,12 +143,15 @@ func (m Migrator) AutoMigrate(values ...interface{}) error { if foundColumn == nil { // not found, add column - if err := execTx.Migrator().AddColumn(value, dbName); err != nil { + if err = execTx.Migrator().AddColumn(value, dbName); err != nil { + return err + } + } else { + // found, smartly migrate + field := stmt.Schema.FieldsByDBName[dbName] + if err = execTx.Migrator().MigrateColumn(value, field, foundColumn); err != nil { return err } - } else if err := execTx.Migrator().MigrateColumn(value, field, foundColumn); err != nil { - // found, smart migrate - return err } } @@ -197,7 +206,7 @@ func (m Migrator) GetTables() (tableList []string, err error) { func (m Migrator) CreateTable(values ...interface{}) error { for _, value := range m.ReorderModels(values, false) { tx := m.DB.Session(&gorm.Session{}) - if err := m.RunWithValue(value, func(stmt *gorm.Statement) (errr error) { + if err := m.RunWithValue(value, func(stmt *gorm.Statement) (err error) { var ( createTableSQL = "CREATE TABLE ? (" values = []interface{}{m.CurrentTable(stmt)} @@ -208,7 +217,7 @@ func (m Migrator) CreateTable(values ...interface{}) error { field := stmt.Schema.FieldsByDBName[dbName] if !field.IgnoreMigration { createTableSQL += "? ?" - hasPrimaryKeyInDataType = hasPrimaryKeyInDataType || strings.Contains(strings.ToUpper(string(field.DataType)), "PRIMARY KEY") + hasPrimaryKeyInDataType = hasPrimaryKeyInDataType || strings.Contains(strings.ToUpper(m.DataTypeOf(field)), "PRIMARY KEY") values = append(values, clause.Column{Name: dbName}, m.DB.Migrator().FullDataTypeOf(field)) createTableSQL += "," } @@ -216,7 +225,7 @@ func (m Migrator) CreateTable(values ...interface{}) error { if !hasPrimaryKeyInDataType && len(stmt.Schema.PrimaryFields) > 0 { createTableSQL += "PRIMARY KEY ?," - primaryKeys := []interface{}{} + primaryKeys := make([]interface{}, 0, len(stmt.Schema.PrimaryFields)) for _, field := range stmt.Schema.PrimaryFields { primaryKeys = append(primaryKeys, clause.Column{Name: field.DBName}) } @@ -227,8 +236,8 @@ func (m Migrator) CreateTable(values ...interface{}) error { for _, idx := range stmt.Schema.ParseIndexes() { if m.CreateIndexAfterCreateTable { defer func(value interface{}, name string) { - if errr == nil { - errr = tx.Migrator().CreateIndex(value, name) + if err == nil { + err = tx.Migrator().CreateIndex(value, name) } }(value, idx.Name) } else { @@ -278,8 +287,8 @@ func (m Migrator) CreateTable(values ...interface{}) error { createTableSQL += fmt.Sprint(tableOption) } - errr = tx.Exec(createTableSQL, values...).Error - return errr + err = tx.Exec(createTableSQL, values...).Error + return err }); err != nil { return err } @@ -500,7 +509,7 @@ func (m Migrator) MigrateColumn(value interface{}, field *schema.Field, columnTy currentDefaultNotNull := field.HasDefaultValue && (field.DefaultValueInterface != nil || !strings.EqualFold(field.DefaultValue, "NULL")) dv, dvNotNull := columnType.DefaultValue() if dvNotNull && !currentDefaultNotNull { - // defalut value -> null + // default value -> null alterColumn = true } else if !dvNotNull && currentDefaultNotNull { // null -> default value @@ -559,14 +568,44 @@ func (m Migrator) ColumnTypes(value interface{}) ([]gorm.ColumnType, error) { return columnTypes, execErr } -// CreateView create view +// CreateView create view from Query in gorm.ViewOption. +// Query in gorm.ViewOption is a [subquery] +// +// // CREATE VIEW `user_view` AS SELECT * FROM `users` WHERE age > 20 +// q := DB.Model(&User{}).Where("age > ?", 20) +// DB.Debug().Migrator().CreateView("user_view", gorm.ViewOption{Query: q}) +// +// // CREATE OR REPLACE VIEW `users_view` AS SELECT * FROM `users` WITH CHECK OPTION +// q := DB.Model(&User{}) +// DB.Debug().Migrator().CreateView("user_view", gorm.ViewOption{Query: q, Replace: true, CheckOption: "WITH CHECK OPTION"}) +// +// [subquery]: https://gorm.io/docs/advanced_query.html#SubQuery func (m Migrator) CreateView(name string, option gorm.ViewOption) error { - return gorm.ErrNotImplemented + if option.Query == nil { + return gorm.ErrSubQueryRequired + } + + sql := new(strings.Builder) + sql.WriteString("CREATE ") + if option.Replace { + sql.WriteString("OR REPLACE ") + } + sql.WriteString("VIEW ") + m.QuoteTo(sql, name) + sql.WriteString(" AS ") + + m.DB.Statement.AddVar(sql, option.Query) + + if option.CheckOption != "" { + sql.WriteString(" ") + sql.WriteString(option.CheckOption) + } + return m.DB.Exec(m.Explain(sql.String(), m.DB.Statement.Vars...)).Error } // DropView drop view func (m Migrator) DropView(name string) error { - return gorm.ErrNotImplemented + return m.DB.Exec("DROP VIEW IF EXISTS ?", clause.Table{Name: name}).Error } func buildConstraint(constraint *schema.Constraint) (sql string, results []interface{}) { @@ -919,3 +958,8 @@ func (m Migrator) GetIndexes(dst interface{}) ([]gorm.Index, error) { func (m Migrator) GetTypeAliases(databaseTypeName string) []string { return nil } + +// TableType return tableType gorm.TableType and execErr error +func (m Migrator) TableType(dst interface{}) (gorm.TableType, error) { + return nil, errors.New("not support") +} diff --git a/vendor/gorm.io/gorm/migrator/table_type.go b/vendor/gorm.io/gorm/migrator/table_type.go new file mode 100644 index 0000000000..ed6e42a0ec --- /dev/null +++ b/vendor/gorm.io/gorm/migrator/table_type.go @@ -0,0 +1,33 @@ +package migrator + +import ( + "database/sql" +) + +// TableType table type implements TableType interface +type TableType struct { + SchemaValue string + NameValue string + TypeValue string + CommentValue sql.NullString +} + +// Schema returns the schema of the table. +func (ct TableType) Schema() string { + return ct.SchemaValue +} + +// Name returns the name of the table. +func (ct TableType) Name() string { + return ct.NameValue +} + +// Type returns the type of the table. +func (ct TableType) Type() string { + return ct.TypeValue +} + +// Comment returns the comment of current table. +func (ct TableType) Comment() (comment string, ok bool) { + return ct.CommentValue.String, ct.CommentValue.Valid +} diff --git a/vendor/gorm.io/gorm/model.go b/vendor/gorm.io/gorm/model.go index 3334d17cb1..fa705df1cd 100644 --- a/vendor/gorm.io/gorm/model.go +++ b/vendor/gorm.io/gorm/model.go @@ -4,9 +4,10 @@ import "time" // Model a basic GoLang struct which includes the following fields: ID, CreatedAt, UpdatedAt, DeletedAt // It may be embedded into your model or you may build your own model without it -// type User struct { -// gorm.Model -// } +// +// type User struct { +// gorm.Model +// } type Model struct { ID uint `gorm:"primarykey"` CreatedAt time.Time diff --git a/vendor/gorm.io/gorm/prepare_stmt.go b/vendor/gorm.io/gorm/prepare_stmt.go index e09fe814f6..aa944624c8 100644 --- a/vendor/gorm.io/gorm/prepare_stmt.go +++ b/vendor/gorm.io/gorm/prepare_stmt.go @@ -3,6 +3,7 @@ package gorm import ( "context" "database/sql" + "reflect" "sync" ) @@ -20,15 +21,24 @@ type PreparedStmtDB struct { ConnPool } -func (db *PreparedStmtDB) GetDBConn() (*sql.DB, error) { - if dbConnector, ok := db.ConnPool.(GetDBConnector); ok && dbConnector != nil { - return dbConnector.GetDBConn() +func NewPreparedStmtDB(connPool ConnPool) *PreparedStmtDB { + return &PreparedStmtDB{ + ConnPool: connPool, + Stmts: make(map[string]*Stmt), + Mux: &sync.RWMutex{}, + PreparedSQL: make([]string, 0, 100), } +} +func (db *PreparedStmtDB) GetDBConn() (*sql.DB, error) { if sqldb, ok := db.ConnPool.(*sql.DB); ok { return sqldb, nil } + if dbConnector, ok := db.ConnPool.(GetDBConnector); ok && dbConnector != nil { + return dbConnector.GetDBConn() + } + return nil, ErrInvalidDB } @@ -44,15 +54,15 @@ func (db *PreparedStmtDB) Close() { } } -func (db *PreparedStmtDB) Reset() { - db.Mux.Lock() - defer db.Mux.Unlock() +func (sdb *PreparedStmtDB) Reset() { + sdb.Mux.Lock() + defer sdb.Mux.Unlock() - for _, stmt := range db.Stmts { + for _, stmt := range sdb.Stmts { go stmt.Close() } - db.PreparedSQL = make([]string, 0, 100) - db.Stmts = make(map[string]*Stmt) + sdb.PreparedSQL = make([]string, 0, 100) + sdb.Stmts = make(map[string]*Stmt) } func (db *PreparedStmtDB) prepare(ctx context.Context, conn ConnPool, isTransaction bool, query string) (Stmt, error) { @@ -117,6 +127,19 @@ func (db *PreparedStmtDB) BeginTx(ctx context.Context, opt *sql.TxOptions) (Conn tx, err := beginner.BeginTx(ctx, opt) return &PreparedStmtTX{PreparedStmtDB: db, Tx: tx}, err } + + beginner, ok := db.ConnPool.(ConnPoolBeginner) + if !ok { + return nil, ErrInvalidTransaction + } + + connPool, err := beginner.BeginTx(ctx, opt) + if err != nil { + return nil, err + } + if tx, ok := connPool.(Tx); ok { + return &PreparedStmtTX{PreparedStmtDB: db, Tx: tx}, nil + } return nil, ErrInvalidTransaction } @@ -162,15 +185,19 @@ type PreparedStmtTX struct { PreparedStmtDB *PreparedStmtDB } +func (db *PreparedStmtTX) GetDBConn() (*sql.DB, error) { + return db.PreparedStmtDB.GetDBConn() +} + func (tx *PreparedStmtTX) Commit() error { - if tx.Tx != nil { + if tx.Tx != nil && !reflect.ValueOf(tx.Tx).IsNil() { return tx.Tx.Commit() } return ErrInvalidTransaction } func (tx *PreparedStmtTX) Rollback() error { - if tx.Tx != nil { + if tx.Tx != nil && !reflect.ValueOf(tx.Tx).IsNil() { return tx.Tx.Rollback() } return ErrInvalidTransaction diff --git a/vendor/gorm.io/gorm/scan.go b/vendor/gorm.io/gorm/scan.go index 12a7786208..736db4d3a7 100644 --- a/vendor/gorm.io/gorm/scan.go +++ b/vendor/gorm.io/gorm/scan.go @@ -4,10 +4,10 @@ import ( "database/sql" "database/sql/driver" "reflect" - "strings" "time" "gorm.io/gorm/schema" + "gorm.io/gorm/utils" ) // prepareValues prepare values slice @@ -50,7 +50,7 @@ func scanIntoMap(mapValue map[string]interface{}, values []interface{}, columns } } -func (db *DB) scanIntoStruct(rows Rows, reflectValue reflect.Value, values []interface{}, fields []*schema.Field, joinFields [][2]*schema.Field) { +func (db *DB) scanIntoStruct(rows Rows, reflectValue reflect.Value, values []interface{}, fields []*schema.Field, joinFields [][]*schema.Field) { for idx, field := range fields { if field != nil { values[idx] = field.NewValuePool.Get() @@ -65,28 +65,45 @@ func (db *DB) scanIntoStruct(rows Rows, reflectValue reflect.Value, values []int db.RowsAffected++ db.AddError(rows.Scan(values...)) - joinedSchemaMap := make(map[*schema.Field]interface{}) + joinedNestedSchemaMap := make(map[string]interface{}) for idx, field := range fields { if field == nil { continue } - if len(joinFields) == 0 || joinFields[idx][0] == nil { + if len(joinFields) == 0 || len(joinFields[idx]) == 0 { db.AddError(field.Set(db.Statement.Context, reflectValue, values[idx])) - } else { - joinSchema := joinFields[idx][0] - relValue := joinSchema.ReflectValueOf(db.Statement.Context, reflectValue) - if relValue.Kind() == reflect.Ptr { - if _, ok := joinedSchemaMap[joinSchema]; !ok { - if value := reflect.ValueOf(values[idx]).Elem(); value.Kind() == reflect.Ptr && value.IsNil() { - continue - } + } else { // joinFields count is larger than 2 when using join + var isNilPtrValue bool + var relValue reflect.Value + // does not contain raw dbname + nestedJoinSchemas := joinFields[idx][:len(joinFields[idx])-1] + // current reflect value + currentReflectValue := reflectValue + fullRels := make([]string, 0, len(nestedJoinSchemas)) + for _, joinSchema := range nestedJoinSchemas { + fullRels = append(fullRels, joinSchema.Name) + relValue = joinSchema.ReflectValueOf(db.Statement.Context, currentReflectValue) + if relValue.Kind() == reflect.Ptr { + fullRelsName := utils.JoinNestedRelationNames(fullRels) + // same nested structure + if _, ok := joinedNestedSchemaMap[fullRelsName]; !ok { + if value := reflect.ValueOf(values[idx]).Elem(); value.Kind() == reflect.Ptr && value.IsNil() { + isNilPtrValue = true + break + } - relValue.Set(reflect.New(relValue.Type().Elem())) - joinedSchemaMap[joinSchema] = nil + relValue.Set(reflect.New(relValue.Type().Elem())) + joinedNestedSchemaMap[fullRelsName] = nil + } } + currentReflectValue = relValue + } + + if !isNilPtrValue { // ignore if value is nil + f := joinFields[idx][len(joinFields[idx])-1] + db.AddError(f.Set(db.Statement.Context, relValue, values[idx])) } - db.AddError(joinFields[idx][1].Set(db.Statement.Context, relValue, values[idx])) } // release data to pool @@ -163,7 +180,7 @@ func Scan(rows Rows, db *DB, mode ScanMode) { default: var ( fields = make([]*schema.Field, len(columns)) - joinFields [][2]*schema.Field + joinFields [][]*schema.Field sch = db.Statement.Schema reflectValue = db.Statement.ReflectValue ) @@ -217,15 +234,26 @@ func Scan(rows Rows, db *DB, mode ScanMode) { } else { matchedFieldCount[column] = 1 } - } else if names := strings.Split(column, "__"); len(names) > 1 { + } else if names := utils.SplitNestedRelationName(column); len(names) > 1 { // has nested relation if rel, ok := sch.Relationships.Relations[names[0]]; ok { - if field := rel.FieldSchema.LookUpField(strings.Join(names[1:], "__")); field != nil && field.Readable { + subNameCount := len(names) + // nested relation fields + relFields := make([]*schema.Field, 0, subNameCount-1) + relFields = append(relFields, rel.Field) + for _, name := range names[1 : subNameCount-1] { + rel = rel.FieldSchema.Relationships.Relations[name] + relFields = append(relFields, rel.Field) + } + // lastest name is raw dbname + dbName := names[subNameCount-1] + if field := rel.FieldSchema.LookUpField(dbName); field != nil && field.Readable { fields[idx] = field if len(joinFields) == 0 { - joinFields = make([][2]*schema.Field, len(columns)) + joinFields = make([][]*schema.Field, len(columns)) } - joinFields[idx] = [2]*schema.Field{rel.Field, field} + relFields = append(relFields, field) + joinFields[idx] = relFields continue } } diff --git a/vendor/gorm.io/gorm/schema/field.go b/vendor/gorm.io/gorm/schema/field.go index 1589d984a3..dd08e056b3 100644 --- a/vendor/gorm.io/gorm/schema/field.go +++ b/vendor/gorm.io/gorm/schema/field.go @@ -89,6 +89,10 @@ type Field struct { NewValuePool FieldNewValuePool } +func (field *Field) BindName() string { + return strings.Join(field.BindNames, ".") +} + // ParseField parses reflect.StructField to Field func (schema *Schema) ParseField(fieldStruct reflect.StructField) *Field { var ( @@ -174,7 +178,7 @@ func (schema *Schema) ParseField(fieldStruct reflect.StructField) *Field { field.DataType = String field.Serializer = v } else { - var serializerName = field.TagSettings["JSON"] + serializerName := field.TagSettings["JSON"] if serializerName == "" { serializerName = field.TagSettings["SERIALIZER"] } @@ -580,8 +584,6 @@ func (field *Field) setupValuerAndSetter() { case **bool: if data != nil && *data != nil { field.ReflectValueOf(ctx, value).SetBool(**data) - } else { - field.ReflectValueOf(ctx, value).SetBool(false) } case bool: field.ReflectValueOf(ctx, value).SetBool(data) @@ -601,8 +603,22 @@ func (field *Field) setupValuerAndSetter() { case **int64: if data != nil && *data != nil { field.ReflectValueOf(ctx, value).SetInt(**data) - } else { - field.ReflectValueOf(ctx, value).SetInt(0) + } + case **int: + if data != nil && *data != nil { + field.ReflectValueOf(ctx, value).SetInt(int64(**data)) + } + case **int8: + if data != nil && *data != nil { + field.ReflectValueOf(ctx, value).SetInt(int64(**data)) + } + case **int16: + if data != nil && *data != nil { + field.ReflectValueOf(ctx, value).SetInt(int64(**data)) + } + case **int32: + if data != nil && *data != nil { + field.ReflectValueOf(ctx, value).SetInt(int64(**data)) } case int64: field.ReflectValueOf(ctx, value).SetInt(data) @@ -667,8 +683,22 @@ func (field *Field) setupValuerAndSetter() { case **uint64: if data != nil && *data != nil { field.ReflectValueOf(ctx, value).SetUint(**data) - } else { - field.ReflectValueOf(ctx, value).SetUint(0) + } + case **uint: + if data != nil && *data != nil { + field.ReflectValueOf(ctx, value).SetUint(uint64(**data)) + } + case **uint8: + if data != nil && *data != nil { + field.ReflectValueOf(ctx, value).SetUint(uint64(**data)) + } + case **uint16: + if data != nil && *data != nil { + field.ReflectValueOf(ctx, value).SetUint(uint64(**data)) + } + case **uint32: + if data != nil && *data != nil { + field.ReflectValueOf(ctx, value).SetUint(uint64(**data)) } case uint64: field.ReflectValueOf(ctx, value).SetUint(data) @@ -721,8 +751,10 @@ func (field *Field) setupValuerAndSetter() { case **float64: if data != nil && *data != nil { field.ReflectValueOf(ctx, value).SetFloat(**data) - } else { - field.ReflectValueOf(ctx, value).SetFloat(0) + } + case **float32: + if data != nil && *data != nil { + field.ReflectValueOf(ctx, value).SetFloat(float64(**data)) } case float64: field.ReflectValueOf(ctx, value).SetFloat(data) @@ -767,8 +799,6 @@ func (field *Field) setupValuerAndSetter() { case **string: if data != nil && *data != nil { field.ReflectValueOf(ctx, value).SetString(**data) - } else { - field.ReflectValueOf(ctx, value).SetString("") } case string: field.ReflectValueOf(ctx, value).SetString(data) @@ -816,7 +846,7 @@ func (field *Field) setupValuerAndSetter() { field.Set = func(ctx context.Context, value reflect.Value, v interface{}) error { switch data := v.(type) { case **time.Time: - if data != nil { + if data != nil && *data != nil { field.ReflectValueOf(ctx, value).Set(reflect.ValueOf(*data)) } case time.Time: @@ -852,14 +882,12 @@ func (field *Field) setupValuerAndSetter() { reflectV := reflect.ValueOf(v) if !reflectV.IsValid() { field.ReflectValueOf(ctx, value).Set(reflect.New(field.FieldType).Elem()) + } else if reflectV.Kind() == reflect.Ptr && reflectV.IsNil() { + return } else if reflectV.Type().AssignableTo(field.FieldType) { field.ReflectValueOf(ctx, value).Set(reflectV) } else if reflectV.Kind() == reflect.Ptr { - if reflectV.IsNil() || !reflectV.IsValid() { - field.ReflectValueOf(ctx, value).Set(reflect.New(field.FieldType).Elem()) - } else { - return field.Set(ctx, value, reflectV.Elem().Interface()) - } + return field.Set(ctx, value, reflectV.Elem().Interface()) } else { fieldValue := field.ReflectValueOf(ctx, value) if fieldValue.IsNil() { @@ -880,14 +908,12 @@ func (field *Field) setupValuerAndSetter() { reflectV := reflect.ValueOf(v) if !reflectV.IsValid() { field.ReflectValueOf(ctx, value).Set(reflect.New(field.FieldType).Elem()) + } else if reflectV.Kind() == reflect.Ptr && reflectV.IsNil() { + return } else if reflectV.Type().AssignableTo(field.FieldType) { field.ReflectValueOf(ctx, value).Set(reflectV) } else if reflectV.Kind() == reflect.Ptr { - if reflectV.IsNil() || !reflectV.IsValid() { - field.ReflectValueOf(ctx, value).Set(reflect.New(field.FieldType).Elem()) - } else { - return field.Set(ctx, value, reflectV.Elem().Interface()) - } + return field.Set(ctx, value, reflectV.Elem().Interface()) } else { if valuer, ok := v.(driver.Valuer); ok { v, _ = valuer.Value() @@ -916,6 +942,8 @@ func (field *Field) setupValuerAndSetter() { sameElemType = field.FieldType == reflect.ValueOf(field.Serializer).Type().Elem() } + serializerValue := reflect.Indirect(reflect.ValueOf(field.Serializer)) + serializerType := serializerValue.Type() field.Set = func(ctx context.Context, value reflect.Value, v interface{}) (err error) { if s, ok := v.(*serializer); ok { if s.fieldValue != nil { @@ -923,11 +951,12 @@ func (field *Field) setupValuerAndSetter() { } else if err = s.Serializer.Scan(ctx, field, value, s.value); err == nil { if sameElemType { field.ReflectValueOf(ctx, value).Set(reflect.ValueOf(s.Serializer).Elem()) - s.Serializer = reflect.New(reflect.Indirect(reflect.ValueOf(field.Serializer)).Type()).Interface().(SerializerInterface) } else if sameType { field.ReflectValueOf(ctx, value).Set(reflect.ValueOf(s.Serializer)) - s.Serializer = reflect.New(reflect.Indirect(reflect.ValueOf(field.Serializer)).Type()).Interface().(SerializerInterface) } + si := reflect.New(serializerType) + si.Elem().Set(serializerValue) + s.Serializer = si.Interface().(SerializerInterface) } } else { err = oldFieldSetter(ctx, value, v) @@ -939,11 +968,15 @@ func (field *Field) setupValuerAndSetter() { func (field *Field) setupNewValuePool() { if field.Serializer != nil { + serializerValue := reflect.Indirect(reflect.ValueOf(field.Serializer)) + serializerType := serializerValue.Type() field.NewValuePool = &sync.Pool{ New: func() interface{} { + si := reflect.New(serializerType) + si.Elem().Set(serializerValue) return &serializer{ Field: field, - Serializer: field.Serializer, + Serializer: si.Interface().(SerializerInterface), } }, } diff --git a/vendor/gorm.io/gorm/schema/naming.go b/vendor/gorm.io/gorm/schema/naming.go index a258beed36..a2a0150a30 100644 --- a/vendor/gorm.io/gorm/schema/naming.go +++ b/vendor/gorm.io/gorm/schema/naming.go @@ -28,10 +28,11 @@ type Replacer interface { // NamingStrategy tables, columns naming strategy type NamingStrategy struct { - TablePrefix string - SingularTable bool - NameReplacer Replacer - NoLowerCase bool + TablePrefix string + SingularTable bool + NameReplacer Replacer + NoLowerCase bool + IdentifierMaxLength int } // TableName convert string to table name @@ -89,12 +90,16 @@ func (ns NamingStrategy) formatName(prefix, table, name string) string { prefix, table, name, }, "_"), ".", "_") - if utf8.RuneCountInString(formattedName) > 64 { + if ns.IdentifierMaxLength == 0 { + ns.IdentifierMaxLength = 64 + } + + if utf8.RuneCountInString(formattedName) > ns.IdentifierMaxLength { h := sha1.New() h.Write([]byte(formattedName)) bs := h.Sum(nil) - formattedName = formattedName[0:56] + hex.EncodeToString(bs)[:8] + formattedName = formattedName[0:ns.IdentifierMaxLength-8] + hex.EncodeToString(bs)[:8] } return formattedName } diff --git a/vendor/gorm.io/gorm/schema/relationship.go b/vendor/gorm.io/gorm/schema/relationship.go index 9436f28319..e03dcc520b 100644 --- a/vendor/gorm.io/gorm/schema/relationship.go +++ b/vendor/gorm.io/gorm/schema/relationship.go @@ -27,6 +27,8 @@ type Relationships struct { HasMany []*Relationship Many2Many []*Relationship Relations map[string]*Relationship + + EmbeddedRelations map[string]*Relationships } type Relationship struct { @@ -106,7 +108,7 @@ func (schema *Schema) parseRelation(field *Field) *Relationship { } if schema.err == nil { - schema.Relationships.Relations[relation.Name] = relation + schema.setRelation(relation) switch relation.Type { case HasOne: schema.Relationships.HasOne = append(schema.Relationships.HasOne, relation) @@ -122,17 +124,51 @@ func (schema *Schema) parseRelation(field *Field) *Relationship { return relation } +func (schema *Schema) setRelation(relation *Relationship) { + // set non-embedded relation + if rel := schema.Relationships.Relations[relation.Name]; rel != nil { + if len(rel.Field.BindNames) > 1 { + schema.Relationships.Relations[relation.Name] = relation + } + } else { + schema.Relationships.Relations[relation.Name] = relation + } + + // set embedded relation + if len(relation.Field.BindNames) <= 1 { + return + } + relationships := &schema.Relationships + for i, name := range relation.Field.BindNames { + if i < len(relation.Field.BindNames)-1 { + if relationships.EmbeddedRelations == nil { + relationships.EmbeddedRelations = map[string]*Relationships{} + } + if r := relationships.EmbeddedRelations[name]; r == nil { + relationships.EmbeddedRelations[name] = &Relationships{} + } + relationships = relationships.EmbeddedRelations[name] + } else { + if relationships.Relations == nil { + relationships.Relations = map[string]*Relationship{} + } + relationships.Relations[relation.Name] = relation + } + } +} + // User has many Toys, its `Polymorphic` is `Owner`, Pet has one Toy, its `Polymorphic` is `Owner` -// type User struct { -// Toys []Toy `gorm:"polymorphic:Owner;"` -// } -// type Pet struct { -// Toy Toy `gorm:"polymorphic:Owner;"` -// } -// type Toy struct { -// OwnerID int -// OwnerType string -// } +// +// type User struct { +// Toys []Toy `gorm:"polymorphic:Owner;"` +// } +// type Pet struct { +// Toy Toy `gorm:"polymorphic:Owner;"` +// } +// type Toy struct { +// OwnerID int +// OwnerType string +// } func (schema *Schema) buildPolymorphicRelation(relation *Relationship, field *Field, polymorphic string) { relation.Polymorphic = &Polymorphic{ Value: schema.Table, @@ -165,6 +201,11 @@ func (schema *Schema) buildPolymorphicRelation(relation *Relationship, field *Fi } } + if primaryKeyField == nil { + schema.err = fmt.Errorf("invalid polymorphic type %v for %v on field %s, missing primaryKey field", relation.FieldSchema, schema, field.Name) + return + } + // use same data type for foreign keys if copyableDataType(primaryKeyField.DataType) { relation.Polymorphic.PolymorphicID.DataType = primaryKeyField.DataType @@ -427,7 +468,7 @@ func (schema *Schema) guessRelation(relation *Relationship, field *Field, cgl gu foreignFields = append(foreignFields, f) } } else { - var primarySchemaName = primarySchema.Name + primarySchemaName := primarySchema.Name if primarySchemaName == "" { primarySchemaName = relation.FieldSchema.Name } @@ -442,6 +483,7 @@ func (schema *Schema) guessRelation(relation *Relationship, field *Field, cgl gu primaryFields = primarySchema.PrimaryFields } + primaryFieldLoop: for _, primaryField := range primaryFields { lookUpName := primarySchemaName + primaryField.Name if gl == guessBelongs { @@ -453,11 +495,18 @@ func (schema *Schema) guessRelation(relation *Relationship, field *Field, cgl gu lookUpNames = append(lookUpNames, strings.TrimSuffix(lookUpName, primaryField.Name)+"ID", strings.TrimSuffix(lookUpName, primaryField.Name)+"Id", schema.namer.ColumnName(foreignSchema.Table, strings.TrimSuffix(lookUpName, primaryField.Name)+"ID")) } + for _, name := range lookUpNames { + if f := foreignSchema.LookUpFieldByBindName(field.BindNames, name); f != nil { + foreignFields = append(foreignFields, f) + primaryFields = append(primaryFields, primaryField) + continue primaryFieldLoop + } + } for _, name := range lookUpNames { if f := foreignSchema.LookUpField(name); f != nil { foreignFields = append(foreignFields, f) primaryFields = append(primaryFields, primaryField) - break + continue primaryFieldLoop } } } diff --git a/vendor/gorm.io/gorm/schema/schema.go b/vendor/gorm.io/gorm/schema/schema.go index b34383bdc3..3e7459ce74 100644 --- a/vendor/gorm.io/gorm/schema/schema.go +++ b/vendor/gorm.io/gorm/schema/schema.go @@ -6,12 +6,27 @@ import ( "fmt" "go/ast" "reflect" + "strings" "sync" "gorm.io/gorm/clause" "gorm.io/gorm/logger" ) +type callbackType string + +const ( + callbackTypeBeforeCreate callbackType = "BeforeCreate" + callbackTypeBeforeUpdate callbackType = "BeforeUpdate" + callbackTypeAfterCreate callbackType = "AfterCreate" + callbackTypeAfterUpdate callbackType = "AfterUpdate" + callbackTypeBeforeSave callbackType = "BeforeSave" + callbackTypeAfterSave callbackType = "AfterSave" + callbackTypeBeforeDelete callbackType = "BeforeDelete" + callbackTypeAfterDelete callbackType = "AfterDelete" + callbackTypeAfterFind callbackType = "AfterFind" +) + // ErrUnsupportedDataType unsupported data type var ErrUnsupportedDataType = errors.New("unsupported data type") @@ -25,6 +40,7 @@ type Schema struct { PrimaryFieldDBNames []string Fields []*Field FieldsByName map[string]*Field + FieldsByBindName map[string]*Field // embedded fields is 'Embed.Field' FieldsByDBName map[string]*Field FieldsWithDefaultDBValue []*Field // fields with default value assigned by database Relationships Relationships @@ -67,6 +83,27 @@ func (schema Schema) LookUpField(name string) *Field { return nil } +// LookUpFieldByBindName looks for the closest field in the embedded struct. +// +// type Struct struct { +// Embedded struct { +// ID string // is selected by LookUpFieldByBindName([]string{"Embedded", "ID"}, "ID") +// } +// ID string // is selected by LookUpFieldByBindName([]string{"ID"}, "ID") +// } +func (schema Schema) LookUpFieldByBindName(bindNames []string, name string) *Field { + if len(bindNames) == 0 { + return nil + } + for i := len(bindNames) - 1; i >= 0; i-- { + find := strings.Join(bindNames[:i], ".") + "." + name + if field, ok := schema.FieldsByBindName[find]; ok { + return field + } + } + return nil +} + type Tabler interface { TableName() string } @@ -140,15 +177,16 @@ func ParseWithSpecialTableName(dest interface{}, cacheStore *sync.Map, namer Nam } schema := &Schema{ - Name: modelType.Name(), - ModelType: modelType, - Table: tableName, - FieldsByName: map[string]*Field{}, - FieldsByDBName: map[string]*Field{}, - Relationships: Relationships{Relations: map[string]*Relationship{}}, - cacheStore: cacheStore, - namer: namer, - initialized: make(chan struct{}), + Name: modelType.Name(), + ModelType: modelType, + Table: tableName, + FieldsByName: map[string]*Field{}, + FieldsByBindName: map[string]*Field{}, + FieldsByDBName: map[string]*Field{}, + Relationships: Relationships{Relations: map[string]*Relationship{}}, + cacheStore: cacheStore, + namer: namer, + initialized: make(chan struct{}), } // When the schema initialization is completed, the channel will be closed defer close(schema.initialized) @@ -176,6 +214,7 @@ func ParseWithSpecialTableName(dest interface{}, cacheStore *sync.Map, namer Nam field.DBName = namer.ColumnName(schema.Table, field.Name) } + bindName := field.BindName() if field.DBName != "" { // nonexistence or shortest path or first appear prioritized if has permission if v, ok := schema.FieldsByDBName[field.DBName]; !ok || ((field.Creatable || field.Updatable || field.Readable) && len(field.BindNames) < len(v.BindNames)) { @@ -184,6 +223,7 @@ func ParseWithSpecialTableName(dest interface{}, cacheStore *sync.Map, namer Nam } schema.FieldsByDBName[field.DBName] = field schema.FieldsByName[field.Name] = field + schema.FieldsByBindName[bindName] = field if v != nil && v.PrimaryKey { for idx, f := range schema.PrimaryFields { @@ -202,6 +242,9 @@ func ParseWithSpecialTableName(dest interface{}, cacheStore *sync.Map, namer Nam if of, ok := schema.FieldsByName[field.Name]; !ok || of.TagSettings["-"] == "-" { schema.FieldsByName[field.Name] = field } + if of, ok := schema.FieldsByBindName[bindName]; !ok || of.TagSettings["-"] == "-" { + schema.FieldsByBindName[bindName] = field + } field.setupValuerAndSetter() } @@ -221,8 +264,18 @@ func ParseWithSpecialTableName(dest interface{}, cacheStore *sync.Map, namer Nam } } - if schema.PrioritizedPrimaryField == nil && len(schema.PrimaryFields) == 1 { - schema.PrioritizedPrimaryField = schema.PrimaryFields[0] + if schema.PrioritizedPrimaryField == nil { + if len(schema.PrimaryFields) == 1 { + schema.PrioritizedPrimaryField = schema.PrimaryFields[0] + } else if len(schema.PrimaryFields) > 1 { + // If there are multiple primary keys, the AUTOINCREMENT field is prioritized + for _, field := range schema.PrimaryFields { + if field.AutoIncrement { + schema.PrioritizedPrimaryField = field + break + } + } + } } for _, field := range schema.PrimaryFields { @@ -249,14 +302,20 @@ func ParseWithSpecialTableName(dest interface{}, cacheStore *sync.Map, namer Nam } } - callbacks := []string{"BeforeCreate", "AfterCreate", "BeforeUpdate", "AfterUpdate", "BeforeSave", "AfterSave", "BeforeDelete", "AfterDelete", "AfterFind"} - for _, name := range callbacks { - if methodValue := modelValue.MethodByName(name); methodValue.IsValid() { + callbackTypes := []callbackType{ + callbackTypeBeforeCreate, callbackTypeAfterCreate, + callbackTypeBeforeUpdate, callbackTypeAfterUpdate, + callbackTypeBeforeSave, callbackTypeAfterSave, + callbackTypeBeforeDelete, callbackTypeAfterDelete, + callbackTypeAfterFind, + } + for _, cbName := range callbackTypes { + if methodValue := callBackToMethodValue(modelValue, cbName); methodValue.IsValid() { switch methodValue.Type().String() { case "func(*gorm.DB) error": // TODO hack - reflect.Indirect(reflect.ValueOf(schema)).FieldByName(name).SetBool(true) + reflect.Indirect(reflect.ValueOf(schema)).FieldByName(string(cbName)).SetBool(true) default: - logger.Default.Warn(context.Background(), "Model %v don't match %vInterface, should be `%v(*gorm.DB) error`. Please see https://gorm.io/docs/hooks.html", schema, name, name) + logger.Default.Warn(context.Background(), "Model %v don't match %vInterface, should be `%v(*gorm.DB) error`. Please see https://gorm.io/docs/hooks.html", schema, cbName, cbName) } } } @@ -283,6 +342,7 @@ func ParseWithSpecialTableName(dest interface{}, cacheStore *sync.Map, namer Nam return schema, schema.err } else { schema.FieldsByName[field.Name] = field + schema.FieldsByBindName[field.BindName()] = field } } @@ -309,6 +369,39 @@ func ParseWithSpecialTableName(dest interface{}, cacheStore *sync.Map, namer Nam return schema, schema.err } +// This unrolling is needed to show to the compiler the exact set of methods +// that can be used on the modelType. +// Prior to go1.22 any use of MethodByName would cause the linker to +// abandon dead code elimination for the entire binary. +// As of go1.22 the compiler supports one special case of a string constant +// being passed to MethodByName. For enterprise customers or those building +// large binaries, this gives a significant reduction in binary size. +// https://github.com/golang/go/issues/62257 +func callBackToMethodValue(modelType reflect.Value, cbType callbackType) reflect.Value { + switch cbType { + case callbackTypeBeforeCreate: + return modelType.MethodByName(string(callbackTypeBeforeCreate)) + case callbackTypeAfterCreate: + return modelType.MethodByName(string(callbackTypeAfterCreate)) + case callbackTypeBeforeUpdate: + return modelType.MethodByName(string(callbackTypeBeforeUpdate)) + case callbackTypeAfterUpdate: + return modelType.MethodByName(string(callbackTypeAfterUpdate)) + case callbackTypeBeforeSave: + return modelType.MethodByName(string(callbackTypeBeforeSave)) + case callbackTypeAfterSave: + return modelType.MethodByName(string(callbackTypeAfterSave)) + case callbackTypeBeforeDelete: + return modelType.MethodByName(string(callbackTypeBeforeDelete)) + case callbackTypeAfterDelete: + return modelType.MethodByName(string(callbackTypeAfterDelete)) + case callbackTypeAfterFind: + return modelType.MethodByName(string(callbackTypeAfterFind)) + default: + return reflect.ValueOf(nil) + } +} + func getOrParse(dest interface{}, cacheStore *sync.Map, namer Namer) (*Schema, error) { modelType := reflect.ValueOf(dest).Type() for modelType.Kind() == reflect.Slice || modelType.Kind() == reflect.Array || modelType.Kind() == reflect.Ptr { diff --git a/vendor/gorm.io/gorm/schema/serializer.go b/vendor/gorm.io/gorm/schema/serializer.go index 9a6aa4fc0c..397edff034 100644 --- a/vendor/gorm.io/gorm/schema/serializer.go +++ b/vendor/gorm.io/gorm/schema/serializer.go @@ -70,8 +70,7 @@ type SerializerValuerInterface interface { } // JSONSerializer json serializer -type JSONSerializer struct { -} +type JSONSerializer struct{} // Scan implements serializer interface func (JSONSerializer) Scan(ctx context.Context, field *Field, dst reflect.Value, dbValue interface{}) (err error) { @@ -110,8 +109,7 @@ func (JSONSerializer) Value(ctx context.Context, field *Field, dst reflect.Value } // UnixSecondSerializer json serializer -type UnixSecondSerializer struct { -} +type UnixSecondSerializer struct{} // Scan implements serializer interface func (UnixSecondSerializer) Scan(ctx context.Context, field *Field, dst reflect.Value, dbValue interface{}) (err error) { @@ -141,8 +139,7 @@ func (UnixSecondSerializer) Value(ctx context.Context, field *Field, dst reflect } // GobSerializer gob serializer -type GobSerializer struct { -} +type GobSerializer struct{} // Scan implements serializer interface func (GobSerializer) Scan(ctx context.Context, field *Field, dst reflect.Value, dbValue interface{}) (err error) { diff --git a/vendor/gorm.io/gorm/schema/utils.go b/vendor/gorm.io/gorm/schema/utils.go index acf1a739bf..7fdda1855f 100644 --- a/vendor/gorm.io/gorm/schema/utils.go +++ b/vendor/gorm.io/gorm/schema/utils.go @@ -115,6 +115,11 @@ func GetIdentityFieldValuesMap(ctx context.Context, reflectValue reflect.Value, notZero, zero bool ) + if reflectValue.Kind() == reflect.Ptr || + reflectValue.Kind() == reflect.Interface { + reflectValue = reflectValue.Elem() + } + switch reflectValue.Kind() { case reflect.Struct: results = [][]interface{}{make([]interface{}, len(fields))} @@ -133,7 +138,7 @@ func GetIdentityFieldValuesMap(ctx context.Context, reflectValue reflect.Value, for i := 0; i < reflectValue.Len(); i++ { elem := reflectValue.Index(i) elemKey := elem.Interface() - if elem.Kind() != reflect.Ptr { + if elem.Kind() != reflect.Ptr && elem.CanAddr() { elemKey = elem.Addr().Interface() } diff --git a/vendor/gorm.io/gorm/statement.go b/vendor/gorm.io/gorm/statement.go index 0816529382..59c0b772cf 100644 --- a/vendor/gorm.io/gorm/statement.go +++ b/vendor/gorm.io/gorm/statement.go @@ -120,6 +120,8 @@ func (stmt *Statement) QuoteTo(writer clause.Writer, field interface{}) { write(v.Raw, stmt.Schema.PrioritizedPrimaryField.DBName) } else if len(stmt.Schema.DBNames) > 0 { write(v.Raw, stmt.Schema.DBNames[0]) + } else { + stmt.DB.AddError(ErrModelAccessibleFieldsRequired) //nolint:typecheck,errcheck } } else { write(v.Raw, v.Name) @@ -322,11 +324,9 @@ func (stmt *Statement) BuildCondition(query interface{}, args ...interface{}) [] case clause.Expression: conds = append(conds, v) case *DB: - for _, scope := range v.Statement.scopes { - v = scope(v) - } + v.executeScopes() - if cs, ok := v.Statement.Clauses["WHERE"]; ok { + if cs, ok := v.Statement.Clauses["WHERE"]; ok && cs.Expression != nil { if where, ok := cs.Expression.(clause.Where); ok { if len(where.Exprs) == 1 { if orConds, ok := where.Exprs[0].(clause.OrConditions); ok { @@ -334,9 +334,13 @@ func (stmt *Statement) BuildCondition(query interface{}, args ...interface{}) [] } } conds = append(conds, clause.And(where.Exprs...)) - } else if cs.Expression != nil { + } else { conds = append(conds, cs.Expression) } + if v.Statement == stmt { + cs.Expression = nil + stmt.Statement.Clauses["WHERE"] = cs + } } case map[interface{}]interface{}: for i, j := range v { diff --git a/vendor/gorm.io/gorm/utils/utils.go b/vendor/gorm.io/gorm/utils/utils.go index e08533cd68..c8fec5b09c 100644 --- a/vendor/gorm.io/gorm/utils/utils.go +++ b/vendor/gorm.io/gorm/utils/utils.go @@ -89,19 +89,28 @@ func Contains(elems []string, elem string) bool { return false } -func AssertEqual(src, dst interface{}) bool { - if !reflect.DeepEqual(src, dst) { - if valuer, ok := src.(driver.Valuer); ok { - src, _ = valuer.Value() - } +func AssertEqual(x, y interface{}) bool { + if reflect.DeepEqual(x, y) { + return true + } + if x == nil || y == nil { + return false + } - if valuer, ok := dst.(driver.Valuer); ok { - dst, _ = valuer.Value() - } + xval := reflect.ValueOf(x) + yval := reflect.ValueOf(y) + if xval.Kind() == reflect.Ptr && xval.IsNil() || + yval.Kind() == reflect.Ptr && yval.IsNil() { + return false + } - return reflect.DeepEqual(src, dst) + if valuer, ok := x.(driver.Valuer); ok { + x, _ = valuer.Value() + } + if valuer, ok := y.(driver.Valuer); ok { + y, _ = valuer.Value() } - return true + return reflect.DeepEqual(x, y) } func ToString(value interface{}) string { @@ -131,3 +140,20 @@ func ToString(value interface{}) string { } return "" } + +const nestedRelationSplit = "__" + +// NestedRelationName nested relationships like `Manager__Company` +func NestedRelationName(prefix, name string) string { + return prefix + nestedRelationSplit + name +} + +// SplitNestedRelationName Split nested relationships to `[]string{"Manager","Company"}` +func SplitNestedRelationName(name string) []string { + return strings.Split(name, nestedRelationSplit) +} + +// JoinNestedRelationNames nested relationships like `Manager__Company` +func JoinNestedRelationNames(relationNames []string) string { + return strings.Join(relationNames, nestedRelationSplit) +} diff --git a/vendor/modules.txt b/vendor/modules.txt index 5b917fdbbb..fac454f9df 100644 --- a/vendor/modules.txt +++ b/vendor/modules.txt @@ -454,7 +454,7 @@ github.com/jinzhu/copier # github.com/jinzhu/inflection v1.0.0 ## explicit github.com/jinzhu/inflection -# github.com/jinzhu/now v1.1.4 +# github.com/jinzhu/now v1.1.5 ## explicit; go 1.12 github.com/jinzhu/now # github.com/jmespath/go-jmespath v0.4.0 @@ -1300,8 +1300,8 @@ gopkg.in/yaml.v3 # gorm.io/driver/postgres v1.3.5 ## explicit; go 1.14 gorm.io/driver/postgres -# gorm.io/gorm v1.24.5 -## explicit; go 1.16 +# gorm.io/gorm v1.25.5 +## explicit; go 1.18 gorm.io/gorm gorm.io/gorm/callbacks gorm.io/gorm/clause