diff --git a/README.md b/README.md index 811dd69..0d50e0c 100644 --- a/README.md +++ b/README.md @@ -84,12 +84,35 @@ func main() { ## Name Resolution -Go has a built-in name resolver that sidesteps CGO (e.g. `getaddrinfo(3)`) -calls. +There are two methods available for resolving a set of IP addresses +for a hostname. -This library will automatically configure the `net.DefaultResolver` -from the standard library to use the `Dial` function from this library. -You just need the following import somewhere: +### getaddrinfo + +The `sock_getaddrinfo` host function is used to implement name resolution. +This requires WasmEdge, or a WasmEdge compatible WASI layer +(e.g. [wasi-go](http://github.com/stealthrocket/wasi-go)). + +When using this method, the standard library resolver **will not work**. You +_cannot_ use `net.DefaultResolver`, `net.LookupIP`, etc. with this approach +because the standard library does not allow us to patch it with an alternative +implementation. + +Note that `sock_getaddrinfo` may block! + +### Pure Go Resolver + +The pure Go name resolver is not currently enabled for GOOS=wasip1. + +The following series of CLs will change this: https://go-review.googlesource.com/c/go/+/500576. +This will hopefully land in Go v1.22 in ~February 2024. + +If you're using a version of Go that has the CL's included, you can +instruct this library to use the pure Go resolver by including the +`purego` build tag. + +The library will then automatically configure the `net.DefaultResolver`. +All you need is the following import somewhere in your application: ```go import _ "github.com/stealthrocket/net" @@ -97,6 +120,3 @@ import _ "github.com/stealthrocket/net" You should then be able to use the lookup functions from the standard library (e.g. `net.LookupIP(host)`). - -Note that name resolution currently depends on the following series of CLs: -https://go-review.googlesource.com/c/go/+/500576 diff --git a/dial_wasip1.go b/dial_wasip1.go index 943dc8b..bfbc96c 100644 --- a/dial_wasip1.go +++ b/dial_wasip1.go @@ -11,8 +11,6 @@ import ( ) func init() { - net.DefaultResolver.Dial = DialContext - if t, ok := http.DefaultTransport.(*http.Transport); ok { t.DialContext = DialContext } diff --git a/lookup_wasip1.go b/lookup_wasip1.go index d4b4934..5966b1d 100644 --- a/lookup_wasip1.go +++ b/lookup_wasip1.go @@ -3,67 +3,63 @@ package net import ( "fmt" "net" + + "github.com/stealthrocket/net/syscall" ) func lookupAddr(context, network, address string) (net.Addr, error) { + var hints syscall.AddrInfo switch network { case "tcp", "tcp4", "tcp6": + hints.SocketType = syscall.SOCK_STREAM + hints.Protocol = syscall.IPPROTO_TCP case "udp", "udp4", "udp6": + hints.SocketType = syscall.SOCK_DGRAM + hints.Protocol = syscall.IPPROTO_UDP case "unix", "unixgram": return &net.UnixAddr{Name: address, Net: network}, nil default: return nil, fmt.Errorf("not implemented: %s", network) } - host, portstr, err := net.SplitHostPort(address) - if err != nil { - return nil, err + switch network { + case "tcp", "udp": + hints.Family = syscall.AF_UNSPEC + case "tcp4", "udp4": + hints.Family = syscall.AF_INET + case "tcp6", "udp6": + hints.Family = syscall.AF_INET6 } - port, err := net.LookupPort(network, portstr) + hostname, service, err := net.SplitHostPort(address) if err != nil { return nil, err } - if host == "" { - if context == "listen" { - switch network { - case "tcp", "tcp4": - return &net.TCPAddr{IP: net.IPv4zero, Port: port}, nil - case "tcp6": - return &net.TCPAddr{IP: net.IPv6zero, Port: port}, nil - } - } - return nil, fmt.Errorf("invalid address %q for %s", address, context) + if context == "listen" && hostname == "" { + hints.Flags |= syscall.AI_PASSIVE } - ips, err := net.LookupIP(host) + + results := make([]syscall.AddrInfo, 16) + n, err := syscall.Getaddrinfo(hostname, service, hints, results) if err != nil { return nil, err } - if network == "tcp" || network == "tcp4" { - for _, ip := range ips { - if len(ip) == net.IPv4len { - return &net.TCPAddr{IP: ip, Port: port}, nil - } + results = results[:n] + for _, r := range results { + var ip net.IP + var port int + switch a := r.Address.(type) { + case *syscall.SockaddrInet4: + ip = a.Addr[:] + port = a.Port + case *syscall.SockaddrInet6: + ip = a.Addr[:] + port = a.Port } - } - if network == "tcp" || network == "tcp6" { - for _, ip := range ips { - if len(ip) == net.IPv6len { - return &net.TCPAddr{IP: ip, Port: port}, nil - } - } - } - if network == "udp" || network == "udp4" { - for _, ip := range ips { - if len(ip) == net.IPv4len { - return &net.UDPAddr{IP: ip, Port: port}, nil - } - } - } - if network == "udp" || network == "udp6" { - for _, ip := range ips { - if len(ip) == net.IPv6len { - return &net.UDPAddr{IP: ip, Port: port}, nil - } + switch network { + case "tcp", "tcp4", "tcp6": + return &net.TCPAddr{IP: ip, Port: port}, nil + case "udp", "udp4", "udp6": + return &net.UDPAddr{IP: ip, Port: port}, nil } } - return nil, fmt.Errorf("cannot listen on %q", host) + return nil, fmt.Errorf("lookup failed: %q", address) } diff --git a/lookup_wasip1_purego.go b/lookup_wasip1_purego.go new file mode 100644 index 0000000..fe81ec4 --- /dev/null +++ b/lookup_wasip1_purego.go @@ -0,0 +1,75 @@ +//go:build wasip1 && purego + +package net + +import ( + "fmt" + "net" +) + +func init() { + net.DefaultResolver.Dial = DialContext +} + +func lookupAddr(context, network, address string) (net.Addr, error) { + switch network { + case "tcp", "tcp4", "tcp6": + case "udp", "udp4", "udp6": + case "unix", "unixgram": + return &net.UnixAddr{Name: address, Net: network}, nil + default: + return nil, fmt.Errorf("not implemented: %s", network) + } + hostname, service, err := net.SplitHostPort(address) + if err != nil { + return nil, err + } + port, err := net.LookupPort(network, service) + if err != nil { + return nil, err + } + if hostname == "" { + if context == "listen" { + switch network { + case "tcp", "tcp4": + return &net.TCPAddr{IP: net.IPv4zero, Port: port}, nil + case "tcp6": + return &net.TCPAddr{IP: net.IPv6zero, Port: port}, nil + } + } + return nil, fmt.Errorf("invalid address %q for %s", address, context) + } + ips, err := net.LookupIP(hostname) + if err != nil { + return nil, err + } + if network == "tcp" || network == "tcp4" { + for _, ip := range ips { + if len(ip) == net.IPv4len { + return &net.TCPAddr{IP: ip, Port: port}, nil + } + } + } + if network == "tcp" || network == "tcp6" { + for _, ip := range ips { + if len(ip) == net.IPv6len { + return &net.TCPAddr{IP: ip, Port: port}, nil + } + } + } + if network == "udp" || network == "udp4" { + for _, ip := range ips { + if len(ip) == net.IPv4len { + return &net.UDPAddr{IP: ip, Port: port}, nil + } + } + } + if network == "udp" || network == "udp6" { + for _, ip := range ips { + if len(ip) == net.IPv6len { + return &net.UDPAddr{IP: ip, Port: port}, nil + } + } + } + return nil, fmt.Errorf("lookup failed: %q", address) +} diff --git a/syscall/net_wasip1.go b/syscall/net_wasip1.go index 8fc67fb..24de460 100644 --- a/syscall/net_wasip1.go +++ b/syscall/net_wasip1.go @@ -5,19 +5,20 @@ package syscall import ( + "encoding/binary" "runtime" "unsafe" ) const ( - _ = iota + AF_UNSPEC = iota AF_INET AF_INET6 AF_UNIX ) const ( - _ = iota + SOCK_ANY = iota SOCK_DGRAM SOCK_STREAM ) @@ -34,6 +35,19 @@ const ( SO_BROADCAST ) +const ( + AI_PASSIVE = 1 << iota + _ + AI_NUMERICHOST + AI_NUMERICSERV +) + +const ( + IPPROTO_IP = iota + IPPROTO_TCP + IPPROTO_UDP +) + type Sockaddr interface { sockaddr() (unsafe.Pointer, error) sockport() int @@ -126,6 +140,18 @@ func sock_getlocaladdr(fd int32, addr unsafe.Pointer, port unsafe.Pointer) Errno //go:wasmimport wasi_snapshot_preview1 sock_getpeeraddr func sock_getpeeraddr(fd int32, addr unsafe.Pointer, port unsafe.Pointer) Errno +//go:wasmimport wasi_snapshot_preview1 sock_getaddrinfo +func sock_getaddrinfo( + node unsafe.Pointer, + nodeLen uint32, + service unsafe.Pointer, + serviceLen uint32, + hints unsafe.Pointer, + res unsafe.Pointer, + maxResLen uint32, + resLen unsafe.Pointer, +) uint32 + func Socket(proto, sotype, unused int) (fd int, err error) { var newfd int32 errno := sock_open(int32(proto), int32(sotype), unsafe.Pointer(&newfd)) @@ -211,3 +237,112 @@ func anyToSockaddr(rsa *RawSockaddrAny, port int) (Sockaddr, error) { return nil, ENOTSUP } } + +// https://github.com/WasmEdge/WasmEdge/blob/434e1fb4690/thirdparty/wasi/api.hpp#L1885 +type addrInfo struct { + ai_flags uint16 + ai_family uint8 + ai_socktype uint8 + ai_protocol uint32 + ai_addrlen uint32 + ai_addr uintptr32 // *sockAddr + ai_canonname uintptr32 // null-terminated string + ai_canonnamelen uint32 + ai_next uintptr32 // *addrInfo +} + +type sockAddr struct { + sa_family uint32 + sa_data_len uint32 + sa_data uintptr32 + _ [4]byte +} + +type AddrInfo struct { + Flags int + Family int + SocketType int + Protocol int + Address Sockaddr + CanonicalName string + + addrInfo + sockAddr + sockData [26]byte + cannoname [30]byte + inet4addr SockaddrInet4 + inet6addr SockaddrInet6 +} + +func Getaddrinfo(name, service string, hints AddrInfo, results []AddrInfo) (int, error) { + // For compatibility with WasmEdge, make sure strings are null-terminated. + if len(name) > 0 && name[len(name)-1] != 0 { + name = string(append([]byte(name), 0)) + } + if len(service) > 0 && service[len(service)-1] != 0 { + service = string(append([]byte(service), 0)) + } + + hints.addrInfo = addrInfo{ + ai_flags: uint16(hints.Flags), + ai_family: uint8(hints.Family), + ai_socktype: uint8(hints.SocketType), + ai_protocol: uint32(hints.Protocol), + } + for i := range results { + results[i].sockAddr = sockAddr{ + sa_family: 0, + sa_data_len: uint32(unsafe.Sizeof(AddrInfo{}.sockData)), + sa_data: uintptr32(uintptr(unsafe.Pointer(&results[i].sockData))), + } + results[i].addrInfo = addrInfo{ + ai_flags: 0, + ai_family: 0, + ai_socktype: 0, + ai_protocol: 0, + ai_addrlen: uint32(unsafe.Sizeof(sockAddr{})), + ai_addr: uintptr32(uintptr(unsafe.Pointer(&results[i].sockAddr))), + ai_canonname: uintptr32(uintptr(unsafe.Pointer(&results[i].cannoname))), + ai_canonnamelen: uint32(unsafe.Sizeof(AddrInfo{}.cannoname)), + } + if i > 0 { + results[i-1].addrInfo.ai_next = uintptr32(uintptr(unsafe.Pointer(&results[i-1].addrInfo))) + } + } + + resPtr := uintptr32(uintptr(unsafe.Pointer(&results[0].addrInfo))) + + var n uint32 + errno := sock_getaddrinfo( + unsafe.Pointer(unsafe.StringData(name)), + uint32(len(name)), + unsafe.Pointer(unsafe.StringData(service)), + uint32(len(service)), + unsafe.Pointer(&hints.addrInfo), + unsafe.Pointer(&resPtr), + uint32(len(results)), + unsafe.Pointer(&n), + ) + if errno != 0 { + return 0, errnoErr(Errno(errno)) + } + + for i := range results[:n] { + r := &results[i] + port := binary.BigEndian.Uint16(results[i].sockData[:2]) + switch results[i].sockAddr.sa_family { + case AF_INET: + r.inet4addr.Port = int(port) + copy(r.inet4addr.Addr[:], results[i].sockData[2:]) + r.Address = &r.inet4addr + case AF_INET6: + r.inet6addr.Port = int(port) + r.Address = &r.inet6addr + copy(r.inet4addr.Addr[:], results[i].sockData[2:]) + default: + r.Address = nil + } + // TODO: canonical names + } + return int(n), nil +}