-
Notifications
You must be signed in to change notification settings - Fork 0
/
cacheqtype.go
128 lines (116 loc) · 2.45 KB
/
cacheqtype.go
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
package recursive
import (
"net/netip"
"sync"
"time"
"github.com/miekg/dns"
)
type cacheQtype struct {
mu sync.RWMutex
cache map[string][]cacheValue
}
func newCacheQtype() *cacheQtype {
return &cacheQtype{cache: make(map[string][]cacheValue)}
}
func (cq *cacheQtype) entries() (n int) {
cq.mu.RLock()
for _, cv := range cq.cache {
n += len(cv)
}
cq.mu.RUnlock()
return
}
func (cq *cacheQtype) set(nsaddr netip.Addr, msg *dns.Msg, ttl int) {
qname := msg.Question[0].Name
expires := time.Now().Add(time.Duration(ttl) * time.Second)
cq.mu.Lock()
defer cq.mu.Unlock()
cvl := cq.cache[qname]
for i := range cvl {
if cvl[i].nsaddr == nsaddr {
cvl[i].Msg = msg
cvl[i].expires = expires
return
}
}
cq.cache[qname] = append(cvl, cacheValue{
Msg: msg,
nsaddr: nsaddr,
expires: expires,
})
}
func find(cvl []cacheValue, addr netip.Addr, qname string) (idx int) {
wild := !addr.IsValid()
idx = -1
for i := range cvl {
if cvl[i].nsaddr == addr {
idx = i
break
}
if wild {
wild = false
idx = i
}
}
return
}
func (cq *cacheQtype) getExisting(addr netip.Addr, qname string) (cv cacheValue) {
cq.mu.RLock()
defer cq.mu.RUnlock()
cvl := cq.cache[qname]
if idx := find(cvl, addr, qname); idx >= 0 {
cv = cvl[idx]
}
return
}
func (cq *cacheQtype) get(addr netip.Addr, qname string) (netip.Addr, *dns.Msg) {
if cv := cq.getExisting(addr, qname); cv.Msg != nil {
if time.Since(cv.expires) < 0 {
return cv.nsaddr, cv.Msg
}
cq.mu.Lock()
defer cq.mu.Unlock()
cvl := cq.cache[qname]
if idx := find(cvl, addr, qname); idx >= 0 {
if cvl = cq.deleteLocked(cvl, idx); len(cvl) > 0 {
cq.cache[qname] = cvl
} else {
delete(cq.cache, qname)
}
}
}
return netip.Addr{}, nil
}
func (cq *cacheQtype) clear() {
cq.clean(time.Time{})
}
func (cq *cacheQtype) clean(now time.Time) {
cq.mu.Lock()
defer cq.mu.Unlock()
for qname, cvl := range cq.cache {
for i := len(cvl); i > 0; i-- {
if idx := len(cvl) - 1; idx >= 0 {
if now.IsZero() || now.After(cvl[idx].expires) {
cvl = cq.deleteLocked(cvl, idx)
}
}
}
if len(cvl) > 0 {
cq.cache[qname] = cvl
} else {
delete(cq.cache, qname)
}
}
}
func (cq *cacheQtype) deleteLocked(cvl []cacheValue, idx int) []cacheValue {
l := len(cvl) - 1
if idx < l {
cvl[idx], cvl[l] = cvl[l], cvl[idx]
}
clear(cvl[l].Msg.Question)
clear(cvl[l].Msg.Answer)
clear(cvl[l].Msg.Ns)
clear(cvl[l].Msg.Extra)
cvl[l] = cacheValue{}
return cvl[:l]
}