diff --git a/const.go b/const.go index dfb0add..e4b2bc2 100644 --- a/const.go +++ b/const.go @@ -115,8 +115,8 @@ const ( const ( // initialStreamWindow is the initial stream window size. // It's not an implementation choice, the value defined in the specification. - initialStreamWindow uint32 = 256 * 1024 - maxStreamWindow uint32 = 16 * 1024 * 1024 + initialStreamWindow = 256 * 1024 + maxStreamWindow = 16 * 1024 * 1024 ) const ( diff --git a/mux.go b/mux.go index be13e11..b1f7a64 100644 --- a/mux.go +++ b/mux.go @@ -60,6 +60,12 @@ type Config struct { // MaxMessageSize is the maximum size of a message that we'll send on a // stream. This ensures that a single stream doesn't hog a connection. MaxMessageSize uint32 + + // MemoryManager allows management of memory allocations. + // Memory is allocated: + // 1. When opening / accepting a new stream. This uses the highest priority. + // 2. When trying to increase the stream receive window. This uses a lower priority. + MemoryManager MemoryManager } // DefaultConfig is used to return a default configuration diff --git a/session.go b/session.go index f7e2453..79441e1 100644 --- a/session.go +++ b/session.go @@ -18,6 +18,21 @@ import ( pool "github.com/libp2p/go-buffer-pool" ) +// The MemoryManager allows management of memory allocations. +type MemoryManager interface { + // ReserveMemory reserves memory / buffer. + ReserveMemory(size int, prio uint8) error + // ReleaseMemory explicitly releases memory previously reserved with ReserveMemory + ReleaseMemory(size int) +} + +type nullMemoryManagerImpl struct{} + +func (n nullMemoryManagerImpl) ReserveMemory(size int, prio uint8) error { return nil } +func (n nullMemoryManagerImpl) ReleaseMemory(size int) {} + +var nullMemoryManager MemoryManager = &nullMemoryManagerImpl{} + // Session is used to wrap a reliable ordered connection and to // multiplex it into multiple streams. type Session struct { @@ -47,6 +62,8 @@ type Session struct { // reader is a buffered reader reader io.Reader + memoryManager MemoryManager + // pings is used to track inflight pings pingLock sync.Mutex pingID uint32 @@ -106,21 +123,22 @@ func newSession(config *Config, conn net.Conn, client bool, readBuf int) *Sessio reader = bufio.NewReaderSize(reader, readBuf) } s := &Session{ - config: config, - client: client, - logger: log.New(config.LogOutput, "", log.LstdFlags), - conn: conn, - reader: reader, - streams: make(map[uint32]*Stream), - inflight: make(map[uint32]struct{}), - synCh: make(chan struct{}, config.AcceptBacklog), - acceptCh: make(chan *Stream, config.AcceptBacklog), - sendCh: make(chan []byte, 64), - pongCh: make(chan uint32, config.PingBacklog), - pingCh: make(chan uint32), - recvDoneCh: make(chan struct{}), - sendDoneCh: make(chan struct{}), - shutdownCh: make(chan struct{}), + config: config, + client: client, + logger: log.New(config.LogOutput, "", log.LstdFlags), + conn: conn, + reader: reader, + streams: make(map[uint32]*Stream), + inflight: make(map[uint32]struct{}), + synCh: make(chan struct{}, config.AcceptBacklog), + acceptCh: make(chan *Stream, config.AcceptBacklog), + sendCh: make(chan []byte, 64), + pongCh: make(chan uint32, config.PingBacklog), + pingCh: make(chan uint32), + recvDoneCh: make(chan struct{}), + sendDoneCh: make(chan struct{}), + shutdownCh: make(chan struct{}), + memoryManager: config.MemoryManager, } if client { s.nextStreamID = 1 @@ -130,6 +148,9 @@ func newSession(config *Config, conn net.Conn, client bool, readBuf int) *Sessio if config.EnableKeepAlive { s.startKeepalive() } + if s.memoryManager == nil { + s.memoryManager = nullMemoryManager + } go s.recv() go s.send() go s.measureRTT() @@ -187,6 +208,10 @@ func (s *Session) OpenStream(ctx context.Context) (*Stream, error) { return nil, s.shutdownErr } + if err := s.memoryManager.ReserveMemory(initialStreamWindow, 255); err != nil { + return nil, err + } + GET_ID: // Get an ID, and check for stream exhaustion id := atomic.LoadUint32(&s.nextStreamID) @@ -198,7 +223,7 @@ GET_ID: } // Register the stream - stream := newStream(s, id, streamInit) + stream := newStream(s, id, streamInit, initialStreamWindow) s.streamLock.Lock() s.streams[id] = stream s.inflight[id] = struct{}{} @@ -477,20 +502,20 @@ func (s *Session) sendLoop() error { // FIXME: https://github.com/libp2p/go-libp2p/issues/644 // Write coalescing is disabled for now. - //writer := pool.Writer{W: s.conn} + // writer := pool.Writer{W: s.conn} - //var writeTimeout *time.Timer - //var writeTimeoutCh <-chan time.Time - //if s.config.WriteCoalesceDelay > 0 { + // var writeTimeout *time.Timer + // var writeTimeoutCh <-chan time.Time + // if s.config.WriteCoalesceDelay > 0 { // writeTimeout = time.NewTimer(s.config.WriteCoalesceDelay) // defer writeTimeout.Stop() // writeTimeoutCh = writeTimeout.C - //} else { + // } else { // ch := make(chan time.Time) // close(ch) // writeTimeoutCh = ch - //} + // } for { // yield after processing the last message, if we've shutdown. @@ -526,7 +551,7 @@ func (s *Session) sendLoop() error { copy(buf, hdr[:]) case <-s.shutdownCh: return nil - //default: + // default: // select { // case buf = <-s.sendCh: // case <-s.shutdownCh: @@ -591,6 +616,7 @@ func (s *Session) recvLoop() error { defer close(s.recvDoneCh) var hdr header for { + // fmt.Printf("ReadFull from %#v\n", s.reader) // Read the header if _, err := io.ReadFull(s.reader, hdr[:]); err != nil { if err != io.EOF && !strings.Contains(err.Error(), "closed") && !strings.Contains(err.Error(), "reset by peer") { @@ -733,7 +759,10 @@ func (s *Session) incomingStream(id uint32) error { } // Allocate a new stream - stream := newStream(s, id, streamSYNReceived) + if err := s.memoryManager.ReserveMemory(initialStreamWindow, 255); err != nil { + return err + } + stream := newStream(s, id, streamSYNReceived, initialStreamWindow) s.streamLock.Lock() defer s.streamLock.Unlock() @@ -744,13 +773,14 @@ func (s *Session) incomingStream(id uint32) error { if sendErr := s.sendMsg(s.goAway(goAwayProtoErr), nil, nil); sendErr != nil { s.logger.Printf("[WARN] yamux: failed to send go away: %v", sendErr) } + s.memoryManager.ReleaseMemory(initialStreamWindow) return ErrDuplicateStream } if s.numIncomingStreams >= s.config.MaxIncomingStreams { // too many active streams at the same time s.logger.Printf("[WARN] yamux: MaxIncomingStreams exceeded, forcing stream reset") - delete(s.streams, id) + s.memoryManager.ReleaseMemory(initialStreamWindow) hdr := encode(typeWindowUpdate, flagRST, id, 0) return s.sendMsg(hdr, nil, nil) } @@ -766,7 +796,7 @@ func (s *Session) incomingStream(id uint32) error { default: // Backlog exceeded! RST the stream s.logger.Printf("[WARN] yamux: backlog exceeded, forcing stream reset") - delete(s.streams, id) + s.deleteStream(id) hdr := encode(typeWindowUpdate, flagRST, id, 0) return s.sendMsg(hdr, nil, nil) } @@ -788,10 +818,19 @@ func (s *Session) closeStream(id uint32) { if s.client == (id%2 == 0) { s.numIncomingStreams-- } - delete(s.streams, id) + s.deleteStream(id) s.streamLock.Unlock() } +func (s *Session) deleteStream(id uint32) { + str, ok := s.streams[id] + if !ok { + return + } + s.memoryManager.ReleaseMemory(int(str.recvWindow)) + delete(s.streams, id) +} + // establishStream is used to mark a stream that was in the // SYN Sent state as established. func (s *Session) establishStream(id uint32) { diff --git a/stream.go b/stream.go index 8213860..7ae0be3 100644 --- a/stream.go +++ b/stream.go @@ -51,7 +51,7 @@ type Stream struct { // newStream is used to construct a new stream within // a given session for an ID -func newStream(session *Session, id uint32, state streamState) *Stream { +func newStream(session *Session, id uint32, state streamState, initialWindow uint32) *Stream { s := &Stream{ id: id, session: session, @@ -62,7 +62,7 @@ func newStream(session *Session, id uint32, state streamState) *Stream { // Initialize the recvBuf with initialStreamWindow, not config.InitialStreamWindowSize. // The peer isn't allowed to send more data than initialStreamWindow until we've sent // the first window update (which will grant it up to config.InitialStreamWindowSize). - recvBuf: newSegmentedBuffer(initialStreamWindow), + recvBuf: newSegmentedBuffer(initialWindow), recvWindow: session.config.InitialStreamWindowSize, epochStart: time.Now(), recvNotifyCh: make(chan struct{}, 1), @@ -225,8 +225,10 @@ func (s *Stream) sendWindowUpdate() error { recvWindow = min(s.recvWindow*2, s.session.config.MaxStreamWindowSize) } if recvWindow > s.recvWindow { - s.recvWindow = recvWindow - _, delta = s.recvBuf.GrowTo(s.recvWindow, true) + if err := s.session.memoryManager.ReserveMemory(int(delta), 128); err == nil { + s.recvWindow = recvWindow + _, delta = s.recvBuf.GrowTo(s.recvWindow, true) + } } } s.epochStart = now