-
Notifications
You must be signed in to change notification settings - Fork 71
/
blocklist.go
125 lines (108 loc) · 2.41 KB
/
blocklist.go
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
package blocklist
import (
"bufio"
"bytes"
"encoding/binary"
"errors"
"io"
"net"
"sync"
"github.com/cenkalti/rain/internal/blocklist/stree"
)
var errNotIPv4Address = errors.New("address is not ipv4")
// Blocklist holds a list of IP ranges in a Segment Tree structure for faster lookups.
type Blocklist struct {
Logger Logger
tree stree.Stree
m sync.RWMutex
count int
}
type Logger func(format string, v ...interface{})
// New returns a new Blocklist.
func New() *Blocklist {
return &Blocklist{}
}
// Len returns the number of rules in the Blocklist.
func (b *Blocklist) Len() int {
b.m.RLock()
defer b.m.RUnlock()
return b.count
}
// Blocked returns true if ip is in Blocklist.
func (b *Blocklist) Blocked(ip net.IP) bool {
b.m.RLock()
defer b.m.RUnlock()
ip = ip.To4()
if ip == nil {
return false
}
val := binary.BigEndian.Uint32(ip)
return b.tree.Contains(stree.ValueType(val))
}
// Reload the segment tree by reading new rules from a io.Reader.
func (b *Blocklist) Reload(r io.Reader) (int, error) {
b.m.Lock()
defer b.m.Unlock()
tree, n, err := load(r, b.Logger)
if err != nil {
return n, err
}
b.tree = *tree
b.count = n
return n, nil
}
func load(r io.Reader, logger Logger) (*stree.Stree, int, error) {
var tree stree.Stree
var n int
var hasError bool
scanner := bufio.NewScanner(r)
for scanner.Scan() {
l := bytes.TrimSpace(scanner.Bytes())
if len(l) == 0 {
continue
}
if l[0] == '#' {
continue
}
r, err := parseCIDR(l)
if err != nil {
hasError = true
if logger != nil {
logger("cannot parse blocklist line (%q): %q", string(l), err.Error())
}
continue
}
tree.AddRange(stree.ValueType(r.first), stree.ValueType(r.last))
n++
}
if err := scanner.Err(); err != nil {
return nil, 0, err
}
if n == 0 && hasError {
// Probably we couln't decode the stream correctly.
// At least one line must be correct before we consider the load operation as successful.
return nil, 0, errors.New("no valid rules")
}
tree.Build()
return &tree, n, nil
}
type ipRange struct {
first, last uint32
}
func parseCIDR(b []byte) (r ipRange, err error) {
_, ipnet, err := net.ParseCIDR(string(b))
if err != nil {
return
}
if len(ipnet.IP) != 4 {
err = errNotIPv4Address
return
}
if len(ipnet.Mask) != 4 {
err = errNotIPv4Address
return
}
r.first = binary.BigEndian.Uint32(ipnet.IP)
r.last = r.first | ^binary.BigEndian.Uint32(ipnet.Mask)
return
}