-
Notifications
You must be signed in to change notification settings - Fork 73
/
multi_ctx_rwlock.go
137 lines (121 loc) · 3 KB
/
multi_ctx_rwlock.go
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
package ctxmu
import (
"context"
"sync"
"sync/atomic"
"time"
)
type MultiCtxLocker interface {
Lock(ctx context.Context, key interface{}) (err error)
Unlock(key interface{})
}
type MultiCtxRWLocker interface {
MultiCtxLocker
RLock(ctx context.Context, key interface{}) (err error)
RUnlock(key interface{})
}
type MultiCtxRWMutex struct {
locks sync.Map
pool sync.Pool
}
func NewDefaultMultiCtxRWMutex() *MultiCtxRWMutex {
return NewMultiCtxRWMutex(func() CtxRWLocker {
return &CtxRWMutex{}
})
}
func NewMultiCtxRWMutex(newCtxRWLock func() CtxRWLocker) *MultiCtxRWMutex {
return &MultiCtxRWMutex{
locks: sync.Map{},
pool: sync.Pool{
New: func() interface{} {
return newCtxRWLock()
},
},
}
}
type ctxRWLockRefCounter struct {
count int64
lock CtxRWLocker
}
func (m *MultiCtxRWMutex) Lock(ctx context.Context, key interface{}) (err error) {
counter, err := m.incrGetRWLockRefCounter(ctx, key)
if err != nil {
return
}
err = (counter.lock).Lock(ctx)
if err != nil {
m.decrPutRWLockRefCounter(key, counter)
}
return
}
func (m *MultiCtxRWMutex) LockWithTimout(timeout time.Duration, key interface{}) (err error) {
ctx, cancel := context.WithTimeout(context.Background(), timeout)
defer cancel()
err = m.Lock(ctx, key)
return
}
func (m *MultiCtxRWMutex) Unlock(key interface{}) {
counter := m.mustGetCounter(key)
counter.lock.Unlock()
m.decrPutRWLockRefCounter(key, counter)
return
}
func (m *MultiCtxRWMutex) RLock(ctx context.Context, key interface{}) (err error) {
counter, err := m.incrGetRWLockRefCounter(ctx, key)
if err != nil {
return
}
err = (counter.lock).RLock(ctx)
if err != nil {
m.decrPutRWLockRefCounter(key, counter)
}
return
}
func (m *MultiCtxRWMutex) RLockWithTimout(timeout time.Duration, key interface{}) (err error) {
ctx, cancel := context.WithTimeout(context.Background(), timeout)
defer cancel()
err = m.RLock(ctx, key)
return
}
func (m *MultiCtxRWMutex) RUnlock(key interface{}) {
counter := m.mustGetCounter(key)
counter.lock.RUnlock()
m.decrPutRWLockRefCounter(key, counter)
return
}
func (m *MultiCtxRWMutex) mustGetCounter(key interface{}) (counter *ctxRWLockRefCounter) {
actual, ok := m.locks.Load(key)
if !ok {
panic("key's lock has been invalidly freed")
}
counter = actual.(*ctxRWLockRefCounter)
return
}
func (m *MultiCtxRWMutex) incrGetRWLockRefCounter(ctx context.Context, key interface{}) (counter *ctxRWLockRefCounter, err error) {
for {
err = ctx.Err()
if err != nil {
return
}
actual, _ := m.locks.LoadOrStore(key, &ctxRWLockRefCounter{
count: 0,
lock: m.pool.Get().(*CtxRWMutex),
})
counter = actual.(*ctxRWLockRefCounter)
old := counter.count
if old < 0 {
continue
}
if atomic.CompareAndSwapInt64(&counter.count, old, old+1) {
break
}
}
return
}
func (m *MultiCtxRWMutex) decrPutRWLockRefCounter(key interface{}, counter *ctxRWLockRefCounter) {
atomic.AddInt64(&counter.count, -1)
if atomic.CompareAndSwapInt64(&counter.count, 0, -1) {
m.pool.Put(counter.lock)
m.locks.Delete(key)
}
}