Skip to content

Commit

Permalink
Fix ambiguous column when using same column name in join table, close #…
Browse files Browse the repository at this point in the history
  • Loading branch information
jinzhu committed Jul 9, 2020
1 parent e1084e7 commit 2ae0653
Show file tree
Hide file tree
Showing 9 changed files with 40 additions and 23 deletions.
20 changes: 10 additions & 10 deletions association.go
Expand Up @@ -122,7 +122,7 @@ func (association *Association) Replace(values ...interface{}) error {
)

if _, rvs := schema.GetIdentityFieldValuesMap(relValues, rel.FieldSchema.PrimaryFields); len(rvs) > 0 {
if column, values := schema.ToQueryValues(rel.FieldSchema.PrimaryFieldDBNames, rvs); len(values) > 0 {
if column, values := schema.ToQueryValues(rel.FieldSchema.Table, rel.FieldSchema.PrimaryFieldDBNames, rvs); len(values) > 0 {
tx.Not(clause.IN{Column: column, Values: values})
}
}
Expand All @@ -138,7 +138,7 @@ func (association *Association) Replace(values ...interface{}) error {
}

if _, pvs := schema.GetIdentityFieldValuesMap(reflectValue, primaryFields); len(pvs) > 0 {
column, values := schema.ToQueryValues(foreignKeys, pvs)
column, values := schema.ToQueryValues(rel.FieldSchema.Table, foreignKeys, pvs)
tx.Where(clause.IN{Column: column, Values: values}).UpdateColumns(updateMap)
}
case schema.Many2Many:
Expand All @@ -164,14 +164,14 @@ func (association *Association) Replace(values ...interface{}) error {
}

_, pvs := schema.GetIdentityFieldValuesMap(reflectValue, primaryFields)
if column, values := schema.ToQueryValues(joinPrimaryKeys, pvs); len(values) > 0 {
if column, values := schema.ToQueryValues(rel.JoinTable.Table, joinPrimaryKeys, pvs); len(values) > 0 {
tx.Where(clause.IN{Column: column, Values: values})
} else {
return ErrorPrimaryKeyRequired
}

_, rvs := schema.GetIdentityFieldValuesMapFromValues(values, relPrimaryFields)
if relColumn, relValues := schema.ToQueryValues(joinRelPrimaryKeys, rvs); len(relValues) > 0 {
if relColumn, relValues := schema.ToQueryValues(rel.JoinTable.Table, joinRelPrimaryKeys, rvs); len(relValues) > 0 {
tx.Where(clause.Not(clause.IN{Column: relColumn, Values: relValues}))
}

Expand Down Expand Up @@ -208,23 +208,23 @@ func (association *Association) Delete(values ...interface{}) error {
tx := association.DB.Model(reflect.New(rel.Schema.ModelType).Interface())

_, pvs := schema.GetIdentityFieldValuesMap(reflectValue, rel.Schema.PrimaryFields)
pcolumn, pvalues := schema.ToQueryValues(rel.Schema.PrimaryFieldDBNames, pvs)
pcolumn, pvalues := schema.ToQueryValues(rel.Schema.Table, rel.Schema.PrimaryFieldDBNames, pvs)
conds = append(conds, clause.IN{Column: pcolumn, Values: pvalues})

_, rvs := schema.GetIdentityFieldValuesMapFromValues(values, primaryFields)
relColumn, relValues := schema.ToQueryValues(foreignKeys, rvs)
relColumn, relValues := schema.ToQueryValues(rel.Schema.Table, foreignKeys, rvs)
conds = append(conds, clause.IN{Column: relColumn, Values: relValues})

association.Error = tx.Clauses(conds...).UpdateColumns(updateAttrs).Error
case schema.HasOne, schema.HasMany:
tx := association.DB.Model(reflect.New(rel.FieldSchema.ModelType).Interface())

_, pvs := schema.GetIdentityFieldValuesMap(reflectValue, primaryFields)
pcolumn, pvalues := schema.ToQueryValues(foreignKeys, pvs)
pcolumn, pvalues := schema.ToQueryValues(rel.FieldSchema.Table, foreignKeys, pvs)
conds = append(conds, clause.IN{Column: pcolumn, Values: pvalues})

_, rvs := schema.GetIdentityFieldValuesMapFromValues(values, rel.FieldSchema.PrimaryFields)
relColumn, relValues := schema.ToQueryValues(rel.FieldSchema.PrimaryFieldDBNames, rvs)
relColumn, relValues := schema.ToQueryValues(rel.FieldSchema.Table, rel.FieldSchema.PrimaryFieldDBNames, rvs)
conds = append(conds, clause.IN{Column: relColumn, Values: relValues})

association.Error = tx.Clauses(conds...).UpdateColumns(updateAttrs).Error
Expand All @@ -250,11 +250,11 @@ func (association *Association) Delete(values ...interface{}) error {
}

_, pvs := schema.GetIdentityFieldValuesMap(reflectValue, primaryFields)
pcolumn, pvalues := schema.ToQueryValues(joinPrimaryKeys, pvs)
pcolumn, pvalues := schema.ToQueryValues(rel.JoinTable.Table, joinPrimaryKeys, pvs)
conds = append(conds, clause.IN{Column: pcolumn, Values: pvalues})

_, rvs := schema.GetIdentityFieldValuesMapFromValues(values, relPrimaryFields)
relColumn, relValues := schema.ToQueryValues(joinRelPrimaryKeys, rvs)
relColumn, relValues := schema.ToQueryValues(rel.JoinTable.Table, joinRelPrimaryKeys, rvs)
conds = append(conds, clause.IN{Column: relColumn, Values: relValues})

association.Error = association.DB.Where(clause.Where{Exprs: conds}).Model(nil).Delete(modelValue).Error
Expand Down
4 changes: 2 additions & 2 deletions callbacks/delete.go
Expand Up @@ -35,15 +35,15 @@ func Delete(db *gorm.DB) {

if db.Statement.Schema != nil {
_, queryValues := schema.GetIdentityFieldValuesMap(db.Statement.ReflectValue, db.Statement.Schema.PrimaryFields)
column, values := schema.ToQueryValues(db.Statement.Schema.PrimaryFieldDBNames, queryValues)
column, values := schema.ToQueryValues(db.Statement.Schema.Table, db.Statement.Schema.PrimaryFieldDBNames, queryValues)

if len(values) > 0 {
db.Statement.AddClause(clause.Where{Exprs: []clause.Expression{clause.IN{Column: column, Values: values}}})
}

if db.Statement.Dest != db.Statement.Model && db.Statement.Model != nil {
_, queryValues = schema.GetIdentityFieldValuesMap(reflect.ValueOf(db.Statement.Model), db.Statement.Schema.PrimaryFields)
column, values = schema.ToQueryValues(db.Statement.Schema.PrimaryFieldDBNames, queryValues)
column, values = schema.ToQueryValues(db.Statement.Schema.Table, db.Statement.Schema.PrimaryFieldDBNames, queryValues)

if len(values) > 0 {
db.Statement.AddClause(clause.Where{Exprs: []clause.Expression{clause.IN{Column: column, Values: values}}})
Expand Down
4 changes: 2 additions & 2 deletions callbacks/preload.go
Expand Up @@ -49,7 +49,7 @@ func preload(db *gorm.DB, rels []*schema.Relationship, conds []interface{}) {
}

joinResults := rel.JoinTable.MakeSlice().Elem()
column, values := schema.ToQueryValues(joinForeignKeys, joinForeignValues)
column, values := schema.ToQueryValues(rel.JoinTable.Table, joinForeignKeys, joinForeignValues)
tx.Where(clause.IN{Column: column, Values: values}).Find(joinResults.Addr().Interface())

// convert join identity map to relation identity map
Expand Down Expand Up @@ -93,7 +93,7 @@ func preload(db *gorm.DB, rels []*schema.Relationship, conds []interface{}) {
}

reflectResults := rel.FieldSchema.MakeSlice().Elem()
column, values := schema.ToQueryValues(relForeignKeys, foreignValues)
column, values := schema.ToQueryValues(rel.FieldSchema.Table, relForeignKeys, foreignValues)

for _, cond := range conds {
if fc, ok := cond.(func(*gorm.DB) *gorm.DB); ok {
Expand Down
4 changes: 3 additions & 1 deletion schema/relationship.go
Expand Up @@ -462,10 +462,12 @@ func (rel *Relationship) ParseConstraint() *Constraint {
}

func (rel *Relationship) ToQueryConditions(reflectValue reflect.Value) (conds []clause.Expression) {
table := rel.FieldSchema.Table
foreignFields := []*Field{}
relForeignKeys := []string{}

if rel.JoinTable != nil {
table = rel.JoinTable.Table
for _, ref := range rel.References {
if ref.OwnPrimaryKey {
foreignFields = append(foreignFields, ref.PrimaryKey)
Expand Down Expand Up @@ -500,7 +502,7 @@ func (rel *Relationship) ToQueryConditions(reflectValue reflect.Value) (conds []
}

_, foreignValues := GetIdentityFieldValuesMap(reflectValue, foreignFields)
column, values := ToQueryValues(relForeignKeys, foreignValues)
column, values := ToQueryValues(table, relForeignKeys, foreignValues)

conds = append(conds, clause.IN{Column: column, Values: values})
return
Expand Down
12 changes: 9 additions & 3 deletions schema/utils.go
Expand Up @@ -5,6 +5,7 @@ import (
"regexp"
"strings"

"gorm.io/gorm/clause"
"gorm.io/gorm/utils"
)

Expand Down Expand Up @@ -164,18 +165,23 @@ func GetIdentityFieldValuesMapFromValues(values []interface{}, fields []*Field)
}

// ToQueryValues to query values
func ToQueryValues(foreignKeys []string, foreignValues [][]interface{}) (interface{}, []interface{}) {
func ToQueryValues(table string, foreignKeys []string, foreignValues [][]interface{}) (interface{}, []interface{}) {
queryValues := make([]interface{}, len(foreignValues))
if len(foreignKeys) == 1 {
for idx, r := range foreignValues {
queryValues[idx] = r[0]
}

return foreignKeys[0], queryValues
return clause.Column{Table: table, Name: foreignKeys[0]}, queryValues
} else {
columns := make([]clause.Column, len(foreignKeys))
for idx, key := range foreignKeys {
columns[idx] = clause.Column{Table: table, Name: key}
}

for idx, r := range foreignValues {
queryValues[idx] = r
}
return columns, queryValues
}
return foreignKeys, queryValues
}
4 changes: 2 additions & 2 deletions soft_delete.go
Expand Up @@ -58,15 +58,15 @@ func (SoftDeleteClause) ModifyStatement(stmt *Statement) {

if stmt.Schema != nil {
_, queryValues := schema.GetIdentityFieldValuesMap(stmt.ReflectValue, stmt.Schema.PrimaryFields)
column, values := schema.ToQueryValues(stmt.Schema.PrimaryFieldDBNames, queryValues)
column, values := schema.ToQueryValues(stmt.Schema.Table, stmt.Schema.PrimaryFieldDBNames, queryValues)

if len(values) > 0 {
stmt.AddClause(clause.Where{Exprs: []clause.Expression{clause.IN{Column: column, Values: values}}})
}

if stmt.Dest != stmt.Model && stmt.Model != nil {
_, queryValues = schema.GetIdentityFieldValuesMap(reflect.ValueOf(stmt.Model), stmt.Schema.PrimaryFields)
column, values = schema.ToQueryValues(stmt.Schema.PrimaryFieldDBNames, queryValues)
column, values = schema.ToQueryValues(stmt.Schema.Table, stmt.Schema.PrimaryFieldDBNames, queryValues)

if len(values) > 0 {
stmt.AddClause(clause.Where{Exprs: []clause.Expression{clause.IN{Column: column, Values: values}}})
Expand Down
9 changes: 9 additions & 0 deletions statement.go
Expand Up @@ -107,6 +107,15 @@ func (stmt *Statement) QuoteTo(writer clause.Writer, field interface{}) {
writer.WriteString(" AS ")
stmt.DB.Dialector.QuoteTo(writer, v.Alias)
}
case []clause.Column:
writer.WriteByte('(')
for idx, d := range v {
if idx > 0 {
writer.WriteString(",")
}
stmt.QuoteTo(writer, d)
}
writer.WriteByte(')')
case string:
stmt.DB.Dialector.QuoteTo(writer, v)
case []string:
Expand Down
2 changes: 1 addition & 1 deletion tests/go.mod
Expand Up @@ -6,7 +6,7 @@ require (
github.com/google/uuid v1.1.1
github.com/jinzhu/now v1.1.1
github.com/lib/pq v1.6.0
gorm.io/driver/mysql v0.2.8
gorm.io/driver/mysql v0.2.9
gorm.io/driver/postgres v0.2.5
gorm.io/driver/sqlite v1.0.8
gorm.io/driver/sqlserver v0.2.4
Expand Down
4 changes: 2 additions & 2 deletions tests/multi_primary_keys_test.go
Expand Up @@ -13,7 +13,7 @@ type Blog struct {
Locale string `gorm:"primary_key"`
Subject string
Body string
Tags []Tag `gorm:"many2many:blogs_tags;"`
Tags []Tag `gorm:"many2many:blog_tags;"`
SharedTags []Tag `gorm:"many2many:shared_blog_tags;ForeignKey:id;References:id"`
LocaleTags []Tag `gorm:"many2many:locale_blog_tags;ForeignKey:id,locale;References:id"`
}
Expand All @@ -22,7 +22,7 @@ type Tag struct {
ID uint `gorm:"primary_key"`
Locale string `gorm:"primary_key"`
Value string
Blogs []*Blog `gorm:"many2many:blogs_tags"`
Blogs []*Blog `gorm:"many2many:blog_tags"`
}

func compareTags(tags []Tag, contents []string) bool {
Expand Down

0 comments on commit 2ae0653

Please sign in to comment.