diff --git a/.github/workflows/build.yml b/.github/workflows/build.yml new file mode 100644 index 0000000..f346156 --- /dev/null +++ b/.github/workflows/build.yml @@ -0,0 +1,33 @@ +name: build +on: + push: + tags: + - v* + branches: + - main + pull_request: + branches: + - main + +concurrency: + group: ${{ github.workflow }}-${{ github.event.number || github.ref }} + cancel-in-progress: true + +jobs: + test: + name: Go Test + runs-on: ubuntu-latest + steps: + - uses: actions/checkout@v3 + - uses: actions/setup-go@v4 + with: + go-version-file: go.mod + check-latest: true + + - name: Install Go tip + run: | + curl -sL https://storage.googleapis.com/go-build-snap/go/linux-amd64/$(git ls-remote https://github.com/golang/go.git HEAD | awk '{print $1;}').tar.gz -o gotip.tar.gz + ls -lah gotip.tar.gz + mkdir -p $HOME/gotip + tar -C $HOME/gotip -xzf gotip.tar.gz + - run: make test GO=$HOME/gotip/bin/go GOPATH=$HOME/gotip diff --git a/.gitignore b/.gitignore new file mode 100644 index 0000000..e6a7f7e --- /dev/null +++ b/.gitignore @@ -0,0 +1,19 @@ +# Binaries for programs and plugins +*.exe +*.exe~ +*.dll +*.so +*.dylib +wasirun + +# Test binary, built with `go test -c` +*.test + +# Output of the go coverage tool, specifically when used with LiteIDE +*.out + +# Dependency directories (remove the comment below to include it) +# vendor/ + +# Emacs +*~ diff --git a/Makefile b/Makefile new file mode 100644 index 0000000..39308c9 --- /dev/null +++ b/Makefile @@ -0,0 +1,15 @@ +.PHONY: test lint wasirun + +GOPATH ?= $(shell $(GO) env GOPATH) +wasirun = $(GOPATH)/bin/wasirun + +wasip1.test: go.mod $(wildcard wasip1/*.go) + GOARCH=wasm GOOS=wasip1 $(GO) test -c ./wasip1 + +test: wasirun wasip1.test + $(wasirun) wasip1.test -test.v + +wasirun: $(wasirun) + +$(wasirun): + $(GO) install github.com/stealthrocket/wasi-go/cmd/wasirun@latest diff --git a/go.mod b/go.mod index ab42d04..446b338 100644 --- a/go.mod +++ b/go.mod @@ -1,3 +1,5 @@ module github.com/stealthrocket/net go 1.20 + +require golang.org/x/net v0.10.0 // indirect diff --git a/go.sum b/go.sum new file mode 100644 index 0000000..ac972c8 --- /dev/null +++ b/go.sum @@ -0,0 +1,2 @@ +golang.org/x/net v0.10.0 h1:X2//UzNDwYmtCLn7To6G58Wr6f5ahEAQgKNzv9Y951M= +golang.org/x/net v0.10.0/go.mod h1:0qNGK6F8kojg2nk9dLZ2mShWaEBan6FAoqfSigmmuDg= diff --git a/wasip1/dial_wasip1.go b/wasip1/dial_wasip1.go index 19cea86..7f81f40 100644 --- a/wasip1/dial_wasip1.go +++ b/wasip1/dial_wasip1.go @@ -9,37 +9,33 @@ import ( // Dial connects to the address on the named network. func Dial(network, address string) (net.Conn, error) { + return DialContext(context.Background(), network, address) +} + +// DialContext is a variant of Dial that accepts a context. +func DialContext(ctx context.Context, network, address string) (net.Conn, error) { addr, err := lookupAddr("dial", network, address) if err != nil { addr := &netAddr{network, address} return nil, dialErr(addr, err) } - conn, err := dialAddr(addr) + conn, err := dialAddr(ctx, addr) if err != nil { return nil, dialErr(addr, err) } return conn, nil } -// DialContext is a variant of Dial that accepts a context. -func DialContext(ctx context.Context, network, address string) (net.Conn, error) { - select { - case <-ctx.Done(): - addr := &netAddr{network, address} - return nil, dialErr(addr, context.Cause(ctx)) - default: - return Dial(network, address) - } -} - func dialErr(addr net.Addr, err error) error { return newOpError("dial", addr, err) } -func dialAddr(addr net.Addr) (net.Conn, error) { +func dialAddr(ctx context.Context, addr net.Addr) (net.Conn, error) { proto := family(addr) - sotype := socketType(addr) - + sotype, err := socketType(addr) + if err != nil { + return nil, os.NewSyscallError("socket", err) + } fd, err := socket(proto, sotype, 0) if err != nil { return nil, os.NewSyscallError("socket", err) @@ -49,7 +45,6 @@ func dialAddr(addr net.Addr) (net.Conn, error) { syscall.Close(fd) return nil, os.NewSyscallError("setnonblock", err) } - if sotype == SOCK_DGRAM && proto != AF_UNIX { if err := setsockopt(fd, SOL_SOCKET, SO_BROADCAST, 1); err != nil { syscall.Close(fd) @@ -61,7 +56,6 @@ func dialAddr(addr net.Addr) (net.Conn, error) { if err != nil { return nil, os.NewSyscallError("connect", err) } - var inProgress bool switch err := connect(fd, connectAddr); err { case nil: @@ -80,32 +74,50 @@ func dialAddr(addr net.Addr) (net.Conn, error) { if err != nil { return nil, err } - rawConnErr := rawConn.Write(func(fd uintptr) bool { - var value int - value, err = getsockopt(int(fd), SOL_SOCKET, SO_ERROR) - if err != nil { - return true // done + + errch := make(chan error) + go func() { + var err error + rawConnErr := rawConn.Write(func(fd uintptr) bool { + var value int + value, err = getsockopt(int(fd), SOL_SOCKET, SO_ERROR) + if err != nil { + return true // done + } + switch syscall.Errno(value) { + case syscall.EINPROGRESS, syscall.EINTR: + return false // continue + case syscall.EISCONN: + err = nil + return true + case syscall.Errno(0): + // The net poller can wake up spuriously. Check that we are + // are really connected. + _, err := getpeername(int(fd)) + return err == nil + default: + return true + } + }) + if err == nil { + err = rawConnErr } - switch syscall.Errno(value) { - case syscall.EINPROGRESS, syscall.EINTR: - return false // continue - case syscall.EISCONN: - err = nil - return true - case syscall.Errno(0): - // The net poller can wake up spuriously. Check that we are - // are really connected. - _, err := getpeername(int(fd)) - return err == nil - default: - return true + errch <- err + }() + + select { + case err := <-errch: + if err != nil { + return nil, os.NewSyscallError("connect", err) } - }) - if err == nil { - err = rawConnErr - } - if err != nil { - return nil, os.NewSyscallError("connect", err) + case <-ctx.Done(): + // This should interrupt the async connect operation handled by the + // goroutine. + f.Close() + // Wait for the goroutine to complete, we can safely discard the + // error here because we don't care about the socket anymore. + <-errch + return nil, context.Cause(ctx) } } @@ -113,63 +125,5 @@ func dialAddr(addr net.Addr) (net.Conn, error) { if err != nil { return nil, err } - - // TODO: get local+peer address; wrap FileConn to implement LocalAddr() and RemoteAddr() - return c, nil -} - -func family(addr net.Addr) int { - var ip net.IP - switch a := addr.(type) { - case *net.UnixAddr: - return AF_UNIX - case *net.TCPAddr: - ip = a.IP - case *net.UDPAddr: - ip = a.IP - case *net.IPAddr: - ip = a.IP - } - if ip.To4() != nil { - return AF_INET - } else if len(ip) == net.IPv6len { - return AF_INET6 - } - return AF_INET -} - -func socketType(addr net.Addr) int { - switch addr.Network() { - case "tcp", "unix": - return SOCK_STREAM - case "udp", "unixgram": - return SOCK_DGRAM - default: - panic("not implemented") - } -} - -func socketAddress(addr net.Addr) (sockaddr, error) { - var ip net.IP - var port int - switch a := addr.(type) { - case *net.UnixAddr: - return &sockaddrUnix{name: a.Name}, nil - case *net.TCPAddr: - ip, port = a.IP, a.Port - case *net.UDPAddr: - ip, port = a.IP, a.Port - case *net.IPAddr: - ip = a.IP - } - if ipv4 := ip.To4(); ipv4 != nil { - return &sockaddrInet4{addr: ([4]byte)(ipv4), port: port}, nil - } else if len(ip) == net.IPv6len { - return &sockaddrInet6{addr: ([16]byte)(ip), port: port}, nil - } else { - return nil, &net.AddrError{ - Err: "unsupported address type", - Addr: addr.String(), - } - } + return makeConn(c) } diff --git a/wasip1/listen_wasip1.go b/wasip1/listen_wasip1.go index 855bc2f..976efe8 100644 --- a/wasip1/listen_wasip1.go +++ b/wasip1/listen_wasip1.go @@ -25,7 +25,11 @@ func listenErr(addr net.Addr, err error) error { } func listenAddr(addr net.Addr) (net.Listener, error) { - fd, err := socket(family(addr), socketType(addr), 0) + sotype, err := socketType(addr) + if err != nil { + return nil, os.NewSyscallError("socket", err) + } + fd, err := socket(family(addr), sotype, 0) if err != nil { return nil, os.NewSyscallError("socket", err) } @@ -39,22 +43,26 @@ func listenAddr(addr net.Addr) (net.Listener, error) { return nil, os.NewSyscallError("setsockopt", err) } - listenAddr, err := socketAddress(addr) + bindAddr, err := socketAddress(addr) if err != nil { return nil, os.NewSyscallError("bind", err) } - - if err := bind(fd, listenAddr); err != nil { + if err := bind(fd, bindAddr); err != nil { syscall.Close(fd) return nil, os.NewSyscallError("bind", err) } - const backlog = 64 // TODO: configurable? if err := listen(fd, backlog); err != nil { syscall.Close(fd) return nil, os.NewSyscallError("listen", err) } + sockaddr, err := getsockname(fd) + if err != nil { + syscall.Close(fd) + return nil, os.NewSyscallError("getsockname", err) + } + f := os.NewFile(uintptr(fd), "") defer f.Close() @@ -62,6 +70,12 @@ func listenAddr(addr net.Addr) (net.Listener, error) { if err != nil { return nil, err } + switch l.(type) { + case *net.UnixListener: + addr = sockaddrToUnixAddr(sockaddr) + case *net.TCPListener: + addr = sockaddrToTCPAddr(sockaddr) + } return &listener{l, addr}, nil } @@ -75,8 +89,7 @@ func (l *listener) Accept() (net.Conn, error) { if err != nil { return nil, err } - // TODO: get local+peer address; wrap Conn to implement LocalAddr() and RemoteAddr() - return c, nil + return makeConn(c) } func (l *listener) Addr() net.Addr { diff --git a/wasip1/net_wasip1.go b/wasip1/net_wasip1.go index ab067f3..bb9470a 100644 --- a/wasip1/net_wasip1.go +++ b/wasip1/net_wasip1.go @@ -5,6 +5,7 @@ import ( "errors" "net" "net/http" + "syscall" ) func dialResolverNotSupported(ctx context.Context, network, address string) (net.Conn, error) { @@ -41,3 +42,153 @@ type netAddr struct{ network, address string } func (na *netAddr) Network() string { return na.address } func (na *netAddr) String() string { return na.address } + +func family(addr net.Addr) int { + var ip net.IP + switch a := addr.(type) { + case *net.UnixAddr: + return AF_UNIX + case *net.TCPAddr: + ip = a.IP + case *net.UDPAddr: + ip = a.IP + case *net.IPAddr: + ip = a.IP + } + if ip.To4() != nil { + return AF_INET + } else if len(ip) == net.IPv6len { + return AF_INET6 + } + return AF_INET +} + +func socketType(addr net.Addr) (int, error) { + switch addr.Network() { + case "tcp", "unix": + return SOCK_STREAM, nil + case "udp", "unixgram": + return SOCK_DGRAM, nil + default: + return -1, syscall.EPROTOTYPE + } +} + +func socketAddress(addr net.Addr) (sockaddr, error) { + var ip net.IP + var port int + switch a := addr.(type) { + case *net.UnixAddr: + return &sockaddrUnix{name: a.Name}, nil + case *net.TCPAddr: + ip, port = a.IP, a.Port + case *net.UDPAddr: + ip, port = a.IP, a.Port + case *net.IPAddr: + ip = a.IP + } + if ipv4 := ip.To4(); ipv4 != nil { + return &sockaddrInet4{addr: ([4]byte)(ipv4), port: port}, nil + } else if len(ip) == net.IPv6len { + return &sockaddrInet6{addr: ([16]byte)(ip), port: port}, nil + } else { + return nil, &net.AddrError{ + Err: "unsupported address type", + Addr: addr.String(), + } + } +} + +type conn struct { + net.Conn + laddr net.Addr + raddr net.Addr +} + +func (c *conn) LocalAddr() net.Addr { return c.laddr } +func (c *conn) RemoteAddr() net.Addr { return c.raddr } + +// In Go 1.21, the net package cannot initialize the local and remote addresses +// of network connections. For this reason, we use this function to retreive the +// addresses and return a wrapped net.Conn with LocalAddr/RemoteAddr implemented. +func makeConn(c net.Conn) (net.Conn, error) { + syscallConn, ok := c.(syscall.Conn) + if !ok { + return c, nil + } + rawConn, err := syscallConn.SyscallConn() + if err != nil { + c.Close() + return nil, err + } + var laddr net.Addr + var raddr net.Addr + rawConnErr := rawConn.Control(func(fd uintptr) { + var addr sockaddr + var peer sockaddr + if addr, err = getsockname(int(fd)); err != nil { + return + } + if peer, err = getpeername(int(fd)); err != nil { + return + } + switch c.(type) { + case *net.UnixConn: + laddr = sockaddrToUnixAddr(addr) + raddr = sockaddrToUnixAddr(peer) + case *net.UDPConn: + laddr = sockaddrToUDPAddr(addr) + raddr = sockaddrToUDPAddr(peer) + case *net.TCPConn: + laddr = sockaddrToTCPAddr(addr) + raddr = sockaddrToTCPAddr(peer) + } + }) + if err == nil { + err = rawConnErr + } + if err != nil { + c.Close() + return nil, err + } + return &conn{c, laddr, raddr}, nil +} + +func sockaddrToUnixAddr(addr sockaddr) net.Addr { + switch a := addr.(type) { + case *sockaddrUnix: + return &net.UnixAddr{ + Net: "unix", + Name: a.name, + } + default: + return nil + } +} + +func sockaddrToTCPAddr(addr sockaddr) net.Addr { + ip, port := sockaddrIPAndPort(addr) + return &net.TCPAddr{ + IP: ip, + Port: port, + } +} + +func sockaddrToUDPAddr(addr sockaddr) net.Addr { + ip, port := sockaddrIPAndPort(addr) + return &net.UDPAddr{ + IP: ip, + Port: port, + } +} + +func sockaddrIPAndPort(addr sockaddr) (net.IP, int) { + switch a := addr.(type) { + case *sockaddrInet4: + return net.IP(a.addr[:]), a.port + case *sockaddrInet6: + return net.IP(a.addr[:]), a.port + default: + return nil, 0 + } +} diff --git a/wasip1/net_wasip1_test.go b/wasip1/net_wasip1_test.go new file mode 100644 index 0000000..caa86b2 --- /dev/null +++ b/wasip1/net_wasip1_test.go @@ -0,0 +1,71 @@ +package wasip1_test + +import ( + "net" + "testing" + + "github.com/stealthrocket/net/wasip1" + "golang.org/x/net/nettest" +) + +func TestConn(t *testing.T) { + // TODO: for now only the TCP tests pass due to limitations in Go 1.21, see: + // https://github.com/golang/go/blob/39effbc105f5c54117a6011af3c48e3c8f14eca9/src/net/file_wasip1.go#L33-L55 + // + // Once https://go-review.googlesource.com/c/go/+/500578 is merged, we will + // be able to test udp and unix networks as well. + tests := []struct { + network string + address string + }{ + { + network: "tcp", + address: ":0", + }, + { + network: "tcp4", + address: ":0", + }, + { + network: "tcp6", + address: ":0", + }, + } + + for _, test := range tests { + t.Run(test.network, func(t *testing.T) { + nettest.TestConn(t, func() (c1, c2 net.Conn, stop func(), err error) { + l, err := wasip1.Listen(test.network, test.address) + if err != nil { + return nil, nil, nil, err + } + defer l.Close() + + conns := make(chan net.Conn, 1) + errch := make(chan error, 1) + go func() { + c, err := l.Accept() + if err != nil { + errch <- err + } else { + conns <- c + } + }() + + address := l.Addr() + c1, err = wasip1.Dial(address.Network(), address.String()) + if err != nil { + return nil, nil, nil, err + } + + select { + case c2 := <-conns: + return c1, c2, func() { c1.Close(); c2.Close() }, nil + case err := <-errch: + c1.Close() + return nil, nil, nil, err + } + }) + }) + } +} diff --git a/wasip1/syscall_wasmedge_wasip1.go b/wasip1/syscall_wasmedge_wasip1.go index 521306f..696bcac 100644 --- a/wasip1/syscall_wasmedge_wasip1.go +++ b/wasip1/syscall_wasmedge_wasip1.go @@ -71,14 +71,14 @@ func (s *sockaddrInet4) sockport() int { } type sockaddrInet6 struct { - port int - ZoneId uint32 - addr [16]byte - raw addressBuffer + port int + zone uint32 + addr [16]byte + raw addressBuffer } func (s *sockaddrInet6) sockaddr() (unsafe.Pointer, error) { - if s.ZoneId != 0 { + if s.zone != 0 { return nil, syscall.ENOTSUP } s.raw.bufLen = 16