diff --git a/pkg/core/core_test.go b/pkg/core/core_test.go index d4c1251..6f46791 100644 --- a/pkg/core/core_test.go +++ b/pkg/core/core_test.go @@ -1,32 +1,33 @@ -package core +package core_test import ( "net" "testing" + "github.com/bschaatsbergen/cidr/pkg/core" "github.com/stretchr/testify/assert" ) func TestGetAddressCount(t *testing.T) { - IPv4CIDR, err := ParseCIDR("10.0.0.0/16") + IPv4CIDR, err := core.ParseCIDR("10.0.0.0/16") if err != nil { t.Log(err) t.Fail() } - IPv6CIDR, err := ParseCIDR("2001:db8:1234:1a00::/106") + IPv6CIDR, err := core.ParseCIDR("2001:db8:1234:1a00::/106") if err != nil { t.Log(err) t.Fail() } - largeIPv4PrefixCIDR, err := ParseCIDR("172.16.18.0/31") + largeIPv4PrefixCIDR, err := core.ParseCIDR("172.16.18.0/31") if err != nil { t.Log(err) t.Fail() } - largestIPv4PrefixCIDR, err := ParseCIDR("172.16.18.0/32") + largestIPv4PrefixCIDR, err := core.ParseCIDR("172.16.18.0/32") if err != nil { t.Log(err) t.Fail() @@ -60,38 +61,38 @@ func TestGetAddressCount(t *testing.T) { } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { - count := GetAddressCount(tt.cidr) + count := core.GetAddressCount(tt.cidr) assert.Equal(t, int(tt.expectedCount), int(count), "Both address counts should be equal") }) } } func TestOverlaps(t *testing.T) { - firstIPv4CIDR, err := ParseCIDR("10.0.0.0/16") + firstIPv4CIDR, err := core.ParseCIDR("10.0.0.0/16") if err != nil { t.Log(err) t.Fail() } - secondIPv4CIDR, err := ParseCIDR("10.0.14.0/22") + secondIPv4CIDR, err := core.ParseCIDR("10.0.14.0/22") if err != nil { t.Log(err) t.Fail() } - thirdIPv4CIDR, err := ParseCIDR("10.1.0.0/28") + thirdIPv4CIDR, err := core.ParseCIDR("10.1.0.0/28") if err != nil { t.Log(err) t.Fail() } - firstIPv6CIDR, err := ParseCIDR("2001:db8:1111:2222:1::/80") + firstIPv6CIDR, err := core.ParseCIDR("2001:db8:1111:2222:1::/80") if err != nil { t.Log(err) t.Fail() } - secondIPv6CIDR, err := ParseCIDR("2001:db8:1111:2222:1:1::/96") + secondIPv6CIDR, err := core.ParseCIDR("2001:db8:1111:2222:1:1::/96") if err != nil { t.Log(err) t.Fail() @@ -124,20 +125,20 @@ func TestOverlaps(t *testing.T) { } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { - overlaps := Overlaps(tt.cidrA, tt.cidrB) + overlaps := core.Overlaps(tt.cidrA, tt.cidrB) assert.Equal(t, tt.overlaps, overlaps, "Given CIDRs should overlap") }) } } func TestContainsAddress(t *testing.T) { - IPv4CIDR, err := ParseCIDR("10.0.0.0/16") + IPv4CIDR, err := core.ParseCIDR("10.0.0.0/16") if err != nil { t.Log(err) t.Fail() } - IPv6CIDR, err := ParseCIDR("2001:db8:1234:1a00::/106") + IPv6CIDR, err := core.ParseCIDR("2001:db8:1234:1a00::/106") if err != nil { t.Log(err) t.Fail() @@ -176,57 +177,199 @@ func TestContainsAddress(t *testing.T) { } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { - overlaps := ContainsAddress(tt.cidr, tt.ip) + overlaps := core.ContainsAddress(tt.cidr, tt.ip) assert.Equal(t, tt.contains, overlaps, "Given IP address should be part of the given CIDR") }) } } -func TestGetNetMask(t *testing.T) { - IPv4CIDR, err := ParseCIDR("10.0.0.0/16") +func TestGetPrefixLength(t *testing.T) { + IPv4CIDR, err := core.ParseCIDR("10.0.0.0/16") if err != nil { t.Log(err) t.Fail() } - IPv6CIDR, err := ParseCIDR("2001:db8:1234:1a00::/106") + IPv6CIDR, err := core.ParseCIDR("2001:db8:1234:1a00::/106") if err != nil { t.Log(err) t.Fail() } tests := []struct { - name string - cidr *net.IPNet - expectedNetMask net.IP + name string + netMask net.IP + expectedPrefixLength int }{ { - name: "Get the netmask of an IPv4 CIDR", - cidr: IPv4CIDR, - expectedNetMask: net.IP{0xff, 0xff, 0x0, 0x0}, + name: "Get the prefix length of an IPv4 netmask", + netMask: net.IP(IPv4CIDR.Mask), + expectedPrefixLength: 16, }, { - name: "Get the netmask of an IPv6 CIDR", - cidr: IPv6CIDR, - expectedNetMask: net.IP{0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xc0, 0x0, 0x0}, + name: "Get the prefix length of an IPv6 netmask", + netMask: net.IP(IPv6CIDR.Mask), + expectedPrefixLength: 106, }, } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + prefixLength := core.GetPrefixLength(tt.netMask) + assert.Equal(t, tt.expectedPrefixLength, prefixLength, "Prefix length is not correct") + }) + } +} + +func TestGetFirstUsableIPAddress(t *testing.T) { + IPv4CIDR, err := core.ParseCIDR("10.0.0.0/16") + if err != nil { + t.Log(err) + t.Fail() + } + + IPv6CIDR, err := core.ParseCIDR("2001:db8:1234:1a00::/106") + if err != nil { + t.Log(err) + t.Fail() + } + + tests := []struct { + name string + CIDR *net.IPNet + expectedFirstUsableIPAddress net.IP + }{ + { + name: "Get the first usable IP address of an IPv4 CIDR range", + CIDR: IPv4CIDR, + expectedFirstUsableIPAddress: net.ParseIP("10.0.0.1").To4(), + }, + { + name: "Get the first usable IP address of an IPv6 CIDR range", + CIDR: IPv6CIDR, + expectedFirstUsableIPAddress: net.ParseIP("2001:db8:1234:1a00::").To16(), + }, + } + for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { - netMask := GetNetMask(tt.cidr) - assert.Equal(t, tt.expectedNetMask, netMask, "Netmask is not correct") + firstUsableIPAddress, err := core.GetFirstUsableIPAddress(tt.CIDR) + if err != nil { + t.Log(err) + t.Fail() + } + assert.Equal(t, tt.expectedFirstUsableIPAddress, firstUsableIPAddress, "First usable IP address is not correct") + }) + } +} + +func TestGetLastUsableIPAddress(t *testing.T) { + IPv4CIDR, err := core.ParseCIDR("10.0.0.0/16") + if err != nil { + t.Log(err) + t.Fail() + } + + IPv6CIDR, err := core.ParseCIDR("2001:db8:1234:1a00::/106") + if err != nil { + t.Log(err) + t.Fail() + } + + tests := []struct { + name string + CIDR *net.IPNet + expectedLastUsableIPAddress net.IP + }{ + { + name: "Get the last usable IP address of an IPv4 CIDR range", + CIDR: IPv4CIDR, + expectedLastUsableIPAddress: net.ParseIP("10.0.255.254").To4(), + }, + { + name: "Get the last usable IP address of an IPv6 CIDR range", + CIDR: IPv6CIDR, + expectedLastUsableIPAddress: net.ParseIP("2001:db8:1234:1a00::3f:ffff").To16(), + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + lastUsableIPAddress, err := core.GetLastUsableIPAddress(tt.CIDR) + if err != nil { + t.Log(err) + t.Fail() + } + assert.Equal(t, tt.expectedLastUsableIPAddress, lastUsableIPAddress, "Last usable IP address is not correct") + }) + } +} + +func TestGetBroadcastAddress(t *testing.T) { + IPv4CIDR, err := core.ParseCIDR("10.0.0.0/16") + if err != nil { + t.Log(err) + t.Fail() + } + + IPv4CIDRWithNoBroadcastAddress, err := core.ParseCIDR("10.0.0.0/31") + if err != nil { + t.Log(err) + t.Fail() + } + + IPv6CIDR, err := core.ParseCIDR("2001:db8:1234:1a00::/106") + if err != nil { + t.Log(err) + t.Fail() + } + + tests := []struct { + name string + CIDR *net.IPNet + expectedBroadcastAddress net.IP + wantErr bool + }{ + { + name: "Get the broadcast IP address of an IPv4 CIDR range", + CIDR: IPv4CIDR, + expectedBroadcastAddress: net.ParseIP("10.0.255.255").To4(), + wantErr: false, + }, + { + name: "Get the broadcast IP address of an IPv4 CIDR range that has no broadcast address", + CIDR: IPv4CIDRWithNoBroadcastAddress, + expectedBroadcastAddress: nil, + wantErr: true, + }, + { + name: "Get the broadcast IP address of an IPv6 CIDR range", + CIDR: IPv6CIDR, + expectedBroadcastAddress: nil, + wantErr: true, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + broadcastAddress, err := core.GetBroadcastAddress(tt.CIDR) + if err != nil { + assert.Equal(t, tt.wantErr, true, "Expected error when getting broadcast address, but got none") + } else { + assert.Equal(t, tt.expectedBroadcastAddress, broadcastAddress, "Broadcast IP address is not correct") + } }) } } func TestGetBaseAddress(t *testing.T) { - IPv4CIDR, err := ParseCIDR("192.168.90.4/30") + IPv4CIDR, err := core.ParseCIDR("192.168.90.4/30") if err != nil { t.Log(err) t.Fail() } - IPv6CIDR, err := ParseCIDR("4a00:db8:1234:1a00::/127") + IPv6CIDR, err := core.ParseCIDR("4a00:db8:1234:1a00::/127") if err != nil { t.Log(err) t.Fail() @@ -250,7 +393,7 @@ func TestGetBaseAddress(t *testing.T) { } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { - baseAddress := GetBaseAddress(tt.cidr) + baseAddress := core.GetBaseAddress(tt.cidr) assert.Equal(t, tt.expectedBaseAddress, baseAddress, "Base address is not correct") }) } @@ -290,7 +433,7 @@ func TestParseCIDR(t *testing.T) { } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { - _, err := ParseCIDR(tt.cidrStr) + _, err := core.ParseCIDR(tt.cidrStr) if (err != nil) != tt.wantErr { t.Errorf("ParseCIDR() error = %v, wantErr %v", err, tt.wantErr) return diff --git a/pkg/helper/contains_test.go b/pkg/helper/contains_test.go new file mode 100644 index 0000000..e0ff9e4 --- /dev/null +++ b/pkg/helper/contains_test.go @@ -0,0 +1,43 @@ +package helper_test + +import ( + "testing" + + "github.com/bschaatsbergen/cidr/pkg/helper" +) + +func TestContainsInt(t *testing.T) { + type args struct { + ints []int + specifiedInt int + } + tests := []struct { + name string + args args + want bool + }{ + { + name: "ContainsInt() should return true", + args: args{ + ints: []int{1, 2, 3, 4, 5}, + specifiedInt: 3, + }, + want: true, + }, + { + name: "ContainsInt() should return false", + args: args{ + ints: []int{1, 2, 3, 4, 5}, + specifiedInt: 6, + }, + want: false, + }, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + if got := helper.ContainsInt(tt.args.ints, tt.args.specifiedInt); got != tt.want { + t.Errorf("ContainsInt() = %v, want %v", got, tt.want) + } + }) + } +}