Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@ class APIService {

/// A `TokenProvider` that gets the access token for the current account and can refresh it when
/// necessary.
private let accountTokenProvider: AccountTokenProvider
internal let accountTokenProvider: AccountTokenProvider

/// A builder for building an `HTTPService`.
private let httpServiceBuilder: HTTPServiceBuilder
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,14 @@
// MARK: - RefreshableAPIService

/// API service protocol to refresh tokens.
protocol RefreshableAPIService { // sourcery: AutoMockable
/// Refreshes the access token by using the refresh token to acquire a new access token.
///
func refreshAccessToken() async throws
}

extension APIService: RefreshableAPIService {
func refreshAccessToken() async throws {
try await accountTokenProvider.refreshToken()
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,55 @@
import BitwardenKitMocks
import TestHelpers
import XCTest

@testable import BitwardenShared

// MARK: - RefreshableAPIServiceTests

class RefreshableAPIServiceTests: BitwardenTestCase {
// MARK: Properties

var accountTokenProvider: MockAccountTokenProvider!
var subject: RefreshableAPIService!

// MARK: Setup & Teardown

override func setUp() {
super.setUp()

accountTokenProvider = MockAccountTokenProvider()
subject = APIService(
accountTokenProvider: accountTokenProvider,
environmentService: MockEnvironmentService(),
flightRecorder: MockFlightRecorder(),
stateService: MockStateService(),
tokenService: MockTokenService()
)
}

override func tearDown() {
super.tearDown()

accountTokenProvider = nil
subject = nil
}

// MARK: Tests

/// `refreshAccessToken()` calls the token provider to refresh the token.
@MainActor
func test_refreshAccessToken() async throws {
try await subject.refreshAccessToken()

XCTAssertTrue(accountTokenProvider.refreshTokenCalled)
}

/// `refreshAccessToken()` throws when the token provider to refresh the token throws.
@MainActor
func test_refreshAccessToken_throws() async throws {
accountTokenProvider.refreshTokenResult = .failure(BitwardenTestError.example)
await assertAsyncThrows(error: BitwardenTestError.example) {
try await subject.refreshAccessToken()
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -8,13 +8,15 @@ import BitwardenKit
class MockAccountTokenProvider: AccountTokenProvider {
var delegate: AccountTokenProviderDelegate?
var getTokenResult: Result<String, Error> = .success("ACCESS_TOKEN")
var refreshTokenCalled = false
var refreshTokenResult: Result<Void, Error> = .success(())

func getToken() async throws -> String {
try getTokenResult.get()
}

func refreshToken() async throws {
refreshTokenCalled = true
try refreshTokenResult.get()
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -99,6 +99,9 @@ class DefaultNotificationService: NotificationService {
/// The API service used to make notification requests.
private let notificationAPIService: NotificationAPIService

/// The API service used to refresh tokens.
private let refreshableApiService: RefreshableAPIService

/// The service used by the application to manage account state.
private let stateService: StateService

Expand All @@ -115,6 +118,7 @@ class DefaultNotificationService: NotificationService {
/// - authService: The service used by the application to handle authentication tasks.
/// - errorReporter: The service used by the application to report non-fatal errors.
/// - notificationAPIService: The API service used to make notification requests.
/// - refreshableApiService: The API service used to refresh tokens.
/// - stateService: The service used by the application to manage account state.
/// - syncService: The service used to handle syncing vault data with the API.
init(
Expand All @@ -123,6 +127,7 @@ class DefaultNotificationService: NotificationService {
authService: AuthService,
errorReporter: ErrorReporter,
notificationAPIService: NotificationAPIService,
refreshableApiService: RefreshableAPIService,
stateService: StateService,
syncService: SyncService
) {
Expand All @@ -131,6 +136,7 @@ class DefaultNotificationService: NotificationService {
self.authService = authService
self.errorReporter = errorReporter
self.notificationAPIService = notificationAPIService
self.refreshableApiService = refreshableApiService
self.stateService = stateService
self.syncService = syncService
}
Expand Down Expand Up @@ -208,6 +214,7 @@ class DefaultNotificationService: NotificationService {
.syncVault:
try await syncService.fetchSync(forceSync: false)
case .syncOrgKeys:
try await refreshableApiService.refreshAccessToken()
try await syncService.fetchSync(forceSync: true)
case .logOut:
guard let data: UserNotification = notificationData.data() else { return }
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@ class NotificationServiceTests: BitwardenTestCase { // swiftlint:disable:this ty
// MARK: Properties

var appSettingsStore: MockAppSettingsStore!
var refreshableApiService: MockRefreshableAPIService!
var authRepository: MockAuthRepository!
var authService: MockAuthService!
var client: MockHTTPClient!
Expand All @@ -30,6 +31,7 @@ class NotificationServiceTests: BitwardenTestCase { // swiftlint:disable:this ty
delegate = MockNotificationServiceDelegate()
errorReporter = MockErrorReporter()
notificationAPIService = APIService(client: client)
refreshableApiService = MockRefreshableAPIService()
stateService = MockStateService()
syncService = MockSyncService()

Expand All @@ -39,6 +41,7 @@ class NotificationServiceTests: BitwardenTestCase { // swiftlint:disable:this ty
authService: authService,
errorReporter: errorReporter,
notificationAPIService: notificationAPIService,
refreshableApiService: refreshableApiService,
stateService: stateService,
syncService: syncService
)
Expand All @@ -54,6 +57,7 @@ class NotificationServiceTests: BitwardenTestCase { // swiftlint:disable:this ty
delegate = nil
errorReporter = nil
notificationAPIService = nil
refreshableApiService = nil
stateService = nil
subject = nil
syncService = nil
Expand Down Expand Up @@ -349,9 +353,33 @@ class NotificationServiceTests: BitwardenTestCase { // swiftlint:disable:this ty
await subject.messageReceived(message, notificationDismissed: nil, notificationTapped: nil)

// Confirm the results.
XCTAssertTrue(refreshableApiService.refreshAccessTokenCalled)
XCTAssertTrue(syncService.didFetchSync)
}

/// `messageReceived(_:notificationDismissed:notificationTapped:)` doesn't sync on
/// `.syncOrgKeys` when refreshing the token fails.
func test_messageReceived_syncOrgKeysRefreshThrows() async throws {
// Set up the mock data.
stateService.setIsAuthenticated()
appSettingsStore.appId = "10"
let message: [AnyHashable: Any] = [
"data": [
"type": NotificationType.syncOrgKeys.rawValue,
"payload": "anything",
],
]
refreshableApiService.refreshAccessTokenThrowableError = BitwardenTestError.example

// Test.
await subject.messageReceived(message, notificationDismissed: nil, notificationTapped: nil)

// Confirm the results.
XCTAssertTrue(refreshableApiService.refreshAccessTokenCalled)
XCTAssertFalse(syncService.didFetchSync)
XCTAssertEqual(errorReporter.errors as? [BitwardenTestError], [.example])
}

/// `messageReceived(_:notificationDismissed:notificationTapped:)` handles messages appropriately.
func test_messageReceived_fetchSync() async throws {
// Set up the mock data.
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -691,6 +691,7 @@ public class ServiceContainer: Services { // swiftlint:disable:this type_body_le
authService: authService,
errorReporter: errorReporter,
notificationAPIService: apiService,
refreshableApiService: apiService,
stateService: stateService,
syncService: syncService
)
Expand Down