From 31e9537d2265a265e7bbdca00ff128e7e521aa8c Mon Sep 17 00:00:00 2001 From: Eren Kabakci Date: Mon, 27 Apr 2020 18:07:07 +0200 Subject: [PATCH 1/3] Overhaul token refresh mechanism and discard dispatchGroup --- Sources/SessionPublisherProtocol.swift | 2 +- .../AuthenticatedWebService.swift | 127 +++++----- .../AuthenticationTokenProvidable.swift | 9 +- Sources/WebService/WebService.swift | 28 +- Tests/MockSession.swift | 22 +- .../AsyncTokenRefreshTests.swift | 239 +++++++----------- .../AuthenticatedWebServiceTests.swift | 67 +++-- fusion.xcodeproj/project.pbxproj | 2 +- 8 files changed, 242 insertions(+), 254 deletions(-) diff --git a/Sources/SessionPublisherProtocol.swift b/Sources/SessionPublisherProtocol.swift index 170d349..c326bbc 100644 --- a/Sources/SessionPublisherProtocol.swift +++ b/Sources/SessionPublisherProtocol.swift @@ -32,7 +32,7 @@ public protocol SessionPublisherProtocol: AnyObject { extension URLSession: SessionPublisherProtocol { public func dataTaskPublisher(for request: URLRequest) -> AnyPublisher<(data: Data, response: URLResponse), Error> { self.dataTaskPublisher(for: request) - .receive(on: DispatchQueue.main) + .subscribe(on: DispatchQueue.global()) .mapError { NetworkError.urlError($0) } .eraseToAnyPublisher() } diff --git a/Sources/WebService/AuthenticatedWebService/AuthenticatedWebService.swift b/Sources/WebService/AuthenticatedWebService/AuthenticatedWebService.swift index 5a1635a..551fe18 100644 --- a/Sources/WebService/AuthenticatedWebService/AuthenticatedWebService.swift +++ b/Sources/WebService/AuthenticatedWebService/AuthenticatedWebService.swift @@ -33,10 +33,10 @@ public enum AuthorizationHeaderScheme: String { public struct AuthenticatedWebServiceConfiguration { let authorizationHeaderScheme: AuthorizationHeaderScheme - let refreshTriggerErrors: [Error] + let refreshTriggerErrors: [NetworkError] public init(authorizationHeaderScheme: AuthorizationHeaderScheme = .none, - refreshTriggerErrors: [Error] = [NetworkError.unauthorized]) { + refreshTriggerErrors: [NetworkError] = [NetworkError.unauthorized]) { self.authorizationHeaderScheme = authorizationHeaderScheme self.refreshTriggerErrors = refreshTriggerErrors } @@ -56,84 +56,79 @@ open class AuthenticatedWebService: WebService { self.configuration = configuration super.init(urlSession: urlSession) } - + override public func execute(urlRequest: URLRequest) -> AnyPublisher where T : Decodable { var urlRequest = urlRequest - var currentAccessToken: String? - authenticationQueue.sync { - currentAccessToken = self.tokenProvider.accessToken.value + func appendTokenAndExecute(accessToken: AccessToken) -> AnyPublisher { + urlRequest.setValue(self.configuration.authorizationHeaderScheme.rawValue + accessToken, forHTTPHeaderField: "Authorization") + return super.execute(urlRequest: urlRequest) + .subscribe(on: DispatchQueue.global()) + .eraseToAnyPublisher() } - - guard let accessToken = currentAccessToken else { - return Fail(error: NetworkError.unauthorized).eraseToAnyPublisher() + + return Deferred { + self.tokenProvider.accessToken + .compactMap { $0 } + .setFailureType(to: Error.self) + .flatMap { accessToken -> AnyPublisher in + return appendTokenAndExecute(accessToken: accessToken) + } + }.catch { [weak self] error -> AnyPublisher in + guard let self = self else { + return Fail(error: NetworkError.unknown).eraseToAnyPublisher() + } + + if self.configuration.refreshTriggerErrors.contains(where: { return $0.reflectedString == error.reflectedString }){ + return self.retrySynchronizedTokenRefresh() + .flatMap { + appendTokenAndExecute(accessToken: $0) + }.eraseToAnyPublisher() + } + return Fail(error: error).eraseToAnyPublisher() } - - urlRequest.setValue(self.configuration.authorizationHeaderScheme.rawValue + accessToken, forHTTPHeaderField: "Authorization") - - return super.execute(urlRequest: urlRequest) - .catch { [weak self] error -> AnyPublisher in - guard let self = self else { - return Fail(error: NetworkError.unknown).eraseToAnyPublisher() - } - - if self.configuration.refreshTriggerErrors.contains(where: { return $0.reflectedString == error.reflectedString }){ - self.retrySynchronizedTokenRefresh() - - return self.execute(urlRequest: urlRequest) - .delay(for: 0.2, scheduler: self.authenticationQueue) - .eraseToAnyPublisher() - } - return Fail(error: error).eraseToAnyPublisher() - }.eraseToAnyPublisher() + .receive(on: DispatchQueue.main) + .eraseToAnyPublisher() } - - + override public func execute(urlRequest: URLRequest) -> AnyPublisher { var urlRequest = urlRequest - var currentAccessToken: String? - authenticationQueue.sync { - currentAccessToken = self.tokenProvider.accessToken.value + func appendTokenAndExecute(accessToken: AccessToken) -> AnyPublisher { + urlRequest.setValue(self.configuration.authorizationHeaderScheme.rawValue + accessToken, forHTTPHeaderField: "Authorization") + return super.execute(urlRequest: urlRequest) + .subscribe(on: DispatchQueue.global()) + .eraseToAnyPublisher() } - guard let accessToken = currentAccessToken else { - return Fail(error: NetworkError.unauthorized).eraseToAnyPublisher() + return Deferred { + self.tokenProvider.accessToken + .compactMap { $0 } + .setFailureType(to: Error.self) + .flatMap { accessToken -> AnyPublisher in + return appendTokenAndExecute(accessToken: accessToken) + } + }.catch { [weak self] error -> AnyPublisher in + guard let self = self else { + return Fail(error: NetworkError.unknown).eraseToAnyPublisher() + } + + if self.configuration.refreshTriggerErrors.contains(where: { return $0.reflectedString == error.reflectedString }){ + return self.retrySynchronizedTokenRefresh() + .flatMap { + appendTokenAndExecute(accessToken: $0) + }.eraseToAnyPublisher() + } + return Fail(error: error).eraseToAnyPublisher() } - - urlRequest.setValue(self.configuration.authorizationHeaderScheme.rawValue + accessToken, forHTTPHeaderField: "Authorization") - - return super.execute(urlRequest: urlRequest) - .catch { [weak self] error -> AnyPublisher in - guard let self = self else { - return Fail(error: NetworkError.unknown).eraseToAnyPublisher() - } - - if self.configuration.refreshTriggerErrors.contains(where: { return $0.reflectedString == error.reflectedString }){ - self.retrySynchronizedTokenRefresh() - - return self.execute(urlRequest: urlRequest) - .delay(for: 0.2, scheduler: self.authenticationQueue) - .eraseToAnyPublisher() - } - return Fail(error: error).eraseToAnyPublisher() - }.eraseToAnyPublisher() + .receive(on: DispatchQueue.main) + .eraseToAnyPublisher() } - private func retrySynchronizedTokenRefresh() { - let dispatchGroup = DispatchGroup() - dispatchGroup.enter() - - authenticationQueue.sync(flags: .barrier) { - self.tokenProvider.invalidateAccessToken() - self.tokenProvider.reissueAccessToken() - .sink(receiveCompletion: { _ in - dispatchGroup.leave() - }, - receiveValue: { _ in }) - .store(in: &self.subscriptions) - dispatchGroup.wait() - } + private func retrySynchronizedTokenRefresh() -> AnyPublisher { + tokenProvider.invalidateAccessToken() + return tokenProvider.reissueAccessToken() + .eraseToAnyPublisher() } } diff --git a/Sources/WebService/AuthenticatedWebService/AuthenticationTokenProvidable.swift b/Sources/WebService/AuthenticatedWebService/AuthenticationTokenProvidable.swift index 9d5a1b7..ae62ea5 100644 --- a/Sources/WebService/AuthenticatedWebService/AuthenticationTokenProvidable.swift +++ b/Sources/WebService/AuthenticatedWebService/AuthenticationTokenProvidable.swift @@ -25,10 +25,13 @@ import Combine import Foundation +public typealias AccessToken = String +public typealias RefreshToken = String + public protocol AuthenticationTokenProvidable: AnyObject { - var accessToken: CurrentValueSubject { get } - var refreshToken: CurrentValueSubject { get } - func reissueAccessToken() -> AnyPublisher + var accessToken: CurrentValueSubject { get } + var refreshToken: CurrentValueSubject { get } + func reissueAccessToken() -> AnyPublisher func invalidateAccessToken() func invalidateRefreshToken() } diff --git a/Sources/WebService/WebService.swift b/Sources/WebService/WebService.swift index cd32183..8813f25 100644 --- a/Sources/WebService/WebService.swift +++ b/Sources/WebService/WebService.swift @@ -13,7 +13,7 @@ // // The above copyright notice and this permission notice shall be included in all // copies or substantial portions of the Software. -// +// // THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR // IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, // FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE @@ -35,7 +35,7 @@ CustomDecodable{ public let jsonDecoder: JSONDecoder = JSONDecoder() private let session: SessionPublisherProtocol @ThreadSafe open var subscriptions = Set() - + public init(urlSession: SessionPublisherProtocol = URLSession(configuration: URLSessionConfiguration.ephemeral, delegate: nil, delegateQueue: nil)) { @@ -49,9 +49,11 @@ CustomDecodable{ return Fail(error: NetworkError.unknown).eraseToAnyPublisher() } return CurrentValueSubject((output.data, httpResponse)).eraseToAnyPublisher() - }.eraseToAnyPublisher() + } + .receive(on: DispatchQueue.main) + .eraseToAnyPublisher() } - + public func execute(urlRequest: URLRequest) -> AnyPublisher where T : Decodable { Deferred { Future { [weak self] promise in @@ -62,8 +64,9 @@ CustomDecodable{ var urlRequest = urlRequest urlRequest.appendAdditionalHeaders(headers: self.defaultHttpHeaders) - + self.rawResponse(urlRequest: urlRequest) + .subscribe(on: DispatchQueue.global()) .tryMap { try self.mapHttpResponseCodes(output: $0) @@ -80,9 +83,11 @@ CustomDecodable{ receiveValue: { promise(.success($0)) }) .store(in: &self.subscriptions) } - }.eraseToAnyPublisher() + } + .receive(on: DispatchQueue.main) + .eraseToAnyPublisher() } - + public func execute(urlRequest: URLRequest) -> AnyPublisher { Deferred { Future { [weak self] promise in @@ -90,11 +95,12 @@ CustomDecodable{ promise(.failure(NetworkError.unknown)) return } - + var urlRequest = urlRequest urlRequest.appendAdditionalHeaders(headers: self.defaultHttpHeaders) - + self.rawResponse(urlRequest: urlRequest) + .subscribe(on: DispatchQueue.global()) .tryMap { try self.mapHttpResponseCodes(output: $0) return @@ -107,7 +113,9 @@ CustomDecodable{ receiveValue: { promise(.success($0)) }) .store(in: &self.subscriptions) } - }.eraseToAnyPublisher() + } + .receive(on: DispatchQueue.main) + .eraseToAnyPublisher() } open func mapHttpResponseCodes(output: (data:Data, response: HTTPURLResponse)) throws { diff --git a/Tests/MockSession.swift b/Tests/MockSession.swift index 760a21a..0a7d564 100644 --- a/Tests/MockSession.swift +++ b/Tests/MockSession.swift @@ -35,16 +35,18 @@ open class MockSession: SessionPublisherProtocol { public func dataTaskPublisher(for urlRequest: URLRequest) -> AnyPublisher<(data: Data, response: URLResponse), Error> { methodCallStack.append(#function) finalUrlRequest = urlRequest - return Future<(data: Data, response: URLResponse), Error> { promise in - usleep(20) - if let successResponse = self.result?.0 { - promise(.success((successResponse.0, - HTTPURLResponse(url: URL(string: "foo.com")!, - statusCode: successResponse.1, - httpVersion: nil, - headerFields: nil)!))) - } else if let errorResponse = self.result?.1 { - promise(.failure(NetworkError.urlError(errorResponse))) + return Deferred { + Future<(data: Data, response: URLResponse), Error> { promise in + usleep(20) + if let successResponse = self.result?.0 { + promise(.success((successResponse.0, + HTTPURLResponse(url: URL(string: "foo.com")!, + statusCode: successResponse.1, + httpVersion: nil, + headerFields: nil)!))) + } else if let errorResponse = self.result?.1 { + promise(.failure(NetworkError.urlError(errorResponse))) + } } }.eraseToAnyPublisher() } diff --git a/Tests/WebService/AuthenticatedWebService/AsyncTokenRefreshTests.swift b/Tests/WebService/AuthenticatedWebService/AsyncTokenRefreshTests.swift index 9d8347b..03f2790 100644 --- a/Tests/WebService/AuthenticatedWebService/AsyncTokenRefreshTests.swift +++ b/Tests/WebService/AuthenticatedWebService/AsyncTokenRefreshTests.swift @@ -12,160 +12,113 @@ import EntwineTest @testable import fusion class AsyncTokenRefreshTests: XCTestCase { - private var session: MockAuthenticatedServiceSession! - private var tokenProvider: MockTokenProvider! - private var webService: AuthenticatedWebService! - private var subscriptions = Set() - private let encoder = JSONEncoder() - - override func setUp() { - super.setUp() - session = MockAuthenticatedServiceSession() - tokenProvider = MockTokenProvider() - webService = AuthenticatedWebService(urlSession: session, - tokenProvider: tokenProvider) - - let encodedData = try! self.encoder.encode(["id": "value"]) - tokenProvider.accessToken - .sink(receiveValue: { - if $0 == "newToken" { - print("Change session response to 200") - self.session.result = ((encodedData, 200), nil) - } - }).store(in: &subscriptions) - } - - func test_givenAuthenticatedWebService_whenParallelRequests_andTokenRefreshAttempts_thenShouldWaitForTokenRefreshing() { - let request = URLRequest(url: URL(string: "foo.com")!) - let expectation1 = self.expectation(description: "parallel reqeust test has failed") - let expectation2 = self.expectation(description: "parallel reqeust test has failed") - - tokenProvider.accessToken.send("invalidToken") - session.result = ((Data(), 401), nil) - - DispatchQueue.global(qos: .default).async { - print("Request 1 started") - self.webService.execute(urlRequest: request) - .sink(receiveCompletion: { - if case .finished = $0 { - expectation1.fulfill() - print("Request 1 finished") - } - else { - XCTFail("should not receive failure since the token is refreshed") - } - }, - receiveValue: { _ in - print("Request 1 received value") - XCTAssertEqual(self.tokenProvider.methodCallStack, ["invalidateAccessToken()", "reissueAccessToken()", "invalidateAccessToken()", "reissueAccessToken()"]) - }) - .store(in: &self.subscriptions) + private var session: MockAuthenticatedServiceSession! + private var tokenProvider: MockTokenProvider! + private var webService: AuthenticatedWebService! + private var subscriptions = Set() + private let encoder = JSONEncoder() + + override func setUp() { + super.setUp() + session = MockAuthenticatedServiceSession() + tokenProvider = MockTokenProvider() + subscriptions = Set() + webService = AuthenticatedWebService(urlSession: session, + tokenProvider: tokenProvider) + + let encodedData = try! self.encoder.encode(["id": "value"]) + tokenProvider.accessToken + .sink(receiveValue: { + if $0 == "newToken" { + print("Change session response to 200") + self.session.result = ((encodedData, 200), nil) } - - DispatchQueue.global(qos: .default).async { - print("Request 2 started") - self.webService.execute(urlRequest: request) - .sink(receiveCompletion: { - if case .finished = $0 { - print("Request 2 finished") - expectation2.fulfill() - } - else { - XCTFail("should not receive failure since the token is refreshed") - } - }, - receiveValue: { (_: SampleResponse) in - print("Request 2 received value") - XCTAssertEqual(self.tokenProvider.methodCallStack, ["invalidateAccessToken()", "reissueAccessToken()", "invalidateAccessToken()", "reissueAccessToken()"]) - }) - .store(in: &self.subscriptions) - } - - waitForExpectations(timeout: 5) + }).store(in: &subscriptions) + } + + func test_givenAuthenticatedWebService_whenContinuousRequests_andTokenRefreshAttempts_thenShouldWaitForTokenRefreshing() { + let request = URLRequest(url: URL(string: "foo.com")!) + let expectation1 = self.expectation(description: "parallel request test has failed") + let expectation2 = self.expectation(description: "parallel request test has failed") + + tokenProvider.accessToken.send("invalidToken") + session.result = ((Data(), 401), nil) + + func fireRequest2() { + print("Request 2 started") + self.webService.execute(urlRequest: request) + .subscribe(on: DispatchQueue.global()) + .receive(on: DispatchQueue.main) + .sink(receiveCompletion: { + if case .finished = $0 { + print("Request 2 finished") + } + else { + XCTFail("should not receive failure since the token is refreshed") + } + }, + receiveValue: { (_: SampleResponse) in + print("Request 2 received value") + XCTAssertEqual(self.tokenProvider.methodCallStack, ["invalidateAccessToken()", "reissueAccessToken()"]) + expectation2.fulfill() + }) + .store(in: &self.subscriptions) } - func test_givenAuthenticatedWebService_whenContinuousRequests_andTokenRefreshAttempts_thenShouldWaitForTokenRefreshing() { - let request = URLRequest(url: URL(string: "foo.com")!) - let expectation1 = self.expectation(description: "parallel reqeust test has failed") - let expectation2 = self.expectation(description: "parallel reqeust test has failed") - - tokenProvider.accessToken.send("invalidToken") - session.result = ((Data(), 401), nil) - - DispatchQueue.global(qos: .default).async { - print("Request 1 started") - self.webService.execute(urlRequest: request) - .sink(receiveCompletion: { - if case .finished = $0 { - expectation1.fulfill() - print("Request 1 finished") - } - else { - XCTFail("should not receive failure since the token is refreshed") - } - }, - receiveValue: { _ in - print("Request 1 received value") - XCTAssertEqual(self.tokenProvider.methodCallStack, ["invalidateAccessToken()", "reissueAccessToken()"]) - }) - .store(in: &self.subscriptions) - } - - DispatchQueue.global(qos: .default).async { - sleep(1) - print("Request 2 started") - self.webService.execute(urlRequest: request) - .sink(receiveCompletion: { - if case .finished = $0 { - print("Request 2 finished") - expectation2.fulfill() - } - else { - XCTFail("should not receive failure since the token is refreshed") - } - }, - receiveValue: { (_: SampleResponse) in - print("Request 2 received value") - XCTAssertEqual(self.tokenProvider.methodCallStack, ["invalidateAccessToken()", "reissueAccessToken()"]) - }) - .store(in: &self.subscriptions) - } - - waitForExpectations(timeout: 5) + DispatchQueue.global().async { + print("Request 1 started") + self.webService.execute(urlRequest: request) + .sink(receiveCompletion: { + if case .finished = $0 { + expectation1.fulfill() + print("Request 1 finished") + fireRequest2() + } + else { + XCTFail("should not receive failure since the token is refreshed") + } + }, + receiveValue: { _ in + print("Request 1 received value") + XCTAssertEqual(self.tokenProvider.methodCallStack, ["invalidateAccessToken()", "reissueAccessToken()"]) + }) + .store(in: &self.subscriptions) } + + waitForExpectations(timeout: 5) + } } private class MockAuthenticatedServiceSession: MockSession { - override func dataTaskPublisher(for request: URLRequest) -> AnyPublisher<(data: Data, response: URLResponse), Error> { - return super.dataTaskPublisher(for: request) - } + override func dataTaskPublisher(for request: URLRequest) -> AnyPublisher<(data: Data, response: URLResponse), Error> { + return super.dataTaskPublisher(for: request) + } } private class MockTokenProvider: AuthenticationTokenProvidable { - private(set) var methodCallStack = [String]() - var accessToken: CurrentValueSubject = CurrentValueSubject(nil) - var refreshToken: CurrentValueSubject = CurrentValueSubject(nil) - - func reissueAccessToken() -> AnyPublisher { - let asyncRefreshStream = PassthroughSubject() - - // replicate a slow & asnyc token refresh - DispatchQueue.global(qos: .default).async { - sleep(2) - self.accessToken.send("newToken") - asyncRefreshStream.send(completion: .finished) - self.methodCallStack.append(#function) + private(set) var methodCallStack = [String]() + var accessToken: CurrentValueSubject = CurrentValueSubject(nil) + var refreshToken: CurrentValueSubject = CurrentValueSubject(nil) + + func reissueAccessToken() -> AnyPublisher { + // replicate a slow & asnyc token refresh + sleep(2) + self.accessToken.send("newToken") + self.methodCallStack.append(#function) + + return Deferred { + Future { promise in + promise(.success("newToken")) } + }.eraseToAnyPublisher() + } - return asyncRefreshStream.eraseToAnyPublisher() - } + func invalidateAccessToken() { + accessToken.send(nil) + methodCallStack.append(#function) + } - func invalidateAccessToken() { - accessToken.send(nil) - methodCallStack.append(#function) - } - - func invalidateRefreshToken() { - methodCallStack.append(#function) - } + func invalidateRefreshToken() { + methodCallStack.append(#function) + } } diff --git a/Tests/WebService/AuthenticatedWebService/AuthenticatedWebServiceTests.swift b/Tests/WebService/AuthenticatedWebService/AuthenticatedWebServiceTests.swift index 46d1e72..d1f462a 100644 --- a/Tests/WebService/AuthenticatedWebService/AuthenticatedWebServiceTests.swift +++ b/Tests/WebService/AuthenticatedWebService/AuthenticatedWebServiceTests.swift @@ -56,10 +56,17 @@ class AuthenticatedWebServiceTests: XCTestCase { tokenProvider: tokenProvider) session.result = ((Data(), 200), nil) tokenProvider.accessToken.value = "someToken" + let expectation = self.expectation(description: "No authorization header scheme test has failed") - _ = webService.execute(urlRequest: request).sink(receiveCompletion: { _ in }, receiveValue: { _ in }) + webService.execute(urlRequest: request) + .sink(receiveCompletion: { _ in }, + receiveValue: { _ in + XCTAssertEqual(self.session.finalUrlRequest?.allHTTPHeaderFields?["Authorization"], "someToken") + expectation.fulfill() + }) + .store(in: &subscriptions) - XCTAssertEqual(session.finalUrlRequest?.allHTTPHeaderFields?["Authorization"], "someToken") + waitForExpectations(timeout: 0.5) } func test_givenAuthenticatedWebService_whenAuthorizationHeaderSchemeBasic_shouldAppendBasicHeader() { @@ -69,10 +76,17 @@ class AuthenticatedWebServiceTests: XCTestCase { configuration: AuthenticatedWebServiceConfiguration(authorizationHeaderScheme: .basic)) session.result = ((Data(), 200), nil) tokenProvider.accessToken.value = "someToken" + let expectation = self.expectation(description: "Basic authorization header scheme test has failed") - _ = webService.execute(urlRequest: request).sink(receiveCompletion: { _ in }, receiveValue: { _ in }) + webService.execute(urlRequest: request) + .sink(receiveCompletion: { _ in }, + receiveValue: { _ in + XCTAssertEqual(self.session.finalUrlRequest?.allHTTPHeaderFields?["Authorization"], "Basic someToken") + expectation.fulfill() + }) + .store(in: &subscriptions) - XCTAssertEqual(session.finalUrlRequest?.allHTTPHeaderFields?["Authorization"], "Basic someToken") + waitForExpectations(timeout: 0.5) } func test_givenAuthenticatedWebService_whenAuthorizationHeaderSchemeBearer_shouldAppendBearerHeader() { @@ -84,10 +98,17 @@ class AuthenticatedWebServiceTests: XCTestCase { let encodedJSON = try! encoder.encode(["name": "value"]) session.result = ((encodedJSON, 200), nil) tokenProvider.accessToken.value = "someToken" + let expectation = self.expectation(description: "Bearer authorization header scheme test has failed") - _ = webService.execute(urlRequest: request).sink(receiveCompletion: { _ in }, receiveValue: { (_: SampleResponse) in }) + webService.execute(urlRequest: request) + .sink(receiveCompletion: { _ in }, + receiveValue: { _ in + XCTAssertEqual(self.session.finalUrlRequest?.allHTTPHeaderFields?["Authorization"], "Bearer someToken") + expectation.fulfill() + }) + .store(in: &subscriptions) - XCTAssertEqual(session.finalUrlRequest?.allHTTPHeaderFields?["Authorization"], "Bearer someToken") + waitForExpectations(timeout: 0.5) } func test_givenAuthenticatedWebService_whenParallelRequestsFired_thenShouldNotRaceForTokenRefresh() { @@ -118,34 +139,33 @@ class AuthenticatedWebServiceTests: XCTestCase { // given second call, has an invalid token testScheduler.schedule(after: 200) { // Demonstrate two parallel requests not racing each other to refresh the token + print("invalid token is set") self.tokenProvider.accessToken.value = "invalidToken" self.webService.execute(urlRequest: request) .sink(receiveCompletion: { - if case .finished = $0 { - expectation1.fulfill() - } - else { + if case .failure = $0 { XCTFail("should not receive failure since the token is refreshed") } }, receiveValue: { _ in XCTAssertEqual(self.tokenProvider.methodCallStack, ["invalidateAccessToken()", "reissueAccessToken()"]) + expectation1.fulfill() }) .store(in: &self.subscriptions) // a parallel call should succesfully execute since the token is refreshed by the previous call self.webService.execute(urlRequest: request) + .receive(on: testScheduler) .sink(receiveCompletion: { - if case .finished = $0 { - expectation2.fulfill() - } else { + if case .failure = $0 { XCTFail("should not receive failure since the token is refreshed") } }, receiveValue: { (_: SampleResponse) in // Reissuing is called only once even though there are two parallel calls XCTAssertEqual(self.tokenProvider.methodCallStack, ["invalidateAccessToken()", "reissueAccessToken()"]) + expectation2.fulfill() }) .store(in: &self.subscriptions) } @@ -181,6 +201,7 @@ class AuthenticatedWebServiceTests: XCTestCase { // first call should refresh the token with a valid one and override the previously set invalid token testScheduler.schedule(after: 200) { self.webService.execute(urlRequest: request) + .receive(on: testScheduler) .sink(receiveCompletion: { if case .finished = $0 { expectation1.fulfill() @@ -195,12 +216,13 @@ class AuthenticatedWebServiceTests: XCTestCase { .store(in: &self.subscriptions) } - // second call should execute normally even if it has an invalid token, since the previous call is already refreshing the token for this one as wel + // second call should execute normally even if it has an invalid token, since the previous call is already refreshing the token for this one as well testScheduler.schedule(after: 220) { // Demonstrate two consecutive requests not racing each other to refresh the token self.tokenProvider.accessToken.value = "invalidToken2" self.webService.execute(urlRequest: request) + .receive(on: testScheduler) .sink(receiveCompletion: { if case .finished = $0 { expectation2.fulfill() @@ -289,13 +311,18 @@ private class MockAuthenticatedServiceSession: MockSession { private class MockTokenProvider: AuthenticationTokenProvidable { private(set) var methodCallStack = [String]() - var accessToken: CurrentValueSubject = CurrentValueSubject(nil) - var refreshToken: CurrentValueSubject = CurrentValueSubject(nil) + var accessToken: CurrentValueSubject = CurrentValueSubject(nil) + var refreshToken: CurrentValueSubject = CurrentValueSubject(nil) - func reissueAccessToken() -> AnyPublisher { - accessToken.send("newToken") - methodCallStack.append(#function) - return Empty(completeImmediately: true).eraseToAnyPublisher() + func reissueAccessToken() -> AnyPublisher { + self.accessToken.send("newToken") + self.methodCallStack.append(#function) + + return Deferred { + Future { promise in + promise(.success("newToken")) + } + }.eraseToAnyPublisher() } func invalidateAccessToken() { diff --git a/fusion.xcodeproj/project.pbxproj b/fusion.xcodeproj/project.pbxproj index 04df52b..9b58c0d 100644 --- a/fusion.xcodeproj/project.pbxproj +++ b/fusion.xcodeproj/project.pbxproj @@ -47,7 +47,7 @@ 3A71BE7023F35911009D092E /* WebServiceHttpStatusCodeTests.swift */ = {isa = PBXFileReference; indentWidth = 2; lastKnownFileType = sourcecode.swift; path = WebServiceHttpStatusCodeTests.swift; sourceTree = ""; tabWidth = 2; }; 3A71BE7223F35DA4009D092E /* WebServiceStreamTests.swift */ = {isa = PBXFileReference; indentWidth = 2; lastKnownFileType = sourcecode.swift; path = WebServiceStreamTests.swift; sourceTree = ""; tabWidth = 2; }; 3A7DD7CA23F449F2001AED72 /* AuthenticatedWebServiceTests.swift */ = {isa = PBXFileReference; indentWidth = 2; lastKnownFileType = sourcecode.swift; path = AuthenticatedWebServiceTests.swift; sourceTree = ""; tabWidth = 2; }; - 3AC7B89A23FD695500BCE0FE /* AsyncTokenRefreshTests.swift */ = {isa = PBXFileReference; lastKnownFileType = sourcecode.swift; path = AsyncTokenRefreshTests.swift; sourceTree = ""; }; + 3AC7B89A23FD695500BCE0FE /* AsyncTokenRefreshTests.swift */ = {isa = PBXFileReference; indentWidth = 2; lastKnownFileType = sourcecode.swift; path = AsyncTokenRefreshTests.swift; sourceTree = ""; tabWidth = 2; }; 3AC7B89C23FD977D00BCE0FE /* ThreadSafePropertyWrapper.swift */ = {isa = PBXFileReference; lastKnownFileType = sourcecode.swift; path = ThreadSafePropertyWrapper.swift; sourceTree = ""; }; 3AE34B1723F2056B00B453D7 /* fusion.framework */ = {isa = PBXFileReference; explicitFileType = wrapper.framework; includeInIndex = 0; path = fusion.framework; sourceTree = BUILT_PRODUCTS_DIR; }; 3AE34B1B23F2056B00B453D7 /* Info.plist */ = {isa = PBXFileReference; lastKnownFileType = text.plist.xml; path = Info.plist; sourceTree = ""; }; From c38025974cb820e72a8b9382c58c991fa1e4020e Mon Sep 17 00:00:00 2001 From: Eren Kabakci Date: Tue, 28 Apr 2020 12:00:15 +0200 Subject: [PATCH 2/3] Change testing approach drastically --- Sources/NetworkError.swift | 1 + .../AuthenticatedWebService.swift | 10 +- .../AsyncTokenRefreshTests.swift | 7 +- .../AuthenticatedWebServiceTests.swift | 171 +++++++++--------- .../xcshareddata/xcschemes/fusion.xcscheme | 8 +- 5 files changed, 102 insertions(+), 95 deletions(-) diff --git a/Sources/NetworkError.swift b/Sources/NetworkError.swift index 4eb4c7e..0037c73 100644 --- a/Sources/NetworkError.swift +++ b/Sources/NetworkError.swift @@ -30,6 +30,7 @@ public enum NetworkError: Error, Equatable { case urlError(URLError?) case parsingFailure case corruptUrl + case timeout case unauthorized case forbidden case generic(HttpStatusCode) diff --git a/Sources/WebService/AuthenticatedWebService/AuthenticatedWebService.swift b/Sources/WebService/AuthenticatedWebService/AuthenticatedWebService.swift index 551fe18..afbb2ac 100644 --- a/Sources/WebService/AuthenticatedWebService/AuthenticatedWebService.swift +++ b/Sources/WebService/AuthenticatedWebService/AuthenticatedWebService.swift @@ -74,7 +74,9 @@ open class AuthenticatedWebService: WebService { .flatMap { accessToken -> AnyPublisher in return appendTokenAndExecute(accessToken: accessToken) } - }.catch { [weak self] error -> AnyPublisher in + } + .timeout(10, scheduler: DispatchQueue.main, customError: { NetworkError.timeout }) + .catch { [weak self] error -> AnyPublisher in guard let self = self else { return Fail(error: NetworkError.unknown).eraseToAnyPublisher() } @@ -107,8 +109,10 @@ open class AuthenticatedWebService: WebService { .setFailureType(to: Error.self) .flatMap { accessToken -> AnyPublisher in return appendTokenAndExecute(accessToken: accessToken) - } - }.catch { [weak self] error -> AnyPublisher in + }.eraseToAnyPublisher() + } + .timeout(10, scheduler: DispatchQueue.main, customError: { NetworkError.timeout }) + .catch { [weak self] error -> AnyPublisher in guard let self = self else { return Fail(error: NetworkError.unknown).eraseToAnyPublisher() } diff --git a/Tests/WebService/AuthenticatedWebService/AsyncTokenRefreshTests.swift b/Tests/WebService/AuthenticatedWebService/AsyncTokenRefreshTests.swift index 03f2790..d97d147 100644 --- a/Tests/WebService/AuthenticatedWebService/AsyncTokenRefreshTests.swift +++ b/Tests/WebService/AuthenticatedWebService/AsyncTokenRefreshTests.swift @@ -20,6 +20,7 @@ class AsyncTokenRefreshTests: XCTestCase { override func setUp() { super.setUp() + subscriptions = Set() session = MockAuthenticatedServiceSession() tokenProvider = MockTokenProvider() subscriptions = Set() @@ -29,7 +30,7 @@ class AsyncTokenRefreshTests: XCTestCase { let encodedData = try! self.encoder.encode(["id": "value"]) tokenProvider.accessToken .sink(receiveValue: { - if $0 == "newToken" { + if $0 == "newAsyncToken" { print("Change session response to 200") self.session.result = ((encodedData, 200), nil) } @@ -103,12 +104,12 @@ private class MockTokenProvider: AuthenticationTokenProvidable { func reissueAccessToken() -> AnyPublisher { // replicate a slow & asnyc token refresh sleep(2) - self.accessToken.send("newToken") + self.accessToken.send("newAsyncToken") self.methodCallStack.append(#function) return Deferred { Future { promise in - promise(.success("newToken")) + promise(.success("newAsyncToken")) } }.eraseToAnyPublisher() } diff --git a/Tests/WebService/AuthenticatedWebService/AuthenticatedWebServiceTests.swift b/Tests/WebService/AuthenticatedWebService/AuthenticatedWebServiceTests.swift index d1f462a..e718971 100644 --- a/Tests/WebService/AuthenticatedWebService/AuthenticatedWebServiceTests.swift +++ b/Tests/WebService/AuthenticatedWebService/AuthenticatedWebServiceTests.swift @@ -31,11 +31,12 @@ class AuthenticatedWebServiceTests: XCTestCase { private var session: MockAuthenticatedServiceSession! private var tokenProvider: MockTokenProvider! private var webService: AuthenticatedWebService! - private var subscriptions = Set() - private let encoder = JSONEncoder() + private var subscriptions: Set! + private var encoder = JSONEncoder() override func setUp() { super.setUp() + subscriptions = Set() session = MockAuthenticatedServiceSession() tokenProvider = MockTokenProvider() webService = AuthenticatedWebService(urlSession: session, @@ -111,63 +112,55 @@ class AuthenticatedWebServiceTests: XCTestCase { waitForExpectations(timeout: 0.5) } - func test_givenAuthenticatedWebService_whenParallelRequestsFired_thenShouldNotRaceForTokenRefresh() { + func test_givenAuthenticatedWebService_whenTimeoutWithInitialTokenlessState_andHavingInvalidToken_thenNextRequestShouldRefreshTokenOnce() { let testScheduler = TestScheduler(initialClock: 0) let request = URLRequest(url: URL(string: "foo.com")!) let expectation1 = self.expectation(description: "authentication stream test expectation1") let expectation2 = self.expectation(description: "authentication stream test expectation2") - testScheduler.schedule(after: 100) { - // Mimicking successful api response but no token case - self.session.result = ((Data(), 200), nil) - - // First call should fail since there is no access token yet - self.webService.execute(urlRequest: request) - .sink(receiveCompletion: { - if case let .failure(error as NetworkError) = $0 { - XCTAssertEqual(error, NetworkError.unauthorized) - XCTAssertEqual(self.tokenProvider.methodCallStack, []) - } - }, - receiveValue: { _ in - XCTFail("No value should be received") - }) - .store(in: &self.subscriptions) - self.session.result = ((Data(), 401), nil) + testScheduler.schedule(after: 200) { + DispatchQueue.global().asyncAfter(deadline: .now(), execute: { + print("initial tokenless state") + // Mimicking successful api response but no token case + self.session.result = ((Data(), 200), nil) + + // First call should fail since there is no access token yet + self.webService.execute(urlRequest: request) + .receive(on: DispatchQueue.main) + .sink(receiveCompletion: { + if case let .failure(error as NetworkError) = $0 { + XCTAssertEqual(error, NetworkError.timeout) + XCTAssertEqual(self.tokenProvider.methodCallStack, []) + expectation1.fulfill() + } + }, + receiveValue: { _ in + XCTFail("No value should be received") + }) + .store(in: &self.subscriptions) + self.session.result = ((Data(), 401), nil) + }) } - // given second call, has an invalid token - testScheduler.schedule(after: 200) { - // Demonstrate two parallel requests not racing each other to refresh the token - print("invalid token is set") - self.tokenProvider.accessToken.value = "invalidToken" - - self.webService.execute(urlRequest: request) - .sink(receiveCompletion: { - if case .failure = $0 { - XCTFail("should not receive failure since the token is refreshed") - } - }, - receiveValue: { _ in - XCTAssertEqual(self.tokenProvider.methodCallStack, ["invalidateAccessToken()", "reissueAccessToken()"]) - expectation1.fulfill() - }) - .store(in: &self.subscriptions) - - // a parallel call should succesfully execute since the token is refreshed by the previous call - self.webService.execute(urlRequest: request) - .receive(on: testScheduler) - .sink(receiveCompletion: { - if case .failure = $0 { - XCTFail("should not receive failure since the token is refreshed") - } - }, - receiveValue: { (_: SampleResponse) in - // Reissuing is called only once even though there are two parallel calls - XCTAssertEqual(self.tokenProvider.methodCallStack, ["invalidateAccessToken()", "reissueAccessToken()"]) - expectation2.fulfill() - }) - .store(in: &self.subscriptions) + testScheduler.schedule(after: 400) { + DispatchQueue.global().asyncAfter(deadline: .now() + 12, execute: { + // Demonstrate two parallel requests not racing each other to refresh the token + print("invalid token is set") + self.tokenProvider.accessToken.value = "invalidToken" + + self.webService.execute(urlRequest: request) + .receive(on: DispatchQueue.main) + .sink(receiveCompletion: { + if case .failure = $0 { + XCTFail("should not receive failure since the token is refreshed") + } + }, + receiveValue: { _ in + XCTAssertEqual(self.tokenProvider.methodCallStack, ["invalidateAccessToken()", "reissueAccessToken()"]) + expectation2.fulfill() + }) + .store(in: &self.subscriptions) + }) } let subscriber = testScheduler.createTestableSubscriber(String?.self, Never.self) @@ -175,14 +168,14 @@ class AuthenticatedWebServiceTests: XCTestCase { testScheduler.resume() - waitForExpectations(timeout: 2) + waitForExpectations(timeout: 15) let expected: TestSequence = [ (0, .subscription), (0, .input(nil)), - (200, .input("invalidToken")), - (200, .input(nil)), - (200, .input("newToken"))] + (400, .input("invalidToken")), + (400, .input(nil)), + (400, .input("newToken"))] XCTAssertEqual(expected, subscriber.recordedOutput) } @@ -190,8 +183,8 @@ class AuthenticatedWebServiceTests: XCTestCase { func test_givenAuthenticatedWebService_whenContinousRequestsFired_thenShouldNotRaceForTokenRefresh() { let testScheduler = TestScheduler(initialClock: 0) let request = URLRequest(url: URL(string: "foo.com")!) - let expectation1 = self.expectation(description: "authentication stream test expectation1") - let expectation2 = self.expectation(description: "authentication stream test expectation2") + let expectation3 = self.expectation(description: "authentication stream test expectation3") + let expectation4 = self.expectation(description: "authentication stream test expectation4") testScheduler.schedule(after: 100) { self.tokenProvider.accessToken.value = "invalidToken" @@ -200,41 +193,43 @@ class AuthenticatedWebServiceTests: XCTestCase { // first call should refresh the token with a valid one and override the previously set invalid token testScheduler.schedule(after: 200) { + DispatchQueue.global().asyncAfter(deadline: .now(), execute: { self.webService.execute(urlRequest: request) - .receive(on: testScheduler) + .receive(on: DispatchQueue.main) .sink(receiveCompletion: { - if case .finished = $0 { - expectation1.fulfill() - } else { + if case .failure = $0 { XCTFail("should not receive failure since the token is refreshed") } }, receiveValue: { (_: SampleResponse) in // Reissuing is called only once even though there are two parallel calls XCTAssertEqual(self.tokenProvider.methodCallStack, ["invalidateAccessToken()", "reissueAccessToken()"]) + expectation3.fulfill() + fireRequest2() }) .store(in: &self.subscriptions) + }) } // second call should execute normally even if it has an invalid token, since the previous call is already refreshing the token for this one as well - testScheduler.schedule(after: 220) { - // Demonstrate two consecutive requests not racing each other to refresh the token - self.tokenProvider.accessToken.value = "invalidToken2" - - self.webService.execute(urlRequest: request) - .receive(on: testScheduler) - .sink(receiveCompletion: { - if case .finished = $0 { - expectation2.fulfill() - } - else { - XCTFail("should not receive failure since the token is refreshed") - } - }, - receiveValue: { _ in - XCTAssertEqual(self.tokenProvider.methodCallStack, ["invalidateAccessToken()", "reissueAccessToken()"]) - }) - .store(in: &self.subscriptions) + func fireRequest2() { + DispatchQueue.global().asyncAfter(deadline: .now() + 5, execute: { + // Demonstrate two consecutive requests not racing each other to refresh the token + self.tokenProvider.accessToken.value = "invalidToken2" + + self.webService.execute(urlRequest: request) + .receive(on: DispatchQueue.main) + .sink(receiveCompletion: { + if case .failure = $0 { + XCTFail("should not receive failure since the token is refreshed") + } + }, + receiveValue: { _ in + XCTAssertEqual(self.tokenProvider.methodCallStack, ["invalidateAccessToken()", "reissueAccessToken()"]) + expectation4.fulfill() + }) + .store(in: &self.subscriptions) + }) } let subscriber = testScheduler.createTestableSubscriber(String?.self, Never.self) @@ -242,7 +237,7 @@ class AuthenticatedWebServiceTests: XCTestCase { testScheduler.resume() - waitForExpectations(timeout: 2) + waitForExpectations(timeout: 10) let expected: TestSequence = [ (0, .subscription), @@ -250,7 +245,7 @@ class AuthenticatedWebServiceTests: XCTestCase { (100, .input("invalidToken")), (200, .input(nil)), (200, .input("newToken")), - (220, .input("invalidToken2")),] + (200, .input("invalidToken2"))] XCTAssertEqual(expected, subscriber.recordedOutput) } @@ -315,13 +310,13 @@ private class MockTokenProvider: AuthenticationTokenProvidable { var refreshToken: CurrentValueSubject = CurrentValueSubject(nil) func reissueAccessToken() -> AnyPublisher { - self.accessToken.send("newToken") - self.methodCallStack.append(#function) + self.accessToken.send("newToken") + self.methodCallStack.append(#function) - return Deferred { - Future { promise in - promise(.success("newToken")) - } + return Deferred { + Future { promise in + promise(.success("newToken")) + } }.eraseToAnyPublisher() } diff --git a/fusion.xcodeproj/xcshareddata/xcschemes/fusion.xcscheme b/fusion.xcodeproj/xcshareddata/xcschemes/fusion.xcscheme index 39db054..0d40d7e 100644 --- a/fusion.xcodeproj/xcshareddata/xcschemes/fusion.xcscheme +++ b/fusion.xcodeproj/xcshareddata/xcschemes/fusion.xcscheme @@ -30,7 +30,8 @@ codeCoverageEnabled = "YES"> + skipped = "NO" + parallelizable = "YES"> + + + + From 8ac9e975fa2e2e3f2d47dfafa19fd3ad4c364ffb Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Eren=20Kabak=C3=A7=C4=B1?= Date: Tue, 28 Apr 2020 12:26:35 +0200 Subject: [PATCH 3/3] Update swift.yml --- .github/workflows/swift.yml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/.github/workflows/swift.yml b/.github/workflows/swift.yml index c730fdf..a1dddd5 100644 --- a/.github/workflows/swift.yml +++ b/.github/workflows/swift.yml @@ -15,4 +15,4 @@ jobs: - name: Build run: swift build -v - name: Run tests - run: swift test -v + run: xcodebuild test -project fusion.xcodeproj -scheme fusion -destination 'platform=iOS Simulator,name=iPhone 11'