/
cidr.go
418 lines (355 loc) · 10.6 KB
/
cidr.go
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
package topdown
import (
"bytes"
"errors"
"fmt"
"math/big"
"net"
"sort"
"github.com/open-policy-agent/opa/ast"
cidrMerge "github.com/open-policy-agent/opa/internal/cidr/merge"
"github.com/open-policy-agent/opa/topdown/builtins"
)
func getNetFromOperand(v ast.Value) (*net.IPNet, error) {
subnetStringA, err := builtins.StringOperand(v, 1)
if err != nil {
return nil, err
}
_, cidrnet, err := net.ParseCIDR(string(subnetStringA))
if err != nil {
return nil, err
}
return cidrnet, nil
}
func getLastIP(cidr *net.IPNet) (net.IP, error) {
prefixLen, bits := cidr.Mask.Size()
if prefixLen == 0 && bits == 0 {
// non-standard mask, see https://golang.org/pkg/net/#IPMask.Size
return nil, fmt.Errorf("CIDR mask is in non-standard format")
}
var lastIP []byte
if prefixLen == bits {
// Special case for single ip address ranges ex: 192.168.1.1/32
// We can just use the starting IP as the last IP
lastIP = cidr.IP
} else {
// Use big.Int's so we can handle ipv6 addresses
firstIPInt := new(big.Int)
firstIPInt.SetBytes(cidr.IP)
hostLen := uint(bits) - uint(prefixLen)
lastIPInt := big.NewInt(1)
lastIPInt.Lsh(lastIPInt, hostLen)
lastIPInt.Sub(lastIPInt, big.NewInt(1))
lastIPInt.Or(lastIPInt, firstIPInt)
ipBytes := lastIPInt.Bytes()
lastIP = make([]byte, bits/8)
// Pack our IP bytes into the end of the return array,
// since big.Int.Bytes() removes front zero padding.
for i := 1; i <= len(lastIPInt.Bytes()); i++ {
lastIP[len(lastIP)-i] = ipBytes[len(ipBytes)-i]
}
}
return lastIP, nil
}
func builtinNetCIDRIntersects(_ BuiltinContext, operands []*ast.Term, iter func(*ast.Term) error) error {
cidrnetA, err := getNetFromOperand(operands[0].Value)
if err != nil {
return err
}
cidrnetB, err := getNetFromOperand(operands[1].Value)
if err != nil {
return err
}
// If either net contains the others starting IP they are overlapping
cidrsOverlap := cidrnetA.Contains(cidrnetB.IP) || cidrnetB.Contains(cidrnetA.IP)
return iter(ast.BooleanTerm(cidrsOverlap))
}
func builtinNetCIDRContains(_ BuiltinContext, operands []*ast.Term, iter func(*ast.Term) error) error {
cidrnetA, err := getNetFromOperand(operands[0].Value)
if err != nil {
return err
}
// b could be either an IP addressor CIDR string, try to parse it as an IP first, fall back to CIDR
bStr, err := builtins.StringOperand(operands[1].Value, 1)
if err != nil {
return err
}
ip := net.ParseIP(string(bStr))
if ip != nil {
return iter(ast.BooleanTerm(cidrnetA.Contains(ip)))
}
// It wasn't an IP, try and parse it as a CIDR
cidrnetB, err := getNetFromOperand(operands[1].Value)
if err != nil {
return fmt.Errorf("not a valid textual representation of an IP address or CIDR: %s", string(bStr))
}
// We can determine if cidr A contains cidr B if and only if A contains
// the starting address of B and the last address in B.
cidrContained := false
if cidrnetA.Contains(cidrnetB.IP) {
// Only spend time calculating the last IP if the starting IP is already verified to be in cidr A
lastIP, err := getLastIP(cidrnetB)
if err != nil {
return err
}
cidrContained = cidrnetA.Contains(lastIP)
}
return iter(ast.BooleanTerm(cidrContained))
}
var errNetCIDRContainsMatchElementType = errors.New("element must be string or non-empty array")
func getCIDRMatchTerm(a *ast.Term) (*ast.Term, error) {
switch v := a.Value.(type) {
case ast.String:
return a, nil
case *ast.Array:
if v.Len() == 0 {
return nil, errNetCIDRContainsMatchElementType
}
return v.Elem(0), nil
default:
return nil, errNetCIDRContainsMatchElementType
}
}
func evalNetCIDRContainsMatchesOperand(operand int, a *ast.Term, iter func(cidr, index *ast.Term) error) error {
switch v := a.Value.(type) {
case ast.String:
return iter(a, a)
case *ast.Array:
for i := 0; i < v.Len(); i++ {
cidr, err := getCIDRMatchTerm(v.Elem(i))
if err != nil {
return fmt.Errorf("operand %v: %v", operand, err)
}
if err := iter(cidr, ast.IntNumberTerm(i)); err != nil {
return err
}
}
return nil
case ast.Set:
return v.Iter(func(x *ast.Term) error {
cidr, err := getCIDRMatchTerm(x)
if err != nil {
return fmt.Errorf("operand %v: %v", operand, err)
}
return iter(cidr, x)
})
case ast.Object:
return v.Iter(func(k, v *ast.Term) error {
cidr, err := getCIDRMatchTerm(v)
if err != nil {
return fmt.Errorf("operand %v: %v", operand, err)
}
return iter(cidr, k)
})
}
return nil
}
func builtinNetCIDRContainsMatches(_ BuiltinContext, operands []*ast.Term, iter func(*ast.Term) error) error {
result := ast.NewSet()
err := evalNetCIDRContainsMatchesOperand(1, operands[0], func(cidr1 *ast.Term, index1 *ast.Term) error {
return evalNetCIDRContainsMatchesOperand(2, operands[1], func(cidr2 *ast.Term, index2 *ast.Term) error {
if v, err := getResult(builtinNetCIDRContains, cidr1, cidr2); err != nil {
return err
} else if vb, ok := v.Value.(ast.Boolean); ok && bool(vb) {
result.Add(ast.ArrayTerm(index1, index2))
}
return nil
})
})
if err == nil {
return iter(ast.NewTerm(result))
}
return err
}
func builtinNetCIDRExpand(bctx BuiltinContext, operands []*ast.Term, iter func(*ast.Term) error) error {
s, err := builtins.StringOperand(operands[0].Value, 1)
if err != nil {
return err
}
ip, ipNet, err := net.ParseCIDR(string(s))
if err != nil {
return err
}
result := ast.NewSet()
for ip := ip.Mask(ipNet.Mask); ipNet.Contains(ip); incIP(ip) {
if bctx.Cancel != nil && bctx.Cancel.Cancelled() {
return Halt{
Err: &Error{
Code: CancelErr,
Message: "net.cidr_expand: timed out before generating all IP addresses",
},
}
}
result.Add(ast.StringTerm(ip.String()))
}
return iter(ast.NewTerm(result))
}
func builtinNetCIDRIsValid(_ BuiltinContext, operands []*ast.Term, iter func(*ast.Term) error) error {
cidr, err := builtins.StringOperand(operands[0].Value, 1)
if err != nil {
return iter(ast.BooleanTerm(false))
}
if _, _, err := net.ParseCIDR(string(cidr)); err != nil {
return iter(ast.BooleanTerm(false))
}
return iter(ast.BooleanTerm(true))
}
type cidrBlockRange struct {
First *net.IP
Last *net.IP
Network *net.IPNet
}
type cidrBlockRanges []*cidrBlockRange
// Implement Sort interface
func (c cidrBlockRanges) Len() int {
return len(c)
}
func (c cidrBlockRanges) Swap(i, j int) {
c[i], c[j] = c[j], c[i]
}
func (c cidrBlockRanges) Less(i, j int) bool {
// Compare last IP.
cmp := bytes.Compare(*c[i].Last, *c[j].Last)
if cmp < 0 {
return true
} else if cmp > 0 {
return false
}
// Then compare first IP.
cmp = bytes.Compare(*c[i].First, *c[i].First)
if cmp < 0 {
return true
} else if cmp > 0 {
return false
}
// Ranges are Equal.
return false
}
// builtinNetCIDRMerge merges the provided list of IP addresses and subnets into the smallest possible list of CIDRs.
// It merges adjacent subnets where possible, those contained within others and also removes any duplicates.
// Original Algorithm: https://github.com/netaddr/netaddr.
func builtinNetCIDRMerge(_ BuiltinContext, operands []*ast.Term, iter func(*ast.Term) error) error {
networks := []*net.IPNet{}
switch v := operands[0].Value.(type) {
case *ast.Array:
for i := 0; i < v.Len(); i++ {
network, err := generateIPNet(v.Elem(i))
if err != nil {
return err
}
networks = append(networks, network)
}
case ast.Set:
err := v.Iter(func(x *ast.Term) error {
network, err := generateIPNet(x)
if err != nil {
return err
}
networks = append(networks, network)
return nil
})
if err != nil {
return err
}
default:
return errors.New("operand must be an array")
}
merged := evalNetCIDRMerge(networks)
result := ast.NewSet()
for _, network := range merged {
result.Add(ast.StringTerm(network.String()))
}
return iter(ast.NewTerm(result))
}
func evalNetCIDRMerge(networks []*net.IPNet) []*net.IPNet {
if len(networks) == 0 {
return nil
}
ranges := make(cidrBlockRanges, 0, len(networks))
// For each CIDR, create an IP range. Sort them and merge when possible.
for _, network := range networks {
firstIP, lastIP := cidrMerge.GetAddressRange(*network)
ranges = append(ranges, &cidrBlockRange{
First: &firstIP,
Last: &lastIP,
Network: network,
})
}
// merge CIDRs.
merged := mergeCIDRs(ranges)
// convert ranges into an equivalent list of net.IPNet.
result := []*net.IPNet{}
for _, r := range merged {
// Not merged with any other CIDR.
if r.Network != nil {
result = append(result, r.Network)
} else {
// Find new network that represents the merged range.
rangeCIDRs := cidrMerge.RangeToCIDRs(*r.First, *r.Last)
result = append(result, rangeCIDRs...)
}
}
return result
}
func generateIPNet(term *ast.Term) (*net.IPNet, error) {
e, ok := term.Value.(ast.String)
if !ok {
return nil, errors.New("element must be string")
}
// try to parse element as an IP first, fall back to CIDR
ip := net.ParseIP(string(e))
if ip == nil {
_, network, err := net.ParseCIDR(string(e))
return network, err
}
if ip.To4() != nil {
return &net.IPNet{
IP: ip,
Mask: ip.DefaultMask(),
}, nil
}
return nil, errors.New("IPv6 invalid: needs prefix length")
}
func mergeCIDRs(ranges cidrBlockRanges) cidrBlockRanges {
sort.Sort(ranges)
// Merge adjacent CIDRs if possible.
for i := len(ranges) - 1; i > 0; i-- {
previousIP := cidrMerge.GetPreviousIP(*ranges[i].First)
// If the previous IP of the current network overlaps
// with the last IP of the previous network in the
// list, then merge the two ranges together.
if bytes.Compare(previousIP, *ranges[i-1].Last) <= 0 {
var firstIP *net.IP
if bytes.Compare(*ranges[i-1].First, *ranges[i].First) < 0 {
firstIP = ranges[i-1].First
} else {
firstIP = ranges[i].First
}
lastIPRange := make(net.IP, len(*ranges[i].Last))
copy(lastIPRange, *ranges[i].Last)
firstIPRange := make(net.IP, len(*firstIP))
copy(firstIPRange, *firstIP)
ranges[i-1] = &cidrBlockRange{First: &firstIPRange, Last: &lastIPRange, Network: nil}
// Delete ranges[i] since merged with the previous.
ranges = append(ranges[:i], ranges[i+1:]...)
}
}
return ranges
}
func incIP(ip net.IP) {
for j := len(ip) - 1; j >= 0; j-- {
ip[j]++
if ip[j] > 0 {
break
}
}
}
func init() {
RegisterBuiltinFunc(ast.NetCIDROverlap.Name, builtinNetCIDRContains)
RegisterBuiltinFunc(ast.NetCIDRIntersects.Name, builtinNetCIDRIntersects)
RegisterBuiltinFunc(ast.NetCIDRContains.Name, builtinNetCIDRContains)
RegisterBuiltinFunc(ast.NetCIDRContainsMatches.Name, builtinNetCIDRContainsMatches)
RegisterBuiltinFunc(ast.NetCIDRExpand.Name, builtinNetCIDRExpand)
RegisterBuiltinFunc(ast.NetCIDRMerge.Name, builtinNetCIDRMerge)
RegisterBuiltinFunc(ast.NetCIDRIsValid.Name, builtinNetCIDRIsValid)
}