/
reentrant.go
111 lines (81 loc) · 1.42 KB
/
reentrant.go
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
package common
import (
"fmt"
"sync"
)
type pass struct {
remux *reentrantmutex
c int
}
type reentrantmutex struct {
mu sync.Mutex
current *pass
}
func NewRentrantMutex() reentrantmutex {
return reentrantmutex{
mu: sync.Mutex{},
current: nil,
}
}
func (this *reentrantmutex) NewPass() *pass {
return &pass{
remux: this,
c: 0,
}
}
func (this *pass) Lock() {
this.remux.mu.Lock()
if this.remux.current == nil {
this.c++
this.remux.current = this
this.remux.mu.Unlock()
return
}
if this.remux.current == this {
this.c++
this.remux.mu.Unlock()
return
}
this.remux.mu.Unlock()
for {
this.remux.mu.Lock()
if this.remux.current == nil {
this.c++
this.remux.current = this
this.remux.mu.Unlock()
return
} else {
this.remux.mu.Unlock()
}
}
}
func (this *reentrantmutex) UnlockNow() {
this.mu.Lock()
defer this.mu.Unlock()
this.current = nil
}
func (this *pass) Unlock() {
for {
this.remux.mu.Lock()
switch {
case this.remux.current == nil:
this.remux.current = nil
this.remux.mu.Unlock()
return
case this.remux.current == this:
this.remux.current.c--
if this.remux.current.c == 0 {
this.remux.current = nil
}
this.remux.mu.Unlock()
return
default:
panic(fmt.Errorf("invalid pass"))
}
}
}
func (this *reentrantmutex) HasLock() bool {
this.mu.Lock()
defer this.mu.Unlock()
return this.current != nil
}