-
Notifications
You must be signed in to change notification settings - Fork 1.7k
/
engine.go
310 lines (277 loc) · 10.9 KB
/
engine.go
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
/*
Copyright 2021 Gravitational, Inc.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
*/
package mongodb
import (
"context"
"net"
"github.com/gravitational/trace"
"github.com/prometheus/client_golang/prometheus"
"go.mongodb.org/mongo-driver/bson"
"go.mongodb.org/mongo-driver/mongo/description"
"go.mongodb.org/mongo-driver/x/mongo/driver"
"github.com/gravitational/teleport/lib/services"
"github.com/gravitational/teleport/lib/srv/db/common"
"github.com/gravitational/teleport/lib/srv/db/common/role"
"github.com/gravitational/teleport/lib/srv/db/mongodb/protocol"
"github.com/gravitational/teleport/lib/utils"
)
// NewEngine create new MongoDB engine.
func NewEngine(ec common.EngineConfig) common.Engine {
return &Engine{
EngineConfig: ec,
maxMessageSize: protocol.DefaultMaxMessageSizeBytes,
}
}
// Engine implements the MongoDB database service that accepts client
// connections coming over reverse tunnel from the proxy and proxies
// them between the proxy and the MongoDB database instance.
//
// Implements common.Engine.
type Engine struct {
// EngineConfig is the common database engine configuration.
common.EngineConfig
// clientConn is an incoming client connection.
clientConn net.Conn
// maxMessageSize is the max message size.
maxMessageSize uint32
// serverConnected specifies whether server connection has been created.
serverConnected bool
}
// InitializeConnection initializes the client connection.
func (e *Engine) InitializeConnection(clientConn net.Conn, _ *common.Session) error {
e.clientConn = clientConn
return nil
}
// SendError sends an error to the connected client in MongoDB understandable format.
func (e *Engine) SendError(err error) {
if err != nil && !utils.IsOKNetworkError(err) {
e.replyError(e.clientConn, nil, err)
}
}
// HandleConnection processes the connection from MongoDB proxy coming
// over reverse tunnel.
//
// It handles all necessary startup actions, authorization and acts as a
// middleman between the proxy and the database intercepting and interpreting
// all messages i.e. doing protocol parsing.
func (e *Engine) HandleConnection(ctx context.Context, sessionCtx *common.Session) error {
observe := common.GetConnectionSetupTimeObserver(sessionCtx.Database)
// Check that the user has access to the database.
err := e.authorizeConnection(ctx, sessionCtx)
if err != nil {
return trace.Wrap(err, "error authorizing database access")
}
// Establish connection to the MongoDB server.
serverConn, closeFn, err := e.connect(ctx, sessionCtx)
if err != nil {
return trace.Wrap(err, "error connecting to the database")
}
defer closeFn()
e.Audit.OnSessionStart(e.Context, sessionCtx, nil)
defer e.Audit.OnSessionEnd(e.Context, sessionCtx)
e.serverConnected = true
observe()
msgFromClient := common.GetMessagesFromClientMetric(sessionCtx.Database)
msgFromServer := common.GetMessagesFromServerMetric(sessionCtx.Database)
// Start reading client messages and sending them to server.
for {
clientMessage, err := protocol.ReadMessage(e.clientConn, e.maxMessageSize)
if err != nil {
return trace.Wrap(err)
}
err = e.handleClientMessage(ctx, sessionCtx, clientMessage, e.clientConn, serverConn, msgFromClient, msgFromServer)
if err != nil {
return trace.Wrap(err)
}
}
}
// handleClientMessage implements the client message's roundtrip which can go
// down a few different ways:
// 1. If the client's command is not allowed by user's role, we do not pass it
// to the server and return an error to the client.
// 2. In the most common case, we send client message to the server, read its
// reply and send it back to the client.
// 3. Some client commands do not receive a reply in which case we just return
// after sending message to the server and wait for next client message.
// 4. Server can also send multiple messages in a row in which case we exhaust
// them before returning to listen for next client message.
func (e *Engine) handleClientMessage(ctx context.Context, sessionCtx *common.Session, clientMessage protocol.Message, clientConn net.Conn, serverConn driver.Connection, msgFromClient prometheus.Counter, msgFromServer prometheus.Counter) error {
msgFromClient.Inc()
// First check the client command against user's role and log in the audit.
err := e.authorizeClientMessage(sessionCtx, clientMessage)
if err != nil {
return protocol.ReplyError(clientConn, clientMessage, err)
}
// If RBAC is ok, pass the message to the server.
err = serverConn.WriteWireMessage(ctx, clientMessage.GetBytes())
if err != nil {
return trace.Wrap(err)
}
// Some client messages will not receive a reply.
if clientMessage.MoreToCome(nil) {
return nil
}
// Otherwise read the server's reply...
serverMessage, err := protocol.ReadServerMessage(ctx, serverConn, e.maxMessageSize)
if err != nil {
return trace.Wrap(err)
}
msgFromServer.Inc()
// Intercept handshake server response to proper configure the engine.
if protocol.IsHandshake(clientMessage) {
e.processHandshakeResponse(ctx, serverMessage)
}
// ... and pass it back to the client.
_, err = clientConn.Write(serverMessage.GetBytes())
if err != nil {
return trace.Wrap(err)
}
// Keep reading if server indicated it has more to send.
for serverMessage.MoreToCome(clientMessage) {
serverMessage, err = protocol.ReadServerMessage(ctx, serverConn, e.maxMessageSize)
if err != nil {
return trace.Wrap(err)
}
msgFromServer.Inc()
_, err = clientConn.Write(serverMessage.GetBytes())
if err != nil {
return trace.Wrap(err)
}
}
return nil
}
// processHandshakeResponse process handshake message and set engine values.
func (e *Engine) processHandshakeResponse(ctx context.Context, respMessage protocol.Message) {
var rawMessage bson.Raw
switch resp := respMessage.(type) {
// OP_REPLY is used on legacy handshake messages (deprecated on MongoDB 5.0)
case *protocol.MessageOpReply:
if len(resp.Documents) == 0 {
e.Log.Warn("Empty MongoDB handshake response.")
return
}
// Handshake messages are always the first document on a reply.
rawMessage = bson.Raw(resp.Documents[0])
// OP_MSG is used on modern handshake messages.
case *protocol.MessageOpMsg:
rawMessage = bson.Raw(resp.BodySection.Document)
default:
e.Log.Warn("Unabled to process MongoDB handshake response. Unexpected message type %T", respMessage)
return
}
// Use the description server to parse the handshake message. The address is
// not validated and won't be used by the engine.
serverDescription := description.NewServer("", rawMessage)
// Only overwrite engine configuration if handshake has value set.
if serverDescription.MaxMessageSize > 0 {
e.maxMessageSize = serverDescription.MaxMessageSize
}
}
// authorizeConnection does authorization check for MongoDB connection about
// to be established.
func (e *Engine) authorizeConnection(ctx context.Context, sessionCtx *common.Session) error {
authPref, err := e.Auth.GetAuthPreference(ctx)
if err != nil {
return trace.Wrap(err)
}
state := sessionCtx.GetAccessState(authPref)
// Only the username is checked upon initial connection. MongoDB sends
// database name with each protocol message (for query, update, etc.)
// so it is checked when we receive a message from client.
err = sessionCtx.Checker.CheckAccess(
sessionCtx.Database,
state,
services.NewDatabaseUserMatcher(sessionCtx.Database, sessionCtx.DatabaseUser),
)
if err != nil {
e.Audit.OnSessionStart(e.Context, sessionCtx, err)
return trace.Wrap(err)
}
return nil
}
// authorizeClientMessage checks if the user can run the provided MongoDB command.
//
// Each MongoDB command contains information about the database it's run in
// so we check it against allowed databases in the user's role.
func (e *Engine) authorizeClientMessage(sessionCtx *common.Session, message protocol.Message) error {
// Each client message should have database information in it.
database, err := message.GetDatabase()
if err != nil {
return trace.Wrap(err)
}
err = e.checkClientMessage(sessionCtx, message, database)
defer e.Audit.OnQuery(e.Context, sessionCtx, common.Query{
Database: database,
Query: message.String(),
Error: err,
})
return trace.Wrap(err)
}
func (e *Engine) checkClientMessage(sessionCtx *common.Session, message protocol.Message, database string) error {
// Legacy OP_KILL_CURSORS command doesn't contain database information.
if _, ok := message.(*protocol.MessageOpKillCursors); ok {
return sessionCtx.Checker.CheckAccess(sessionCtx.Database,
services.AccessState{MFAVerified: true},
services.NewDatabaseUserMatcher(sessionCtx.Database, sessionCtx.DatabaseUser),
)
}
// Do not allow certain commands that deal with authentication.
command, err := message.GetCommand()
if err != nil {
return trace.Wrap(err)
}
switch command {
case "authenticate", "saslStart", "saslContinue", "logout":
return trace.AccessDenied("access denied")
}
// Otherwise authorize the command against allowed databases.
return sessionCtx.Checker.CheckAccess(sessionCtx.Database,
services.AccessState{MFAVerified: true},
role.DatabaseRoleMatchers(
sessionCtx.Database,
sessionCtx.DatabaseUser,
database)...)
}
func (e *Engine) waitForAnyClientMessage(clientConn net.Conn) protocol.Message {
clientMessage, err := protocol.ReadMessage(clientConn, e.maxMessageSize)
if err != nil {
e.Log.Warnf("Failed to read a message for reply: %v.", err)
}
return clientMessage
}
// replyError sends the error to client. It is currently assumed that this
// function will only be called when HandleConnection terminates.
func (e *Engine) replyError(clientConn net.Conn, replyTo protocol.Message, err error) {
// If an error happens during server connection, wait for a client message
// before replying to ensure the client can interpret the reply.
// The first message is usually the isMaster hello message.
if replyTo == nil && !e.serverConnected {
waitChan := make(chan protocol.Message, 1)
go func() {
waitChan <- e.waitForAnyClientMessage(clientConn)
}()
select {
case clientMessage := <-waitChan:
replyTo = clientMessage
case <-e.Clock.After(common.DefaultMongoDBServerSelectionTimeout):
e.Log.Warnf("Timed out waiting for client message to reply err %v.", err)
// Make sure the connection is closed so waitForAnyClientMessage
// doesn't get stuck.
defer clientConn.Close()
}
}
errSend := protocol.ReplyError(clientConn, replyTo, err)
if errSend != nil {
e.Log.WithError(errSend).Errorf("Failed to send error message to MongoDB client: %v.", err)
}
}