diff --git a/Sources/Store.swift b/Sources/Store.swift index 71b90f4..69e8c2d 100644 --- a/Sources/Store.swift +++ b/Sources/Store.swift @@ -12,8 +12,7 @@ public class Store: Publisher { didSet { queue.sync { if state != oldValue { - stateCurrentValueSubject.send(state) - statePassthroughSubject.send(state) + stateSubject.send(state) } } } @@ -22,15 +21,12 @@ public class Store: Publisher { public init(_ state: StoreState, dispatcher: Dispatcher, - storeController: StoreController, - defaultPublisherMode: DefaultPublisherMode = .currentValue) { + storeController: StoreController) { self.initialState = state self.dispatcher = dispatcher - self.stateCurrentValueSubject = .init(state) - self.statePassthroughSubject = .init() + self.stateSubject = .init(state) self.storeController = storeController self.state = state - self.defaultPublisherMode = defaultPublisherMode } /** @@ -57,8 +53,7 @@ public class Store: Publisher { } public func replayOnce() { - stateCurrentValueSubject.send(state) - statePassthroughSubject.send(state) + stateSubject.send(state) dispatcher.stateWasReplayed(state: state) } @@ -71,56 +66,85 @@ public class Store: Publisher { publisher.receive(subscriber: subscriber) } - public var publisher: StorePublisher { - switch defaultPublisherMode { - case .passthrough: - return passthroughPublisher - - case .currentValue: - return currentValuePublisher - } - } - - public var passthroughPublisher: StorePublisher { - .init(subject: statePassthroughSubject) - } - - public var currentValuePublisher: StorePublisher { - .init(subject: stateCurrentValueSubject) + public var publisher: Publishers.StoreStatePublisher { + .init(upstream: stateSubject) } /// Scope a task from the state and receive only new updated since subscription. - public func scope(_ transform: @escaping (StoreState) -> T) -> AnyPublisher { - passthroughPublisher - .map(transform) - .removeDuplicates() - .eraseToAnyPublisher() + public func scope(_ transform: @escaping (StoreState) -> T) -> Publishers.StoreScopePublisher { + Publishers.StoreScopePublisher(upstream: stateSubject.map(transform), + initialValue: transform(state)) } - private var stateCurrentValueSubject: CurrentValueSubject - private var statePassthroughSubject: PassthroughSubject + private var stateSubject: CurrentValueSubject private let queue = DispatchQueue(label: "atomic state") - private let defaultPublisherMode: DefaultPublisherMode } -public extension Store { - enum DefaultPublisherMode { - case passthrough - case currentValue +public extension Publishers { + class StoreStatePublisher: Publisher { + public typealias Upstream = any Subject + public typealias Output = StoreState + public typealias Failure = Never + + private let upstream: Upstream + + internal init(upstream: Upstream) { + self.upstream = upstream + } + + public func receive(subscriber: S) where Failure == S.Failure, Output == S.Input { + upstream.subscribe(subscriber) + } } - class StorePublisher: Publisher { - public typealias Output = StoreState + struct StoreScopePublisher: Publisher { + public typealias Upstream = any Publisher + public typealias Output = StoreTask public typealias Failure = Never - private var subject: any Subject + private let upstream: Upstream + private let initialValue: StoreTask - internal init(subject: any Subject) { - self.subject = subject + internal init(upstream: Upstream, initialValue: StoreTask) { + self.upstream = upstream + self.initialValue = initialValue } public func receive(subscriber: S) where Failure == S.Failure, Output == S.Input { - subject.subscribe(subscriber) + upstream.subscribe(Inner(downstream: subscriber, initialValue: initialValue)) + } + } +} + +extension Publishers.StoreScopePublisher { + private class Inner: Subscriber + where Downstream.Input == Output, Downstream.Failure == Never, Output == StoreTask { + public typealias Input = Output + public typealias Failure = Never + + let combineIdentifier = CombineIdentifier() + private let downstream: Downstream + private var lastValue: StoreTask + + fileprivate init(downstream: Downstream, initialValue: StoreTask) { + self.downstream = downstream + self.lastValue = initialValue + } + + func receive(subscription: Subscription) { + downstream.receive(subscription: subscription) + } + + func receive(_ input: Output) -> Subscribers.Demand { + if input == lastValue { + return .none + } + self.lastValue = input + return downstream.receive(input) + } + + func receive(completion: Subscribers.Completion) { + downstream.receive(completion: completion) } } } diff --git a/Tests/ReducerTests.swift b/Tests/ReducerTests.swift index 0b0ee06..6d59f74 100644 --- a/Tests/ReducerTests.swift +++ b/Tests/ReducerTests.swift @@ -69,11 +69,11 @@ final class ReducerTests: XCTestCase { XCTAssertEqual(store.state, initialState) } - func test_subscribe_state_changes_with_initial_value() { + func test_subscribe_state_changes() { var cancellables = Set() let dispatcher = Dispatcher() let initialState = TestStateWithOneTask() - let store = Store(initialState, dispatcher: dispatcher, storeController: TestStoreController(), defaultPublisherMode: .currentValue) + let store = Store(initialState, dispatcher: dispatcher, storeController: TestStoreController()) let expectation1 = XCTestExpectation(description: "Subscription Emits 1") let expectation2 = XCTestExpectation(description: "Subscription Emits 2") @@ -97,38 +97,4 @@ final class ReducerTests: XCTestCase { dispatcher.dispatch(TestAction(counter: 2)) wait(for: [expectation1, expectation2], timeout: 5.0) } - - func test_subscribe_state_changes_without_initial_value() { - var cancellables = Set() - let dispatcher = Dispatcher() - let initialState = TestStateWithOneTask() - let store = Store(initialState, dispatcher: dispatcher, storeController: TestStoreController(), defaultPublisherMode: .passthrough) - let expectation = XCTestExpectation(description: "Subscription Emits") - - store - .reducerGroup() - .store(in: &cancellables) - - dispatcher.dispatch(TestAction(counter: 1)) - - DispatchQueue.main.asyncAfter(deadline: .now() + 1) { - // Only gets the action with counter == 2. - store - .map(\.counter) - .sink { counter in - if counter == 1 { - XCTFail("counter == 1 should not be emmited because this is a stateless subscription") - } - if counter == 2 { - expectation.fulfill() - } - } - .store(in: &cancellables) - - // Send action with counter == 2, this action should be caught by the two subscriptions - dispatcher.dispatch(TestAction(counter: 2)) - } - - wait(for: [expectation], timeout: 5.0) - } } diff --git a/Tests/StoreTests.swift b/Tests/StoreTests.swift index 3bb19ea..cccaef4 100644 --- a/Tests/StoreTests.swift +++ b/Tests/StoreTests.swift @@ -3,7 +3,7 @@ import Combine import XCTest class StoreTests: XCTestCase { - func test_scope() { + func test_scope_with_initial_state() { var cancellables = Set() let expectation = XCTestExpectation(description: "Scope usage check") expectation.expectedFulfillmentCount = 2 @@ -37,4 +37,45 @@ class StoreTests: XCTestCase { XCTAssertTrue(counterValue == 2) } + + func test_scope_with_initial_change_after_subscriptions() { + var cancellables = Set() + let expectation = XCTestExpectation(description: "Scope usage check") + expectation.expectedFulfillmentCount = 2 + let dispatcher = Dispatcher() + let initialState = TestStateWithTwoTasks() + let store = Store(initialState, dispatcher: dispatcher, storeController: TestStoreController()) + + var counterValue = 0 + // SCOPING.... + store + .scope { $0.testTask1 } + .sink { _ in + expectation.fulfill() + counterValue += 1 + } + .store(in: &cancellables) + + Thread.sleep(forTimeInterval: 1) + + // THIS NOT PASS: change task2 + store.state = TestStateWithTwoTasks(testTask1: store.state.testTask1, + testTask2: .success(10)) + + // THIS PASSES: change task1 + store.state = TestStateWithTwoTasks(testTask1: .success(6), + testTask2: store.state.testTask2) + + // THIS NOT PASS: change task2 + store.state = TestStateWithTwoTasks(testTask1: store.state.testTask1, + testTask2: .success(2)) + + // THIS PASSES: change task1 + store.state = TestStateWithTwoTasks(testTask1: .success(7), + testTask2: store.state.testTask2) + + wait(for: [expectation], timeout: 5.0) + + XCTAssertTrue(counterValue == 2) + } }