diff --git a/pkg/cloud_provider/file/fake.go b/pkg/cloud_provider/file/fake.go index 21c84aa0b..8d4852fd6 100644 --- a/pkg/cloud_provider/file/fake.go +++ b/pkg/cloud_provider/file/fake.go @@ -73,9 +73,9 @@ func (manager *fakeServiceManager) GetInstance(ctx context.Context, obj *Service } } -func (manager *fakeServiceManager) ListInstances(ctx context.Context, parent string) ([]*ServiceInstance, error) { +func (manager *fakeServiceManager) ListInstances(ctx context.Context, obj *ServiceInstance) ([]*ServiceInstance, error) { instances := []*ServiceInstance{ - &ServiceInstance{ + { Project: "test-project", Location: "test-location", Name: "test", @@ -84,7 +84,7 @@ func (manager *fakeServiceManager) ListInstances(ctx context.Context, parent str ReservedIpRange: "192.168.92.32/29", }, }, - &ServiceInstance{ + { Project: "test-project", Location: "test-location", Name: "test", diff --git a/pkg/cloud_provider/file/file.go b/pkg/cloud_provider/file/file.go index 6744a5141..91f7be497 100644 --- a/pkg/cloud_provider/file/file.go +++ b/pkg/cloud_provider/file/file.go @@ -56,7 +56,7 @@ type Service interface { CreateInstance(ctx context.Context, obj *ServiceInstance) (*ServiceInstance, error) DeleteInstance(ctx context.Context, obj *ServiceInstance) error GetInstance(ctx context.Context, obj *ServiceInstance) (*ServiceInstance, error) - ListInstances(ctx context.Context, parent string) ([]*ServiceInstance, error) + ListInstances(ctx context.Context, obj *ServiceInstance) ([]*ServiceInstance, error) } type gcfsServiceManager struct { @@ -232,11 +232,14 @@ func (manager *gcfsServiceManager) DeleteInstance(ctx context.Context, obj *Serv return nil } -func (manager *gcfsServiceManager) ListInstances(ctx context.Context, parent string) ([]*ServiceInstance, error) { - instances, err := manager.instancesService.List(parent).Context(ctx).Do() +// ListInstances returns a list of active instances in a project at a specific location +func (manager *gcfsServiceManager) ListInstances(ctx context.Context, obj *ServiceInstance) ([]*ServiceInstance, error) { + // Calling cloud provider service to get list of active instances. - indicates we are looking for instances in all the locations for a project + instances, err := manager.instancesService.List(locationURI(obj.Project, "-")).Context(ctx).Do() if err != nil { return nil, err } + var activeInstances []*ServiceInstance for _, activeInstance := range instances.Instances { serviceInstance, err := cloudInstanceToServiceInstance(activeInstance) diff --git a/pkg/csi_driver/controller.go b/pkg/csi_driver/controller.go index 409b13136..c489582e3 100644 --- a/pkg/csi_driver/controller.go +++ b/pkg/csi_driver/controller.go @@ -19,7 +19,6 @@ package driver import ( "fmt" "strings" - "sync" csi "github.com/container-storage-interface/spec/lib/go/csi/v0" "github.com/golang/glog" @@ -52,7 +51,7 @@ const ( paramTier = "tier" paramLocation = "location" paramNetwork = "network" - paramReservedIPV4CIDR = "cidr" + paramReservedIPV4CIDR = "reserved-ipv4-cidr" ) // controllerServer handles volume provisioning @@ -61,14 +60,14 @@ type controllerServer struct { } type controllerServerConfig struct { - driver *GCFSDriver - fileService file.Service - metaService metadata.Service - reservedIPRanges map[string]bool - mutex sync.Mutex + driver *GCFSDriver + fileService file.Service + metaService metadata.Service + ipAllocator *util.IPAllocator } func newControllerServer(config *controllerServerConfig) csi.ControllerServer { + config.ipAllocator = util.NewIPAllocator(make(map[string]bool)) return &controllerServer{config: config} } @@ -85,16 +84,7 @@ func (s *controllerServer) CreateVolume(ctx context.Context, req *csi.CreateVolu if err := s.config.driver.validateVolumeCapabilities(req.GetVolumeCapabilities()); err != nil { return nil, status.Error(codes.InvalidArgument, err.Error()) } - s.config.mutex.Lock() - instances, err := s.config.fileService.ListInstances(ctx, "-") - if err != nil { - return nil, status.Error(codes.Aborted, err.Error()) - } - s.config.reservedIPRanges = make(map[string]bool) - for _, instance := range instances { - s.config.reservedIPRanges[instance.Network.ReservedIpRange] = true - } capBytes := getRequestCapacity(req.GetCapacityRange()) glog.V(5).Infof("Using capacity bytes %q for volume %q", capBytes, name) @@ -114,16 +104,64 @@ func (s *controllerServer) CreateVolume(ctx context.Context, req *csi.CreateVolu return nil, status.Error(codes.AlreadyExists, err.Error()) } } else { + // If we are creating a new instance, we need pick an unused /29 range from reserved-ipv4-cidr + // If the param was not provided, we default reservedIPRange to "" and cloud provider takes care of the allocation + if reservedIPV4CIDR, ok := req.GetParameters()[paramReservedIPV4CIDR]; ok { + validCIDR := s.config.ipAllocator.ValidateCIDR(reservedIPV4CIDR) + if !validCIDR { + return nil, fmt.Errorf("invalid reserved-ipv4-cidr %s", reservedIPV4CIDR) + } + + reservedIPRange, err := s.reserveIPRange(ctx, newFiler, reservedIPV4CIDR) + + // Possible cases are 1) CreateInstanceAborted, 2)CreateInstance running in background + // The ListInstances response will contain the reservedIPRange if the operation was started + // In case of abort, the /29 IP is released and available for reservation + defer s.config.ipAllocator.ReleaseIPRange(reservedIPRange) + if err != nil { + return nil, err + } + + // Adding the reserved IP range to the instance object + newFiler.Network.ReservedIpRange = reservedIPRange + } + // Create the instance filer, err = s.config.fileService.CreateInstance(ctx, newFiler) if err != nil { return nil, status.Error(codes.Internal, err.Error()) } } - s.config.mutex.Unlock() return &csi.CreateVolumeResponse{Volume: fileInstanceToCSIVolume(filer, modeInstance)}, nil } +// reserveIPRange returns the available IP in the cidr +func (s *controllerServer) reserveIPRange(ctx context.Context, filer *file.ServiceInstance, cidr string) (string, error) { + cloudInstancesReservedIPRanges, err := s.getCloudInstancesReservedIPRanges(ctx, filer) + if err != nil { + return "", err + } + unreservedIPBlock, err := s.config.ipAllocator.GetUnreservedIPRange(cidr, cloudInstancesReservedIPRanges) + if err != nil { + return "", err + } + return unreservedIPBlock, nil +} + +//getCloudInstancesReservedIPRanges gets the list of reservedIPRanges from cloud instances +func (s *controllerServer) getCloudInstancesReservedIPRanges(ctx context.Context, filer *file.ServiceInstance) (map[string]bool, error) { + instances, err := s.config.fileService.ListInstances(ctx, filer) + if err != nil { + return nil, status.Error(codes.Aborted, err.Error()) + } + // Initialize an empty reserved list. It will be populated with all the reservedIPRanges obtained from the cloud instances + cloudInstancesReservedIPRanges := make(map[string]bool) + for _, instance := range instances { + cloudInstancesReservedIPRanges[instance.Network.ReservedIpRange] = true + } + return cloudInstancesReservedIPRanges, nil +} + // DeleteVolume deletes a GCFS instance func (s *controllerServer) DeleteVolume(ctx context.Context, req *csi.DeleteVolumeRequest) (*csi.DeleteVolumeResponse, error) { glog.V(4).Infof("DeleteVolume called with request %v", *req) @@ -225,7 +263,6 @@ func (s *controllerServer) generateNewFileInstance(name string, capBytes int64, tier := defaultTier network := defaultNetwork location := s.config.metaService.GetZone() - reservedIPV4CIDR := "" // Validate parameters (case-insensitive). for k, v := range params { switch strings.ToLower(k) { @@ -236,15 +273,11 @@ func (s *controllerServer) generateNewFileInstance(name string, capBytes int64, location = v case paramNetwork: network = v - case paramReservedIPV4CIDR: - reservedIPV4CIDR, err := util.GetUnreservedIPBlock(s.config.reservedIPRanges, v) - if err != nil { - return nil, err - } - if reservedIPV4CIDR == "" { - return nil, fmt.Errorf("Invalid unreserved IP block received for cidr %s", v) - } + // Ignore the cidr flag as it is not passed to the cloud provider + // It will be used to get unreserved IP in the reserveIPV4Range function + case paramReservedIPV4CIDR: + continue case "csiprovisionersecretname", "csiprovisionersecretnamespace": default: return nil, fmt.Errorf("invalid parameter %q", k) @@ -256,8 +289,7 @@ func (s *controllerServer) generateNewFileInstance(name string, capBytes int64, Location: location, Tier: tier, Network: file.Network{ - Name: network, - ReservedIpRange: reservedIPV4CIDR, + Name: network, }, Volume: file.Volume{ Name: newInstanceVolume, diff --git a/pkg/csi_driver/controller_test.go b/pkg/csi_driver/controller_test.go index cf6ca4af3..7c2692c5a 100644 --- a/pkg/csi_driver/controller_test.go +++ b/pkg/csi_driver/controller_test.go @@ -28,12 +28,13 @@ import ( ) const ( - testProject = "test-project" - testLocation = "test-location" - testIp = "test-ip" - testCSIVolume = "test-csi" - testVolumeId = "modeInstance/test-location/test-csi/vol1" - testBytes = 1 * util.Tb + testProject = "test-project" + testLocation = "test-location" + testIp = "test-ip" + testCSIVolume = "test-csi" + testVolumeId = "modeInstance/test-location/test-csi/vol1" + testReservedIPV4CIDR = "192.168.92.0/26" + testBytes = 1 * util.Tb ) func initTestController(t *testing.T) csi.ControllerServer { diff --git a/pkg/util/ip_reservation.go b/pkg/util/ip_reservation.go index a2e7b5d1e..a9aa62051 100644 --- a/pkg/util/ip_reservation.go +++ b/pkg/util/ip_reservation.go @@ -17,71 +17,154 @@ limitations under the License. package util import ( - "bytes" "fmt" "net" + "sync" ) -// GetUnreservedIPBlock returns an unreserved /29 IP block. It accepts the list of currently reserved -// IPs and the requested CIDR as arguments and returns the /29 IP available in that CIDR +const ( + incrementStep29IPRange = 8 + reservationRange = "/29" + byteMax = 255 +) + +// IPAllocator struct consists of shared resources that are used to keep track of the /29 IPRanges currently reserved by service instances +type IPAllocator struct { + // pendingIPRanges set maintains the set of IP ranges that have been reserved by the service instances but pending reservation in the cloud instances + // The key is a IP range currently reserved by a service instance e.g(192.168.92.0/29). Value is a bool to implement map as a set + pendingIPRanges map[string]bool + + // pendingIPRangesMutex is used to synchronize access to the pendingIPRanges set to prevent data races + pendingIPRangesMutex sync.Mutex +} + +// NewIPAllocator is the constructor to initialize the IPAllocator object +// Argument pendingIPRanges map[string]bool is a set of IP ranges currently reserved by service instances but pending reservation in the cloud instances +func NewIPAllocator(pendingIPRanges map[string]bool) *IPAllocator { + return &IPAllocator{ + pendingIPRanges: pendingIPRanges, + } +} + +// holdIPRange adds a particular IP range in the pendingIPRanges set +// Argument ipRange string is an IPV4 range which needs put in pendingIPRanges +func (ipAllocator *IPAllocator) holdIPRange(ipRange string) { + ipAllocator.pendingIPRanges[ipRange] = true +} + +// ReleaseIPRange releases the pending IPRange +// Argument ipRange string is an IPV4 range which needs to be released +func (ipAllocator *IPAllocator) ReleaseIPRange(ipRange string) { + ipAllocator.pendingIPRangesMutex.Lock() + defer ipAllocator.pendingIPRangesMutex.Unlock() + delete(ipAllocator.pendingIPRanges, ipRange) +} + +// ValidateCIDR function validates whether a particular cidr is a valid IP range +// Argument cidr string is is a CIDR range that needs to be validated +func (ipAllocator *IPAllocator) ValidateCIDR(cidr string) bool { + _, ipnet, err := net.ParseCIDR(cidr) + if err != nil { + return false + } + // The cidr network id size needs to be maximum 29 bits + cidrSize, _ := ipnet.Mask.Size() + return cidrSize <= 29 + +} + +// GetUnreservedIPRange returns an unreserved /29 IP block. +// cidr: Provided cidr address in which we need to look for an unreserved /29 IP Block +// cloudInstancesReservedIPRanges: All the used IP ranges in the cloud instances +// All the used IP ranges in the service instances not updated in cloud instances is extracted from the pendingIPRanges list in the IPAllocator +// Finally a final reservedIPRange list is created by merging these two lists // Potential error cases: 1) No /29 Block in the CIDR is unreserved -// Parsing the CIDR resulted in an error -func GetUnreservedIPBlock(reservedIPRanges map[string]bool, cidr string) (string, error) { +// 2) Parsing the CIDR resulted in an error +func (ipAllocator *IPAllocator) GetUnreservedIPRange(cidr string, cloudInstancesReservedIPRanges map[string]bool) (string, error) { ip, ipnet, err := net.ParseCIDR(cidr) if err != nil { return "", fmt.Errorf("Unable to parse CIDR %s", cidr) } - buffer := bytes.NewBufferString("") - for cidrIPBlock := cloneIP(ip.Mask(ipnet.Mask)); ipnet.Contains(cidrIPBlock) && err == nil; err = incrementIP(cidrIPBlock, 8) { + + var reservedIPRanges = make(map[string]bool) + + // The final reserved list is obtained by combining the cloudInstancesReservedIPRanges list and the pendingIPRanges list in the ipAllocator + for cloudInstancesReservedIPRange := range cloudInstancesReservedIPRanges { + reservedIPRanges[cloudInstancesReservedIPRange] = true + } + + // Lock is placed here so that the pendingIPRanges list captures all the IPs pending reservation in the cloud instances + ipAllocator.pendingIPRangesMutex.Lock() + defer ipAllocator.pendingIPRangesMutex.Unlock() + for reservedIPRange := range ipAllocator.pendingIPRanges { + reservedIPRanges[reservedIPRange] = true + } + + for cidrIPBlock := cloneIP(ip.Mask(ipnet.Mask)); ipnet.Contains(cidrIPBlock) && err == nil; cidrIPBlock, err = incrementIP(cidrIPBlock, incrementStep29IPRange) { overLap := false - buffer.WriteString(cidrIPBlock.String()) - buffer.WriteString("/29") + ipRange := fmt.Sprint(cidrIPBlock.String(), reservationRange) for reservedIPRange := range reservedIPRanges { - err = validateCIDROverlap(buffer.String(), reservedIPRange) + // Find if the current IP range in the CIDR overlaps with any of the reserved IP ranges. If not, this can be returned + overLap, err = isOverlap(ipRange, reservedIPRange) + // Error while processing cidr if err != nil { - overLap = true + return "", fmt.Errorf("Error while parsing cidr to determine overlap between %s and %s", cidrIPBlock.String(), reservedIPRange) + } + if overLap { break } } if !overLap { - return buffer.String(), nil + ipAllocator.holdIPRange(ipRange) + return ipRange, nil } - buffer.Reset() } + + // No unreserved IP range available in the entire CIDR range since we did not return return "", fmt.Errorf("All of the /29 IP ranges in the cidr %s are reserved", cidr) } -// validateCIDROverlap checks if two cidrs have any overlapping IPs -func validateCIDROverlap(cidr1 string, cidr2 string) error { +// isOverlap checks if two cidrs have any overlapping IPs +func isOverlap(cidr1 string, cidr2 string) (bool, error) { _, ipnet1, err := net.ParseCIDR(cidr1) + // Invalid CIDRs are considered as overlapping if err != nil { - return err + return true, fmt.Errorf("Invalid cidr %s provided", cidr1) } _, ipnet2, err := net.ParseCIDR(cidr2) if err != nil { - return err - } - - if ipnet1.Contains(ipnet2.IP) || ipnet2.Contains(ipnet1.IP) { - return fmt.Errorf("The cidr ranges %s and %s overlap", cidr1, cidr2) + return true, fmt.Errorf("Invalid cidr %s provided", cidr2) } - return nil + return ipnet1.Contains(ipnet2.IP) || ipnet2.Contains(ipnet1.IP), nil } -// Increment the given IP value by the provided step -func incrementIP(ip net.IP, step byte) error { - if uint8(255)-uint8(ip[len(ip)-1]) < uint8(step) { - return fmt.Errorf("IP overflow occured when incrementing ip %s with step %d", ip.String(), step) +// Increment the given IP value by the provided step. The step is a byte with maximum value maximum byte value +func incrementIP(ip net.IP, step byte) (net.IP, error) { + incrementedIP := cloneIP(ip) + incrementedIP = incrementedIP.To4() + + // Step can be added directly to the Least Significant Byte and we can return the result + if incrementedIP[3] < byteMax-step { + incrementedIP[3] += step + return incrementedIP, nil } - for j := len(ip) - 1; j >= 0; j-- { - ip[j] += step - if ip[j] > 0 { - break + + // Step addition in the Least Significant Byte resulted in overflow + // Propogating the carry addition to the higher order bytes and calculating value of the current byte + incrementedIP[3] = incrementedIP[3] - byteMax + step - 1 + + for ipByte := 2; ipByte >= 0; ipByte-- { + // Rollover occurs when value changes from maximum byte value to 0 as maximum propagated carry is 1 + if incrementedIP[ipByte] != byteMax { + incrementedIP[ipByte]++ + return incrementedIP, nil } + incrementedIP[ipByte] = 0 } - return nil + return nil, fmt.Errorf("IP range overflowed while incrementing IP %s by step %d", ip.String(), step) + } // Clone the provided IP and return the copy diff --git a/pkg/util/ip_reservation_test.go b/pkg/util/ip_reservation_test.go index 3c9419df0..34046d1f3 100644 --- a/pkg/util/ip_reservation_test.go +++ b/pkg/util/ip_reservation_test.go @@ -21,93 +21,237 @@ import ( "testing" ) -func TestAllIPBlocksAvailable(t *testing.T) { - ips := [8]string{"192.168.92.32/29", "192.168.92.40/29", "192.168.92.48/29", "192.168.92.56/29"} - - cidr := "192.168.92.32/27" - - ip, err := GetUnreservedIPBlock(make(map[string]bool), cidr) - if err != nil { - t.Errorf(err.Error()) - } - if ip != ips[0] { - t.Errorf("Expected IP %s to be released from the free IP pool but got IP %s", ips[0], ip) +func initTestIPAllocator() *IPAllocator { + pendingIPRanges := make(map[string]bool) + return &IPAllocator{ + pendingIPRanges: pendingIPRanges, } } -func TestSomeIPBlocksAvailable(t *testing.T) { - ips := [8]string{"192.168.92.32/29", "192.168.92.40/29", "192.168.92.48/29", "192.168.92.56/29"} - - cidr := "192.168.92.32/27" - reservedIPBlocks := make(map[string]bool) - for i := 0; i < 2; i++ { - reservedIPBlocks[ips[i]] = true +func TestValidateCIDR(t *testing.T) { + testCIDR := "192.168.92.0/29" + ipAllocator := initTestIPAllocator() + valid := ipAllocator.ValidateCIDR(testCIDR) + if !valid { + t.Errorf("Valid CIDR %s evaluated to be invalid", testCIDR) } - ip, err := GetUnreservedIPBlock(reservedIPBlocks, cidr) - if err != nil { - t.Errorf(err.Error()) + testCIDR = "192.168.92.0/30" + valid = ipAllocator.ValidateCIDR(testCIDR) + if valid { + t.Errorf("Invalid CIDR %s evaluated to be valid", testCIDR) } - if ip != ips[2] { - t.Errorf("Expected IP %s to be released from the free IP pool but got IP %s", ips[2], ip) + + testCIDR = "192.168.92.0" + valid = ipAllocator.ValidateCIDR(testCIDR) + if valid { + t.Errorf("Invalid CIDR %s evaluated to be invalid", testCIDR) } } -func TestNoIPBlocksAvailable(t *testing.T) { - ips := [8]string{"192.168.92.33/29", "192.168.92.42/29", "192.168.92.49/29", "192.168.92.58/29"} - - cidr := "192.168.92.32/27" - reservedIPBlocks := make(map[string]bool) - for _, ip := range ips { - reservedIPBlocks[ip] = true +func TestGetUnReservedIPRange(t *testing.T) { + // Using IPs which are not the beginning IPs of /29 CIDRs to evaluate the edge case + ips := [8]string{"192.168.92.3/29", "192.168.92.10/29", "192.168.92.20/29", "192.168.92.28/29"} + cidr := "192.168.92.0/27" + cases := []struct { + name string + expected string + pendingIPRanges map[string]bool + cloudProviderReservedIPRanges map[string]bool + errorExpected bool + }{ + { + name: "0/4 /29 IP ranges used", + expected: "192.168.92.0/29", + pendingIPRanges: make(map[string]bool), + cloudProviderReservedIPRanges: make(map[string]bool), + errorExpected: false, + }, { + name: "1/4 /29 IP ranges used", + expected: "192.168.92.8/29", + pendingIPRanges: map[string]bool{ + ips[0]: true, + }, + cloudProviderReservedIPRanges: make(map[string]bool), + errorExpected: false, + }, + { + name: "2/4 /29 IP ranges used", + expected: "192.168.92.16/29", + pendingIPRanges: map[string]bool{ + ips[0]: true, + }, + cloudProviderReservedIPRanges: map[string]bool{ + ips[1]: true, + }, + errorExpected: false, + }, + { + name: "3/4 /29 IP ranges used", + expected: "192.168.92.24/29", + pendingIPRanges: map[string]bool{ + ips[0]: true, + ips[2]: true, + }, + cloudProviderReservedIPRanges: map[string]bool{ + ips[1]: true, + }, + errorExpected: false, + }, + { + name: "All /29 IP ranges used", + expected: "", + pendingIPRanges: map[string]bool{ + ips[0]: true, + ips[2]: true, + }, + cloudProviderReservedIPRanges: map[string]bool{ + ips[1]: true, + ips[3]: true, + }, + errorExpected: true, + }, } - ip, err := GetUnreservedIPBlock(reservedIPBlocks, cidr) - if err == nil { - t.Errorf("Expected error as all IPs in cidr %s had been reserved but got IP %s as unreserved", cidr, ip) + for _, test := range cases { + ipAllocator := initTestIPAllocator() + ipAllocator.pendingIPRanges = test.pendingIPRanges + ipRange, err := ipAllocator.GetUnreservedIPRange(cidr, test.cloudProviderReservedIPRanges) + if err != nil && !test.errorExpected { + t.Errorf("test %q failed: got error %s, expected %s", test.name, err.Error(), test.expected) + } else if err == nil && test.errorExpected { + t.Errorf("test %q failed: got reserved IP range %s, expected error", test.name, ipRange) + } else if ipRange != test.expected { + t.Errorf("test %q failed: got reserved IP range %s, expected %s", test.name, ipRange, test.expected) + } } } - -func TestValidateCIDR(t *testing.T) { - cidr1 := "192.168.92.32/27" - cidr2 := "192.168.92.48/26" - - err := validateCIDROverlap(cidr1, cidr2) - if err == nil { - t.Errorf("Expected error as cidr %s overlaps with cidr %s", cidr1, cidr2) +func TestValidateCIDROverlap(t *testing.T) { + cases := []struct { + name string + cidr1 string + cidr2 string + expected bool + errorExpected bool + }{ + { + name: "Overlapping CIDRs", + cidr1: "192.168.92.0/29", + cidr2: "192.168.92.48/26", + expected: true, + errorExpected: false, + }, + { + name: "Non overlapping CIDRs", + cidr1: "192.168.92.0/29", + cidr2: "192.168.22.67/26", + expected: false, + errorExpected: false, + }, + { + name: "Non overlapping CIDRs with same cidr size", + cidr1: "192.168.92.247/29", + cidr2: "192.168.92.248/29", + expected: false, + errorExpected: false, + }, + { + name: "Overlapping CIDRs with same cidr size", + cidr1: "192.168.92.249/29", + cidr2: "192.168.92.255/29", + expected: true, + errorExpected: false, + }, + { + name: "Invalid CIDR provided", + cidr1: "192.168.92.0", + cidr2: "192.168.22.67/26", + errorExpected: true, + }, } - cidr2 = "192.168.22.67/26" - err = validateCIDROverlap(cidr1, cidr2) - if err != nil { - t.Errorf("Got overlapping error for non overlapping cidrs %s and %s", cidr1, cidr2) + for _, test := range cases { + overlap, err := isOverlap(test.cidr1, test.cidr2) + if err != nil && !test.errorExpected { + t.Errorf("test %q failed: got error %s, expected cidr overlap between %s and %s to be %t", test.name, err.Error(), test.cidr1, test.cidr2, test.expected) + } else if err == nil && test.errorExpected { + t.Errorf("test %q failed: got cidr overlap value %t, expected error", test.name, overlap) + } else if !test.errorExpected && overlap != test.expected { + t.Errorf("test %q failed: got overlap for cidr %s and %s as %t, expected %t", test.name, test.cidr1, test.cidr2, test.expected, test.expected) + } } - } func TestIncrementIP(t *testing.T) { - currentIP := "192.168.92.32" - nextIP := "192.168.92.40" - ip := net.ParseIP(currentIP) - incrementIP(ip, 8) - if ip.String() != nextIP { - t.Errorf("Error while incrementing IP expected %s but got %s", nextIP, ip.String()) + cases := []struct { + name string + currentIP string + step byte + expected string + errorExpected bool + }{ + { + name: "Valid IP increment without carry forward to significant bytes", + currentIP: "192.168.92.32", + step: 7, + expected: "192.168.92.39", + errorExpected: false, + }, + { + name: "Valid IP increment without carry forward to significant bytes", + currentIP: "192.168.92.32", + step: 143, + expected: "192.168.92.175", + errorExpected: false, + }, { + name: "Valid increment with carry forward to significant bytes", + currentIP: "192.168.255.253", + step: 7, + expected: "192.169.0.4", + errorExpected: false, + }, { + name: "Valid increment with carry forward to significant bytes", + currentIP: "192.255.255.253", + step: 255, + expected: "193.0.0.252", + errorExpected: false, + }, + { + name: "Valid increment with carry forward to significant bytes", + currentIP: "0.255.255.255", + step: 3, + expected: "1.0.0.2", + errorExpected: false, + }, + { + name: "Invalid increment", + currentIP: "255.255.255.253", + step: 167, + errorExpected: true, + }, } - currentIP = "255.255.255.254" - ip = net.ParseIP(currentIP) - step := 8 - err := incrementIP(ip, byte(step)) - if err == nil { - t.Errorf("IP Overflow not caught for IP %s increment by %d", ip.String(), 8) + for _, test := range cases { + currentIP := net.ParseIP(test.currentIP) + incrementedIP, err := incrementIP(currentIP, test.step) + + if err != nil && !test.errorExpected { + t.Errorf("test %q failed: got error %s, expected %s", test.name, err.Error(), test.expected) + } else if err == nil && test.errorExpected { + t.Errorf("test %q failed: got reserved IP range %s, expected error", test.name, incrementedIP.String()) + } else if !test.errorExpected && incrementedIP.String() != test.expected { + t.Errorf("test %q failed: got incremented IP %s, expected %s", test.name, incrementedIP.String(), test.expected) + } } } - func TestCloneIP(t *testing.T) { originalIP := net.ParseIP("192.168.92.32") cloneIP := cloneIP(originalIP) if cloneIP.String() != originalIP.String() { t.Errorf("Error while cloning IP %s", originalIP.String()) } + if &originalIP == &cloneIP { + t.Errorf("Clone function returned the original object") + } }