Skip to content

Commit

Permalink
Overhaul Refresh Token Retry Mechanism (#4)
Browse files Browse the repository at this point in the history
  • Loading branch information
erenkabakci committed Apr 28, 2020
1 parent 4fead08 commit 8512cd2
Show file tree
Hide file tree
Showing 11 changed files with 326 additions and 331 deletions.
2 changes: 1 addition & 1 deletion .github/workflows/swift.yml
Original file line number Diff line number Diff line change
Expand Up @@ -15,4 +15,4 @@ jobs:
- name: Build
run: swift build -v
- name: Run tests
run: swift test -v
run: xcodebuild test -project fusion.xcodeproj -scheme fusion -destination 'platform=iOS Simulator,name=iPhone 11'
1 change: 1 addition & 0 deletions Sources/NetworkError.swift
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,7 @@ public enum NetworkError: Error, Equatable {
case urlError(URLError?)
case parsingFailure
case corruptUrl
case timeout
case unauthorized
case forbidden
case generic(HttpStatusCode)
Expand Down
2 changes: 1 addition & 1 deletion Sources/SessionPublisherProtocol.swift
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,7 @@ public protocol SessionPublisherProtocol: AnyObject {
extension URLSession: SessionPublisherProtocol {
public func dataTaskPublisher(for request: URLRequest) -> AnyPublisher<(data: Data, response: URLResponse), Error> {
self.dataTaskPublisher(for: request)
.receive(on: DispatchQueue.main)
.subscribe(on: DispatchQueue.global())
.mapError { NetworkError.urlError($0) }
.eraseToAnyPublisher()
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -33,10 +33,10 @@ public enum AuthorizationHeaderScheme: String {

public struct AuthenticatedWebServiceConfiguration {
let authorizationHeaderScheme: AuthorizationHeaderScheme
let refreshTriggerErrors: [Error]
let refreshTriggerErrors: [NetworkError]

public init(authorizationHeaderScheme: AuthorizationHeaderScheme = .none,
refreshTriggerErrors: [Error] = [NetworkError.unauthorized]) {
refreshTriggerErrors: [NetworkError] = [NetworkError.unauthorized]) {
self.authorizationHeaderScheme = authorizationHeaderScheme
self.refreshTriggerErrors = refreshTriggerErrors
}
Expand All @@ -56,84 +56,83 @@ open class AuthenticatedWebService: WebService {
self.configuration = configuration
super.init(urlSession: urlSession)
}

override public func execute<T>(urlRequest: URLRequest) -> AnyPublisher<T, Error> where T : Decodable {
var urlRequest = urlRequest
var currentAccessToken: String?

authenticationQueue.sync {
currentAccessToken = self.tokenProvider.accessToken.value
func appendTokenAndExecute(accessToken: AccessToken) -> AnyPublisher<T, Error> {
urlRequest.setValue(self.configuration.authorizationHeaderScheme.rawValue + accessToken, forHTTPHeaderField: "Authorization")
return super.execute(urlRequest: urlRequest)
.subscribe(on: DispatchQueue.global())
.eraseToAnyPublisher()
}

return Deferred {
self.tokenProvider.accessToken
.compactMap { $0 }
.setFailureType(to: Error.self)
.flatMap { accessToken -> AnyPublisher<T, Error> in
return appendTokenAndExecute(accessToken: accessToken)
}
}

guard let accessToken = currentAccessToken else {
return Fail<T, Error>(error: NetworkError.unauthorized).eraseToAnyPublisher()
.timeout(10, scheduler: DispatchQueue.main, customError: { NetworkError.timeout })
.catch { [weak self] error -> AnyPublisher<T, Error> in
guard let self = self else {
return Fail<T, Error>(error: NetworkError.unknown).eraseToAnyPublisher()
}

if self.configuration.refreshTriggerErrors.contains(where: { return $0.reflectedString == error.reflectedString }){
return self.retrySynchronizedTokenRefresh()
.flatMap {
appendTokenAndExecute(accessToken: $0)
}.eraseToAnyPublisher()
}
return Fail<T, Error>(error: error).eraseToAnyPublisher()
}

urlRequest.setValue(self.configuration.authorizationHeaderScheme.rawValue + accessToken, forHTTPHeaderField: "Authorization")

return super.execute(urlRequest: urlRequest)
.catch { [weak self] error -> AnyPublisher<T, Error> in
guard let self = self else {
return Fail<T, Error>(error: NetworkError.unknown).eraseToAnyPublisher()
}

if self.configuration.refreshTriggerErrors.contains(where: { return $0.reflectedString == error.reflectedString }){
self.retrySynchronizedTokenRefresh()

return self.execute(urlRequest: urlRequest)
.delay(for: 0.2, scheduler: self.authenticationQueue)
.eraseToAnyPublisher()
}
return Fail<T, Error>(error: error).eraseToAnyPublisher()
}.eraseToAnyPublisher()
.receive(on: DispatchQueue.main)
.eraseToAnyPublisher()
}



override public func execute(urlRequest: URLRequest) -> AnyPublisher<Void, Error> {
var urlRequest = urlRequest
var currentAccessToken: String?

authenticationQueue.sync {
currentAccessToken = self.tokenProvider.accessToken.value
func appendTokenAndExecute(accessToken: AccessToken) -> AnyPublisher<Void, Error> {
urlRequest.setValue(self.configuration.authorizationHeaderScheme.rawValue + accessToken, forHTTPHeaderField: "Authorization")
return super.execute(urlRequest: urlRequest)
.subscribe(on: DispatchQueue.global())
.eraseToAnyPublisher()
}

guard let accessToken = currentAccessToken else {
return Fail<Void, Error>(error: NetworkError.unauthorized).eraseToAnyPublisher()
return Deferred {
self.tokenProvider.accessToken
.compactMap { $0 }
.setFailureType(to: Error.self)
.flatMap { accessToken -> AnyPublisher<Void, Error> in
return appendTokenAndExecute(accessToken: accessToken)
}.eraseToAnyPublisher()
}
.timeout(10, scheduler: DispatchQueue.main, customError: { NetworkError.timeout })
.catch { [weak self] error -> AnyPublisher<Void, Error> in
guard let self = self else {
return Fail<Void, Error>(error: NetworkError.unknown).eraseToAnyPublisher()
}

if self.configuration.refreshTriggerErrors.contains(where: { return $0.reflectedString == error.reflectedString }){
return self.retrySynchronizedTokenRefresh()
.flatMap {
appendTokenAndExecute(accessToken: $0)
}.eraseToAnyPublisher()
}
return Fail<Void, Error>(error: error).eraseToAnyPublisher()
}

urlRequest.setValue(self.configuration.authorizationHeaderScheme.rawValue + accessToken, forHTTPHeaderField: "Authorization")

return super.execute(urlRequest: urlRequest)
.catch { [weak self] error -> AnyPublisher<Void, Error> in
guard let self = self else {
return Fail<Void, Error>(error: NetworkError.unknown).eraseToAnyPublisher()
}

if self.configuration.refreshTriggerErrors.contains(where: { return $0.reflectedString == error.reflectedString }){
self.retrySynchronizedTokenRefresh()

return self.execute(urlRequest: urlRequest)
.delay(for: 0.2, scheduler: self.authenticationQueue)
.eraseToAnyPublisher()
}
return Fail<Void, Error>(error: error).eraseToAnyPublisher()
}.eraseToAnyPublisher()
.receive(on: DispatchQueue.main)
.eraseToAnyPublisher()
}

private func retrySynchronizedTokenRefresh() {
let dispatchGroup = DispatchGroup()
dispatchGroup.enter()

authenticationQueue.sync(flags: .barrier) {
self.tokenProvider.invalidateAccessToken()
self.tokenProvider.reissueAccessToken()
.sink(receiveCompletion: { _ in
dispatchGroup.leave()
},
receiveValue: { _ in })
.store(in: &self.subscriptions)
dispatchGroup.wait()
}
private func retrySynchronizedTokenRefresh() -> AnyPublisher<AccessToken, Error> {
tokenProvider.invalidateAccessToken()
return tokenProvider.reissueAccessToken()
.eraseToAnyPublisher()
}
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -25,10 +25,13 @@
import Combine
import Foundation

public typealias AccessToken = String
public typealias RefreshToken = String

public protocol AuthenticationTokenProvidable: AnyObject {
var accessToken: CurrentValueSubject<String?, Never> { get }
var refreshToken: CurrentValueSubject<String?, Never> { get }
func reissueAccessToken() -> AnyPublisher<Never, Error>
var accessToken: CurrentValueSubject<AccessToken?, Never> { get }
var refreshToken: CurrentValueSubject<RefreshToken?, Never> { get }
func reissueAccessToken() -> AnyPublisher<AccessToken, Error>
func invalidateAccessToken()
func invalidateRefreshToken()
}
28 changes: 18 additions & 10 deletions Sources/WebService/WebService.swift
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@
//
// The above copyright notice and this permission notice shall be included in all
// copies or substantial portions of the Software.
//
//
// THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
// IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
// FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
Expand All @@ -35,7 +35,7 @@ CustomDecodable{
public let jsonDecoder: JSONDecoder = JSONDecoder()
private let session: SessionPublisherProtocol
@ThreadSafe open var subscriptions = Set<AnyCancellable>()

public init(urlSession: SessionPublisherProtocol = URLSession(configuration: URLSessionConfiguration.ephemeral,
delegate: nil,
delegateQueue: nil)) {
Expand All @@ -49,9 +49,11 @@ CustomDecodable{
return Fail(error: NetworkError.unknown).eraseToAnyPublisher()
}
return CurrentValueSubject((output.data, httpResponse)).eraseToAnyPublisher()
}.eraseToAnyPublisher()
}
.receive(on: DispatchQueue.main)
.eraseToAnyPublisher()
}

public func execute<T>(urlRequest: URLRequest) -> AnyPublisher<T, Error> where T : Decodable {
Deferred {
Future { [weak self] promise in
Expand All @@ -62,8 +64,9 @@ CustomDecodable{

var urlRequest = urlRequest
urlRequest.appendAdditionalHeaders(headers: self.defaultHttpHeaders)

self.rawResponse(urlRequest: urlRequest)
.subscribe(on: DispatchQueue.global())
.tryMap {
try self.mapHttpResponseCodes(output: $0)

Expand All @@ -80,21 +83,24 @@ CustomDecodable{
receiveValue: { promise(.success($0)) })
.store(in: &self.subscriptions)
}
}.eraseToAnyPublisher()
}
.receive(on: DispatchQueue.main)
.eraseToAnyPublisher()
}

public func execute(urlRequest: URLRequest) -> AnyPublisher<Void, Error> {
Deferred {
Future { [weak self] promise in
guard let self = self else {
promise(.failure(NetworkError.unknown))
return
}

var urlRequest = urlRequest
urlRequest.appendAdditionalHeaders(headers: self.defaultHttpHeaders)

self.rawResponse(urlRequest: urlRequest)
.subscribe(on: DispatchQueue.global())
.tryMap {
try self.mapHttpResponseCodes(output: $0)
return
Expand All @@ -107,7 +113,9 @@ CustomDecodable{
receiveValue: { promise(.success($0)) })
.store(in: &self.subscriptions)
}
}.eraseToAnyPublisher()
}
.receive(on: DispatchQueue.main)
.eraseToAnyPublisher()
}

open func mapHttpResponseCodes(output: (data:Data, response: HTTPURLResponse)) throws {
Expand Down
22 changes: 12 additions & 10 deletions Tests/MockSession.swift
Original file line number Diff line number Diff line change
Expand Up @@ -35,16 +35,18 @@ open class MockSession: SessionPublisherProtocol {
public func dataTaskPublisher(for urlRequest: URLRequest) -> AnyPublisher<(data: Data, response: URLResponse), Error> {
methodCallStack.append(#function)
finalUrlRequest = urlRequest
return Future<(data: Data, response: URLResponse), Error> { promise in
usleep(20)
if let successResponse = self.result?.0 {
promise(.success((successResponse.0,
HTTPURLResponse(url: URL(string: "foo.com")!,
statusCode: successResponse.1,
httpVersion: nil,
headerFields: nil)!)))
} else if let errorResponse = self.result?.1 {
promise(.failure(NetworkError.urlError(errorResponse)))
return Deferred {
Future<(data: Data, response: URLResponse), Error> { promise in
usleep(20)
if let successResponse = self.result?.0 {
promise(.success((successResponse.0,
HTTPURLResponse(url: URL(string: "foo.com")!,
statusCode: successResponse.1,
httpVersion: nil,
headerFields: nil)!)))
} else if let errorResponse = self.result?.1 {
promise(.failure(NetworkError.urlError(errorResponse)))
}
}
}.eraseToAnyPublisher()
}
Expand Down

0 comments on commit 8512cd2

Please sign in to comment.