diff --git a/Sources/SwiftBSON/BSONDocument.swift b/Sources/SwiftBSON/BSONDocument.swift index 5ae84a2..9e8830d 100644 --- a/Sources/SwiftBSON/BSONDocument.swift +++ b/Sources/SwiftBSON/BSONDocument.swift @@ -74,12 +74,27 @@ public struct BSONDocument { * - SeeAlso: http://bsonspec.org/ */ public init(fromBSON bson: ByteBuffer) throws { - let storage = BSONDocumentStorage(bson) + let storage = BSONDocumentStorage(bson.slice()) try storage.validate() - self = BSONDocument(fromUnsafeBSON: storage) + self.storage = storage + } + + /** + * Initialize a new `BSONDocument` from the provided BSON data without validating the elements. The first four + * bytes must accurately reflect the length of the buffer, however. + * + * If invalid BSON data is provided, undefined behavior or server-side errors may occur when using the + * resultant `BSONDocument`. + * + * - Throws: `BSONError.InvalidArgumentError` if the provided BSON's length does not match the encoded length. + */ + public init(fromBSONWithoutValidatingElements bson: ByteBuffer) throws { + let storage = BSONDocumentStorage(bson) + try self.init(fromBSONWithoutValidatingElements: storage) } - internal init(fromUnsafeBSON storage: BSONDocumentStorage) { + internal init(fromBSONWithoutValidatingElements storage: BSONDocumentStorage) throws { + try storage.validateLength() self.storage = storage } @@ -129,11 +144,7 @@ public struct BSONDocument { /// The keys in this `BSONDocument`. public var keys: [String] { - do { - return try BSONDocumentIterator.getKeys(from: self.storage.buffer) - } catch { - fatalError("Failed to retrieve keys for document") - } + BSONDocumentIterator.getKeys(from: self.storage.buffer) } /// The values in this `BSONDocument`. @@ -141,11 +152,7 @@ public struct BSONDocument { /// The number of (key, value) pairs stored at the top level of this document. public var count: Int { - do { - return try BSONDocumentIterator.getKeys(from: self.storage.buffer).count - } catch { - return 0 - } + BSONDocumentIterator.getKeys(from: self.storage.buffer).count } /// A copy of the `ByteBuffer` backing this document, containing raw BSON data. As `ByteBuffer`s implement @@ -158,7 +165,8 @@ public struct BSONDocument { /// Returns a `Boolean` indicating whether this `BSONDocument` contains the provided key. public func hasKey(_ key: String) -> Bool { - (try? BSONDocumentIterator.find(key: key, in: self)) != nil + let it = self.makeIterator() + return it.findValue(forKey: key) != nil } /** @@ -174,19 +182,10 @@ public struct BSONDocument { */ public subscript(key: String) -> BSON? { get { - do { - return try BSONDocumentIterator.find(key: key, in: self)?.value - } catch { - fatalError("Error looking up key \(key) in document: \(error)") - } + BSONDocumentIterator.find(key: key, in: self)?.value } set { - // The only time this would crash is document too big error - do { - return try self.set(key: key, to: newValue) - } catch { - fatalError("Failed to set \(key) to \(String(describing: newValue)): \(error)") - } + self.set(key: key, to: newValue) } } @@ -251,10 +250,9 @@ public struct BSONDocument { ) } newStorage.buffer.writeBytes(suffix) + newStorage.encodedLength = newSize - var document = BSONDocument(fromUnsafeBSON: newStorage) - document.storage.encodedLength = newSize - return document + return try BSONDocument(fromBSONWithoutValidatingElements: newStorage) } /// Appends the provided key value pair without checking to see if the key already exists. @@ -271,8 +269,8 @@ public struct BSONDocument { * Sets a BSON element with the corresponding key * if element.value is nil the element is deleted from the BSON */ - internal mutating func set(key: String, to value: BSON?) throws { - guard let range = try BSONDocumentIterator.findByteRange(for: key, in: self) else { + private mutating func set(key: String, to value: BSON?) { + guard let range = BSONDocumentIterator.findByteRange(for: key, in: self) else { guard let value = value else { // no-op: key does not exist and the value is nil return @@ -285,18 +283,18 @@ public struct BSONDocument { let suffixLength = self.storage.encodedLength - range.endIndex guard - let prefix = self.storage.buffer.getBytes(at: 0, length: prefixLength), + var prefix = self.storage.buffer.getSlice(at: 0, length: prefixLength), let suffix = self.storage.buffer.getBytes(at: range.endIndex, length: suffixLength) else { - throw BSONError.InternalError( - message: "Cannot slice buffer from " + + fatalError( + "Cannot slice buffer from " + "0 to len \(range.startIndex) and from \(range.endIndex) " + "to len \(suffixLength) : \(self.storage.buffer)" ) } var newStorage = BSONDocumentStorage() - newStorage.buffer.writeBytes(prefix) + newStorage.buffer.writeBuffer(&prefix) var newSize = self.storage.encodedLength - (range.endIndex - range.startIndex) if let value = value { @@ -305,7 +303,7 @@ public struct BSONDocument { newSize += size guard newSize <= BSON_MAX_SIZE else { - throw BSONError.DocumentTooLargeError(value: value.bsonValue, forKey: key) + fatalError(BSONError.DocumentTooLargeError(value: value.bsonValue, forKey: key).message) } } @@ -389,8 +387,12 @@ public struct BSONDocument { return totalBytes } - internal func validate() throws { - // Pull apart the underlying binary into [KeyValuePair], should reveal issues + /// Verify that the encoded length matches the actual length of the buffer and that the buffer is + /// isn't too small or too large. + /// + /// - Throws: `BSONError.InvalidArgumentError` if validation fails + /// + internal func validateLength() throws { guard let encodedLength = self.buffer.getInteger(at: 0, endianness: .little, as: Int32.self) else { throw BSONError.InvalidArgumentError(message: "Validation Failed: Cannot read encoded length") } @@ -403,11 +405,16 @@ public struct BSONDocument { guard encodedLength == self.buffer.readableBytes else { throw BSONError.InvalidArgumentError( - message: "BSONDocument's encoded byte length is \(encodedLength), however the" + + message: "BSONDocument's encoded byte length is \(encodedLength), however the " + "buffer has \(self.buffer.readableBytes) readable bytes" ) } + } + internal func validate() throws { + try self.validateLength() + + // Pull apart the underlying binary into [KeyValuePair], should reveal issues var keySet = Set() let iter = BSONDocumentIterator(over: self.buffer) do { @@ -528,7 +535,7 @@ extension BSONDocument: BSONValue { throw BSONError.InternalError(message: "Cannot read document contents") } - return .document(BSONDocument(fromUnsafeBSON: BSONDocument.BSONDocumentStorage(bytes))) + return .document(try BSONDocument(fromBSONWithoutValidatingElements: BSONDocument.BSONDocumentStorage(bytes))) } internal func write(to buffer: inout ByteBuffer) { diff --git a/Sources/SwiftBSON/BSONDocumentIterator.swift b/Sources/SwiftBSON/BSONDocumentIterator.swift index 8677b25..1b2c463 100644 --- a/Sources/SwiftBSON/BSONDocumentIterator.swift +++ b/Sources/SwiftBSON/BSONDocumentIterator.swift @@ -21,13 +21,12 @@ public class BSONDocumentIterator: IteratorProtocol { } /// Advances to the next element and returns it, or nil if no next element exists. + /// Returns nil if invalid BSON is encountered. public func next() -> BSONDocument.KeyValuePair? { - // The only time this would crash is when the document is incorrectly formatted - do { - return try self.nextThrowing() - } catch { - fatalError("Failed to iterate to next: \(error)") - } + // soft fail on read error by returning nil. + // this should only be possible if invalid BSON bytes were provided via + // BSONDocument.init(fromBSONWithoutValidatingElements:) + try? self.nextThrowing() } /** @@ -49,12 +48,14 @@ public class BSONDocumentIterator: IteratorProtocol { /// Get the next key in the iterator, if there is one. /// This method should only be used for iterating through the keys. It advances to the beginning of the next /// element, meaning the element associated with the last returned key cannot be accessed via this iterator. - private func nextKey() throws -> String? { - guard let type = try self.readNextType() else { + /// Returns nil if invalid BSON is encountered. + private func nextKey() -> String? { + guard let type = try? self.readNextType(), let key = try? self.buffer.readCString() else { + return nil + } + guard self.skipNextValue(type: type) else { return nil } - let key = try self.buffer.readCString() - try self.skipNextValue(type: type) return key } @@ -95,126 +96,197 @@ public class BSONDocumentIterator: IteratorProtocol { return bsonType } + /// Search for the value associated with the given key, returning its type if found and nil otherwise. + /// This moves the iterator right up to the first byte of the value. + /// Returns nil if invalid BSON is encountered. + internal func findValue(forKey key: String) -> BSONType? { + guard !self.exhausted else { + return nil + } + + let keyUTF8 = key.utf8 + + while true { + var bsonType = BSONType.invalid + let matchResult = self.buffer.readWithUnsafeReadableBytes { buffer -> (Int, Bool?) in + var matched = true + + var keyIter = keyUTF8.makeIterator() + for (i, byte) in buffer.enumerated() { + // first byte is type of element + guard i != 0 else { + guard let type = BSONType(rawValue: byte), type != .invalid else { + return (1, nil) + } + bsonType = type + continue + } + + guard byte != 0 else { + // hit the null terminator + return (i + 1, matched && keyIter.next() == nil) + } + + // if matched the key so far, check the next character + if matched { + guard let keyByte = keyIter.next() else { + matched = false + continue + } + matched = byte == keyByte + } + } + + // unterminated C string, so we read the whole buffer + return (buffer.count, nil) + } + + guard let matched = matchResult else { + // encountered invalid BSON, just return nil + return nil + } + + guard matched else { + guard self.skipNextValue(type: bsonType) else { + return nil + } + continue + } + + return bsonType + } + } + /// Finds an element with the specified key in the document. Returns nil if the key is not found. - internal static func find(key: String, in document: BSONDocument) throws -> BSONDocument.KeyValuePair? { + /// Returns nil if invalid BSON is encountered when trying to find the key or read the value. + internal static func find(key: String, in document: BSONDocument) -> BSONDocument.KeyValuePair? { let iter = document.makeIterator() - while let type = try iter.readNextType() { - let foundKey = try iter.buffer.readCString() - if foundKey == key { - // the map contains a value for every valid BSON type. - // swiftlint:disable:next force_unwrapping - let bson = try BSON.allBSONTypes[type]!.read(from: &iter.buffer) - return (key: key, value: bson) - } - try iter.skipNextValue(type: type) + guard let bsonType = iter.findValue(forKey: key) else { + return nil + } + // the map contains a value for every valid BSON type. + // swiftlint:disable:next force_unwrapping + guard let bson = try? BSON.allBSONTypes[bsonType]!.read(from: &iter.buffer) else { + return nil + } + return (key: key, value: bson) + } + + /// Move the reader index for the underlying buffer forward by the provided amount if possible. + /// Returns true if the index was moved successfully and false otherwise. + /// + /// This will only fail if the underlying buffer contains invalid BSON. + private func moveReaderIndexSafely(forwardBy amount: Int) -> Bool { + guard amount > 0 && self.buffer.readerIndex + amount <= self.buffer.writerIndex else { + return false } - return nil + self.buffer.moveReaderIndex(forwardBy: amount) + return true } /// Given the type of the encoded value starting at self.buffer.readerIndex, advances the reader index to the index /// after the end of the element. - internal func skipNextValue(type: BSONType) throws { + /// + /// Returns false if invalid BSON is encountered while trying to skip, returns true otherwise. + internal func skipNextValue(type: BSONType) -> Bool { switch type { case .invalid: - throw BSONIterationError(message: "encountered invalid BSON type") + return false case .undefined, .null, .minKey, .maxKey: // no data stored, nothing to skip. - return + break case .bool: - self.buffer.moveReaderIndex(forwardBy: 1) + return self.moveReaderIndexSafely(forwardBy: 1) case .double, .int64, .timestamp, .datetime: - self.buffer.moveReaderIndex(forwardBy: 8) + return self.moveReaderIndexSafely(forwardBy: 8) case .objectID: - self.buffer.moveReaderIndex(forwardBy: 12) + return self.moveReaderIndexSafely(forwardBy: 12) case .int32: - self.buffer.moveReaderIndex(forwardBy: 4) + return self.moveReaderIndexSafely(forwardBy: 4) case .string, .code, .symbol: guard let strLength = buffer.readInteger(endianness: .little, as: Int32.self) else { - throw BSONError.InternalError(message: "Failed to read encoded string length") + return false } - self.buffer.moveReaderIndex(forwardBy: Int(strLength)) + return self.moveReaderIndexSafely(forwardBy: Int(strLength)) case .regex: - _ = try self.buffer.readCString() - _ = try self.buffer.readCString() + do { + _ = try self.buffer.readCString() + _ = try self.buffer.readCString() + } catch { + return false + } case .binary: guard let dataLength = buffer.readInteger(endianness: .little, as: Int32.self) else { - throw BSONError.InternalError(message: "Failed to read encoded binary data length") + return false } - self.buffer.moveReaderIndex(forwardBy: Int(dataLength) + 1) // +1 for the binary subtype. + return self.moveReaderIndexSafely(forwardBy: Int(dataLength) + 1) // +1 for the binary subtype. case .document, .array, .codeWithScope: guard let embeddedDocLength = buffer.readInteger(endianness: .little, as: Int32.self) else { - throw BSONError.InternalError(message: "Failed to read encoded document length") + return false } // -4 because the encoded length includes the bytes necessary to store the length itself. - self.buffer.moveReaderIndex(forwardBy: Int(embeddedDocLength) - 4) + return self.moveReaderIndexSafely(forwardBy: Int(embeddedDocLength) - 4) case .dbPointer: // initial string guard let strLength = buffer.readInteger(endianness: .little, as: Int32.self) else { - throw BSONError.InternalError(message: "Failed to read encoded string length") + return false } - self.buffer.moveReaderIndex(forwardBy: Int(strLength)) - // 12 bytes of data - self.buffer.moveReaderIndex(forwardBy: 12) + return self.moveReaderIndexSafely(forwardBy: Int(strLength) + 12) case .decimal128: - self.buffer.moveReaderIndex(forwardBy: 16) + return self.moveReaderIndexSafely(forwardBy: 16) } + + return true } /// Finds the key in the underlying buffer, and returns the [startIndex, endIndex) containing the corresponding /// element. - internal static func findByteRange(for searchKey: String, in document: BSONDocument) throws -> Range? { + /// Returns nil if invalid BSON is encountered. + internal static func findByteRange(for searchKey: String, in document: BSONDocument) -> Range? { let iter = document.makeIterator() - while true { - let startIndex = iter.buffer.readerIndex - guard let type = try iter.readNextType() else { - return nil - } - let foundKey = try iter.buffer.readCString() - try iter.skipNextValue(type: type) + guard let type = iter.findValue(forKey: searchKey) else { + return nil + } - if foundKey == searchKey { - let endIndex = iter.buffer.readerIndex - return startIndex.. [String] { + /// If invalid BSON is encountered while retrieving the keys, any valid keys seen up to that point are returned. + internal static func getKeys(from buffer: ByteBuffer) -> [String] { let iter = BSONDocumentIterator(over: buffer) var keys = [String]() - while let key = try iter.nextKey() { + while let key = iter.nextKey() { keys.append(key) } return keys } - /// Retrieves an unordered list of the keys in the provided document buffer. - internal static func getKeySet(from buffer: ByteBuffer) throws -> Set { - let iter = BSONDocumentIterator(over: buffer) - var keySet: Set = [] - while let key = try iter.nextKey() { - keySet.insert(key) - } - return keySet - } - // uses an iterator to copy (key, value) pairs of the provided document from range [startIndex, endIndex) into a new // document. starts at the startIndex-th pair and ends at the end of the document or the (endIndex-1)th index, // whichever comes first. + // If invalid BSON is encountered before getting to the ith element, a new, empty document will be returned. + // If invalid BSON is encountered while iterating over elements included in the subsequence, a document containing + // the elements in the subsequence that came before the invalid BSON will be returned. internal static func subsequence( of doc: BSONDocument, startIndex: Int = 0, @@ -234,23 +306,26 @@ public class BSONDocumentIterator: IteratorProtocol { break } _ = try iter.buffer.readCString() - try iter.skipNextValue(type: type) - } - - var newDoc = BSONDocument() - - for _ in startIndex..