diff --git a/internal/aghstrings/set.go b/internal/aghstrings/set.go new file mode 100644 index 00000000000..81f031e1aaa --- /dev/null +++ b/internal/aghstrings/set.go @@ -0,0 +1,75 @@ +package aghstrings + +// unit is a convenient alias for struct{} +type unit = struct{} + +// Set is a set of strings. +type Set struct { + m map[string]unit +} + +// NewSet returns a new string set containing strs. +func NewSet(strs ...string) (set *Set) { + set = &Set{ + m: make(map[string]unit, len(strs)), + } + + for _, s := range strs { + set.Add(s) + } + + return set +} + +// Add adds s to the set. Add panics if the set is a nil set, just like a nil +// map does. +func (set *Set) Add(s string) { + set.m[s] = unit{} +} + +// Del deletes s from the set. Calling Del on a nil set has no effect, just +// like delete on an empty map doesn't. +func (set *Set) Del(s string) { + if set == nil { + return + } + + delete(set.m, s) +} + +// Has returns true if s is in the set. Calling Has on a nil set returns false, +// just like indexing on an empty map does. +func (set *Set) Has(s string) (ok bool) { + if set == nil { + return false + } + + _, ok = set.m[s] + + return ok +} + +// Len returns the length of the set. A nil set has a length of zero, just like +// an empty map. +func (set *Set) Len() (n int) { + if set == nil { + return 0 + } + + return len(set.m) +} + +// Values returns all values in the set. The order of the values is undefined. +// Values returns nil if the set is nil. +func (set *Set) Values() (strs []string) { + if set == nil { + return nil + } + + strs = make([]string, 0, len(set.m)) + for s := range set.m { + strs = append(strs, s) + } + + return strs +} diff --git a/internal/aghstrings/set_test.go b/internal/aghstrings/set_test.go new file mode 100644 index 00000000000..a344e8afed1 --- /dev/null +++ b/internal/aghstrings/set_test.go @@ -0,0 +1,56 @@ +package aghstrings + +import ( + "testing" + + "github.com/stretchr/testify/assert" +) + +func TestSet(t *testing.T) { + const s = "a" + + t.Run("nil", func(t *testing.T) { + var set *Set + + assert.NotPanics(t, func() { + set.Del(s) + }) + + assert.NotPanics(t, func() { + assert.False(t, set.Has(s)) + }) + + assert.NotPanics(t, func() { + assert.Equal(t, 0, set.Len()) + }) + + assert.NotPanics(t, func() { + assert.Nil(t, set.Values()) + }) + + assert.Panics(t, func() { + set.Add(s) + }) + }) + + t.Run("non_nil", func(t *testing.T) { + set := NewSet() + assert.Equal(t, 0, set.Len()) + + ok := set.Has(s) + assert.False(t, ok) + + set.Add(s) + ok = set.Has(s) + assert.True(t, ok) + + assert.Equal(t, []string{s}, set.Values()) + + set.Del(s) + ok = set.Has(s) + assert.False(t, ok) + + set = NewSet(s) + assert.Equal(t, 1, set.Len()) + }) +} diff --git a/internal/dhcpd/iprange.go b/internal/dhcpd/iprange.go index 36ca28959f5..3591869a929 100644 --- a/internal/dhcpd/iprange.go +++ b/internal/dhcpd/iprange.go @@ -110,3 +110,8 @@ func (r *ipRange) offset(ip net.IP) (offset uint64, ok bool) { // construction. return offsetInt.Uint64(), true } + +// String implements the fmt.Stringer interface for *ipRange. +func (r *ipRange) String() (s string) { + return fmt.Sprintf("%s-%s", r.start, r.end) +} diff --git a/internal/dhcpd/server.go b/internal/dhcpd/server.go index 43caa308299..51395535afe 100644 --- a/internal/dhcpd/server.go +++ b/internal/dhcpd/server.go @@ -64,9 +64,12 @@ type V4ServerConf struct { leaseTime time.Duration // the time during which a dynamic lease is considered valid dnsIPAddrs []net.IP // IPv4 addresses to return to DHCP clients as DNS server addresses - routerIP net.IP // value for Option Router - subnetMask net.IPMask // value for Option SubnetMask - options []dhcpOption + + // subnet contains the DHCP server's subnet. The IP is the IP of the + // gateway. + subnet *net.IPNet + + options []dhcpOption // notify is a way to signal to other components that leases have // change. notify must be called outside of locked sections, since the diff --git a/internal/dhcpd/v4.go b/internal/dhcpd/v4.go index 5d8f5359604..1148b7e92e3 100644 --- a/internal/dhcpd/v4.go +++ b/internal/dhcpd/v4.go @@ -12,6 +12,7 @@ import ( "github.com/AdguardTeam/AdGuardHome/internal/agherr" "github.com/AdguardTeam/AdGuardHome/internal/aghnet" + "github.com/AdguardTeam/AdGuardHome/internal/aghstrings" "github.com/AdguardTeam/golibs/log" "github.com/go-ping/ping" "github.com/insomniacslk/dhcp/dhcpv4" @@ -29,6 +30,9 @@ type v4Server struct { // leased. leasedOffsets *bitSet + // leaseHosts is the set of all the hosts of all known DHCP clients. + leaseHosts *aghstrings.Set + // leases contains all dynamic and static leases. leases []*Lease @@ -49,22 +53,16 @@ func (s *v4Server) WriteDiskConfig6(c *V6ServerConf) { func (s *v4Server) ResetLeases(leases []*Lease) { s.leases = nil - r := s.conf.ipRange for _, l := range leases { - if !l.IsStatic() && !r.contains(l.IP) { - log.Debug( - "dhcpv4: skipping lease %s (%s): not within current ip range", - l.IP, - l.HWAddr, - ) - - continue - } - err := s.addLease(l) if err != nil { // TODO(a.garipov): Better error handling. - log.Error("dhcpv4: adding a lease for %s (%s): %s", l.IP, l.HWAddr, err) + log.Error( + "dhcpv4: reset: re-adding a lease for %s (%s): %s", + l.IP, + l.HWAddr, + err, + ) continue } @@ -174,17 +172,14 @@ func (s *v4Server) rmLeaseByIndex(i int) { l := s.leases[i] s.leases = append(s.leases[:i], s.leases[i+1:]...) - n = len(s.leases) - if n > 0 { - s.leases = s.leases[:n-1] - } - r := s.conf.ipRange offset, ok := r.offset(l.IP) if ok { s.leasedOffsets.set(offset, false) } + s.leaseHosts.Del(l.Hostname) + log.Debug("dhcpv4: removed lease %s (%s)", l.IP, l.HWAddr) } @@ -220,13 +215,8 @@ func (s *v4Server) rmDynamicLease(lease *Lease) (err error) { } func (s *v4Server) addStaticLease(l *Lease) (err error) { - subnet := &net.IPNet{ - IP: s.conf.routerIP, - Mask: s.conf.subnetMask, - } - - if !subnet.Contains(l.IP) { - return fmt.Errorf("subnet %s does not contain the ip %q", subnet, l.IP) + if sn := s.conf.subnet; !sn.Contains(l.IP) { + return fmt.Errorf("subnet %s does not contain the ip %q", sn, l.IP) } s.leases = append(s.leases, l) @@ -237,6 +227,8 @@ func (s *v4Server) addStaticLease(l *Lease) (err error) { s.leasedOffsets.set(offset, true) } + s.leaseHosts.Add(l.Hostname) + return nil } @@ -248,6 +240,7 @@ func (s *v4Server) addDynamicLease(l *Lease) (err error) { } s.leases = append(s.leases, l) + s.leaseHosts.Add(l.Hostname) s.leasedOffsets.set(offset, true) return nil @@ -255,6 +248,11 @@ func (s *v4Server) addDynamicLease(l *Lease) (err error) { // addLease adds a dynamic or static lease. func (s *v4Server) addLease(l *Lease) (err error) { + err = s.validateLease(l) + if err != nil { + return err + } + if l.IsStatic() { return s.addStaticLease(l) } @@ -291,13 +289,13 @@ func (s *v4Server) AddStaticLease(l Lease) (err error) { return fmt.Errorf("invalid ip %q, only ipv4 is supported", l.IP) } - err = aghnet.ValidateHardwareAddress(l.HWAddr) + l.Expiry = time.Unix(leaseExpireStatic, 0) + + l.Hostname, err = normalizeHostname(l.Hostname) if err != nil { - return fmt.Errorf("validating lease: %w", err) + return err } - l.Expiry = time.Unix(leaseExpireStatic, 0) - // Perform the following actions in an anonymous function to make sure // that the lock gets unlocked before the notification step. func() { @@ -564,7 +562,8 @@ func normalizeHostname(name string) (norm string, err error) { return "", nil } - parts := strings.FieldsFunc(name, func(c rune) (ok bool) { + norm = strings.ToLower(name) + parts := strings.FieldsFunc(norm, func(c rune) (ok bool) { return c != '.' && !aghnet.IsValidHostOuterRune(c) }) @@ -580,19 +579,49 @@ func normalizeHostname(name string) (norm string, err error) { // validateHostname validates a hostname sent by the client. func (s *v4Server) validateHostname(name string) (err error) { + defer agherr.Annotate("validating hostname: %s", &err) + if name == "" { return nil } err = aghnet.ValidateDomainName(name) if err != nil { - return fmt.Errorf("validating hostname: %w", err) + return err } - // TODO(a.garipov): Add client hostname uniqueness validation either - // here or into method processRequest. This is not as easy as it might - // look like, because the process of adding and releasing a lease is - // currently non-straightforward. + if s.leaseHosts.Has(name) { + return agherr.Error("hostname exists") + } + + return nil +} + +// validateLease returns an error if the lease is invalid. +func (s *v4Server) validateLease(l *Lease) (err error) { + defer agherr.Annotate("validating lease: %s", &err) + if l == nil { + return agherr.Error("lease is nil") + } + + err = aghnet.ValidateHardwareAddress(l.HWAddr) + if err != nil { + return err + } + + err = s.validateHostname(l.Hostname) + if err != nil { + return err + } + + if sn := s.conf.subnet; !sn.Contains(l.IP) { + return fmt.Errorf("subnet %s does not contain the ip %q", sn, l.IP) + } + + r := s.conf.ipRange + if !l.IsStatic() && !r.contains(l.IP) { + return fmt.Errorf("dynamic lease range %s does not contain the ip %q", r, l.IP) + } return nil } @@ -655,15 +684,27 @@ func (s *v4Server) processRequest(req, resp *dhcpv4.DHCPv4) (lease *Lease, ok bo // Go on and assign a hostname made from the IP. } - if hostname != "" && cliHostname != hostname { - log.Debug("dhcpv4: normalized hostname %q into %q", cliHostname, hostname) - } + if hostname != "" { + if cliHostname != hostname { + log.Debug( + "dhcpv4: normalized hostname %q into %q", + cliHostname, + hostname, + ) + } - err = s.validateHostname(hostname) - if err != nil { - log.Error("dhcpv4: validating hostname for %s: %s", mac, err) + if lease.Hostname != hostname { + // Either a new lease or an old lease with a new + // hostname, so validate. + err = s.validateHostname(hostname) + if err != nil { + log.Error("dhcpv4: validating %s: %s", mac, err) - // Go on and assign a hostname made from the IP. + // Go on and assign a hostname made from + // the IP below. + hostname = "" + } + } } if hostname == "" { @@ -726,8 +767,8 @@ func (s *v4Server) process(req, resp *dhcpv4.DHCPv4) int { copy(resp.YourIPAddr, l.IP) resp.UpdateOption(dhcpv4.OptIPAddressLeaseTime(s.conf.leaseTime)) - resp.UpdateOption(dhcpv4.OptRouter(s.conf.routerIP)) - resp.UpdateOption(dhcpv4.OptSubnetMask(s.conf.subnetMask)) + resp.UpdateOption(dhcpv4.OptRouter(s.conf.subnet.IP)) + resp.UpdateOption(dhcpv4.OptSubnetMask(s.conf.subnet.Mask)) resp.UpdateOption(dhcpv4.OptDNS(s.conf.dnsIPAddrs...)) for _, opt := range s.conf.options { @@ -855,6 +896,7 @@ func (s *v4Server) Stop() { func v4Create(conf V4ServerConf) (srv DHCPServer, err error) { s := &v4Server{} s.conf = conf + s.leaseHosts = aghstrings.NewSet() // TODO(a.garipov): Don't use a disabled server in other places or just // use an interface. @@ -862,7 +904,8 @@ func v4Create(conf V4ServerConf) (srv DHCPServer, err error) { return s, nil } - s.conf.routerIP, err = tryTo4(s.conf.GatewayIP) + var routerIP net.IP + routerIP, err = tryTo4(s.conf.GatewayIP) if err != nil { return s, fmt.Errorf("dhcpv4: %w", err) } @@ -870,8 +913,14 @@ func v4Create(conf V4ServerConf) (srv DHCPServer, err error) { if s.conf.SubnetMask == nil { return s, fmt.Errorf("dhcpv4: invalid subnet mask: %v", s.conf.SubnetMask) } - s.conf.subnetMask = make([]byte, 4) - copy(s.conf.subnetMask, s.conf.SubnetMask.To4()) + + subnetMask := make([]byte, 4) + copy(subnetMask, s.conf.SubnetMask.To4()) + + s.conf.subnet = &net.IPNet{ + IP: routerIP, + Mask: subnetMask, + } s.conf.ipRange, err = newIPRange(conf.RangeStart, conf.RangeEnd) if err != nil { diff --git a/internal/dhcpd/v4_test.go b/internal/dhcpd/v4_test.go index 8afbdf1dda2..303b6cb8e93 100644 --- a/internal/dhcpd/v4_test.go +++ b/internal/dhcpd/v4_test.go @@ -157,7 +157,7 @@ func TestV4StaticLease_Get(t *testing.T) { assert.True(t, l.IP.Equal(resp.YourIPAddr)) assert.True(t, s.conf.GatewayIP.Equal(resp.Router()[0])) assert.True(t, s.conf.GatewayIP.Equal(resp.ServerIdentifier())) - assert.Equal(t, s.conf.subnetMask, resp.SubnetMask()) + assert.Equal(t, s.conf.subnet.Mask, resp.SubnetMask()) assert.Equal(t, s.conf.leaseTime.Seconds(), resp.IPAddressLeaseTime(-1).Seconds()) }) @@ -179,7 +179,7 @@ func TestV4StaticLease_Get(t *testing.T) { assert.True(t, l.IP.Equal(resp.YourIPAddr)) assert.True(t, s.conf.GatewayIP.Equal(resp.Router()[0])) assert.True(t, s.conf.GatewayIP.Equal(resp.ServerIdentifier())) - assert.Equal(t, s.conf.subnetMask, resp.SubnetMask()) + assert.Equal(t, s.conf.subnet.Mask, resp.SubnetMask()) assert.Equal(t, s.conf.leaseTime.Seconds(), resp.IPAddressLeaseTime(-1).Seconds()) }) @@ -246,7 +246,7 @@ func TestV4DynamicLease_Get(t *testing.T) { assert.Equal(t, s.conf.GatewayIP, router[0]) - assert.Equal(t, s.conf.subnetMask, resp.SubnetMask()) + assert.Equal(t, s.conf.subnet.Mask, resp.SubnetMask()) assert.Equal(t, s.conf.leaseTime.Seconds(), resp.IPAddressLeaseTime(-1).Seconds()) assert.Equal(t, []byte("012"), resp.Options[uint8(dhcpv4.OptionFQDN)]) @@ -271,7 +271,7 @@ func TestV4DynamicLease_Get(t *testing.T) { assert.True(t, s.conf.RangeStart.Equal(resp.YourIPAddr)) assert.True(t, s.conf.GatewayIP.Equal(resp.Router()[0])) assert.True(t, s.conf.GatewayIP.Equal(resp.ServerIdentifier())) - assert.Equal(t, s.conf.subnetMask, resp.SubnetMask()) + assert.Equal(t, s.conf.subnet.Mask, resp.SubnetMask()) assert.Equal(t, s.conf.leaseTime.Seconds(), resp.IPAddressLeaseTime(-1).Seconds()) })