Skip to content

Commit

Permalink
fix: request timeout race condition (#465)
Browse files Browse the repository at this point in the history
  • Loading branch information
raymonstah committed Sep 15, 2023
1 parent 81e9beb commit 03cc78c
Show file tree
Hide file tree
Showing 4 changed files with 78 additions and 16 deletions.
21 changes: 13 additions & 8 deletions conn.go
Original file line number Diff line number Diff line change
Expand Up @@ -288,9 +288,10 @@ func (l *Conn) Close() (err error) {
l.chanMessage <- &messagePacket{Op: MessageQuit}

timeoutCtx := context.Background()
if l.requestTimeout > 0 {
requestTimeout := l.getTimeout()

This comment has been minimized.

Copy link
@cpuschma

cpuschma Sep 15, 2023

Member

Just noticed both the v3 and root version are not equal unfortunately. This slipped through @johnweldon. I'll open a PR to sync both directories again, there are several differences.

if requestTimeout > 0 {
var cancelFunc context.CancelFunc
timeoutCtx, cancelFunc = context.WithTimeout(timeoutCtx, time.Duration(l.requestTimeout))
timeoutCtx, cancelFunc = context.WithTimeout(timeoutCtx, time.Duration(requestTimeout))
defer cancelFunc()
}
select {
Expand All @@ -316,6 +317,10 @@ func (l *Conn) SetTimeout(timeout time.Duration) {
atomic.StoreInt64(&l.requestTimeout, int64(timeout))
}

func (l *Conn) getTimeout() int64 {
return atomic.LoadInt64(&l.requestTimeout)
}

// Returns the next available messageID
func (l *Conn) nextMessageID() int64 {
if messageID, ok := <-l.chanMessageID; ok {
Expand Down Expand Up @@ -486,7 +491,7 @@ func (l *Conn) processMessages() {
// If we are closing due to an error, inform anyone who
// is waiting about the error.
if l.IsClosing() && l.closeErr.Load() != nil {
msgCtx.sendResponse(&PacketResponse{Error: l.closeErr.Load().(error)}, time.Duration(l.requestTimeout))
msgCtx.sendResponse(&PacketResponse{Error: l.closeErr.Load().(error)}, time.Duration(l.getTimeout()))
}
l.Debug.Printf("Closing channel for MessageID %d", messageID)
close(msgCtx.responses)
Expand Down Expand Up @@ -514,7 +519,7 @@ func (l *Conn) processMessages() {
_, err := l.conn.Write(buf)
if err != nil {
l.Debug.Printf("Error Sending Message: %s", err.Error())
message.Context.sendResponse(&PacketResponse{Error: fmt.Errorf("unable to send request: %s", err)}, time.Duration(l.requestTimeout))
message.Context.sendResponse(&PacketResponse{Error: fmt.Errorf("unable to send request: %s", err)}, time.Duration(l.getTimeout()))
close(message.Context.responses)
break
}
Expand All @@ -524,9 +529,9 @@ func (l *Conn) processMessages() {
l.messageContexts[message.MessageID] = message.Context

// Add timeout if defined
if l.requestTimeout > 0 {
if l.getTimeout() > 0 {
go func() {
timer := time.NewTimer(time.Duration(l.requestTimeout))
timer := time.NewTimer(time.Duration(l.getTimeout()))
defer func() {
if err := recover(); err != nil {
l.err = fmt.Errorf("ldap: recovered panic in RequestTimeout: %v", err)
Expand All @@ -549,7 +554,7 @@ func (l *Conn) processMessages() {
case MessageResponse:
l.Debug.Printf("Receiving message %d", message.MessageID)
if msgCtx, ok := l.messageContexts[message.MessageID]; ok {
msgCtx.sendResponse(&PacketResponse{message.Packet, nil}, time.Duration(l.requestTimeout))
msgCtx.sendResponse(&PacketResponse{message.Packet, nil}, time.Duration(l.getTimeout()))
} else {
l.err = fmt.Errorf("ldap: received unexpected message %d, %v", message.MessageID, l.IsClosing())
l.Debug.PrintPacket(message.Packet)
Expand All @@ -559,7 +564,7 @@ func (l *Conn) processMessages() {
// All reads will return immediately
if msgCtx, ok := l.messageContexts[message.MessageID]; ok {
l.Debug.Printf("Receiving message timeout for %d", message.MessageID)
msgCtx.sendResponse(&PacketResponse{message.Packet, NewError(ErrorNetwork, errors.New("ldap: connection timed out"))}, time.Duration(l.requestTimeout))
msgCtx.sendResponse(&PacketResponse{message.Packet, NewError(ErrorNetwork, errors.New("ldap: connection timed out"))}, time.Duration(l.getTimeout()))
delete(l.messageContexts, message.MessageID)
close(msgCtx.responses)
}
Expand Down
26 changes: 26 additions & 0 deletions conn_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -79,6 +79,32 @@ func TestInvalidStateCloseDeadlock(t *testing.T) {
conn.Close()
}

func TestRequestTimeoutDeadlock(t *testing.T) {
// The do-nothing server that accepts requests and does nothing
ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
}))
defer ts.Close()
c, err := net.Dial(ts.Listener.Addr().Network(), ts.Listener.Addr().String())
if err != nil {
t.Fatalf("error connecting to localhost tcp: %v", err)
}

// Create an Ldap connection
conn := NewConn(c, false)
conn.Start()
// trigger a race condition on accessing request timeout
n := 3
for i := 0; i < n; i++ {
go func() {
conn.SetTimeout(time.Millisecond)
}()
}

// Attempt to close the connection when the message handler is
// blocked or inactive
conn.Close()
}

// TestInvalidStateSendResponseDeadlock tests that we do not enter deadlock when the
// message handler is blocked or inactive.
func TestInvalidStateSendResponseDeadlock(t *testing.T) {
Expand Down
21 changes: 13 additions & 8 deletions v3/conn.go
Original file line number Diff line number Diff line change
Expand Up @@ -288,9 +288,9 @@ func (l *Conn) Close() (err error) {
l.chanMessage <- &messagePacket{Op: MessageQuit}

timeoutCtx := context.Background()
if l.requestTimeout > 0 {
if l.getTimeout() > 0 {
var cancelFunc context.CancelFunc
timeoutCtx, cancelFunc = context.WithTimeout(timeoutCtx, time.Duration(l.requestTimeout))
timeoutCtx, cancelFunc = context.WithTimeout(timeoutCtx, time.Duration(l.getTimeout()))
defer cancelFunc()
}
select {
Expand All @@ -316,6 +316,10 @@ func (l *Conn) SetTimeout(timeout time.Duration) {
atomic.StoreInt64(&l.requestTimeout, int64(timeout))
}

func (l *Conn) getTimeout() int64 {
return atomic.LoadInt64(&l.requestTimeout)
}

// Returns the next available messageID
func (l *Conn) nextMessageID() int64 {
if messageID, ok := <-l.chanMessageID; ok {
Expand Down Expand Up @@ -486,7 +490,7 @@ func (l *Conn) processMessages() {
// If we are closing due to an error, inform anyone who
// is waiting about the error.
if l.IsClosing() && l.closeErr.Load() != nil {
msgCtx.sendResponse(&PacketResponse{Error: l.closeErr.Load().(error)}, time.Duration(l.requestTimeout))
msgCtx.sendResponse(&PacketResponse{Error: l.closeErr.Load().(error)}, time.Duration(l.getTimeout()))
}
l.Debug.Printf("Closing channel for MessageID %d", messageID)
close(msgCtx.responses)
Expand Down Expand Up @@ -514,7 +518,7 @@ func (l *Conn) processMessages() {
_, err := l.conn.Write(buf)
if err != nil {
l.Debug.Printf("Error Sending Message: %s", err.Error())
message.Context.sendResponse(&PacketResponse{Error: fmt.Errorf("unable to send request: %s", err)}, time.Duration(l.requestTimeout))
message.Context.sendResponse(&PacketResponse{Error: fmt.Errorf("unable to send request: %s", err)}, time.Duration(l.getTimeout()))
close(message.Context.responses)
break
}
Expand All @@ -524,9 +528,10 @@ func (l *Conn) processMessages() {
l.messageContexts[message.MessageID] = message.Context

// Add timeout if defined
if l.requestTimeout > 0 {
requestTimeout := l.getTimeout()
if requestTimeout > 0 {
go func() {
timer := time.NewTimer(time.Duration(l.requestTimeout))
timer := time.NewTimer(time.Duration(requestTimeout))
defer func() {
if err := recover(); err != nil {
l.err = fmt.Errorf("ldap: recovered panic in RequestTimeout: %v", err)
Expand All @@ -549,7 +554,7 @@ func (l *Conn) processMessages() {
case MessageResponse:
l.Debug.Printf("Receiving message %d", message.MessageID)
if msgCtx, ok := l.messageContexts[message.MessageID]; ok {
msgCtx.sendResponse(&PacketResponse{message.Packet, nil}, time.Duration(l.requestTimeout))
msgCtx.sendResponse(&PacketResponse{message.Packet, nil}, time.Duration(l.getTimeout()))
} else {
l.err = fmt.Errorf("ldap: received unexpected message %d, %v", message.MessageID, l.IsClosing())
l.Debug.PrintPacket(message.Packet)
Expand All @@ -559,7 +564,7 @@ func (l *Conn) processMessages() {
// All reads will return immediately
if msgCtx, ok := l.messageContexts[message.MessageID]; ok {
l.Debug.Printf("Receiving message timeout for %d", message.MessageID)
msgCtx.sendResponse(&PacketResponse{message.Packet, NewError(ErrorNetwork, errors.New("ldap: connection timed out"))}, time.Duration(l.requestTimeout))
msgCtx.sendResponse(&PacketResponse{message.Packet, NewError(ErrorNetwork, errors.New("ldap: connection timed out"))}, time.Duration(l.getTimeout()))
delete(l.messageContexts, message.MessageID)
close(msgCtx.responses)
}
Expand Down
26 changes: 26 additions & 0 deletions v3/conn_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -58,6 +58,32 @@ func TestUnresponsiveConnection(t *testing.T) {
}
}

func TestRequestTimeoutDeadlock(t *testing.T) {
// The do-nothing server that accepts requests and does nothing
ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
}))
defer ts.Close()
c, err := net.Dial(ts.Listener.Addr().Network(), ts.Listener.Addr().String())
if err != nil {
t.Fatalf("error connecting to localhost tcp: %v", err)
}

// Create an Ldap connection
conn := NewConn(c, false)
conn.Start()
// trigger a race condition on accessing request timeout
n := 3
for i := 0; i < n; i++ {
go func() {
conn.SetTimeout(time.Millisecond)
}()
}

// Attempt to close the connection when the message handler is
// blocked or inactive
conn.Close()
}

// TestInvalidStateCloseDeadlock tests that we do not enter deadlock when the
// message handler is blocked or inactive.
func TestInvalidStateCloseDeadlock(t *testing.T) {
Expand Down

0 comments on commit 03cc78c

Please sign in to comment.