Skip to content

Commit

Permalink
home: fix dns address fallback
Browse files Browse the repository at this point in the history
  • Loading branch information
ainar-g committed Mar 24, 2021
1 parent 5d0d32b commit 7b1dfa6
Show file tree
Hide file tree
Showing 2 changed files with 82 additions and 51 deletions.
90 changes: 80 additions & 10 deletions internal/home/control.go
Original file line number Diff line number Diff line change
Expand Up @@ -34,15 +34,75 @@ func httpError(w http.ResponseWriter, code int, format string, args ...interface
http.Error(w, text, code)
}

// ---------------
// dns run control
// ---------------
func addDNSAddress(dnsAddresses *[]string, addr net.IP) {
hostport := addr.String()
if config.DNS.Port != 53 {
hostport = net.JoinHostPort(hostport, strconv.Itoa(config.DNS.Port))
// appendDNSAddrs is a convenient helper for appending a formatted form of DNS
// addresses to a slice of strings.
func appendDNSAddrs(dst []string, addrs ...net.IP) (res []string) {
for _, addr := range addrs {
hostport := addr.String()
if config.DNS.Port != 53 {
hostport = net.JoinHostPort(hostport, strconv.Itoa(config.DNS.Port))
}

dst = append(dst, hostport)
}

return dst
}

// appendDNSAddrsWithIfaces formats and appends all DNS addresses from src to
// dst. It also adds the IP addresses of all network interfaces if src contains
// an unspecified IP addresss.
func appendDNSAddrsWithIfaces(dst []string, src []net.IP) (res []string, err error) {
ifacesAdded := false
for _, h := range src {
if !h.IsUnspecified() || ifacesAdded {
dst = appendDNSAddrs(dst, h)
}

// Add addresses of all network interfaces for addresses like
// "0.0.0.0" and "::".
var ifaces []*aghnet.NetInterface
ifaces, err = aghnet.GetValidNetInterfacesForWeb()
if err != nil {
return nil, fmt.Errorf("cannot get network interfaces: %w", err)
}

for _, iface := range ifaces {
dst = appendDNSAddrs(dst, iface.Addresses...)
}

ifacesAdded = true
}

return dst, nil
}

// collectDNSAddresses returns the list of DNS addresses the server is listening
// on, including the addresses on all interfaces in cases of unspecified IPs.
func collectDNSAddresses() (addrs []string, err error) {
if hosts := config.DNS.BindHosts; len(hosts) == 0 {
addrs = appendDNSAddrs(addrs, net.IP{127, 0, 0, 1})
} else {
addrs, err = appendDNSAddrsWithIfaces(addrs, hosts)
if err != nil {
return nil, fmt.Errorf("collecting dns addresses: %w", err)
}
}

de := getDNSEncryption()
if de.https != "" {
addrs = append(addrs, de.https)
}

if de.tls != "" {
addrs = append(addrs, de.tls)
}

if de.quic != "" {
addrs = append(addrs, de.quic)
}
*dnsAddresses = append(*dnsAddresses, hostport)

return addrs, nil
}

// statusResponse is a response for /control/status endpoint.
Expand All @@ -60,8 +120,17 @@ type statusResponse struct {
}

func handleStatus(w http.ResponseWriter, _ *http.Request) {
dnsAddrs, err := collectDNSAddresses()
if err != nil {
// Don't add a lot of formatting, since the error is already
// wrapped by collectDNSAddresses.
httpError(w, http.StatusInternalServerError, "%s", err)

return
}

resp := statusResponse{
DNSAddrs: getDNSAddresses(),
DNSAddrs: dnsAddrs,
DNSPort: config.DNS.Port,
HTTPPort: config.BindPort,
IsRunning: isRunning(),
Expand All @@ -82,9 +151,10 @@ func handleStatus(w http.ResponseWriter, _ *http.Request) {
}

w.Header().Set("Content-Type", "application/json")
err := json.NewEncoder(w).Encode(resp)
err = json.NewEncoder(w).Encode(resp)
if err != nil {
httpError(w, http.StatusInternalServerError, "Unable to write response json: %s", err)

return
}
}
Expand Down
43 changes: 2 additions & 41 deletions internal/home/dns.go
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,6 @@ import (
"strconv"

"github.com/AdguardTeam/AdGuardHome/internal/agherr"
"github.com/AdguardTeam/AdGuardHome/internal/aghnet"
"github.com/AdguardTeam/AdGuardHome/internal/dnsfilter"
"github.com/AdguardTeam/AdGuardHome/internal/dnsforward"
"github.com/AdguardTeam/AdGuardHome/internal/querylog"
Expand Down Expand Up @@ -144,10 +143,8 @@ func ipsToUDPAddrs(ips []net.IP, port int) (udpAddrs []*net.UDPAddr) {
func generateServerConfig() (newConf dnsforward.ServerConfig, err error) {
dnsConf := config.DNS
hosts := dnsConf.BindHosts
for i, h := range hosts {
if h.IsUnspecified() {
hosts[i] = net.IP{127, 0, 0, 1}
}
if len(hosts) == 0 {
hosts = []net.IP{{127, 0, 0, 1}}
}

newConf = dnsforward.ServerConfig{
Expand Down Expand Up @@ -268,42 +265,6 @@ func getDNSEncryption() (de dnsEncryption) {
return de
}

// Get the list of DNS addresses the server is listening on
func getDNSAddresses() (dnsAddrs []string) {
if hosts := config.DNS.BindHosts; len(hosts) == 0 || hosts[0].IsUnspecified() {
ifaces, e := aghnet.GetValidNetInterfacesForWeb()
if e != nil {
log.Error("Couldn't get network interfaces: %v", e)
return []string{}
}

for _, iface := range ifaces {
for _, addr := range iface.Addresses {
addDNSAddress(&dnsAddrs, addr)
}
}
} else {
for _, h := range hosts {
addDNSAddress(&dnsAddrs, h)
}
}

de := getDNSEncryption()
if de.https != "" {
dnsAddrs = append(dnsAddrs, de.https)
}

if de.tls != "" {
dnsAddrs = append(dnsAddrs, de.tls)
}

if de.quic != "" {
dnsAddrs = append(dnsAddrs, de.quic)
}

return dnsAddrs
}

// applyAdditionalFiltering adds additional client information and settings if
// the client has them.
func applyAdditionalFiltering(clientAddr net.IP, clientID string, setts *dnsfilter.RequestFilteringSettings) {
Expand Down

0 comments on commit 7b1dfa6

Please sign in to comment.