diff --git a/README.md b/README.md index 6858e41..8b5470b 100644 --- a/README.md +++ b/README.md @@ -48,7 +48,7 @@ This also works with a IPv6 CIDR range, for example: ``` $ cidr count 2001:db8:1234:1a00::/106 -4194304 +4194302 ``` Or with a large prefix like a point-to-point link CIDR range: diff --git a/cmd/contains.go b/cmd/contains.go index 0a99dff..0f7fc49 100644 --- a/cmd/contains.go +++ b/cmd/contains.go @@ -34,7 +34,7 @@ var ( fmt.Println("See 'cidr contains -h' for help and examples") os.Exit(1) } - ip := core.ParseIP(args[1]) + ip := net.ParseIP(args[1]) if ip == nil { fmt.Printf("error: invalid IP address: %s\n", args[1]) fmt.Println("See 'cidr contains -h' for help and examples") diff --git a/go.mod b/go.mod index 54ca2a2..43e3761 100644 --- a/go.mod +++ b/go.mod @@ -4,7 +4,14 @@ go 1.19 require github.com/spf13/cobra v1.5.0 +require ( + github.com/davecgh/go-spew v1.1.1 // indirect + github.com/pmezard/go-difflib v1.0.0 // indirect + gopkg.in/yaml.v3 v3.0.1 // indirect +) + require ( github.com/inconshreveable/mousetrap v1.0.0 // indirect github.com/spf13/pflag v1.0.5 // indirect + github.com/stretchr/testify v1.8.0 ) diff --git a/go.sum b/go.sum index 0d85248..4031008 100644 --- a/go.sum +++ b/go.sum @@ -1,10 +1,23 @@ github.com/cpuguy83/go-md2man/v2 v2.0.2/go.mod h1:tgQtvFlXSQOSOSIRvRPT7W67SCa46tRHOmNcaadrF8o= +github.com/davecgh/go-spew v1.1.0/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= +github.com/davecgh/go-spew v1.1.1 h1:vj9j/u1bqnvCEfJOwUhtlOARqs3+rkHYY13jYWTU97c= +github.com/davecgh/go-spew v1.1.1/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= github.com/inconshreveable/mousetrap v1.0.0 h1:Z8tu5sraLXCXIcARxBp/8cbvlwVa7Z1NHg9XEKhtSvM= github.com/inconshreveable/mousetrap v1.0.0/go.mod h1:PxqpIevigyE2G7u3NXJIT2ANytuPF1OarO4DADm73n8= +github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZbAQM= +github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4= github.com/russross/blackfriday/v2 v2.1.0/go.mod h1:+Rmxgy9KzJVeS9/2gXHxylqXiyQDYRxCVz55jmeOWTM= github.com/spf13/cobra v1.5.0 h1:X+jTBEBqF0bHN+9cSMgmfuvv2VHJ9ezmFNf9Y/XstYU= github.com/spf13/cobra v1.5.0/go.mod h1:dWXEIy2H428czQCjInthrTRUg7yKbok+2Qi/yBIJoUM= github.com/spf13/pflag v1.0.5 h1:iy+VFUOCP1a+8yFto/drg2CJ5u0yRoB7fZw3DKv/JXA= github.com/spf13/pflag v1.0.5/go.mod h1:McXfInJRrz4CZXVZOBLb0bTZqETkiAhM9Iw0y3An2Bg= +github.com/stretchr/objx v0.1.0/go.mod h1:HFkY916IF+rwdDfMAkV7OtwuqBVzrE8GR6GFx+wExME= +github.com/stretchr/objx v0.4.0/go.mod h1:YvHI0jy2hoMjB+UWwv71VJQ9isScKT/TqJzVSSt89Yw= +github.com/stretchr/testify v1.7.1/go.mod h1:6Fq8oRcR53rry900zMqJjRRixrwX3KX962/h/Wwjteg= +github.com/stretchr/testify v1.8.0 h1:pSgiaMZlXftHpm5L7V1+rVB+AZJydKsMxsQBIJw4PKk= +github.com/stretchr/testify v1.8.0/go.mod h1:yNjHg4UonilssWZ8iaSj1OCr/vHnekPRkoO+kdMU+MU= gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0= gopkg.in/yaml.v2 v2.4.0/go.mod h1:RDklbk79AGWmwhnvt/jBztapEOGDOx6ZbXqjP6csGnQ= +gopkg.in/yaml.v3 v3.0.0-20200313102051-9f266ea9e77c/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM= +gopkg.in/yaml.v3 v3.0.1 h1:fxVm/GzAzEWqLHuvctI91KS9hhNmmWOoWu0XTYJS7CA= +gopkg.in/yaml.v3 v3.0.1/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM= diff --git a/pkg/core/core.go b/pkg/core/core.go index aee7ed7..a13c4ea 100644 --- a/pkg/core/core.go +++ b/pkg/core/core.go @@ -11,13 +11,15 @@ func AddressCount(network *net.IPNet) uint64 { if network.Mask != nil { // Handle edge cases switch prefixLen { - case 32: return 1 - case 31: return 2 + case 32: + return 1 + case 31: + return 2 } } // Remember to subtract the network address and broadcast address - return 1 << (uint64(bits) - uint64(prefixLen)) - 2 + return 1<<(uint64(bits)-uint64(prefixLen)) - 2 } func ParseCIDR(network string) (*net.IPNet, error) { @@ -28,10 +30,6 @@ func ParseCIDR(network string) (*net.IPNet, error) { return ip, err } -func ParseIP(ip string) net.IP { - return net.ParseIP(ip) -} - func ContainsAddress(network *net.IPNet, ip net.IP) bool { return network.Contains(ip) } diff --git a/pkg/core/core_test.go b/pkg/core/core_test.go new file mode 100644 index 0000000..6c454ee --- /dev/null +++ b/pkg/core/core_test.go @@ -0,0 +1,183 @@ +package core + +import ( + "net" + "testing" + + "github.com/stretchr/testify/assert" +) + +func TestAddressCount(t *testing.T) { + IPv4CIDR, err := ParseCIDR("10.0.0.0/16") + if err != nil { + t.Log(err) + t.Fail() + } + + IPv6CIDR, err := ParseCIDR("2001:db8:1234:1a00::/106") + if err != nil { + t.Log(err) + t.Fail() + } + + largeIPv4PrefixCIDR, err := ParseCIDR("172.16.18.0/31") + if err != nil { + t.Log(err) + t.Fail() + } + + largestIPv4PrefixCIDR, err := ParseCIDR("172.16.18.0/32") + if err != nil { + t.Log(err) + t.Fail() + } + + tests := []struct { + name string + cidr *net.IPNet + expectedCount uint64 + }{ + { + name: "Return the count of all distinct host addresses in a common IPv4 CIDR", + cidr: IPv4CIDR, + expectedCount: 65534, + }, + { + name: "Return the count of all distinct host addresses in a common IPv6 CIDR", + cidr: IPv6CIDR, + expectedCount: 4194302, + }, + { + name: "Return the count of all distinct host addresses in an uncommon (large prefix) IPv4 CIDR", + cidr: largeIPv4PrefixCIDR, + expectedCount: 2, + }, + { + name: "Return the count of all distinct host addresses in an uncommon (largest prefix) IPv4 CIDR", + cidr: largestIPv4PrefixCIDR, + expectedCount: 1, + }, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + count := AddressCount(tt.cidr) + assert.Equal(t, int(tt.expectedCount), int(count), "Both address counts should be equal") + }) + } +} + +func TestOverlaps(t *testing.T) { + firstIPv4CIDR, err := ParseCIDR("10.0.0.0/16") + if err != nil { + t.Log(err) + t.Fail() + } + + secondIPv4CIDR, err := ParseCIDR("10.0.14.0/22") + if err != nil { + t.Log(err) + t.Fail() + } + + thirdIPv4CIDR, err := ParseCIDR("10.1.0.0/28") + if err != nil { + t.Log(err) + t.Fail() + } + + firstIPv6CIDR, err := ParseCIDR("2001:db8:1111:2222:1::/80") + if err != nil { + t.Log(err) + t.Fail() + } + + secondIPv6CIDR, err := ParseCIDR("2001:db8:1111:2222:1:1::/96") + if err != nil { + t.Log(err) + t.Fail() + } + + tests := []struct { + name string + cidrA *net.IPNet + cidrB *net.IPNet + overlaps bool + }{ + { + name: "2 IPv4 CIDR ranges should overlap", + cidrA: firstIPv4CIDR, + cidrB: secondIPv4CIDR, + overlaps: true, + }, + { + name: "2 IPv4 CIDR ranges should NOT overlap", + cidrA: firstIPv4CIDR, + cidrB: thirdIPv4CIDR, + overlaps: false, + }, + { + name: "2 IPv6 CIDR ranges should overlap", + cidrA: firstIPv6CIDR, + cidrB: secondIPv6CIDR, + overlaps: true, + }, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + overlaps := Overlaps(tt.cidrA, tt.cidrB) + assert.Equal(t, tt.overlaps, overlaps, "Given CIDRs should overlap") + }) + } +} + +func TestContainsAddress(t *testing.T) { + IPv4CIDR, err := ParseCIDR("10.0.0.0/16") + if err != nil { + t.Log(err) + t.Fail() + } + + IPv6CIDR, err := ParseCIDR("2001:db8:1234:1a00::/106") + if err != nil { + t.Log(err) + t.Fail() + } + + tests := []struct { + name string + cidr *net.IPNet + ip net.IP + contains bool + }{ + { + name: "IPv4 CIDR that does contain an IPv4 IP", + cidr: IPv4CIDR, + ip: net.ParseIP("10.0.14.5"), + contains: true, + }, + { + name: "IPv4 CIDR that does NOT contain an IPv4 IP", + cidr: IPv4CIDR, + ip: net.ParseIP("10.1.55.5"), + contains: false, + }, + { + name: "IPv6 CIDR that does contain an IPv6 IP", + cidr: IPv6CIDR, + ip: net.ParseIP("2001:db8:1234:1a00::"), + contains: true, + }, + { + name: "IPv6 CIDR that does NOT contain an IPv6 IP", + cidr: IPv6CIDR, + ip: net.ParseIP("2001:af1:1222:1a20::"), + contains: false, + }, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + overlaps := ContainsAddress(tt.cidr, tt.ip) + assert.Equal(t, tt.contains, overlaps, "Given IP address should be part of the given CIDR") + }) + } +}