Skip to content

Commit

Permalink
Fix failed to guess relations for embedded types, close #3224
Browse files Browse the repository at this point in the history
  • Loading branch information
jinzhu committed Aug 4, 2020
1 parent c11c939 commit ff985b9
Show file tree
Hide file tree
Showing 5 changed files with 76 additions and 18 deletions.
1 change: 1 addition & 0 deletions migrator/migrator.go
Original file line number Diff line number Diff line change
Expand Up @@ -120,6 +120,7 @@ func (m Migrator) AutoMigrate(values ...interface{}) error {
}
return nil
}); err != nil {
fmt.Println(err)
return err
}
}
Expand Down
2 changes: 2 additions & 0 deletions schema/field.go
Original file line number Diff line number Diff line change
Expand Up @@ -62,6 +62,7 @@ type Field struct {
TagSettings map[string]string
Schema *Schema
EmbeddedSchema *Schema
OwnerSchema *Schema
ReflectValueOf func(reflect.Value) reflect.Value
ValueOf func(reflect.Value) (value interface{}, zero bool)
Set func(reflect.Value, interface{}) error
Expand Down Expand Up @@ -321,6 +322,7 @@ func (schema *Schema) ParseField(fieldStruct reflect.StructField) *Field {
}
for _, ef := range field.EmbeddedSchema.Fields {
ef.Schema = schema
ef.OwnerSchema = field.EmbeddedSchema
ef.BindNames = append([]string{fieldStruct.Name}, ef.BindNames...)
// index is negative means is pointer
if field.FieldType.Kind() == reflect.Struct {
Expand Down
69 changes: 53 additions & 16 deletions schema/relationship.go
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@ import (
"reflect"
"regexp"
"strings"
"sync"

"github.com/jinzhu/inflection"
"gorm.io/gorm/clause"
Expand Down Expand Up @@ -66,9 +67,16 @@ func (schema *Schema) parseRelation(field *Field) {
}
)

if relation.FieldSchema, err = Parse(fieldValue, schema.cacheStore, schema.namer); err != nil {
schema.err = err
return
if field.OwnerSchema != nil {
if relation.FieldSchema, err = Parse(fieldValue, &sync.Map{}, schema.namer); err != nil {
schema.err = err
return
}
} else {
if relation.FieldSchema, err = Parse(fieldValue, schema.cacheStore, schema.namer); err != nil {
schema.err = err
return
}
}

if polymorphic := field.TagSettings["POLYMORPHIC"]; polymorphic != "" {
Expand All @@ -78,7 +86,7 @@ func (schema *Schema) parseRelation(field *Field) {
} else {
switch field.IndirectFieldType.Kind() {
case reflect.Struct, reflect.Slice:
schema.guessRelation(relation, field, true)
schema.guessRelation(relation, field, guessHas)
default:
schema.err = fmt.Errorf("unsupported data type %v for %v on field %v", relation.FieldSchema, schema, field.Name)
}
Expand Down Expand Up @@ -316,21 +324,50 @@ func (schema *Schema) buildMany2ManyRelation(relation *Relationship, field *Fiel
}
}

func (schema *Schema) guessRelation(relation *Relationship, field *Field, guessHas bool) {
type guessLevel int

const (
guessHas guessLevel = iota
guessEmbeddedHas
guessBelongs
guessEmbeddedBelongs
)

func (schema *Schema) guessRelation(relation *Relationship, field *Field, gl guessLevel) {
var (
primaryFields, foreignFields []*Field
primarySchema, foreignSchema = schema, relation.FieldSchema
)

if !guessHas {
primarySchema, foreignSchema = relation.FieldSchema, schema
reguessOrErr := func(err string, args ...interface{}) {
switch gl {
case guessHas:
schema.guessRelation(relation, field, guessEmbeddedHas)
case guessEmbeddedHas:
schema.guessRelation(relation, field, guessBelongs)
case guessBelongs:
schema.guessRelation(relation, field, guessEmbeddedBelongs)
default:
schema.err = fmt.Errorf(err, args...)
}
}

reguessOrErr := func(err string, args ...interface{}) {
if guessHas {
schema.guessRelation(relation, field, false)
switch gl {
case guessEmbeddedHas:
if field.OwnerSchema != nil {
primarySchema, foreignSchema = field.OwnerSchema, relation.FieldSchema
} else {
schema.err = fmt.Errorf(err, args...)
reguessOrErr("failed to guess %v's relations with %v's field %v, guess level: %v", relation.FieldSchema, schema, field.Name, gl)
return
}
case guessBelongs:
primarySchema, foreignSchema = relation.FieldSchema, schema
case guessEmbeddedBelongs:
if field.OwnerSchema != nil {
primarySchema, foreignSchema = relation.FieldSchema, field.OwnerSchema
} else {
reguessOrErr("failed to guess %v's relations with %v's field %v, guess level: %v", relation.FieldSchema, schema, field.Name, gl)
return
}
}

Expand All @@ -345,8 +382,8 @@ func (schema *Schema) guessRelation(relation *Relationship, field *Field, guessH
}
} else {
for _, primaryField := range primarySchema.PrimaryFields {
lookUpName := schema.Name + primaryField.Name
if !guessHas {
lookUpName := primarySchema.Name + primaryField.Name
if gl == guessBelongs {
lookUpName = field.Name + primaryField.Name
}

Expand All @@ -358,7 +395,7 @@ func (schema *Schema) guessRelation(relation *Relationship, field *Field, guessH
}

if len(foreignFields) == 0 {
reguessOrErr("failed to guess %v's relations with %v's field %v 1 g %v", relation.FieldSchema, schema, field.Name, guessHas)
reguessOrErr("failed to guess %v's relations with %v's field %v, guess level: %v", relation.FieldSchema, schema, field.Name, gl)
return
} else if len(relation.primaryKeys) > 0 {
for idx, primaryKey := range relation.primaryKeys {
Expand Down Expand Up @@ -394,11 +431,11 @@ func (schema *Schema) guessRelation(relation *Relationship, field *Field, guessH
relation.References = append(relation.References, &Reference{
PrimaryKey: primaryFields[idx],
ForeignKey: foreignField,
OwnPrimaryKey: schema == primarySchema && guessHas,
OwnPrimaryKey: (schema == primarySchema && gl == guessHas) || (field.OwnerSchema == primarySchema && gl == guessEmbeddedHas),
})
}

if guessHas {
if gl == guessHas || gl == guessEmbeddedHas {
relation.Type = "has"
} else {
relation.Type = BelongsTo
Expand Down
8 changes: 6 additions & 2 deletions tests/callbacks_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -101,8 +101,12 @@ func TestCallbacks(t *testing.T) {
results: []string{"c5", "c1", "c2", "c3", "c4"},
},
{
callbacks: []callback{{h: c1}, {h: c2, before: "c4", after: "c5"}, {h: c3, after: "*"}, {h: c4}, {h: c5, before: "*"}},
results: []string{"c5", "c1", "c2", "c4", "c3"},
callbacks: []callback{{h: c1}, {h: c2, before: "c4", after: "c5"}, {h: c3, before: "*"}, {h: c4}, {h: c5, before: "*"}},
results: []string{"c3", "c5", "c1", "c2", "c4"},
},
{
callbacks: []callback{{h: c1}, {h: c2, before: "c4", after: "c5"}, {h: c3, before: "c4", after: "*"}, {h: c4, after: "*"}, {h: c5, before: "*"}},
results: []string{"c5", "c1", "c2", "c3", "c4"},
},
}

Expand Down
14 changes: 14 additions & 0 deletions tests/embedded_struct_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@ import (
"testing"

"gorm.io/gorm"
. "gorm.io/gorm/utils/tests"
)

func TestEmbeddedStruct(t *testing.T) {
Expand Down Expand Up @@ -152,3 +153,16 @@ func TestEmbeddedScanValuer(t *testing.T) {
t.Errorf("Failed to create got error %v", err)
}
}

func TestEmbeddedRelations(t *testing.T) {
type AdvancedUser struct {
User `gorm:"embedded"`
Advanced bool
}

DB.Debug().Migrator().DropTable(&AdvancedUser{})

if err := DB.Debug().AutoMigrate(&AdvancedUser{}); err != nil {
t.Errorf("Failed to auto migrate advanced user, got error %v", err)
}
}

0 comments on commit ff985b9

Please sign in to comment.