Skip to content

Commit

Permalink
dnsforward: fix upstream test
Browse files Browse the repository at this point in the history
  • Loading branch information
EugeneOne1 committed Jan 13, 2022
1 parent 1458600 commit 0de155b
Show file tree
Hide file tree
Showing 4 changed files with 126 additions and 136 deletions.
4 changes: 3 additions & 1 deletion CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@ TODO(a.garipov): Remove this deprecation, if v0.108.0 is released before the Go

### Fixed

- Poor testing of domain-specific upstream servers ([#4074]).
- Omitted aliases of hosts specified by another line within the OS's hosts file
([#4079]).

Expand All @@ -37,6 +38,8 @@ TODO(a.garipov): Remove this deprecation, if v0.108.0 is released before the Go
- Go 1.16 support.

[#3057]: https://github.com/AdguardTeam/AdGuardHome/issues/3057
[#4074]: https://github.com/AdguardTeam/AdGuardHome/issues/4074
[#4079]: https://github.com/AdguardTeam/AdGuardHome/issues/4079



Expand Down Expand Up @@ -82,7 +85,6 @@ TODO(a.garipov): Remove this deprecation, if v0.108.0 is released before the Go
[#4008]: https://github.com/AdguardTeam/AdGuardHome/issues/4008
[#4016]: https://github.com/AdguardTeam/AdGuardHome/issues/4016
[#4027]: https://github.com/AdguardTeam/AdGuardHome/issues/4027
[#4079]: https://github.com/AdguardTeam/AdGuardHome/issues/4079



Expand Down
133 changes: 58 additions & 75 deletions internal/dnsforward/http.go
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,6 @@ import (
"fmt"
"net"
"net/http"
"strconv"
"strings"
"time"

Expand Down Expand Up @@ -192,22 +191,23 @@ func (req *dnsConfig) checkCacheTTL() bool {

func (s *Server) handleSetConfig(w http.ResponseWriter, r *http.Request) {
req := dnsConfig{}
dec := json.NewDecoder(r.Body)
if err := dec.Decode(&req); err != nil {
err := json.NewDecoder(r.Body).Decode(&req)
if err != nil {
aghhttp.Error(r, w, http.StatusBadRequest, "json Encode: %s", err)

return
}

if req.Upstreams != nil {
if err := ValidateUpstreams(*req.Upstreams); err != nil {
if err = ValidateUpstreams(*req.Upstreams); err != nil {
aghhttp.Error(r, w, http.StatusBadRequest, "wrong upstreams specification: %s", err)

return
}
}

if errBoot, err := req.checkBootstrap(); err != nil {
var errBoot string
if errBoot, err = req.checkBootstrap(); err != nil {
aghhttp.Error(
r,
w,
Expand All @@ -220,19 +220,16 @@ func (s *Server) handleSetConfig(w http.ResponseWriter, r *http.Request) {
return
}

if !req.checkBlockingMode() {
switch {
case !req.checkBlockingMode():
aghhttp.Error(r, w, http.StatusBadRequest, "blocking_mode: incorrect value")

return
}

if !req.checkUpstreamsMode() {
case !req.checkUpstreamsMode():
aghhttp.Error(r, w, http.StatusBadRequest, "upstream_mode: incorrect value")

return
}

if !req.checkCacheTTL() {
case !req.checkCacheTTL():
aghhttp.Error(
r,
w,
Expand All @@ -241,13 +238,15 @@ func (s *Server) handleSetConfig(w http.ResponseWriter, r *http.Request) {
)

return
default:
// Go on.
}

restart := s.setConfig(req)
s.conf.ConfigModified()

if restart {
if err := s.Reconfigure(nil); err != nil {
if err = s.Reconfigure(nil); err != nil {
aghhttp.Error(r, w, http.StatusInternalServerError, "%s", err)
}
}
Expand Down Expand Up @@ -387,14 +386,14 @@ func ValidateUpstreams(upstreams []string) (err error) {

var defaultUpstreamFound bool
for _, u := range upstreams {
var ok bool
ok, err = validateUpstream(u)
var useDefault bool
useDefault, err = validateUpstream(u)
if err != nil {
return err
}

if !defaultUpstreamFound {
defaultUpstreamFound = ok
defaultUpstreamFound = useDefault
}
}

Expand All @@ -407,87 +406,74 @@ func ValidateUpstreams(upstreams []string) (err error) {

var protocols = []string{"tls://", "https://", "tcp://", "sdns://", "quic://"}

func validateUpstream(u string) (bool, error) {
func validateUpstream(u string) (useDefault bool, err error) {
// Check if the user tries to specify upstream for domain.
u, useDefault, err := separateUpstream(u)
var isDomainSpec bool
u, isDomainSpec, err = separateUpstream(u)
if err != nil {
return useDefault, err
return !isDomainSpec, err
}

// The special server address '#' means "use the default servers"
if u == "#" && !useDefault {
// The special server address '#' means that default server must be used.
if useDefault = !isDomainSpec; u == "#" && isDomainSpec {
return useDefault, nil
}

// Check if the upstream has a valid protocol prefix
// Check if the upstream has a valid protocol prefix.
//
// TODO(e.burkov): Validate the domain name.
for _, proto := range protocols {
if strings.HasPrefix(u, proto) {
return useDefault, nil
}
}

// Return error if the upstream contains '://' without any valid protocol
if strings.Contains(u, "://") {
return useDefault, fmt.Errorf("wrong protocol")
return useDefault, errors.Error("wrong protocol")
}

// Check if upstream is valid plain DNS
return useDefault, checkPlainDNS(u)
// Check if upstream is either an IP or IP with port.
if net.ParseIP(u) != nil {
return useDefault, nil
} else if _, err = netutil.ParseIPPort(u); err != nil {
return useDefault, err
}

return useDefault, nil
}

// separateUpstream returns the upstream without the specified domains.
// useDefault is true when a default upstream must be used.
func separateUpstream(upstreamStr string) (upstream string, useDefault bool, err error) {
defer func() { err = errors.Annotate(err, "bad upstream for domain spec %q: %w", upstreamStr) }()

// isDomainSpec is true when the upstream is domains-specific.
func separateUpstream(upstreamStr string) (upstream string, isDomainSpec bool, err error) {
if !strings.HasPrefix(upstreamStr, "[/") {
return upstreamStr, true, nil
return upstreamStr, false, nil
}
defer func() { err = errors.Annotate(err, "bad upstream for domain %q: %w", upstreamStr) }()

parts := strings.Split(upstreamStr[2:], "/]")
if len(parts) != 2 {
return "", false, errors.Error("duplicated separator")
switch len(parts) {
case 2:
// Go on.
case 1:
return "", false, errors.Error("missing separator")
default:
return "", true, errors.Error("duplicated separator")
}

domains := parts[0]
upstream = parts[1]
var domains string
domains, upstream = parts[0], parts[1]
for i, host := range strings.Split(domains, "/") {
if host == "" {
continue
}

err = netutil.ValidateDomainName(host)
if err != nil {
return "", false, fmt.Errorf("domain at index %d: %w", i, err)
return "", true, fmt.Errorf("domain at index %d: %w", i, err)
}
}

return upstream, false, nil
}

// checkPlainDNS checks if host is plain DNS
func checkPlainDNS(upstream string) error {
// Check if host is ip without port
if net.ParseIP(upstream) != nil {
return nil
}

// Check if host is ip with port
ip, port, err := net.SplitHostPort(upstream)
if err != nil {
return err
}

if net.ParseIP(ip) == nil {
return fmt.Errorf("%s is not a valid IP", ip)
}

_, err = strconv.ParseInt(port, 0, 64)
if err != nil {
return fmt.Errorf("%s is not a valid port: %w", port, err)
}

return nil
return upstream, true, nil
}

// excFunc is a signature of function to check if upstream exchanges correctly.
Expand Down Expand Up @@ -515,12 +501,8 @@ func checkDNSUpstreamExc(u upstream.Upstream) (err error) {

if len(reply.Answer) != 1 {
return fmt.Errorf("wrong response")
}

if t, ok := reply.Answer[0].(*dns.A); ok {
if !net.IPv4(8, 8, 8, 8).Equal(t.A) {
return fmt.Errorf("wrong response")
}
} else if a, ok := reply.Answer[0].(*dns.A); !ok || !a.A.Equal(net.IP{8, 8, 8, 8}) {
return fmt.Errorf("wrong response")
}

return nil
Expand Down Expand Up @@ -555,7 +537,7 @@ func checkDNS(input string, bootstrap []string, timeout time.Duration, ef excFun

// Separate upstream from domains list.
var useDefault bool
if input, useDefault, err = separateUpstream(input); err != nil {
if useDefault, err = validateUpstream(input); err != nil {
return fmt.Errorf("wrong upstream format: %w", err)
}

Expand All @@ -564,15 +546,16 @@ func checkDNS(input string, bootstrap []string, timeout time.Duration, ef excFun
return nil
}

if _, err = validateUpstream(input); err != nil {
if input, _, err = separateUpstream(input); err != nil {
return fmt.Errorf("wrong upstream format: %w", err)
}

if len(bootstrap) == 0 {
bootstrap = defaultBootstrap
}

log.Debug("checking if dns server %q works...", input)
log.Debug("checking if upstream %s works", input)

var u upstream.Upstream
u, err = upstream.AddressToUpstream(input, &upstream.Options{
Bootstrap: bootstrap,
Expand All @@ -586,7 +569,7 @@ func checkDNS(input string, bootstrap []string, timeout time.Duration, ef excFun
return fmt.Errorf("upstream %q fails to exchange: %w", input, err)
}

log.Debug("dns %s works OK", input)
log.Debug("upstream %s is ok", input)

return nil
}
Expand Down Expand Up @@ -620,9 +603,9 @@ func (s *Server) handleTestUpstreamDNS(w http.ResponseWriter, r *http.Request) {
err = checkDNS(host, bootstraps, timeout, checkPrivateUpstreamExc)
if err != nil {
log.Info("%v", err)
// TODO(e.burkov): If passed upstream have already
// written an error above, we rewriting the error for
// it. These cases should be handled properly instead.
// TODO(e.burkov): If passed upstream have already written an error
// above, we rewriting the error for it. These cases should be
// handled properly instead.
result[host] = err.Error()

continue
Expand Down
Loading

0 comments on commit 0de155b

Please sign in to comment.