Skip to content

Commit

Permalink
Merge pull request #193 from ipinfo/talhahameed/be-2715-handle-adjace…
Browse files Browse the repository at this point in the history
…nt-prefixes-in-ipinfo-tool-aggregate

Handle adjacent prefixes in ipinfo tool aggregate
  • Loading branch information
UmanShahzad committed Nov 21, 2023
2 parents 32f1e7e + 2aedf33 commit d2f94f9
Show file tree
Hide file tree
Showing 3 changed files with 148 additions and 125 deletions.
7 changes: 1 addition & 6 deletions ipinfo/cmd_tool_aggregate.go
Original file line number Diff line number Diff line change
Expand Up @@ -23,9 +23,7 @@ func printHelpToolAggregate() {
`Usage: %s tool aggregate [<opts>] <cidr | ip | ip-range | filepath>
Description:
Accepts IPs, IP ranges, and CIDRs, aggregating them efficiently.
Input can be IPs, IP ranges, CIDRs, and/or filepath to a file
containing any of these. Works for both IPv4 and IPv6.
Accepts IPv4 IPs and CIDRs, aggregating them efficiently.
If input contains single IPs, it tries to merge them into the input CIDRs,
otherwise they are printed to the output as they are.
Expand All @@ -37,9 +35,6 @@ Examples:
# Aggregate two CIDRs.
$ %[1]s tool aggregate 1.1.1.0/30 1.1.1.0/28
# Aggregate IP range and CIDR.
$ %[1]s tool aggregate 1.1.1.0-1.1.1.244 1.1.1.0/28
# Aggregate enteries from 2 files.
$ %[1]s tool aggregate /path/to/file1.txt /path/to/file2.txt
Expand Down
68 changes: 68 additions & 0 deletions lib/cidr.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,68 @@
package lib

import (
"bytes"
"encoding/binary"
"math"
"net"
"sort"
)

// CIDR represens a Classless Inter-Domain Routing structure.
type CIDR struct {
IP net.IP
Network *net.IPNet
}

// newCidr creates a newCidr CIDR structure.
func newCidr(s string) *CIDR {
ip, ipnet, err := net.ParseCIDR(s)
if err != nil {
panic(err)
}
return &CIDR{
IP: ip,
Network: ipnet,
}
}

func (c *CIDR) String() string {
return c.Network.String()
}

// MaskLen returns a network mask length.
func (c *CIDR) MaskLen() uint32 {
i, _ := c.Network.Mask.Size()
return uint32(i)
}

// PrefixUint32 returns a prefix.
func (c *CIDR) PrefixUint32() uint32 {
return binary.BigEndian.Uint32(c.IP.To4())
}

// Size returns a size of a CIDR range.
func (c *CIDR) Size() int {
ones, bits := c.Network.Mask.Size()
return int(math.Pow(2, float64(bits-ones)))
}

// list returns a slice of sorted CIDR structures.
func list(s []string) []*CIDR {
out := make([]*CIDR, 0)
for _, c := range s {
out = append(out, newCidr(c))
}
sort.Sort(cidrSort(out))
return out
}

type cidrSort []*CIDR

func (s cidrSort) Len() int { return len(s) }
func (s cidrSort) Swap(i, j int) { s[i], s[j] = s[j], s[i] }

func (s cidrSort) Less(i, j int) bool {
cmp := bytes.Compare(s[i].IP, s[j].IP)
return cmp < 0 || (cmp == 0 && s[i].MaskLen() < s[j].MaskLen())
}
198 changes: 79 additions & 119 deletions lib/cmd_tool_aggregate.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,12 +2,10 @@ package lib

import (
"bufio"
"bytes"
"fmt"
"io"
"net"
"os"
"sort"
"strings"

"github.com/spf13/pflag"
Expand Down Expand Up @@ -55,72 +53,21 @@ func CmdToolAggregate(
return nil
}

// Parses a list of CIDRs.
parseCIDRs := func(cidrs []string) []net.IPNet {
parsedCIDRs := make([]net.IPNet, 0)
for _, cidrStr := range cidrs {
_, ipNet, err := net.ParseCIDR(cidrStr)
if err != nil {
if !f.Quiet {
fmt.Printf("Invalid CIDR: %s\n", cidrStr)
}
continue
}
parsedCIDRs = append(parsedCIDRs, *ipNet)
}

return parsedCIDRs
}

// Input parser.
parseInput := func(rows []string) ([]net.IPNet, []net.IP) {
parsedCIDRs := make([]net.IPNet, 0)
parseInput := func(rows []string) ([]string, []net.IP) {
parsedCIDRs := make([]string, 0)
parsedIPs := make([]net.IP, 0)
var separator string
for _, rowStr := range rows {
if strings.ContainsAny(rowStr, ",-") {
if delim := strings.ContainsRune(rowStr, ','); delim {
separator = ","
} else {
separator = "-"
}

ipRange := strings.Split(rowStr, separator)
if len(ipRange) != 2 {
if !f.Quiet {
fmt.Printf("Invalid IP range: %s\n", rowStr)
}
continue
}

if strings.ContainsRune(rowStr, ':') {
cidrs, err := CIDRsFromIP6RangeStrRaw(rowStr)
if err == nil {
parsedCIDRs = append(parsedCIDRs, parseCIDRs(cidrs)...)
continue
} else {
if !f.Quiet {
fmt.Printf("Invalid IP range %s. Err: %v\n", rowStr, err)
}
continue
}
} else {
cidrs, err := CIDRsFromIPRangeStrRaw(rowStr)
if err == nil {
parsedCIDRs = append(parsedCIDRs, parseCIDRs(cidrs)...)
continue
} else {
if !f.Quiet {
fmt.Printf("Invalid IP range %s. Err: %v\n", rowStr, err)
}
continue
}
}
continue
} else if strings.ContainsRune(rowStr, '/') {
parsedCIDRs = append(parsedCIDRs, parseCIDRs([]string{rowStr})...)
_, ipnet, err := net.ParseCIDR(rowStr)
if err == nil && IsCIDRIPv4(ipnet) {
parsedCIDRs = append(parsedCIDRs, []string{rowStr}...)
}
continue
} else {
if ip := net.ParseIP(rowStr); ip != nil {
if ip := net.ParseIP(rowStr); IsIPv4(ip) {
parsedIPs = append(parsedIPs, ip)
} else {
if !f.Quiet {
Expand Down Expand Up @@ -165,7 +112,7 @@ func CmdToolAggregate(
}

// Vars to contain CIDRs/IPs from all input sources.
parsedCIDRs := make([]net.IPNet, 0)
parsedCIDRs := make([]string, 0)
parsedIPs := make([]net.IP, 0)

// Collect CIDRs/IPs from stdin.
Expand All @@ -187,93 +134,106 @@ func CmdToolAggregate(
rows := scanrdr(file)
file.Close()
cidrs, ips := parseInput(rows)

parsedCIDRs = append(parsedCIDRs, cidrs...)
parsedIPs = append(parsedIPs, ips...)
}

// Sort and merge collected CIDRs and IPs.
aggregatedCIDRs := aggregateCIDRs(parsedCIDRs)
adjacentCombined := combineAdjacent(stripOverlapping(list(parsedCIDRs)))

outlierIPs := make([]net.IP, 0)
length := len(aggregatedCIDRs)
for _, ip := range parsedIPs {
for i, cidr := range aggregatedCIDRs {
if cidr.Contains(ip) {
break
} else if i == length-1 {
outlierIPs = append(outlierIPs, ip)
length := len(adjacentCombined)
if length != 0 {
for _, ip := range parsedIPs {
for i, cidr := range adjacentCombined {
if cidr.Network.Contains(ip) {
break
} else if i == length-1 {
outlierIPs = append(outlierIPs, ip)
}
}
}
} else {
outlierIPs = append(outlierIPs, parsedIPs...)
}

// Print the aggregated CIDRs.
for _, r := range aggregatedCIDRs {
for _, r := range adjacentCombined {
fmt.Println(r.String())
}

// Print outliers.
// Print the outlierIPs.
for _, r := range outlierIPs {
fmt.Println(r.String())
}

return nil
}

// Helper function to aggregate IP ranges.
func aggregateCIDRs(cidrs []net.IPNet) []net.IPNet {
aggregatedCIDRs := make([]net.IPNet, 0)

// Sort CIDRs by starting IP.
sortCIDRs(cidrs)

for _, r := range cidrs {
if len(aggregatedCIDRs) == 0 {
aggregatedCIDRs = append(aggregatedCIDRs, r)
// stripOverlapping returns a slice of CIDR structures with overlapping ranges
// stripped.
func stripOverlapping(s []*CIDR) []*CIDR {
l := len(s)
for i := 0; i < l-1; i++ {
if s[i] == nil {
continue
}

last := len(aggregatedCIDRs) - 1
prev := aggregatedCIDRs[last]

if canAggregate(prev, r) {
// Merge overlapping CIDRs.
aggregatedCIDRs[last] = aggregateCIDR(prev, r)
} else {
aggregatedCIDRs = append(aggregatedCIDRs, r)
for j := i + 1; j < l; j++ {
if overlaps(s[j], s[i]) {
s[j] = nil
}
}
}

return aggregatedCIDRs
}

// Helper function to sort IP ranges by starting IP.
func sortCIDRs(ipRanges []net.IPNet) {
sort.SliceStable(ipRanges, func(i, j int) bool {
return bytes.Compare(ipRanges[i].IP, ipRanges[j].IP) < 0
})
return filter(s)
}

// Helper function to check if two CIDRs can be aggregated.
func canAggregate(r1, r2 net.IPNet) bool {
return r1.Contains(r2.IP) || r2.Contains(r1.IP)
func overlaps(a, b *CIDR) bool {
return (a.PrefixUint32() / (1 << (32 - b.MaskLen()))) ==
(b.PrefixUint32() / (1 << (32 - b.MaskLen())))
}

// Helper function to aggregate two CIDRs.
func aggregateCIDR(r1, r2 net.IPNet) net.IPNet {
mask1, _ := r1.Mask.Size()
mask2, _ := r2.Mask.Size()

ipLen := net.IPv6len * 8
if r1.IP.To4() != nil {
ipLen = net.IPv4len * 8
}
// combineAdjacent returns a slice of CIDR structures with adjacent ranges
// combined.
func combineAdjacent(s []*CIDR) []*CIDR {
for {
found := false
l := len(s)
for i := 0; i < l-1; i++ {
if s[i] == nil {
continue
}
for j := i + 1; j < l; j++ {
if s[j] == nil {
continue
}
if adjacent(s[i], s[j]) {
c := fmt.Sprintf("%s/%d", s[i].IP.String(), s[i].MaskLen()-1)
s[i] = newCidr(c)
s[j] = nil
found = true
}
}
}

// Find the common prefix length
commonPrefixLen := mask1
if mask2 < commonPrefixLen {
commonPrefixLen = mask2
if !found {
break
}
}
return filter(s)
}

commonPrefix := r1.IP.Mask(net.CIDRMask(commonPrefixLen, ipLen))
func adjacent(a, b *CIDR) bool {
return (a.MaskLen() == b.MaskLen()) &&
(a.PrefixUint32()%(2<<(32-b.MaskLen())) == 0) &&
(b.PrefixUint32()-a.PrefixUint32() == (1 << (32 - a.MaskLen())))
}

return net.IPNet{IP: commonPrefix, Mask: net.CIDRMask(commonPrefixLen, ipLen)}
func filter(s []*CIDR) []*CIDR {
out := s[:0]
for _, x := range s {
if x != nil {
out = append(out, x)
}
}
return out
}

0 comments on commit d2f94f9

Please sign in to comment.