Skip to content

Commit

Permalink
Fix unsafe concurrent SingularTable method call
Browse files Browse the repository at this point in the history
  • Loading branch information
emirb committed Apr 14, 2019
1 parent 9df293e commit 5959487
Show file tree
Hide file tree
Showing 3 changed files with 47 additions and 7 deletions.
4 changes: 3 additions & 1 deletion main.go
Expand Up @@ -12,6 +12,7 @@ import (

// DB contains information for current db connection
type DB struct {
sync.Mutex
Value interface{}
Error error
RowsAffected int64
Expand Down Expand Up @@ -170,7 +171,8 @@ func (s *DB) HasBlockGlobalUpdate() bool {

// SingularTable use singular table by default
func (s *DB) SingularTable(enable bool) {
modelStructsMap = sync.Map{}
s.parent.Lock()
defer s.parent.Unlock()
s.parent.singularTable = enable
}

Expand Down
33 changes: 29 additions & 4 deletions main_test.go
Expand Up @@ -9,6 +9,7 @@ import (
"reflect"
"strconv"
"strings"
"sync"
"testing"
"time"

Expand Down Expand Up @@ -277,6 +278,30 @@ func TestTableName(t *testing.T) {
DB.SingularTable(false)
}

func TestTableNameConcurrently(t *testing.T) {
DB := DB.Model("")
if DB.NewScope(Order{}).TableName() != "orders" {
t.Errorf("Order's table name should be orders")
}

var wg sync.WaitGroup
wg.Add(10)

for i := 1; i <= 10; i++ {
go func(db *gorm.DB) {
DB.SingularTable(true)
wg.Done()
}(DB)
}
wg.Wait()

if DB.NewScope(Order{}).TableName() != "order" {
t.Errorf("Order's singular table name should be order")
}

DB.SingularTable(false)
}

func TestNullValues(t *testing.T) {
DB.DropTable(&NullValue{})
DB.AutoMigrate(&NullValue{})
Expand Down Expand Up @@ -1066,12 +1091,12 @@ func TestCountWithHaving(t *testing.T) {

DB.Create(getPreparedUser("user1", "pluck_user"))
DB.Create(getPreparedUser("user2", "pluck_user"))
user3:=getPreparedUser("user3", "pluck_user")
user3.Languages=[]Language{}
user3 := getPreparedUser("user3", "pluck_user")
user3.Languages = []Language{}
DB.Create(user3)

var count int
err:=db.Model(User{}).Select("users.id").
err := db.Model(User{}).Select("users.id").
Joins("LEFT JOIN user_languages ON user_languages.user_id = users.id").
Joins("LEFT JOIN languages ON user_languages.language_id = languages.id").
Group("users.id").Having("COUNT(languages.id) > 1").Count(&count).Error
Expand All @@ -1080,7 +1105,7 @@ func TestCountWithHaving(t *testing.T) {
t.Error("Unexpected error on query count with having")
}

if count!=2{
if count != 2 {
t.Error("Unexpected result on query count with having")
}
}
Expand Down
17 changes: 15 additions & 2 deletions model_struct.go
Expand Up @@ -40,9 +40,11 @@ func (s *ModelStruct) TableName(db *DB) string {
s.defaultTableName = tabler.TableName()
} else {
tableName := ToTableName(s.ModelType.Name())
db.parent.Lock()
if db == nil || (db.parent != nil && !db.parent.singularTable) {
tableName = inflection.Plural(tableName)
}
db.parent.Unlock()
s.defaultTableName = tableName
}
}
Expand Down Expand Up @@ -163,7 +165,18 @@ func (scope *Scope) GetModelStruct() *ModelStruct {
}

// Get Cached model struct
if value, ok := modelStructsMap.Load(reflectType); ok && value != nil {
isSingularTable := false
if scope.db != nil && scope.db.parent != nil {
scope.db.parent.Lock()
isSingularTable = scope.db.parent.singularTable
scope.db.parent.Unlock()
}

hashKey := struct {
singularTable bool
reflectType reflect.Type
}{isSingularTable, reflectType}
if value, ok := modelStructsMap.Load(hashKey); ok && value != nil {
return value.(*ModelStruct)
}

Expand Down Expand Up @@ -612,7 +625,7 @@ func (scope *Scope) GetModelStruct() *ModelStruct {
}
}

modelStructsMap.Store(reflectType, &modelStruct)
modelStructsMap.Store(hashKey, &modelStruct)

return &modelStruct
}
Expand Down

0 comments on commit 5959487

Please sign in to comment.