/
plugin_allow_ip.go
157 lines (147 loc) · 4.23 KB
/
plugin_allow_ip.go
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
package main
import (
"errors"
"fmt"
"io"
"net"
"strings"
"time"
iradix "github.com/hashicorp/go-immutable-radix"
"github.com/jedisct1/dlog"
"github.com/miekg/dns"
)
type PluginAllowedIP struct {
allowedPrefixes *iradix.Tree
allowedIPs map[string]interface{}
logger io.Writer
format string
}
func (plugin *PluginAllowedIP) Name() string {
return "allow_ip"
}
func (plugin *PluginAllowedIP) Description() string {
return "Allows DNS queries containing specific IP addresses"
}
func (plugin *PluginAllowedIP) Init(proxy *Proxy) error {
dlog.Noticef("Loading the set of allowed IP rules from [%s]", proxy.allowedIPFile)
lines, err := ReadTextFile(proxy.allowedIPFile)
if err != nil {
return err
}
plugin.allowedPrefixes = iradix.New()
plugin.allowedIPs = make(map[string]interface{})
for lineNo, line := range strings.Split(lines, "\n") {
line = TrimAndStripInlineComments(line)
if len(line) == 0 {
continue
}
ip := net.ParseIP(line)
trailingStar := strings.HasSuffix(line, "*")
if len(line) < 2 || (ip != nil && trailingStar) {
dlog.Errorf("Suspicious allowed IP rule [%s] at line %d", line, lineNo)
continue
}
if trailingStar {
line = line[:len(line)-1]
}
if strings.HasSuffix(line, ":") || strings.HasSuffix(line, ".") {
line = line[:len(line)-1]
}
if len(line) == 0 {
dlog.Errorf("Empty allowed IP rule at line %d", lineNo)
continue
}
if strings.Contains(line, "*") {
dlog.Errorf("Invalid rule: [%s] - wildcards can only be used as a suffix at line %d", line, lineNo)
continue
}
line = strings.ToLower(line)
if trailingStar {
plugin.allowedPrefixes, _, _ = plugin.allowedPrefixes.Insert([]byte(line), 0)
} else {
plugin.allowedIPs[line] = true
}
}
if len(proxy.allowedIPLogFile) == 0 {
return nil
}
plugin.logger = Logger(proxy.logMaxSize, proxy.logMaxAge, proxy.logMaxBackups, proxy.allowedIPLogFile)
plugin.format = proxy.allowedIPFormat
return nil
}
func (plugin *PluginAllowedIP) Drop() error {
return nil
}
func (plugin *PluginAllowedIP) Reload() error {
return nil
}
func (plugin *PluginAllowedIP) Eval(pluginsState *PluginsState, msg *dns.Msg) error {
answers := msg.Answer
if len(answers) == 0 {
return nil
}
allowed, reason, ipStr := false, "", ""
for _, answer := range answers {
header := answer.Header()
Rrtype := header.Rrtype
if header.Class != dns.ClassINET || (Rrtype != dns.TypeA && Rrtype != dns.TypeAAAA) {
continue
}
if Rrtype == dns.TypeA {
ipStr = answer.(*dns.A).A.String()
} else if Rrtype == dns.TypeAAAA {
ipStr = answer.(*dns.AAAA).AAAA.String() // IPv4-mapped IPv6 addresses are converted to IPv4
}
if _, found := plugin.allowedIPs[ipStr]; found {
allowed, reason = true, ipStr
break
}
match, _, found := plugin.allowedPrefixes.Root().LongestPrefix([]byte(ipStr))
if found {
if len(match) == len(ipStr) || (ipStr[len(match)] == '.' || ipStr[len(match)] == ':') {
allowed, reason = true, string(match)+"*"
break
}
}
}
if allowed {
pluginsState.sessionData["whitelisted"] = true
if plugin.logger != nil {
qName := pluginsState.qName
var clientIPStr string
switch pluginsState.clientProto {
case "udp":
clientIPStr = (*pluginsState.clientAddr).(*net.UDPAddr).IP.String()
case "tcp", "local_doh":
clientIPStr = (*pluginsState.clientAddr).(*net.TCPAddr).IP.String()
default:
// Ignore internal flow.
return nil
}
var line string
if plugin.format == "tsv" {
now := time.Now()
year, month, day := now.Date()
hour, minute, second := now.Clock()
tsStr := fmt.Sprintf("[%d-%02d-%02d %02d:%02d:%02d]", year, int(month), day, hour, minute, second)
line = fmt.Sprintf(
"%s\t%s\t%s\t%s\t%s\n",
tsStr,
clientIPStr,
StringQuote(qName),
StringQuote(ipStr),
StringQuote(reason),
)
} else if plugin.format == "ltsv" {
line = fmt.Sprintf("time:%d\thost:%s\tqname:%s\tip:%s\tmessage:%s\n", time.Now().Unix(), clientIPStr, StringQuote(qName), StringQuote(ipStr), StringQuote(reason))
} else {
dlog.Fatalf("Unexpected log format: [%s]", plugin.format)
}
if plugin.logger == nil {
return errors.New("Log file not initialized")
}
_, _ = plugin.logger.Write([]byte(line))
}
}
return nil
}