diff --git a/p2p/transport/websocket/addrs_test.go b/p2p/transport/websocket/addrs_test.go index 3c5ba502a9..d262eedbad 100644 --- a/p2p/transport/websocket/addrs_test.go +++ b/p2p/transport/websocket/addrs_test.go @@ -69,7 +69,8 @@ func TestConvertWebsocketMultiaddrToNetAddr(t *testing.T) { } func TestListeningOnDNSAddr(t *testing.T) { - ln, err := newListener(ma.StringCast("/dns/localhost/tcp/0/ws"), nil) + wt := &WebsocketTransport{} + ln, err := wt.newListener(ma.StringCast("/dns/localhost/tcp/0/ws"), nil) require.NoError(t, err) addr := ln.Multiaddr() first, rest := ma.SplitFirst(addr) diff --git a/p2p/transport/websocket/listener.go b/p2p/transport/websocket/listener.go index d7a1b885b8..d9e27a2e18 100644 --- a/p2p/transport/websocket/listener.go +++ b/p2p/transport/websocket/listener.go @@ -14,7 +14,7 @@ import ( ) type listener struct { - nl net.Listener + nl manet.Listener server http.Server // The Go standard library sets the http.Server.TLSConfig no matter if this is a WS or WSS, // so we can't rely on checking if server.TLSConfig is set. @@ -40,7 +40,7 @@ func (pwma *parsedWebsocketMultiaddr) toMultiaddr() ma.Multiaddr { // newListener creates a new listener from a raw net.Listener. // tlsConf may be nil (for unencrypted websockets). -func newListener(a ma.Multiaddr, tlsConf *tls.Config) (*listener, error) { +func (t *WebsocketTransport) newListener(a ma.Multiaddr, tlsConf *tls.Config) (*listener, error) { parsed, err := parseWebsocketMultiaddr(a) if err != nil { return nil, err @@ -50,11 +50,16 @@ func newListener(a ma.Multiaddr, tlsConf *tls.Config) (*listener, error) { return nil, fmt.Errorf("cannot listen on wss address %s without a tls.Config", a) } - lnet, lnaddr, err := manet.DialArgs(parsed.restMultiaddr) - if err != nil { - return nil, err + var nl manet.Listener + if !t.UseReuseport() { + nl, err = manet.Listen(a) + } else { + nl, err = t.reuse.Listen(a) + // Fallback to regular listener in case of an error. + if err != nil { + nl, err = manet.Listen(a) + } } - nl, err := net.Listen(lnet, lnaddr) if err != nil { return nil, err } @@ -88,10 +93,11 @@ func newListener(a ma.Multiaddr, tlsConf *tls.Config) (*listener, error) { func (l *listener) serve() { defer close(l.closed) + list := manet.NetListener(l.nl) if !l.isWss { - l.server.Serve(l.nl) + l.server.Serve(list) } else { - l.server.ServeTLS(l.nl, "", "") + l.server.ServeTLS(list, "", "") } } diff --git a/p2p/transport/websocket/reuseport.go b/p2p/transport/websocket/reuseport.go new file mode 100644 index 0000000000..ea8bee7af8 --- /dev/null +++ b/p2p/transport/websocket/reuseport.go @@ -0,0 +1,9 @@ +package websocket + +import ( + "github.com/libp2p/go-reuseport" +) + +func reuseportIsAvailable() bool { + return reuseport.Available() +} diff --git a/p2p/transport/websocket/websocket.go b/p2p/transport/websocket/websocket.go index 5142ca97a1..f3f4059916 100644 --- a/p2p/transport/websocket/websocket.go +++ b/p2p/transport/websocket/websocket.go @@ -11,6 +11,7 @@ import ( "github.com/libp2p/go-libp2p/core/network" "github.com/libp2p/go-libp2p/core/peer" "github.com/libp2p/go-libp2p/core/transport" + "github.com/libp2p/go-libp2p/p2p/net/reuseport" ma "github.com/multiformats/go-multiaddr" mafmt "github.com/multiformats/go-multiaddr-fmt" @@ -80,6 +81,13 @@ func WithTLSConfig(conf *tls.Config) Option { } } +func EnableReuseport() Option { + return func(t *WebsocketTransport) error { + t.enableReuseport = true + return nil + } +} + // WebsocketTransport is the actual go-libp2p transport type WebsocketTransport struct { upgrader transport.Upgrader @@ -87,6 +95,9 @@ type WebsocketTransport struct { tlsClientConf *tls.Config tlsConf *tls.Config + + enableReuseport bool // Explicitly enable reuseport. + reuse reuseport.Transport } var _ transport.Transport = (*WebsocketTransport)(nil) @@ -188,6 +199,32 @@ func (t *WebsocketTransport) maDial(ctx context.Context, raddr ma.Multiaddr) (ma } isWss := wsurl.Scheme == "wss" dialer := ws.Dialer{HandshakeTimeout: 30 * time.Second} + dialer.NetDialContext = func(ctx context.Context, network string, address string) (net.Conn, error) { + + tcpAddr, err := net.ResolveTCPAddr(network, address) + if err != nil { + return nil, err + } + + maddr, err := manet.FromNetAddr(tcpAddr) + if err != nil { + return nil, err + } + + var conn manet.Conn + if t.UseReuseport() { + conn, err = t.reuse.DialContext(ctx, maddr) + } else { + var d manet.Dialer + conn, err = d.DialContext(ctx, maddr) + } + if err != nil { + return nil, err + } + + return conn, nil + } + if isWss { sni := "" sni, err = raddr.ValueForProtocol(ma.P_SNI) @@ -202,12 +239,29 @@ func (t *WebsocketTransport) maDial(ctx context.Context, raddr ma.Multiaddr) (ma ipAddr := wsurl.Host // Setting the NetDial because we already have the resolved IP address, so we don't want to do another resolution. // We set the `.Host` to the sni field so that the host header gets properly set. - dialer.NetDial = func(network, address string) (net.Conn, error) { + dialer.NetDialContext = func(ctx context.Context, network, address string) (net.Conn, error) { tcpAddr, err := net.ResolveTCPAddr(network, ipAddr) if err != nil { return nil, err } - return net.DialTCP("tcp", nil, tcpAddr) + + maddr, err := manet.FromNetAddr(tcpAddr) + if err != nil { + return nil, err + } + + var conn manet.Conn + if t.UseReuseport() { + conn, err = t.reuse.DialContext(ctx, maddr) + } else { + var d manet.Dialer + conn, err = d.DialContext(ctx, maddr) + } + if err != nil { + return nil, err + } + + return conn, nil } wsurl.Host = sni + ":" + wsurl.Port() } else { @@ -229,7 +283,7 @@ func (t *WebsocketTransport) maDial(ctx context.Context, raddr ma.Multiaddr) (ma } func (t *WebsocketTransport) maListen(a ma.Multiaddr) (manet.Listener, error) { - l, err := newListener(a, t.tlsConf) + l, err := t.newListener(a, t.tlsConf) if err != nil { return nil, err } @@ -244,3 +298,8 @@ func (t *WebsocketTransport) Listen(a ma.Multiaddr) (transport.Listener, error) } return &transportListener{Listener: t.upgrader.UpgradeListener(t, malist)}, nil } + +// UseReuseport returns true if reuseport is enabled and available. +func (t *WebsocketTransport) UseReuseport() bool { + return t.enableReuseport && reuseportIsAvailable() +} diff --git a/p2p/transport/websocket/websocket_test.go b/p2p/transport/websocket/websocket_test.go index 2023ee3528..c26c8f83e5 100644 --- a/p2p/transport/websocket/websocket_test.go +++ b/p2p/transport/websocket/websocket_test.go @@ -15,7 +15,9 @@ import ( "math/big" "net" "net/http" + "runtime" "strings" + "sync" "testing" "time" @@ -32,6 +34,7 @@ import ( ttransport "github.com/libp2p/go-libp2p/p2p/transport/testsuite" ma "github.com/multiformats/go-multiaddr" + manet "github.com/multiformats/go-multiaddr/net" "github.com/stretchr/testify/require" ) @@ -549,3 +552,151 @@ func TestResolveMultiaddr(t *testing.T) { }) } } + +func TestReusePortOnDial(t *testing.T) { + + // Create an endpoint that will accept connections. + // We'll use this to verify that the party initiating the connection reused port. + serverID, cu := newUpgrader(t) + transport, err := New(cu, &network.NullResourceManager{}) + require.NoError(t, err) + + server, err := transport.Listen(ma.StringCast("/ip4/127.0.0.1/tcp/0/ws")) + require.NoError(t, err) + defer server.Close() + + // Create an endpoint that will initiate connection. + _, u := newUpgrader(t) + tpt, err := New(u, &network.NullResourceManager{}, EnableReuseport()) + require.NoError(t, err) + + // Start listening. + listener, err := tpt.Listen(ma.StringCast("/ip4/127.0.0.1/tcp/0/ws")) + require.NoError(t, err) + defer listener.Close() + + // Take a note of the multiaddress on which we listen. This should be the address from which we dial too. + expectedAddr := listener.Multiaddr() + + done := make(chan struct{}) + go func() { + defer close(done) + + conn, err := server.Accept() + require.NoError(t, err) + defer conn.Close() + + // The meat of this test - verify that the connection was received from the same port as the listen port recorded above. + remote := conn.RemoteMultiaddr() + require.Equal(t, expectedAddr, remote) + }() + + conn, err := tpt.Dial(context.Background(), server.Multiaddr(), serverID) + require.NoError(t, err) + defer conn.Close() + + <-done +} + +func TestReusePortOnListen(t *testing.T) { + + const ( + // how many connections we try to establish. + connectionCount = 20 + ) + + // Create an endpoint that will accept connections. + // We'll use this to verify that the party initiating the connection reused port. + _, cu := newUpgrader(t) + tpt, err := New(cu, &network.NullResourceManager{}, EnableReuseport()) + require.NoError(t, err) + + listener1, err := tpt.maListen(ma.StringCast("/ip4/127.0.0.1/tcp/0/ws")) + require.NoError(t, err) + + // Get the port on which we should start the second listener + addr, ok := listener1.Addr().(*net.TCPAddr) + require.True(t, ok) + + port := addr.Port + listener2, err := tpt.maListen(ma.StringCast(fmt.Sprintf("/ip4/127.0.0.1/tcp/%v/ws", port))) + require.NoError(t, err) + + listeners := []manet.Listener{listener1, listener2} + + // Record which listener accepted how many connections. + requestCount := make(map[int]int) + var lock sync.Mutex + + var connsHandled sync.WaitGroup + connsHandled.Add(connectionCount) + // For both listeners spin up goroutines to accept incoming connections. + for i, listener := range listeners { + for j := 0; j < connectionCount; j++ { + go func(index int, listener manet.Listener) { + + conn, err := listener.Accept() + if err != nil { + // Stop condition - this happens when the listener is closed. + require.ErrorIs(t, err, transport.ErrListenerClosed) + return + } + defer conn.Close() + connsHandled.Done() + + // Record which listener accepted the connection. + lock.Lock() + defer lock.Unlock() + requestCount[index]++ + }(i, listener) + } + } + + // Create a different transport as you cannot self-dial using reuseport. + tpt2, err := New(cu, &network.NullResourceManager{}) + require.NoError(t, err) + + var dialers sync.WaitGroup + dialers.Add(connectionCount) + + for i := 0; i < connectionCount; i++ { + go func() { + defer dialers.Done() + conn, err := tpt2.maDial(context.Background(), listener1.Multiaddr()) + require.NoError(t, err) + defer conn.Close() + }() + } + + // Wait for all dialers to complete. + dialers.Wait() + + // Wait for listeners to complete their part. + connsHandled.Wait() + + // Cancel listeners to unblock any further pending accepts. + listener1.Close() + listener2.Close() + + // For Windows we can't make any assumptions with regards to connection distribution: + // "Once the second socket has successfully bound, the behavior for all sockets bound to that port is indeterminate. + // For example, if all of the sockets on the same port provide TCP service, any incoming TCP connection requests over + // the port cannot be guaranteed to be handled by the correct socket — the behavior is non-deterministic." + // => https://learn.microsoft.com/en-us/windows/win32/winsock/using-so-reuseaddr-and-so-exclusiveaddruse + + // For MacOS (FreeBSD) it's the last socket to bind that receives the connections. Anegdotal evidence but: + // "Ironically it's the BSD semantics which support seamless server restarts. In my tests OS X's behavior (which I presume + // is identical to FreeBSD and other BSDs) is that the last socket to bind is the only one to receive new connections." + // => https://lwn.net/Articles/542629/ + // On FreeBSD it's the SO_REUSEPORT_LB variant that provides load balancing. + + // For Linux only - verify that both listeners handled some connections. + if runtime.GOOS == "linux" { + // We're not trying to verify an even distribution as it's not a perfect world. + require.NotZero(t, requestCount[0], "first listener accepted no connections") + require.NotZero(t, requestCount[1], "second listener accepted no connections") + } + + total := requestCount[0] + requestCount[1] + require.Equal(t, connectionCount, total, "not all requests were handled") +}