From 56c56c22a6b93798006f57aa3336e815d8de9a23 Mon Sep 17 00:00:00 2001 From: STeve Huang Date: Mon, 6 Nov 2023 16:11:50 -0500 Subject: [PATCH 1/3] Fix an issue `tsh db connect ` does not give reason on connection errors --- lib/srv/db/mongodb/engine.go | 32 ++++++++++++++++++++++++++++++++ 1 file changed, 32 insertions(+) diff --git a/lib/srv/db/mongodb/engine.go b/lib/srv/db/mongodb/engine.go index 9d5c5bd7bd868..1759f187ab029 100644 --- a/lib/srv/db/mongodb/engine.go +++ b/lib/srv/db/mongodb/engine.go @@ -19,6 +19,7 @@ package mongodb import ( "context" "net" + "sync/atomic" "github.com/gravitational/trace" "github.com/prometheus/client_golang/prometheus" @@ -53,6 +54,8 @@ type Engine struct { clientConn net.Conn // maxMessageSize is the max message size. maxMessageSize uint32 + // serverConnected specifies whether server connection has been created. + serverConnected atomic.Bool } // InitializeConnection initializes the client connection. @@ -91,6 +94,7 @@ func (e *Engine) HandleConnection(ctx context.Context, sessionCtx *common.Sessio e.Audit.OnSessionStart(e.Context, sessionCtx, nil) defer e.Audit.OnSessionEnd(e.Context, sessionCtx) + e.serverConnected.Store(true) observe() msgFromClient := common.GetMessagesFromClientMetric(sessionCtx.Database) @@ -269,7 +273,35 @@ func (e *Engine) checkClientMessage(sessionCtx *common.Session, message protocol database)...) } +func (e *Engine) waitForAnyClientMessage(clientConn net.Conn, waitChan chan protocol.Message) { + clientMessage, err := protocol.ReadMessage(clientConn, e.maxMessageSize) + if err != nil { + e.Log.Warnf("Failed to read a message for reply: %v.", err) + waitChan <- nil + } else { + waitChan <- clientMessage + } +} + func (e *Engine) replyError(clientConn net.Conn, replyTo protocol.Message, err error) { + // If an error happens during server connection, wait for a client message + // before replying to ensure the client can interpret the reply. + // The first message is usually the isMaster hello message. + if replyTo == nil && !e.serverConnected.Load() { + waitChan := make(chan protocol.Message, 1) + go e.waitForAnyClientMessage(clientConn, waitChan) + + select { + case clientMessage := <-waitChan: + replyTo = clientMessage + case <-e.Clock.After(common.DefaultMongoDBServerSelectionTimeout): + e.Log.Warnf("Timed out waiting for client message to reply err %v.", err) + // Make sure the connection is closed so waitForAnyClientMessage + // doesn't get stuck. + defer clientConn.Close() + } + } + errSend := protocol.ReplyError(clientConn, replyTo, err) if errSend != nil { e.Log.WithError(errSend).Errorf("Failed to send error message to MongoDB client: %v.", err) From e299286c2d6e0205b9813034115f676361e74126 Mon Sep 17 00:00:00 2001 From: STeve Huang Date: Fri, 10 Nov 2023 12:51:48 -0500 Subject: [PATCH 2/3] add ut --- lib/srv/db/mongodb/engine.go | 10 ++-- lib/srv/db/mongodb/engine_test.go | 79 +++++++++++++++++++++++++++++++ 2 files changed, 84 insertions(+), 5 deletions(-) create mode 100644 lib/srv/db/mongodb/engine_test.go diff --git a/lib/srv/db/mongodb/engine.go b/lib/srv/db/mongodb/engine.go index 1759f187ab029..abdacd8b6fd05 100644 --- a/lib/srv/db/mongodb/engine.go +++ b/lib/srv/db/mongodb/engine.go @@ -273,14 +273,12 @@ func (e *Engine) checkClientMessage(sessionCtx *common.Session, message protocol database)...) } -func (e *Engine) waitForAnyClientMessage(clientConn net.Conn, waitChan chan protocol.Message) { +func (e *Engine) waitForAnyClientMessage(clientConn net.Conn) protocol.Message { clientMessage, err := protocol.ReadMessage(clientConn, e.maxMessageSize) if err != nil { e.Log.Warnf("Failed to read a message for reply: %v.", err) - waitChan <- nil - } else { - waitChan <- clientMessage } + return clientMessage } func (e *Engine) replyError(clientConn net.Conn, replyTo protocol.Message, err error) { @@ -289,7 +287,9 @@ func (e *Engine) replyError(clientConn net.Conn, replyTo protocol.Message, err e // The first message is usually the isMaster hello message. if replyTo == nil && !e.serverConnected.Load() { waitChan := make(chan protocol.Message, 1) - go e.waitForAnyClientMessage(clientConn, waitChan) + go func() { + waitChan <- e.waitForAnyClientMessage(clientConn) + }() select { case clientMessage := <-waitChan: diff --git a/lib/srv/db/mongodb/engine_test.go b/lib/srv/db/mongodb/engine_test.go new file mode 100644 index 0000000000000..2312514bf62b9 --- /dev/null +++ b/lib/srv/db/mongodb/engine_test.go @@ -0,0 +1,79 @@ +/* +Copyright 2023 Gravitational, Inc. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +*/ + +package mongodb + +import ( + "net" + "testing" + + "github.com/gravitational/trace" + "github.com/jonboulle/clockwork" + "github.com/sirupsen/logrus" + "github.com/stretchr/testify/require" + "go.mongodb.org/mongo-driver/bson" + + "github.com/gravitational/teleport/lib/srv/db/common" + "github.com/gravitational/teleport/lib/srv/db/mongodb/protocol" +) + +func TestEngineReplyError(t *testing.T) { + connectError := trace.NotFound("user not found") + + clientMsgDoc, err := bson.Marshal(bson.M{ + "isMaster": 1, + }) + require.NoError(t, err) + clientMsg := protocol.MakeOpMsg(clientMsgDoc) + clientMsg.Header.RequestID = 123456 + + e := NewEngine(common.EngineConfig{ + Clock: clockwork.NewRealClock(), + Log: logrus.StandardLogger(), + }).(*Engine) + + t.Run("wait for client message", func(t *testing.T) { + enginePipeEndConn, clientPipeEndConn := net.Pipe() + defer enginePipeEndConn.Close() + defer clientPipeEndConn.Close() + + go e.replyError(enginePipeEndConn, nil, connectError) + + _, err = clientPipeEndConn.Write(clientMsg.ToWire(0)) + require.NoError(t, err) + msg, err := protocol.ReadMessage(clientPipeEndConn, protocol.DefaultMaxMessageSizeBytes) + require.NoError(t, err) + require.Equal(t, clientMsg.Header.RequestID, msg.GetHeader().ResponseTo) + require.Contains(t, msg.String(), connectError.Error()) + }) + + t.Run("no wait", func(t *testing.T) { + e.serverConnected.Store(true) + + enginePipeEndConn, clientPipeEndConn := net.Pipe() + defer enginePipeEndConn.Close() + defer clientPipeEndConn.Close() + + go e.replyError(enginePipeEndConn, nil, connectError) + + // There is no need to write a message and reply does not respond to a + // message. + msg, err := protocol.ReadMessage(clientPipeEndConn, protocol.DefaultMaxMessageSizeBytes) + require.NoError(t, err) + require.Equal(t, int32(0), msg.GetHeader().ResponseTo) + require.Contains(t, msg.String(), connectError.Error()) + }) +} From 2546fd33961e7a75185bf4265daa9bbb1b8228c7 Mon Sep 17 00:00:00 2001 From: STeve Huang Date: Mon, 20 Nov 2023 09:21:06 -0500 Subject: [PATCH 3/3] downgrade serverConnected to bool --- lib/srv/db/mongodb/engine.go | 9 +++++---- lib/srv/db/mongodb/engine_test.go | 20 ++++++++++++++------ 2 files changed, 19 insertions(+), 10 deletions(-) diff --git a/lib/srv/db/mongodb/engine.go b/lib/srv/db/mongodb/engine.go index abdacd8b6fd05..ecbb632bb2c27 100644 --- a/lib/srv/db/mongodb/engine.go +++ b/lib/srv/db/mongodb/engine.go @@ -19,7 +19,6 @@ package mongodb import ( "context" "net" - "sync/atomic" "github.com/gravitational/trace" "github.com/prometheus/client_golang/prometheus" @@ -55,7 +54,7 @@ type Engine struct { // maxMessageSize is the max message size. maxMessageSize uint32 // serverConnected specifies whether server connection has been created. - serverConnected atomic.Bool + serverConnected bool } // InitializeConnection initializes the client connection. @@ -94,7 +93,7 @@ func (e *Engine) HandleConnection(ctx context.Context, sessionCtx *common.Sessio e.Audit.OnSessionStart(e.Context, sessionCtx, nil) defer e.Audit.OnSessionEnd(e.Context, sessionCtx) - e.serverConnected.Store(true) + e.serverConnected = true observe() msgFromClient := common.GetMessagesFromClientMetric(sessionCtx.Database) @@ -281,11 +280,13 @@ func (e *Engine) waitForAnyClientMessage(clientConn net.Conn) protocol.Message { return clientMessage } +// replyError sends the error to client. It is currently assumed that this +// function will only be called when HandleConnection terminates. func (e *Engine) replyError(clientConn net.Conn, replyTo protocol.Message, err error) { // If an error happens during server connection, wait for a client message // before replying to ensure the client can interpret the reply. // The first message is usually the isMaster hello message. - if replyTo == nil && !e.serverConnected.Load() { + if replyTo == nil && !e.serverConnected { waitChan := make(chan protocol.Message, 1) go func() { waitChan <- e.waitForAnyClientMessage(clientConn) diff --git a/lib/srv/db/mongodb/engine_test.go b/lib/srv/db/mongodb/engine_test.go index 2312514bf62b9..2cc62d2bcfdbe 100644 --- a/lib/srv/db/mongodb/engine_test.go +++ b/lib/srv/db/mongodb/engine_test.go @@ -40,12 +40,14 @@ func TestEngineReplyError(t *testing.T) { clientMsg := protocol.MakeOpMsg(clientMsgDoc) clientMsg.Header.RequestID = 123456 - e := NewEngine(common.EngineConfig{ - Clock: clockwork.NewRealClock(), - Log: logrus.StandardLogger(), - }).(*Engine) - t.Run("wait for client message", func(t *testing.T) { + t.Parallel() + + e := NewEngine(common.EngineConfig{ + Clock: clockwork.NewRealClock(), + Log: logrus.StandardLogger(), + }).(*Engine) + enginePipeEndConn, clientPipeEndConn := net.Pipe() defer enginePipeEndConn.Close() defer clientPipeEndConn.Close() @@ -61,7 +63,13 @@ func TestEngineReplyError(t *testing.T) { }) t.Run("no wait", func(t *testing.T) { - e.serverConnected.Store(true) + t.Parallel() + + e := NewEngine(common.EngineConfig{ + Clock: clockwork.NewRealClock(), + Log: logrus.StandardLogger(), + }).(*Engine) + e.serverConnected = true enginePipeEndConn, clientPipeEndConn := net.Pipe() defer enginePipeEndConn.Close()