diff --git a/Sources/MongoSwift/ClientSession.swift b/Sources/MongoSwift/ClientSession.swift index 6273a4056..da8f44a68 100644 --- a/Sources/MongoSwift/ClientSession.swift +++ b/Sources/MongoSwift/ClientSession.swift @@ -70,6 +70,80 @@ public final class ClientSession { /// started the libmongoc session. internal var id: Document? + /// The server ID of the mongos this session is pinned to. A server ID of 0 indicates that the session is unpinned. + internal var serverId: UInt32? { + switch self.state { + case .notStarted, .ended: + return nil + case let .started(session, _): + return mongoc_client_session_get_server_id(session) + } + } + + /// Enum tracking the state of the transaction associated with this session. + internal enum TransactionState: String, Decodable { + /// There is no transaction in progress. + case none + /// A transaction has been started, but no operation has been sent to the server. + case starting + /// A transaction is in progress. + case inProgress + /// The transaction was committed. + case committed + /// The transaction was aborted. + case aborted + + fileprivate var mongocTransactionState: mongoc_transaction_state_t { + switch self { + case .none: + return MONGOC_TRANSACTION_NONE + case .starting: + return MONGOC_TRANSACTION_STARTING + case .inProgress: + return MONGOC_TRANSACTION_IN_PROGRESS + case .committed: + return MONGOC_TRANSACTION_COMMITTED + case .aborted: + return MONGOC_TRANSACTION_ABORTED + } + } + + fileprivate init(mongocTransactionState: mongoc_transaction_state_t) { + switch mongocTransactionState { + case MONGOC_TRANSACTION_NONE: + self = .none + case MONGOC_TRANSACTION_STARTING: + self = .starting + case MONGOC_TRANSACTION_IN_PROGRESS: + self = .inProgress + case MONGOC_TRANSACTION_COMMITTED: + self = .committed + case MONGOC_TRANSACTION_ABORTED: + self = .aborted + default: + fatalError("Unexpected transaction state: \(mongocTransactionState)") + } + } + } + + /// The transaction state of this session. + internal var transactionState: TransactionState? { + switch self.state { + case .notStarted, .ended: + return nil + case let .started(session, _): + return TransactionState(mongocTransactionState: mongoc_client_session_get_transaction_state(session)) + } + } + + /// Indicates whether or not the session is in a transaction. + internal var inTransaction: Bool { + if let transactionState = self.transactionState { + return transactionState != .none + } + return false + } + /// The most recent cluster time seen by this session. This value will be nil if either of the following are true: /// - No operations have been executed using this session and `advanceClusterTime` has not been called. /// - This session has been ended. @@ -243,7 +317,7 @@ public final class ClientSession { * - SeeAlso: * - https://docs.mongodb.com/manual/core/transactions/ */ - public func startTransaction(_ options: TransactionOptions?) -> EventLoopFuture { + public func startTransaction(options: TransactionOptions? = nil) -> EventLoopFuture { switch self.state { case .notStarted, .started: let operation = StartTransactionOperation(options: options) diff --git a/Sources/MongoSwift/MongoClient.swift b/Sources/MongoSwift/MongoClient.swift index 56383c429..7268b381f 100644 --- a/Sources/MongoSwift/MongoClient.swift +++ b/Sources/MongoSwift/MongoClient.swift @@ -3,16 +3,14 @@ import NIO import NIOConcurrencyHelpers /// Options to use when creating a `MongoClient`. -public struct ClientOptions: CodingStrategyProvider, Decodable { - // swiftlint:disable redundant_optional_initialization - +public struct ClientOptions: CodingStrategyProvider { /// Specifies the `DataCodingStrategy` to use for BSON encoding/decoding operations performed by this client and any /// databases or collections that derive from it. - public var dataCodingStrategy: DataCodingStrategy? = nil + public var dataCodingStrategy: DataCodingStrategy? /// Specifies the `DateCodingStrategy` to use for BSON encoding/decoding operations performed by this client and any /// databases or collections that derive from it. - public var dateCodingStrategy: DateCodingStrategy? = nil + public var dateCodingStrategy: DateCodingStrategy? /// The maximum number of connections that may be associated with a connection pool created by this client at a /// given time. This includes in-use and available connections. Defaults to 100. @@ -22,7 +20,7 @@ public struct ClientOptions: CodingStrategyProvider, Decodable { public var readConcern: ReadConcern? /// Specifies a ReadPreference to use for the client. - public var readPreference: ReadPreference? = nil + public var readPreference: ReadPreference? /// Determines whether the client should retry supported read operations (on by default). public var retryReads: Bool? @@ -65,7 +63,7 @@ public struct ClientOptions: CodingStrategyProvider, Decodable { /// Specifies the `UUIDCodingStrategy` to use for BSON encoding/decoding operations performed by this client and any /// databases or collections that derive from it. - public var uuidCodingStrategy: UUIDCodingStrategy? = nil + public var uuidCodingStrategy: UUIDCodingStrategy? // swiftlint:enable redundant_optional_initialization diff --git a/Sources/MongoSwift/MongoCollection+BulkWrite.swift b/Sources/MongoSwift/MongoCollection+BulkWrite.swift index 08047d2b0..827dd09e1 100644 --- a/Sources/MongoSwift/MongoCollection+BulkWrite.swift +++ b/Sources/MongoSwift/MongoCollection+BulkWrite.swift @@ -197,6 +197,13 @@ internal struct BulkWriteOperation: Operation { let opts = try encodeOptions(options: options, session: session) var insertedIds: [Int: BSON] = [:] + if session?.inTransaction == true && self.options?.writeConcern != nil { + throw InvalidArgumentError( + message: "Cannot specify a write concern on an individual helper in a " + + "transaction. Instead specify it when starting the transaction." + ) + } + let (serverId, isAcknowledged): (UInt32, Bool) = try self.collection.withMongocCollection(from: connection) { collPtr in guard let bulk = mongoc_collection_create_bulk_operation_with_opts(collPtr, opts?._bson) else { @@ -214,8 +221,22 @@ internal struct BulkWriteOperation: Operation { mongoc_bulk_operation_execute(bulk, replyPtr, &error) } - let writeConcern = WriteConcern(from: mongoc_bulk_operation_get_write_concern(bulk)) - return (serverId, writeConcern.isAcknowledged) + var writeConcernAcknowledged: Bool + if session?.inTransaction == true { + // Bulk write operations in transactions must get their write concern from the session, not from + // the `BulkWriteOptions` passed to the `bulkWrite` helper. `libmongoc` surfaces this + // implementation detail by nulling out the write concern stored on the bulk write. To sidestep + // this, we can only call `mongoc_bulk_operation_get_write_concern` out of a transaction. + // + // In a transaction, default to writeConcernAcknowledged = true. This is acceptable because + // transactions do not support unacknowledged writes. + writeConcernAcknowledged = true + } else { + let writeConcern = WriteConcern(from: mongoc_bulk_operation_get_write_concern(bulk)) + writeConcernAcknowledged = writeConcern.isAcknowledged + } + + return (serverId, writeConcernAcknowledged) } let result = try BulkWriteResult(reply: reply, insertedIds: insertedIds) diff --git a/Sources/MongoSwift/MongoCollection+Read.swift b/Sources/MongoSwift/MongoCollection+Read.swift index 210a5b7e7..cdf947a15 100644 --- a/Sources/MongoSwift/MongoCollection+Read.swift +++ b/Sources/MongoSwift/MongoCollection+Read.swift @@ -114,11 +114,11 @@ extension MongoCollection { } /** - * Gets an estimate of the count of documents in this collection using collection metadata. + * Gets an estimate of the count of documents in this collection using collection metadata. This operation cannot + * be used in a transaction. * * - Parameters: * - options: Optional `EstimatedDocumentCountOptions` to use when executing the command - * - session: Optional `ClientSession` to use when executing this command * * - Returns: * An `EventLoopFuture`. On success, contains an estimate of the count of documents in this collection. @@ -126,16 +126,12 @@ extension MongoCollection { * If the future fails, the error is likely one of the following: * - `CommandError` if an error occurs that prevents the command from executing. * - `InvalidArgumentError` if the options passed in form an invalid combination. - * - `LogicError` if the provided session is inactive. * - `LogicError` if this collection's parent client has already been closed. * - `EncodingError` if an error occurs while encoding the options to BSON. */ - public func estimatedDocumentCount( - options: EstimatedDocumentCountOptions? = nil, - session: ClientSession? = nil - ) -> EventLoopFuture { + public func estimatedDocumentCount(options: EstimatedDocumentCountOptions? = nil) -> EventLoopFuture { let operation = EstimatedDocumentCountOperation(collection: self, options: options) - return self._client.operationExecutor.execute(operation, client: self._client, session: session) + return self._client.operationExecutor.execute(operation, client: self._client, session: nil) } /** diff --git a/Sources/MongoSwift/MongoError.swift b/Sources/MongoSwift/MongoError.swift index 9f4d734cf..7d8b6ffb3 100644 --- a/Sources/MongoSwift/MongoError.swift +++ b/Sources/MongoSwift/MongoError.swift @@ -96,7 +96,7 @@ public struct InternalError: RuntimeError { /// An error thrown when encountering a connection or socket related error. /// May contain labels providing additional information on the nature of the error. public struct ConnectionError: RuntimeError, LabeledError { - internal let message: String + public let message: String public let errorLabels: [String]? @@ -173,11 +173,36 @@ public struct WriteConcernFailure: Codable { /// A description of the error. public let message: String + /// Labels that may describe the context in which this error was thrown. + public let errorLabels: [String]? + private enum CodingKeys: String, CodingKey { case code case codeName case details = "errInfo" case message = "errmsg" + case errorLabels + } + + // TODO: can remove this once SERVER-36755 is resolved + public init(from decoder: Decoder) throws { + let container = try decoder.container(keyedBy: CodingKeys.self) + self.code = try container.decode(ServerErrorCode.self, forKey: .code) + self.message = try container.decode(String.self, forKey: .message) + self.codeName = try container.decodeIfPresent(String.self, forKey: .codeName) ?? "" + self.details = try container.decodeIfPresent(Document.self, forKey: .details) + self.errorLabels = try container.decodeIfPresent([String].self, forKey: .errorLabels) + } + + // TODO: can remove this once SERVER-36755 is resolved + internal init( + code: ServerErrorCode, codeName: String, details: Document?, message: String, errorLabels: [String]? = nil + ) { + self.code = code + self.codeName = codeName + self.message = message + self.details = details + self.errorLabels = errorLabels } } @@ -279,6 +304,7 @@ internal func extractMongoError(error bsonError: bson_error_t, reply: Document? // if the reply is nil or writeErrors or writeConcernErrors aren't present, then this is likely a commandError. guard let serverReply: Document = reply, !(serverReply["writeErrors"]?.arrayValue ?? []).isEmpty || + !(serverReply["writeConcernError"]?.documentValue?.keys ?? []).isEmpty || !(serverReply["writeConcernErrors"]?.arrayValue ?? []).isEmpty else { return parseMongocError(bsonError, reply: reply) } @@ -390,11 +416,14 @@ internal func extractBulkWriteError( /// Extracts a `WriteConcernError` from a server reply. private func extractWriteConcernError(from reply: Document) throws -> WriteConcernFailure? { - guard let writeConcernErrors = reply["writeConcernErrors"]?.arrayValue?.compactMap({ $0.documentValue }), - !writeConcernErrors.isEmpty else { + if let writeConcernErrors = reply["writeConcernErrors"]?.arrayValue?.compactMap({ $0.documentValue }), + !writeConcernErrors.isEmpty { + return try BSONDecoder().decode(WriteConcernFailure.self, from: writeConcernErrors[0]) + } else if let writeConcernError = reply["writeConcernError"]?.documentValue { + return try BSONDecoder().decode(WriteConcernFailure.self, from: writeConcernError) + } else { return nil } - return try BSONDecoder().decode(WriteConcernFailure.self, from: writeConcernErrors[0]) } /// Internal function used by write methods performing single writes that are implemented via the bulk API. If the diff --git a/Sources/MongoSwift/MongoSwiftVersion.swift b/Sources/MongoSwift/MongoSwiftVersion.swift index 318684e9d..152a6f17c 100644 --- a/Sources/MongoSwift/MongoSwiftVersion.swift +++ b/Sources/MongoSwift/MongoSwiftVersion.swift @@ -1,6 +1,4 @@ // Generated using Sourcery 0.16.1 — https://github.com/krzysztofzablocki/Sourcery // DO NOT EDIT - -// swiftlint:disable:previous vertical_whitespace internal let MongoSwiftVersionString = "1.0.0-rc0" diff --git a/Sources/MongoSwiftSync/ClientSession.swift b/Sources/MongoSwiftSync/ClientSession.swift index 4d0f94d1d..f996fe6ec 100644 --- a/Sources/MongoSwiftSync/ClientSession.swift +++ b/Sources/MongoSwiftSync/ClientSession.swift @@ -115,7 +115,7 @@ public final class ClientSession { * - https://docs.mongodb.com/manual/core/transactions/ */ public func startTransaction(options: TransactionOptions? = nil) throws { - try self.asyncSession.startTransaction(options).wait() + try self.asyncSession.startTransaction(options: options).wait() } /** diff --git a/Sources/MongoSwiftSync/MongoCollection+Read.swift b/Sources/MongoSwiftSync/MongoCollection+Read.swift index db055f658..854595162 100644 --- a/Sources/MongoSwiftSync/MongoCollection+Read.swift +++ b/Sources/MongoSwiftSync/MongoCollection+Read.swift @@ -94,19 +94,16 @@ extension MongoCollection { } /** - * Gets an estimate of the count of documents in this collection using collection metadata. + * Gets an estimate of the count of documents in this collection using collection metadata. This operation cannot + * be used in a transaction. * * - Parameters: * - options: Optional `EstimatedDocumentCountOptions` to use when executing the command - * - session: Optional `ClientSession` to use when executing this command * * - Returns: an estimate of the count of documents in this collection */ - public func estimatedDocumentCount( - options: EstimatedDocumentCountOptions? = nil, - session: ClientSession? = nil - ) throws -> Int { - try self.asyncColl.estimatedDocumentCount(options: options, session: session?.asyncSession).wait() + public func estimatedDocumentCount(options: EstimatedDocumentCountOptions? = nil) throws -> Int { + try self.asyncColl.estimatedDocumentCount(options: options).wait() } /** diff --git a/Tests/LinuxMain.swift b/Tests/LinuxMain.swift index 1e2f8274f..85d4a8171 100644 --- a/Tests/LinuxMain.swift +++ b/Tests/LinuxMain.swift @@ -366,6 +366,12 @@ extension SyncMongoClientTests { ] } +extension TransactionsTests { + static var allTests = [ + ("testTransactions", testTransactions), + ] +} + extension WriteConcernTests { static var allTests = [ ("testWriteConcernType", testWriteConcernType), @@ -408,5 +414,6 @@ XCTMain([ testCase(SyncChangeStreamTests.allTests), testCase(SyncClientSessionTests.allTests), testCase(SyncMongoClientTests.allTests), + testCase(TransactionsTests.allTests), testCase(WriteConcernTests.allTests), ]) diff --git a/Tests/MongoSwiftSyncTests/ClientSessionTests.swift b/Tests/MongoSwiftSyncTests/ClientSessionTests.swift index 5981558f5..6afd8c6fe 100644 --- a/Tests/MongoSwiftSyncTests/ClientSessionTests.swift +++ b/Tests/MongoSwiftSyncTests/ClientSessionTests.swift @@ -1,5 +1,5 @@ import Foundation -@testable import MongoSwift +@testable import class MongoSwift.ClientSession @testable import MongoSwiftSync import Nimble import TestsCommon @@ -7,29 +7,19 @@ import TestsCommon /// Describes an operation run on a collection that takes in a session. struct CollectionSessionOp { let name: String - let body: (MongoSwiftSync.MongoCollection, MongoSwiftSync.ClientSession?) throws -> Void + let body: (MongoCollection, MongoSwiftSync.ClientSession?) throws -> Void } /// Describes an operation run on a database that takes in a session. struct DatabaseSessionOp { let name: String - let body: (MongoSwiftSync.MongoDatabase, MongoSwiftSync.ClientSession?) throws -> Void + let body: (MongoDatabase, MongoSwiftSync.ClientSession?) throws -> Void } /// Describes an operation run on a client that takes in a session. struct ClientSessionOp { let name: String - let body: (MongoSwiftSync.MongoClient, MongoSwiftSync.ClientSession?) throws -> Void -} - -extension MongoSwiftSync.ClientSession { - var active: Bool { - self.asyncSession.active - } - - var id: Document? { - self.asyncSession.id - } + let body: (MongoClient, MongoSwiftSync.ClientSession?) throws -> Void } final class SyncClientSessionTests: MongoSwiftTestCase { @@ -53,8 +43,7 @@ final class SyncClientSessionTests: MongoSwiftTestCase { CollectionSessionOp(name: "findOne") { _ = try $0.findOne([:], session: $1) }, CollectionSessionOp(name: "aggregate") { _ = try $0.aggregate([], session: $1).next()?.get() }, CollectionSessionOp(name: "distinct") { _ = try $0.distinct(fieldName: "x", session: $1) }, - CollectionSessionOp(name: "countDocuments") { _ = try $0.countDocuments(session: $1) }, - CollectionSessionOp(name: "estimatedDocumentCount") { _ = try $0.estimatedDocumentCount(session: $1) } + CollectionSessionOp(name: "countDocuments") { _ = try $0.countDocuments(session: $1) } ] // list of write operations on MongoCollection that take in a session @@ -109,9 +98,9 @@ final class SyncClientSessionTests: MongoSwiftTestCase { /// iterate over all the different session op types, passing in the provided client/db/collection as needed. func forEachSessionOp( - client: MongoSwiftSync.MongoClient, - database: MongoSwiftSync.MongoDatabase, - collection: MongoSwiftSync.MongoCollection, + client: MongoClient, + database: MongoDatabase, + collection: MongoCollection, _ body: (SessionOp) throws -> Void ) rethrows { try (self.collectionSessionReadOps + self.collectionSessionWriteOps).forEach { op in @@ -240,7 +229,11 @@ final class SyncClientSessionTests: MongoSwiftTestCase { expect(session1.active).to(beFalse()) try self.forEachSessionOp(client: client, database: db, collection: collection) { op in - expect(try op.body(session1)).to(throwError(ClientSession.SessionInactiveError), description: op.name) + expect(try op.body(session1)).to( + throwError( + MongoSwift.ClientSession.SessionInactiveError), + description: op.name + ) } let session2 = client.startSession() @@ -252,7 +245,7 @@ final class SyncClientSessionTests: MongoSwiftTestCase { let cursor = try collection.find(session: session2) expect(cursor.next()).toNot(beNil()) session2.end() - expect(try cursor.next()?.get()).to(throwError(ClientSession.SessionInactiveError)) + expect(try cursor.next()?.get()).to(throwError(MongoSwift.ClientSession.SessionInactiveError)) } /// Sessions spec test 10: Test cursors have the same lsid in the initial find command and in subsequent getMores. diff --git a/Tests/MongoSwiftSyncTests/CommandMonitoringTests.swift b/Tests/MongoSwiftSyncTests/CommandMonitoringTests.swift index 34ddd7bae..12196d510 100644 --- a/Tests/MongoSwiftSyncTests/CommandMonitoringTests.swift +++ b/Tests/MongoSwiftSyncTests/CommandMonitoringTests.swift @@ -79,31 +79,6 @@ private struct CMTestFile: Decodable { } } -extension ReadPreference.Mode: Decodable {} - -extension ReadPreference: Decodable { - public init(from decoder: Decoder) throws { - let container = try decoder.container(keyedBy: CodingKeys.self) - let mode = try container.decode(Mode.self, forKey: .mode) - switch mode { - case .primary: - self = .primary - case .primaryPreferred: - self = .primaryPreferred - case .secondary: - self = .secondary - case .secondaryPreferred: - self = .secondaryPreferred - case .nearest: - self = .nearest - } - } - - private enum CodingKeys: String, CodingKey { - case mode - } -} - /// A struct to hold the data for a single test from a CMTestFile. private struct CMTest: Decodable { struct Operation: Decodable { diff --git a/Tests/MongoSwiftSyncTests/RetryableWritesTests.swift b/Tests/MongoSwiftSyncTests/RetryableWritesTests.swift index 3f2e8310b..1ba2afbaf 100644 --- a/Tests/MongoSwiftSyncTests/RetryableWritesTests.swift +++ b/Tests/MongoSwiftSyncTests/RetryableWritesTests.swift @@ -99,7 +99,10 @@ final class RetryableWritesTests: MongoSwiftTestCase, FailPointConfigured { var seenError: Error? do { - result = try test.operation.execute(on: .collection(collection), session: nil) + result = try test.operation.execute( + on: .collection(collection), + sessions: [:] + ) } catch { if let bulkError = error as? BulkWriteError { result = TestOperationResult(from: bulkError.result) diff --git a/Tests/MongoSwiftSyncTests/SpecTestRunner/CodableExtensions.swift b/Tests/MongoSwiftSyncTests/SpecTestRunner/CodableExtensions.swift new file mode 100644 index 000000000..9d09f3bd3 --- /dev/null +++ b/Tests/MongoSwiftSyncTests/SpecTestRunner/CodableExtensions.swift @@ -0,0 +1,101 @@ +@testable import struct MongoSwift.ReadPreference +import MongoSwiftSync + +extension DatabaseOptions: Decodable { + public init(from decoder: Decoder) throws { + let container = try decoder.container(keyedBy: CodingKeys.self) + let readConcern = try container.decodeIfPresent(ReadConcern.self, forKey: .readConcern) + let readPreference = try container.decodeIfPresent(ReadPreference.self, forKey: .readPreference) + let writeConcern = try container.decodeIfPresent(WriteConcern.self, forKey: .writeConcern) + self.init(readConcern: readConcern, readPreference: readPreference, writeConcern: writeConcern) + } + + private enum CodingKeys: CodingKey { + case readConcern, readPreference, writeConcern + } +} + +extension CollectionOptions: Decodable { + public init(from decoder: Decoder) throws { + let container = try decoder.container(keyedBy: CodingKeys.self) + let readConcern = try container.decodeIfPresent(ReadConcern.self, forKey: .readConcern) + let writeConcern = try container.decodeIfPresent(WriteConcern.self, forKey: .writeConcern) + self.init(readConcern: readConcern, writeConcern: writeConcern) + } + + private enum CodingKeys: CodingKey { + case readConcern, writeConcern + } +} + +extension ClientSessionOptions: Decodable { + public init(from decoder: Decoder) throws { + let container = try decoder.container(keyedBy: CodingKeys.self) + let causalConsistency = try container.decodeIfPresent(Bool.self, forKey: .causalConsistency) + let defaultTransactionOptions = try container.decodeIfPresent( + TransactionOptions.self, + forKey: .defaultTransactionOptions + ) + self.init(causalConsistency: causalConsistency, defaultTransactionOptions: defaultTransactionOptions) + } + + private enum CodingKeys: CodingKey { + case causalConsistency, defaultTransactionOptions + } +} + +extension TransactionOptions: Decodable { + public init(from decoder: Decoder) throws { + let container = try decoder.container(keyedBy: CodingKeys.self) + let maxCommitTimeMS = try container.decodeIfPresent(Int64.self, forKey: .maxCommitTimeMS) + let readConcern = try container.decodeIfPresent(ReadConcern.self, forKey: .readConcern) + let readPreference = try container.decodeIfPresent(ReadPreference.self, forKey: .readPreference) + let writeConcern = try container.decodeIfPresent(WriteConcern.self, forKey: .writeConcern) + self.init( + maxCommitTimeMS: maxCommitTimeMS, + readConcern: readConcern, + readPreference: readPreference, + writeConcern: writeConcern + ) + } + + private enum CodingKeys: CodingKey { + case maxCommitTimeMS, readConcern, readPreference, writeConcern + } +} + +extension ReadPreference.Mode: Decodable {} + +extension ReadPreference: Decodable { + private enum CodingKeys: String, CodingKey { + case mode + } + + public init(from decoder: Decoder) throws { + let container = try decoder.container(keyedBy: CodingKeys.self) + let mode = try container.decode(Mode.self, forKey: .mode) + self.init(mode) + } +} + +extension ClientOptions: Decodable { + private enum CodingKeys: String, CodingKey { + case retryReads, retryWrites, w, readConcernLevel, mode = "readPreference" + } + + public init(from decoder: Decoder) throws { + let container = try decoder.container(keyedBy: CodingKeys.self) + let readConcern = try? ReadConcern(container.decode(String.self, forKey: .readConcernLevel)) + let readPreference = try? ReadPreference(container.decode(ReadPreference.Mode.self, forKey: .mode)) + let retryReads = try container.decodeIfPresent(Bool.self, forKey: .retryReads) + let retryWrites = try container.decodeIfPresent(Bool.self, forKey: .retryWrites) + let writeConcern = try? WriteConcern(w: container.decode(WriteConcern.W.self, forKey: .w)) + self.init( + readConcern: readConcern, + readPreference: readPreference, + retryReads: retryReads, + retryWrites: retryWrites, + writeConcern: writeConcern + ) + } +} diff --git a/Tests/MongoSwiftSyncTests/SpecTestRunner/FailPoint.swift b/Tests/MongoSwiftSyncTests/SpecTestRunner/FailPoint.swift index 00d7e3365..b25e539a0 100644 --- a/Tests/MongoSwiftSyncTests/SpecTestRunner/FailPoint.swift +++ b/Tests/MongoSwiftSyncTests/SpecTestRunner/FailPoint.swift @@ -102,6 +102,7 @@ internal struct FailPoint: Decodable { mode: Mode, closeConnection: Bool? = nil, errorCode: Int? = nil, + errorLabels: [String]? = nil, writeConcernError: Document? = nil ) -> FailPoint { var data: Document = [ @@ -113,6 +114,9 @@ internal struct FailPoint: Decodable { if let code = errorCode { data["errorCode"] = BSON(code) } + if let labels = errorLabels { + data["errorLabels"] = .array(labels.map { .string($0) }) + } if let writeConcernError = writeConcernError { data["writeConcernError"] = .document(writeConcernError) } diff --git a/Tests/MongoSwiftSyncTests/SpecTestRunner/Match.swift b/Tests/MongoSwiftSyncTests/SpecTestRunner/Match.swift index 986158846..6432535a4 100644 --- a/Tests/MongoSwiftSyncTests/SpecTestRunner/Match.swift +++ b/Tests/MongoSwiftSyncTests/SpecTestRunner/Match.swift @@ -69,7 +69,12 @@ extension Array: Matchable where Element: Matchable { extension Document: Matchable { internal func contentMatches(expected: Document) -> Bool { for (eK, eV) in expected { - guard let aV = self[eK], aV.matches(expected: eV) else { + // If the expected document has "key": null then the actual document must either have "key": null + // or no reference to "key". + guard let aV = self[eK] else { + return eV == .null + } + guard aV.matches(expected: eV) else { return false } } diff --git a/Tests/MongoSwiftSyncTests/SpecTestRunner/SpecTest.swift b/Tests/MongoSwiftSyncTests/SpecTestRunner/SpecTest.swift index b763e8cd0..564b16243 100644 --- a/Tests/MongoSwiftSyncTests/SpecTestRunner/SpecTest.swift +++ b/Tests/MongoSwiftSyncTests/SpecTestRunner/SpecTest.swift @@ -1,4 +1,5 @@ import Foundation +@testable import struct MongoSwift.ReadPreference import MongoSwiftSync import Nimble import TestsCommon @@ -20,8 +21,23 @@ internal struct TestCommandStartedEvent: Decodable, Matchable { case type = "command_started_event" } - internal init(from event: CommandStartedEvent) { - self.command = event.command + internal init(from event: CommandStartedEvent, sessionIds: [Document: String]? = nil) { + var command = event.command + + // If command started event has "lsid": Document(...), change the value to correpond to "session0", + // "session1", etc. + if let sessionIds = sessionIds, let sessionDoc = command["lsid"]?.documentValue { + for (sessionId, sessionName) in sessionIds where sessionId == sessionDoc { + command["lsid"] = .string(sessionName) + } + } + // If command is "findAndModify" and does not have key "new", add the default value "new": false. + // This is necessary because `libmongoc` only sends a value for "new" in a command if "new": true. + if event.commandName == "findAndModify" && command["new"] == nil { + command["new"] = .bool(false) + } + + self.command = command self.databaseName = event.databaseName self.commandName = event.commandName } @@ -167,6 +183,10 @@ extension SpecTestFile { internal func populateData(using client: MongoClient) throws { let database = client.db(self.databaseName) + // Majority write concern ensures that initial data is propagated to all nodes in a replica set or sharded + // cluster. + let collectionOptions = CollectionOptions(writeConcern: try WriteConcern(w: .majority)) + try? database.drop() switch self.data { @@ -179,13 +199,13 @@ extension SpecTestFile { return } - try database.collection(collName).insertMany(docs) + try database.collection(collName, options: collectionOptions).insertMany(docs) case let .multiple(mapping): for (k, v) in mapping { guard !v.isEmpty else { continue } - try database.collection(k).insertMany(v) + try database.collection(k, options: collectionOptions).insertMany(v) } } } @@ -203,14 +223,13 @@ extension SpecTestFile { } } - try self.populateData(using: setupClient) - fileLevelLog("Executing tests from file \(self.name)...") for test in self.tests { guard skippedTestKeywords.allSatisfy({ !test.description.contains($0) }) else { print("Skipping test \(test.description)") return } + try self.populateData(using: setupClient) try test.run(parent: parent, dbName: self.databaseName, collName: self.collectionName) } } @@ -241,10 +260,27 @@ internal protocol SpecTest: Decodable { /// List of expected CommandStartedEvents. var expectations: [TestCommandStartedEvent]? { get } + + /// Document describing the return value and/or expected state of the collection after the operation is executed. + var outcome: TestOutcome? { get } + + /// Map of session names (e.g. "session0") to parameters to pass to `MongoClient.startSession()` when creating that + /// session. + var sessionOptions: [String: ClientSessionOptions]? { get } + + /// Array of session names (e.g. "session0", "session1") that the test refers to. Each session is proactively + /// started in `run()`. + static var sessionNames: [String] { get } } /// Default implementation of a test execution. extension SpecTest { + var outcome: TestOutcome? { nil } + + var sessionOptions: [String: ClientSessionOptions]? { nil } + + static var sessionNames: [String] { [] } + internal func run( parent: FailPointConfigured, dbName: String, @@ -257,37 +293,76 @@ extension SpecTest { print("Executing test: \(self.description)") - let clientOptions = self.clientOptions ?? ClientOptions(retryReads: true) + var singleMongos = true + if let useMultipleMongoses = self.useMultipleMongoses, useMultipleMongoses == true { + singleMongos = false + } + + let client = try MongoClient.makeTestClient( + MongoSwiftTestCase.getConnectionString(singleMongos: singleMongos), options: self.clientOptions + ) + let monitor = client.addCommandMonitor() + + if let collName = collName { + _ = try? client.db(dbName).createCollection(collName) + // Run the distinct command before every test to prevent `StableDbVersion` error in sharded cluster + // transactions. This workaround can be removed once SERVER-39704 is resolved. + _ = try? client.db(dbName).collection(collName).distinct(fieldName: "_id") + } if let failPoint = self.failPoint { try parent.activateFailPoint(failPoint) } defer { parent.disableActiveFailPoint() } - let client = try MongoClient.makeTestClient(options: clientOptions) - let monitor = client.addCommandMonitor() - - let db = client.db(dbName) - var collection: MongoCollection? - - if let collName = collName { - collection = db.collection(collName) + var sessions = [String: ClientSession]() + for session in Self.sessionNames { + sessions[session] = client.startSession(options: self.sessionOptions?[session]) } + var sessionIds = [Document: String]() + try monitor.captureEvents { for operation in self.operations { try operation.validateExecution( client: client, - database: db, - collection: collection, - session: nil + dbName: dbName, + collName: collName, + sessions: sessions ) } + // Keep track of the session IDs assigned to each session. + // Deinitialize each session thereby implicitly ending them. + for session in sessions.keys { + if let sessionId = sessions[session]?.id { sessionIds[sessionId] = session } + sessions[session] = nil + } + } + + let events = monitor.commandStartedEvents().map { commandStartedEvent in + TestCommandStartedEvent(from: commandStartedEvent, sessionIds: sessionIds) } - let events = monitor.commandStartedEvents().map { TestCommandStartedEvent(from: $0) } if let expectations = self.expectations { expect(events).to(match(expectations), description: self.description) } + + try self.checkOutcome(dbName: dbName, collName: collName) + } + + internal func checkOutcome(dbName: String, collName: String?) throws { + guard let outcome = self.outcome else { + return + } + guard let collName = collName else { + throw TestError(message: "outcome specifies a collection but spec test omits collection name") + } + let client = try MongoClient.makeTestClient() + let verifyColl = client.db(dbName).collection(collName) + let foundDocs = try verifyColl.find().all() + expect(foundDocs.count).to(equal(outcome.collection.data.count)) + zip(foundDocs, outcome.collection.data).forEach { + expect($0).to(sortedEqual($1), description: self.description) + } } } diff --git a/Tests/MongoSwiftSyncTests/SpecTestRunner/TestOperation.swift b/Tests/MongoSwiftSyncTests/SpecTestRunner/TestOperation.swift index ece691db4..f44523b5d 100644 --- a/Tests/MongoSwiftSyncTests/SpecTestRunner/TestOperation.swift +++ b/Tests/MongoSwiftSyncTests/SpecTestRunner/TestOperation.swift @@ -3,8 +3,42 @@ import Nimble import TestsCommon /// A enumeration of the different objects a `TestOperation` may be performed against. -enum TestOperationObject: String, Decodable { - case client, database, collection, gridfsbucket +enum TestOperationObject: RawRepresentable, Decodable { + case client, database, collection, gridfsbucket, testRunner, session(String) + + public var rawValue: String { + switch self { + case .client: + return "client" + case .database: + return "database" + case .collection: + return "collection" + case .gridfsbucket: + return "gridfsbucket" + case .testRunner: + return "testRunner" + case let .session(sessionName): + return sessionName + } + } + + public init(rawValue: String) { + switch rawValue { + case "client": + self = .client + case "database": + self = .database + case "collection": + self = .collection + case "gridfsbucket": + self = .gridfsbucket + case "testRunner": + self = .testRunner + default: + self = .session(rawValue) + } + } } /// Struct containing an operation and an expected outcome. @@ -21,8 +55,17 @@ struct TestOperationDescription: Decodable { /// Whether the operation should expect an error. let error: Bool? - public enum CodingKeys: CodingKey { - case object, result, error + /// The parameters to pass to the database used for this operation. + let databaseOptions: DatabaseOptions? + + /// The parameters to pass to the collection used for this operation. + let collectionOptions: CollectionOptions? + + /// Present only when the operation is `runCommand`. The name of the command to run. + let commandName: String? + + public enum CodingKeys: String, CodingKey { + case object, result, error, databaseOptions, collectionOptions, commandName = "command_name" } public init(from decoder: Decoder) throws { @@ -32,23 +75,32 @@ struct TestOperationDescription: Decodable { self.object = try container.decode(TestOperationObject.self, forKey: .object) self.result = try container.decodeIfPresent(TestOperationResult.self, forKey: .result) self.error = try container.decodeIfPresent(Bool.self, forKey: .error) + self.databaseOptions = try container.decodeIfPresent(DatabaseOptions.self, forKey: .databaseOptions) + self.collectionOptions = try container.decodeIfPresent(CollectionOptions.self, forKey: .collectionOptions) + self.commandName = try container.decodeIfPresent(String.self, forKey: .commandName) } + // swiftlint:disable cyclomatic_complexity + /// Runs the operation and asserts its results meet the expectation. func validateExecution( client: MongoClient, - database: MongoDatabase?, - collection: MongoCollection?, - session: ClientSession? + dbName: String, + collName: String?, + sessions: [String: ClientSession] ) throws { + let database = client.db(dbName, options: self.databaseOptions) + var collection: MongoCollection? + + if let collName = collName { + collection = database.collection(collName, options: self.collectionOptions) + } + let target: TestOperationTarget switch self.object { case .client: target = .client(client) case .database: - guard let database = database else { - throw TestError(message: "got database object but was not provided a database") - } target = .database(database) case .collection: guard let collection = collection else { @@ -57,19 +109,32 @@ struct TestOperationDescription: Decodable { target = .collection(collection) case .gridfsbucket: throw TestError(message: "gridfs tests should be skipped") + case let .session(sessionName): + guard let session = sessions[sessionName] else { + throw TestError(message: "got session object but was not provided a session") + } + target = .session(session) + case .testRunner: + target = .testRunner } do { - let result = try self.operation.execute(on: target, session: session) + let result = try self.operation.execute(on: target, sessions: sessions) expect(self.error ?? false) .to(beFalse(), description: "expected to fail but succeeded with result \(String(describing: result))") if let expectedResult = self.result { - expect(result).to(equal(expectedResult)) + expect(result?.matches(expected: expectedResult)).to(beTrue()) } } catch { - expect(self.error ?? false).to(beTrue(), description: "expected no error, got \(error)") + if case let .error(expectedErrorResult) = self.result { + try expectedErrorResult.checkErrorResult(error) + } else { + expect(self.error ?? false).to(beTrue(), description: "expected no error, got \(error)") + } } } + + // swiftlint:enable cyclomatic_complexity } /// Object in which an operation should be executed on. @@ -83,12 +148,19 @@ enum TestOperationTarget { /// Execute against the provided collection. case collection(MongoCollection) + + /// Execute against the provided session. + case session(ClientSession) + + /// Execute against the provided test runner. Operations that execute on the test runner do not correspond to API + /// methods but instead represent special test operations such as asserts. + case testRunner } /// Protocol describing the behavior of a spec test "operation" protocol TestOperation: Decodable { /// Execute the operation given the context. - func execute(on target: TestOperationTarget, session: ClientSession?) throws -> TestOperationResult? + func execute(on target: TestOperationTarget, sessions: [String: ClientSession]) throws -> TestOperationResult? } /// Wrapper around a `TestOperation.swift` allowing it to be decoded from a spec test. @@ -141,8 +213,34 @@ struct AnyTestOperation: Decodable, TestOperation { self.op = try container.decode(ReplaceOne.self, forKey: .arguments) case "rename": self.op = try container.decode(RenameCollection.self, forKey: .arguments) + case "startTransaction": + self.op = (try container.decodeIfPresent(StartTransaction.self, forKey: .arguments)) ?? StartTransaction() + case "createCollection": + self.op = try container.decode(CreateCollection.self, forKey: .arguments) + case "dropCollection": + self.op = try container.decode(DropCollection.self, forKey: .arguments) + case "createIndex": + self.op = try container.decode(CreateIndex.self, forKey: .arguments) + case "runCommand": + self.op = try container.decode(RunCommand.self, forKey: .arguments) + case "assertCollectionExists": + self.op = try container.decode(AssertCollectionExists.self, forKey: .arguments) + case "assertCollectionNotExists": + self.op = try container.decode(AssertCollectionNotExists.self, forKey: .arguments) + case "assertIndexExists": + self.op = try container.decode(AssertIndexExists.self, forKey: .arguments) + case "assertIndexNotExists": + self.op = try container.decode(AssertIndexNotExists.self, forKey: .arguments) + case "assertSessionPinned": + self.op = try container.decode(AssertSessionPinned.self, forKey: .arguments) + case "assertSessionUnpinned": + self.op = try container.decode(AssertSessionUnpinned.self, forKey: .arguments) + case "assertSessionTransactionState": + self.op = try container.decode(AssertSessionTransactionState.self, forKey: .arguments) + case "targetedFailPoint": + self.op = try container.decode(TargetedFailPoint.self, forKey: .arguments) case "drop": - self.op = DropCollection() + self.op = Drop() case "listDatabaseNames": self.op = ListDatabaseNames() case "listDatabases": @@ -161,6 +259,10 @@ struct AnyTestOperation: Decodable, TestOperation { self.op = ListCollectionNames() case "watch": self.op = Watch() + case "commitTransaction": + self.op = CommitTransaction() + case "abortTransaction": + self.op = AbortTransaction() case "mapReduce", "download_by_name", "download", "count": self.op = NotImplemented(name: opName) default: @@ -168,68 +270,75 @@ struct AnyTestOperation: Decodable, TestOperation { } } - func execute(on target: TestOperationTarget, session: ClientSession?) throws -> TestOperationResult? { - try self.op.execute(on: target, session: session) + func execute(on target: TestOperationTarget, sessions: [String: ClientSession]) throws -> TestOperationResult? { + try self.op.execute(on: target, sessions: sessions) } } struct Aggregate: TestOperation { + let session: String? let pipeline: [Document] let options: AggregateOptions - private enum CodingKeys: String, CodingKey { case pipeline } + private enum CodingKeys: String, CodingKey { case session, pipeline } init(from decoder: Decoder) throws { self.options = try AggregateOptions(from: decoder) let container = try decoder.container(keyedBy: CodingKeys.self) + self.session = try container.decodeIfPresent(String.self, forKey: .session) self.pipeline = try container.decode([Document].self, forKey: .pipeline) } - func execute(on target: TestOperationTarget, session: ClientSession?) throws -> TestOperationResult? { + func execute(on target: TestOperationTarget, sessions: [String: ClientSession]) throws -> TestOperationResult? { guard case let .collection(collection) = target else { throw TestError(message: "collection not provided to aggregate") } return try TestOperationResult( - from: collection.aggregate(self.pipeline, options: self.options, session: session) + from: collection.aggregate(self.pipeline, options: self.options, session: sessions[self.session ?? ""]) ) } } struct CountDocuments: TestOperation { + let session: String? let filter: Document let options: CountDocumentsOptions - private enum CodingKeys: String, CodingKey { case filter } + private enum CodingKeys: String, CodingKey { case session, filter } init(from decoder: Decoder) throws { self.options = try CountDocumentsOptions(from: decoder) let container = try decoder.container(keyedBy: CodingKeys.self) + self.session = try container.decodeIfPresent(String.self, forKey: .session) self.filter = try container.decode(Document.self, forKey: .filter) } - func execute(on target: TestOperationTarget, session: ClientSession?) throws -> TestOperationResult? { + func execute(on target: TestOperationTarget, sessions: [String: ClientSession]) throws -> TestOperationResult? { guard case let .collection(collection) = target else { throw TestError(message: "collection not provided to count") } - return .int(try collection.countDocuments(self.filter, options: self.options, session: session)) + return .int( + try collection.countDocuments(self.filter, options: self.options, session: sessions[self.session ?? ""])) } } struct Distinct: TestOperation { + let session: String? let fieldName: String let filter: Document? let options: DistinctOptions - private enum CodingKeys: String, CodingKey { case fieldName, filter } + private enum CodingKeys: String, CodingKey { case session, fieldName, filter } init(from decoder: Decoder) throws { self.options = try DistinctOptions(from: decoder) let container = try decoder.container(keyedBy: CodingKeys.self) + self.session = try container.decodeIfPresent(String.self, forKey: .session) self.fieldName = try container.decode(String.self, forKey: .fieldName) self.filter = try container.decodeIfPresent(Document.self, forKey: .filter) } - func execute(on target: TestOperationTarget, session: ClientSession?) throws -> TestOperationResult? { + func execute(on target: TestOperationTarget, sessions: [String: ClientSession]) throws -> TestOperationResult? { guard case let .collection(collection) = target else { throw TestError(message: "collection not provided to distinct") } @@ -237,67 +346,77 @@ struct Distinct: TestOperation { fieldName: self.fieldName, filter: self.filter ?? [:], options: self.options, - session: session + session: sessions[self.session ?? ""] ) return .array(result) } } struct Find: TestOperation { + let session: String? let filter: Document let options: FindOptions - private enum CodingKeys: String, CodingKey { case filter } + private enum CodingKeys: String, CodingKey { case session, filter } init(from decoder: Decoder) throws { self.options = try FindOptions(from: decoder) let container = try decoder.container(keyedBy: CodingKeys.self) - self.filter = try container.decode(Document.self, forKey: .filter) + self.session = try container.decodeIfPresent(String.self, forKey: .session) + self.filter = (try container.decodeIfPresent(Document.self, forKey: .filter)) ?? Document() } - func execute(on target: TestOperationTarget, session: ClientSession?) throws -> TestOperationResult? { + func execute(on target: TestOperationTarget, sessions: [String: ClientSession]) throws -> TestOperationResult? { guard case let .collection(collection) = target else { throw TestError(message: "collection not provided to find") } - return try TestOperationResult(from: collection.find(self.filter, options: self.options, session: session)) + return try TestOperationResult( + from: collection.find(self.filter, options: self.options, session: sessions[self.session ?? ""]) + ) } } struct FindOne: TestOperation { + let session: String? let filter: Document let options: FindOneOptions - private enum CodingKeys: String, CodingKey { case filter } + private enum CodingKeys: String, CodingKey { case session, filter } init(from decoder: Decoder) throws { self.options = try FindOneOptions(from: decoder) let container = try decoder.container(keyedBy: CodingKeys.self) + self.session = try container.decodeIfPresent(String.self, forKey: .session) self.filter = try container.decode(Document.self, forKey: .filter) } - func execute(on target: TestOperationTarget, session: ClientSession?) throws -> TestOperationResult? { + func execute(on target: TestOperationTarget, sessions: [String: ClientSession]) throws -> TestOperationResult? { guard case let .collection(collection) = target else { throw TestError(message: "collection not provided to findOne") } - return try TestOperationResult(from: collection.findOne(self.filter, options: self.options, session: session)) + return try TestOperationResult( + from: collection.findOne(self.filter, options: self.options, session: sessions[self.session ?? ""]) + ) } } struct UpdateOne: TestOperation { + let session: String? let filter: Document let update: Document let options: UpdateOptions - private enum CodingKeys: String, CodingKey { case filter, update } + private enum CodingKeys: String, CodingKey { case session, filter, update } init(from decoder: Decoder) throws { self.options = try UpdateOptions(from: decoder) let container = try decoder.container(keyedBy: CodingKeys.self) + self.session = try container.decodeIfPresent(String.self, forKey: .session) self.filter = try container.decode(Document.self, forKey: .filter) self.update = try container.decode(Document.self, forKey: .update) } - func execute(on target: TestOperationTarget, session: ClientSession?) throws -> TestOperationResult? { + func execute(on target: TestOperationTarget, sessions: [String: ClientSession]) throws -> TestOperationResult? { guard case let .collection(collection) = target else { throw TestError(message: "collection not provided to updateOne") } @@ -306,27 +425,29 @@ struct UpdateOne: TestOperation { filter: self.filter, update: self.update, options: self.options, - session: session + session: sessions[self.session ?? ""] ) return TestOperationResult(from: result) } } struct UpdateMany: TestOperation { + let session: String? let filter: Document let update: Document let options: UpdateOptions - private enum CodingKeys: String, CodingKey { case filter, update } + private enum CodingKeys: String, CodingKey { case session, filter, update } init(from decoder: Decoder) throws { self.options = try UpdateOptions(from: decoder) let container = try decoder.container(keyedBy: CodingKeys.self) + self.session = try container.decodeIfPresent(String.self, forKey: .session) self.filter = try container.decode(Document.self, forKey: .filter) self.update = try container.decode(Document.self, forKey: .update) } - func execute(on target: TestOperationTarget, session: ClientSession?) throws -> TestOperationResult? { + func execute(on target: TestOperationTarget, sessions: [String: ClientSession]) throws -> TestOperationResult? { guard case let .collection(collection) = target else { throw TestError(message: "collection not provided to ") } @@ -335,74 +456,85 @@ struct UpdateMany: TestOperation { filter: self.filter, update: self.update, options: self.options, - session: session + session: sessions[self.session ?? ""] ) return TestOperationResult(from: result) } } struct DeleteMany: TestOperation { + let session: String? let filter: Document let options: DeleteOptions - private enum CodingKeys: String, CodingKey { case filter } + private enum CodingKeys: String, CodingKey { case session, filter } init(from decoder: Decoder) throws { self.options = try DeleteOptions(from: decoder) let container = try decoder.container(keyedBy: CodingKeys.self) + self.session = try container.decodeIfPresent(String.self, forKey: .session) self.filter = try container.decode(Document.self, forKey: .filter) } - func execute(on target: TestOperationTarget, session: ClientSession?) throws -> TestOperationResult? { + func execute(on target: TestOperationTarget, sessions: [String: ClientSession]) throws -> TestOperationResult? { guard case let .collection(collection) = target else { throw TestError(message: "collection not provided to deleteMany") } - let result = try collection.deleteMany(self.filter, options: self.options, session: session) + let result = + try collection.deleteMany(self.filter, options: self.options, session: sessions[self.session ?? ""]) return TestOperationResult(from: result) } } struct DeleteOne: TestOperation { + let session: String? let filter: Document let options: DeleteOptions - private enum CodingKeys: String, CodingKey { case filter } + private enum CodingKeys: String, CodingKey { case session, filter } init(from decoder: Decoder) throws { self.options = try DeleteOptions(from: decoder) let container = try decoder.container(keyedBy: CodingKeys.self) + self.session = try container.decodeIfPresent(String.self, forKey: .session) self.filter = try container.decode(Document.self, forKey: .filter) } - func execute(on target: TestOperationTarget, session: ClientSession?) throws -> TestOperationResult? { + func execute(on target: TestOperationTarget, sessions: [String: ClientSession]) throws -> TestOperationResult? { guard case let .collection(collection) = target else { throw TestError(message: "collection not provided to deleteOne") } - let result = try collection.deleteOne(self.filter, options: self.options, session: session) + let result = try collection.deleteOne(self.filter, options: self.options, session: sessions[self.session ?? ""]) return TestOperationResult(from: result) } } struct InsertOne: TestOperation { + let session: String? let document: Document - func execute(on target: TestOperationTarget, session _: ClientSession?) throws -> TestOperationResult? { + func execute(on target: TestOperationTarget, sessions: [String: ClientSession]) throws -> TestOperationResult? { guard case let .collection(collection) = target else { throw TestError(message: "collection not provided to insertOne") } - return TestOperationResult(from: try collection.insertOne(self.document)) + return TestOperationResult(from: try collection.insertOne(self.document, session: sessions[self.session ?? ""])) } } struct InsertMany: TestOperation { + let session: String? let documents: [Document] - let options: InsertManyOptions + let options: InsertManyOptions? - func execute(on target: TestOperationTarget, session: ClientSession?) throws -> TestOperationResult? { + func execute(on target: TestOperationTarget, sessions: [String: ClientSession]) throws -> TestOperationResult? { guard case let .collection(collection) = target else { throw TestError(message: "collection not provided to insertMany") } - let result = try collection.insertMany(self.documents, options: self.options, session: session) + let result = try collection.insertMany( + self.documents, + options: self.options, + session: sessions[self.session ?? ""] + ) return TestOperationResult(from: result) } } @@ -414,19 +546,19 @@ extension WriteModel: Decodable { } private enum InsertOneKeys: CodingKey { - case document + case session, document } private enum DeleteKeys: CodingKey { - case filter + case session, filter } private enum ReplaceOneKeys: CodingKey { - case filter, replacement + case session, filter, replacement } private enum UpdateKeys: CodingKey { - case filter, update + case session, filter, update } public init(from decoder: Decoder) throws { @@ -470,33 +602,37 @@ extension WriteModel: Decodable { } struct BulkWrite: TestOperation { + let session: String? let requests: [WriteModel] - let options: BulkWriteOptions + let options: BulkWriteOptions? - func execute(on target: TestOperationTarget, session: ClientSession?) throws -> TestOperationResult? { + func execute(on target: TestOperationTarget, sessions: [String: ClientSession]) throws -> TestOperationResult? { guard case let .collection(collection) = target else { throw TestError(message: "collection not provided to bulk write") } - let result = try collection.bulkWrite(self.requests, options: self.options, session: session) + let result = + try collection.bulkWrite(self.requests, options: self.options, session: sessions[self.session ?? ""]) return TestOperationResult(from: result) } } struct FindOneAndUpdate: TestOperation { + let session: String? let filter: Document let update: Document let options: FindOneAndUpdateOptions - private enum CodingKeys: String, CodingKey { case filter, update } + private enum CodingKeys: String, CodingKey { case session, filter, update } init(from decoder: Decoder) throws { self.options = try FindOneAndUpdateOptions(from: decoder) let container = try decoder.container(keyedBy: CodingKeys.self) + self.session = try container.decodeIfPresent(String.self, forKey: .session) self.filter = try container.decode(Document.self, forKey: .filter) self.update = try container.decode(Document.self, forKey: .update) } - func execute(on target: TestOperationTarget, session: ClientSession?) throws -> TestOperationResult? { + func execute(on target: TestOperationTarget, sessions: [String: ClientSession]) throws -> TestOperationResult? { guard case let .collection(collection) = target else { throw TestError(message: "collection not provided to findOneAndUpdate") } @@ -504,48 +640,56 @@ struct FindOneAndUpdate: TestOperation { filter: self.filter, update: self.update, options: self.options, - session: session + session: sessions[self.session ?? ""] ) return TestOperationResult(from: doc) } } struct FindOneAndDelete: TestOperation { + let session: String? let filter: Document let options: FindOneAndDeleteOptions - private enum CodingKeys: String, CodingKey { case filter } + private enum CodingKeys: String, CodingKey { case session, filter } init(from decoder: Decoder) throws { self.options = try FindOneAndDeleteOptions(from: decoder) let container = try decoder.container(keyedBy: CodingKeys.self) + self.session = try container.decodeIfPresent(String.self, forKey: .session) self.filter = try container.decode(Document.self, forKey: .filter) } - func execute(on target: TestOperationTarget, session: ClientSession?) throws -> TestOperationResult? { + func execute(on target: TestOperationTarget, sessions: [String: ClientSession]) throws -> TestOperationResult? { guard case let .collection(collection) = target else { throw TestError(message: "collection not provided to findOneAndDelete") } - let result = try collection.findOneAndDelete(self.filter, options: self.options, session: session) + let result = try collection.findOneAndDelete( + self.filter, + options: self.options, + session: sessions[self.session ?? ""] + ) return TestOperationResult(from: result) } } struct FindOneAndReplace: TestOperation { + let session: String? let filter: Document let replacement: Document let options: FindOneAndReplaceOptions - private enum CodingKeys: String, CodingKey { case filter, replacement } + private enum CodingKeys: String, CodingKey { case session, filter, replacement } init(from decoder: Decoder) throws { self.options = try FindOneAndReplaceOptions(from: decoder) let container = try decoder.container(keyedBy: CodingKeys.self) + self.session = try container.decodeIfPresent(String.self, forKey: .session) self.filter = try container.decode(Document.self, forKey: .filter) self.replacement = try container.decode(Document.self, forKey: .replacement) } - func execute(on target: TestOperationTarget, session: ClientSession?) throws -> TestOperationResult? { + func execute(on target: TestOperationTarget, sessions: [String: ClientSession]) throws -> TestOperationResult? { guard case let .collection(collection) = target else { throw TestError(message: "collection not provided to findOneAndReplace") } @@ -553,27 +697,29 @@ struct FindOneAndReplace: TestOperation { filter: self.filter, replacement: self.replacement, options: self.options, - session: session + session: sessions[self.session ?? ""] ) return TestOperationResult(from: result) } } struct ReplaceOne: TestOperation { + let session: String? let filter: Document let replacement: Document let options: ReplaceOptions - private enum CodingKeys: String, CodingKey { case filter, replacement } + private enum CodingKeys: String, CodingKey { case session, filter, replacement } init(from decoder: Decoder) throws { self.options = try ReplaceOptions(from: decoder) let container = try decoder.container(keyedBy: CodingKeys.self) + self.session = try container.decodeIfPresent(String.self, forKey: .session) self.filter = try container.decode(Document.self, forKey: .filter) self.replacement = try container.decode(Document.self, forKey: .replacement) } - func execute(on target: TestOperationTarget, session: ClientSession?) throws -> TestOperationResult? { + func execute(on target: TestOperationTarget, sessions: [String: ClientSession]) throws -> TestOperationResult? { guard case let .collection(collection) = target else { throw TestError(message: "collection not provided to replaceOne") } @@ -581,15 +727,16 @@ struct ReplaceOne: TestOperation { filter: self.filter, replacement: self.replacement, options: self.options, - session: session + session: sessions[self.session ?? ""] )) } } struct RenameCollection: TestOperation { + let session: String? let to: String - func execute(on target: TestOperationTarget, session: ClientSession?) throws -> TestOperationResult? { + func execute(on target: TestOperationTarget, sessions: [String: ClientSession]) throws -> TestOperationResult? { guard case let .collection(collection) = target else { throw TestError(message: "collection not provided to renameCollection") } @@ -599,14 +746,17 @@ struct RenameCollection: TestOperation { "renameCollection": .string(databaseName + "." + collection.name), "to": .string(databaseName + "." + self.to) ] - return try TestOperationResult(from: collection._client.db("admin").runCommand(cmd, session: session)) + return try TestOperationResult( + from: collection._client.db("admin").runCommand(cmd, session: sessions[self.session ?? ""]) + ) } } -struct DropCollection: TestOperation { - func execute(on target: TestOperationTarget, session _: ClientSession?) throws -> TestOperationResult? { +struct Drop: TestOperation { + func execute(on target: TestOperationTarget, sessions _: [String: ClientSession]) + throws -> TestOperationResult? { guard case let .collection(collection) = target else { - throw TestError(message: "collection not provided to dropCollection") + throw TestError(message: "collection not provided to drop") } try collection.drop() return nil @@ -614,99 +764,338 @@ struct DropCollection: TestOperation { } struct ListDatabaseNames: TestOperation { - func execute(on target: TestOperationTarget, session: ClientSession?) throws -> TestOperationResult? { + func execute(on target: TestOperationTarget, sessions _: [String: ClientSession]) + throws -> TestOperationResult? { guard case let .client(client) = target else { throw TestError(message: "client not provided to listDatabaseNames") } - return try .array(client.listDatabaseNames(session: session).map { .string($0) }) + return try .array(client.listDatabaseNames().map { .string($0) }) } } struct ListIndexes: TestOperation { - func execute(on target: TestOperationTarget, session: ClientSession?) throws -> TestOperationResult? { + func execute(on target: TestOperationTarget, sessions _: [String: ClientSession]) + throws -> TestOperationResult? { guard case let .collection(collection) = target else { throw TestError(message: "collection not provided to listIndexes") } - return try TestOperationResult(from: collection.listIndexes(session: session)) + return try TestOperationResult(from: collection.listIndexes()) } } struct ListIndexNames: TestOperation { - func execute(on target: TestOperationTarget, session: ClientSession?) throws -> TestOperationResult? { + func execute(on target: TestOperationTarget, sessions _: [String: ClientSession]) + throws -> TestOperationResult? { guard case let .collection(collection) = target else { throw TestError(message: "collection not provided to listIndexNames") } - return try .array(collection.listIndexNames(session: session).map { .string($0) }) + return try .array(collection.listIndexNames().map { .string($0) }) } } struct ListDatabases: TestOperation { - func execute(on target: TestOperationTarget, session: ClientSession?) throws -> TestOperationResult? { + func execute(on target: TestOperationTarget, sessions _: [String: ClientSession]) + throws -> TestOperationResult? { guard case let .client(client) = target else { throw TestError(message: "client not provided to listDatabases") } - return try TestOperationResult(from: client.listDatabases(session: session)) + return try TestOperationResult(from: client.listDatabases()) } } struct ListMongoDatabases: TestOperation { - func execute(on target: TestOperationTarget, session: ClientSession?) throws -> TestOperationResult? { + func execute(on target: TestOperationTarget, sessions _: [String: ClientSession]) + throws -> TestOperationResult? { guard case let .client(client) = target else { throw TestError(message: "client not provided to listDatabases") } - _ = try client.listMongoDatabases(session: session) + _ = try client.listMongoDatabases() return nil } } struct ListCollections: TestOperation { - func execute(on target: TestOperationTarget, session: ClientSession?) throws -> TestOperationResult? { + func execute(on target: TestOperationTarget, sessions _: [String: ClientSession]) + throws -> TestOperationResult? { guard case let .database(database) = target else { throw TestError(message: "database not provided to listCollections") } - return try TestOperationResult(from: database.listCollections(session: session)) + return try TestOperationResult(from: database.listCollections()) } } struct ListMongoCollections: TestOperation { - func execute(on target: TestOperationTarget, session: ClientSession?) throws -> TestOperationResult? { + func execute(on target: TestOperationTarget, sessions _: [String: ClientSession]) + throws -> TestOperationResult? { guard case let .database(database) = target else { throw TestError(message: "database not provided to listCollectionObjects") } - _ = try database.listMongoCollections(session: session) + _ = try database.listMongoCollections() return nil } } struct ListCollectionNames: TestOperation { - func execute(on target: TestOperationTarget, session: ClientSession?) throws -> TestOperationResult? { + func execute(on target: TestOperationTarget, sessions _: [String: ClientSession]) + throws -> TestOperationResult? { guard case let .database(database) = target else { throw TestError(message: "database not provided to listCollectionNames") } - return try .array(database.listCollectionNames(session: session).map { .string($0) }) + return try .array(database.listCollectionNames().map { .string($0) }) } } struct Watch: TestOperation { - func execute(on target: TestOperationTarget, session: ClientSession?) throws -> TestOperationResult? { + func execute(on target: TestOperationTarget, sessions _: [String: ClientSession]) + throws -> TestOperationResult? { switch target { case let .client(client): - _ = try client.watch(session: session) + _ = try client.watch() case let .database(database): - _ = try database.watch(session: session) + _ = try database.watch() case let .collection(collection): - _ = try collection.watch(session: session) + _ = try collection.watch() + case .session: + throw TestError(message: "watch cannot be executed on a session") + case .testRunner: + throw TestError(message: "watch cannot be executed on the test runner") } return nil } } struct EstimatedDocumentCount: TestOperation { - func execute(on target: TestOperationTarget, session: ClientSession?) throws -> TestOperationResult? { + func execute(on target: TestOperationTarget, sessions _: [String: ClientSession]) + throws -> TestOperationResult? { guard case let .collection(collection) = target else { throw TestError(message: "collection not provided to estimatedDocumentCount") } - return try .int(collection.estimatedDocumentCount(session: session)) + return try .int(collection.estimatedDocumentCount()) + } +} + +struct StartTransaction: TestOperation { + let options: TransactionOptions + + init() { + self.options = TransactionOptions() + } + + func execute(on target: TestOperationTarget, sessions _: [String: ClientSession]) + throws -> TestOperationResult? { + guard case let .session(session) = target else { + throw TestError(message: "session not provided to startTransaction") + } + try session.startTransaction(options: self.options) + return nil + } +} + +struct CommitTransaction: TestOperation { + func execute(on target: TestOperationTarget, sessions _: [String: ClientSession]) + throws -> TestOperationResult? { + guard case let .session(session) = target else { + throw TestError(message: "session not provided to commitTransaction") + } + try session.commitTransaction() + return nil + } +} + +struct AbortTransaction: TestOperation { + func execute(on target: TestOperationTarget, sessions _: [String: ClientSession]) + throws -> TestOperationResult? { + guard case let .session(session) = target else { + throw TestError(message: "session not provided to abortTransaction") + } + try session.abortTransaction() + return nil + } +} + +struct CreateCollection: TestOperation { + let session: String? + let collection: String + + func execute(on target: TestOperationTarget, sessions: [String: ClientSession]) throws -> TestOperationResult? { + guard case let .database(database) = target else { + throw TestError(message: "database not provided to createCollection") + } + _ = try database.createCollection(self.collection, session: sessions[self.session ?? ""]) + return nil + } +} + +struct DropCollection: TestOperation { + let session: String? + let collection: String + + func execute(on target: TestOperationTarget, sessions: [String: ClientSession]) throws -> TestOperationResult? { + guard case let .database(database) = target else { + throw TestError(message: "database not provided to dropCollection") + } + _ = try database.collection(self.collection).drop(session: sessions[self.session ?? ""]) + return nil + } +} + +struct CreateIndex: TestOperation { + let session: String? + let name: String + let keys: Document + + func execute(on target: TestOperationTarget, sessions: [String: ClientSession]) throws -> TestOperationResult? { + guard case let .collection(collection) = target else { + throw TestError(message: "collection not provided to createIndex") + } + let indexOptions = IndexOptions(name: self.name) + _ = try collection.createIndex(self.keys, indexOptions: indexOptions, session: sessions[self.session ?? ""]) + return nil + } +} + +struct RunCommand: TestOperation { + let session: String? + let command: Document + let readPreference: ReadPreference? + + func execute(on target: TestOperationTarget, sessions: [String: ClientSession]) throws -> TestOperationResult? { + guard case let .database(database) = target else { + throw TestError(message: "database not provided to runCommand") + } + let runCommandOptions = RunCommandOptions(readPreference: self.readPreference) + let result = try database.runCommand( + self.command, + options: runCommandOptions, + session: sessions[self.session ?? ""] + ) + return TestOperationResult(from: result) + } +} + +struct AssertCollectionExists: TestOperation { + let database: String + let collection: String + + func execute(on target: TestOperationTarget, sessions _: [String: ClientSession]) + throws -> TestOperationResult? { + guard case .testRunner = target else { + throw TestError(message: "test runner not provided to assertCollectionExists") + } + let client = try MongoClient.makeTestClient() + let collectionNames = try client.db(self.database).listCollectionNames() + expect(collectionNames).to(contain(self.collection)) + return nil + } +} + +struct AssertCollectionNotExists: TestOperation { + let database: String + let collection: String + + func execute(on target: TestOperationTarget, sessions _: [String: ClientSession]) + throws -> TestOperationResult? { + guard case .testRunner = target else { + throw TestError(message: "test runner not provided to assertCollectionNotExists") + } + let client = try MongoClient.makeTestClient() + let collectionNames = try client.db(self.database).listCollectionNames() + expect(collectionNames).toNot(contain(self.collection)) + return nil + } +} + +struct AssertIndexExists: TestOperation { + let database: String + let collection: String + let index: String + + func execute(on target: TestOperationTarget, sessions _: [String: ClientSession]) + throws -> TestOperationResult? { + guard case .testRunner = target else { + throw TestError(message: "test runner not provided to assertIndexExists") + } + let client = try MongoClient.makeTestClient() + let indexNames = try client.db(self.database).collection(self.collection).listIndexNames() + expect(indexNames).to(contain(self.index)) + return nil + } +} + +struct AssertIndexNotExists: TestOperation { + let database: String + let collection: String + let index: String + + func execute(on target: TestOperationTarget, sessions _: [String: ClientSession]) + throws -> TestOperationResult? { + guard case .testRunner = target else { + throw TestError(message: "test runner not provided to assertIndexNotExists") + } + let client = try MongoClient.makeTestClient() + let indexNames = try client.db(self.database).collection(self.collection).listIndexNames() + expect(indexNames).toNot(contain(self.index)) + return nil + } +} + +struct AssertSessionPinned: TestOperation { + let session: String? + + func execute(on target: TestOperationTarget, sessions: [String: ClientSession]) throws -> TestOperationResult? { + guard case .testRunner = target else { + throw TestError(message: "test runner not provided to assertSessionPinned") + } + guard let serverId = sessions[self.session ?? ""]?.serverId else { + throw TestError(message: "active session not provided to assertSessionPinned") + } + expect(serverId).toNot(equal(0)) + return nil + } +} + +struct AssertSessionUnpinned: TestOperation { + let session: String? + + func execute(on target: TestOperationTarget, sessions: [String: ClientSession]) throws -> TestOperationResult? { + guard case .testRunner = target else { + throw TestError(message: "test runner not provided to assertSessionUnpinned") + } + guard let serverId = sessions[self.session ?? ""]?.serverId else { + throw TestError(message: "active session not provided to assertSessionPinned") + } + expect(serverId).to(equal(0)) + return nil + } +} + +struct AssertSessionTransactionState: TestOperation { + let session: String? + let state: ClientSession.TransactionState + + func execute(on target: TestOperationTarget, sessions: [String: ClientSession]) throws -> TestOperationResult? { + guard case .testRunner = target else { + throw TestError(message: "test runner not provided to assertSessionTransactionState") + } + guard let transactionState = sessions[self.session ?? ""]?.transactionState else { + throw TestError(message: "active session not provided to assertSessionTransactionState") + } + expect(transactionState).to(equal(self.state)) + return nil + } +} + +struct TargetedFailPoint: TestOperation { + let session: String? + let failPoint: Document + + func execute(on target: TestOperationTarget, sessions _: [String: ClientSession]) throws -> TestOperationResult? { + guard case .testRunner = target else { + throw TestError(message: "test runner not provided to targetedFailPoint") + } + let client = try MongoClient.makeTestClient() + try client.db("admin").runCommand(self.failPoint) + return nil } } @@ -714,7 +1103,7 @@ struct EstimatedDocumentCount: TestOperation { struct NotImplemented: TestOperation { internal let name: String - func execute(on _: TestOperationTarget, session _: ClientSession?) throws -> TestOperationResult? { + func execute(on _: TestOperationTarget, sessions _: [String: ClientSession]) throws -> TestOperationResult? { throw TestError(message: "\(self.name) not implemented in the driver, skip this test") } } diff --git a/Tests/MongoSwiftSyncTests/SpecTestRunner/TestOperationResult.swift b/Tests/MongoSwiftSyncTests/SpecTestRunner/TestOperationResult.swift index 97e0233df..d578428c6 100644 --- a/Tests/MongoSwiftSyncTests/SpecTestRunner/TestOperationResult.swift +++ b/Tests/MongoSwiftSyncTests/SpecTestRunner/TestOperationResult.swift @@ -1,7 +1,10 @@ import MongoSwiftSync +import Nimble +import TestsCommon +import XCTest -/// Enum encapsulating the possible results returned from CRUD operations. -enum TestOperationResult: Decodable, Equatable { +/// Enum encapsulating the possible results returned from test operations. +enum TestOperationResult: Decodable, Equatable, Matchable { /// Crud operation returns an int (e.g. `count`). case int(Int) @@ -14,6 +17,9 @@ enum TestOperationResult: Decodable, Equatable { /// Result of CRUD operations whose result can be represented by a `BulkWriteResult` (e.g. `InsertOne`). case bulkWrite(BulkWriteResult) + /// Result of test operations that are expected to return an error (e.g. `CommandError`, `WriteError`). + case error(ErrorResult) + public init?(from doc: Document?) { guard let doc = doc else { return nil @@ -48,6 +54,8 @@ enum TestOperationResult: Decodable, Equatable { self = .int(int) } else if let array = try? [BSON](from: decoder) { self = .array(array) + } else if let error = try? ErrorResult(from: decoder) { + self = .error(error) } else if let doc = try? Document(from: decoder) { self = .document(doc) } else { @@ -71,12 +79,33 @@ enum TestOperationResult: Decodable, Equatable { return lhsArray == rhsArray case let (.document(lhsDoc), .document(rhsDoc)): return lhsDoc.sortedEquals(rhsDoc) + case let (.error(lhsErr), .error(rhsErr)): + return lhsErr == rhsErr + default: + return false + } + } + + internal func contentMatches(expected: TestOperationResult) -> Bool { + switch (self, expected) { + case let (.bulkWrite(bw), .bulkWrite(expectedBw)): + return bw.matches(expected: expectedBw) + case let (.int(int), .int(expectedInt)): + return int.matches(expected: expectedInt) + case let (.array(array), .array(expectedArray)): + return array.matches(expected: expectedArray) + case let (.document(doc), .document(expectedDoc)): + return doc.matches(expected: expectedDoc) + case (.error, .error): + return false default: return false } } } +extension BulkWriteResult: Matchable {} + /// Protocol for allowing conversion from different result types to `BulkWriteResult`. /// This behavior is used to funnel the various CRUD results into the `.bulkWrite` `TestOperationResult` case. protocol BulkWriteResultConvertible { @@ -120,3 +149,135 @@ extension DeleteResult: BulkWriteResultConvertible { BulkWriteResult.new(deletedCount: self.deletedCount) } } + +struct ErrorResult: Equatable, Decodable { + internal var errorContains: String? + + internal var errorCodeName: String? + + internal var errorLabelsContain: [String]? + + internal var errorLabelsOmit: [String]? + + private enum CodingKeys: CodingKey { + case errorContains, errorCodeName, errorLabelsContain, errorLabelsOmit + } + + public init(from decoder: Decoder) throws { + let container = try decoder.container(keyedBy: CodingKeys.self) + + // None of the error keys must be present themselves, but at least one must. + guard !container.allKeys.isEmpty else { + throw DecodingError.valueNotFound( + ErrorResult.self, + DecodingError.Context( + codingPath: decoder.codingPath, + debugDescription: "No results found" + ) + ) + } + + self.errorContains = try container.decodeIfPresent(String.self, forKey: .errorContains) + self.errorCodeName = try container.decodeIfPresent(String.self, forKey: .errorCodeName) + self.errorLabelsContain = try container.decodeIfPresent([String].self, forKey: .errorLabelsContain) + self.errorLabelsOmit = try container.decodeIfPresent([String].self, forKey: .errorLabelsOmit) + } + + public func checkErrorResult(_ error: Error) throws { + try self.checkErrorContains(error) + try self.checkCodeName(error) + try self.checkErrorLabels(error) + } + + // swiftlint:disable cyclomatic_complexity + + internal func checkErrorContains(_ error: Error) throws { + if let errorContains = self.errorContains?.lowercased() { + if let commandError = error as? CommandError { + expect(commandError.message.lowercased()).to(contain(errorContains)) + } else if let writeError = error as? WriteError { + if let writeFailure = writeError.writeFailure { + expect(writeFailure.message.lowercased()).to(contain(errorContains)) + } + if let writeConcernFailure = writeError.writeConcernFailure { + expect(writeConcernFailure.message.lowercased()).to(contain(errorContains)) + } + } else if let bulkWriteError = error as? BulkWriteError { + if let writeFailures = bulkWriteError.writeFailures { + for writeFailure in writeFailures { + expect(writeFailure.message.lowercased()).to(contain(errorContains)) + } + } + if let writeConcernFailure = bulkWriteError.writeConcernFailure { + expect(writeConcernFailure.message.lowercased()).to(contain(errorContains)) + } + } else if let logicError = error as? LogicError { + expect(logicError.errorDescription.lowercased()).to(contain(errorContains)) + } else if let invalidArgumentError = error as? InvalidArgumentError { + expect(invalidArgumentError.errorDescription.lowercased()).to(contain(errorContains)) + } else if let connectionError = error as? ConnectionError { + expect(connectionError.message.lowercased()).to(contain(errorContains)) + } else { + XCTFail("\(error) does not contain message") + } + } + } + + // swiftlint:enable cyclomatic_complexity + + internal func checkCodeName(_ error: Error) throws { + // TODO: can remove `equal("")` references once SERVER-36755 is resolved + if let errorCodeName = self.errorCodeName { + if let commandError = error as? CommandError { + expect(commandError.codeName).to(satisfyAnyOf(equal(errorCodeName), equal(""))) + } else if let writeError = error as? WriteError { + if let writeFailure = writeError.writeFailure { + expect(writeFailure.codeName).to(satisfyAnyOf(equal(errorCodeName), equal(""))) + } + if let writeConcernFailure = writeError.writeConcernFailure { + expect(writeConcernFailure.codeName).to(satisfyAnyOf(equal(errorCodeName), equal(""))) + } + } else if let bulkWriteError = error as? BulkWriteError { + if let writeFailures = bulkWriteError.writeFailures { + for writeFailure in writeFailures { + expect(writeFailure.codeName).to(satisfyAnyOf(equal(errorCodeName), equal(""))) + } + } + if let writeConcernFailure = bulkWriteError.writeConcernFailure { + expect(writeConcernFailure.codeName).to(satisfyAnyOf(equal(errorCodeName), equal(""))) + } + } else { + XCTFail("\(error) does not contain codeName") + } + } + } + + internal func checkErrorLabels(_ error: Error) throws { + // `configureFailPoint` command correctly handles error labels in MongoDB v4.3.1+ (see SERVER-43941). + // Do not check the "RetryableWriteError" error label until the spec test requirements are updated. + let skippedErrorLabels = ["RetryableWriteError"] + + if let errorLabelsContain = self.errorLabelsContain { + guard let labeledError = error as? LabeledError else { + XCTFail("\(error) does not contain errorLabels") + return + } + for label in errorLabelsContain where !skippedErrorLabels.contains(label) { + expect(labeledError.errorLabels).to(contain(label)) + } + } + + if let errorLabelsOmit = self.errorLabelsOmit { + guard let labeledError = error as? LabeledError else { + XCTFail("\(error) does not contain errorLabels") + return + } + guard let errorLabels = labeledError.errorLabels else { + return + } + for label in errorLabelsOmit { + expect(errorLabels).toNot(contain(label)) + } + } + } +} diff --git a/Tests/MongoSwiftSyncTests/SyncChangeStreamTests.swift b/Tests/MongoSwiftSyncTests/SyncChangeStreamTests.swift index 1916d161e..cf70bc2a5 100644 --- a/Tests/MongoSwiftSyncTests/SyncChangeStreamTests.swift +++ b/Tests/MongoSwiftSyncTests/SyncChangeStreamTests.swift @@ -69,7 +69,7 @@ internal struct ChangeStreamTestOperation: Decodable { internal func execute(using client: MongoClient) throws -> TestOperationResult? { let db = client.db(self.database) let coll = db.collection(self.collection) - return try self.operation.execute(on: .collection(coll), session: nil) + return try self.operation.execute(on: .collection(coll), sessions: [:]) } } diff --git a/Tests/MongoSwiftSyncTests/SyncTestUtils.swift b/Tests/MongoSwiftSyncTests/SyncTestUtils.swift index f0e57b893..0cf8ea668 100644 --- a/Tests/MongoSwiftSyncTests/SyncTestUtils.swift +++ b/Tests/MongoSwiftSyncTests/SyncTestUtils.swift @@ -1,4 +1,5 @@ import Foundation +@testable import class MongoSwift.ClientSession @testable import MongoSwiftSync import TestsCommon @@ -190,3 +191,15 @@ extension ChangeStream { return nil } } + +extension MongoSwiftSync.ClientSession { + internal var active: Bool { self.asyncSession.active } + + internal var id: Document? { self.asyncSession.id } + + internal var serverId: UInt32? { self.asyncSession.serverId } + + internal typealias TransactionState = MongoSwift.ClientSession.TransactionState + + internal var transactionState: TransactionState? { self.asyncSession.transactionState } +} diff --git a/Tests/MongoSwiftSyncTests/TransactionsTests.swift b/Tests/MongoSwiftSyncTests/TransactionsTests.swift new file mode 100644 index 000000000..b7946f093 --- /dev/null +++ b/Tests/MongoSwiftSyncTests/TransactionsTests.swift @@ -0,0 +1,78 @@ +import Foundation +import MongoSwift +import Nimble +import TestsCommon + +/// Struct representing a single test within a spec test JSON file. +private struct TransactionsTest: SpecTest { + let description: String + + let operations: [TestOperationDescription] + + let outcome: TestOutcome? + + let skipReason: String? + + let useMultipleMongoses: Bool? + + let clientOptions: ClientOptions? + + let failPoint: FailPoint? + + let sessionOptions: [String: ClientSessionOptions]? + + let expectations: [TestCommandStartedEvent]? + + static let sessionNames: [String] = ["session0", "session1"] +} + +/// Struct representing a single transactions spec test JSON file. +private struct TransactionsTestFile: Decodable, SpecTestFile { + private enum CodingKeys: String, CodingKey { + case name, runOn, databaseName = "database_name", collectionName = "collection_name", data, tests + } + + let name: String + + let runOn: [TestRequirement]? + + let databaseName: String + + let collectionName: String? + + let data: TestData + + let tests: [TransactionsTest] +} + +final class TransactionsTests: MongoSwiftTestCase, FailPointConfigured { + var activeFailPoint: FailPoint? + + override func tearDown() { + self.disableActiveFailPoint() + } + + override func setUp() { + self.continueAfterFailure = false + } + + func testTransactions() throws { + let skippedTestKeywords = [ + "count", // old count API was deprecated before MongoDB 4.0 and is not supported by the driver + "mongos-pin-auto", // TODO: see SWIFT-774 + "mongos-recovery-token", // TODO: see SWIFT-774 + "pin-mongos", // TODO: see SWIFT-774 + "retryable-abort-errorLabels", // requires libmongoc v1.17 (see SWIFT-762) + "retryable-commit-errorLabels" // requires libmongoc v1.17 (see SWIFT-762) + ] + + let tests = try retrieveSpecTestFiles(specName: "transactions", asType: TransactionsTestFile.self) + for (_, testFile) in tests { + guard skippedTestKeywords.allSatisfy({ !testFile.name.contains($0) }) else { + fileLevelLog("Skipping tests from file \(testFile.name)...") + continue + } + try testFile.runTests(parent: self) + } + } +} diff --git a/etc/add_json_files.rb b/etc/add_json_files.rb index a4d4c0ea8..e0745a7fc 100644 --- a/etc/add_json_files.rb +++ b/etc/add_json_files.rb @@ -20,6 +20,7 @@ def make_reference(project, path) change_streams = make_reference(project, "./Tests/Specs/change-streams") dns_seedlist = make_reference(project, "./Tests/Specs/initial-dns-seedlist-discovery") auth = make_reference(project, "./Tests/Specs/auth") -mongoswift_tests_target.add_resources([crud, corpus, cm, read_write_concern, retryable_writes, retryable_reads, change_streams, dns_seedlist, auth]) +transactions = make_reference(project, "./Tests/Specs/transactions") +mongoswift_tests_target.add_resources([crud, corpus, cm, read_write_concern, retryable_writes, retryable_reads, change_streams, dns_seedlist, auth, transactions]) project.save