/
handler.go
155 lines (130 loc) · 2.95 KB
/
handler.go
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
package handler
import (
"encoding/base64"
"errors"
"io"
"net/http"
"strings"
"unsafe"
"github.com/caddyserver/caddy/v2"
"github.com/caddyserver/caddy/v2/modules/caddyhttp"
"github.com/imgk/memory-go"
"github.com/miekg/dns"
"github.com/imgk/caddy-dnsproxy/app"
)
func init() {
caddy.RegisterModule(Handler{})
}
// DefaultPrefix is ...
const DefaultPrefix = "/dns-query"
// Handler is ...
type Handler struct {
// Prefix is ...
Prefix string `json:"prefix,omitempty"`
up app.Upstream
}
// CaddyModule is ...
func (Handler) CaddyModule() caddy.ModuleInfo {
return caddy.ModuleInfo{
ID: "http.handlers.dns_over_https",
New: func() caddy.Module { return new(Handler) },
}
}
// Provision is ...
func (m *Handler) Provision(ctx caddy.Context) error {
if m.Prefix == "" {
m.Prefix = DefaultPrefix
}
ctx.App(app.CaddyAppID)
if ctx.AppIfConfigured(app.CaddyAppID) == nil {
return errors.New(app.CaddyAppID + " is not configured")
}
mod, err := ctx.App(app.CaddyAppID)
if err != nil {
return err
}
m.up = mod.(app.Upstream)
return nil
}
// ServeHTTP is ...
func (m *Handler) ServeHTTP(w http.ResponseWriter, r *http.Request, next caddyhttp.Handler) error {
if strings.HasPrefix(r.URL.Path, m.Prefix) {
switch r.Method {
case http.MethodGet:
return m.serveGet(w, r)
case http.MethodPost:
return m.servePost(w, r)
default:
return next.ServeHTTP(w, r)
}
}
return next.ServeHTTP(w, r)
}
func (m *Handler) serveGet(w http.ResponseWriter, r *http.Request) error {
ss, ok := r.URL.Query()["dns"]
if !ok || len(ss) < 1 {
return errors.New("no dns query")
}
ptr, buf := memory.Alloc[byte](dns.MaxMsgSize)
defer memory.Free(ptr)
n, err := base64.RawURLEncoding.Decode(buf, func(s string) []byte {
return unsafe.Slice((*byte)(unsafe.Pointer(*(*uintptr)(unsafe.Pointer(&s)))), len(s))
}(ss[0]))
if err != nil {
return err
}
return m.response(w, buf, n)
}
func (m *Handler) servePost(w http.ResponseWriter, r *http.Request) error {
ptr, buf := memory.Alloc[byte](dns.MaxMsgSize)
defer memory.Free(ptr)
// read dns message from request
n, err := Buffer(buf).ReadFrom(r.Body)
if err != nil {
return err
}
return m.response(w, buf, int(n))
}
func (m *Handler) response(w http.ResponseWriter, buf []byte, n int) error {
// parse dns message
msg := &dns.Msg{}
if err := msg.Unpack(buf[:n]); err != nil {
return err
}
// request response
msg, err := m.up.Exchange(msg)
if err != nil {
return err
}
bb, err := msg.PackBuffer(buf)
if err != nil {
return err
}
// write response back
_, err = w.Write(bb)
return err
}
var _ caddyhttp.MiddlewareHandler = (*Handler)(nil)
// Buffer is ...
type Buffer []byte
// ReadFrom is ...
func (b Buffer) ReadFrom(r io.Reader) (n int64, err error) {
for {
nr, er := r.Read(b[n:])
if nr > 0 {
n += int64(nr)
}
if er != nil {
if errors.Is(er, io.EOF) {
break
}
err = er
break
}
if int(n) == len(b) {
err = io.ErrShortBuffer
break
}
}
return
}