forked from orlandos-nl/MongoKitten
/
Connection.swift
256 lines (216 loc) · 9.53 KB
/
Connection.swift
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
import BSON
import Foundation
import MongoCore
import NIO
import Logging
import Metrics
import NIOConcurrencyHelpers
#if canImport(NIOTransportServices) && os(iOS)
import Network
import NIOTransportServices
#else
import NIOSSL
#endif
public struct MongoHandshakeResult {
public let sent: Date
public let received: Date
public let handshake: ServerHandshake
public var interval: Double {
received.timeIntervalSince(sent)
}
init(sentAt sent: Date, handshake: ServerHandshake) {
self.sent = sent
self.received = Date()
self.handshake = handshake
}
}
public final actor MongoConnection: @unchecked Sendable {
public let serverApi: ServerApi?
/// The NIO channel
internal let channel: Channel
public nonisolated var logger: Logger { context.logger }
var queryTimer: Metrics.Timer?
public internal(set) var lastHeartbeat: MongoHandshakeResult?
public var queryTimeout: TimeAmount? = .seconds(30)
public var isMetricsEnabled = false {
didSet {
if isMetricsEnabled, !oldValue {
queryTimer = Metrics.Timer(label: "org.openkitten.mongokitten.core.queries")
} else {
queryTimer = nil
}
}
}
/// A LIFO (Last In, First Out) holder for sessions
public let sessionManager: MongoSessionManager
public nonisolated var implicitSession: MongoClientSession {
return sessionManager.implicitClientSession
}
public nonisolated var implicitSessionId: SessionIdentifier {
return implicitSession.sessionId
}
/// The current request ID, used to generate unique identifiers for MongoDB commands
private var currentRequestId: NIOAtomic<Int32> = .makeAtomic(value: 0)
internal let context: MongoClientContext
public var serverHandshake: ServerHandshake? {
get async { await context.serverHandshake }
}
public nonisolated var closeFuture: EventLoopFuture<Void> {
return channel.closeFuture
}
public nonisolated var eventLoop: EventLoop { return channel.eventLoop }
public var allocator: ByteBufferAllocator { return channel.allocator }
public let slaveOk = NIOAtomic<Bool>.makeAtomic(value: false)
internal func nextRequestId() -> Int32 {
return currentRequestId.add(1)
}
/// Creates a connection that can communicate with MongoDB over a channel
public init(channel: Channel, context: MongoClientContext, sessionManager: MongoSessionManager = .init(), api: ServerApi? = nil) {
self.sessionManager = sessionManager
self.channel = channel
self.context = context
self.serverApi = api
}
public static func addHandlers(to channel: Channel, context: MongoClientContext) -> EventLoopFuture<Void> {
let parser = ClientConnectionParser(context: context)
return channel.pipeline.addHandler(ByteToMessageHandler(parser))
}
public static func connect(
settings: ConnectionSettings,
logger: Logger = Logger(label: "org.openkitten.mongokitten.connection"),
resolver: Resolver? = nil,
clientDetails: MongoClientDetails? = nil,
api: ServerApi? = nil
) async throws -> MongoConnection {
#if canImport(NIOTransportServices) && os(iOS)
return try await connect(settings: settings, logger: logger, onGroup: NIOTSEventLoopGroup(loopCount: 1), resolver: resolver, clientDetails: clientDetails)
#else
return try await connect(settings: settings, logger: logger, onGroup: MultiThreadedEventLoopGroup(numberOfThreads: 1), resolver: resolver, clientDetails: clientDetails, api: api)
#endif
}
internal static func connect(
settings: ConnectionSettings,
logger: Logger = Logger(label: "org.openkitten.mongokitten.connection"),
onGroup group: _MongoPlatformEventLoopGroup,
resolver: Resolver? = nil,
clientDetails: MongoClientDetails? = nil,
sessionManager: MongoSessionManager = .init(),
api: ServerApi? = nil
) async throws -> MongoConnection {
let context = MongoClientContext(logger: logger)
#if canImport(NIOTransportServices) && os(iOS)
var bootstrap = NIOTSConnectionBootstrap(group: group)
if settings.useSSL {
bootstrap = bootstrap.tlsOptions(NWProtocolTLS.Options())
}
#else
let bootstrap = ClientBootstrap(group: group)
.resolver(resolver)
#endif
guard let host = settings.hosts.first else {
logger.critical("Cannot connect to MongoDB: No host specified")
throw MongoError(.cannotConnect, reason: .noHostSpecified)
}
let channel = try await bootstrap
.channelOption(ChannelOptions.socket(SocketOptionLevel(SOL_SOCKET), SO_REUSEADDR), value: 1)
.channelInitializer { channel in
#if !canImport(NIOTransportServices) && !os(iOS)
if settings.useSSL {
do {
var configuration = TLSConfiguration.clientDefault
if let caCert = settings.sslCaCertificate {
configuration.trustRoots = NIOSSLTrustRoots.certificates([caCert])
} else if let caCertPath = settings.sslCaCertificatePath {
configuration.trustRoots = NIOSSLTrustRoots.file(caCertPath)
}
let handler = try NIOSSLClientHandler(context: NIOSSLContext(configuration: configuration), serverHostname: host.hostname)
return channel.pipeline.addHandler(handler).flatMap {
return MongoConnection.addHandlers(to: channel, context: context)
}
} catch {
return channel.eventLoop.makeFailedFuture(error)
}
}
#endif
return MongoConnection.addHandlers(to: channel, context: context)
}.connect(host: host.hostname, port: host.port).get()
let connection = MongoConnection(
channel: channel,
context: context,
sessionManager: sessionManager,
api: api
)
try await connection.authenticate(
clientDetails: clientDetails,
using: settings.authentication,
to: settings.authenticationSource ?? "admin"
)
return connection
}
/// Executes a MongoDB `isMaster`
///
/// - SeeAlso: https://github.com/mongodb/specifications/blob/master/source/mongodb-handshake/handshake.rst
public func doHandshake(
clientDetails: MongoClientDetails?,
credentials: ConnectionSettings.Authentication,
authenticationDatabase: String = "admin"
) async throws -> ServerHandshake {
let userNamespace: String?
if case .auto(let user, _) = credentials {
userNamespace = "\(authenticationDatabase).\(user)"
} else {
userNamespace = nil
}
// NO session must be used here: https://github.com/mongodb/specifications/blob/master/source/sessions/driver-sessions.rst#when-opening-and-authenticating-a-connection
// Forced on the current connection
let sent = Date()
let result = try await executeCodable(
IsMaster(
clientDetails: clientDetails,
userNamespace: userNamespace
),
decodeAs: ServerHandshake.self,
namespace: .administrativeCommand,
sessionId: nil
)
self.lastHeartbeat = MongoHandshakeResult(sentAt: sent, handshake: result)
return result
}
public func authenticate(
clientDetails: MongoClientDetails?,
using credentials: ConnectionSettings.Authentication,
to authenticationDatabase: String = "admin"
) async throws {
let handshake = try await doHandshake(
clientDetails: clientDetails,
credentials: credentials,
authenticationDatabase: authenticationDatabase
)
await self.context.setServerHandshake(to: handshake)
try await self.authenticate(to: authenticationDatabase, serverHandshake: handshake, with: credentials)
}
func executeMessage<Request: MongoRequestMessage>(_ message: Request) async throws -> MongoServerReply {
if await self.context.didError {
channel.close(mode: .all, promise: nil)
throw MongoError(.queryFailure, reason: .connectionClosed)
}
let promise = self.eventLoop.makePromise(of: MongoServerReply.self)
await self.context.setReplyCallback(forRequestId: message.header.requestId, completing: promise)
var buffer = self.channel.allocator.buffer(capacity: Int(message.header.messageLength))
message.write(to: &buffer)
try await self.channel.writeAndFlush(buffer)
if let queryTimeout = queryTimeout {
Task {
try await Task.sleep(nanoseconds: UInt64(queryTimeout.nanoseconds))
promise.fail(MongoError(.queryTimeout, reason: nil))
}
}
return try await promise.futureResult.get()
}
public func close() async {
_ = try? await self.channel.close()
}
deinit {
channel.close(mode: .all, promise: nil)
}
}