/
list.go
132 lines (120 loc) · 3.27 KB
/
list.go
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
package blocklist
import (
"bufio"
"io"
"net/http"
"strings"
"time"
"github.com/coredns/coredns/plugin"
"github.com/coredns/coredns/plugin/metrics"
"github.com/mholt/caddy"
"github.com/miekg/dns"
"github.com/prometheus/client_golang/prometheus"
)
// ListDB is the persistent store of blocklist data.
type ListDB interface {
LastFetched(string) time.Time
Update(string, time.Time, []string) error
}
// List represents a single blocklist.
type List struct {
source string
refresh, retry, expire time.Duration
}
// NewList returns a new List representing the blocklist at source.
func NewList(source string) *List {
return &List{
source: source,
refresh: 2 * 24 * time.Hour,
retry: time.Hour,
expire: 7 * 24 * time.Hour,
}
}
var (
entries = prometheus.NewGaugeVec(prometheus.GaugeOpts{
Namespace: plugin.Namespace,
Subsystem: "blocklist",
Name: "list_size",
Help: "count of names on blocklist",
}, []string{"list"})
fetches = prometheus.NewCounterVec(prometheus.CounterOpts{
Namespace: plugin.Namespace,
Subsystem: "blocklist",
Name: "fetch",
Help: "count of blocklist fetches",
}, []string{"list", "result"})
)
func listMetrics(c *caddy.Controller) {
metrics.MustRegister(c, entries)
metrics.MustRegister(c, fetches)
}
// Run periodically downloads the blocklist and updates the internal database.
func (l *List) Run(db ListDB, stop <-chan struct{}, poke chan<- struct{}) {
delay := l.refresh - time.Now().Sub(db.LastFetched(l.source))
for {
if delay > 0 {
select {
case <-stop:
return
case <-time.Tick(delay):
}
}
now := time.Now()
delay = l.retry
// TODO(miki): retain etags?
resp, err := http.Get(l.source)
if err != nil {
fetches.WithLabelValues(l.source, "http_client_error").Inc()
log.Errorf("blocklist GET %q: %q", l.source, err)
continue
}
if resp.StatusCode < 200 || resp.StatusCode >= 300 {
fetches.WithLabelValues(l.source, "http_server_error").Inc()
log.Errorf("blocklist GET %q: %q", l.source, resp.Status)
continue
}
blocked, err := listRead(resp.Body)
if err != nil {
fetches.WithLabelValues(l.source, "parse_error").Inc()
log.Errorf("blocklist parse %q: %q", l.source, err)
continue
}
fetches.WithLabelValues(l.source, "OK").Inc()
entries.WithLabelValues(l.source).Set(float64(len(blocked)))
if err := db.Update(l.source, now, blocked); err != nil {
log.Errorf("blocklist update %q: %q", l.source, err)
continue
}
delay = l.refresh
select {
case poke <- struct{}{}:
default:
}
}
}
// listRead parses two types of lists: a single and double column (host file like). We only care about the domain
// names. For the double column ones we only keep the second one.
func listRead(r io.Reader) ([]string, error) {
var blocked []string
scanner := bufio.NewScanner(r)
for scanner.Scan() {
txt := scanner.Text()
if strings.HasPrefix("#", txt) {
continue
}
var domain string
flds := strings.Fields(scanner.Text())
switch len(flds) {
case 1:
domain = dns.Fqdn(flds[0])
case 2:
domain = dns.Fqdn(flds[1])
}
// we only allow domains with more thna 2 dots, i.e. don't accidently block an entire TLD.
if strings.Count(domain, ".") <= 2 {
continue
}
blocked = append(blocked, domain)
}
return blocked, scanner.Err()
}