diff --git a/p2p/http/libp2phttp.go b/p2p/http/libp2phttp.go index 5ff025e3ab..e5078f772e 100644 --- a/p2p/http/libp2phttp.go +++ b/p2p/http/libp2phttp.go @@ -61,7 +61,7 @@ type WellKnownHandler struct { // streamHostListen returns a net.Listener that listens on libp2p streams for HTTP/1.1 messages. func streamHostListen(streamHost host.Host) (net.Listener, error) { - return gostream.Listen(streamHost, ProtocolIDForMultistreamSelect) + return gostream.Listen(streamHost, ProtocolIDForMultistreamSelect, gostream.IgnoreEOF()) } func (h *WellKnownHandler) ServeHTTP(w http.ResponseWriter, r *http.Request) { diff --git a/p2p/http/libp2phttp_test.go b/p2p/http/libp2phttp_test.go index a444c6e209..e7e66bb4cd 100644 --- a/p2p/http/libp2phttp_test.go +++ b/p2p/http/libp2phttp_test.go @@ -719,3 +719,39 @@ func TestServerLegacyWellKnownResource(t *testing.T) { } } + +func TestResponseWriterShouldNotHaveCancelledContext(t *testing.T) { + h, err := libp2p.New() + require.NoError(t, err) + defer h.Close() + httpHost := libp2phttp.Host{StreamHost: h} + go httpHost.Serve() + defer httpHost.Close() + + closeNotifyCh := make(chan bool, 1) + httpHost.SetHTTPHandlerAtPath("/test", "/", http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + // Legacy code uses this to check if the connection was closed + //lint:ignore SA1019 This is a test to assert we do the right thing since Go HTTP stdlib depends on this. + ch := w.(http.CloseNotifier).CloseNotify() + select { + case <-ch: + closeNotifyCh <- true + case <-time.After(100 * time.Millisecond): + closeNotifyCh <- false + } + w.WriteHeader(http.StatusOK) + })) + + clientH, err := libp2p.New() + require.NoError(t, err) + defer clientH.Close() + clientHost := libp2phttp.Host{StreamHost: clientH} + + rt, err := clientHost.NewConstrainedRoundTripper(peer.AddrInfo{ID: h.ID(), Addrs: h.Addrs()}) + require.NoError(t, err) + httpClient := &http.Client{Transport: rt} + _, err = httpClient.Get("/") + require.NoError(t, err) + + require.False(t, <-closeNotifyCh) +} diff --git a/p2p/net/gostream/conn.go b/p2p/net/gostream/conn.go index 991dd2ff96..6959b6cbe0 100644 --- a/p2p/net/gostream/conn.go +++ b/p2p/net/gostream/conn.go @@ -2,6 +2,7 @@ package gostream import ( "context" + "io" "net" "github.com/libp2p/go-libp2p/core/host" @@ -14,11 +15,20 @@ import ( // libp2p streams. type conn struct { network.Stream + ignoreEOF bool +} + +func (c *conn) Read(b []byte) (int, error) { + n, err := c.Stream.Read(b) + if err != nil && c.ignoreEOF && err == io.EOF { + return n, nil + } + return n, err } // newConn creates a conn given a libp2p stream -func newConn(s network.Stream) net.Conn { - return &conn{s} +func newConn(s network.Stream, ignoreEOF bool) net.Conn { + return &conn{s, ignoreEOF} } // LocalAddr returns the local network address. @@ -39,5 +49,5 @@ func Dial(ctx context.Context, h host.Host, pid peer.ID, tag protocol.ID) (net.C if err != nil { return nil, err } - return newConn(s), nil + return newConn(s, false), nil } diff --git a/p2p/net/gostream/listener.go b/p2p/net/gostream/listener.go index 250e688050..f1146b0617 100644 --- a/p2p/net/gostream/listener.go +++ b/p2p/net/gostream/listener.go @@ -18,6 +18,10 @@ type listener struct { tag protocol.ID cancel func() streamCh chan network.Stream + // ignoreEOF is a flag that tells the listener to return conns that ignore EOF errors. + // Necessary because the default responsewriter will consider a connection closed if it reads EOF. + // But when on streams, it's fine for us to read EOF, but still be able to write. + ignoreEOF bool } // Accept returns the next a connection to this listener. @@ -26,7 +30,7 @@ type listener struct { func (l *listener) Accept() (net.Conn, error) { select { case s := <-l.streamCh: - return newConn(s), nil + return newConn(s, l.ignoreEOF), nil case <-l.ctx.Done(): return nil, l.ctx.Err() } @@ -48,7 +52,7 @@ func (l *listener) Addr() net.Addr { // Listen provides a standard net.Listener ready to accept "connections". // Under the hood, these connections are libp2p streams tagged with the // given protocol.ID. -func Listen(h host.Host, tag protocol.ID) (net.Listener, error) { +func Listen(h host.Host, tag protocol.ID, opts ...ListenerOption) (net.Listener, error) { ctx, cancel := context.WithCancel(context.Background()) l := &listener{ @@ -58,6 +62,11 @@ func Listen(h host.Host, tag protocol.ID) (net.Listener, error) { tag: tag, streamCh: make(chan network.Stream), } + for _, opt := range opts { + if err := opt(l); err != nil { + return nil, err + } + } h.SetStreamHandler(tag, func(s network.Stream) { select { @@ -69,3 +78,12 @@ func Listen(h host.Host, tag protocol.ID) (net.Listener, error) { return l, nil } + +type ListenerOption func(*listener) error + +func IgnoreEOF() ListenerOption { + return func(l *listener) error { + l.ignoreEOF = true + return nil + } +}