diff --git a/concurrent_map.go b/concurrent_map.go index 72428a4..68925b0 100644 --- a/concurrent_map.go +++ b/concurrent_map.go @@ -6,10 +6,10 @@ import ( "sync" ) -var SHARD_COUNT = 32 +const SHARD_COUNT = 32 // A "thread" safe map of type string:Anything. -// To avoid lock bottlenecks this map is dived to several (SHARD_COUNT) map shards. +// To avoid lock bottlenecks this map is dived to several (len(m)) map shards. type ConcurrentMap []*ConcurrentMapShared // A "thread" safe string to anything map. @@ -19,9 +19,15 @@ type ConcurrentMapShared struct { } // Creates a new concurrent map. -func New() ConcurrentMap { - m := make(ConcurrentMap, SHARD_COUNT) - for i := 0; i < SHARD_COUNT; i++ { +func New(shardCount ...int) ConcurrentMap { + var nShards int + if len(shardCount) > 0 { + nShards = shardCount[0] + } else { + nShards = SHARD_COUNT + } + m := make(ConcurrentMap, nShards) + for i := 0; i < nShards; i++ { m[i] = &ConcurrentMapShared{items: make(map[string]interface{})} } return m @@ -31,7 +37,7 @@ func New() ConcurrentMap { func (m ConcurrentMap) GetShard(key string) *ConcurrentMapShared { hasher := fnv.New32() hasher.Write([]byte(key)) - return m[uint(hasher.Sum32())%uint(SHARD_COUNT)] + return m[uint(hasher.Sum32())%uint(len(m))] } func (m ConcurrentMap) MSet(data map[string]interface{}) { @@ -79,7 +85,7 @@ func (m ConcurrentMap) Get(key string) (interface{}, bool) { // Returns the number of elements within the map. func (m ConcurrentMap) Count() int { count := 0 - for i := 0; i < SHARD_COUNT; i++ { + for i := 0; i < len(m); i++ { shard := m[i] shard.RLock() count += len(shard.items) @@ -126,7 +132,7 @@ func (m ConcurrentMap) Iter() <-chan Tuple { ch := make(chan Tuple) go func() { wg := sync.WaitGroup{} - wg.Add(SHARD_COUNT) + wg.Add(len(m)) // Foreach shard. for _, shard := range m { go func(shard *ConcurrentMapShared) { @@ -150,7 +156,7 @@ func (m ConcurrentMap) IterBuffered() <-chan Tuple { ch := make(chan Tuple, m.Count()) go func() { wg := sync.WaitGroup{} - wg.Add(SHARD_COUNT) + wg.Add(len(m)) // Foreach shard. for _, shard := range m { go func(shard *ConcurrentMapShared) { @@ -188,7 +194,7 @@ func (m ConcurrentMap) Keys() []string { go func() { // Foreach shard. wg := sync.WaitGroup{} - wg.Add(SHARD_COUNT) + wg.Add(len(m)) for _, shard := range m { go func(shard *ConcurrentMapShared) { // Foreach key, value pair. diff --git a/concurrent_map_bench_test.go b/concurrent_map_bench_test.go index 47cb8d8..7b8c0e6 100644 --- a/concurrent_map_bench_test.go +++ b/concurrent_map_bench_test.go @@ -3,6 +3,8 @@ package cmap import "testing" import "strconv" +var m ConcurrentMap + func BenchmarkItems(b *testing.B) { m := New() @@ -177,10 +179,8 @@ func GetSet(m ConcurrentMap, finished chan struct{}) (set func(key, value string } func runWithShards(bench func(b *testing.B), b *testing.B, shardsCount int) { - oldShardsCount := SHARD_COUNT - SHARD_COUNT = shardsCount + m = New(shardsCount) bench(b) - SHARD_COUNT = oldShardsCount } func BenchmarkKeys(b *testing.B) { diff --git a/concurrent_map_test.go b/concurrent_map_test.go index 2e794aa..49be012 100644 --- a/concurrent_map_test.go +++ b/concurrent_map_test.go @@ -271,12 +271,8 @@ func TestConcurrent(t *testing.T) { } func TestJsonMarshal(t *testing.T) { - SHARD_COUNT = 2 - defer func() { - SHARD_COUNT = 32 - }() expected := "{\"a\":1,\"b\":2}" - m := New() + m := New(2) m.Set("a", 1) m.Set("b", 2) j, err := json.Marshal(m)