From b60216223801bc09f1e3eeab0003c1e896db5547 Mon Sep 17 00:00:00 2001 From: Bailey Pearson Date: Wed, 8 Nov 2023 10:23:59 -0700 Subject: [PATCH] refactor(NODE-5696): add async-iterator based socket helpers (#3896) Co-authored-by: Neal Beeken --- src/cmap/command_monitoring_events.ts | 10 +- src/cmap/commands.ts | 73 ++- src/cmap/connection.ts | 547 +++++++++++++++++- src/cmap/message_stream.ts | 28 +- src/cmap/wire_protocol/compression.ts | 73 +++ src/index.ts | 15 +- .../cmap/command_monitoring_events.test.js | 18 +- test/unit/cmap/commands.test.js | 26 +- test/unit/cmap/connection.test.ts | 23 +- test/unit/cmap/message_stream.test.ts | 4 +- test/unit/cmap/modern_connection.test.ts | 425 ++++++++++++++ test/unit/commands.test.ts | 133 ++++- test/unit/index.test.ts | 12 +- 13 files changed, 1292 insertions(+), 95 deletions(-) create mode 100644 test/unit/cmap/modern_connection.test.ts diff --git a/src/cmap/command_monitoring_events.ts b/src/cmap/command_monitoring_events.ts index ada41b4304..b08ce723eb 100644 --- a/src/cmap/command_monitoring_events.ts +++ b/src/cmap/command_monitoring_events.ts @@ -7,7 +7,7 @@ import { LEGACY_HELLO_COMMAND_CAMEL_CASE } from '../constants'; import { calculateDurationInMs, deepCopy } from '../utils'; -import { Msg, type Query, type WriteProtocolMessageType } from './commands'; +import { OpMsgRequest, type OpQueryRequest, type WriteProtocolMessageType } from './commands'; import type { Connection } from './connection'; /** @@ -181,8 +181,8 @@ const HELLO_COMMANDS = new Set(['hello', LEGACY_HELLO_COMMAND, LEGACY_HELLO_COMM // helper methods const extractCommandName = (commandDoc: Document) => Object.keys(commandDoc)[0]; -const namespace = (command: Query) => command.ns; -const collectionName = (command: Query) => command.ns.split('.')[1]; +const namespace = (command: OpQueryRequest) => command.ns; +const collectionName = (command: OpQueryRequest) => command.ns.split('.')[1]; const maybeRedact = (commandName: string, commandDoc: Document, result: Error | Document) => SENSITIVE_COMMANDS.has(commandName) || (HELLO_COMMANDS.has(commandName) && commandDoc.speculativeAuthenticate) @@ -220,7 +220,7 @@ const OP_QUERY_KEYS = [ /** Extract the actual command from the query, possibly up-converting if it's a legacy format */ function extractCommand(command: WriteProtocolMessageType): Document { - if (command instanceof Msg) { + if (command instanceof OpMsgRequest) { return deepCopy(command.command); } @@ -283,7 +283,7 @@ function extractReply(command: WriteProtocolMessageType, reply?: Document) { return reply; } - if (command instanceof Msg) { + if (command instanceof OpMsgRequest) { return deepCopy(reply.result ? reply.result : reply); } diff --git a/src/cmap/commands.ts b/src/cmap/commands.ts index 15ecbfc8d6..ee1e7b6a7f 100644 --- a/src/cmap/commands.ts +++ b/src/cmap/commands.ts @@ -4,7 +4,13 @@ import { MongoInvalidArgumentError, MongoRuntimeError } from '../error'; import { ReadPreference } from '../read_preference'; import type { ClientSession } from '../sessions'; import type { CommandOptions } from './connection'; -import { OP_MSG, OP_QUERY } from './wire_protocol/constants'; +import { + compress, + Compressor, + type CompressorName, + uncompressibleCommands +} from './wire_protocol/compression'; +import { OP_COMPRESSED, OP_MSG, OP_QUERY } from './wire_protocol/constants'; // Incrementing request id let _requestId = 0; @@ -25,7 +31,7 @@ const SHARD_CONFIG_STALE = 4; const AWAIT_CAPABLE = 8; /** @internal */ -export type WriteProtocolMessageType = Query | Msg; +export type WriteProtocolMessageType = OpQueryRequest | OpMsgRequest; /** @internal */ export interface OpQueryOptions extends CommandOptions { @@ -52,7 +58,7 @@ export interface OpQueryOptions extends CommandOptions { * QUERY **************************************************************/ /** @internal */ -export class Query { +export class OpQueryRequest { ns: string; numberToSkip: number; numberToReturn: number; @@ -96,7 +102,7 @@ export class Query { this.numberToSkip = options.numberToSkip || 0; this.numberToReturn = options.numberToReturn || 0; this.returnFieldSelector = options.returnFieldSelector || undefined; - this.requestId = Query.getRequestId(); + this.requestId = options.requestId ?? OpQueryRequest.getRequestId(); // special case for pre-3.2 find commands, delete ASAP this.pre32Limit = options.pre32Limit; @@ -285,7 +291,7 @@ export interface OpResponseOptions extends BSONSerializeOptions { } /** @internal */ -export class Response { +export class OpQueryResponse { parsed: boolean; raw: Buffer; data: Buffer; @@ -472,7 +478,7 @@ export interface OpMsgOptions { } /** @internal */ -export class Msg { +export class OpMsgRequest { requestId: number; serializeFunctions: boolean; ignoreUndefined: boolean; @@ -502,7 +508,7 @@ export class Msg { this.options = options ?? {}; // Additional options - this.requestId = options.requestId ? options.requestId : Msg.getRequestId(); + this.requestId = options.requestId ? options.requestId : OpMsgRequest.getRequestId(); // Serialization option this.serializeFunctions = @@ -580,7 +586,7 @@ export class Msg { } /** @internal */ -export class BinMsg { +export class OpMsgResponse { parsed: boolean; raw: Buffer; data: Buffer; @@ -709,3 +715,54 @@ export class BinMsg { return { utf8: { writeErrors: false } }; } } + +const MESSAGE_HEADER_SIZE = 16; +const COMPRESSION_DETAILS_SIZE = 9; // originalOpcode + uncompressedSize, compressorID + +/** + * @internal + * + * An OP_COMPRESSED request wraps either an OP_QUERY or OP_MSG message. + */ +export class OpCompressedRequest { + constructor( + private command: WriteProtocolMessageType, + private options: { zlibCompressionLevel: number; agreedCompressor: CompressorName } + ) {} + + // Return whether a command contains an uncompressible command term + // Will return true if command contains no uncompressible command terms + static canCompress(command: WriteProtocolMessageType) { + const commandDoc = command instanceof OpMsgRequest ? command.command : command.query; + const commandName = Object.keys(commandDoc)[0]; + return !uncompressibleCommands.has(commandName); + } + + async toBin(): Promise { + const concatenatedOriginalCommandBuffer = Buffer.concat(this.command.toBin()); + // otherwise, compress the message + const messageToBeCompressed = concatenatedOriginalCommandBuffer.slice(MESSAGE_HEADER_SIZE); + + // Extract information needed for OP_COMPRESSED from the uncompressed message + const originalCommandOpCode = concatenatedOriginalCommandBuffer.readInt32LE(12); + + // Compress the message body + const compressedMessage = await compress(this.options, messageToBeCompressed); + // Create the msgHeader of OP_COMPRESSED + const msgHeader = Buffer.alloc(MESSAGE_HEADER_SIZE); + msgHeader.writeInt32LE( + MESSAGE_HEADER_SIZE + COMPRESSION_DETAILS_SIZE + compressedMessage.length, + 0 + ); // messageLength + msgHeader.writeInt32LE(this.command.requestId, 4); // requestID + msgHeader.writeInt32LE(0, 8); // responseTo (zero) + msgHeader.writeInt32LE(OP_COMPRESSED, 12); // opCode + + // Create the compression details of OP_COMPRESSED + const compressionDetails = Buffer.alloc(COMPRESSION_DETAILS_SIZE); + compressionDetails.writeInt32LE(originalCommandOpCode, 0); // originalOpcode + compressionDetails.writeInt32LE(messageToBeCompressed.length, 4); // Size of the uncompressed compressedMessage, excluding the MsgHeader + compressionDetails.writeUInt8(Compressor[this.options.agreedCompressor], 8); // compressorID + return [msgHeader, compressionDetails, compressedMessage]; + } +} diff --git a/src/cmap/connection.ts b/src/cmap/connection.ts index 67dc42ee7f..d505496519 100644 --- a/src/cmap/connection.ts +++ b/src/cmap/connection.ts @@ -1,3 +1,5 @@ +import { once } from 'events'; +import { on } from 'stream'; import { clearTimeout, setTimeout } from 'timers'; import { promisify } from 'util'; @@ -18,6 +20,7 @@ import { MongoMissingDependencyError, MongoNetworkError, MongoNetworkTimeoutError, + MongoParseError, MongoRuntimeError, MongoServerError, MongoWriteConcernError @@ -27,6 +30,7 @@ import { type CancellationToken, TypedEventEmitter } from '../mongo_types'; import type { ReadPreferenceLike } from '../read_preference'; import { applySession, type ClientSession, updateSessionFromResponse } from '../sessions'; import { + BufferPool, calculateDurationInMs, type Callback, HostAddress, @@ -43,11 +47,19 @@ import { CommandStartedEvent, CommandSucceededEvent } from './command_monitoring_events'; -import { type BinMsg, Msg, Query, type Response, type WriteProtocolMessageType } from './commands'; +import { + OpCompressedRequest, + OpMsgRequest, + type OpMsgResponse, + OpQueryRequest, + type OpQueryResponse, + type WriteProtocolMessageType +} from './commands'; import type { Stream } from './connect'; import type { ClientMetadata } from './handshake/client_metadata'; import { MessageStream, type OperationDescription } from './message_stream'; import { StreamDescription, type StreamDescriptionOptions } from './stream_description'; +import { decompressResponse } from './wire_protocol/compression'; import { getReadPreference, isSharded } from './wire_protocol/shared'; /** @internal */ @@ -324,7 +336,7 @@ export class Connection extends TypedEventEmitter { }, 1).unref(); // No need for this timer to hold the event loop open } - onMessage(message: BinMsg | Response) { + onMessage(message: OpMsgResponse | OpQueryResponse) { const delayedTimeoutId = this[kDelayedTimeoutId]; if (delayedTimeoutId != null) { clearTimeout(delayedTimeoutId); @@ -540,8 +552,8 @@ export class Connection extends TypedEventEmitter { ); const message = shouldUseOpMsg - ? new Msg(ns.db, cmd, commandOptions) - : new Query(ns.db, cmd, commandOptions); + ? new OpMsgRequest(ns.db, cmd, commandOptions) + : new OpQueryRequest(ns.db, cmd, commandOptions); try { write(this, message, commandOptions, callback); @@ -757,3 +769,530 @@ function write( operationDescription.cb(); } } + +/** in-progress connection layer */ + +/** @internal */ +export class ModernConnection extends TypedEventEmitter { + id: number | ''; + address: string; + socketTimeoutMS: number; + monitorCommands: boolean; + /** Indicates that the connection (including underlying TCP socket) has been closed. */ + closed: boolean; + lastHelloMS?: number; + serverApi?: ServerApi; + helloOk?: boolean; + commandAsync: ( + ns: MongoDBNamespace, + cmd: Document, + options: CommandOptions | undefined + ) => Promise; + /** @internal */ + authContext?: AuthContext; + + /**@internal */ + [kDelayedTimeoutId]: NodeJS.Timeout | null; + /** @internal */ + [kDescription]: StreamDescription; + /** @internal */ + [kGeneration]: number; + /** @internal */ + [kLastUseTime]: number; + /** @internal */ + [kQueue]: Map; + /** @internal */ + [kMessageStream]: MessageStream; + /** @internal */ + socket: Stream; + /** @internal */ + [kHello]: Document | null; + /** @internal */ + [kClusterTime]: Document | null; + + /** @event */ + static readonly COMMAND_STARTED = COMMAND_STARTED; + /** @event */ + static readonly COMMAND_SUCCEEDED = COMMAND_SUCCEEDED; + /** @event */ + static readonly COMMAND_FAILED = COMMAND_FAILED; + /** @event */ + static readonly CLUSTER_TIME_RECEIVED = CLUSTER_TIME_RECEIVED; + /** @event */ + static readonly CLOSE = CLOSE; + /** @event */ + static readonly MESSAGE = MESSAGE; + /** @event */ + static readonly PINNED = PINNED; + /** @event */ + static readonly UNPINNED = UNPINNED; + + constructor(stream: Stream, options: ConnectionOptions) { + super(); + + this.commandAsync = promisify( + ( + ns: MongoDBNamespace, + cmd: Document, + options: CommandOptions | undefined, + callback: Callback + ) => this.command(ns, cmd, options, callback as any) + ); + + this.id = options.id; + this.address = streamIdentifier(stream, options); + this.socketTimeoutMS = options.socketTimeoutMS ?? 0; + this.monitorCommands = options.monitorCommands; + this.serverApi = options.serverApi; + this.closed = false; + this[kHello] = null; + this[kClusterTime] = null; + + this[kDescription] = new StreamDescription(this.address, options); + this[kGeneration] = options.generation; + this[kLastUseTime] = now(); + + // setup parser stream and message handling + this[kQueue] = new Map(); + this[kMessageStream] = new MessageStream({ + ...options, + maxBsonMessageSize: this.hello?.maxBsonMessageSize + }); + this.socket = stream; + + this[kDelayedTimeoutId] = null; + + this[kMessageStream].on('message', message => this.onMessage(message)); + this[kMessageStream].on('error', error => this.onError(error)); + this.socket.on('close', () => this.onClose()); + this.socket.on('timeout', () => this.onTimeout()); + this.socket.on('error', () => { + /* ignore errors, listen to `close` instead */ + }); + + // hook the message stream up to the passed in stream + this.socket.pipe(this[kMessageStream]); + this[kMessageStream].pipe(this.socket); + } + + get description(): StreamDescription { + return this[kDescription]; + } + + get hello(): Document | null { + return this[kHello]; + } + + // the `connect` method stores the result of the handshake hello on the connection + set hello(response: Document | null) { + this[kDescription].receiveResponse(response); + this[kDescription] = Object.freeze(this[kDescription]); + + // TODO: remove this, and only use the `StreamDescription` in the future + this[kHello] = response; + } + + // Set the whether the message stream is for a monitoring connection. + set isMonitoringConnection(value: boolean) { + this[kMessageStream].isMonitoringConnection = value; + } + + get isMonitoringConnection(): boolean { + return this[kMessageStream].isMonitoringConnection; + } + + get serviceId(): ObjectId | undefined { + return this.hello?.serviceId; + } + + get loadBalanced(): boolean { + return this.description.loadBalanced; + } + + get generation(): number { + return this[kGeneration] || 0; + } + + set generation(generation: number) { + this[kGeneration] = generation; + } + + get idleTime(): number { + return calculateDurationInMs(this[kLastUseTime]); + } + + get clusterTime(): Document | null { + return this[kClusterTime]; + } + + get stream(): Stream { + return this.socket; + } + + get hasSessionSupport(): boolean { + return this.description.logicalSessionTimeoutMinutes != null; + } + + get supportsOpMsg(): boolean { + return ( + this.description != null && + maxWireVersion(this as any as Connection) >= 6 && + !this.description.__nodejs_mock_server__ + ); + } + + markAvailable(): void { + this[kLastUseTime] = now(); + } + + onError(error: Error) { + this.cleanup(true, error); + } + + onClose() { + const message = `connection ${this.id} to ${this.address} closed`; + this.cleanup(true, new MongoNetworkError(message)); + } + + onTimeout() { + this[kDelayedTimeoutId] = setTimeout(() => { + const message = `connection ${this.id} to ${this.address} timed out`; + const beforeHandshake = this.hello == null; + this.cleanup(true, new MongoNetworkTimeoutError(message, { beforeHandshake })); + }, 1).unref(); // No need for this timer to hold the event loop open + } + + onMessage(message: OpMsgResponse | OpQueryResponse) { + const delayedTimeoutId = this[kDelayedTimeoutId]; + if (delayedTimeoutId != null) { + clearTimeout(delayedTimeoutId); + this[kDelayedTimeoutId] = null; + } + + const socketTimeoutMS = this.socket.timeout ?? 0; + this.socket.setTimeout(0); + + // always emit the message, in case we are streaming + this.emit('message', message); + let operationDescription = this[kQueue].get(message.responseTo); + + if (!operationDescription && this.isMonitoringConnection) { + // This is how we recover when the initial hello's requestId is not + // the responseTo when hello responses have been skipped: + + // First check if the map is of invalid size + if (this[kQueue].size > 1) { + this.cleanup(true, new MongoRuntimeError(INVALID_QUEUE_SIZE)); + } else { + // Get the first orphaned operation description. + const entry = this[kQueue].entries().next(); + if (entry.value != null) { + const [requestId, orphaned]: [number, OperationDescription] = entry.value; + // If the orphaned operation description exists then set it. + operationDescription = orphaned; + // Remove the entry with the bad request id from the queue. + this[kQueue].delete(requestId); + } + } + } + + if (!operationDescription) { + return; + } + + const callback = operationDescription.cb; + + // SERVER-45775: For exhaust responses we should be able to use the same requestId to + // track response, however the server currently synthetically produces remote requests + // making the `responseTo` change on each response + this[kQueue].delete(message.responseTo); + if ('moreToCome' in message && message.moreToCome) { + // If the operation description check above does find an orphaned + // description and sets the operationDescription then this line will put one + // back in the queue with the correct requestId and will resolve not being able + // to find the next one via the responseTo of the next streaming hello. + this[kQueue].set(message.requestId, operationDescription); + this.socket.setTimeout(socketTimeoutMS); + } + + try { + // Pass in the entire description because it has BSON parsing options + message.parse(operationDescription); + } catch (err) { + // If this error is generated by our own code, it will already have the correct class applied + // if it is not, then it is coming from a catastrophic data parse failure or the BSON library + // in either case, it should not be wrapped + callback(err); + return; + } + + if (message.documents[0]) { + const document: Document = message.documents[0]; + const session = operationDescription.session; + if (session) { + updateSessionFromResponse(session, document); + } + + if (document.$clusterTime) { + this[kClusterTime] = document.$clusterTime; + this.emit(Connection.CLUSTER_TIME_RECEIVED, document.$clusterTime); + } + + if (document.writeConcernError) { + callback(new MongoWriteConcernError(document.writeConcernError, document), document); + return; + } + + if (document.ok === 0 || document.$err || document.errmsg || document.code) { + callback(new MongoServerError(document)); + return; + } + } + + callback(undefined, message.documents[0]); + } + + destroy(options: DestroyOptions, callback?: Callback): void { + if (this.closed) { + process.nextTick(() => callback?.()); + return; + } + if (typeof callback === 'function') { + this.once('close', () => process.nextTick(() => callback())); + } + + // load balanced mode requires that these listeners remain on the connection + // after cleanup on timeouts, errors or close so we remove them before calling + // cleanup. + this.removeAllListeners(Connection.PINNED); + this.removeAllListeners(Connection.UNPINNED); + const message = `connection ${this.id} to ${this.address} closed`; + this.cleanup(options.force, new MongoNetworkError(message)); + } + + /** + * A method that cleans up the connection. When `force` is true, this method + * forcibly destroys the socket. + * + * If an error is provided, any in-flight operations will be closed with the error. + * + * This method does nothing if the connection is already closed. + */ + private cleanup(force: boolean, error?: Error): void { + if (this.closed) { + return; + } + + this.closed = true; + + const completeCleanup = () => { + for (const op of this[kQueue].values()) { + op.cb(error); + } + + this[kQueue].clear(); + + this.emit(Connection.CLOSE); + }; + + this.socket.removeAllListeners(); + this[kMessageStream].removeAllListeners(); + + this[kMessageStream].destroy(); + + if (force) { + this.socket.destroy(); + completeCleanup(); + return; + } + + if (!this.socket.writableEnded) { + this.socket.end(() => { + this.socket.destroy(); + completeCleanup(); + }); + } else { + completeCleanup(); + } + } + + command( + ns: MongoDBNamespace, + command: Document, + options: CommandOptions | undefined, + callback: Callback + ): void { + let cmd = { ...command }; + + const readPreference = getReadPreference(options); + const session = options?.session; + + let clusterTime = this.clusterTime; + + if (this.serverApi) { + const { version, strict, deprecationErrors } = this.serverApi; + cmd.apiVersion = version; + if (strict != null) cmd.apiStrict = strict; + if (deprecationErrors != null) cmd.apiDeprecationErrors = deprecationErrors; + } + + if (this.hasSessionSupport && session) { + if ( + session.clusterTime && + clusterTime && + session.clusterTime.clusterTime.greaterThan(clusterTime.clusterTime) + ) { + clusterTime = session.clusterTime; + } + + const err = applySession(session, cmd, options); + if (err) { + return callback(err); + } + } else if (session?.explicit) { + return callback(new MongoCompatibilityError('Current topology does not support sessions')); + } + + // if we have a known cluster time, gossip it + if (clusterTime) { + cmd.$clusterTime = clusterTime; + } + + if ( + // @ts-expect-error ModernConnections cannot be passed as connections + isSharded(this) && + !this.supportsOpMsg && + readPreference && + readPreference.mode !== 'primary' + ) { + cmd = { + $query: cmd, + $readPreference: readPreference.toJSON() + }; + } + + const commandOptions: Document = Object.assign( + { + numberToSkip: 0, + numberToReturn: -1, + checkKeys: false, + // This value is not overridable + secondaryOk: readPreference.secondaryOk() + }, + options + ); + + const message = this.supportsOpMsg + ? new OpMsgRequest(ns.db, cmd, commandOptions) + : new OpQueryRequest(ns.db, cmd, commandOptions); + + try { + write(this as any as Connection, message, commandOptions, callback); + } catch (err) { + callback(err); + } + } +} + +const kDefaultMaxBsonMessageSize = 1024 * 1024 * 16 * 4; + +/** + * @internal + * + * This helper reads chucks of data out of a socket and buffers them until it has received a + * full wire protocol message. + * + * By itself, produces an infinite async generator of wire protocol messages and consumers must end + * the stream by calling `return` on the generator. + * + * Note that `for-await` loops call `return` automatically when the loop is exited. + */ +export async function* readWireProtocolMessages( + connection: ModernConnection +): AsyncGenerator { + const bufferPool = new BufferPool(); + const maxBsonMessageSize = connection.hello?.maxBsonMessageSize ?? kDefaultMaxBsonMessageSize; + for await (const [chunk] of on(connection.socket, 'data')) { + bufferPool.append(chunk); + const sizeOfMessage = bufferPool.getInt32(); + + if (sizeOfMessage == null) { + continue; + } + + if (sizeOfMessage < 0) { + throw new MongoParseError(`Invalid message size: ${sizeOfMessage}`); + } + + if (sizeOfMessage > maxBsonMessageSize) { + throw new MongoParseError( + `Invalid message size: ${sizeOfMessage}, max allowed: ${maxBsonMessageSize}` + ); + } + + if (sizeOfMessage > bufferPool.length) { + continue; + } + + yield bufferPool.read(sizeOfMessage); + } +} + +/** + * @internal + * + * Writes an OP_MSG or OP_QUERY request to the socket, optionally compressing the command. This method + * waits until the socket's buffer has emptied (the Nodejs socket `drain` event has fired). + */ +export async function writeCommand( + connection: ModernConnection, + command: WriteProtocolMessageType, + options: Partial> +): Promise { + const drained = once(connection.socket, 'drain'); + const finalCommand = + options.agreedCompressor === 'none' || !OpCompressedRequest.canCompress(command) + ? command + : new OpCompressedRequest(command, { + agreedCompressor: options.agreedCompressor ?? 'none', + zlibCompressionLevel: options.zlibCompressionLevel ?? 0 + }); + const buffer = Buffer.concat(await finalCommand.toBin()); + connection.socket.push(buffer); + await drained; +} + +/** + * @internal + * + * Returns an async generator that yields full wire protocol messages from the underlying socket. This function + * yields messages until `moreToCome` is false or not present in a response, or the caller cancels the request + * by calling `return` on the generator. + * + * Note that `for-await` loops call `return` automatically when the loop is exited. + */ +export async function* readMany( + connection: ModernConnection +): AsyncGenerator { + for await (const message of readWireProtocolMessages(connection)) { + const response = await decompressResponse(message); + yield response; + + if (!('moreToCome' in response) || !response.moreToCome) { + return; + } + } +} + +/** + * @internal + * + * Reads a single wire protocol message out of a connection. + */ +export async function read(connection: ModernConnection): Promise { + for await (const value of readMany(connection)) { + return value; + } + + throw new MongoRuntimeError('unable to read message off of connection'); +} diff --git a/src/cmap/message_stream.ts b/src/cmap/message_stream.ts index e90b34650f..42d16ae26e 100644 --- a/src/cmap/message_stream.ts +++ b/src/cmap/message_stream.ts @@ -5,19 +5,13 @@ import { MongoDecompressionError, MongoParseError } from '../error'; import type { ClientSession } from '../sessions'; import { BufferPool, type Callback } from '../utils'; import { - BinMsg, type MessageHeader, - Msg, - Response, + OpCompressedRequest, + OpMsgResponse, + OpQueryResponse, type WriteProtocolMessageType } from './commands'; -import { - compress, - Compressor, - type CompressorName, - decompress, - uncompressibleCommands -} from './wire_protocol/compression'; +import { compress, Compressor, type CompressorName, decompress } from './wire_protocol/compression'; import { OP_COMPRESSED, OP_MSG } from './wire_protocol/constants'; const MESSAGE_HEADER_SIZE = 16; @@ -85,7 +79,7 @@ export class MessageStream extends Duplex { operationDescription: OperationDescription ): void { const agreedCompressor = operationDescription.agreedCompressor ?? 'none'; - if (agreedCompressor === 'none' || !canCompress(command)) { + if (agreedCompressor === 'none' || !OpCompressedRequest.canCompress(command)) { const data = command.toBin(); this.push(Array.isArray(data) ? Buffer.concat(data) : data); return; @@ -128,14 +122,6 @@ export class MessageStream extends Duplex { } } -// Return whether a command contains an uncompressible command term -// Will return true if command contains no uncompressible command terms -function canCompress(command: WriteProtocolMessageType) { - const commandDoc = command instanceof Msg ? command.command : command.query; - const commandName = Object.keys(commandDoc)[0]; - return !uncompressibleCommands.has(commandName); -} - function processIncomingData(stream: MessageStream, callback: Callback): void { const buffer = stream[kBuffer]; const sizeOfMessage = buffer.getInt32(); @@ -179,7 +165,7 @@ function processIncomingData(stream: MessageStream, callback: Callback): return false; }; - let ResponseType = messageHeader.opCode === OP_MSG ? BinMsg : Response; + let ResponseType = messageHeader.opCode === OP_MSG ? OpMsgResponse : OpQueryResponse; if (messageHeader.opCode !== OP_COMPRESSED) { const messageBody = message.subarray(MESSAGE_HEADER_SIZE); @@ -205,7 +191,7 @@ function processIncomingData(stream: MessageStream, callback: Callback): const compressedBuffer = message.slice(MESSAGE_HEADER_SIZE + 9); // recalculate based on wrapped opcode - ResponseType = messageHeader.opCode === OP_MSG ? BinMsg : Response; + ResponseType = messageHeader.opCode === OP_MSG ? OpMsgResponse : OpQueryResponse; decompress(compressorID, compressedBuffer).then( messageBody => { if (messageBody.length !== messageHeader.length) { diff --git a/src/cmap/wire_protocol/compression.ts b/src/cmap/wire_protocol/compression.ts index 6e55268c54..74cca5da5f 100644 --- a/src/cmap/wire_protocol/compression.ts +++ b/src/cmap/wire_protocol/compression.ts @@ -4,6 +4,15 @@ import * as zlib from 'zlib'; import { LEGACY_HELLO_COMMAND } from '../../constants'; import { getSnappy, getZstdLibrary, type SnappyLib, type ZStandard } from '../../deps'; import { MongoDecompressionError, MongoInvalidArgumentError } from '../../error'; +import { + type MessageHeader, + OpCompressedRequest, + OpMsgResponse, + OpQueryResponse, + type WriteProtocolMessageType +} from '../commands'; +import { type OperationDescription } from '../message_stream'; +import { OP_COMPRESSED, OP_MSG } from './constants'; /** @public */ export const Compressor = Object.freeze({ @@ -124,3 +133,67 @@ function loadZstd() { zstd = getZstdLibrary(); } } + +const MESSAGE_HEADER_SIZE = 16; + +/** + * @internal + * + * Compresses an OP_MSG or OP_QUERY message, if compression is configured. This method + * also serializes the command to BSON. + */ +export async function compressCommand( + command: WriteProtocolMessageType, + description: OperationDescription +): Promise { + const finalCommand = + description.agreedCompressor === 'none' || !OpCompressedRequest.canCompress(command) + ? command + : new OpCompressedRequest(command, { + agreedCompressor: description.agreedCompressor ?? 'none', + zlibCompressionLevel: description.zlibCompressionLevel ?? 0 + }); + const data = await finalCommand.toBin(); + return Buffer.concat(data); +} + +/** + * @internal + * + * Decompresses an OP_MSG or OP_QUERY response from the server, if compression is configured. + * + * This method does not parse the response's BSON. + */ +export async function decompressResponse( + message: Buffer +): Promise { + const messageHeader: MessageHeader = { + length: message.readInt32LE(0), + requestId: message.readInt32LE(4), + responseTo: message.readInt32LE(8), + opCode: message.readInt32LE(12) + }; + + if (messageHeader.opCode !== OP_COMPRESSED) { + const ResponseType = messageHeader.opCode === OP_MSG ? OpMsgResponse : OpQueryResponse; + const messageBody = message.subarray(MESSAGE_HEADER_SIZE); + return new ResponseType(message, messageHeader, messageBody); + } + + const header: MessageHeader = { + ...messageHeader, + fromCompressed: true, + opCode: message.readInt32LE(MESSAGE_HEADER_SIZE), + length: message.readInt32LE(MESSAGE_HEADER_SIZE + 4) + }; + const compressorID = message[MESSAGE_HEADER_SIZE + 8]; + const compressedBuffer = message.slice(MESSAGE_HEADER_SIZE + 9); + + // recalculate based on wrapped opcode + const ResponseType = header.opCode === OP_MSG ? OpMsgResponse : OpQueryResponse; + const messageBody = await decompress(compressorID, compressedBuffer); + if (messageBody.length !== header.length) { + throw new MongoDecompressionError('Message body and message header must be the same length'); + } + return new ResponseType(message, header, messageBody); +} diff --git a/src/index.ts b/src/index.ts index 280a6e829a..0c013a3b3b 100644 --- a/src/index.ts +++ b/src/index.ts @@ -250,14 +250,15 @@ export type { OIDCRequestFunction } from './cmap/auth/mongodb_oidc'; export type { - BinMsg, MessageHeader, - Msg, + OpCompressedRequest, OpMsgOptions, + OpMsgRequest, + OpMsgResponse, OpQueryOptions, + OpQueryRequest, + OpQueryResponse, OpResponseOptions, - Query, - Response, WriteProtocolMessageType } from './cmap/commands'; export type { LEGAL_TCP_SOCKET_OPTIONS, LEGAL_TLS_SOCKET_OPTIONS, Stream } from './cmap/connect'; @@ -267,7 +268,11 @@ export type { ConnectionEvents, ConnectionOptions, DestroyOptions, - ProxyOptions + ModernConnection, + ProxyOptions, + read, + readMany, + writeCommand } from './cmap/connection'; export type { CloseOptions, diff --git a/test/unit/cmap/command_monitoring_events.test.js b/test/unit/cmap/command_monitoring_events.test.js index 7790d1e56b..08d06b0ed2 100644 --- a/test/unit/cmap/command_monitoring_events.test.js +++ b/test/unit/cmap/command_monitoring_events.test.js @@ -1,15 +1,15 @@ 'use strict'; -const { Msg, Query } = require('../../mongodb'); +const { OpQueryRequest, OpMsgRequest } = require('../../mongodb'); const { CommandStartedEvent } = require('../../mongodb'); const { expect } = require('chai'); describe('Command Monitoring Events - unit/cmap', function () { const commands = [ - new Query('admin', { a: { b: 10 }, $query: { b: 10 } }, {}), - new Query('hello', { a: { b: 10 }, $query: { b: 10 } }, {}), - new Msg('admin', { b: { c: 20 } }, {}), - new Msg('hello', { b: { c: 20 } }, {}), + new OpQueryRequest('admin', { a: { b: 10 }, $query: { b: 10 } }, {}), + new OpQueryRequest('hello', { a: { b: 10 }, $query: { b: 10 } }, {}), + new OpMsgRequest('admin', { b: { c: 20 } }, {}), + new OpMsgRequest('hello', { b: { c: 20 } }, {}), { ns: 'admin.$cmd', query: { $query: { a: 16 } } }, { ns: 'hello there', f1: { h: { a: 52, b: { c: 10, d: [1, 2, 3, 5] } } } } ]; @@ -17,7 +17,7 @@ describe('Command Monitoring Events - unit/cmap', function () { for (const command of commands) { it(`should make a deep copy of object of type: ${command.constructor.name}`, () => { const ev = new CommandStartedEvent({ id: 'someId', address: 'someHost' }, command); - if (command instanceof Query) { + if (command instanceof OpQueryRequest) { if (command.ns === 'admin.$cmd') { expect(ev.command !== command.query.$query).to.equal(true); for (const k in command.query.$query) { @@ -29,7 +29,7 @@ describe('Command Monitoring Events - unit/cmap', function () { expect(ev.command.filter[k]).to.deep.equal(command.query.$query[k]); } } - } else if (command instanceof Msg) { + } else if (command instanceof OpMsgRequest) { expect(ev.command !== command.command).to.equal(true); expect(ev.command).to.deep.equal(command.command); } else if (typeof command === 'object') { @@ -48,7 +48,7 @@ describe('Command Monitoring Events - unit/cmap', function () { it('should wrap a basic query option', function () { const db = 'test1'; - const query = new Query( + const query = new OpQueryRequest( `${db}`, { testCmd: 1, @@ -68,7 +68,7 @@ describe('Command Monitoring Events - unit/cmap', function () { it('should upconvert a Query wrapping a command into the corresponding command', function () { const db = 'admin'; - const query = new Query( + const query = new OpQueryRequest( `${db}`, { $query: { diff --git a/test/unit/cmap/commands.test.js b/test/unit/cmap/commands.test.js index 33161747ea..496a4b5e30 100644 --- a/test/unit/cmap/commands.test.js +++ b/test/unit/cmap/commands.test.js @@ -1,5 +1,5 @@ const { expect } = require('chai'); -const { Response } = require('../../mongodb'); +const { OpQueryResponse } = require('../../mongodb'); describe('commands', function () { describe('Response', function () { @@ -16,7 +16,7 @@ describe('commands', function () { const body = Buffer.from([]); it('throws an exception', function () { - const response = new Response(message, header, body); + const response = new OpQueryResponse(message, header, body); expect(() => response.parse()).to.throw(RangeError, /outside buffer bounds/); }); }); @@ -33,7 +33,7 @@ describe('commands', function () { body.writeInt32LE(-1, 16); it('throws an exception', function () { - const response = new Response(message, header, body); + const response = new OpQueryResponse(message, header, body); expect(() => response.parse()).to.throw(RangeError, /Invalid array length/); }); }); @@ -54,7 +54,7 @@ describe('commands', function () { it('does not throw an exception', function () { let error; try { - new Response(message, header, body); + new OpQueryResponse(message, header, body); } catch (err) { error = err; } @@ -62,47 +62,47 @@ describe('commands', function () { }); it('initializes the documents to an empty array', function () { - const response = new Response(message, header, body); + const response = new OpQueryResponse(message, header, body); expect(response.documents).to.be.empty; }); it('does not set the responseFlags', function () { - const response = new Response(message, header, body); + const response = new OpQueryResponse(message, header, body); expect(response.responseFlags).to.be.undefined; }); it('does not set the cursorNotFound flag', function () { - const response = new Response(message, header, body); + const response = new OpQueryResponse(message, header, body); expect(response.cursorNotFound).to.be.undefined; }); it('does not set the cursorId', function () { - const response = new Response(message, header, body); + const response = new OpQueryResponse(message, header, body); expect(response.cursorId).to.be.undefined; }); it('does not set startingFrom', function () { - const response = new Response(message, header, body); + const response = new OpQueryResponse(message, header, body); expect(response.startingFrom).to.be.undefined; }); it('does not set numberReturned', function () { - const response = new Response(message, header, body); + const response = new OpQueryResponse(message, header, body); expect(response.numberReturned).to.be.undefined; }); it('does not set queryFailure', function () { - const response = new Response(message, header, body); + const response = new OpQueryResponse(message, header, body); expect(response.queryFailure).to.be.undefined; }); it('does not set shardConfigStale', function () { - const response = new Response(message, header, body); + const response = new OpQueryResponse(message, header, body); expect(response.shardConfigStale).to.be.undefined; }); it('does not set awaitCapable', function () { - const response = new Response(message, header, body); + const response = new OpQueryResponse(message, header, body); expect(response.awaitCapable).to.be.undefined; }); }); diff --git a/test/unit/cmap/connection.test.ts b/test/unit/cmap/connection.test.ts index 1a8ada7af6..9c0de2e542 100644 --- a/test/unit/cmap/connection.test.ts +++ b/test/unit/cmap/connection.test.ts @@ -7,21 +7,22 @@ import { setTimeout } from 'timers'; import { promisify } from 'util'; import { - BinMsg, type ClientMetadata, connect, Connection, hasSessionSupport, type HostAddress, isHello, + type MessageHeader, MessageStream, MongoNetworkError, MongoNetworkTimeoutError, MongoRuntimeError, - Msg, ns, type OperationDescription, - Query + OpMsgRequest, + OpMsgResponse, + OpQueryRequest } from '../../mongodb'; import * as mock from '../../tools/mongodb-mock/index'; import { generateOpMsgBuffer, getSymbolFrom } from '../../tools/utils'; @@ -311,7 +312,7 @@ describe('new Connection()', function () { }; const msgBody = msg.subarray(16); - const message = new BinMsg(msg, msgHeader, msgBody); + const message = new OpMsgResponse(msg, msgHeader, msgBody); connection.onMessage(message); }); @@ -351,7 +352,7 @@ describe('new Connection()', function () { }; const msgBody = msg.subarray(16); - const message = new BinMsg(msg, msgHeader, msgBody); + const message = new OpMsgResponse(msg, msgHeader, msgBody); connection.onMessage(message); }); @@ -381,7 +382,7 @@ describe('new Connection()', function () { }; const msgBody = msg.subarray(16); - const message = new BinMsg(msg, msgHeader, msgBody); + const message = new OpMsgResponse(msg, msgHeader, msgBody); expect(() => { connection.onMessage(message); }).to.not.throw(); @@ -426,7 +427,7 @@ describe('new Connection()', function () { }; const msgBody = msg.subarray(16); - const message = new BinMsg(msg, msgHeader, msgBody); + const message = new OpMsgResponse(msg, msgHeader, msgBody); connection.onMessage(message); }); @@ -507,7 +508,7 @@ describe('new Connection()', function () { }; const msgBody = msg.subarray(16); msgBody.writeInt32LE(0, 0); // OPTS_MORE_TO_COME - connection.onMessage(new BinMsg(msg, msgHeader, msgBody)); + connection.onMessage(new OpMsgResponse(msg, msgHeader, msgBody)); // timeout is still reset expect(connection.stream).to.have.property('timeout', 0); }); @@ -524,7 +525,7 @@ describe('new Connection()', function () { const msgBody = msg.subarray(16); msgBody.writeInt32LE(2, 0); // OPTS_MORE_TO_COME connection[getSymbolFrom(connection, 'queue')].set(0, { cb: () => null }); - connection.onMessage(new BinMsg(msg, msgHeader, msgBody)); + connection.onMessage(new OpMsgResponse(msg, msgHeader, msgBody)); // timeout is still set expect(connection.stream).to.have.property('timeout', 1); }); @@ -1079,7 +1080,7 @@ describe('new Connection()', function () { } expect(writeCommandSpy).to.have.been.called; - expect(writeCommandSpy.firstCall.args[0] instanceof Msg).to.equal(true); + expect(writeCommandSpy.firstCall.args[0] instanceof OpMsgRequest).to.equal(true); }); }); @@ -1131,7 +1132,7 @@ describe('new Connection()', function () { } expect(writeCommandSpy).to.have.been.called; - expect(writeCommandSpy.firstCall.args[0] instanceof Query).to.equal(true); + expect(writeCommandSpy.firstCall.args[0] instanceof OpQueryRequest).to.equal(true); }); }); }); diff --git a/test/unit/cmap/message_stream.test.ts b/test/unit/cmap/message_stream.test.ts index c6bc8f1660..3887d1a4f1 100644 --- a/test/unit/cmap/message_stream.test.ts +++ b/test/unit/cmap/message_stream.test.ts @@ -2,7 +2,7 @@ import { expect } from 'chai'; import { on, once } from 'events'; import { Readable, Writable } from 'stream'; -import { LEGACY_HELLO_COMMAND, MessageStream, Msg } from '../../mongodb'; +import { LEGACY_HELLO_COMMAND, MessageStream, OpMsgRequest } from '../../mongodb'; import { bufferToStream, generateOpMsgBuffer } from '../../tools/utils'; describe('MessageStream', function () { @@ -139,7 +139,7 @@ describe('MessageStream', function () { const messageStream = new MessageStream(); messageStream.pipe(writeableStream); - const command = new Msg('admin', { [LEGACY_HELLO_COMMAND]: 1 }, { requestId: 3 }); + const command = new OpMsgRequest('admin', { [LEGACY_HELLO_COMMAND]: 1 }, { requestId: 3 }); messageStream.writeCommand(command, { started: 0, command: true, diff --git a/test/unit/cmap/modern_connection.test.ts b/test/unit/cmap/modern_connection.test.ts new file mode 100644 index 0000000000..c4405bcffb --- /dev/null +++ b/test/unit/cmap/modern_connection.test.ts @@ -0,0 +1,425 @@ +import { expect } from 'chai'; +import * as sinon from 'sinon'; +import { EventEmitter } from 'stream'; +import { setTimeout } from 'timers/promises'; + +// eslint-disable-next-line @typescript-eslint/no-restricted-imports +import * as compression from '../../../src/cmap/wire_protocol/compression'; +import { + decompressResponse, + LEGACY_HELLO_COMMAND, + MongoDecompressionError, + MongoParseError, + OP_COMPRESSED, + OP_MSG, + OpCompressedRequest, + OpMsgRequest, + OpMsgResponse, + type OpQueryResponse, + read, + readMany, + writeCommand +} from '../../mongodb'; + +class MockSocket extends EventEmitter { + buffer: Buffer[] = []; + push(...args: Buffer[]) { + this.buffer.push(...args); + } +} + +class MockModernConnection { + socket = new MockSocket(); +} + +describe('writeCommand', () => { + context('when compression is disabled', () => { + it('pushes an uncompressed command into the socket buffer', async () => { + const command = new OpMsgRequest('db', { find: 1 }, { requestId: 1 }); + const connection = new MockModernConnection(); + const prom = writeCommand(connection as any, command, { + agreedCompressor: 'none' + }); + + connection.socket.emit('drain'); + await prom; + + const [buffer] = connection.socket.buffer; + expect(buffer).to.exist; + const opCode = buffer.readInt32LE(12); + + expect(opCode).to.equal(OP_MSG); + }); + }); + + context('when compression is enabled', () => { + context('when the command is compressible', () => { + it('pushes a compressed command into the socket buffer', async () => { + const command = new OpMsgRequest('db', { find: 1 }, { requestId: 1 }); + const connection = new MockModernConnection(); + const prom = writeCommand(connection as any, command, { + agreedCompressor: 'snappy' + }); + + connection.socket.emit('drain'); + await prom; + + const [buffer] = connection.socket.buffer; + expect(buffer).to.exist; + const opCode = buffer.readInt32LE(12); + + expect(opCode).to.equal(OP_COMPRESSED); + }); + }); + context('when the command is not compressible', () => { + it('pushes an uncompressed command into the socket buffer', async () => { + const command = new OpMsgRequest('db', { [LEGACY_HELLO_COMMAND]: 1 }, { requestId: 1 }); + const connection = new MockModernConnection(); + const prom = writeCommand(connection as any, command, { + agreedCompressor: 'snappy' + }); + + connection.socket.emit('drain'); + await prom; + + const [buffer] = connection.socket.buffer; + expect(buffer).to.exist; + const opCode = buffer.readInt32LE(12); + + expect(opCode).to.equal(OP_MSG); + }); + }); + }); + context('when a `drain` event is not emitted from the underlying socket', () => { + it('never resolves', async () => { + const connection = new MockModernConnection(); + const promise = writeCommand(connection, new OpMsgRequest('db', { ping: 1 }, {}), { + agreedCompressor: 'none' + }); + const result = await Promise.race([promise, setTimeout(1000, 'timeout', { ref: false })]); + expect(result).to.equal('timeout'); + }); + }); + + context('when a `drain` event is emitted from the underlying socket', () => { + it('resolves', async () => { + const connection = new MockModernConnection(); + const promise = writeCommand(connection, new OpMsgRequest('db', { ping: 1 }, {}), { + agreedCompressor: 'none' + }); + connection.socket.emit('drain'); + const result = await Promise.race([promise, setTimeout(5000, 'timeout', { ref: false })]); + expect(result).to.be.undefined; + }); + }); +}); + +describe('decompressResponse()', () => { + context('when the message is not compressed', () => { + let message: Buffer; + let response: OpMsgResponse | OpQueryResponse; + let spy; + beforeEach(async () => { + message = Buffer.concat(new OpMsgRequest('db', { find: 1 }, { requestId: 1 }).toBin()); + spy = sinon.spy(compression, 'decompress'); + + response = await decompressResponse(message); + }); + afterEach(() => sinon.restore()); + it('returns a wire protocol message', () => { + expect(response).to.be.instanceOf(OpMsgResponse); + }); + it('does not attempt decompression', () => { + expect(spy).not.to.have.been.called; + }); + }); + + context('when the message is compressed', () => { + let message: Buffer; + let response: OpMsgResponse | OpQueryResponse; + beforeEach(async () => { + const msg = new OpMsgRequest('db', { find: 1 }, { requestId: 1 }); + message = Buffer.concat( + await new OpCompressedRequest(msg, { + zlibCompressionLevel: 0, + agreedCompressor: 'snappy' + }).toBin() + ); + + response = await decompressResponse(message); + }); + + it('returns a wire protocol message', () => { + expect(response).to.be.instanceOf(OpMsgResponse); + }); + it('correctly decompresses the message', () => { + response.parse({}); + expect(response.documents[0]).to.deep.equal({ $db: 'db', find: 1 }); + }); + + context( + 'when the compressed message does not match the compression metadata in the header', + () => { + beforeEach(async () => { + const msg = new OpMsgRequest('db', { find: 1 }, { requestId: 1 }); + message = Buffer.concat( + await new OpCompressedRequest(msg, { + zlibCompressionLevel: 0, + agreedCompressor: 'snappy' + }).toBin() + ); + message.writeInt32LE( + 100, + 16 + 4 // message header size + offset to length + ); // write an invalid message length into the header + }); + it('throws a MongoDecompressionError', async () => { + const error = await decompressResponse(message).catch(e => e); + expect(error).to.be.instanceOf(MongoDecompressionError); + }); + } + ); + }); +}); + +describe('read()', () => { + let connection: MockModernConnection; + let message: Buffer; + + beforeEach(() => { + connection = new MockModernConnection(); + message = Buffer.concat(new OpMsgRequest('db', { ping: 1 }, { requestId: 1 }).toBin()); + }); + it('does not resolve if there are no data events', async () => { + const promise = read(connection); + const result = await Promise.race([promise, setTimeout(1000, 'timeout', { ref: false })]); + expect(result).to.equal('timeout'); + }); + + it('does not resolve until there is a complete message', async () => { + const promise = read(connection); + { + const result = await Promise.race([promise, setTimeout(1000, 'timeout', { ref: false })]); + expect(result, 'received data on empty socket').to.equal('timeout'); + } + + { + connection.socket.emit('data', message.slice(0, 10)); + const result = await Promise.race([promise, setTimeout(1000, 'timeout', { ref: false })]); + expect( + result, + 'received data when only part of message was emitted from the socket' + ).to.equal('timeout'); + } + + { + connection.socket.emit('data', message.slice(10)); + const result = await Promise.race([promise, setTimeout(1000, 'timeout', { ref: false })]); + expect(result, 'expected OpMsgResponse - got timeout instead').to.be.instanceOf( + OpMsgResponse + ); + } + }); + + it('removes all event listeners from the socket after a message is received', async () => { + const promise = read(connection); + + connection.socket.emit('data', message); + await promise; + + expect(connection.socket.listenerCount('data')).to.equal(0); + }); + + it('when `moreToCome` is set in the response, it only returns one message', async () => { + message = Buffer.concat( + new OpMsgRequest('db', { ping: 1 }, { requestId: 1, moreToCome: true }).toBin() + ); + + const promise = read(connection); + + connection.socket.emit('data', message); + await promise; + + expect(connection.socket.listenerCount('data')).to.equal(0); + }); + + context('when reading an invalid message', () => { + context('when the message < 0', () => { + it('throws a mongo parse error', async () => { + message.writeInt32LE(-1); + const promise = read(connection).catch(e => e); + + connection.socket.emit('data', message); + const error = await promise; + expect(error).to.be.instanceof(MongoParseError); + }); + }); + + context('when the message length > max bson message size', () => { + it('throws a mongo parse error', async () => { + message.writeInt32LE(1024 * 1024 * 16 * 4 + 1); + const promise = read(connection).catch(e => e); + + connection.socket.emit('data', message); + const error = await promise; + expect(error).to.be.instanceof(MongoParseError); + }); + }); + }); + + context('when compression is enabled', () => { + it('returns a decompressed message', async () => { + const message = Buffer.concat( + await new OpCompressedRequest( + new OpMsgRequest('db', { ping: 1 }, { requestId: 1, moreToCome: true }), + { zlibCompressionLevel: 0, agreedCompressor: 'snappy' } + ).toBin() + ); + + const promise = read(connection); + + connection.socket.emit('data', message); + const result = await promise; + + expect(result).to.be.instanceOf(OpMsgResponse); + }); + }); +}); + +describe('readMany()', () => { + let connection: MockModernConnection; + let message: Buffer; + + beforeEach(() => { + connection = new MockModernConnection(); + message = Buffer.concat(new OpMsgRequest('db', { ping: 1 }, { requestId: 1 }).toBin()); + }); + it('does not resolve if there are no data events', async () => { + const generator = readMany(connection); + const result = await Promise.race([ + generator.next(), + setTimeout(1000, 'timeout', { ref: false }) + ]); + expect(result).to.equal('timeout'); + }); + + it('does not resolve until there is a complete message', async () => { + const generator = readMany(connection); + const promise = generator.next(); + { + const result = await Promise.race([promise, setTimeout(1000, 'timeout', { ref: false })]); + expect(result, 'received data on empty socket').to.equal('timeout'); + } + + { + connection.socket.emit('data', message.slice(0, 10)); + const result = await Promise.race([promise, setTimeout(1000, 'timeout', { ref: false })]); + expect( + result, + 'received data when only part of message was emitted from the socket' + ).to.equal('timeout'); + } + + { + connection.socket.emit('data', message.slice(10)); + const result = await Promise.race([promise, setTimeout(1000, 'timeout', { ref: false })]); + expect(result.value, 'expected OpMsgResponse - got timeout instead').to.be.instanceOf( + OpMsgResponse + ); + } + }); + + it('when moreToCome is set, it does not remove `data` listeners after receiving a message', async () => { + const generator = readMany(connection); + const promise = generator.next(); + message = Buffer.concat( + new OpMsgRequest('db', { ping: 1 }, { requestId: 1, moreToCome: true }).toBin() + ); + connection.socket.emit('data', message); + + const { value: response } = await promise; + + expect(response).to.be.instanceOf(OpMsgResponse); + expect(connection.socket.listenerCount('data')).to.equal(1); + }); + + it('returns messages until `moreToCome` is false', async () => { + const generator = readMany(connection); + + for ( + let i = 0, + message = Buffer.concat( + new OpMsgRequest('db', { ping: 1 }, { requestId: 1, moreToCome: true }).toBin() + ); + i < 3; + ++i + ) { + const promise = generator.next(); + connection.socket.emit('data', message); + const { value: response } = await promise; + expect(response, `response ${i} was not OpMsgResponse`).to.be.instanceOf(OpMsgResponse); + expect( + connection.socket.listenerCount('data'), + `listener count for ${i} was non-zero` + ).to.equal(1); + } + + const message = Buffer.concat( + new OpMsgRequest('db', { ping: 1 }, { requestId: 1, moreToCome: false }).toBin() + ); + const promise = generator.next(); + connection.socket.emit('data', message); + const { value: response } = await promise; + expect(response, `response was not OpMsgResponse`).to.be.instanceOf(OpMsgResponse); + expect(connection.socket.listenerCount('data')).to.equal(1); + + await generator.next(); + expect(connection.socket.listenerCount('data')).to.equal(0); + }); + + context('when reading an invalid message', () => { + context('when the message < 0', () => { + it('throws a mongo parse error', async () => { + message.writeInt32LE(-1); + const promise = readMany(connection) + .next() + .catch(e => e); + + connection.socket.emit('data', message); + const error = await promise; + expect(error).to.be.instanceof(MongoParseError); + }); + }); + + context('when the message length > max bson message size', () => { + it('throws a mongo parse error', async () => { + message.writeInt32LE(1024 * 1024 * 16 * 4 + 1); + const promise = readMany(connection) + .next() + .catch(e => e); + + connection.socket.emit('data', message); + const error = await promise; + expect(error).to.be.instanceof(MongoParseError); + }); + }); + }); + + context('when compression is enabled', () => { + it('returns a decompressed message', async () => { + const message = Buffer.concat( + await new OpCompressedRequest(new OpMsgRequest('db', { ping: 1 }, { requestId: 1 }), { + zlibCompressionLevel: 0, + agreedCompressor: 'snappy' + }).toBin() + ); + + const generator = readMany(connection); + const promise = generator.next(); + connection.socket.emit('data', message); + const { value: response } = await promise; + + expect(response).to.be.instanceOf(OpMsgResponse); + }); + }); +}); diff --git a/test/unit/commands.test.ts b/test/unit/commands.test.ts index c4c5738f39..80bc75877e 100644 --- a/test/unit/commands.test.ts +++ b/test/unit/commands.test.ts @@ -1,8 +1,21 @@ -import { BSONError } from 'bson'; +import { BSONError, deserialize } from 'bson'; import { expect } from 'chai'; - -import * as BSON from '../mongodb'; -import { BinMsg, type MessageHeader } from '../mongodb'; +import * as sinon from 'sinon'; + +// eslint-disable-next-line @typescript-eslint/no-restricted-imports +import * as compression from '../../src/cmap/wire_protocol/compression'; +import { + compress, + Compressor, + type MessageHeader, + OP_MSG, + OP_QUERY, + OpCompressedRequest, + OpMsgRequest, + OpMsgResponse, + OpQueryRequest, + uncompressibleCommands +} from '../mongodb'; const msgHeader: MessageHeader = { length: 735, @@ -45,12 +58,12 @@ describe('BinMsg BSON utf8 validation', () => { // this is a sanity check to make sure nothing unexpected is happening in the deserialize method itself const options = { validation: { utf8: { writeErrors: false } as const } }; - const deserializerCall = () => BSON.deserialize(invalidUtf8ErrorMsgDeserializeInput, options); + const deserializerCall = () => deserialize(invalidUtf8ErrorMsgDeserializeInput, options); expect(deserializerCall()).to.deep.equals(invalidUtf8InWriteErrorsJSON); }); context('when enableUtf8Validation option is not specified', () => { - const binMsgInvalidUtf8ErrorMsg = new BinMsg( + const binMsgInvalidUtf8ErrorMsg = new OpMsgResponse( Buffer.alloc(0), msgHeader, msgBodyInvalidUtf8WriteErrors @@ -62,7 +75,7 @@ describe('BinMsg BSON utf8 validation', () => { }); it('validates keys other than the writeErrors key', () => { - const binMsgAnotherKeyWithInvalidUtf8 = new BinMsg( + const binMsgAnotherKeyWithInvalidUtf8 = new OpMsgResponse( Buffer.alloc(0), msgHeader, msgBodyNKeyWithInvalidUtf8 @@ -75,7 +88,7 @@ describe('BinMsg BSON utf8 validation', () => { }); context('when validation is disabled', () => { - const binMsgInvalidUtf8ErrorMsg = new BinMsg( + const binMsgInvalidUtf8ErrorMsg = new OpMsgResponse( Buffer.alloc(0), msgHeader, msgBodyInvalidUtf8WriteErrors @@ -87,7 +100,7 @@ describe('BinMsg BSON utf8 validation', () => { }); it('does not validate keys other than the writeErrors key', () => { - const binMsgAnotherKeyWithInvalidUtf8 = new BinMsg( + const binMsgAnotherKeyWithInvalidUtf8 = new OpMsgResponse( Buffer.alloc(0), msgHeader, msgBodyNKeyWithInvalidUtf8 @@ -100,7 +113,7 @@ describe('BinMsg BSON utf8 validation', () => { }); it('disables validation by default for writeErrors if no validation specified', () => { - const binMsgInvalidUtf8ErrorMsg = new BinMsg( + const binMsgInvalidUtf8ErrorMsg = new OpMsgResponse( Buffer.alloc(0), msgHeader, msgBodyInvalidUtf8WriteErrors @@ -118,7 +131,7 @@ describe('BinMsg BSON utf8 validation', () => { context('utf8 validation enabled', () => { const options = { enableUtf8Validation: true }; it('does not validate the writeErrors key', () => { - const binMsgInvalidUtf8ErrorMsg = new BinMsg( + const binMsgInvalidUtf8ErrorMsg = new OpMsgResponse( Buffer.alloc(0), msgHeader, msgBodyInvalidUtf8WriteErrors @@ -131,7 +144,7 @@ describe('BinMsg BSON utf8 validation', () => { }); it('validates keys other than the writeErrors key', () => { - const binMsgAnotherKeyWithInvalidUtf8 = new BinMsg( + const binMsgAnotherKeyWithInvalidUtf8 = new OpMsgResponse( Buffer.alloc(0), msgHeader, msgBodyNKeyWithInvalidUtf8 @@ -143,3 +156,99 @@ describe('BinMsg BSON utf8 validation', () => { }); }); }); + +describe('class OpCompressedRequest', () => { + context('canCompress()', () => { + for (const command of uncompressibleCommands) { + it(`returns true when the command is ${command}`, () => { + const msg = new OpMsgRequest('db', { [command]: 1 }, {}); + expect(OpCompressedRequest.canCompress(msg)).to.be.false; + }); + } + + it(`returns true for a compressable command`, () => { + const msg = new OpMsgRequest('db', { find: 1 }, {}); + expect(OpCompressedRequest.canCompress(msg)).to.be.true; + }); + }); + + context('toBin()', async () => { + for (const protocol of [OpMsgRequest, OpQueryRequest]) { + context(`when ${protocol.name} is used`, () => { + let msg; + const serializedFindCommand = Buffer.concat( + new protocol('db', { find: 1 }, { requestId: 1 }).toBin() + ); + let expectedCompressedCommand; + let compressedCommand; + + beforeEach(async () => { + msg = new protocol('db', { find: 1 }, { requestId: 1 }); + expectedCompressedCommand = await compress( + { agreedCompressor: 'snappy', zlibCompressionLevel: 0 }, + serializedFindCommand.slice(16) + ); + compressedCommand = await new OpCompressedRequest(msg, { + agreedCompressor: 'snappy', + zlibCompressionLevel: 0 + }).toBin(); + }); + afterEach(() => sinon.restore()); + + it('returns an array of buffers', async () => { + expect(compressedCommand).to.be.a('array'); + expect(compressedCommand).to.have.lengthOf(3); + }); + + it('constructs a new message header for the request', async () => { + const messageHeader = compressedCommand[0]; + expect(messageHeader.byteLength, 'message header is incorrect length').to.equal(16); + expect( + messageHeader.readInt32LE(), + 'message header reports incorrect message length' + ).to.equal(16 + 9 + expectedCompressedCommand.length); + expect(messageHeader.readInt32LE(4), 'requestId incorrect').to.equal(1); + expect(messageHeader.readInt32LE(8), 'responseTo incorrect').to.equal(0); + expect(messageHeader.readInt32LE(12), 'opcode is not OP_COMPRESSED').to.equal(2012); + }); + + it('constructs the compression details for the request', async () => { + const compressionDetails = compressedCommand[1]; + expect(compressionDetails.byteLength, 'incorrect length').to.equal(9); + expect(compressionDetails.readInt32LE(), 'op code incorrect').to.equal( + protocol === OpMsgRequest ? OP_MSG : OP_QUERY + ); + expect( + compressionDetails.readInt32LE(4), + 'uncompressed message length incorrect' + ).to.equal(serializedFindCommand.length - 16); + expect(compressionDetails.readUint8(8), 'compressor incorrect').to.equal( + Compressor['snappy'] + ); + }); + + it('compresses the command', async () => { + const compressedMessage = compressedCommand[2]; + expect(compressedMessage).to.deep.equal(expectedCompressedCommand); + }); + + it('respects the zlib compression level', async () => { + const spy = sinon.spy(compression, 'compress'); + const [messageHeader] = await new OpCompressedRequest(msg, { + agreedCompressor: 'snappy', + zlibCompressionLevel: 3 + }).toBin(); + + expect(messageHeader.readInt32LE(12), 'opcode is not OP_COMPRESSED').to.equal(2012); + + expect(spy).to.have.been.called; + + expect(spy.args[0][0]).to.deep.equal({ + agreedCompressor: 'snappy', + zlibCompressionLevel: 3 + }); + }); + }); + } + }); +}); diff --git a/test/unit/index.test.ts b/test/unit/index.test.ts index d44f97f9ff..d826dbc52d 100644 --- a/test/unit/index.test.ts +++ b/test/unit/index.test.ts @@ -3,7 +3,7 @@ import { expect } from 'chai'; // Exception to the import from mongodb rule we're unit testing our public API // eslint-disable-next-line @typescript-eslint/no-restricted-imports import * as mongodb from '../../src/index'; -import { alphabetically, sorted } from '../tools/utils'; +import { setDifference } from '../mongodb'; /** * TS-NODE Adds these keys but they are undefined, they are not present when you import from lib @@ -133,10 +133,12 @@ const EXPECTED_EXPORTS = [ ]; describe('mongodb entrypoint', () => { - it('should export all and only the expected keys in expected_exports', () => { - expect(sorted(Object.keys(mongodb), alphabetically)).to.deep.equal( - sorted(EXPECTED_EXPORTS, alphabetically) - ); + it('exports all the expected keys', () => { + expect(setDifference(EXPECTED_EXPORTS, Object.keys(mongodb))).to.be.empty; + }); + + it('exports only the expected keys', () => { + expect(setDifference(Object.keys(mongodb), EXPECTED_EXPORTS)).to.be.empty; }); it('should export keys added by ts-node as undefined', () => {