diff --git a/lib/srv/db/mongodb/engine.go b/lib/srv/db/mongodb/engine.go index 9d5c5bd7bd868..ecbb632bb2c27 100644 --- a/lib/srv/db/mongodb/engine.go +++ b/lib/srv/db/mongodb/engine.go @@ -53,6 +53,8 @@ type Engine struct { clientConn net.Conn // maxMessageSize is the max message size. maxMessageSize uint32 + // serverConnected specifies whether server connection has been created. + serverConnected bool } // InitializeConnection initializes the client connection. @@ -91,6 +93,7 @@ func (e *Engine) HandleConnection(ctx context.Context, sessionCtx *common.Sessio e.Audit.OnSessionStart(e.Context, sessionCtx, nil) defer e.Audit.OnSessionEnd(e.Context, sessionCtx) + e.serverConnected = true observe() msgFromClient := common.GetMessagesFromClientMetric(sessionCtx.Database) @@ -269,7 +272,37 @@ func (e *Engine) checkClientMessage(sessionCtx *common.Session, message protocol database)...) } +func (e *Engine) waitForAnyClientMessage(clientConn net.Conn) protocol.Message { + clientMessage, err := protocol.ReadMessage(clientConn, e.maxMessageSize) + if err != nil { + e.Log.Warnf("Failed to read a message for reply: %v.", err) + } + return clientMessage +} + +// replyError sends the error to client. It is currently assumed that this +// function will only be called when HandleConnection terminates. func (e *Engine) replyError(clientConn net.Conn, replyTo protocol.Message, err error) { + // If an error happens during server connection, wait for a client message + // before replying to ensure the client can interpret the reply. + // The first message is usually the isMaster hello message. + if replyTo == nil && !e.serverConnected { + waitChan := make(chan protocol.Message, 1) + go func() { + waitChan <- e.waitForAnyClientMessage(clientConn) + }() + + select { + case clientMessage := <-waitChan: + replyTo = clientMessage + case <-e.Clock.After(common.DefaultMongoDBServerSelectionTimeout): + e.Log.Warnf("Timed out waiting for client message to reply err %v.", err) + // Make sure the connection is closed so waitForAnyClientMessage + // doesn't get stuck. + defer clientConn.Close() + } + } + errSend := protocol.ReplyError(clientConn, replyTo, err) if errSend != nil { e.Log.WithError(errSend).Errorf("Failed to send error message to MongoDB client: %v.", err) diff --git a/lib/srv/db/mongodb/engine_test.go b/lib/srv/db/mongodb/engine_test.go new file mode 100644 index 0000000000000..2cc62d2bcfdbe --- /dev/null +++ b/lib/srv/db/mongodb/engine_test.go @@ -0,0 +1,87 @@ +/* +Copyright 2023 Gravitational, Inc. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +*/ + +package mongodb + +import ( + "net" + "testing" + + "github.com/gravitational/trace" + "github.com/jonboulle/clockwork" + "github.com/sirupsen/logrus" + "github.com/stretchr/testify/require" + "go.mongodb.org/mongo-driver/bson" + + "github.com/gravitational/teleport/lib/srv/db/common" + "github.com/gravitational/teleport/lib/srv/db/mongodb/protocol" +) + +func TestEngineReplyError(t *testing.T) { + connectError := trace.NotFound("user not found") + + clientMsgDoc, err := bson.Marshal(bson.M{ + "isMaster": 1, + }) + require.NoError(t, err) + clientMsg := protocol.MakeOpMsg(clientMsgDoc) + clientMsg.Header.RequestID = 123456 + + t.Run("wait for client message", func(t *testing.T) { + t.Parallel() + + e := NewEngine(common.EngineConfig{ + Clock: clockwork.NewRealClock(), + Log: logrus.StandardLogger(), + }).(*Engine) + + enginePipeEndConn, clientPipeEndConn := net.Pipe() + defer enginePipeEndConn.Close() + defer clientPipeEndConn.Close() + + go e.replyError(enginePipeEndConn, nil, connectError) + + _, err = clientPipeEndConn.Write(clientMsg.ToWire(0)) + require.NoError(t, err) + msg, err := protocol.ReadMessage(clientPipeEndConn, protocol.DefaultMaxMessageSizeBytes) + require.NoError(t, err) + require.Equal(t, clientMsg.Header.RequestID, msg.GetHeader().ResponseTo) + require.Contains(t, msg.String(), connectError.Error()) + }) + + t.Run("no wait", func(t *testing.T) { + t.Parallel() + + e := NewEngine(common.EngineConfig{ + Clock: clockwork.NewRealClock(), + Log: logrus.StandardLogger(), + }).(*Engine) + e.serverConnected = true + + enginePipeEndConn, clientPipeEndConn := net.Pipe() + defer enginePipeEndConn.Close() + defer clientPipeEndConn.Close() + + go e.replyError(enginePipeEndConn, nil, connectError) + + // There is no need to write a message and reply does not respond to a + // message. + msg, err := protocol.ReadMessage(clientPipeEndConn, protocol.DefaultMaxMessageSizeBytes) + require.NoError(t, err) + require.Equal(t, int32(0), msg.GetHeader().ResponseTo) + require.Contains(t, msg.String(), connectError.Error()) + }) +}