Skip to content

Commit

Permalink
Merge pull request #23 from cherrot/feature-dns-over-https
Browse files Browse the repository at this point in the history
DNS-over-HTTPS (DoH) Support
  • Loading branch information
cherrot committed Feb 28, 2021
2 parents 20344fe + 470108f commit a753928
Show file tree
Hide file tree
Showing 14 changed files with 364 additions and 110 deletions.
2 changes: 1 addition & 1 deletion .github/workflows/unittest.yml
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@ on:

jobs:

build:
unittest:
runs-on: ubuntu-latest
steps:
- uses: actions/checkout@v2
Expand Down
30 changes: 19 additions & 11 deletions client.go
Original file line number Diff line number Diff line change
Expand Up @@ -4,37 +4,39 @@ import (
"time"

"github.com/miekg/dns"

"github.com/cherrot/gochinadns/doh"
)

type Client struct {
*clientOptions
UDPCli *dns.Client
TCPCli *dns.Client
DoHCli *doh.Client
}

func NewClient(opts ...ClientOption) *Client {
o := newClientOptions()
o := new(clientOptions)
for _, f := range opts {
f(o)
}
return &Client{
clientOptions: o,
UDPCli: &dns.Client{Timeout: o.Timeout, Net: "udp"},
TCPCli: &dns.Client{Timeout: o.Timeout, Net: "tcp"},
DoHCli: doh.NewClient(
doh.WithTimeout(o.Timeout),
doh.WithSkipQueryMySelf(o.DoHSkipQuerySelf),
),
}
}

type clientOptions struct {
Timeout time.Duration // Timeout for one DNS query
UDPMaxSize int // Max message size for UDP queries
TCPOnly bool // Use TCP only
Mutation bool // Enable DNS pointer mutation for trusted servers
}

func newClientOptions() *clientOptions {
return &clientOptions{
Timeout: time.Second,
}
Timeout time.Duration // Timeout for one DNS query
UDPMaxSize int // Max message size for UDP queries
TCPOnly bool // Use TCP only
Mutation bool // Enable DNS pointer mutation for trusted servers
DoHSkipQuerySelf bool
}

type ClientOption func(*clientOptions)
Expand Down Expand Up @@ -62,3 +64,9 @@ func WithMutation(b bool) ClientOption {
o.Mutation = b
}
}

func WithDoHSkipQuerySelf(skip bool) ClientOption {
return func(o *clientOptions) {
o.DoHSkipQuerySelf = skip
}
}
4 changes: 2 additions & 2 deletions cmd/chinadns/flags.go
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@ var (
flagMutation = flag.Bool("m", false, "Enable compression pointer mutation in DNS queries.")
flagBidirectional = flag.Bool("d", true, "Drop results of trusted servers which containing IPs in China. (Bidirectional mode.)")
flagReusePort = flag.Bool("reuse-port", true, "Enable SO_REUSEPORT to gain some performance optimization. Need Linux>=3.9")
flagTimeout = flag.Duration("timeout", time.Second, "DNS request timeout")
flagTimeout = flag.Duration("timeout", 2*time.Second, "DNS request timeout")
flagDelay = flag.Float64("y", 0.1, "Delay (in seconds) to query another DNS server when no reply received.")
flagTestDomains = flag.String("test-domains", "www.qq.com", "Domain names to test DNS connection health, separated by comma.")
flagCHNList = flag.String("c", "./china.list", "Path to China route list. Both IPv4 and IPv6 are supported. See http://ipverse.net")
Expand All @@ -37,7 +37,7 @@ func init() {
"Protocols are dialed in order left to right. Rightmost protocol will only be dialed if the leftmost fails.\n"+
"Protocols will override force-tcp flag. "+
"If empty, protocol defaults to udp+tcp (tcp if force-tcp is set) and port defaults to 53.\n"+
"Examples: udp@8.8.8.8,udp+tcp@127.0.0.1:5353,1.1.1.1")
"Examples: 8.8.8.8,udp@127.0.0.1:5353,udp+tcp@1.1.1.1, doh@https://cloudflare-dns.com/dns-query")
flag.Var(&flagTrustedResolvers, "trusted-servers", "Comma separated list of servers which (located in China but) can be trusted. \n"+
"Uses the same format as -s.")
}
Expand Down
1 change: 1 addition & 0 deletions cmd/chinadns/main.go
Original file line number Diff line number Diff line change
Expand Up @@ -59,6 +59,7 @@ func main() {
gochinadns.WithTCPOnly(*flagForceTCP),
gochinadns.WithMutation(*flagMutation),
gochinadns.WithTimeout(*flagTimeout),
gochinadns.WithDoHSkipQuerySelf(true),
}

client := gochinadns.NewClient(copts...)
Expand Down
7 changes: 3 additions & 4 deletions cmd/lookup/main.go
Original file line number Diff line number Diff line change
Expand Up @@ -25,8 +25,7 @@ var (

func usage() {
fmt.Fprintf(flag.CommandLine.Output(), "Usage: %s [options] [proto[+proto]]@server www.domain.com\n", os.Args[0])
// TODO: supported schemas
fmt.Fprintln(flag.CommandLine.Output(), "Where proto being one of: udp, tcp.")
fmt.Fprintln(flag.CommandLine.Output(), "Where proto being one of: ", gochinadns.SupportedProtocols())
fmt.Fprintln(flag.CommandLine.Output(), "\nOptions:")
flag.PrintDefaults()
}
Expand Down Expand Up @@ -61,7 +60,7 @@ func main() {

fmt.Println(r)
fmt.Println(";; Query time:", rtt)
fmt.Println(";; SERVER:", resolver.Addr)
fmt.Println(";; SERVER:", resolver)
}

func parseArgs(args []string) (question string, resolver *gochinadns.Resolver) {
Expand All @@ -78,7 +77,7 @@ func parseArgs(args []string) (question string, resolver *gochinadns.Resolver) {
question = arg
}
}
if resolver.Addr == "" {
if resolver.GetAddr() == "" {
config, err := dns.ClientConfigFromFile("/etc/resolv.conf")
if err != nil {
logrus.Fatalln(err)
Expand Down
112 changes: 112 additions & 0 deletions doh/client.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,112 @@
package doh

import (
"encoding/base64"
"errors"
"io"
"net/http"
"net/url"
"time"

"github.com/miekg/dns"
"github.com/sirupsen/logrus"
)

const DoHMediaType = "application/dns-message"

var ErrQueryMyself = errors.New("not allowed to query myself")

type clientOptions struct {
Timeout time.Duration
SkipQueryMyself bool
}

type ClientOption func(*clientOptions)

// WithTimeout set a DNS query timeout
func WithTimeout(t time.Duration) ClientOption {
return func(o *clientOptions) {
o.Timeout = t
}
}

// WithSkipQueryMySelf controls whether sending DNS request of DoH server's domain to itself.
// Suppose we have a DoH client for https://dns.google.com/query, when this option is set,
// a DNS request whose question is dns.google.com will get an ErrQueryMyself.
// This is useful when this client acts as a local DNS resolver.
func WithSkipQueryMySelf(skip bool) ClientOption {
return func(o *clientOptions) {
o.SkipQueryMyself = skip
}
}

type Client struct {
opt *clientOptions
cli *http.Client
}

func NewClient(opts ...ClientOption) *Client {
o := new(clientOptions)
for _, f := range opts {
f(o)
}
return &Client{
opt: o,
cli: &http.Client{
Timeout: o.Timeout,
},
}
}

func (c *Client) Exchange(req *dns.Msg, address string) (r *dns.Msg, rtt time.Duration, err error) {
var (
buf, b64 []byte
begin = time.Now()
origID = req.Id
)

if c.opt.SkipQueryMyself {
u, e := url.Parse(address)
if e != nil {
return nil, 0, e
}
if req.Question[0].Name == dns.Fqdn(u.Hostname()) {
return nil, 0, ErrQueryMyself
}
}

// Set DNS ID as zero accoreding to RFC8484 (cache friendly)
req.Id = 0
buf, err = req.Pack()
b64 = make([]byte, base64.RawURLEncoding.EncodedLen(len(buf)))
if err != nil {
return
}
base64.RawURLEncoding.Encode(b64, buf)

// No need to use hreq.URL.Query()
uri := address + "?dns=" + string(b64)
logrus.Debugln("DoH request:", uri)
hreq, _ := http.NewRequest("GET", uri, nil)
hreq.Header.Add("Accept", DoHMediaType)
resp, err := c.cli.Do(hreq)
if err != nil {
return
}
defer resp.Body.Close()

content, err := io.ReadAll(resp.Body)
if err != nil {
return
}
if resp.StatusCode != http.StatusOK {
err = errors.New("DoH query failed: " + string(content))
return
}

r = new(dns.Msg)
err = r.Unpack(content)
r.Id = origID
rtt = time.Since(begin)
return
}
1 change: 1 addition & 0 deletions go.mod
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@ module github.com/cherrot/gochinadns
go 1.16

require (
github.com/goodhosts/hostsfile v0.0.7
github.com/miekg/dns v1.1.35
github.com/sirupsen/logrus v1.7.0
github.com/yl2chen/cidranger v1.0.2
Expand Down
4 changes: 4 additions & 0 deletions go.sum
Original file line number Diff line number Diff line change
@@ -1,6 +1,10 @@
github.com/davecgh/go-spew v1.1.0/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38=
github.com/davecgh/go-spew v1.1.1 h1:vj9j/u1bqnvCEfJOwUhtlOARqs3+rkHYY13jYWTU97c=
github.com/davecgh/go-spew v1.1.1/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38=
github.com/dimchansky/utfbom v1.1.0 h1:FcM3g+nofKgUteL8dm/UpdRXNC9KmADgTpLKsu0TRo4=
github.com/dimchansky/utfbom v1.1.0/go.mod h1:rO41eb7gLfo8SF1jd9F8HplJm1Fewwi4mQvIirEdv+8=
github.com/goodhosts/hostsfile v0.0.7 h1:5yBaORuv1dybDhDRju32bQQ1l4iHKJs+h6GIgFV4qJQ=
github.com/goodhosts/hostsfile v0.0.7/go.mod h1:MAfdBdP0f9MVmfhmNP4EjQxPu7J/WnncHv8p/J8hkLs=
github.com/miekg/dns v1.1.35 h1:oTfOaDH+mZkdcgdIjH6yBajRGtIwcwcaR+rt23ZSrJs=
github.com/miekg/dns v1.1.35/go.mod h1:KNUDUusw/aVsxyTYZM1oqvCicbwhgbNgztCETuNZ7xM=
github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZbAQM=
Expand Down
50 changes: 50 additions & 0 deletions hosts/lookup.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,50 @@
package hosts

import (
"net"

"github.com/goodhosts/hostsfile"
"github.com/sirupsen/logrus"
)

var h hostsfile.Hosts

func init() {
var err error
if h, err = hostsfile.NewHosts(); err != nil {
logrus.WithError(err).Warnln("Fail to parse local hosts file.")
}
}

func Lookup(host string) net.IP {
i := getHostnamePosition(host)
if i == -1 {
return nil
}
return net.ParseIP(h.Lines[i].IP)
}

// copied from package hostsfile
func getHostnamePosition(host string) int {
for i := range h.Lines {
line := h.Lines[i]
if !line.IsComment() && line.Raw != "" {
if itemInSlice(host, line.Hosts) {
return i
}
}
}

return -1
}

// copied from package hostsfile
func itemInSlice(item string, list []string) bool {
for _, i := range list {
if i == item {
return true
}
}

return false
}
24 changes: 19 additions & 5 deletions lookup.go
Original file line number Diff line number Diff line change
Expand Up @@ -51,8 +51,15 @@ func (c *Client) lookupNormal(req *dns.Msg, server *Resolver) (reply *dns.Msg, r
return
}
logger.WithError(err).Error("Fail to send TCP query.")
case "doh":
logger.Debug("Query upstream doh")
reply, rtt, err = c.DoHCli.Exchange(req, server.GetAddr())
if err == nil {
return
}
logger.WithError(err).Error("Fail to send DoH query.")
default:
logger.Errorf("No available protocols for resolver %s", server)
logger.Errorf("Protocol %s is unsupported in normal method.", protocol)
return
}
}
Expand All @@ -61,7 +68,6 @@ func (c *Client) lookupNormal(req *dns.Msg, server *Resolver) (reply *dns.Msg, r

// lookupMutation does the same as lookupNormal, with pointer mutation for DNS query.
// DNS Compression: https://tools.ietf.org/html/rfc1035#section-4.1.4
// DNS compression pointer mutation: https://gist.github.com/klzgrad/f124065c0616022b65e5#file-sendmsg-c-L30-L63
func (c *Client) lookupMutation(req *dns.Msg, server *Resolver) (reply *dns.Msg, rtt time.Duration, err error) {
logger := logrus.WithFields(logrus.Fields{
"question": questionString(&req.Question[0]),
Expand All @@ -73,7 +79,7 @@ func (c *Client) lookupMutation(req *dns.Msg, server *Resolver) (reply *dns.Msg,
if err != nil {
return nil, 0, fmt.Errorf("fail to pack request: %v", err.Error())
}
buffer = mutateQuestion(buffer)
buffer = mutateQuestion2(buffer)

// FIXME: may cause unexpected timeout (especially in `proto1+proto2@addr` case)
t := time.Now()
Expand Down Expand Up @@ -101,8 +107,15 @@ func (c *Client) lookupMutation(req *dns.Msg, server *Resolver) (reply *dns.Msg,
return
}
logger.WithError(err).Error("Fail to send TCP mutation query.")
case "doh":
logger.Debug("Query upstream doh")
reply, rtt, err = c.DoHCli.Exchange(req, server.GetAddr())
if err == nil {
return
}
logger.WithError(err).Error("Fail to send DoH query.")
default:
logger.Errorf("No available protocols for resolver %s", server)
logger.Errorf("Protocol %s is unsupported in mutation method.", protocol)
return
}
}
Expand Down Expand Up @@ -167,6 +180,7 @@ func cleanEdns0(req *dns.Msg) {
}
}

// DNS compression pointer mutation: https://gist.github.com/klzgrad/f124065c0616022b65e5#file-sendmsg-c-L30-L63
//nolint:deadcode,unused
func mutateQuestion(raw []byte) []byte {
length := len(raw)
Expand All @@ -192,7 +206,7 @@ func mutateQuestion(raw []byte) []byte {
return mutation
}

// add a "pointer" question. does not work now.
// black magic, works on limited resolvers (tested on Google and CloudFlare)
//nolint:deadcode,unused
func mutateQuestion2(raw []byte) []byte {
length := len(raw)
Expand Down
Loading

0 comments on commit a753928

Please sign in to comment.