diff --git a/benchmarks_test.go b/benchmarks_test.go index a4264c1..f0a85e8 100644 --- a/benchmarks_test.go +++ b/benchmarks_test.go @@ -90,6 +90,7 @@ func testSmallPackets(b *testing.B, n1, n2 net.Conn) { wg.Add(1) go func() { defer wg.Done() + defer localB.Close() receiveBuf := make([]byte, 2048) for { @@ -103,7 +104,7 @@ func testSmallPackets(b *testing.B, n1, n2 net.Conn) { atomic.AddUint64(&receivedBytes, uint64(n)) } }() - + defer localA.Close() i := 0 for { n, err := localA.Write(msgs[i]) @@ -116,7 +117,6 @@ func testSmallPackets(b *testing.B, n1, n2 net.Conn) { break } } - localA.Close() }) b.StopTimer() wg.Wait() diff --git a/deadline.go b/deadline.go index dd2dfaf..b251c1a 100644 --- a/deadline.go +++ b/deadline.go @@ -32,6 +32,11 @@ func (d *pipeDeadline) set(t time.Time) { d.mu.Lock() defer d.mu.Unlock() + // deadline closed + if d.cancel == nil { + return + } + if d.timer != nil && !d.timer.Stop() { <-d.cancel // Wait for the timer callback to finish and close cancel } @@ -70,6 +75,18 @@ func (d *pipeDeadline) wait() chan struct{} { return d.cancel } +// close closes, the deadline. Any future calls to `set` will do nothing. +func (d *pipeDeadline) close() { + d.mu.Lock() + defer d.mu.Unlock() + + if d.timer != nil && !d.timer.Stop() { + <-d.cancel // Wait for the timer callback to finish and close cancel + } + d.timer = nil + d.cancel = nil +} + func isClosedChan(c <-chan struct{}) bool { select { case <-c: diff --git a/go.mod b/go.mod index f2c000f..dac7e03 100644 --- a/go.mod +++ b/go.mod @@ -7,6 +7,7 @@ require ( github.com/libp2p/go-libp2p-testing v0.1.2-0.20200422005655-8775583591d8 github.com/multiformats/go-varint v0.0.6 github.com/opentracing/opentracing-go v1.2.0 // indirect + go.uber.org/multierr v1.5.0 go.uber.org/zap v1.15.0 // indirect golang.org/x/crypto v0.0.0-20190618222545-ea8f1a30c443 // indirect google.golang.org/grpc v1.28.1 diff --git a/multiplex.go b/multiplex.go index fbb80db..2ee4aad 100644 --- a/multiplex.go +++ b/multiplex.go @@ -111,16 +111,15 @@ func NewMultiplex(con net.Conn, initiator bool) *Multiplex { func (mp *Multiplex) newStream(id streamID, name string) (s *Stream) { s = &Stream{ - id: id, - name: name, - dataIn: make(chan []byte, 8), - reset: make(chan struct{}), - rDeadline: makePipeDeadline(), - wDeadline: makePipeDeadline(), - mp: mp, + id: id, + name: name, + dataIn: make(chan []byte, 8), + rDeadline: makePipeDeadline(), + wDeadline: makePipeDeadline(), + mp: mp, + writeCancel: make(chan struct{}), + readCancel: make(chan struct{}), } - - s.closedLocal, s.doCloseLocal = context.WithCancel(context.Background()) return } @@ -168,7 +167,7 @@ func (mp *Multiplex) IsClosed() bool { } } -func (mp *Multiplex) sendMsg(done <-chan struct{}, header uint64, data []byte) error { +func (mp *Multiplex) sendMsg(timeout, cancel <-chan struct{}, header uint64, data []byte) error { buf := pool.Get(len(data) + 20) n := 0 @@ -181,8 +180,10 @@ func (mp *Multiplex) sendMsg(done <-chan struct{}, header uint64, data []byte) e return nil case <-mp.shutdown: return ErrShutdown - case <-done: + case <-timeout: return errTimeout + case <-cancel: + return ErrStreamClosed } } @@ -321,7 +322,7 @@ func (mp *Multiplex) NewNamedStream(name string) (*Stream, error) { ctx, cancel := context.WithTimeout(context.Background(), NewStreamTimeout) defer cancel() - err := mp.sendMsg(ctx.Done(), header, []byte(name)) + err := mp.sendMsg(ctx.Done(), nil, header, []byte(name)) if err != nil { return nil, err } @@ -331,23 +332,20 @@ func (mp *Multiplex) NewNamedStream(name string) (*Stream, error) { func (mp *Multiplex) cleanup() { mp.closeNoWait() + + // Take the channels. mp.chLock.Lock() - defer mp.chLock.Unlock() - for _, msch := range mp.channels { - msch.clLock.Lock() - if !msch.closedRemote { - msch.closedRemote = true - // Cancel readers - close(msch.reset) - } + channels := mp.channels + mp.channels = nil + mp.chLock.Unlock() - msch.doCloseLocal() - msch.clLock.Unlock() + // Cancel any reads/writes + for _, msch := range channels { + msch.cancelRead(ErrStreamReset) + msch.cancelWrite(ErrStreamReset) } - // Don't remove this nil assignment. We check if this is nil to check if - // the connection is closed when we already have the lock (faster than - // checking if the stream is closed). - mp.channels = nil + + // And... shutdown! if mp.shutdownErr == nil { mp.shutdownErr = ErrShutdown } @@ -421,81 +419,43 @@ func (mp *Multiplex) handleIncoming() { // This is *ok*. We forget the stream on reset. continue } - msch.clLock.Lock() - - isClosed := msch.isClosed() - - if !msch.closedRemote { - close(msch.reset) - msch.closedRemote = true - } - - if !isClosed { - msch.doCloseLocal() - } - msch.clLock.Unlock() - - msch.cancelDeadlines() - - mp.chLock.Lock() - delete(mp.channels, ch) - mp.chLock.Unlock() + // Cancel any ongoing reads/writes. + msch.cancelRead(ErrStreamReset) + msch.cancelWrite(ErrStreamReset) case closeTag: if !ok { + // may have canceled our reads already. continue } - msch.clLock.Lock() - - if msch.closedRemote { - msch.clLock.Unlock() - // Technically a bug on the other side. We - // should consider killing the connection. - continue - } + // unregister and throw away future data. + mp.chLock.Lock() + delete(mp.channels, ch) + mp.chLock.Unlock() + // close data channel, there will be no more data. close(msch.dataIn) - msch.closedRemote = true - - cleanup := msch.isClosed() - msch.clLock.Unlock() - - if cleanup { - msch.cancelDeadlines() - mp.chLock.Lock() - delete(mp.channels, ch) - mp.chLock.Unlock() - } + // We intentionally don't cancel any deadlines, cancel reads, cancel + // writes, etc. We just deliver the EOF by closing the + // data channel, and unregister the channel so we don't + // receive any more data. The user still needs to call + // `Close()` or `Reset()`. case messageTag: if !ok { - // reset stream, return b - pool.Put(b) - - // This is a perfectly valid case when we reset - // and forget about the stream. - log.Debugf("message for non-existant stream, dropping data: %d", ch) - // go mp.sendResetMsg(ch.header(resetTag), false) - continue - } - - msch.clLock.Lock() - remoteClosed := msch.closedRemote - msch.clLock.Unlock() - if remoteClosed { - // closed stream, return b + // We're not accepting data on this stream, for + // some reason. It's likely that we reset it, or + // simply canceled reads (e.g., called Close). pool.Put(b) - - log.Warnf("Received data from remote after stream was closed by them. (len = %d)", len(b)) - // go mp.sendResetMsg(msch.id.header(resetTag), false) continue } recvTimeout.Reset(ReceiveTimeout) select { case msch.dataIn <- b: - case <-msch.reset: + case <-msch.readCancel: + // the user has canceled reading. walk away. pool.Put(b) case <-recvTimeout.C: pool.Put(b) @@ -534,7 +494,7 @@ func (mp *Multiplex) sendResetMsg(header uint64, hard bool) { ctx, cancel := context.WithTimeout(context.Background(), ResetStreamTimeout) defer cancel() - err := mp.sendMsg(ctx.Done(), header, nil) + err := mp.sendMsg(ctx.Done(), nil, header, nil) if err != nil && !mp.isShutdown() { if hard { log.Warnf("error sending reset message: %s; killing connection", err.Error()) diff --git a/multiplex_test.go b/multiplex_test.go index 90017b2..81d484d 100644 --- a/multiplex_test.go +++ b/multiplex_test.go @@ -6,6 +6,8 @@ import ( "io/ioutil" "math/rand" "net" + "os" + "sync" "testing" "time" ) @@ -205,6 +207,58 @@ func TestEcho(t *testing.T) { mpb.Close() } +func TestFullClose(t *testing.T) { + a, b := net.Pipe() + mpa := NewMultiplex(a, false) + mpb := NewMultiplex(b, true) + + mes := make([]byte, 40960) + acceptWait := make(chan struct{}) + rand.Read(mes) + go func() { + s, err := mpb.Accept() + if err != nil { + t.Error(err) + } + close(acceptWait) + + defer s.Close() + + _, err = s.Write(mes) + if err != nil { + t.Error(err) + } + }() + + s, err := mpa.NewStream() + if err != nil { + t.Fatal(err) + } + + <-acceptWait + + err = s.Close() + if err != nil { + t.Fatal(err) + } + + if n, err := s.Write([]byte("foo")); err != ErrStreamClosed { + t.Fatal("expected stream closed error on write to closed stream, got", err) + } else if n != 0 { + t.Fatal("should not have written any bytes to closed stream") + } + + // We closed for reading, this should fail. + if n, err := s.Read([]byte{0}); err != ErrStreamClosed { + t.Fatal("expected stream closed error on read from closed stream, got", err) + } else if n != 0 { + t.Fatal("should not have read any bytes from closed stream, got", n) + } + + mpa.Close() + mpb.Close() +} + func TestHalfClose(t *testing.T) { a, b := net.Pipe() mpa := NewMultiplex(a, false) @@ -216,15 +270,19 @@ func TestHalfClose(t *testing.T) { go func() { s, err := mpb.Accept() if err != nil { - t.Fatal(err) + t.Error(err) } defer s.Close() + if err := s.CloseRead(); err != nil { + t.Error(err) + } + <-wait _, err = s.Write(mes) if err != nil { - t.Fatal(err) + t.Error(err) } }() @@ -232,8 +290,9 @@ func TestHalfClose(t *testing.T) { if err != nil { t.Fatal(err) } + defer s.Close() - err = s.Close() + err = s.CloseWrite() if err != nil { t.Fatal(err) } @@ -362,6 +421,184 @@ func TestReset(t *testing.T) { } } +func TestCancelRead(t *testing.T) { + a, b := net.Pipe() + + mpa := NewMultiplex(a, false) + mpb := NewMultiplex(b, true) + + defer mpa.Close() + defer mpb.Close() + + sa, err := mpa.NewStream() + if err != nil { + t.Fatal(err) + } + defer sa.Reset() + + sb, err := mpb.Accept() + if err != nil { + t.Fatal(err) + } + defer sb.Reset() + + // spin off a read + done := make(chan struct{}) + go func() { + defer close(done) + _, err := sa.Read([]byte{0}) + if err != ErrStreamClosed { + t.Error(err) + } + }() + // give it a chance to start. + time.Sleep(time.Millisecond) + + // cancel it. + err = sa.CloseRead() + if err != nil { + t.Fatal(err) + } + + // It should be canceled. + <-done + + // Writing should still succeed. + _, err = sa.Write([]byte("foo")) + if err != nil { + t.Fatal(err) + } + err = sa.Close() + if err != nil { + t.Fatal(err) + } + // Data should still be sent. + buf, err := ioutil.ReadAll(sb) + if err != nil { + t.Fatal(err) + } + if string(buf) != "foo" { + t.Fatalf("expected foo, got %#v", err) + } +} + +func TestCancelWrite(t *testing.T) { + a, b := net.Pipe() + + mpa := NewMultiplex(a, false) + mpb := NewMultiplex(b, true) + + defer mpa.Close() + defer mpb.Close() + + sa, err := mpa.NewStream() + if err != nil { + t.Fatal(err) + } + defer sa.Reset() + + sb, err := mpb.Accept() + if err != nil { + t.Fatal(err) + } + defer sb.Reset() + + // spin off a read + var wg sync.WaitGroup + wg.Add(2) + go func() { + defer wg.Done() + for { + _, err := sa.Write([]byte("foo")) + if err != nil { + if err != ErrStreamClosed { + t.Error("unexpected error", err) + } + return + } + } + }() + // give it a chance to fill up. + time.Sleep(time.Millisecond) + + go func() { + defer wg.Done() + // close it. + err = sa.CloseWrite() + if err != nil { + t.Error(err) + } + }() + _, err = ioutil.ReadAll(sb) + if err != nil { + t.Fatalf("expected stream to be closed correctly") + } + + // It should be canceled. + wg.Wait() + + // Reading should still succeed. + _, err = sb.Write([]byte("bar")) + if err != nil { + t.Fatal(err) + } + err = sb.Close() + if err != nil { + t.Fatal(err) + } + // Data should still be sent. + buf, err := ioutil.ReadAll(sa) + if err != nil { + t.Fatal(err) + } + if string(buf) != "bar" { + t.Fatalf("expected foo, got %#v", err) + } +} + +func TestCancelReadAfterWrite(t *testing.T) { + a, b := net.Pipe() + + mpa := NewMultiplex(a, false) + mpb := NewMultiplex(b, true) + + defer mpa.Close() + defer mpb.Close() + + sa, err := mpa.NewStream() + if err != nil { + t.Fatal(err) + } + defer sa.Reset() + + sb, err := mpb.Accept() + if err != nil { + t.Fatal(err) + } + defer sb.Reset() + + // Write small messages till we would block. + sa.SetWriteDeadline(time.Now().Add(time.Millisecond)) + for { + _, err = sa.Write([]byte("foo")) + if err != nil { + if os.IsTimeout(err) { + break + } else { + t.Fatal(err) + } + } + } + + // Cancel inbound reads. + sb.CloseRead() + // We shouldn't read anything. + n, err := sb.Read([]byte{0}) + if n != 0 || err != ErrStreamClosed { + t.Fatal("got data", err) + } +} + func TestResetAfterEOF(t *testing.T) { a, b := net.Pipe() @@ -377,7 +614,7 @@ func TestResetAfterEOF(t *testing.T) { } sb, err := mpb.Accept() - if err := sa.Close(); err != nil { + if err := sa.CloseWrite(); err != nil { t.Fatal(err) } diff --git a/stream.go b/stream.go index b76f75e..4d7007d 100644 --- a/stream.go +++ b/stream.go @@ -8,6 +8,7 @@ import ( "time" pool "github.com/libp2p/go-buffer-pool" + "go.uber.org/multierr" ) var ( @@ -44,15 +45,9 @@ type Stream struct { rDeadline, wDeadline pipeDeadline - clLock sync.Mutex - closedRemote bool - - // Closed when the connection is reset. - reset chan struct{} - - // Closed when the writer is closed (reset will also be closed) - closedLocal context.Context - doCloseLocal context.CancelFunc + clLock sync.Mutex + writeCancelErr, readCancelErr error + writeCancel, readCancel chan struct{} } func (s *Stream) Name() string { @@ -74,10 +69,6 @@ func (s *Stream) preloadData() { func (s *Stream) waitForData() error { select { - case <-s.reset: - // This is the only place where it's safe to return these. - s.returnBuffers() - return ErrStreamReset case read, ok := <-s.dataIn: if !ok { return io.EOF @@ -85,6 +76,10 @@ func (s *Stream) waitForData() error { s.extra = read s.exbuf = read return nil + case <-s.readCancel: + // This is the only place where it's safe to return these. + s.returnBuffers() + return s.readCancelErr case <-s.rDeadline.wait(): return errTimeout } @@ -114,10 +109,11 @@ func (s *Stream) returnBuffers() { func (s *Stream) Read(b []byte) (int, error) { select { - case <-s.reset: - return 0, ErrStreamReset + case <-s.readCancel: + return 0, s.readCancelErr default: } + if s.extra == nil { err := s.waitForData() if err != nil { @@ -162,134 +158,112 @@ func (s *Stream) Write(b []byte) (int, error) { } func (s *Stream) write(b []byte) (int, error) { - if s.isClosed() { - return 0, ErrStreamClosed + select { + case <-s.writeCancel: + return 0, s.writeCancelErr + default: } - err := s.mp.sendMsg(s.wDeadline.wait(), s.id.header(messageTag), b) - + err := s.mp.sendMsg(s.wDeadline.wait(), s.writeCancel, s.id.header(messageTag), b) if err != nil { - if err == context.Canceled { - err = ErrStreamClosed - } return 0, err } return len(b), nil } -func (s *Stream) isClosed() bool { - return s.closedLocal.Err() != nil -} +func (s *Stream) cancelWrite(err error) bool { + s.wDeadline.close() -func (s *Stream) Close() error { - ctx, cancel := context.WithTimeout(context.Background(), ResetStreamTimeout) - defer cancel() + s.clLock.Lock() + defer s.clLock.Unlock() + select { + case <-s.writeCancel: + return false + default: + s.writeCancelErr = err + close(s.writeCancel) + return true + } +} - err := s.mp.sendMsg(ctx.Done(), s.id.header(closeTag), nil) +func (s *Stream) cancelRead(err error) bool { + // Always unregister for reading first, even if we're already closed (or + // already closing). When handleIncoming calls this, it expects the + // stream to be unregistered by the time it returns. + s.mp.chLock.Lock() + delete(s.mp.channels, s.id) + s.mp.chLock.Unlock() - if s.isClosed() { - return nil - } + s.rDeadline.close() s.clLock.Lock() - remote := s.closedRemote - s.clLock.Unlock() - - s.doCloseLocal() + defer s.clLock.Unlock() + select { + case <-s.readCancel: + return false + default: + s.readCancelErr = err + close(s.readCancel) + return true + } +} - if remote { - s.cancelDeadlines() - s.mp.chLock.Lock() - delete(s.mp.channels, s.id) - s.mp.chLock.Unlock() +func (s *Stream) CloseWrite() error { + if !s.cancelWrite(ErrStreamClosed) { + // Check if we closed the stream _nicely_. If so, we don't need + // to report an error to the user. + if s.writeCancelErr == ErrStreamClosed { + return nil + } + // Closed for some other reason. Report it. + return s.writeCancelErr } + ctx, cancel := context.WithTimeout(context.Background(), ResetStreamTimeout) + defer cancel() + + err := s.mp.sendMsg(ctx.Done(), nil, s.id.header(closeTag), nil) + // We failed to close the stream after 2 minutes, something is probably wrong. if err != nil && !s.mp.isShutdown() { log.Warnf("Error closing stream: %s; killing connection", err.Error()) s.mp.Close() } - return err } -func (s *Stream) Reset() error { - s.clLock.Lock() - - // Don't reset when fully closed. - if s.closedRemote && s.isClosed() { - s.clLock.Unlock() - return nil - } - - // Don't reset twice. - select { - case <-s.reset: - s.clLock.Unlock() - return nil - default: - } - - close(s.reset) - s.doCloseLocal() - s.closedRemote = true - s.cancelDeadlines() - - go s.mp.sendResetMsg(s.id.header(resetTag), true) - - s.clLock.Unlock() - - s.mp.chLock.Lock() - delete(s.mp.channels, s.id) - s.mp.chLock.Unlock() - +func (s *Stream) CloseRead() error { + s.cancelRead(ErrStreamClosed) return nil } -func (s *Stream) cancelDeadlines() { - s.rDeadline.set(time.Time{}) - s.wDeadline.set(time.Time{}) +func (s *Stream) Close() error { + return multierr.Combine(s.CloseRead(), s.CloseWrite()) } -func (s *Stream) SetDeadline(t time.Time) error { - s.clLock.Lock() - defer s.clLock.Unlock() - - if s.closedRemote && s.isClosed() { - return errStreamClosed - } +func (s *Stream) Reset() error { + s.cancelRead(ErrStreamReset) - if !s.closedRemote { - s.rDeadline.set(t) + if s.cancelWrite(ErrStreamReset) { + // Send a reset in the background. + go s.mp.sendResetMsg(s.id.header(resetTag), true) } - if !s.isClosed() { - s.wDeadline.set(t) - } + return nil +} +func (s *Stream) SetDeadline(t time.Time) error { + s.rDeadline.set(t) + s.wDeadline.set(t) return nil } func (s *Stream) SetReadDeadline(t time.Time) error { - s.clLock.Lock() - defer s.clLock.Unlock() - - if s.closedRemote { - return errStreamClosed - } - s.rDeadline.set(t) return nil } func (s *Stream) SetWriteDeadline(t time.Time) error { - s.clLock.Lock() - defer s.clLock.Unlock() - - if s.isClosed() { - return errStreamClosed - } - s.wDeadline.set(t) return nil }