From 55f17a0c162c4e8557265a3b2d2ca3b4993eb6a0 Mon Sep 17 00:00:00 2001 From: black Date: Wed, 15 Mar 2023 17:01:51 +0800 Subject: [PATCH 1/3] fix cond in scopes --- callbacks.go | 8 +------ chainable_api.go | 33 ++++++++++++++++++++++++++++ migrator.go | 8 +------ statement.go | 12 ++++++----- tests/scopes_test.go | 51 ++++++++++++++++++++++++++++++++++++++++++++ 5 files changed, 93 insertions(+), 19 deletions(-) diff --git a/callbacks.go b/callbacks.go index de979e459..0da4ecf99 100644 --- a/callbacks.go +++ b/callbacks.go @@ -74,13 +74,7 @@ func (cs *callbacks) Raw() *processor { func (p *processor) Execute(db *DB) *DB { // call scopes - for len(db.Statement.scopes) > 0 { - scopes := db.Statement.scopes - db.Statement.scopes = nil - for _, scope := range scopes { - db = scope(db) - } - } + db = db.executeScopes(false) var ( curTime = time.Now() diff --git a/chainable_api.go b/chainable_api.go index a85235e01..b50992846 100644 --- a/chainable_api.go +++ b/chainable_api.go @@ -366,6 +366,39 @@ func (db *DB) Scopes(funcs ...func(*DB) *DB) (tx *DB) { return tx } +func (db *DB) executeScopes(keepScopes bool) (tx *DB) { + tx = db.getInstance() + scopes := db.Statement.scopes + if len(scopes) == 0 { + return tx + } + tx.Statement.scopes = nil + + conditions := make([]clause.Interface, 0, 4) + if cs, ok := tx.Statement.Clauses["WHERE"]; ok && cs.Expression != nil { + conditions = append(conditions, cs.Expression.(clause.Interface)) + cs.Expression = nil + tx.Statement.Clauses["WHERE"] = cs + } + + for _, scope := range scopes { + tx = scope(tx) + if cs, ok := tx.Statement.Clauses["WHERE"]; ok && cs.Expression != nil { + conditions = append(conditions, cs.Expression.(clause.Interface)) + cs.Expression = nil + tx.Statement.Clauses["WHERE"] = cs + } + } + + for _, condition := range conditions { + tx.Statement.AddClause(condition) + } + if keepScopes { + tx.Statement.scopes = scopes + } + return tx +} + // Preload preload associations with given conditions // // // get all users, and preload all non-cancelled orders diff --git a/migrator.go b/migrator.go index 9c7cc2c49..6da231b20 100644 --- a/migrator.go +++ b/migrator.go @@ -12,13 +12,7 @@ func (db *DB) Migrator() Migrator { tx := db.getInstance() // apply scopes to migrator - for len(tx.Statement.scopes) > 0 { - scopes := tx.Statement.scopes - tx.Statement.scopes = nil - for _, scope := range scopes { - tx = scope(tx) - } - } + tx.executeScopes(false) return tx.Dialector.Migrator(tx.Session(&Session{})) } diff --git a/statement.go b/statement.go index bc959f0b6..162f0697e 100644 --- a/statement.go +++ b/statement.go @@ -324,11 +324,9 @@ func (stmt *Statement) BuildCondition(query interface{}, args ...interface{}) [] case clause.Expression: conds = append(conds, v) case *DB: - for _, scope := range v.Statement.scopes { - v = scope(v) - } + v.executeScopes(true) - if cs, ok := v.Statement.Clauses["WHERE"]; ok { + if cs, ok := v.Statement.Clauses["WHERE"]; ok && cs.Expression != nil { if where, ok := cs.Expression.(clause.Where); ok { if len(where.Exprs) == 1 { if orConds, ok := where.Exprs[0].(clause.OrConditions); ok { @@ -336,9 +334,13 @@ func (stmt *Statement) BuildCondition(query interface{}, args ...interface{}) [] } } conds = append(conds, clause.And(where.Exprs...)) - } else if cs.Expression != nil { + } else { conds = append(conds, cs.Expression) } + if v.Statement == stmt { + cs.Expression = nil + stmt.Statement.Clauses["WHERE"] = cs + } } case map[interface{}]interface{}: for i, j := range v { diff --git a/tests/scopes_test.go b/tests/scopes_test.go index ab3807ea2..b4a9d11cc 100644 --- a/tests/scopes_test.go +++ b/tests/scopes_test.go @@ -72,3 +72,54 @@ func TestScopes(t *testing.T) { t.Errorf("select max(id)") } } + +func TestComplexScopes(t *testing.T) { + tests := []struct { + name string + queryFn func(tx *gorm.DB) *gorm.DB + expected string + }{ + { + name: "depth_1", + queryFn: func(tx *gorm.DB) *gorm.DB { + return tx.Scopes( + func(d *gorm.DB) *gorm.DB { return d.Where("a = 1") }, + func(d *gorm.DB) *gorm.DB { return d.Where(d.Or("b = 2").Or("c = 3")) }, + ).Find(&Language{}) + }, + expected: "SELECT * FROM `languages` WHERE a = 1 AND (b = 2 OR c = 3)", + }, { + name: "depth_1_pre_cond", + queryFn: func(tx *gorm.DB) *gorm.DB { + return tx.Where("z = 0").Scopes( + func(d *gorm.DB) *gorm.DB { return d.Where("a = 1") }, + func(d *gorm.DB) *gorm.DB { return d.Or(d.Where("b = 2").Or("c = 3")) }, + ).Find(&Language{}) + }, + expected: "SELECT * FROM `languages` WHERE z = 0 AND a = 1 OR (b = 2 OR c = 3)", + }, { + name: "depth_2", + queryFn: func(tx *gorm.DB) *gorm.DB { + return tx.Scopes( + func(d *gorm.DB) *gorm.DB { return d.Model(&Language{}) }, + func(d *gorm.DB) *gorm.DB { + return d. + Or(d.Scopes( + func(d *gorm.DB) *gorm.DB { return d.Where("a = 1") }, + func(d *gorm.DB) *gorm.DB { return d.Where("b = 2") }, + )). + Or("c = 3") + }, + func(d *gorm.DB) *gorm.DB { return d.Where("d = 4") }, + ).Find(&Language{}) + }, + expected: "SELECT * FROM `languages` WHERE d = 4 OR c = 3 OR (a = 1 AND b = 2)", + }, + } + + for _, test := range tests { + t.Run(test.name, func(t *testing.T) { + assertEqualSQL(t, test.expected, DB.ToSQL(test.queryFn)) + }) + } +} From ec0b0d6fabb46903b5eaa7fb6a7b65c1f1922277 Mon Sep 17 00:00:00 2001 From: black Date: Wed, 15 Mar 2023 17:34:49 +0800 Subject: [PATCH 2/3] replace quote --- tests/scopes_test.go | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/tests/scopes_test.go b/tests/scopes_test.go index b4a9d11cc..52c6b37b1 100644 --- a/tests/scopes_test.go +++ b/tests/scopes_test.go @@ -87,7 +87,7 @@ func TestComplexScopes(t *testing.T) { func(d *gorm.DB) *gorm.DB { return d.Where(d.Or("b = 2").Or("c = 3")) }, ).Find(&Language{}) }, - expected: "SELECT * FROM `languages` WHERE a = 1 AND (b = 2 OR c = 3)", + expected: `SELECT * FROM "languages" WHERE a = 1 AND (b = 2 OR c = 3)`, }, { name: "depth_1_pre_cond", queryFn: func(tx *gorm.DB) *gorm.DB { @@ -96,7 +96,7 @@ func TestComplexScopes(t *testing.T) { func(d *gorm.DB) *gorm.DB { return d.Or(d.Where("b = 2").Or("c = 3")) }, ).Find(&Language{}) }, - expected: "SELECT * FROM `languages` WHERE z = 0 AND a = 1 OR (b = 2 OR c = 3)", + expected: `SELECT * FROM "languages" WHERE z = 0 AND a = 1 OR (b = 2 OR c = 3)`, }, { name: "depth_2", queryFn: func(tx *gorm.DB) *gorm.DB { @@ -113,7 +113,7 @@ func TestComplexScopes(t *testing.T) { func(d *gorm.DB) *gorm.DB { return d.Where("d = 4") }, ).Find(&Language{}) }, - expected: "SELECT * FROM `languages` WHERE d = 4 OR c = 3 OR (a = 1 AND b = 2)", + expected: `SELECT * FROM "languages" WHERE d = 4 OR c = 3 OR (a = 1 AND b = 2)`, }, } From 349957c54c053427a374f1f35a4f8a13f68dd458 Mon Sep 17 00:00:00 2001 From: black Date: Thu, 23 Mar 2023 13:00:26 +0800 Subject: [PATCH 3/3] fix execute scopes --- callbacks.go | 4 +++- chainable_api.go | 5 +---- migrator.go | 4 +++- statement.go | 2 +- 4 files changed, 8 insertions(+), 7 deletions(-) diff --git a/callbacks.go b/callbacks.go index 0da4ecf99..ca6b6d507 100644 --- a/callbacks.go +++ b/callbacks.go @@ -74,7 +74,9 @@ func (cs *callbacks) Raw() *processor { func (p *processor) Execute(db *DB) *DB { // call scopes - db = db.executeScopes(false) + for len(db.Statement.scopes) > 0 { + db = db.executeScopes() + } var ( curTime = time.Now() diff --git a/chainable_api.go b/chainable_api.go index b50992846..19d405cc7 100644 --- a/chainable_api.go +++ b/chainable_api.go @@ -366,7 +366,7 @@ func (db *DB) Scopes(funcs ...func(*DB) *DB) (tx *DB) { return tx } -func (db *DB) executeScopes(keepScopes bool) (tx *DB) { +func (db *DB) executeScopes() (tx *DB) { tx = db.getInstance() scopes := db.Statement.scopes if len(scopes) == 0 { @@ -393,9 +393,6 @@ func (db *DB) executeScopes(keepScopes bool) (tx *DB) { for _, condition := range conditions { tx.Statement.AddClause(condition) } - if keepScopes { - tx.Statement.scopes = scopes - } return tx } diff --git a/migrator.go b/migrator.go index 6da231b20..037afc35b 100644 --- a/migrator.go +++ b/migrator.go @@ -12,7 +12,9 @@ func (db *DB) Migrator() Migrator { tx := db.getInstance() // apply scopes to migrator - tx.executeScopes(false) + for len(tx.Statement.scopes) > 0 { + tx = tx.executeScopes() + } return tx.Dialector.Migrator(tx.Session(&Session{})) } diff --git a/statement.go b/statement.go index 162f0697e..59c0b772c 100644 --- a/statement.go +++ b/statement.go @@ -324,7 +324,7 @@ func (stmt *Statement) BuildCondition(query interface{}, args ...interface{}) [] case clause.Expression: conds = append(conds, v) case *DB: - v.executeScopes(true) + v.executeScopes() if cs, ok := v.Statement.Clauses["WHERE"]; ok && cs.Expression != nil { if where, ok := cs.Expression.(clause.Where); ok {