Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[v14] Fix an issue tsh db connect <mongodb> does not give reason on connection errors #34910

Merged
merged 3 commits into from Nov 23, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
33 changes: 33 additions & 0 deletions lib/srv/db/mongodb/engine.go
Expand Up @@ -53,6 +53,8 @@ type Engine struct {
clientConn net.Conn
// maxMessageSize is the max message size.
maxMessageSize uint32
// serverConnected specifies whether server connection has been created.
serverConnected bool
}

// InitializeConnection initializes the client connection.
Expand Down Expand Up @@ -91,6 +93,7 @@ func (e *Engine) HandleConnection(ctx context.Context, sessionCtx *common.Sessio
e.Audit.OnSessionStart(e.Context, sessionCtx, nil)
defer e.Audit.OnSessionEnd(e.Context, sessionCtx)

e.serverConnected = true
observe()

msgFromClient := common.GetMessagesFromClientMetric(sessionCtx.Database)
Expand Down Expand Up @@ -269,7 +272,37 @@ func (e *Engine) checkClientMessage(sessionCtx *common.Session, message protocol
database)...)
}

func (e *Engine) waitForAnyClientMessage(clientConn net.Conn) protocol.Message {
clientMessage, err := protocol.ReadMessage(clientConn, e.maxMessageSize)
if err != nil {
e.Log.Warnf("Failed to read a message for reply: %v.", err)
}
return clientMessage
}

// replyError sends the error to client. It is currently assumed that this
// function will only be called when HandleConnection terminates.
func (e *Engine) replyError(clientConn net.Conn, replyTo protocol.Message, err error) {
// If an error happens during server connection, wait for a client message
// before replying to ensure the client can interpret the reply.
// The first message is usually the isMaster hello message.
if replyTo == nil && !e.serverConnected {
waitChan := make(chan protocol.Message, 1)
go func() {
waitChan <- e.waitForAnyClientMessage(clientConn)
}()

select {
case clientMessage := <-waitChan:
replyTo = clientMessage
case <-e.Clock.After(common.DefaultMongoDBServerSelectionTimeout):
e.Log.Warnf("Timed out waiting for client message to reply err %v.", err)
// Make sure the connection is closed so waitForAnyClientMessage
// doesn't get stuck.
defer clientConn.Close()
}
}

errSend := protocol.ReplyError(clientConn, replyTo, err)
if errSend != nil {
e.Log.WithError(errSend).Errorf("Failed to send error message to MongoDB client: %v.", err)
Expand Down
87 changes: 87 additions & 0 deletions lib/srv/db/mongodb/engine_test.go
@@ -0,0 +1,87 @@
/*
Copyright 2023 Gravitational, Inc.

Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at

http://www.apache.org/licenses/LICENSE-2.0

Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
*/

package mongodb

import (
"net"
"testing"

"github.com/gravitational/trace"
"github.com/jonboulle/clockwork"
"github.com/sirupsen/logrus"
"github.com/stretchr/testify/require"
"go.mongodb.org/mongo-driver/bson"

"github.com/gravitational/teleport/lib/srv/db/common"
"github.com/gravitational/teleport/lib/srv/db/mongodb/protocol"
)

func TestEngineReplyError(t *testing.T) {
connectError := trace.NotFound("user not found")

clientMsgDoc, err := bson.Marshal(bson.M{
"isMaster": 1,
})
require.NoError(t, err)
clientMsg := protocol.MakeOpMsg(clientMsgDoc)
clientMsg.Header.RequestID = 123456

t.Run("wait for client message", func(t *testing.T) {
t.Parallel()

e := NewEngine(common.EngineConfig{
Clock: clockwork.NewRealClock(),
Log: logrus.StandardLogger(),
}).(*Engine)

enginePipeEndConn, clientPipeEndConn := net.Pipe()
defer enginePipeEndConn.Close()
defer clientPipeEndConn.Close()

go e.replyError(enginePipeEndConn, nil, connectError)

_, err = clientPipeEndConn.Write(clientMsg.ToWire(0))
require.NoError(t, err)
msg, err := protocol.ReadMessage(clientPipeEndConn, protocol.DefaultMaxMessageSizeBytes)
require.NoError(t, err)
require.Equal(t, clientMsg.Header.RequestID, msg.GetHeader().ResponseTo)
require.Contains(t, msg.String(), connectError.Error())
})

t.Run("no wait", func(t *testing.T) {
t.Parallel()

e := NewEngine(common.EngineConfig{
Clock: clockwork.NewRealClock(),
Log: logrus.StandardLogger(),
}).(*Engine)
e.serverConnected = true

enginePipeEndConn, clientPipeEndConn := net.Pipe()
defer enginePipeEndConn.Close()
defer clientPipeEndConn.Close()

go e.replyError(enginePipeEndConn, nil, connectError)

// There is no need to write a message and reply does not respond to a
// message.
msg, err := protocol.ReadMessage(clientPipeEndConn, protocol.DefaultMaxMessageSizeBytes)
require.NoError(t, err)
require.Equal(t, int32(0), msg.GetHeader().ResponseTo)
require.Contains(t, msg.String(), connectError.Error())
})
}