diff --git a/pkg/config/manager.go b/pkg/config/manager.go index 7e8c100255a1..ad5b723c6126 100644 --- a/pkg/config/manager.go +++ b/pkg/config/manager.go @@ -18,6 +18,7 @@ package config import ( "fmt" "strings" + "sync" "github.com/cockroachdb/errors" "go.uber.org/zap" @@ -84,7 +85,10 @@ type Manager struct { keySourceMap *typeutil.ConcurrentMap[string, string] // store the key to config source, example: key is A.B.C and source is file which means the A.B.C's value is from file overlays *typeutil.ConcurrentMap[string, string] // store the highest priority configs which modified at runtime forbiddenKeys *typeutil.ConcurrentSet[string] - configCache *typeutil.ConcurrentMap[string, interface{}] + + cacheMutex sync.RWMutex + configCache map[string]any + // configCache *typeutil.ConcurrentMap[string, interface{}] } func NewManager() *Manager { @@ -94,36 +98,50 @@ func NewManager() *Manager { keySourceMap: typeutil.NewConcurrentMap[string, string](), overlays: typeutil.NewConcurrentMap[string, string](), forbiddenKeys: typeutil.NewConcurrentSet[string](), - configCache: typeutil.NewConcurrentMap[string, interface{}](), + configCache: make(map[string]any), } resetConfigCacheFunc := NewHandler("reset.config.cache", func(event *Event) { keyToRemove := strings.NewReplacer("/", ".").Replace(event.Key) - manager.configCache.Remove(keyToRemove) + manager.EvictCachedValue(keyToRemove) }) manager.Dispatcher.RegisterForKeyPrefix("", resetConfigCacheFunc) return manager } func (m *Manager) GetCachedValue(key string) (interface{}, bool) { - return m.configCache.Get(key) + m.cacheMutex.RLock() + defer m.cacheMutex.RUnlock() + value, ok := m.configCache[key] + return value, ok } -func (m *Manager) SetCachedValue(key string, value interface{}) { - m.configCache.Insert(key, value) +func (m *Manager) CASCachedValue(key string, origin string, value interface{}) bool { + m.cacheMutex.Lock() + defer m.cacheMutex.Unlock() + current, err := m.GetConfig(key) + if err != nil { + return false + } + if current != origin { + return false + } + m.configCache[key] = value + return true } func (m *Manager) EvictCachedValue(key string) { - m.configCache.Remove(key) + m.cacheMutex.Lock() + defer m.cacheMutex.Unlock() + delete(m.configCache, key) } func (m *Manager) EvictCacheValueByFormat(keys ...string) { - set := typeutil.NewSet(keys...) - m.configCache.Range(func(key string, value interface{}) bool { - if set.Contain(formatKey(key)) { - m.configCache.Remove(key) - } - return true - }) + m.cacheMutex.Lock() + defer m.cacheMutex.Unlock() + + for _, key := range keys { + delete(m.configCache, key) + } } func (m *Manager) GetConfig(key string) (string, error) { diff --git a/pkg/config/manager_test.go b/pkg/config/manager_test.go index b955071661d6..a5c868c48e04 100644 --- a/pkg/config/manager_test.go +++ b/pkg/config/manager_test.go @@ -222,7 +222,7 @@ func TestCachedConfig(t *testing.T) { time.Sleep(time.Second) _, exist := mgr.GetCachedValue("a.b") assert.False(t, exist) - mgr.SetCachedValue("a.b", "aaa") + mgr.CASCachedValue("a.b", "aaa", "aaa") val, exist := mgr.GetCachedValue("a.b") assert.True(t, exist) assert.Equal(t, "aaa", val.(string)) @@ -237,10 +237,9 @@ func TestCachedConfig(t *testing.T) { { _, exist := mgr.GetCachedValue("c.d") assert.False(t, exist) - mgr.SetCachedValue("cd", "xxx") - val, exist := mgr.GetCachedValue("cd") - assert.True(t, exist) - assert.Equal(t, "xxx", val.(string)) + mgr.CASCachedValue("cd", "", "xxx") + _, exist = mgr.GetCachedValue("cd") + assert.False(t, exist) // after refresh, the cached value should be reset ctx := context.Background() diff --git a/pkg/util/paramtable/param_item.go b/pkg/util/paramtable/param_item.go index 50d3aca3cb48..37ec78e6c286 100644 --- a/pkg/util/paramtable/param_item.go +++ b/pkg/util/paramtable/param_item.go @@ -100,8 +100,9 @@ func (pi *ParamItem) GetAsStrings() []string { return strings } } - realStrs := getAsStrings(pi.GetValue()) - pi.manager.SetCachedValue(pi.Key, realStrs) + val := pi.GetValue() + realStrs := getAsStrings(val) + pi.manager.CASCachedValue(pi.Key, val, realStrs) return realStrs } @@ -111,8 +112,9 @@ func (pi *ParamItem) GetAsBool() bool { return boolVal } } - boolVal := getAsBool(pi.GetValue()) - pi.manager.SetCachedValue(pi.Key, boolVal) + val := pi.GetValue() + boolVal := getAsBool(val) + pi.manager.CASCachedValue(pi.Key, val, boolVal) return boolVal } @@ -122,8 +124,9 @@ func (pi *ParamItem) GetAsInt() int { return intVal } } - intVal := getAsInt(pi.GetValue()) - pi.manager.SetCachedValue(pi.Key, intVal) + val := pi.GetValue() + intVal := getAsInt(val) + pi.manager.CASCachedValue(pi.Key, val, intVal) return intVal } @@ -133,8 +136,9 @@ func (pi *ParamItem) GetAsInt32() int32 { return int32Val } } - int32Val := int32(getAsInt64(pi.GetValue())) - pi.manager.SetCachedValue(pi.Key, int32Val) + val := pi.GetValue() + int32Val := int32(getAsInt64(val)) + pi.manager.CASCachedValue(pi.Key, val, int32Val) return int32Val } @@ -144,8 +148,9 @@ func (pi *ParamItem) GetAsUint() uint { return uintVal } } - uintVal := uint(getAsUint64(pi.GetValue())) - pi.manager.SetCachedValue(pi.Key, uintVal) + val := pi.GetValue() + uintVal := uint(getAsUint64(val)) + pi.manager.CASCachedValue(pi.Key, val, uintVal) return uintVal } @@ -155,8 +160,9 @@ func (pi *ParamItem) GetAsUint32() uint32 { return uint32Val } } - uint32Val := uint32(getAsUint64(pi.GetValue())) - pi.manager.SetCachedValue(pi.Key, uint32Val) + val := pi.GetValue() + uint32Val := uint32(getAsUint64(val)) + pi.manager.CASCachedValue(pi.Key, val, uint32Val) return uint32Val } @@ -166,8 +172,9 @@ func (pi *ParamItem) GetAsUint64() uint64 { return uint64Val } } - uint64Val := getAsUint64(pi.GetValue()) - pi.manager.SetCachedValue(pi.Key, uint64Val) + val := pi.GetValue() + uint64Val := getAsUint64(val) + pi.manager.CASCachedValue(pi.Key, val, uint64Val) return uint64Val } @@ -177,8 +184,9 @@ func (pi *ParamItem) GetAsUint16() uint16 { return uint16Val } } - uint16Val := uint16(getAsUint64(pi.GetValue())) - pi.manager.SetCachedValue(pi.Key, uint16Val) + val := pi.GetValue() + uint16Val := uint16(getAsUint64(val)) + pi.manager.CASCachedValue(pi.Key, val, uint16Val) return uint16Val } @@ -188,8 +196,9 @@ func (pi *ParamItem) GetAsInt64() int64 { return int64Val } } - int64Val := getAsInt64(pi.GetValue()) - pi.manager.SetCachedValue(pi.Key, int64Val) + val := pi.GetValue() + int64Val := getAsInt64(val) + pi.manager.CASCachedValue(pi.Key, val, int64Val) return int64Val } @@ -199,8 +208,9 @@ func (pi *ParamItem) GetAsFloat() float64 { return floatVal } } - floatVal := getAsFloat(pi.GetValue()) - pi.manager.SetCachedValue(pi.Key, floatVal) + val := pi.GetValue() + floatVal := getAsFloat(val) + pi.manager.CASCachedValue(pi.Key, val, floatVal) return floatVal } @@ -210,8 +220,9 @@ func (pi *ParamItem) GetAsDuration(unit time.Duration) time.Duration { return durationVal } } - durationVal := getAsDuration(pi.GetValue(), unit) - pi.manager.SetCachedValue(pi.Key, durationVal) + val := pi.GetValue() + durationVal := getAsDuration(val, unit) + pi.manager.CASCachedValue(pi.Key, val, durationVal) return durationVal }