Skip to content

Commit

Permalink
all: move more client code to netip.Addr
Browse files Browse the repository at this point in the history
  • Loading branch information
ainar-g committed Nov 9, 2022
1 parent 98af0e0 commit 1934ea1
Show file tree
Hide file tree
Showing 17 changed files with 164 additions and 210 deletions.
6 changes: 0 additions & 6 deletions internal/aghnet/net.go
Expand Up @@ -31,12 +31,6 @@ var (
// the IP being static is available.
const ErrNoStaticIPInfo errors.Error = "no information about static ip"

// IPv4Localhost returns 127.0.0.1, which returns true for [netip.Addr.Is4].
func IPv4Localhost() (ip netip.Addr) { return netip.AddrFrom4([4]byte{127, 0, 0, 1}) }

// IPv6Localhost returns ::1, which returns true for [netip.Addr.Is6].
func IPv6Localhost() (ip netip.Addr) { return netip.AddrFrom16([16]byte{15: 1}) }

// IfaceHasStaticIP checks if interface is configured to have static IP address.
// If it can't give a definitive answer, it returns false and an error for which
// errors.Is(err, ErrNoStaticIPInfo) is true.
Expand Down
2 changes: 1 addition & 1 deletion internal/aghnet/net_test.go
Expand Up @@ -188,7 +188,7 @@ func TestBroadcastFromIPNet(t *testing.T) {
}

func TestCheckPort(t *testing.T) {
laddr := netip.AddrPortFrom(IPv4Localhost(), 0)
laddr := netip.AddrPortFrom(netutil.IPv4Localhost(), 0)

t.Run("tcp_bound", func(t *testing.T) {
l, err := net.Listen("tcp", laddr.String())
Expand Down
12 changes: 1 addition & 11 deletions internal/dnsforward/clientid.go
Expand Up @@ -23,16 +23,6 @@ func ValidateClientID(id string) (err error) {
return nil
}

// hasLabelSuffix returns true if s ends with suffix preceded by a dot. It's
// a helper function to prevent unnecessary allocations in code like:
//
// if strings.HasSuffix(s, "." + suffix) { /* … */ }
//
// s must be longer than suffix.
func hasLabelSuffix(s, suffix string) (ok bool) {
return strings.HasSuffix(s, suffix) && s[len(s)-len(suffix)-1] == '.'
}

// clientIDFromClientServerName extracts and validates a ClientID. hostSrvName
// is the server name of the host. cliSrvName is the server name as sent by the
// client. When strict is true, and client and host server name don't match,
Expand All @@ -46,7 +36,7 @@ func clientIDFromClientServerName(
return "", nil
}

if !hasLabelSuffix(cliSrvName, hostSrvName) {
if !netutil.IsImmediateSubdomain(cliSrvName, hostSrvName) {
if !strict {
return "", nil
}
Expand Down
18 changes: 7 additions & 11 deletions internal/dnsforward/dnsforward.go
Expand Up @@ -246,6 +246,7 @@ type RDNSExchanger interface {
// Exchange tries to resolve the ip in a suitable way, e.g. either as
// local or as external.
Exchange(ip net.IP) (host string, err error)

// ResolvesPrivatePTR returns true if the RDNSExchanger is able to
// resolve PTR requests for locally-served addresses.
ResolvesPrivatePTR() (ok bool)
Expand All @@ -261,6 +262,9 @@ const (
rDNSNotPTRErr errors.Error = "the response is not a ptr"
)

// type check
var _ RDNSExchanger = (*Server)(nil)

// Exchange implements the RDNSExchanger interface for *Server.
func (s *Server) Exchange(ip net.IP) (host string, err error) {
s.serverLock.RLock()
Expand Down Expand Up @@ -675,21 +679,13 @@ func (s *Server) ServeHTTP(w http.ResponseWriter, r *http.Request) {

// IsBlockedClient returns true if the client is blocked by the current access
// settings.
func (s *Server) IsBlockedClient(ip net.IP, clientID string) (blocked bool, rule string) {
func (s *Server) IsBlockedClient(ip netip.Addr, clientID string) (blocked bool, rule string) {
s.serverLock.RLock()
defer s.serverLock.RUnlock()

blockedByIP := false
if ip != nil {
// TODO(a.garipov): Remove once we switch to netip.Addr more fully.
ipAddr, err := netutil.IPToAddrNoMapped(ip)
if err != nil {
log.Error("dnsforward: bad client ip %v: %s", ip, err)

return false, ""
}

blockedByIP, rule = s.access.isBlockedIP(ipAddr)
if ip != (netip.Addr{}) {
blockedByIP, rule = s.access.isBlockedIP(ip)
}

allowlistMode := s.access.allowlistMode()
Expand Down
4 changes: 2 additions & 2 deletions internal/dnsforward/filter.go
Expand Up @@ -19,13 +19,13 @@ func (s *Server) beforeRequestHandler(
_ *proxy.Proxy,
pctx *proxy.DNSContext,
) (reply bool, err error) {
ip, _ := netutil.IPAndPortFromAddr(pctx.Addr)
clientID, err := s.clientIDFromDNSContext(pctx)
if err != nil {
return false, fmt.Errorf("getting clientid: %w", err)
}

blocked, _ := s.IsBlockedClient(ip, clientID)
addrPort := netutil.NetAddrToAddrPort(pctx.Addr)
blocked, _ := s.IsBlockedClient(addrPort.Addr(), clientID)
if blocked {
return s.preBlockedResponse(pctx)
}
Expand Down
4 changes: 2 additions & 2 deletions internal/filtering/filter_test.go
Expand Up @@ -11,7 +11,7 @@ import (
"testing"
"time"

"github.com/AdguardTeam/AdGuardHome/internal/aghnet"
"github.com/AdguardTeam/golibs/netutil"
"github.com/AdguardTeam/golibs/testutil"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
Expand Down Expand Up @@ -40,7 +40,7 @@ func serveFiltersLocally(t *testing.T, fltContent []byte) (ipp netip.AddrPort) {
addr := l.Addr()
require.IsType(t, new(net.TCPAddr), addr)

return netip.AddrPortFrom(aghnet.IPv4Localhost(), uint16(addr.(*net.TCPAddr).Port))
return netip.AddrPortFrom(netutil.IPv4Localhost(), uint16(addr.(*net.TCPAddr).Port))
}

func TestFilters(t *testing.T) {
Expand Down
102 changes: 38 additions & 64 deletions internal/home/clients.go
Expand Up @@ -129,7 +129,7 @@ type RuntimeClientWHOISInfo struct {

type clientsContainer struct {
// TODO(a.garipov): Perhaps use a number of separate indices for
// different types (string, net.IP, and so on).
// different types (string, netip.Addr, and so on).
list map[string]*Client // name -> client
idIndex map[string]*Client // ID -> client

Expand Down Expand Up @@ -333,7 +333,7 @@ func (clients *clientsContainer) onDHCPLeaseChanged(flags int) {
}

// exists checks if client with this IP address already exists.
func (clients *clientsContainer) exists(ip net.IP, source clientSource) (ok bool) {
func (clients *clientsContainer) exists(ip netip.Addr, source clientSource) (ok bool) {
clients.lock.Lock()
defer clients.lock.Unlock()

Expand All @@ -342,7 +342,7 @@ func (clients *clientsContainer) exists(ip net.IP, source clientSource) (ok bool
return true
}

rc, ok := clients.findRuntimeClientLocked(ip)
rc, ok := clients.ipToRC[ip]
if !ok {
return false
}
Expand Down Expand Up @@ -371,7 +371,8 @@ func (clients *clientsContainer) findMultiple(ids []string) (c *querylog.Client,
var artClient *querylog.Client
var art bool
for _, id := range ids {
c, art = clients.clientOrArtificial(net.ParseIP(id), id)
ip, _ := netip.ParseAddr(id)
c, art = clients.clientOrArtificial(ip, id)
if art {
artClient = c

Expand All @@ -389,7 +390,7 @@ func (clients *clientsContainer) findMultiple(ids []string) (c *querylog.Client,
// records about this client besides maybe whether or not it is blocked. c is
// never nil.
func (clients *clientsContainer) clientOrArtificial(
ip net.IP,
ip netip.Addr,
id string,
) (c *querylog.Client, art bool) {
defer func() {
Expand All @@ -406,13 +407,6 @@ func (clients *clientsContainer) clientOrArtificial(
}, false
}

if ip == nil {
// Technically should never happen, but still.
return &querylog.Client{
Name: "",
}, true
}

var rc *RuntimeClient
rc, ok = clients.findRuntimeClient(ip)
if ok {
Expand Down Expand Up @@ -492,19 +486,20 @@ func (clients *clientsContainer) findLocked(id string) (c *Client, ok bool) {
return c, true
}

ip := net.ParseIP(id)
if ip == nil {
ip, err := netip.ParseAddr(id)
if err != nil {
return nil, false
}

for _, c = range clients.list {
for _, id := range c.IDs {
_, ipnet, err := net.ParseCIDR(id)
var n netip.Prefix
n, err = netip.ParsePrefix(id)
if err != nil {
continue
}

if ipnet.Contains(ip) {
if n.Contains(ip) {
return c, true
}
}
Expand All @@ -514,19 +509,20 @@ func (clients *clientsContainer) findLocked(id string) (c *Client, ok bool) {
return nil, false
}

macFound := clients.dhcpServer.FindMACbyIP(ip)
macFound := clients.dhcpServer.FindMACbyIP(ip.AsSlice())
if macFound == nil {
return nil, false
}

for _, c = range clients.list {
for _, id := range c.IDs {
hwAddr, err := net.ParseMAC(id)
var mac net.HardwareAddr
mac, err = net.ParseMAC(id)
if err != nil {
continue
}

if bytes.Equal(hwAddr, macFound) {
if bytes.Equal(mac, macFound) {
return c, true
}
}
Expand All @@ -535,32 +531,18 @@ func (clients *clientsContainer) findLocked(id string) (c *Client, ok bool) {
return nil, false
}

// findRuntimeClientLocked finds a runtime client by their IP address. For
// internal use only.
func (clients *clientsContainer) findRuntimeClientLocked(ip net.IP) (rc *RuntimeClient, ok bool) {
// TODO(a.garipov): Remove once we switch to netip.Addr more fully.
ipAddr, err := netutil.IPToAddrNoMapped(ip)
if err != nil {
log.Error("clients: bad client ip %v: %s", ip, err)

return nil, false
}

rc, ok = clients.ipToRC[ipAddr]

return rc, ok
}

// findRuntimeClient finds a runtime client by their IP.
func (clients *clientsContainer) findRuntimeClient(ip net.IP) (rc *RuntimeClient, ok bool) {
if ip == nil {
func (clients *clientsContainer) findRuntimeClient(ip netip.Addr) (rc *RuntimeClient, ok bool) {
if ip == (netip.Addr{}) {
return nil, false
}

clients.lock.Lock()
defer clients.lock.Unlock()

return clients.findRuntimeClientLocked(ip)
rc, ok = clients.ipToRC[ip]

return rc, ok
}

// check validates the client.
Expand All @@ -578,14 +560,16 @@ func (clients *clientsContainer) check(c *Client) (err error) {

for i, id := range c.IDs {
// Normalize structured data.
var ip net.IP
var ipnet *net.IPNet
var mac net.HardwareAddr
if ip = net.ParseIP(id); ip != nil {
var (
ip netip.Addr
n netip.Prefix
mac net.HardwareAddr
)

if ip, err = netip.ParseAddr(id); err == nil {
c.IDs[i] = ip.String()
} else if ip, ipnet, err = net.ParseCIDR(id); err == nil {
ipnet.IP = ip
c.IDs[i] = ipnet.String()
} else if n, err = netip.ParsePrefix(id); err == nil {
c.IDs[i] = n.String()
} else if mac, err = net.ParseMAC(id); err == nil {
c.IDs[i] = mac.String()
} else if err = dnsforward.ValidateClientID(id); err == nil {
Expand Down Expand Up @@ -750,7 +734,7 @@ func (clients *clientsContainer) Update(name string, c *Client) (err error) {
}

// setWHOISInfo sets the WHOIS information for a client.
func (clients *clientsContainer) setWHOISInfo(ip net.IP, wi *RuntimeClientWHOISInfo) {
func (clients *clientsContainer) setWHOISInfo(ip netip.Addr, wi *RuntimeClientWHOISInfo) {
clients.lock.Lock()
defer clients.lock.Unlock()

Expand All @@ -760,7 +744,7 @@ func (clients *clientsContainer) setWHOISInfo(ip net.IP, wi *RuntimeClientWHOISI
return
}

rc, ok := clients.findRuntimeClientLocked(ip)
rc, ok := clients.ipToRC[ip]
if ok {
rc.WHOISInfo = wi
log.Debug("clients: set whois info for runtime client %s: %+v", rc.Host, wi)
Expand All @@ -776,32 +760,22 @@ func (clients *clientsContainer) setWHOISInfo(ip net.IP, wi *RuntimeClientWHOISI

rc.WHOISInfo = wi

// TODO(a.garipov): Remove once we switch to netip.Addr more fully.
ipAddr, err := netutil.IPToAddrNoMapped(ip)
if err != nil {
log.Error("clients: bad client ip %v: %s", ip, err)

return
}

clients.ipToRC[ipAddr] = rc
clients.ipToRC[ip] = rc

log.Debug("clients: set whois info for runtime client with ip %s: %+v", ip, wi)
}

// AddHost adds a new IP-hostname pairing. The priorities of the sources are
// taken into account. ok is true if the pairing was added.
func (clients *clientsContainer) AddHost(ip net.IP, host string, src clientSource) (ok bool, err error) {
func (clients *clientsContainer) AddHost(
ip netip.Addr,
host string,
src clientSource,
) (ok bool) {
clients.lock.Lock()
defer clients.lock.Unlock()

// TODO(a.garipov): Remove once we switch to netip.Addr more fully.
ipAddr, err := netutil.IPToAddrNoMapped(ip)
if err != nil {
return false, fmt.Errorf("adding host: %w", err)
}

return clients.addHostLocked(ipAddr, host, src), nil
return clients.addHostLocked(ip, host, src)
}

// addHostLocked adds a new IP-hostname pairing. clients.lock is expected to be
Expand Down

0 comments on commit 1934ea1

Please sign in to comment.