diff --git a/src/client-side-encryption/auto_encrypter.ts b/src/client-side-encryption/auto_encrypter.ts index d6eb0826c8..981a698837 100644 --- a/src/client-side-encryption/auto_encrypter.ts +++ b/src/client-side-encryption/auto_encrypter.ts @@ -7,9 +7,9 @@ import { import { deserialize, type Document, serialize } from '../bson'; import { type CommandOptions, type ProxyOptions } from '../cmap/connection'; import { getMongoDBClientEncryption } from '../deps'; -import { type AnyError, MongoRuntimeError } from '../error'; +import { MongoRuntimeError } from '../error'; import { MongoClient, type MongoClientOptions } from '../mongo_client'; -import { type Callback, MongoDBCollectionNamespace } from '../utils'; +import { MongoDBCollectionNamespace } from '../utils'; import * as cryptoCallbacks from './crypto_callbacks'; import { MongoCryptInvalidArgumentError } from './errors'; import { MongocryptdManager } from './mongocryptd_manager'; @@ -396,133 +396,66 @@ export class AutoEncrypter { * * This function is a no-op when bypassSpawn is set or the crypt shared library is used. */ - init(callback: Callback) { + async init(): Promise { if (this._bypassMongocryptdAndCryptShared || this.cryptSharedLibVersionInfo) { - return callback(); + return; } if (!this._mongocryptdManager) { - return callback( - new MongoRuntimeError( - 'Reached impossible state: mongocryptdManager is undefined when neither bypassSpawn nor the shared lib are specified.' - ) + throw new MongoRuntimeError( + 'Reached impossible state: mongocryptdManager is undefined when neither bypassSpawn nor the shared lib are specified.' ); } if (!this._mongocryptdClient) { - return callback( - new MongoRuntimeError( - 'Reached impossible state: mongocryptdClient is undefined when neither bypassSpawn nor the shared lib are specified.' - ) + throw new MongoRuntimeError( + 'Reached impossible state: mongocryptdClient is undefined when neither bypassSpawn nor the shared lib are specified.' ); } - const _callback = (err?: AnyError, res?: MongoClient) => { - if ( - err && - err.message && - (err.message.match(/timed out after/) || err.message.match(/ENOTFOUND/)) - ) { - callback( - new MongoRuntimeError( - 'Unable to connect to `mongocryptd`, please make sure it is running or in your PATH for auto-spawn', - { cause: err } - ) - ); - return; - } - - callback(err, res); - }; - if (this._mongocryptdManager.bypassSpawn) { - this._mongocryptdClient.connect().then( - result => { - return _callback(undefined, result); - }, - error => { - _callback(error, undefined); - } - ); - return; + if (!this._mongocryptdManager.bypassSpawn) { + await this._mongocryptdManager.spawn(); } - this._mongocryptdManager.spawn(() => { - if (!this._mongocryptdClient) { - return callback( - new MongoRuntimeError( - 'Reached impossible state: mongocryptdClient is undefined after spawning libmongocrypt.' - ) + try { + const client = await this._mongocryptdClient.connect(); + return client; + } catch (error) { + const { message } = error; + if (message && (message.match(/timed out after/) || message.match(/ENOTFOUND/))) { + throw new MongoRuntimeError( + 'Unable to connect to `mongocryptd`, please make sure it is running or in your PATH for auto-spawn', + { cause: error } ); } - this._mongocryptdClient.connect().then( - result => { - return _callback(undefined, result); - }, - error => { - _callback(error, undefined); - } - ); - }); + throw error; + } } /** * Cleans up the `_mongocryptdClient`, if present. */ - teardown(force: boolean, callback: Callback) { - if (this._mongocryptdClient) { - this._mongocryptdClient.close(force).then( - result => { - return callback(undefined, result); - }, - error => { - callback(error); - } - ); - } else { - callback(); - } + async teardown(force: boolean): Promise { + await this._mongocryptdClient?.close(force); } - encrypt(ns: string, cmd: Document, callback: Callback): void; - encrypt( - ns: string, - cmd: Document, - options: CommandOptions, - callback: Callback - ): void; /** * Encrypt a command for a given namespace. */ - encrypt( + async encrypt( ns: string, cmd: Document, - options?: CommandOptions | Callback, - callback?: Callback - ) { - callback = typeof options === 'function' ? options : callback; - - if (callback == null) { - throw new MongoCryptInvalidArgumentError('Callback must be provided'); - } - - options = typeof options === 'function' ? {} : options; - - // If `bypassAutoEncryption` has been specified, don't encrypt + options: CommandOptions = {} + ): Promise { if (this._bypassEncryption) { - callback(undefined, cmd); - return; + // If `bypassAutoEncryption` has been specified, don't encrypt + return cmd; } const commandBuffer = Buffer.isBuffer(cmd) ? cmd : serialize(cmd, options); - let context; - try { - context = this._mongocrypt.makeEncryptionContext( - MongoDBCollectionNamespace.fromString(ns).db, - commandBuffer - ); - } catch (err) { - callback(err, undefined); - return; - } + const context = this._mongocrypt.makeEncryptionContext( + MongoDBCollectionNamespace.fromString(ns).db, + commandBuffer + ); context.id = this._contextCounter++; context.ns = ns; @@ -534,34 +467,16 @@ export class AutoEncrypter { proxyOptions: this._proxyOptions, tlsOptions: this._tlsOptions }); - stateMachine.execute(this, context, callback); + return stateMachine.execute(this, context); } /** * Decrypt a command response */ - decrypt( - response: Uint8Array, - options: CommandOptions | Callback, - callback?: Callback - ) { - callback = typeof options === 'function' ? options : callback; - - if (callback == null) { - throw new MongoCryptInvalidArgumentError('Callback must be provided'); - } - - options = typeof options === 'function' ? {} : options; - + async decrypt(response: Uint8Array | Document, options: CommandOptions = {}): Promise { const buffer = Buffer.isBuffer(response) ? response : serialize(response, options); - let context; - try { - context = this._mongocrypt.makeDecryptionContext(buffer); - } catch (err) { - callback(err, undefined); - return; - } + const context = this._mongocrypt.makeDecryptionContext(buffer); context.id = this._contextCounter++; @@ -572,16 +487,11 @@ export class AutoEncrypter { }); const decorateResult = this[kDecorateResult]; - stateMachine.execute(this, context, function (error?: Error, result?: Document) { - // Only for testing/internal usage - if (!error && result && decorateResult) { - const error = decorateDecryptionResult(result, response); - // eslint-disable-next-line @typescript-eslint/no-non-null-assertion - if (error) return callback!(error); - } - // eslint-disable-next-line @typescript-eslint/no-non-null-assertion - callback!(error, result); - }); + const result = await stateMachine.execute(this, context); + if (decorateResult) { + decorateDecryptionResult(result, response); + } + return result; } /** @@ -621,14 +531,14 @@ function decorateDecryptionResult( decrypted: Document & { [kDecoratedKeys]?: Array }, original: Document, isTopLevelDecorateCall = true -): Error | void { +): void { if (isTopLevelDecorateCall) { // The original value could have been either a JS object or a BSON buffer if (Buffer.isBuffer(original)) { original = deserialize(original); } if (Buffer.isBuffer(decrypted)) { - return new MongoRuntimeError('Expected result of decryption to be deserialized BSON object'); + throw new MongoRuntimeError('Expected result of decryption to be deserialized BSON object'); } } @@ -647,10 +557,10 @@ function decorateDecryptionResult( writable: false }); } - // this is defined in the preceeding if-statement + // this is defined in the preceding if-statement // eslint-disable-next-line @typescript-eslint/no-non-null-assertion decrypted[kDecoratedKeys]!.push(k); - // Do not recurse into this decrypted value. It could be a subdocument/array, + // Do not recurse into this decrypted value. It could be a sub-document/array, // in which case there is no original value associated with its subfields. continue; } diff --git a/src/client-side-encryption/client_encryption.ts b/src/client-side-encryption/client_encryption.ts index 66e15aaa49..dbd946ed36 100644 --- a/src/client-side-encryption/client_encryption.ts +++ b/src/client-side-encryption/client_encryption.ts @@ -13,7 +13,7 @@ import { type FindCursor } from '../cursor/find_cursor'; import { type Db } from '../db'; import { getMongoDBClientEncryption } from '../deps'; import { type MongoClient } from '../mongo_client'; -import { type Filter } from '../mongo_types'; +import { type Filter, type WithId } from '../mongo_types'; import { type CreateCollectionOptions } from '../operations/create_collection'; import { type DeleteResult } from '../operations/delete'; import { MongoDBCollectionNamespace } from '../utils'; @@ -202,7 +202,7 @@ export class ClientEncryption { tlsOptions: this._tlsOptions }); - const dataKey = await stateMachine.executeAsync(this, context); + const dataKey = await stateMachine.execute(this, context); const { db: dbName, collection: collectionName } = MongoDBCollectionNamespace.fromString( this._keyVaultNamespace @@ -246,7 +246,7 @@ export class ClientEncryption { async rewrapManyDataKey( filter: Filter, options: ClientEncryptionRewrapManyDataKeyProviderOptions - ) { + ): Promise<{ bulkWriteResult?: BulkWriteResult }> { let keyEncryptionKeyBson = undefined; if (options) { const keyEncryptionKey = Object.assign({ provider: options.provider }, options.masterKey); @@ -259,8 +259,8 @@ export class ClientEncryption { tlsOptions: this._tlsOptions }); - const dataKey = await stateMachine.executeAsync<{ v: DataKey[] }>(this, context); - if (!dataKey || dataKey.v.length === 0) { + const { v: dataKeys } = await stateMachine.execute<{ v: DataKey[] }>(this, context); + if (dataKeys.length === 0) { return {}; } @@ -268,7 +268,7 @@ export class ClientEncryption { this._keyVaultNamespace ); - const replacements = dataKey.v.map( + const replacements = dataKeys.map( (key: DataKey): AnyBulkWriteOperation => ({ updateOne: { filter: { _id: key._id }, @@ -386,7 +386,7 @@ export class ClientEncryption { * } * ``` */ - async getKeyByAltName(keyAltName: string) { + async getKeyByAltName(keyAltName: string): Promise | null> { const { db: dbName, collection: collectionName } = MongoDBCollectionNamespace.fromString( this._keyVaultNamespace ); @@ -417,7 +417,7 @@ export class ClientEncryption { * } * ``` */ - async addKeyAltName(_id: Binary, keyAltName: string) { + async addKeyAltName(_id: Binary, keyAltName: string): Promise | null> { const { db: dbName, collection: collectionName } = MongoDBCollectionNamespace.fromString( this._keyVaultNamespace ); @@ -457,7 +457,7 @@ export class ClientEncryption { * } * ``` */ - async removeKeyAltName(_id: Binary, keyAltName: string) { + async removeKeyAltName(_id: Binary, keyAltName: string): Promise | null> { const { db: dbName, collection: collectionName } = MongoDBCollectionNamespace.fromString( this._keyVaultNamespace ); @@ -640,7 +640,7 @@ export class ClientEncryption { tlsOptions: this._tlsOptions }); - const { v } = await stateMachine.executeAsync<{ v: T }>(this, context); + const { v } = await stateMachine.execute<{ v: T }>(this, context); return v; } @@ -719,7 +719,7 @@ export class ClientEncryption { }); const context = this._mongoCrypt.makeExplicitEncryptionContext(valueBuffer, contextOptions); - const result = await stateMachine.executeAsync<{ v: Binary }>(this, context); + const result = await stateMachine.execute<{ v: Binary }>(this, context); return result.v; } } diff --git a/src/client-side-encryption/common.js b/src/client-side-encryption/common.js deleted file mode 100644 index 1d7dd2fe32..0000000000 --- a/src/client-side-encryption/common.js +++ /dev/null @@ -1,59 +0,0 @@ -/* eslint-disable */ - -export function maybeCallback(promiseFn, callback) { - const promise = promiseFn(); - if (callback == null) { - return promise; - } - - promise.then( - result => process.nextTick(callback, undefined, result), - error => process.nextTick(callback, error) - ); - return; -} - -/** - * @ignore - * A helper function. Invokes a function that takes a callback as the final - * parameter. If a callback is supplied, then it is passed to the function. - * If not, a Promise is returned that resolves/rejects with the result of the - * callback - * @param {Function} [callback] an optional callback. - * @param {Function} fn A function that takes a callback - * @returns {Promise|void} Returns nothing if a callback is supplied, else returns a Promise. - */ -export function promiseOrCallback(callback, fn) { - if (typeof callback === 'function') { - fn(function (err) { - if (err != null) { - try { - callback(err); - } catch (error) { - return process.nextTick(() => { - throw error; - }); - } - return; - } - - callback.apply(this, arguments); - }); - - return; - } - - return new Promise((resolve, reject) => { - fn(function (err, res) { - if (err != null) { - return reject(err); - } - - if (arguments.length > 2) { - return resolve(Array.prototype.slice.call(arguments, 1)); - } - - resolve(res); - }); - }); -} diff --git a/src/client-side-encryption/mongocryptd_manager.ts b/src/client-side-encryption/mongocryptd_manager.ts index cf69f0fd7b..499f2aab29 100644 --- a/src/client-side-encryption/mongocryptd_manager.ts +++ b/src/client-side-encryption/mongocryptd_manager.ts @@ -1,6 +1,6 @@ import type { ChildProcess } from 'child_process'; -import { type Callback } from '../utils'; +import { MongoNetworkTimeoutError } from '../error'; import { type AutoEncryptionExtraOptions } from './auto_encrypter'; /** @@ -42,7 +42,7 @@ export class MongocryptdManager { * Will check to see if a mongocryptd is up. If it is not up, it will attempt * to spawn a mongocryptd in a detached process, and then wait for it to be up. */ - spawn(callback: Callback) { + async spawn(): Promise { const cmdName = this.spawnPath || 'mongocryptd'; // eslint-disable-next-line @typescript-eslint/no-var-requires @@ -73,7 +73,24 @@ export class MongocryptdManager { // unref child to remove handle from event loop this._child.unref(); + } - process.nextTick(callback); + /** + * @returns the result of `fn` or rejects with an error. + */ + async withRespawn(fn: () => Promise): ReturnType { + try { + const result = await fn(); + return result; + } catch (err) { + // If we are not bypassing spawning, then we should retry once on a MongoTimeoutError (server selection error) + const shouldSpawn = err instanceof MongoNetworkTimeoutError && !this.bypassSpawn; + if (!shouldSpawn) { + throw err; + } + } + await this.spawn(); + const result = await fn(); + return result; } } diff --git a/src/client-side-encryption/state_machine.ts b/src/client-side-encryption/state_machine.ts index bf3963ffb1..7d5dc23bf8 100644 --- a/src/client-side-encryption/state_machine.ts +++ b/src/client-side-encryption/state_machine.ts @@ -1,8 +1,7 @@ -import * as fs from 'fs'; +import * as fs from 'fs/promises'; import { type MongoCryptContext, type MongoCryptKMSRequest } from 'mongodb-client-encryption'; import * as net from 'net'; import * as tls from 'tls'; -import { promisify } from 'util'; import { type BSONSerializeOptions, @@ -13,9 +12,8 @@ import { } from '../bson'; import { type ProxyOptions } from '../cmap/connection'; import { getSocks, type SocksLib } from '../deps'; -import { MongoNetworkTimeoutError } from '../error'; import { type MongoClient, type MongoClientOptions } from '../mongo_client'; -import { BufferPool, type Callback, MongoDBCollectionNamespace } from '../utils'; +import { BufferPool, MongoDBCollectionNamespace } from '../utils'; import { type DataKey } from './client_encryption'; import { MongoCryptError } from './errors'; import { type MongocryptdManager } from './mongocryptd_manager'; @@ -153,176 +151,130 @@ export class StateMachine { private bsonOptions = pluckBSONSerializeOptions(options) ) {} - executeAsync(executor: StateMachineExecutable, context: MongoCryptContext): Promise { - // @ts-expect-error The callback version allows undefined for the result, but we'll never actually have an undefined result without an error. - return promisify(this.execute.bind(this))(executor, context); - } - /** * Executes the state machine according to the specification */ - execute( + async execute( executor: StateMachineExecutable, - context: MongoCryptContext, - callback: Callback - ) { + context: MongoCryptContext + ): Promise { const keyVaultNamespace = executor._keyVaultNamespace; const keyVaultClient = executor._keyVaultClient; const metaDataClient = executor._metaDataClient; const mongocryptdClient = executor._mongocryptdClient; const mongocryptdManager = executor._mongocryptdManager; + let result: T | null = null; + + while (context.state !== MONGOCRYPT_CTX_DONE && context.state !== MONGOCRYPT_CTX_ERROR) { + debug(`[context#${context.id}] ${stateToString.get(context.state) || context.state}`); - debug(`[context#${context.id}] ${stateToString.get(context.state) || context.state}`); - switch (context.state) { - case MONGOCRYPT_CTX_NEED_MONGO_COLLINFO: { - const filter = deserialize(context.nextMongoOperation()); - if (!metaDataClient) { - return callback( - new MongoCryptError( + switch (context.state) { + case MONGOCRYPT_CTX_NEED_MONGO_COLLINFO: { + const filter = deserialize(context.nextMongoOperation()); + if (!metaDataClient) { + throw new MongoCryptError( 'unreachable state machine state: entered MONGOCRYPT_CTX_NEED_MONGO_COLLINFO but metadata client is undefined' - ) - ); - } - this.fetchCollectionInfo(metaDataClient, context.ns, filter, (err, collInfo) => { - if (err) { - return callback(err); + ); } + const collInfo = await this.fetchCollectionInfo(metaDataClient, context.ns, filter); if (collInfo) { context.addMongoOperationResponse(collInfo); } context.finishMongoOperation(); - this.execute(executor, context, callback); - }); - - return; - } + break; + } - case MONGOCRYPT_CTX_NEED_MONGO_MARKINGS: { - const command = context.nextMongoOperation(); - if (!mongocryptdClient) { - return callback( - new MongoCryptError( + case MONGOCRYPT_CTX_NEED_MONGO_MARKINGS: { + const command = context.nextMongoOperation(); + if (!mongocryptdClient) { + throw new MongoCryptError( 'unreachable state machine state: entered MONGOCRYPT_CTX_NEED_MONGO_MARKINGS but mongocryptdClient is undefined' - ) - ); - } - this.markCommand(mongocryptdClient, context.ns, command, (err, markedCommand) => { - if (err || !markedCommand) { - // If we are not bypassing spawning, then we should retry once on a MongoTimeoutError (server selection error) - if ( - err instanceof MongoNetworkTimeoutError && - mongocryptdManager && - !mongocryptdManager.bypassSpawn - ) { - mongocryptdManager.spawn(() => { - // TODO: should we be shadowing the variables here? - this.markCommand(mongocryptdClient, context.ns, command, (err, markedCommand) => { - if (err || !markedCommand) return callback(err); - - context.addMongoOperationResponse(markedCommand); - context.finishMongoOperation(); - - this.execute(executor, context, callback); - }); - }); - return; - } - return callback(err); + ); } - context.addMongoOperationResponse(markedCommand); - context.finishMongoOperation(); - this.execute(executor, context, callback); - }); + // When we are using the shared library, we don't have a mongocryptd manager. + const markedCommand: Uint8Array = mongocryptdManager + ? await mongocryptdManager.withRespawn( + this.markCommand.bind(this, mongocryptdClient, context.ns, command) + ) + : await this.markCommand(mongocryptdClient, context.ns, command); - return; - } + context.addMongoOperationResponse(markedCommand); + context.finishMongoOperation(); + break; + } - case MONGOCRYPT_CTX_NEED_MONGO_KEYS: { - const filter = context.nextMongoOperation(); - this.fetchKeys(keyVaultClient, keyVaultNamespace, filter, (err, keys) => { - if (err || !keys) return callback(err); - keys.forEach(key => { + case MONGOCRYPT_CTX_NEED_MONGO_KEYS: { + const filter = context.nextMongoOperation(); + const keys = await this.fetchKeys(keyVaultClient, keyVaultNamespace, filter); + + if (keys.length === 0) { + // This is kind of a hack. For `rewrapManyDataKey`, we have tests that + // guarantee that when there are no matching keys, `rewrapManyDataKey` returns + // nothing. We also have tests for auto encryption that guarantee for `encrypt` + // we return an error when there are no matching keys. This error is generated in + // subsequent iterations of the state machine. + // Some apis (`encrypt`) throw if there are no filter matches and others (`rewrapManyDataKey`) + // do not. We set the result manually here, and let the state machine continue. `libmongocrypt` + // will inform us if we need to error by setting the state to `MONGOCRYPT_CTX_ERROR` but + // otherwise we'll return `{ v: [] }`. + result = { v: [] } as any as T; + } + for await (const key of keys) { context.addMongoOperationResponse(serialize(key)); - }); + } context.finishMongoOperation(); - this.execute(executor, context, callback); - }); - return; - } + break; + } - case MONGOCRYPT_CTX_NEED_KMS_CREDENTIALS: { - executor - .askForKMSCredentials() - .then(kmsProviders => { - context.provideKMSProviders(serialize(kmsProviders)); - this.execute(executor, context, callback); - }) - .catch(err => { - callback(err); - }); - - return; - } + case MONGOCRYPT_CTX_NEED_KMS_CREDENTIALS: { + const kmsProviders = await executor.askForKMSCredentials(); + context.provideKMSProviders(serialize(kmsProviders)); + break; + } - case MONGOCRYPT_CTX_NEED_KMS: { - const promises = []; + case MONGOCRYPT_CTX_NEED_KMS: { + const requests = Array.from(this.requests(context)); + await Promise.all(requests); - let request; - while ((request = context.nextKMSRequest())) { - promises.push(this.kmsRequest(request)); + context.finishKMSRequests(); + break; } - Promise.all(promises) - .then(() => { - context.finishKMSRequests(); - this.execute(executor, context, callback); - }) - .catch(err => { - callback(err); - }); + case MONGOCRYPT_CTX_READY: { + const finalizedContext = context.finalize(); + // @ts-expect-error finalize can change the state, check for error + if (context.state === MONGOCRYPT_CTX_ERROR) { + const message = context.status.message || 'Finalization error'; + throw new MongoCryptError(message); + } + result = deserialize(finalizedContext, this.options) as T; + break; + } - return; + default: + throw new MongoCryptError(`Unknown state: ${context.state}`); } + } - // terminal states - case MONGOCRYPT_CTX_READY: { - const finalizedContext = context.finalize(); - // TODO: Maybe rework the logic here so that instead of doing - // the callback here, finalize stores the result, and then - // we wait to MONGOCRYPT_CTX_DONE to do the callback - // @ts-expect-error finalize can change the state, check for error - if (context.state === MONGOCRYPT_CTX_ERROR) { - const message = context.status.message || 'Finalization error'; - callback(new MongoCryptError(message)); - return; - } - callback(undefined, deserialize(finalizedContext, this.options) as T); - return; - } - case MONGOCRYPT_CTX_ERROR: { - const message = context.status.message; - callback( - new MongoCryptError( - message ?? - 'unidentifiable error in MongoCrypt - received an error status from `libmongocrypt` but received no error message.' - ) + if (context.state === MONGOCRYPT_CTX_ERROR || result == null) { + const message = context.status.message; + if (!message) { + debug( + `unidentifiable error in MongoCrypt - received an error status from \`libmongocrypt\` but received no error message.` ); - return; } - - case MONGOCRYPT_CTX_DONE: - callback(); - return; - - default: - callback(new MongoCryptError(`Unknown state: ${context.state}`)); - return; + throw new MongoCryptError( + message ?? + 'unidentifiable error in MongoCrypt - received an error status from `libmongocrypt` but received no error message.' + ); } + + return result; } /** @@ -341,11 +293,11 @@ export class StateMachine { const message = request.message; // TODO(NODE-3959): We can adopt `for-await on(socket, 'data')` with logic to control abort - // eslint-disable-next-line no-async-promise-executor, @typescript-eslint/no-misused-promises + // eslint-disable-next-line @typescript-eslint/no-misused-promises, no-async-promise-executor return new Promise(async (resolve, reject) => { const buffer = new BufferPool(); - /* eslint-disable prefer-const */ + // eslint-disable-next-line prefer-const let socket: net.Socket; let rawSocket: net.Socket; @@ -409,7 +361,11 @@ export class StateMachine { if (providerTlsOptions) { const error = this.validateTlsOptions(kmsProvider, providerTlsOptions); if (error) reject(error); - this.setTlsOptions(providerTlsOptions, options); + try { + await this.setTlsOptions(providerTlsOptions, options); + } catch (error) { + return onerror(error); + } } } socket = tls.connect(options, () => { @@ -435,6 +391,16 @@ export class StateMachine { }); } + *requests(context: MongoCryptContext) { + for ( + let request = context.nextKMSRequest(); + request != null; + request = context.nextKMSRequest() + ) { + yield this.kmsRequest(request); + } + } + /** * Validates the provided TLS options are secure. * @@ -461,13 +427,16 @@ export class StateMachine { * @param tlsOptions - The client TLS options for the provider. * @param options - The existing connection options. */ - setTlsOptions(tlsOptions: ClientEncryptionTlsOptions, options: tls.ConnectionOptions) { + async setTlsOptions( + tlsOptions: ClientEncryptionTlsOptions, + options: tls.ConnectionOptions + ): Promise { if (tlsOptions.tlsCertificateKeyFile) { - const cert = fs.readFileSync(tlsOptions.tlsCertificateKeyFile); + const cert = await fs.readFile(tlsOptions.tlsCertificateKeyFile); options.cert = options.key = cert; } if (tlsOptions.tlsCAFile) { - options.ca = fs.readFileSync(tlsOptions.tlsCAFile); + options.ca = await fs.readFile(tlsOptions.tlsCAFile); } if (tlsOptions.tlsCertificateKeyFilePassword) { options.passphrase = tlsOptions.tlsCertificateKeyFilePassword; @@ -485,30 +454,23 @@ export class StateMachine { * @param filter - A filter for the listCollections command * @param callback - Invoked with the info of the requested collection, or with an error */ - fetchCollectionInfo( + async fetchCollectionInfo( client: MongoClient, ns: string, - filter: Document, - callback: Callback - ) { + filter: Document + ): Promise { const { db } = MongoDBCollectionNamespace.fromString(ns); - client + const collections = await client .db(db) .listCollections(filter, { promoteLongs: false, promoteValues: false }) - .toArray() - .then( - collections => { - const info = collections.length > 0 ? serialize(collections[0]) : null; - return callback(undefined, info); - }, - err => { - callback(err); - } - ); + .toArray(); + + const info = collections.length > 0 ? serialize(collections[0]) : null; + return info; } /** @@ -519,27 +481,14 @@ export class StateMachine { * @param command - The command to execute. * @param callback - Invoked with the serialized and marked bson command, or with an error */ - markCommand( - client: MongoClient, - ns: string, - command: Uint8Array, - callback: Callback - ) { + async markCommand(client: MongoClient, ns: string, command: Uint8Array): Promise { const options = { promoteLongs: false, promoteValues: false }; const { db } = MongoDBCollectionNamespace.fromString(ns); const rawCommand = deserialize(command, options); - client - .db(db) - .command(rawCommand, options) - .then( - response => { - return callback(undefined, serialize(response, this.bsonOptions)); - }, - err => { - callback(err); - } - ); + const response = await client.db(db).command(rawCommand, options); + + return serialize(response, this.bsonOptions); } /** @@ -553,24 +502,15 @@ export class StateMachine { fetchKeys( client: MongoClient, keyVaultNamespace: string, - filter: Uint8Array, - callback: Callback> - ) { + filter: Uint8Array + ): Promise> { const { db: dbName, collection: collectionName } = MongoDBCollectionNamespace.fromString(keyVaultNamespace); - client + return client .db(dbName) .collection(collectionName, { readConcern: { level: 'majority' } }) .find(deserialize(filter)) - .toArray() - .then( - keys => { - return callback(undefined, keys); - }, - err => { - callback(err); - } - ); + .toArray(); } } diff --git a/src/cmap/connection.ts b/src/cmap/connection.ts index 5afb8a522b..0159db9d8a 100644 --- a/src/cmap/connection.ts +++ b/src/cmap/connection.ts @@ -598,32 +598,37 @@ export class CryptoConnection extends Connection { ? cmd.indexes.map((index: { key: Map }) => index.key) : null; - autoEncrypter.encrypt(ns.toString(), cmd, options, (err, encrypted) => { - if (err || encrypted == null) { - callback(err, null); - return; - } - - // Replace the saved values - if (sort != null && (cmd.find || cmd.findAndModify)) { - encrypted.sort = sort; - } - if (indexKeys != null && cmd.createIndexes) { - for (const [offset, index] of indexKeys.entries()) { - // @ts-expect-error `encrypted` is a generic "command", but we've narrowed for only `createIndexes` commands here - encrypted.indexes[offset].key = index; + autoEncrypter.encrypt(ns.toString(), cmd, options).then( + encrypted => { + // Replace the saved values + if (sort != null && (cmd.find || cmd.findAndModify)) { + encrypted.sort = sort; } - } - - super.command(ns, encrypted, options, (err, response) => { - if (err || response == null) { - callback(err, response); - return; + if (indexKeys != null && cmd.createIndexes) { + for (const [offset, index] of indexKeys.entries()) { + // @ts-expect-error `encrypted` is a generic "command", but we've narrowed for only `createIndexes` commands here + encrypted.indexes[offset].key = index; + } } - autoEncrypter.decrypt(response, options, callback); - }); - }); + super.command(ns, encrypted, options, (err, response) => { + if (err || response == null) { + callback(err, response); + return; + } + + autoEncrypter.decrypt(response, options).then( + res => callback(undefined, res), + err => callback(err) + ); + }); + }, + err => { + if (err) { + callback(err, null); + } + } + ); } } diff --git a/src/encrypter.ts b/src/encrypter.ts index f03e8d6c1a..db3413bae7 100644 --- a/src/encrypter.ts +++ b/src/encrypter.ts @@ -1,3 +1,5 @@ +import { callbackify } from 'util'; + import { AutoEncrypter, type AutoEncryptionOptions } from './client-side-encryption/auto_encrypter'; import { MONGO_CLIENT_EVENTS } from './constants'; import { getMongoDBClientEncryption } from './deps'; @@ -101,19 +103,19 @@ export class Encrypter { } } - close(client: MongoClient, force: boolean, callback: Callback): void { - // TODO(NODE-5422): add typescript support - this.autoEncrypter.teardown(!!force, (e: any) => { - const internalClient = this[kInternalClient]; - if (internalClient != null && client !== internalClient) { - internalClient.close(force).then( - () => callback(), - error => callback(error) - ); - return; - } - callback(e); - }); + closeCallback(client: MongoClient, force: boolean, callback: Callback) { + callbackify(this.close.bind(this))(client, force, callback); + } + + async close(client: MongoClient, force: boolean): Promise { + const maybeError: Error | void = await this.autoEncrypter.teardown(!!force).catch(e => e); + const internalClient = this[kInternalClient]; + if (internalClient != null && client !== internalClient) { + return internalClient.close(force); + } + if (maybeError) { + throw maybeError; + } } static checkForMongoCrypt(): void { diff --git a/src/mongo_client.ts b/src/mongo_client.ts index 64a3305eda..3e9a4c504b 100644 --- a/src/mongo_client.ts +++ b/src/mongo_client.ts @@ -497,8 +497,7 @@ export class MongoClient extends TypedEventEmitter { }; if (this.autoEncrypter) { - const initAutoEncrypter = promisify(callback => this.autoEncrypter?.init(callback)); - await initAutoEncrypter(); + await this.autoEncrypter?.init(); await topologyConnect(); await options.encrypter.connectInternalClient(); } else { @@ -559,7 +558,7 @@ export class MongoClient extends TypedEventEmitter { if (error) return reject(error); const { encrypter } = this[kOptions]; if (encrypter) { - return encrypter.close(this, force, error => { + return encrypter.closeCallback(this, force, error => { if (error) return reject(error); resolve(); }); diff --git a/test/integration/node-specific/auto_encrypter.test.ts b/test/integration/node-specific/auto_encrypter.test.ts index e8da0fc8d9..e4868b44d5 100644 --- a/test/integration/node-specific/auto_encrypter.test.ts +++ b/test/integration/node-specific/auto_encrypter.test.ts @@ -3,7 +3,6 @@ import { spawnSync } from 'child_process'; import * as fs from 'fs'; import { dirname, resolve } from 'path'; import * as sinon from 'sinon'; -import { promisify } from 'util'; /* eslint-disable @typescript-eslint/no-restricted-imports */ import { AutoEncrypter } from '../../../src/client-side-encryption/auto_encrypter'; @@ -54,9 +53,7 @@ describe('crypt_shared library', function () { await client.connect(); }); afterEach(async () => { - await promisify(cb => - autoEncrypter ? autoEncrypter.teardown(true, cb) : cb(undefined, undefined) - )(); + await autoEncrypter?.teardown(true); await client?.close(); }); const sandbox = sinon.createSandbox(); @@ -103,7 +100,7 @@ describe('crypt_shared library', function () { 'should autoSpawn a mongocryptd on init by default', { requires: { clientSideEncryption: true, predicate: cryptShared('disabled') } }, - function (done) { + async function () { autoEncrypter = new AutoEncrypter(client, { keyVaultNamespace: 'admin.datakeys', kmsProviders: { @@ -117,18 +114,15 @@ describe('crypt_shared library', function () { const localMcdm = autoEncrypter._mongocryptdManager; sandbox.spy(localMcdm, 'spawn'); - autoEncrypter.init(err => { - if (err) return done(err); - expect(localMcdm.spawn).to.have.been.calledOnce; - done(); - }); + await autoEncrypter.init(); + expect(localMcdm.spawn).to.have.been.calledOnce; } ); it( 'should not attempt to kick off mongocryptd on a normal error', { requires: { clientSideEncryption: true, predicate: cryptShared('disabled') } }, - function (done) { + async function () { let called = false; StateMachine.prototype.markCommand.callsFake((client, ns, filter, callback) => { if (!called) { @@ -150,24 +144,20 @@ describe('crypt_shared library', function () { expect(autoEncrypter).to.have.property('cryptSharedLibVersionInfo', null); const localMcdm = autoEncrypter._mongocryptdManager; - autoEncrypter.init(err => { - if (err) return done(err); + await autoEncrypter.init(); - sandbox.spy(localMcdm, 'spawn'); + sandbox.spy(localMcdm, 'spawn'); - autoEncrypter.encrypt('test.test', TEST_COMMAND, err => { - expect(localMcdm.spawn).to.not.have.been.called; - expect(err).to.be.an.instanceOf(Error); - done(); - }); - }); + const err = await autoEncrypter.encrypt('test.test', TEST_COMMAND).catch(e => e); + expect(localMcdm.spawn).to.not.have.been.called; + expect(err).to.be.an.instanceOf(Error); } ); it( 'should restore the mongocryptd and retry once if a MongoNetworkTimeoutError is experienced', { requires: { clientSideEncryption: true, predicate: cryptShared('disabled') } }, - function (done) { + async function () { let called = false; StateMachine.prototype.markCommand.callsFake((client, ns, filter, callback) => { if (!called) { @@ -189,24 +179,19 @@ describe('crypt_shared library', function () { expect(autoEncrypter).to.have.property('cryptSharedLibVersionInfo', null); const localMcdm = autoEncrypter._mongocryptdManager; - autoEncrypter.init(err => { - if (err) return done(err); + await autoEncrypter.init(); - sandbox.spy(localMcdm, 'spawn'); + sandbox.spy(localMcdm, 'spawn'); - autoEncrypter.encrypt('test.test', TEST_COMMAND, err => { - expect(localMcdm.spawn).to.have.been.calledOnce; - expect(err).to.not.exist; - done(); - }); - }); + await autoEncrypter.encrypt('test.test', TEST_COMMAND); + expect(localMcdm.spawn).to.have.been.calledOnce; } ); it( 'should propagate error if MongoNetworkTimeoutError is experienced twice in a row', { requires: { clientSideEncryption: true, predicate: cryptShared('disabled') } }, - function (done) { + async function () { let counter = 2; StateMachine.prototype.markCommand.callsFake((client, ns, filter, callback) => { if (counter) { @@ -228,24 +213,20 @@ describe('crypt_shared library', function () { expect(autoEncrypter).to.have.property('cryptSharedLibVersionInfo', null); const localMcdm = autoEncrypter._mongocryptdManager; - autoEncrypter.init(err => { - if (err) return done(err); + await autoEncrypter.init(); - sandbox.spy(localMcdm, 'spawn'); + sandbox.spy(localMcdm, 'spawn'); - autoEncrypter.encrypt('test.test', TEST_COMMAND, err => { - expect(localMcdm.spawn).to.have.been.calledOnce; - expect(err).to.be.an.instanceof(MongoNetworkTimeoutError); - done(); - }); - }); + const err = await autoEncrypter.encrypt('test.test', TEST_COMMAND).catch(e => e); + expect(localMcdm.spawn).to.have.been.calledOnce; + expect(err).to.be.an.instanceof(MongoNetworkTimeoutError); } ); it( 'should return a useful message if mongocryptd fails to autospawn', { requires: { clientSideEncryption: true, predicate: cryptShared('disabled') } }, - function (done) { + async function () { autoEncrypter = new AutoEncrypter(client, { keyVaultNamespace: 'admin.datakeys', kmsProviders: { @@ -258,15 +239,11 @@ describe('crypt_shared library', function () { }); expect(autoEncrypter).to.have.property('cryptSharedLibVersionInfo', null); - sandbox.stub(MongocryptdManager.prototype, 'spawn').callsFake(callback => { - callback(); - }); + sandbox.stub(MongocryptdManager.prototype, 'spawn').resolves(); - autoEncrypter.init(err => { - expect(err).to.exist; - expect(err).to.be.instanceOf(MongoError); - done(); - }); + const err = await autoEncrypter.init().catch(e => e); + expect(err).to.exist; + expect(err).to.be.instanceOf(MongoError); } ); }); @@ -289,7 +266,7 @@ describe('crypt_shared library', function () { it( `should not spawn mongocryptd on startup if ${opt} is true`, { requires: { clientSideEncryption: true, predicate: cryptShared('disabled') } }, - function (done) { + async function () { autoEncrypter = new AutoEncrypter(client, encryptionOptions); const localMcdm = autoEncrypter._mongocryptdManager || { @@ -299,11 +276,8 @@ describe('crypt_shared library', function () { }; sandbox.spy(localMcdm, 'spawn'); - autoEncrypter.init(err => { - expect(err).to.not.exist; - expect(localMcdm.spawn).to.have.a.callCount(0); - done(); - }); + await autoEncrypter.init(); + expect(localMcdm.spawn).to.have.a.callCount(0); } ); }); @@ -311,7 +285,7 @@ describe('crypt_shared library', function () { it( 'should not spawn a mongocryptd or retry on a server selection error if mongocryptdBypassSpawn: true', { requires: { clientSideEncryption: true, predicate: cryptShared('disabled') } }, - function (done) { + async function () { let called = false; const timeoutError = new MongoNetworkTimeoutError('msg'); StateMachine.prototype.markCommand.callsFake((client, ns, filter, callback) => { @@ -338,17 +312,12 @@ describe('crypt_shared library', function () { const localMcdm = autoEncrypter._mongocryptdManager; sandbox.spy(localMcdm, 'spawn'); - autoEncrypter.init(err => { - expect(err).to.not.exist; - expect(localMcdm.spawn).to.not.have.been.called; + await autoEncrypter.init(); + expect(localMcdm.spawn).to.not.have.been.called; - autoEncrypter.encrypt('test.test', TEST_COMMAND, (err, response) => { - expect(localMcdm.spawn).to.not.have.been.called; - expect(response).to.not.exist; - expect(err).to.equal(timeoutError); - done(); - }); - }); + const err = await autoEncrypter.encrypt('test.test', TEST_COMMAND).catch(e => e); + expect(localMcdm.spawn).to.not.have.been.called; + expect(err).to.equal(timeoutError); } ); }); diff --git a/test/unit/client-side-encryption/auto_encrypter.test.ts b/test/unit/client-side-encryption/auto_encrypter.test.ts index 965bce6e3d..33c62ec7bc 100644 --- a/test/unit/client-side-encryption/auto_encrypter.test.ts +++ b/test/unit/client-side-encryption/auto_encrypter.test.ts @@ -10,7 +10,7 @@ import { MongocryptdManager } from '../../../src/client-side-encryption/mongocry import { StateMachine } from '../../../src/client-side-encryption/state_machine'; // eslint-disable-next-line @typescript-eslint/no-restricted-imports import { MongoClient } from '../../../src/mongo_client'; -import { BSON } from '../../mongodb'; +import { BSON, type DataKey } from '../../mongodb'; import * as requirements from './requirements.helper'; const bson = BSON; @@ -53,32 +53,23 @@ describe('AutoEncrypter', function () { return Promise.resolve(); }); - sandbox - .stub(StateMachine.prototype, 'fetchCollectionInfo') - .callsFake((client, ns, filter, callback) => { - callback(null, MOCK_COLLINFO_RESPONSE); - }); + sandbox.stub(StateMachine.prototype, 'fetchCollectionInfo').resolves(MOCK_COLLINFO_RESPONSE); - sandbox - .stub(StateMachine.prototype, 'markCommand') - .callsFake((client, ns, command, callback) => { - if (ENABLE_LOG_TEST) { - const response = bson.deserialize(MOCK_MONGOCRYPTD_RESPONSE); - response.schemaRequiresEncryption = false; - - ENABLE_LOG_TEST = false; // disable test after run - callback(null, bson.serialize(response)); - return; - } + sandbox.stub(StateMachine.prototype, 'markCommand').callsFake(() => { + if (ENABLE_LOG_TEST) { + const response = bson.deserialize(MOCK_MONGOCRYPTD_RESPONSE); + response.schemaRequiresEncryption = false; - callback(null, MOCK_MONGOCRYPTD_RESPONSE); - }); + ENABLE_LOG_TEST = false; // disable test after run + return Promise.resolve(bson.serialize(response)); + } - sandbox.stub(StateMachine.prototype, 'fetchKeys').callsFake((client, ns, filter, callback) => { - // mock data is already serialized, our action deals with the result of a cursor - const deserializedKey = bson.deserialize(MOCK_KEYDOCUMENT_RESPONSE); - callback(null, [deserializedKey]); + return Promise.resolve(MOCK_MONGOCRYPTD_RESPONSE); }); + + sandbox + .stub(StateMachine.prototype, 'fetchKeys') + .resolves([bson.deserialize(MOCK_KEYDOCUMENT_RESPONSE) as DataKey]); }); afterEach(() => { @@ -135,7 +126,7 @@ describe('AutoEncrypter', function () { }); }); - it('should support `bypassAutoEncryption`', function (done) { + it('should support `bypassAutoEncryption`', async function () { const client = new MockClient(); const autoEncrypter = new AutoEncrypter(client, { bypassAutoEncryption: true, @@ -151,15 +142,12 @@ describe('AutoEncrypter', function () { } }); - autoEncrypter.encrypt('test.test', { test: 'command' }, (err, encrypted) => { - expect(err).to.not.exist; - expect(encrypted).to.eql({ test: 'command' }); - done(); - }); + const encrypted = await autoEncrypter.encrypt('test.test', { test: 'command' }); + expect(encrypted).to.eql({ test: 'command' }); }); describe('state machine', function () { - it('should decrypt mock data', function (done) { + it('should decrypt mock data', async function () { const input = readExtendedJsonToBuffer(`${__dirname}/data/encrypted-document.json`); const client = new MockClient() as MongoClient; const mc = new AutoEncrypter(client, { @@ -173,16 +161,13 @@ describe('AutoEncrypter', function () { local: { key: Buffer.alloc(96) } } }); - mc.decrypt(input, (err, decrypted) => { - if (err) return done(err); - expect(decrypted).to.eql({ filter: { find: 'test', ssn: '457-55-5462' } }); - expect(decrypted).to.not.have.property(Symbol.for('@@mdb.decryptedKeys')); - expect(decrypted.filter).to.not.have.property(Symbol.for('@@mdb.decryptedKeys')); - done(); - }); + const decrypted = await mc.decrypt(input); + expect(decrypted).to.eql({ filter: { find: 'test', ssn: '457-55-5462' } }); + expect(decrypted).to.not.have.property(Symbol.for('@@mdb.decryptedKeys')); + expect(decrypted.filter).to.not.have.property(Symbol.for('@@mdb.decryptedKeys')); }); - it('should decrypt mock data and mark decrypted items if enabled for testing', function (done) { + it('should decrypt mock data and mark decrypted items if enabled for testing', async function () { const input = readExtendedJsonToBuffer(`${__dirname}/data/encrypted-document.json`); const nestedInput = readExtendedJsonToBuffer( `${__dirname}/data/encrypted-document-nested.json` @@ -200,30 +185,23 @@ describe('AutoEncrypter', function () { } }); mc[Symbol.for('@@mdb.decorateDecryptionResult')] = true; - mc.decrypt(input, (err, decrypted) => { - if (err) return done(err); - expect(decrypted).to.eql({ filter: { find: 'test', ssn: '457-55-5462' } }); - expect(decrypted).to.not.have.property(Symbol.for('@@mdb.decryptedKeys')); - expect(decrypted.filter[Symbol.for('@@mdb.decryptedKeys')]).to.eql(['ssn']); - - // The same, but with an object containing different data types as the input - mc.decrypt({ a: [null, 1, { c: new bson.Binary('foo', 1) }] }, (err, decrypted) => { - if (err) return done(err); - expect(decrypted).to.eql({ a: [null, 1, { c: new bson.Binary('foo', 1) }] }); - expect(decrypted).to.not.have.property(Symbol.for('@@mdb.decryptedKeys')); - - // The same, but with nested data inside the decrypted input - mc.decrypt(nestedInput, (err, decrypted) => { - if (err) return done(err); - expect(decrypted).to.eql({ nested: { x: { y: 1234 } } }); - expect(decrypted[Symbol.for('@@mdb.decryptedKeys')]).to.eql(['nested']); - expect(decrypted.nested).to.not.have.property(Symbol.for('@@mdb.decryptedKeys')); - expect(decrypted.nested.x).to.not.have.property(Symbol.for('@@mdb.decryptedKeys')); - expect(decrypted.nested.x.y).to.not.have.property(Symbol.for('@@mdb.decryptedKeys')); - done(); - }); - }); - }); + let decrypted = await mc.decrypt(input); + expect(decrypted).to.eql({ filter: { find: 'test', ssn: '457-55-5462' } }); + expect(decrypted).to.not.have.property(Symbol.for('@@mdb.decryptedKeys')); + expect(decrypted.filter[Symbol.for('@@mdb.decryptedKeys')]).to.eql(['ssn']); + + // The same, but with an object containing different data types as the input + decrypted = await mc.decrypt({ a: [null, 1, { c: new bson.Binary('foo', 1) }] }); + expect(decrypted).to.eql({ a: [null, 1, { c: new bson.Binary('foo', 1) }] }); + expect(decrypted).to.not.have.property(Symbol.for('@@mdb.decryptedKeys')); + + // The same, but with nested data inside the decrypted input + decrypted = await mc.decrypt(nestedInput); + expect(decrypted).to.eql({ nested: { x: { y: 1234 } } }); + expect(decrypted[Symbol.for('@@mdb.decryptedKeys')]).to.eql(['nested']); + expect(decrypted.nested).to.not.have.property(Symbol.for('@@mdb.decryptedKeys')); + expect(decrypted.nested.x).to.not.have.property(Symbol.for('@@mdb.decryptedKeys')); + expect(decrypted.nested.x.y).to.not.have.property(Symbol.for('@@mdb.decryptedKeys')); }); context('when the aws sdk is installed', function () { @@ -247,7 +225,7 @@ describe('AutoEncrypter', function () { process.env.AWS_SECRET_ACCESS_KEY = originalSecretAccessKey; }); - it('should decrypt mock data with KMS credentials from the environment', function (done) { + it('should decrypt mock data with KMS credentials from the environment', async function () { const input = readExtendedJsonToBuffer(`${__dirname}/data/encrypted-document.json`); const client = new MockClient(); const mc = new AutoEncrypter(client, { @@ -260,11 +238,8 @@ describe('AutoEncrypter', function () { aws: {} } }); - mc.decrypt(input, (err, decrypted) => { - if (err) return done(err); - expect(decrypted).to.eql({ filter: { find: 'test', ssn: '457-55-5462' } }); - done(); - }); + const decrypted = await mc.decrypt(input); + expect(decrypted).to.eql({ filter: { find: 'test', ssn: '457-55-5462' } }); }); }); @@ -289,7 +264,7 @@ describe('AutoEncrypter', function () { process.env.AWS_SECRET_ACCESS_KEY = originalSecretAccessKey; }); - it('errors without the optional sdk credential provider', function (done) { + it('errors without the optional sdk credential provider', async function () { const input = readExtendedJsonToBuffer(`${__dirname}/data/encrypted-document.json`); const client = new MockClient(); const mc = new AutoEncrypter(client, { @@ -302,16 +277,14 @@ describe('AutoEncrypter', function () { aws: {} } }); - mc.decrypt(input, err => { - expect(err.message).to.equal( - 'client not configured with KMS provider necessary to decrypt' - ); - done(); - }); + const error = await mc.decrypt(input).catch(e => e); + expect(error.message).to.equal( + 'client not configured with KMS provider necessary to decrypt' + ); }); }); - it('should encrypt mock data', function (done) { + it('should encrypt mock data', async function () { const client = new MockClient(); const mc = new AutoEncrypter(client, { keyVaultNamespace: 'admin.datakeys', @@ -325,31 +298,28 @@ describe('AutoEncrypter', function () { } }); - mc.encrypt('test.test', TEST_COMMAND, (err, encrypted) => { - if (err) return done(err); - const expected = EJSON.parse( - JSON.stringify({ - find: 'test', - filter: { - ssn: { - $binary: { - base64: - 'AWFhYWFhYWFhYWFhYWFhYWECRTOW9yZzNDn5dGwuqsrJQNLtgMEKaujhs9aRWRp+7Yo3JK8N8jC8P0Xjll6C1CwLsE/iP5wjOMhVv1KMMyOCSCrHorXRsb2IKPtzl2lKTqQ=', - subType: '6' - } + const encrypted = await mc.encrypt('test.test', TEST_COMMAND); + const expected = EJSON.parse( + JSON.stringify({ + find: 'test', + filter: { + ssn: { + $binary: { + base64: + 'AWFhYWFhYWFhYWFhYWFhYWECRTOW9yZzNDn5dGwuqsrJQNLtgMEKaujhs9aRWRp+7Yo3JK8N8jC8P0Xjll6C1CwLsE/iP5wjOMhVv1KMMyOCSCrHorXRsb2IKPtzl2lKTqQ=', + subType: '6' } } - }) - ); + } + }) + ); - expect(encrypted).to.containSubset(expected); - done(); - }); + expect(encrypted).to.containSubset(expected); }); }); describe('logging', function () { - it('should allow registration of a log handler', function (done) { + it('should allow registration of a log handler', async function () { ENABLE_LOG_TEST = true; let loggerCalled = false; @@ -370,20 +340,17 @@ describe('AutoEncrypter', function () { } }); - mc.encrypt('test.test', TEST_COMMAND, (err, encrypted) => { - if (err) return done(err); - const expected = EJSON.parse( - JSON.stringify({ - find: 'test', - filter: { - ssn: '457-55-5462' - } - }) - ); + const encrypted = await mc.encrypt('test.test', TEST_COMMAND); + const expected = EJSON.parse( + JSON.stringify({ + find: 'test', + filter: { + ssn: '457-55-5462' + } + }) + ); - expect(encrypted).to.containSubset(expected); - done(); - }); + expect(encrypted).to.containSubset(expected); }); }); diff --git a/test/unit/client-side-encryption/client_encryption.test.ts b/test/unit/client-side-encryption/client_encryption.test.ts index 5e2586990a..c83383d4e4 100644 --- a/test/unit/client-side-encryption/client_encryption.test.ts +++ b/test/unit/client-side-encryption/client_encryption.test.ts @@ -53,7 +53,7 @@ describe('ClientEncryption', function () { }); // stubbed out for AWS unit testing below - sandbox.stub(StateMachine.prototype, 'fetchKeys').callsFake((client, ns, filter, cb) => { + sandbox.stub(StateMachine.prototype, 'fetchKeys').callsFake((client, ns, filter) => { filter = deserialize(filter); const keyIds = filter.$or[0]._id.$in.map(key => key.toString('hex')); const fileNames = keyIds.map(keyId => @@ -62,7 +62,7 @@ describe('ClientEncryption', function () { const contents = fileNames.map(filename => EJSON.parse(fs.readFileSync(filename, { encoding: 'utf-8' })) ); - cb(null, contents); + return Promise.resolve(contents); }); }); diff --git a/test/unit/client-side-encryption/common.test.js b/test/unit/client-side-encryption/common.test.js deleted file mode 100644 index 93232d9779..0000000000 --- a/test/unit/client-side-encryption/common.test.js +++ /dev/null @@ -1,95 +0,0 @@ -'use strict'; - -const { expect } = require('chai'); -// eslint-disable-next-line no-restricted-modules -const { maybeCallback } = require('../../../src/client-side-encryption/common'); - -describe('maybeCallback()', () => { - it('should accept two arguments', () => { - expect(maybeCallback).to.have.lengthOf(2); - }); - - describe('when handling an error case', () => { - it('should pass the error to the callback provided', done => { - const superPromiseRejection = Promise.reject(new Error('fail')); - const result = maybeCallback( - () => superPromiseRejection, - (error, result) => { - try { - expect(result).to.not.exist; - expect(error).to.be.instanceOf(Error); - return done(); - } catch (assertionError) { - return done(assertionError); - } - } - ); - expect(result).to.be.undefined; - }); - - it('should return the rejected promise to the caller when no callback is provided', async () => { - const superPromiseRejection = Promise.reject(new Error('fail')); - const returnedPromise = maybeCallback(() => superPromiseRejection, undefined); - expect(returnedPromise).to.equal(superPromiseRejection); - // @ts-expect-error: There is no overload to change the return type not be nullish, - // and we do not want to add one in fear of making it too easy to neglect adding the callback argument - const thrownError = await returnedPromise.catch(error => error); - expect(thrownError).to.be.instanceOf(Error); - }); - - it('should not modify a rejection error promise', async () => { - class MyError extends Error {} - const driverError = Object.freeze(new MyError()); - const rejection = Promise.reject(driverError); - // @ts-expect-error: There is no overload to change the return type not be nullish, - // and we do not want to add one in fear of making it too easy to neglect adding the callback argument - const thrownError = await maybeCallback(() => rejection, undefined).catch(error => error); - expect(thrownError).to.be.equal(driverError); - }); - - it('should not modify a rejection error when passed to callback', done => { - class MyError extends Error {} - const driverError = Object.freeze(new MyError()); - const rejection = Promise.reject(driverError); - maybeCallback( - () => rejection, - error => { - try { - expect(error).to.exist; - expect(error).to.equal(driverError); - done(); - } catch (assertionError) { - done(assertionError); - } - } - ); - }); - }); - - describe('when handling a success case', () => { - it('should pass the result and undefined error to the callback provided', done => { - const superPromiseSuccess = Promise.resolve(2); - - const result = maybeCallback( - () => superPromiseSuccess, - (error, result) => { - try { - expect(error).to.be.undefined; - expect(result).to.equal(2); - done(); - } catch (assertionError) { - done(assertionError); - } - } - ); - expect(result).to.be.undefined; - }); - - it('should return the resolved promise to the caller when no callback is provided', async () => { - const superPromiseSuccess = Promise.resolve(2); - const result = maybeCallback(() => superPromiseSuccess); - expect(result).to.equal(superPromiseSuccess); - expect(await result).to.equal(2); - }); - }); -}); diff --git a/test/unit/client-side-encryption/state_machine.test.ts b/test/unit/client-side-encryption/state_machine.test.ts index f1a7042d46..e84b7c1f18 100644 --- a/test/unit/client-side-encryption/state_machine.test.ts +++ b/test/unit/client-side-encryption/state_machine.test.ts @@ -1,11 +1,12 @@ import { expect } from 'chai'; import { EventEmitter, once } from 'events'; -import * as fs from 'fs'; +import * as fs from 'fs/promises'; import { type MongoCryptKMSRequest } from 'mongodb-client-encryption'; import * as net from 'net'; import * as sinon from 'sinon'; import { setTimeout } from 'timers'; import * as tls from 'tls'; +import { promisify } from 'util'; // eslint-disable-next-line @typescript-eslint/no-restricted-imports import { StateMachine } from '../../../src/client-side-encryption/state_machine'; @@ -176,22 +177,24 @@ describe('StateMachine', function () { const buffer = Buffer.from('foobar'); let connectOptions; - it('sets the cert and key options in the tls connect options', function (done) { - this.sinon.stub(fs, 'readFileSync').callsFake(fileName => { + it('sets the cert and key options in the tls connect options', async function () { + this.sinon.stub(fs, 'readFile').callsFake(fileName => { expect(fileName).to.equal('test.pem'); - return buffer; + return Promise.resolve(buffer); }); this.sinon.stub(tls, 'connect').callsFake((options, callback) => { connectOptions = options; this.fakeSocket = new MockSocket(callback); return this.fakeSocket; }); - stateMachine.kmsRequest(request).then(function () { - expect(connectOptions.cert).to.equal(buffer); - expect(connectOptions.key).to.equal(buffer); - done(); - }); + const kmsRequestPromise = stateMachine.kmsRequest(request); + + await promisify(setTimeout)(0); this.fakeSocket.emit('data', Buffer.alloc(0)); + + await kmsRequestPromise; + expect(connectOptions.cert).to.equal(buffer); + expect(connectOptions.key).to.equal(buffer); }); }); @@ -203,21 +206,23 @@ describe('StateMachine', function () { const buffer = Buffer.from('foobar'); let connectOptions; - it('sets the ca options in the tls connect options', function (done) { - this.sinon.stub(fs, 'readFileSync').callsFake(fileName => { + it('sets the ca options in the tls connect options', async function () { + this.sinon.stub(fs, 'readFile').callsFake(fileName => { expect(fileName).to.equal('test.pem'); - return buffer; + return Promise.resolve(buffer); }); this.sinon.stub(tls, 'connect').callsFake((options, callback) => { connectOptions = options; this.fakeSocket = new MockSocket(callback); return this.fakeSocket; }); - stateMachine.kmsRequest(request).then(function () { - expect(connectOptions.ca).to.equal(buffer); - done(); - }); + const kmsRequestPromise = stateMachine.kmsRequest(request); + + await promisify(setTimeout)(0); this.fakeSocket.emit('data', Buffer.alloc(0)); + + await kmsRequestPromise; + expect(connectOptions.ca).to.equal(buffer); }); }); @@ -228,17 +233,19 @@ describe('StateMachine', function () { const request = new MockRequest(Buffer.from('foobar'), -1); let connectOptions; - it('sets the passphrase option in the tls connect options', function (done) { + it('sets the passphrase option in the tls connect options', async function () { this.sinon.stub(tls, 'connect').callsFake((options, callback) => { connectOptions = options; this.fakeSocket = new MockSocket(callback); return this.fakeSocket; }); - stateMachine.kmsRequest(request).then(function () { - expect(connectOptions.passphrase).to.equal('test'); - done(); - }); + const kmsRequestPromise = stateMachine.kmsRequest(request); + + await promisify(setTimeout)(0); this.fakeSocket.emit('data', Buffer.alloc(0)); + + await kmsRequestPromise; + expect(connectOptions.passphrase).to.equal('test'); }); }); });