forked from rekby/lets-proxy
-
Notifications
You must be signed in to change notification settings - Fork 0
/
dns.go
146 lines (130 loc) · 3.48 KB
/
dns.go
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
package main
import (
"context"
"errors"
"github.com/Sirupsen/logrus"
"github.com/miekg/dns"
"net"
"strings"
"sync"
)
var (
allowedDomainChars [255]bool
)
func init() {
for _, b := range []byte("1234567890qwertyuiopasdfghjklzxcvbnmQWERTYUIOPASDFGHJKLZXCVBNM.-") {
allowedDomainChars[b] = true
}
}
func domainValidName(domain string) error {
if len(domain) == 0 {
return errors.New("Zero length domain name")
}
if domain[0] == '.' || domain[0] == '-' {
return errors.New("Bad start symbol")
}
if domain[len(domain)-1] == '-' {
return errors.New("Bad end symbol")
}
for _, byte := range []byte(domain) {
if !allowedDomainChars[byte] {
return errors.New("Bad symbol")
}
}
return nil
}
func domainHasLocalIP(ctx context.Context, domain string) bool {
var ipsChan = make(chan []net.IP, 1)
defer func() {
// clean channel
for range ipsChan {
// pass
}
}()
var dnsRequests = &sync.WaitGroup{}
dnsRequests.Add(1)
go func() {
ips, err := net.LookupIP(domain)
if err == nil {
ipsChan <- ips
} else {
logrus.Warnf("Can't local lookup ip for domain '%v': %v", domain, err)
}
logrus.Debugf("Receive answer from local lookup for domain '%v' ips: '%v'", domain, ips)
dnsRequests.Done()
}()
domainForRequest := domain
if !strings.HasSuffix(domainForRequest, ".") {
domainForRequest += "."
}
dnsq := func(server string) {
dnsRequests.Add(2) // for A and AAAA requests
go func() {
ipsChan <- getIPsFromDNS(ctx, domainForRequest, server, dns.TypeA)
dnsRequests.Done()
}()
go func() {
ipsChan <- getIPsFromDNS(ctx, domainForRequest, server, dns.TypeAAAA)
dnsRequests.Done()
}()
}
dnsq("8.8.8.8:53") // google 1
dnsq("[2001:4860:4860::8844]:53") // google 2 (ipv6)
dnsq("77.88.8.8:53") // yandex 1
dnsq("[2a02:6b8:0:1::feed:0ff]:53") // yandex 2 (ipv6)
go func() {
// close channel after all requests complete
dnsRequests.Wait()
close(ipsChan)
}()
hasIP := false
allowIPs := getAllowIPs()
for ips := range ipsChan {
if len(ips) > 0 {
hasIP = true
}
for _, ip := range ips {
// If domain has ip doesn't that doesn't bind to the server
if !ipContains(allowIPs, ip) {
logrus.Debugf("Domain have ip of other server. Domain '%v', Domain ips: '%v', Server ips: '%v'", domain, ips, allowIPs)
return false
}
}
}
if !hasIP {
logrus.Infof("Doesn't found ip addresses for domain '%v'", domain)
return false
}
return true
}
func getIPsFromDNS(ctx context.Context, domain, dnsServer string, recordType uint16) []net.IP {
dnsClient := dns.Client{}
msg := dns.Msg{}
msg.Id = dns.Id()
msg.SetQuestion(domain, recordType)
answer, _, err := dnsClient.Exchange(&msg, dnsServer)
if err != nil {
logrus.Infof("Error from dns server '%v' for domain '%v', record type '%v': %v", dnsServer, domain, dns.TypeToString[recordType], err)
return nil
}
if answer.Id != msg.Id {
logrus.Infof("Error answer ID from dns server '%v' for domain '%v', record type '%v', %v != %v", dnsServer, domain, dns.TypeToString[recordType], msg.Id, answer.Id)
return nil
}
var res []net.IP
for _, r := range answer.Answer {
if r.Header().Rrtype != recordType {
continue
}
switch r.Header().Rrtype {
case dns.TypeA:
res = append(res, r.(*dns.A).A)
case dns.TypeAAAA:
res = append(res, r.(*dns.AAAA).AAAA)
default:
continue
}
}
logrus.Debugf("Receive answer from dns server '%v' for domain '%v' record type '%v' ips: '%v'", dnsServer, domain, dns.TypeToString[recordType], res)
return res
}