From 445f860ba32d4cdb7cd53267914b316ed80a7191 Mon Sep 17 00:00:00 2001 From: Prem Chaitanya Prathi Date: Wed, 26 Apr 2023 12:41:00 +0530 Subject: [PATCH] Added support for using reuseport in connection Dialing #1435 --- p2p/transport/websocket/websocket.go | 11 ++++++- p2p/transport/websocket/websocket_test.go | 39 ++++++++++++++++------- 2 files changed, 38 insertions(+), 12 deletions(-) diff --git a/p2p/transport/websocket/websocket.go b/p2p/transport/websocket/websocket.go index 554329af68..8f793fbb75 100644 --- a/p2p/transport/websocket/websocket.go +++ b/p2p/transport/websocket/websocket.go @@ -16,6 +16,7 @@ import ( "github.com/libp2p/go-libp2p/core/transport" "github.com/libp2p/go-reuseport" + reusetransport "github.com/libp2p/go-libp2p/p2p/net/reuseport" ma "github.com/multiformats/go-multiaddr" mafmt "github.com/multiformats/go-multiaddr-fmt" manet "github.com/multiformats/go-multiaddr/net" @@ -91,6 +92,7 @@ type WebsocketTransport struct { tlsClientConf *tls.Config tlsConf *tls.Config reuseport bool //reuseport is disabled by default, can be enabled by passing it as an option. + reuse reusetransport.Transport } var _ transport.Transport = (*WebsocketTransport)(nil) @@ -198,7 +200,14 @@ func (t *WebsocketTransport) maDial(ctx context.Context, raddr ma.Multiaddr) (ma transport := &http.Transport{ DialContext: func(ctx context.Context, network, addr string) (net.Conn, error) { - conn, err := net.Dial(network, addr) + var conn manet.Conn + var err error + if t.UseReuseport() { + conn, err = t.reuse.Dial(raddr) + } else { + var d manet.Dialer + conn, err = d.Dial(raddr) + } if err != nil { close(localAddrChan) return nil, err diff --git a/p2p/transport/websocket/websocket_test.go b/p2p/transport/websocket/websocket_test.go index 6846015d26..9de7e3e530 100644 --- a/p2p/transport/websocket/websocket_test.go +++ b/p2p/transport/websocket/websocket_test.go @@ -552,51 +552,68 @@ func TestResolveMultiaddr(t *testing.T) { func TestListenerResusePort(t *testing.T) { laddr := ma.StringCast("/ip4/127.0.0.1/tcp/5002/ws") - fmt.Println("Starting Reuse Port test.") + //fmt.Println("Starting Reuse Port test.") var wg sync.WaitGroup var opts []Option opts = append(opts, EnableReuseport()) _, u := newUpgrader(t) tpt, err := New(u, &network.NullResourceManager{}, opts...) require.NoError(t, err) - fmt.Println("Invoking Go routines.") + //fmt.Println("Invoking Go routines.") for i := 0; i < 2; i++ { wg.Add(1) go func(index int) { - l, err := tpt.Listen(laddr) + l, err := tpt.maListen(laddr) if err != nil { fmt.Println("Failed to listen on websocket due to error ", err) } require.NoError(t, err) require.Equal(t, lastComponent(t, l.Multiaddr()), wsComponent) defer l.Close() - fmt.Println("Routine-", index, " Calling Accept...") + //fmt.Println("Routine-", index, " Calling Accept...") for j := 0; j < 2; j++ { conn, err := l.Accept() if err != nil { fmt.Println("Routine-", index, " Failed accepting connection due to error ", err) } - //require.NoError(t, err) - fmt.Println("Routine-", index, " Accepting connection ", conn) + require.NoError(t, err) + //fmt.Println("Routine-", index, " Accepting connection ", conn) defer conn.Close() + buf := make([]byte, 6) + n, err := conn.Read(buf) + if n != 6 { + t.Errorf("read %d bytes, expected 2", n) + } + require.NoError(t, err) + fmt.Println("Read bytes:", buf) } }(i) } time.Sleep(2 * time.Second) - fmt.Println("Invoking Connector Go routines.") + //fmt.Println("Invoking Connector Go routines.") for i := 0; i < 4; i++ { go func(index int) { - fmt.Println("Routine-", index, " Initiating connection ") + //fmt.Println("Routine-", index, " Initiating connection ") c, err := tpt.maDial(context.Background(), laddr) if err != nil { t.Error(err) return } + require.NoError(t, err) defer c.Close() - fmt.Println("Sleeping for 10 seconds after connection intiation") - time.Sleep(10 * time.Second) + //fmt.Println("Sleeping for 2 seconds after connection intiation") + msg := fmt.Sprintf("Hello%d", index) + n, err := c.Write([]byte(msg)) + if n != 6 { + t.Errorf("expected to write 0 bytes, wrote %d", n) + } + if err != nil { + t.Error(err) + return + } + time.Sleep(2 * time.Second) }(i) } - wg.Wait() + time.Sleep(2 * time.Second) }