From 03cc78cedc5657b699aef81b202cc9d139a0be59 Mon Sep 17 00:00:00 2001 From: Raymond Ho Date: Fri, 15 Sep 2023 05:21:12 -0700 Subject: [PATCH] fix: request timeout race condition (#465) --- conn.go | 21 +++++++++++++-------- conn_test.go | 26 ++++++++++++++++++++++++++ v3/conn.go | 21 +++++++++++++-------- v3/conn_test.go | 26 ++++++++++++++++++++++++++ 4 files changed, 78 insertions(+), 16 deletions(-) diff --git a/conn.go b/conn.go index 474d494b..db4667ca 100644 --- a/conn.go +++ b/conn.go @@ -288,9 +288,10 @@ func (l *Conn) Close() (err error) { l.chanMessage <- &messagePacket{Op: MessageQuit} timeoutCtx := context.Background() - if l.requestTimeout > 0 { + requestTimeout := l.getTimeout() + if requestTimeout > 0 { var cancelFunc context.CancelFunc - timeoutCtx, cancelFunc = context.WithTimeout(timeoutCtx, time.Duration(l.requestTimeout)) + timeoutCtx, cancelFunc = context.WithTimeout(timeoutCtx, time.Duration(requestTimeout)) defer cancelFunc() } select { @@ -316,6 +317,10 @@ func (l *Conn) SetTimeout(timeout time.Duration) { atomic.StoreInt64(&l.requestTimeout, int64(timeout)) } +func (l *Conn) getTimeout() int64 { + return atomic.LoadInt64(&l.requestTimeout) +} + // Returns the next available messageID func (l *Conn) nextMessageID() int64 { if messageID, ok := <-l.chanMessageID; ok { @@ -486,7 +491,7 @@ func (l *Conn) processMessages() { // If we are closing due to an error, inform anyone who // is waiting about the error. if l.IsClosing() && l.closeErr.Load() != nil { - msgCtx.sendResponse(&PacketResponse{Error: l.closeErr.Load().(error)}, time.Duration(l.requestTimeout)) + msgCtx.sendResponse(&PacketResponse{Error: l.closeErr.Load().(error)}, time.Duration(l.getTimeout())) } l.Debug.Printf("Closing channel for MessageID %d", messageID) close(msgCtx.responses) @@ -514,7 +519,7 @@ func (l *Conn) processMessages() { _, err := l.conn.Write(buf) if err != nil { l.Debug.Printf("Error Sending Message: %s", err.Error()) - message.Context.sendResponse(&PacketResponse{Error: fmt.Errorf("unable to send request: %s", err)}, time.Duration(l.requestTimeout)) + message.Context.sendResponse(&PacketResponse{Error: fmt.Errorf("unable to send request: %s", err)}, time.Duration(l.getTimeout())) close(message.Context.responses) break } @@ -524,9 +529,9 @@ func (l *Conn) processMessages() { l.messageContexts[message.MessageID] = message.Context // Add timeout if defined - if l.requestTimeout > 0 { + if l.getTimeout() > 0 { go func() { - timer := time.NewTimer(time.Duration(l.requestTimeout)) + timer := time.NewTimer(time.Duration(l.getTimeout())) defer func() { if err := recover(); err != nil { l.err = fmt.Errorf("ldap: recovered panic in RequestTimeout: %v", err) @@ -549,7 +554,7 @@ func (l *Conn) processMessages() { case MessageResponse: l.Debug.Printf("Receiving message %d", message.MessageID) if msgCtx, ok := l.messageContexts[message.MessageID]; ok { - msgCtx.sendResponse(&PacketResponse{message.Packet, nil}, time.Duration(l.requestTimeout)) + msgCtx.sendResponse(&PacketResponse{message.Packet, nil}, time.Duration(l.getTimeout())) } else { l.err = fmt.Errorf("ldap: received unexpected message %d, %v", message.MessageID, l.IsClosing()) l.Debug.PrintPacket(message.Packet) @@ -559,7 +564,7 @@ func (l *Conn) processMessages() { // All reads will return immediately if msgCtx, ok := l.messageContexts[message.MessageID]; ok { l.Debug.Printf("Receiving message timeout for %d", message.MessageID) - msgCtx.sendResponse(&PacketResponse{message.Packet, NewError(ErrorNetwork, errors.New("ldap: connection timed out"))}, time.Duration(l.requestTimeout)) + msgCtx.sendResponse(&PacketResponse{message.Packet, NewError(ErrorNetwork, errors.New("ldap: connection timed out"))}, time.Duration(l.getTimeout())) delete(l.messageContexts, message.MessageID) close(msgCtx.responses) } diff --git a/conn_test.go b/conn_test.go index bfae3e99..0f6d7181 100644 --- a/conn_test.go +++ b/conn_test.go @@ -79,6 +79,32 @@ func TestInvalidStateCloseDeadlock(t *testing.T) { conn.Close() } +func TestRequestTimeoutDeadlock(t *testing.T) { + // The do-nothing server that accepts requests and does nothing + ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + })) + defer ts.Close() + c, err := net.Dial(ts.Listener.Addr().Network(), ts.Listener.Addr().String()) + if err != nil { + t.Fatalf("error connecting to localhost tcp: %v", err) + } + + // Create an Ldap connection + conn := NewConn(c, false) + conn.Start() + // trigger a race condition on accessing request timeout + n := 3 + for i := 0; i < n; i++ { + go func() { + conn.SetTimeout(time.Millisecond) + }() + } + + // Attempt to close the connection when the message handler is + // blocked or inactive + conn.Close() +} + // TestInvalidStateSendResponseDeadlock tests that we do not enter deadlock when the // message handler is blocked or inactive. func TestInvalidStateSendResponseDeadlock(t *testing.T) { diff --git a/v3/conn.go b/v3/conn.go index a42a9697..4de22fda 100644 --- a/v3/conn.go +++ b/v3/conn.go @@ -288,9 +288,9 @@ func (l *Conn) Close() (err error) { l.chanMessage <- &messagePacket{Op: MessageQuit} timeoutCtx := context.Background() - if l.requestTimeout > 0 { + if l.getTimeout() > 0 { var cancelFunc context.CancelFunc - timeoutCtx, cancelFunc = context.WithTimeout(timeoutCtx, time.Duration(l.requestTimeout)) + timeoutCtx, cancelFunc = context.WithTimeout(timeoutCtx, time.Duration(l.getTimeout())) defer cancelFunc() } select { @@ -316,6 +316,10 @@ func (l *Conn) SetTimeout(timeout time.Duration) { atomic.StoreInt64(&l.requestTimeout, int64(timeout)) } +func (l *Conn) getTimeout() int64 { + return atomic.LoadInt64(&l.requestTimeout) +} + // Returns the next available messageID func (l *Conn) nextMessageID() int64 { if messageID, ok := <-l.chanMessageID; ok { @@ -486,7 +490,7 @@ func (l *Conn) processMessages() { // If we are closing due to an error, inform anyone who // is waiting about the error. if l.IsClosing() && l.closeErr.Load() != nil { - msgCtx.sendResponse(&PacketResponse{Error: l.closeErr.Load().(error)}, time.Duration(l.requestTimeout)) + msgCtx.sendResponse(&PacketResponse{Error: l.closeErr.Load().(error)}, time.Duration(l.getTimeout())) } l.Debug.Printf("Closing channel for MessageID %d", messageID) close(msgCtx.responses) @@ -514,7 +518,7 @@ func (l *Conn) processMessages() { _, err := l.conn.Write(buf) if err != nil { l.Debug.Printf("Error Sending Message: %s", err.Error()) - message.Context.sendResponse(&PacketResponse{Error: fmt.Errorf("unable to send request: %s", err)}, time.Duration(l.requestTimeout)) + message.Context.sendResponse(&PacketResponse{Error: fmt.Errorf("unable to send request: %s", err)}, time.Duration(l.getTimeout())) close(message.Context.responses) break } @@ -524,9 +528,10 @@ func (l *Conn) processMessages() { l.messageContexts[message.MessageID] = message.Context // Add timeout if defined - if l.requestTimeout > 0 { + requestTimeout := l.getTimeout() + if requestTimeout > 0 { go func() { - timer := time.NewTimer(time.Duration(l.requestTimeout)) + timer := time.NewTimer(time.Duration(requestTimeout)) defer func() { if err := recover(); err != nil { l.err = fmt.Errorf("ldap: recovered panic in RequestTimeout: %v", err) @@ -549,7 +554,7 @@ func (l *Conn) processMessages() { case MessageResponse: l.Debug.Printf("Receiving message %d", message.MessageID) if msgCtx, ok := l.messageContexts[message.MessageID]; ok { - msgCtx.sendResponse(&PacketResponse{message.Packet, nil}, time.Duration(l.requestTimeout)) + msgCtx.sendResponse(&PacketResponse{message.Packet, nil}, time.Duration(l.getTimeout())) } else { l.err = fmt.Errorf("ldap: received unexpected message %d, %v", message.MessageID, l.IsClosing()) l.Debug.PrintPacket(message.Packet) @@ -559,7 +564,7 @@ func (l *Conn) processMessages() { // All reads will return immediately if msgCtx, ok := l.messageContexts[message.MessageID]; ok { l.Debug.Printf("Receiving message timeout for %d", message.MessageID) - msgCtx.sendResponse(&PacketResponse{message.Packet, NewError(ErrorNetwork, errors.New("ldap: connection timed out"))}, time.Duration(l.requestTimeout)) + msgCtx.sendResponse(&PacketResponse{message.Packet, NewError(ErrorNetwork, errors.New("ldap: connection timed out"))}, time.Duration(l.getTimeout())) delete(l.messageContexts, message.MessageID) close(msgCtx.responses) } diff --git a/v3/conn_test.go b/v3/conn_test.go index bfae3e99..d0bfc0c5 100644 --- a/v3/conn_test.go +++ b/v3/conn_test.go @@ -58,6 +58,32 @@ func TestUnresponsiveConnection(t *testing.T) { } } +func TestRequestTimeoutDeadlock(t *testing.T) { + // The do-nothing server that accepts requests and does nothing + ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + })) + defer ts.Close() + c, err := net.Dial(ts.Listener.Addr().Network(), ts.Listener.Addr().String()) + if err != nil { + t.Fatalf("error connecting to localhost tcp: %v", err) + } + + // Create an Ldap connection + conn := NewConn(c, false) + conn.Start() + // trigger a race condition on accessing request timeout + n := 3 + for i := 0; i < n; i++ { + go func() { + conn.SetTimeout(time.Millisecond) + }() + } + + // Attempt to close the connection when the message handler is + // blocked or inactive + conn.Close() +} + // TestInvalidStateCloseDeadlock tests that we do not enter deadlock when the // message handler is blocked or inactive. func TestInvalidStateCloseDeadlock(t *testing.T) {