forked from straightdave/trunks
-
Notifications
You must be signed in to change notification settings - Fork 0
/
resolver.go
78 lines (67 loc) · 1.8 KB
/
resolver.go
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
package resolver
import (
"context"
"errors"
"fmt"
"net"
"strconv"
"strings"
"sync/atomic"
)
type resolver struct {
addrs []string
dialer *net.Dialer
idx uint64
}
// NewResolver - create a new instance of a dns resolver for plugging
// into net.DefaultResolver. Addresses should be a list of
// ip addrs and optional port numbers, separated by colon.
// For example: 1.2.3.4:53 and 1.2.3.4 are both valid. In the absence
// of a port number, 53 will be used instead.
func NewResolver(addrs []string) (*net.Resolver, error) {
if len(addrs) == 0 {
return nil, errors.New("must specify at least resolver address")
}
cleanAddrs, err := normalizeAddrs(addrs)
if err != nil {
return nil, err
}
return &net.Resolver{
PreferGo: true,
Dial: (&resolver{addrs: cleanAddrs, dialer: &net.Dialer{}}).dial,
}, nil
}
func normalizeAddrs(addrs []string) ([]string, error) {
normal := make([]string, len(addrs))
for i, addr := range addrs {
// if addr has no port, give it 53
if !strings.Contains(addr, ":") {
addr += ":53"
}
// validate addr is a valid host:port
host, portstr, err := net.SplitHostPort(addr)
if err != nil {
return nil, err
}
// validate valid port.
_, err = strconv.ParseUint(portstr, 10, 16)
if err != nil {
return nil, err
}
// make sure host is an ip.
ip := net.ParseIP(host)
if ip == nil {
return nil, fmt.Errorf("host %s is not an IP address", host)
}
normal[i] = addr
}
return normal, nil
}
// ignore the third parameter, as this represents the dns server address that
// we are overriding.
func (r *resolver) dial(ctx context.Context, network, _ string) (net.Conn, error) {
return r.dialer.DialContext(ctx, network, r.address())
}
func (r *resolver) address() string {
return r.addrs[atomic.AddUint64(&r.idx, 1)%uint64(len(r.addrs))]
}