-
Notifications
You must be signed in to change notification settings - Fork 53
/
dns.go
122 lines (91 loc) · 2.26 KB
/
dns.go
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
package resolver
import (
"context"
"fmt"
"net"
"strings"
"github.com/miekg/dns"
)
type DNS struct {
client dns.Client
handler DNSHandler
mux *dns.ServeMux
server *dns.Server
soa dns.RR
upstream string
}
type DNSHandler func(typ, host string) (string, bool)
func NewDNS(conn net.PacketConn, handler DNSHandler, upstream string) (*DNS, error) {
mux := dns.NewServeMux()
fmt.Printf("ns=dns at=new upstream=%s\n", upstream)
d := &DNS{
client: dns.Client{Net: "udp"},
handler: handler,
mux: mux,
server: &dns.Server{
PacketConn: conn,
Handler: mux,
},
upstream: upstream,
}
soa, err := dns.NewRR("$ORIGIN .\n$TTL 0\n@ SOA ns.convox. support.convox.com. 2020010100 0 0 0 0")
if err != nil {
return nil, err
}
d.soa = soa
mux.Handle(".", d)
return d, nil
}
func (d *DNS) ListenAndServe() error {
fmt.Printf("ns=dns at=serve\n")
return d.server.ActivateAndServe()
}
func (d *DNS) Shutdown(ctx context.Context) error {
fmt.Printf("ns=dns at=shutdown\n")
return d.server.Shutdown()
}
func (d *DNS) ServeDNS(w dns.ResponseWriter, r *dns.Msg) {
if len(r.Question) < 1 {
dnsError(w, r, fmt.Errorf("invalid question"))
return
}
q := r.Question[0]
typ := dns.TypeToString[q.Qtype]
question := strings.TrimSuffix(r.Question[0].Name, ".")
fmt.Printf("ns=dns at=question type=%s question=%q\n", typ, question)
a := &dns.Msg{}
a.Compress = false
a.RecursionAvailable = true
if r.IsEdns0() != nil {
a.SetEdns0(4096, true)
}
a.SetReply(r)
if answer, ok := d.handler(typ, question); ok {
fmt.Printf("ns=dns at=answer type=%s question=%q answer=%q\n", typ, question, answer)
if answer != "" {
rr, err := dns.NewRR(fmt.Sprintf("%s %s %s", question, typ, answer))
if err != nil {
dnsError(w, r, err)
return
}
a.Answer = append(a.Answer, rr)
a.Authoritative = true
a.Ns = []dns.RR{d.soa}
}
w.WriteMsg(a)
return
}
fmt.Printf("ns=dns at=forward type=%s question=%q\n", typ, question)
rs, _, err := d.client.Exchange(r, d.upstream)
if err != nil {
dnsError(w, r, err)
return
}
w.WriteMsg(rs)
}
func dnsError(w dns.ResponseWriter, r *dns.Msg, err error) {
fmt.Printf("ns=dns at=error error=%s\n", err)
m := &dns.Msg{}
m.SetRcode(r, dns.RcodeServerFailure)
w.WriteMsg(m)
}