Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion Sources/MongoSwift/BSON/AnyBSONValue.swift
Original file line number Diff line number Diff line change
Expand Up @@ -63,7 +63,7 @@ public struct AnyBSONValue: Codable, Equatable, Hashable {
}

public static func == (lhs: AnyBSONValue, rhs: AnyBSONValue) -> Bool {
return bsonEquals(lhs.value, rhs.value)
return lhs.value.bsonEquals(rhs.value)
}

/**
Expand Down
11 changes: 11 additions & 0 deletions Sources/MongoSwift/BSON/BSONEncoder.swift
Original file line number Diff line number Diff line change
Expand Up @@ -751,6 +751,10 @@ private class MutableArray: BSONValue {
required convenience init(from decoder: Decoder) throws {
fatalError("`MutableArray` is not meant to be initialized from a `Decoder`")
}

func bsonEquals(_ other: BSONValue?) -> Bool {
return self.array.bsonEquals(other)
}
}

/// A private class wrapping a Swift dictionary so we can pass it by reference
Expand Down Expand Up @@ -799,6 +803,13 @@ private class MutableDictionary: BSONValue {

init() {}

func bsonEquals(_ other: BSONValue?) -> Bool {
guard let otherDict = other as? MutableDictionary else {
return false
}
return otherDict.keys == self.keys && otherDict.values.bsonEquals(self.values)
}

/// methods required by the BSONValue protocol that we don't actually need/use. MutableDictionary
/// is just a BSONValue to simplify usage alongside true BSONValues within the encoder.
public static func from(iterator iter: DocumentIterator) -> Self {
Expand Down
96 changes: 51 additions & 45 deletions Sources/MongoSwift/BSON/BSONValue.swift
Original file line number Diff line number Diff line change
Expand Up @@ -56,29 +56,53 @@ public protocol BSONValue {
var bsonType: BSONType { get }

/**
* Given the `DocumentStorage` backing a `Document`, appends this `BSONValue` to the end.
*
* - Parameters:
* - storage: A `DocumentStorage` to write to.
* - key: A `String`, the key under which to store the value.
*
* - Throws:
* - `RuntimeError.internalError` if the `DocumentStorage` would exceed the maximum size by encoding this
* key-value pair.
* - `UserError.logicError` if the value is an `Array` and it contains a non-`BSONValue` element.
*/
* Given the `DocumentStorage` backing a `Document`, appends this `BSONValue` to the end.
*
* - Parameters:
* - storage: A `DocumentStorage` to write to.
* - key: A `String`, the key under which to store the value.
*
* - Throws:
* - `RuntimeError.internalError` if the `DocumentStorage` would exceed the maximum size by encoding this
* key-value pair.
* - `UserError.logicError` if the value is an `Array` and it contains a non-`BSONValue` element.
*/
func encode(to storage: DocumentStorage, forKey key: String) throws

/**
* Given a `DocumentIterator` known to have a next value of this type,
* initializes the value.
*
* - Throws: `UserError.logicError` if the current type of the `DocumentIterator` does not correspond to the
* associated type of this `BSONValue`.
*/
* Function to test equality with another `BSONValue`. This function tests for exact BSON equality.
* This means that differing types with equivalent value are not equivalent.
*
* e.g.
* 4.0 (Double) != 4 (Int)
*
* - Parameters:
* - other: The right-hand-side `BSONValue` to compare.
*
* - Returns: `true` if `self` is equal to `rhs`, `false` otherwise.
*/
func bsonEquals(_ other: BSONValue?) -> Bool

/**
* Given a `DocumentIterator` known to have a next value of this type,
* initializes the value.
*
* - Throws: `UserError.logicError` if the current type of the `DocumentIterator` does not correspond to the
* associated type of this `BSONValue`.
*/
static func from(iterator iter: DocumentIterator) throws -> Self
}

extension BSONValue where Self: Equatable {
/// Default implementation of `bsonEquals` for `BSONValue`s that conform to `Equatable`.
public func bsonEquals(_ other: BSONValue?) -> Bool {
guard let otherAsSelf = other as? Self else {
return false
}
return self == otherAsSelf
}
}

/// An extension of `Array` to represent the BSON array type.
extension Array: BSONValue {
public var bsonType: BSONType { return .array }
Expand Down Expand Up @@ -128,6 +152,13 @@ extension Array: BSONValue {
throw bsonTooLargeError(value: self, forKey: key)
}
}

public func bsonEquals(_ other: BSONValue?) -> Bool {
guard let otherArr = other as? [BSONValue], let selfArr = self as? [BSONValue] else {
return false
}
return self.count == otherArr.count && zip(selfArr, otherArr).allSatisfy { lhs, rhs in lhs.bsonEquals(rhs) }
}
}

/// A struct to represent the BSON null type.
Expand Down Expand Up @@ -1083,8 +1114,6 @@ public struct BSONUndefined: BSONValue, Equatable, Codable {
}
}

// See https://github.com/realm/SwiftLint/issues/461
// swiftlint:disable cyclomatic_complexity
/**
* A helper function to test equality between two `BSONValue`s. This function tests for exact BSON equality.
* This means that differing types with equivalent value are not equivalent.
Expand All @@ -1101,33 +1130,9 @@ public struct BSONUndefined: BSONValue, Equatable, Codable {
*
* - Returns: `true` if `lhs` is equal to `rhs`, `false` otherwise.
*/
@available(*, deprecated, message: "Use lhs.bsonEquals(rhs) instead")
public func bsonEquals(_ lhs: BSONValue, _ rhs: BSONValue) -> Bool {
switch (lhs, rhs) {
case let (l as Int, r as Int): return l == r
case let (l as Int32, r as Int32): return l == r
case let (l as Int64, r as Int64): return l == r
case let (l as Double, r as Double): return l == r
case let (l as Decimal128, r as Decimal128): return l == r
case let (l as Bool, r as Bool): return l == r
case let (l as String, r as String): return l == r
case let (l as RegularExpression, r as RegularExpression): return l == r
case let (l as Timestamp, r as Timestamp): return l == r
case let (l as Date, r as Date): return l == r
case (_ as MinKey, _ as MinKey): return true
case (_ as MaxKey, _ as MaxKey): return true
case let (l as ObjectId, r as ObjectId): return l == r
case let (l as CodeWithScope, r as CodeWithScope): return l == r
case let (l as Binary, r as Binary): return l == r
case (_ as BSONNull, _ as BSONNull): return true
case let (l as Document, r as Document): return l == r
case let (l as [BSONValue], r as [BSONValue]): // TODO: SWIFT-242
return l.count == r.count && zip(l, r).reduce(true, { prev, next in prev && bsonEquals(next.0, next.1) })
case (_ as [Any], _ as [Any]): return false
case let (l as Symbol, r as Symbol): return l == r
case let (l as DBPointer, r as DBPointer): return l == r
case (_ as BSONUndefined, _ as BSONUndefined): return true
default: return false
}
return lhs.bsonEquals(rhs)
}

/**
Expand All @@ -1140,6 +1145,7 @@ public func bsonEquals(_ lhs: BSONValue, _ rhs: BSONValue) -> Bool {
*
* - Returns: True if lhs is equal to rhs, false otherwise.
*/
@available(*, deprecated, message: "use lhs?.bsonEquals(rhs) instead")
public func bsonEquals(_ lhs: BSONValue?, _ rhs: BSONValue?) -> Bool {
guard let left = lhs, let right = rhs else {
return lhs == nil && rhs == nil
Expand Down
4 changes: 2 additions & 2 deletions Tests/MongoSwiftTests/BSONValueTests.swift
Original file line number Diff line number Diff line change
Expand Up @@ -85,15 +85,15 @@ final class BSONValueTests: MongoSwiftTestCase {
)
// Check that when an array contains non-BSONValues, we return false
let arr = [[String: Int]()]
expect(bsonEquals(arr, arr)).to(beFalse())
expect(arr.bsonEquals(arr)).to(beFalse())

// Different types
expect(4).toNot(bsonEqual("swift"))

// Arrays of different sizes should not be equal
let b0: [BSONValue] = [1, 2]
let b1: [BSONValue] = [1, 2, 3]
expect(bsonEquals(b0, b1)).to(beFalse())
expect(b0.bsonEquals(b1)).to(beFalse())
}

/// Test object for ObjectIdRoundTrip
Expand Down
2 changes: 1 addition & 1 deletion Tests/MongoSwiftTests/TestUtils.swift
Original file line number Diff line number Diff line change
Expand Up @@ -267,7 +267,7 @@ internal func bsonEqual(_ expectedValue: BSONValue?) -> Predicate<BSONValue> {
case (nil, nil), (_, nil):
return PredicateResult(status: .fail, message: msg)
case let (expected?, actual?):
let matches = bsonEquals(expected, actual)
let matches = expected.bsonEquals(actual)
return PredicateResult(bool: matches, message: msg)
}
}
Expand Down