diff --git a/internal/aghnet/net.go b/internal/aghnet/net.go index bdcc6e4f3ca..b8c8c05e443 100644 --- a/internal/aghnet/net.go +++ b/internal/aghnet/net.go @@ -31,12 +31,6 @@ var ( // the IP being static is available. const ErrNoStaticIPInfo errors.Error = "no information about static ip" -// IPv4Localhost returns 127.0.0.1, which returns true for [netip.Addr.Is4]. -func IPv4Localhost() (ip netip.Addr) { return netip.AddrFrom4([4]byte{127, 0, 0, 1}) } - -// IPv6Localhost returns ::1, which returns true for [netip.Addr.Is6]. -func IPv6Localhost() (ip netip.Addr) { return netip.AddrFrom16([16]byte{15: 1}) } - // IfaceHasStaticIP checks if interface is configured to have static IP address. // If it can't give a definitive answer, it returns false and an error for which // errors.Is(err, ErrNoStaticIPInfo) is true. diff --git a/internal/aghnet/net_test.go b/internal/aghnet/net_test.go index f4275c6c320..6e9e612e5b0 100644 --- a/internal/aghnet/net_test.go +++ b/internal/aghnet/net_test.go @@ -188,7 +188,7 @@ func TestBroadcastFromIPNet(t *testing.T) { } func TestCheckPort(t *testing.T) { - laddr := netip.AddrPortFrom(IPv4Localhost(), 0) + laddr := netip.AddrPortFrom(netutil.IPv4Localhost(), 0) t.Run("tcp_bound", func(t *testing.T) { l, err := net.Listen("tcp", laddr.String()) diff --git a/internal/dnsforward/clientid.go b/internal/dnsforward/clientid.go index 6a111c47a30..fb5eefdae08 100644 --- a/internal/dnsforward/clientid.go +++ b/internal/dnsforward/clientid.go @@ -23,16 +23,6 @@ func ValidateClientID(id string) (err error) { return nil } -// hasLabelSuffix returns true if s ends with suffix preceded by a dot. It's -// a helper function to prevent unnecessary allocations in code like: -// -// if strings.HasSuffix(s, "." + suffix) { /* … */ } -// -// s must be longer than suffix. -func hasLabelSuffix(s, suffix string) (ok bool) { - return strings.HasSuffix(s, suffix) && s[len(s)-len(suffix)-1] == '.' -} - // clientIDFromClientServerName extracts and validates a ClientID. hostSrvName // is the server name of the host. cliSrvName is the server name as sent by the // client. When strict is true, and client and host server name don't match, @@ -46,7 +36,7 @@ func clientIDFromClientServerName( return "", nil } - if !hasLabelSuffix(cliSrvName, hostSrvName) { + if !netutil.IsImmediateSubdomain(cliSrvName, hostSrvName) { if !strict { return "", nil } diff --git a/internal/dnsforward/dnsforward.go b/internal/dnsforward/dnsforward.go index 6d10715334e..53091e23729 100644 --- a/internal/dnsforward/dnsforward.go +++ b/internal/dnsforward/dnsforward.go @@ -246,6 +246,7 @@ type RDNSExchanger interface { // Exchange tries to resolve the ip in a suitable way, e.g. either as // local or as external. Exchange(ip net.IP) (host string, err error) + // ResolvesPrivatePTR returns true if the RDNSExchanger is able to // resolve PTR requests for locally-served addresses. ResolvesPrivatePTR() (ok bool) @@ -261,6 +262,9 @@ const ( rDNSNotPTRErr errors.Error = "the response is not a ptr" ) +// type check +var _ RDNSExchanger = (*Server)(nil) + // Exchange implements the RDNSExchanger interface for *Server. func (s *Server) Exchange(ip net.IP) (host string, err error) { s.serverLock.RLock() @@ -675,21 +679,13 @@ func (s *Server) ServeHTTP(w http.ResponseWriter, r *http.Request) { // IsBlockedClient returns true if the client is blocked by the current access // settings. -func (s *Server) IsBlockedClient(ip net.IP, clientID string) (blocked bool, rule string) { +func (s *Server) IsBlockedClient(ip netip.Addr, clientID string) (blocked bool, rule string) { s.serverLock.RLock() defer s.serverLock.RUnlock() blockedByIP := false - if ip != nil { - // TODO(a.garipov): Remove once we switch to netip.Addr more fully. - ipAddr, err := netutil.IPToAddrNoMapped(ip) - if err != nil { - log.Error("dnsforward: bad client ip %v: %s", ip, err) - - return false, "" - } - - blockedByIP, rule = s.access.isBlockedIP(ipAddr) + if ip != (netip.Addr{}) { + blockedByIP, rule = s.access.isBlockedIP(ip) } allowlistMode := s.access.allowlistMode() diff --git a/internal/dnsforward/filter.go b/internal/dnsforward/filter.go index 6f64d35d685..f36fd52a5e2 100644 --- a/internal/dnsforward/filter.go +++ b/internal/dnsforward/filter.go @@ -19,13 +19,13 @@ func (s *Server) beforeRequestHandler( _ *proxy.Proxy, pctx *proxy.DNSContext, ) (reply bool, err error) { - ip, _ := netutil.IPAndPortFromAddr(pctx.Addr) clientID, err := s.clientIDFromDNSContext(pctx) if err != nil { return false, fmt.Errorf("getting clientid: %w", err) } - blocked, _ := s.IsBlockedClient(ip, clientID) + addrPort := netutil.NetAddrToAddrPort(pctx.Addr) + blocked, _ := s.IsBlockedClient(addrPort.Addr(), clientID) if blocked { return s.preBlockedResponse(pctx) } diff --git a/internal/filtering/filter_test.go b/internal/filtering/filter_test.go index 99014bd0fa7..81a7929a57b 100644 --- a/internal/filtering/filter_test.go +++ b/internal/filtering/filter_test.go @@ -11,7 +11,7 @@ import ( "testing" "time" - "github.com/AdguardTeam/AdGuardHome/internal/aghnet" + "github.com/AdguardTeam/golibs/netutil" "github.com/AdguardTeam/golibs/testutil" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" @@ -40,7 +40,7 @@ func serveFiltersLocally(t *testing.T, fltContent []byte) (ipp netip.AddrPort) { addr := l.Addr() require.IsType(t, new(net.TCPAddr), addr) - return netip.AddrPortFrom(aghnet.IPv4Localhost(), uint16(addr.(*net.TCPAddr).Port)) + return netip.AddrPortFrom(netutil.IPv4Localhost(), uint16(addr.(*net.TCPAddr).Port)) } func TestFilters(t *testing.T) { diff --git a/internal/home/clients.go b/internal/home/clients.go index fa230178b5c..f46c616eced 100644 --- a/internal/home/clients.go +++ b/internal/home/clients.go @@ -129,7 +129,7 @@ type RuntimeClientWHOISInfo struct { type clientsContainer struct { // TODO(a.garipov): Perhaps use a number of separate indices for - // different types (string, net.IP, and so on). + // different types (string, netip.Addr, and so on). list map[string]*Client // name -> client idIndex map[string]*Client // ID -> client @@ -333,7 +333,7 @@ func (clients *clientsContainer) onDHCPLeaseChanged(flags int) { } // exists checks if client with this IP address already exists. -func (clients *clientsContainer) exists(ip net.IP, source clientSource) (ok bool) { +func (clients *clientsContainer) exists(ip netip.Addr, source clientSource) (ok bool) { clients.lock.Lock() defer clients.lock.Unlock() @@ -342,7 +342,7 @@ func (clients *clientsContainer) exists(ip net.IP, source clientSource) (ok bool return true } - rc, ok := clients.findRuntimeClientLocked(ip) + rc, ok := clients.ipToRC[ip] if !ok { return false } @@ -371,7 +371,8 @@ func (clients *clientsContainer) findMultiple(ids []string) (c *querylog.Client, var artClient *querylog.Client var art bool for _, id := range ids { - c, art = clients.clientOrArtificial(net.ParseIP(id), id) + ip, _ := netip.ParseAddr(id) + c, art = clients.clientOrArtificial(ip, id) if art { artClient = c @@ -389,7 +390,7 @@ func (clients *clientsContainer) findMultiple(ids []string) (c *querylog.Client, // records about this client besides maybe whether or not it is blocked. c is // never nil. func (clients *clientsContainer) clientOrArtificial( - ip net.IP, + ip netip.Addr, id string, ) (c *querylog.Client, art bool) { defer func() { @@ -406,13 +407,6 @@ func (clients *clientsContainer) clientOrArtificial( }, false } - if ip == nil { - // Technically should never happen, but still. - return &querylog.Client{ - Name: "", - }, true - } - var rc *RuntimeClient rc, ok = clients.findRuntimeClient(ip) if ok { @@ -492,19 +486,20 @@ func (clients *clientsContainer) findLocked(id string) (c *Client, ok bool) { return c, true } - ip := net.ParseIP(id) - if ip == nil { + ip, err := netip.ParseAddr(id) + if err != nil { return nil, false } for _, c = range clients.list { for _, id := range c.IDs { - _, ipnet, err := net.ParseCIDR(id) + var n netip.Prefix + n, err = netip.ParsePrefix(id) if err != nil { continue } - if ipnet.Contains(ip) { + if n.Contains(ip) { return c, true } } @@ -514,19 +509,20 @@ func (clients *clientsContainer) findLocked(id string) (c *Client, ok bool) { return nil, false } - macFound := clients.dhcpServer.FindMACbyIP(ip) + macFound := clients.dhcpServer.FindMACbyIP(ip.AsSlice()) if macFound == nil { return nil, false } for _, c = range clients.list { for _, id := range c.IDs { - hwAddr, err := net.ParseMAC(id) + var mac net.HardwareAddr + mac, err = net.ParseMAC(id) if err != nil { continue } - if bytes.Equal(hwAddr, macFound) { + if bytes.Equal(mac, macFound) { return c, true } } @@ -535,32 +531,18 @@ func (clients *clientsContainer) findLocked(id string) (c *Client, ok bool) { return nil, false } -// findRuntimeClientLocked finds a runtime client by their IP address. For -// internal use only. -func (clients *clientsContainer) findRuntimeClientLocked(ip net.IP) (rc *RuntimeClient, ok bool) { - // TODO(a.garipov): Remove once we switch to netip.Addr more fully. - ipAddr, err := netutil.IPToAddrNoMapped(ip) - if err != nil { - log.Error("clients: bad client ip %v: %s", ip, err) - - return nil, false - } - - rc, ok = clients.ipToRC[ipAddr] - - return rc, ok -} - // findRuntimeClient finds a runtime client by their IP. -func (clients *clientsContainer) findRuntimeClient(ip net.IP) (rc *RuntimeClient, ok bool) { - if ip == nil { +func (clients *clientsContainer) findRuntimeClient(ip netip.Addr) (rc *RuntimeClient, ok bool) { + if ip == (netip.Addr{}) { return nil, false } clients.lock.Lock() defer clients.lock.Unlock() - return clients.findRuntimeClientLocked(ip) + rc, ok = clients.ipToRC[ip] + + return rc, ok } // check validates the client. @@ -578,14 +560,16 @@ func (clients *clientsContainer) check(c *Client) (err error) { for i, id := range c.IDs { // Normalize structured data. - var ip net.IP - var ipnet *net.IPNet - var mac net.HardwareAddr - if ip = net.ParseIP(id); ip != nil { + var ( + ip netip.Addr + n netip.Prefix + mac net.HardwareAddr + ) + + if ip, err = netip.ParseAddr(id); err == nil { c.IDs[i] = ip.String() - } else if ip, ipnet, err = net.ParseCIDR(id); err == nil { - ipnet.IP = ip - c.IDs[i] = ipnet.String() + } else if n, err = netip.ParsePrefix(id); err == nil { + c.IDs[i] = n.String() } else if mac, err = net.ParseMAC(id); err == nil { c.IDs[i] = mac.String() } else if err = dnsforward.ValidateClientID(id); err == nil { @@ -750,7 +734,7 @@ func (clients *clientsContainer) Update(name string, c *Client) (err error) { } // setWHOISInfo sets the WHOIS information for a client. -func (clients *clientsContainer) setWHOISInfo(ip net.IP, wi *RuntimeClientWHOISInfo) { +func (clients *clientsContainer) setWHOISInfo(ip netip.Addr, wi *RuntimeClientWHOISInfo) { clients.lock.Lock() defer clients.lock.Unlock() @@ -760,7 +744,7 @@ func (clients *clientsContainer) setWHOISInfo(ip net.IP, wi *RuntimeClientWHOISI return } - rc, ok := clients.findRuntimeClientLocked(ip) + rc, ok := clients.ipToRC[ip] if ok { rc.WHOISInfo = wi log.Debug("clients: set whois info for runtime client %s: %+v", rc.Host, wi) @@ -776,32 +760,22 @@ func (clients *clientsContainer) setWHOISInfo(ip net.IP, wi *RuntimeClientWHOISI rc.WHOISInfo = wi - // TODO(a.garipov): Remove once we switch to netip.Addr more fully. - ipAddr, err := netutil.IPToAddrNoMapped(ip) - if err != nil { - log.Error("clients: bad client ip %v: %s", ip, err) - - return - } - - clients.ipToRC[ipAddr] = rc + clients.ipToRC[ip] = rc log.Debug("clients: set whois info for runtime client with ip %s: %+v", ip, wi) } // AddHost adds a new IP-hostname pairing. The priorities of the sources are // taken into account. ok is true if the pairing was added. -func (clients *clientsContainer) AddHost(ip net.IP, host string, src clientSource) (ok bool, err error) { +func (clients *clientsContainer) AddHost( + ip netip.Addr, + host string, + src clientSource, +) (ok bool) { clients.lock.Lock() defer clients.lock.Unlock() - // TODO(a.garipov): Remove once we switch to netip.Addr more fully. - ipAddr, err := netutil.IPToAddrNoMapped(ip) - if err != nil { - return false, fmt.Errorf("adding host: %w", err) - } - - return clients.addHostLocked(ipAddr, host, src), nil + return clients.addHostLocked(ip, host, src) } // addHostLocked adds a new IP-hostname pairing. clients.lock is expected to be diff --git a/internal/home/clients_test.go b/internal/home/clients_test.go index 636971ebf40..ca98352a927 100644 --- a/internal/home/clients_test.go +++ b/internal/home/clients_test.go @@ -22,8 +22,18 @@ func TestClients(t *testing.T) { clients.Init(nil, nil, nil, nil) t.Run("add_success", func(t *testing.T) { + var ( + cliNone = "1.2.3.4" + cli1 = "1.1.1.1" + cli2 = "2.2.2.2" + + cliNoneIP = netip.MustParseAddr(cliNone) + cli1IP = netip.MustParseAddr(cli1) + cli2IP = netip.MustParseAddr(cli2) + ) + c := &Client{ - IDs: []string{"1.1.1.1", "1:2:3::4", "aa:aa:aa:aa:aa:aa"}, + IDs: []string{cli1, "1:2:3::4", "aa:aa:aa:aa:aa:aa"}, Name: "client1", } @@ -33,7 +43,7 @@ func TestClients(t *testing.T) { assert.True(t, ok) c = &Client{ - IDs: []string{"2.2.2.2"}, + IDs: []string{cli2}, Name: "client2", } @@ -42,7 +52,7 @@ func TestClients(t *testing.T) { assert.True(t, ok) - c, ok = clients.Find("1.1.1.1") + c, ok = clients.Find(cli1) require.True(t, ok) assert.Equal(t, "client1", c.Name) @@ -52,14 +62,14 @@ func TestClients(t *testing.T) { assert.Equal(t, "client1", c.Name) - c, ok = clients.Find("2.2.2.2") + c, ok = clients.Find(cli2) require.True(t, ok) assert.Equal(t, "client2", c.Name) - assert.False(t, clients.exists(net.IP{1, 2, 3, 4}, ClientSourceHostsFile)) - assert.True(t, clients.exists(net.IP{1, 1, 1, 1}, ClientSourceHostsFile)) - assert.True(t, clients.exists(net.IP{2, 2, 2, 2}, ClientSourceHostsFile)) + assert.False(t, clients.exists(cliNoneIP, ClientSourceHostsFile)) + assert.True(t, clients.exists(cli1IP, ClientSourceHostsFile)) + assert.True(t, clients.exists(cli2IP, ClientSourceHostsFile)) }) t.Run("add_fail_name", func(t *testing.T) { @@ -103,23 +113,31 @@ func TestClients(t *testing.T) { }) t.Run("update_success", func(t *testing.T) { + var ( + cliOld = "1.1.1.1" + cliNew = "1.1.1.2" + + cliOldIP = netip.MustParseAddr(cliOld) + cliNewIP = netip.MustParseAddr(cliNew) + ) + err := clients.Update("client1", &Client{ - IDs: []string{"1.1.1.2"}, + IDs: []string{cliNew}, Name: "client1", }) require.NoError(t, err) - assert.False(t, clients.exists(net.IP{1, 1, 1, 1}, ClientSourceHostsFile)) - assert.True(t, clients.exists(net.IP{1, 1, 1, 2}, ClientSourceHostsFile)) + assert.False(t, clients.exists(cliOldIP, ClientSourceHostsFile)) + assert.True(t, clients.exists(cliNewIP, ClientSourceHostsFile)) err = clients.Update("client1", &Client{ - IDs: []string{"1.1.1.2"}, + IDs: []string{cliNew}, Name: "client1-renamed", UseOwnSettings: true, }) require.NoError(t, err) - c, ok := clients.Find("1.1.1.2") + c, ok := clients.Find(cliNew) require.True(t, ok) assert.Equal(t, "client1-renamed", c.Name) @@ -132,14 +150,14 @@ func TestClients(t *testing.T) { require.Len(t, c.IDs, 1) - assert.Equal(t, "1.1.1.2", c.IDs[0]) + assert.Equal(t, cliNew, c.IDs[0]) }) t.Run("del_success", func(t *testing.T) { ok := clients.Del("client1-renamed") require.True(t, ok) - assert.False(t, clients.exists(net.IP{1, 1, 1, 2}, ClientSourceHostsFile)) + assert.False(t, clients.exists(netip.MustParseAddr("1.1.1.2"), ClientSourceHostsFile)) }) t.Run("del_fail", func(t *testing.T) { @@ -148,45 +166,33 @@ func TestClients(t *testing.T) { }) t.Run("addhost_success", func(t *testing.T) { - ip := net.IP{1, 1, 1, 1} - - ok, err := clients.AddHost(ip, "host", ClientSourceARP) - require.NoError(t, err) - + ip := netip.MustParseAddr("1.1.1.1") + ok := clients.AddHost(ip, "host", ClientSourceARP) assert.True(t, ok) - ok, err = clients.AddHost(ip, "host2", ClientSourceARP) - require.NoError(t, err) - + ok = clients.AddHost(ip, "host2", ClientSourceARP) assert.True(t, ok) - ok, err = clients.AddHost(ip, "host3", ClientSourceHostsFile) - require.NoError(t, err) - + ok = clients.AddHost(ip, "host3", ClientSourceHostsFile) assert.True(t, ok) assert.True(t, clients.exists(ip, ClientSourceHostsFile)) }) t.Run("dhcp_replaces_arp", func(t *testing.T) { - ip := net.IP{1, 2, 3, 4} - - ok, err := clients.AddHost(ip, "from_arp", ClientSourceARP) - require.NoError(t, err) - + ip := netip.MustParseAddr("1.2.3.4") + ok := clients.AddHost(ip, "from_arp", ClientSourceARP) assert.True(t, ok) assert.True(t, clients.exists(ip, ClientSourceARP)) - ok, err = clients.AddHost(ip, "from_dhcp", ClientSourceDHCP) - require.NoError(t, err) - + ok = clients.AddHost(ip, "from_dhcp", ClientSourceDHCP) assert.True(t, ok) assert.True(t, clients.exists(ip, ClientSourceDHCP)) }) t.Run("addhost_fail", func(t *testing.T) { - ok, err := clients.AddHost(net.IP{1, 1, 1, 1}, "host1", ClientSourceRDNS) - require.NoError(t, err) + ip := netip.MustParseAddr("1.1.1.1") + ok := clients.AddHost(ip, "host1", ClientSourceRDNS) assert.False(t, ok) }) } @@ -203,7 +209,7 @@ func TestClientsWHOIS(t *testing.T) { t.Run("new_client", func(t *testing.T) { ip := netip.MustParseAddr("1.1.1.255") - clients.setWHOISInfo(ip.AsSlice(), whois) + clients.setWHOISInfo(ip, whois) rc := clients.ipToRC[ip] require.NotNil(t, rc) @@ -212,12 +218,10 @@ func TestClientsWHOIS(t *testing.T) { t.Run("existing_auto-client", func(t *testing.T) { ip := netip.MustParseAddr("1.1.1.1") - ok, err := clients.AddHost(ip.AsSlice(), "host", ClientSourceRDNS) - require.NoError(t, err) - + ok := clients.AddHost(ip, "host", ClientSourceRDNS) assert.True(t, ok) - clients.setWHOISInfo(ip.AsSlice(), whois) + clients.setWHOISInfo(ip, whois) rc := clients.ipToRC[ip] require.NotNil(t, rc) @@ -234,7 +238,7 @@ func TestClientsWHOIS(t *testing.T) { require.NoError(t, err) assert.True(t, ok) - clients.setWHOISInfo(ip.AsSlice(), whois) + clients.setWHOISInfo(ip, whois) rc := clients.ipToRC[ip] require.Nil(t, rc) @@ -249,7 +253,7 @@ func TestClientsAddExisting(t *testing.T) { clients.Init(nil, nil, nil, nil) t.Run("simple", func(t *testing.T) { - ip := net.IP{1, 1, 1, 1} + ip := netip.MustParseAddr("1.1.1.1") // Add a client. ok, err := clients.Add(&Client{ @@ -260,8 +264,7 @@ func TestClientsAddExisting(t *testing.T) { assert.True(t, ok) // Now add an auto-client with the same IP. - ok, err = clients.AddHost(ip, "test", ClientSourceRDNS) - require.NoError(t, err) + ok = clients.AddHost(ip, "test", ClientSourceRDNS) assert.True(t, ok) }) diff --git a/internal/home/clientshttp.go b/internal/home/clientshttp.go index 313fd998efe..d9bb0ea9f1f 100644 --- a/internal/home/clientshttp.go +++ b/internal/home/clientshttp.go @@ -3,8 +3,8 @@ package home import ( "encoding/json" "fmt" - "net" "net/http" + "net/netip" "github.com/AdguardTeam/AdGuardHome/internal/aghhttp" ) @@ -47,8 +47,8 @@ type runtimeClientJSON struct { WHOISInfo *RuntimeClientWHOISInfo `json:"whois_info"` Name string `json:"name"` + IP netip.Addr `json:"ip"` Source clientSource `json:"source"` - IP net.IP `json:"ip"` } type clientListJSON struct { @@ -75,7 +75,7 @@ func (clients *clientsContainer) handleGetClients(w http.ResponseWriter, r *http Name: rc.Host, Source: rc.Source, - IP: ip.AsSlice(), + IP: ip, } data.RuntimeClients = append(data.RuntimeClients, cj) @@ -218,7 +218,7 @@ func (clients *clientsContainer) handleFindClient(w http.ResponseWriter, r *http break } - ip := net.ParseIP(idStr) + ip, _ := netip.ParseAddr(idStr) c, ok := clients.Find(idStr) var cj *clientJSON if !ok { @@ -240,7 +240,7 @@ func (clients *clientsContainer) handleFindClient(w http.ResponseWriter, r *http // findRuntime looks up the IP in runtime and temporary storages, like // /etc/hosts tables, DHCP leases, or blocklists. cj is guaranteed to be // non-nil. -func (clients *clientsContainer) findRuntime(ip net.IP, idStr string) (cj *clientJSON) { +func (clients *clientsContainer) findRuntime(ip netip.Addr, idStr string) (cj *clientJSON) { rc, ok := clients.findRuntimeClient(ip) if !ok { // It is still possible that the IP used to be in the runtime clients diff --git a/internal/home/control.go b/internal/home/control.go index 5e4e6df2cc8..f9e7d4d2afa 100644 --- a/internal/home/control.go +++ b/internal/home/control.go @@ -71,9 +71,7 @@ func appendDNSAddrsWithIfaces(dst []string, src []netip.Addr) (res []string, err // on, including the addresses on all interfaces in cases of unspecified IPs. func collectDNSAddresses() (addrs []string, err error) { if hosts := config.DNS.BindHosts; len(hosts) == 0 { - addr := aghnet.IPv4Localhost() - - addrs = appendDNSAddrs(addrs, addr) + addrs = appendDNSAddrs(addrs, netutil.IPv4Localhost()) } else { addrs, err = appendDNSAddrsWithIfaces(addrs, hosts) if err != nil { diff --git a/internal/home/dns.go b/internal/home/dns.go index 7d40ed351fe..1980b252392 100644 --- a/internal/home/dns.go +++ b/internal/home/dns.go @@ -8,6 +8,7 @@ import ( "os" "path/filepath" + "github.com/AdguardTeam/AdGuardHome/internal/aghalg" "github.com/AdguardTeam/AdGuardHome/internal/aghnet" "github.com/AdguardTeam/AdGuardHome/internal/dnsforward" "github.com/AdguardTeam/AdGuardHome/internal/filtering" @@ -150,8 +151,8 @@ func isRunning() bool { } func onDNSRequest(pctx *proxy.DNSContext) { - ip, _ := netutil.IPAndPortFromAddr(pctx.Addr) - if ip == nil { + ip := netutil.NetAddrToAddrPort(pctx.Addr).Addr() + if ip == (netip.Addr{}) { // This would be quite weird if we get here. return } @@ -160,7 +161,8 @@ func onDNSRequest(pctx *proxy.DNSContext) { if srcs.RDNS && !ip.IsLoopback() { Context.rdns.Begin(ip) } - if srcs.WHOIS && !netutil.IsSpecialPurpose(ip) { + + if srcs.WHOIS && !netutil.IsSpecialPurposeAddr(ip) { Context.whois.Begin(ip) } } @@ -193,11 +195,7 @@ func ipsToUDPAddrs(ips []netip.Addr, port int) (udpAddrs []*net.UDPAddr) { func generateServerConfig() (newConf dnsforward.ServerConfig, err error) { dnsConf := config.DNS - hosts := dnsConf.BindHosts - if len(hosts) == 0 { - hosts = []netip.Addr{aghnet.IPv4Localhost()} - } - + hosts := aghalg.CoalesceSlice(dnsConf.BindHosts, []netip.Addr{netutil.IPv4Localhost()}) newConf = dnsforward.ServerConfig{ UDPListenAddrs: ipsToUDPAddrs(hosts, dnsConf.Port), TCPListenAddrs: ipsToTCPAddrs(hosts, dnsConf.Port), @@ -400,15 +398,12 @@ func startDNSServer() error { const topClientsNumber = 100 // the number of clients to get for _, ip := range Context.stats.TopClientsIP(topClientsNumber) { - if ip == nil { - continue - } - srcs := config.Clients.Sources if srcs.RDNS && !ip.IsLoopback() { Context.rdns.Begin(ip) } - if srcs.WHOIS && !netutil.IsSpecialPurpose(ip) { + + if srcs.WHOIS && !netutil.IsSpecialPurposeAddr(ip) { Context.whois.Begin(ip) } } diff --git a/internal/home/home.go b/internal/home/home.go index 665c41ce417..61bdaba27f1 100644 --- a/internal/home/home.go +++ b/internal/home/home.go @@ -576,7 +576,7 @@ func checkPermissions() { } // We should check if AdGuard Home is able to bind to port 53 - err := aghnet.CheckPort("tcp", netip.AddrPortFrom(aghnet.IPv4Localhost(), defaultPortDNS)) + err := aghnet.CheckPort("tcp", netip.AddrPortFrom(netutil.IPv4Localhost(), defaultPortDNS)) if err != nil { if errors.Is(err, os.ErrPermission) { log.Fatal(`Permission check failed. diff --git a/internal/home/rdns.go b/internal/home/rdns.go index e44000b33ca..c6ce0f59636 100644 --- a/internal/home/rdns.go +++ b/internal/home/rdns.go @@ -2,7 +2,7 @@ package home import ( "encoding/binary" - "net" + "net/netip" "sync/atomic" "time" @@ -21,7 +21,7 @@ type RDNS struct { usePrivate uint32 // ipCh used to pass client's IP to rDNS workerLoop. - ipCh chan net.IP + ipCh chan netip.Addr // ipCache caches the IP addresses to be resolved by rDNS. The resolved // address stays here while it's inside clients. After leaving clients the @@ -50,7 +50,7 @@ func NewRDNS( EnableLRU: true, MaxCount: defaultRDNSCacheSize, }), - ipCh: make(chan net.IP, defaultRDNSIPChSize), + ipCh: make(chan netip.Addr, defaultRDNSIPChSize), } if usePrivate { rDNS.usePrivate = 1 @@ -80,9 +80,10 @@ func (r *RDNS) ensurePrivateCache() { // isCached returns true if ip is already cached and not expired yet. It also // caches it otherwise. -func (r *RDNS) isCached(ip net.IP) (ok bool) { +func (r *RDNS) isCached(ip netip.Addr) (ok bool) { + ipBytes := ip.AsSlice() now := uint64(time.Now().Unix()) - if expire := r.ipCache.Get(ip); len(expire) != 0 { + if expire := r.ipCache.Get(ipBytes); len(expire) != 0 { if binary.BigEndian.Uint64(expire) > now { return true } @@ -91,14 +92,14 @@ func (r *RDNS) isCached(ip net.IP) (ok bool) { // The cache entry either expired or doesn't exist. ttl := make([]byte, 8) binary.BigEndian.PutUint64(ttl, now+defaultRDNSCacheTTL) - r.ipCache.Set(ip, ttl) + r.ipCache.Set(ipBytes, ttl) return false } // Begin adds the ip to the resolving queue if it is not cached or already // resolved. -func (r *RDNS) Begin(ip net.IP) { +func (r *RDNS) Begin(ip netip.Addr) { r.ensurePrivateCache() if r.isCached(ip) || r.clients.exists(ip, ClientSourceRDNS) { @@ -107,9 +108,9 @@ func (r *RDNS) Begin(ip net.IP) { select { case r.ipCh <- ip: - log.Tracef("rdns: %q added to queue", ip) + log.Debug("rdns: %q added to queue", ip) default: - log.Tracef("rdns: queue is full") + log.Debug("rdns: queue is full") } } @@ -119,7 +120,7 @@ func (r *RDNS) workerLoop() { defer log.OnPanic("rdns") for ip := range r.ipCh { - host, err := r.exchanger.Exchange(ip) + host, err := r.exchanger.Exchange(ip.AsSlice()) if err != nil { log.Debug("rdns: resolving %q: %s", ip, err) @@ -128,8 +129,6 @@ func (r *RDNS) workerLoop() { continue } - // Don't handle any errors since AddHost doesn't return non-nil errors - // for now. - _, _ = r.clients.AddHost(ip, host, ClientSourceRDNS) + _ = r.clients.AddHost(ip, host, ClientSourceRDNS) } } diff --git a/internal/home/rdns_test.go b/internal/home/rdns_test.go index 9f90ce5ac58..8dc675b5d88 100644 --- a/internal/home/rdns_test.go +++ b/internal/home/rdns_test.go @@ -27,14 +27,14 @@ func TestRDNS_Begin(t *testing.T) { w := &bytes.Buffer{} aghtest.ReplaceLogWriter(t, w) - ip1234, ip1235 := net.IP{1, 2, 3, 4}, net.IP{1, 2, 3, 5} + ip1234, ip1235 := netip.MustParseAddr("1.2.3.4"), netip.MustParseAddr("1.2.3.5") testCases := []struct { cliIDIndex map[string]*Client - customChan chan net.IP + customChan chan netip.Addr name string wantLog string - req net.IP + ip netip.Addr wantCacheHit int wantCacheMiss int }{{ @@ -42,7 +42,7 @@ func TestRDNS_Begin(t *testing.T) { customChan: nil, name: "cached", wantLog: "", - req: ip1234, + ip: ip1234, wantCacheHit: 1, wantCacheMiss: 0, }, { @@ -50,7 +50,7 @@ func TestRDNS_Begin(t *testing.T) { customChan: nil, name: "not_cached", wantLog: "rdns: queue is full", - req: ip1235, + ip: ip1235, wantCacheHit: 0, wantCacheMiss: 1, }, { @@ -58,15 +58,15 @@ func TestRDNS_Begin(t *testing.T) { customChan: nil, name: "already_in_clients", wantLog: "", - req: ip1235, + ip: ip1235, wantCacheHit: 0, wantCacheMiss: 1, }, { cliIDIndex: map[string]*Client{}, - customChan: make(chan net.IP, 1), + customChan: make(chan netip.Addr, 1), name: "add_to_queue", wantLog: `rdns: "1.2.3.5" added to queue`, - req: ip1235, + ip: ip1235, wantCacheHit: 0, wantCacheMiss: 1, }} @@ -102,7 +102,7 @@ func TestRDNS_Begin(t *testing.T) { } t.Run(tc.name, func(t *testing.T) { - rdns.Begin(tc.req) + rdns.Begin(tc.ip) assert.Equal(t, tc.wantCacheHit, ipCache.Stats().Hit) assert.Equal(t, tc.wantCacheMiss, ipCache.Stats().Miss) assert.Contains(t, w.String(), tc.wantLog) @@ -179,8 +179,8 @@ func TestRDNS_WorkerLoop(t *testing.T) { w := &bytes.Buffer{} aghtest.ReplaceLogWriter(t, w) - localIP := net.IP{192, 168, 1, 1} - revIPv4, err := netutil.IPToReversedAddr(localIP) + localIP := netip.MustParseAddr("192.168.1.1") + revIPv4, err := netutil.IPToReversedAddr(localIP.AsSlice()) require.NoError(t, err) revIPv6, err := netutil.IPToReversedAddr(net.ParseIP("2a00:1450:400c:c06::93")) @@ -201,24 +201,24 @@ func TestRDNS_WorkerLoop(t *testing.T) { testCases := []struct { ups upstream.Upstream + cliIP netip.Addr wantLog string name string - cliIP net.IP }{{ ups: locUpstream, + cliIP: localIP, wantLog: "", name: "all_good", - cliIP: localIP, }, { ups: errUpstream, + cliIP: netip.MustParseAddr("192.168.1.2"), wantLog: `rdns: resolving "192.168.1.2": test upstream error`, name: "resolve_error", - cliIP: net.IP{192, 168, 1, 2}, }, { ups: locUpstream, + cliIP: netip.MustParseAddr("2a00:1450:400c:c06::93"), wantLog: "", name: "ipv6_good", - cliIP: net.ParseIP("2a00:1450:400c:c06::93"), }} for _, tc := range testCases { @@ -230,7 +230,7 @@ func TestRDNS_WorkerLoop(t *testing.T) { ipToRC: map[netip.Addr]*RuntimeClient{}, allTags: stringutil.NewSet(), } - ch := make(chan net.IP) + ch := make(chan netip.Addr) rdns := &RDNS{ exchanger: &rDNSExchanger{ ex: tc.ups, diff --git a/internal/home/whois.go b/internal/home/whois.go index c9834708a37..9ffee9e0d9a 100644 --- a/internal/home/whois.go +++ b/internal/home/whois.go @@ -6,6 +6,7 @@ import ( "fmt" "io" "net" + "net/netip" "strings" "time" @@ -26,7 +27,7 @@ const ( // WHOIS - module context type WHOIS struct { clients *clientsContainer - ipChan chan net.IP + ipChan chan netip.Addr // dialContext specifies the dial function for creating unencrypted TCP // connections. @@ -51,7 +52,7 @@ func initWHOIS(clients *clientsContainer) *WHOIS { MaxCount: 10000, }), dialContext: customDialContext, - ipChan: make(chan net.IP, 255), + ipChan: make(chan netip.Addr, 255), } go w.workerLoop() @@ -192,7 +193,7 @@ func (w *WHOIS) queryAll(ctx context.Context, target string) (string, error) { } // Request WHOIS information -func (w *WHOIS) process(ctx context.Context, ip net.IP) (wi *RuntimeClientWHOISInfo) { +func (w *WHOIS) process(ctx context.Context, ip netip.Addr) (wi *RuntimeClientWHOISInfo) { resp, err := w.queryAll(ctx, ip.String()) if err != nil { log.Debug("whois: error: %s IP:%s", err, ip) @@ -220,24 +221,25 @@ func (w *WHOIS) process(ctx context.Context, ip net.IP) (wi *RuntimeClientWHOISI } // Begin - begin requesting WHOIS info -func (w *WHOIS) Begin(ip net.IP) { +func (w *WHOIS) Begin(ip netip.Addr) { + ipBytes := ip.AsSlice() now := uint64(time.Now().Unix()) - expire := w.ipAddrs.Get([]byte(ip)) + expire := w.ipAddrs.Get(ipBytes) if len(expire) != 0 { exp := binary.BigEndian.Uint64(expire) if exp > now { return } - // TTL expired } + expire = make([]byte, 8) binary.BigEndian.PutUint64(expire, now+whoisTTL) - _ = w.ipAddrs.Set([]byte(ip), expire) + _ = w.ipAddrs.Set(ipBytes, expire) log.Debug("whois: adding %s", ip) + select { case w.ipChan <- ip: - // default: log.Debug("whois: queue is full") } diff --git a/internal/stats/stats.go b/internal/stats/stats.go index e939790775b..0ac8d9be837 100644 --- a/internal/stats/stats.go +++ b/internal/stats/stats.go @@ -6,6 +6,7 @@ import ( "fmt" "io" "net" + "net/netip" "os" "sync" "sync/atomic" @@ -64,7 +65,7 @@ type Interface interface { // GetTopClientIP returns at most limit IP addresses corresponding to the // clients with the most number of requests. - TopClientsIP(limit uint) []net.IP + TopClientsIP(limit uint) []netip.Addr // WriteDiskConfig puts the Interface's configuration to the dc. WriteDiskConfig(dc *DiskConfig) @@ -107,8 +108,6 @@ type StatsCtx struct { filename string } -var _ Interface = &StatsCtx{} - // New creates s from conf and properly initializes it. Don't use s before // calling it's Start method. func New(conf Config) (s *StatsCtx, err error) { @@ -178,6 +177,9 @@ func withRecovered(orig *error) { *orig = errors.WithDeferred(*orig, err) } +// type check +var _ Interface = (*StatsCtx)(nil) + // Start implements the Interface interface for *StatsCtx. func (s *StatsCtx) Start() { s.initWeb() @@ -250,8 +252,8 @@ func (s *StatsCtx) WriteDiskConfig(dc *DiskConfig) { dc.Interval = atomic.LoadUint32(&s.limitHours) / 24 } -// TopClientsIP implements the Interface interface for *StatsCtx. -func (s *StatsCtx) TopClientsIP(maxCount uint) (ips []net.IP) { +// TopClientsIP implements the [Interface] interface for *StatsCtx. +func (s *StatsCtx) TopClientsIP(maxCount uint) (ips []netip.Addr) { limit := atomic.LoadUint32(&s.limitHours) if limit == 0 { return nil @@ -271,10 +273,10 @@ func (s *StatsCtx) TopClientsIP(maxCount uint) (ips []net.IP) { } a := convertMapToSlice(m, int(maxCount)) - ips = []net.IP{} + ips = []netip.Addr{} for _, it := range a { - ip := net.ParseIP(it.Name) - if ip != nil { + ip, err := netip.ParseAddr(it.Name) + if err == nil { ips = append(ips, ip) } } diff --git a/internal/stats/stats_test.go b/internal/stats/stats_test.go index bb2cc0d8df7..cd88cbaf822 100644 --- a/internal/stats/stats_test.go +++ b/internal/stats/stats_test.go @@ -11,6 +11,7 @@ import ( "testing" "github.com/AdguardTeam/AdGuardHome/internal/stats" + "github.com/AdguardTeam/golibs/netutil" "github.com/AdguardTeam/golibs/testutil" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" @@ -45,7 +46,7 @@ func assertSuccessAndUnmarshal(t *testing.T, to any, handler http.Handler, req * } func TestStats(t *testing.T) { - cliIP := net.IP{127, 0, 0, 1} + cliIP := netutil.IPv4Localhost() cliIPStr := cliIP.String() handlers := map[string]http.Handler{} @@ -123,7 +124,7 @@ func TestStats(t *testing.T) { topClients := s.TopClientsIP(2) require.NotEmpty(t, topClients) - assert.True(t, cliIP.Equal(topClients[0])) + assert.Equal(t, cliIP, topClients[0]) }) t.Run("reset", func(t *testing.T) {