Skip to content

Commit

Permalink
Better support Count in chain
Browse files Browse the repository at this point in the history
  • Loading branch information
jinzhu committed Jul 1, 2020
1 parent 9075b33 commit d02b592
Show file tree
Hide file tree
Showing 2 changed files with 10 additions and 0 deletions.
2 changes: 2 additions & 0 deletions finisher_api.go
Original file line number Diff line number Diff line change
Expand Up @@ -269,6 +269,7 @@ func (db *DB) Count(count *int64) (tx *DB) {

if len(tx.Statement.Selects) == 0 {
tx.Statement.AddClause(clause.Select{Expression: clause.Expr{SQL: "count(1)"}})
defer tx.Statement.AddClause(clause.Select{})
} else if !strings.Contains(strings.ToLower(tx.Statement.Selects[0]), "count(") {
expr := clause.Expr{SQL: "count(1)"}

Expand All @@ -281,6 +282,7 @@ func (db *DB) Count(count *int64) (tx *DB) {
}

tx.Statement.AddClause(clause.Select{Expression: expr})
defer tx.Statement.AddClause(clause.Select{})
}

tx.Statement.Dest = count
Expand Down
8 changes: 8 additions & 0 deletions tests/count_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,14 @@ func TestCount(t *testing.T) {
t.Errorf("Count() method should get correct value, expect: %v, got %v", count, len(users))
}

if err := DB.Model(&User{}).Where("name = ?", user1.Name).Or("name = ?", user3.Name).Count(&count).Find(&users).Error; err != nil {
t.Errorf(fmt.Sprintf("Count should work, but got err %v", err))
}

if count != int64(len(users)) {
t.Errorf("Count() method should get correct value, expect: %v, got %v", count, len(users))
}

DB.Model(&User{}).Where("name = ?", user1.Name).Count(&count1).Or("name in ?", []string{user2.Name, user3.Name}).Count(&count2)
if count1 != 1 || count2 != 3 {
t.Errorf("multiple count in chain should works")
Expand Down

0 comments on commit d02b592

Please sign in to comment.