-
Notifications
You must be signed in to change notification settings - Fork 22
/
inet.go
149 lines (135 loc) · 3.28 KB
/
inet.go
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
package scalar
import (
"encoding"
"fmt"
"net"
"strings"
"github.com/apache/arrow/go/v16/arrow"
"github.com/cloudquery/plugin-sdk/v4/types"
)
type Inet struct {
Valid bool
Value *net.IPNet
}
func (s *Inet) IsValid() bool {
return s.Valid
}
func (*Inet) DataType() arrow.DataType {
return types.ExtensionTypes.Inet
}
func (s *Inet) Equal(rhs Scalar) bool {
if rhs == nil {
return false
}
r, ok := rhs.(*Inet)
if !ok {
return false
}
return s.Valid == r.Valid && s.Value.String() == r.Value.String()
}
func (s *Inet) String() string {
if !s.Valid {
return nullValueStr
}
return s.Value.String()
}
func (s *Inet) Get() any {
return s.Value
}
func (s *Inet) Set(val any) error {
if val == nil {
return nil
}
if sc, ok := val.(Scalar); ok {
if !sc.IsValid() {
s.Valid = false
return nil
}
return s.Set(sc.Get())
}
switch value := val.(type) {
case net.IPNet:
s.Value = &value
case net.IP:
if len(value) == 0 {
return nil
}
bitCount := len(value) * 8
mask := net.CIDRMask(bitCount, bitCount)
s.Value = &net.IPNet{Mask: mask, IP: value}
case string:
ip, ipnet, err := net.ParseCIDR(value)
if err != nil {
ip := net.ParseIP(value)
if ip == nil {
return &ValidationError{Type: types.ExtensionTypes.Inet, Msg: "cannot parse string as IP", Value: value}
}
if ipv4 := maybeGetIPv4(value, ip); ipv4 != nil {
ipnet = &net.IPNet{IP: ipv4, Mask: net.CIDRMask(32, 32)}
} else {
ipnet = &net.IPNet{IP: ip, Mask: net.CIDRMask(128, 128)}
}
} else {
ipnet.IP = ip
if ipv4 := maybeGetIPv4(value, ipnet.IP); ipv4 != nil {
ipnet.IP = ipv4
if len(ipnet.Mask) == 16 {
ipnet.Mask = ipnet.Mask[12:] // Not sure this is ever needed.
}
}
}
s.Value = ipnet
case *net.IPNet:
if value == nil {
s.Valid = false
return nil
}
return s.Set(*value)
case *net.IP:
if value == nil {
s.Valid = false
return nil
}
return s.Set(*value)
case *string:
if value == nil {
s.Valid = false
return nil
}
return s.Set(*value)
default:
if tv, ok := value.(encoding.TextMarshaler); ok {
text, err := tv.MarshalText()
if err != nil {
return &ValidationError{Type: types.ExtensionTypes.UUID, Msg: "cannot marshal text", Err: err, Value: val}
}
return s.Set(string(text))
}
if sv, ok := value.(fmt.Stringer); ok {
return s.Set(sv.String())
}
if originalSrc, ok := underlyingPtrType(val); ok {
return s.Set(originalSrc)
}
return &ValidationError{Type: types.ExtensionTypes.UUID, Msg: noConversion, Value: value}
}
s.Valid = true
return nil
}
// Convert the net.IP to IPv4, if appropriate.
//
// When parsing a string to a net.IP using net.ParseIP() and the like, we get a
// 16 byte slice for IPv4 addresses as well as IPv6 addresses. This function
// calls To4() to convert them to a 4 byte slice. This is useful as it allows
// users of the net.IP check for IPv4 addresses based on the length and makes
// it clear we are handling IPv4 as opposed to IPv6 or IPv4-mapped IPv6
// addresses.
func maybeGetIPv4(input string, ip net.IP) net.IP {
// Do not do this if the provided input looks like IPv6. This is because
// To4() on IPv4-mapped IPv6 addresses converts them to IPv4, which behave
// different in some cases.
if strings.Contains(input, ":") {
return nil
}
return ip.To4()
}