Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions firebase-dataconnect/CHANGELOG.md
Original file line number Diff line number Diff line change
@@ -1,5 +1,8 @@
# Unreleased

- [changed] Internal refactor for managing Auth and App Check tokens
([#7184](https://github.com/firebase/firebase-android-sdk/pull/7184))

# 17.1.0

- [fixed] Addressed minor reference documentation issues (#7399)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@ import com.google.firebase.annotations.DeferredApi
import com.google.firebase.appcheck.AppCheckTokenResult
import com.google.firebase.appcheck.interop.AppCheckTokenListener
import com.google.firebase.appcheck.interop.InteropAppCheckTokenProvider
import com.google.firebase.dataconnect.core.DataConnectAppCheck.GetAppCheckTokenResult
import com.google.firebase.dataconnect.core.Globals.toScrubbedAccessToken
import com.google.firebase.dataconnect.core.LoggerGlobals.debug
import kotlinx.coroutines.CoroutineDispatcher
Expand All @@ -32,7 +33,7 @@ internal class DataConnectAppCheck(
blockingDispatcher: CoroutineDispatcher,
logger: Logger,
) :
DataConnectCredentialsTokenManager<InteropAppCheckTokenProvider>(
DataConnectCredentialsTokenManager<InteropAppCheckTokenProvider, GetAppCheckTokenResult>(
deferredProvider = deferredAppCheckTokenProvider,
parentCoroutineScope = parentCoroutineScope,
blockingDispatcher = blockingDispatcher,
Expand All @@ -48,7 +49,9 @@ internal class DataConnectAppCheck(
provider.removeAppCheckTokenListener(appCheckTokenListener)

override suspend fun getToken(provider: InteropAppCheckTokenProvider, forceRefresh: Boolean) =
provider.getToken(forceRefresh).await().let { GetTokenResult(it.token) }
provider.getToken(forceRefresh).await().let { GetAppCheckTokenResult(it.token) }

data class GetAppCheckTokenResult(override val token: String?) : GetTokenResult

private class AppCheckTokenListenerImpl(private val logger: Logger) : AppCheckTokenListener {
override fun onAppCheckTokenChanged(tokenResult: AppCheckTokenResult) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@ package com.google.firebase.dataconnect.core
import com.google.firebase.annotations.DeferredApi
import com.google.firebase.auth.internal.IdTokenListener
import com.google.firebase.auth.internal.InternalAuthProvider
import com.google.firebase.dataconnect.core.DataConnectAuth.GetAuthTokenResult
import com.google.firebase.dataconnect.core.Globals.toScrubbedAccessToken
import com.google.firebase.dataconnect.core.LoggerGlobals.debug
import com.google.firebase.internal.InternalTokenResult
Expand All @@ -32,7 +33,7 @@ internal class DataConnectAuth(
blockingDispatcher: CoroutineDispatcher,
logger: Logger,
) :
DataConnectCredentialsTokenManager<InternalAuthProvider>(
DataConnectCredentialsTokenManager<InternalAuthProvider, GetAuthTokenResult>(
deferredProvider = deferredAuthProvider,
parentCoroutineScope = parentCoroutineScope,
blockingDispatcher = blockingDispatcher,
Expand All @@ -48,7 +49,9 @@ internal class DataConnectAuth(
provider.removeIdTokenListener(idTokenListener)

override suspend fun getToken(provider: InternalAuthProvider, forceRefresh: Boolean) =
provider.getAccessToken(forceRefresh).await().let { GetTokenResult(it.token) }
provider.getAccessToken(forceRefresh).await().let { GetAuthTokenResult(it.token) }

data class GetAuthTokenResult(override val token: String?) : GetTokenResult

private class IdTokenListenerImpl(private val logger: Logger) : IdTokenListener {
override fun onIdTokenChanged(tokenResult: InternalTokenResult) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@ import com.google.firebase.annotations.DeferredApi
import com.google.firebase.appcheck.interop.InteropAppCheckTokenProvider
import com.google.firebase.auth.internal.InternalAuthProvider
import com.google.firebase.dataconnect.DataConnectException
import com.google.firebase.dataconnect.core.DataConnectCredentialsTokenManager.GetTokenResult
import com.google.firebase.dataconnect.core.Globals.toScrubbedAccessToken
import com.google.firebase.dataconnect.core.LoggerGlobals.debug
import com.google.firebase.dataconnect.core.LoggerGlobals.warn
Expand Down Expand Up @@ -52,7 +53,7 @@ import kotlinx.coroutines.flow.update
import kotlinx.coroutines.launch

/** Base class that shares logic for managing the Auth token and AppCheck token. */
internal sealed class DataConnectCredentialsTokenManager<T : Any>(
internal sealed class DataConnectCredentialsTokenManager<T : Any, R : GetTokenResult>(
private val deferredProvider: com.google.firebase.inject.Deferred<T>,
parentCoroutineScope: CoroutineScope,
private val blockingDispatcher: CoroutineDispatcher,
Expand All @@ -75,13 +76,13 @@ internal sealed class DataConnectCredentialsTokenManager<T : Any>(
}
)

private sealed interface State<out T> {
private sealed interface State<out T, out R : GetTokenResult> {

/**
* State indicating that the object has just been created and [initialize] has not yet been
* called.
*/
object New : State<Nothing>
object New : State<Nothing, Nothing>

/**
* State indicating that [initialize] has been invoked but the token provider is not (yet?)
Expand All @@ -93,33 +94,33 @@ internal sealed class DataConnectCredentialsTokenManager<T : Any>(
}

/** State indicating that [close] has been invoked. */
object Closed : State<Nothing>
object Closed : State<Nothing, Nothing>

sealed interface StateWithForceTokenRefresh<out T> : State<T> {
sealed interface StateWithForceTokenRefresh<out T> : State<T, Nothing> {
/** The value to specify for `forceRefresh` on the next invocation of [getToken]. */
val forceTokenRefresh: Boolean
}

sealed interface StateWithProvider<out T> : State<T> {
sealed interface StateWithProvider<out T, out R : GetTokenResult> : State<T, R> {
/** The token provider, [InternalAuthProvider] or [InteropAppCheckTokenProvider] */
val provider: T
}

/** State indicating that there is no outstanding "get token" request. */
data class Idle<T>(override val provider: T, override val forceTokenRefresh: Boolean) :
StateWithProvider<T>, StateWithForceTokenRefresh<T>
StateWithProvider<T, Nothing>, StateWithForceTokenRefresh<T>

/** State indicating that there _is_ an outstanding "get token" request. */
data class Active<out T>(
data class Active<out T, out R : GetTokenResult>(
override val provider: T,

/** The job that is performing the "get token" request. */
val job: Deferred<SequencedReference<Result<GetTokenResult>>>
) : StateWithProvider<T>
val job: Deferred<SequencedReference<Result<R>>>
) : StateWithProvider<T, R>
}

/** The current state of this object. */
private val state = MutableStateFlow<State<T>>(State.New)
private val state = MutableStateFlow<State<T, R>>(State.New)

/**
* Adds the token listener to the given provider.
Expand All @@ -139,7 +140,7 @@ internal sealed class DataConnectCredentialsTokenManager<T : Any>(
* Starts an asynchronous task to get a new access token from the given provider, forcing a token
* refresh if and only if `forceRefresh` is true.
*/
protected abstract suspend fun getToken(provider: T, forceRefresh: Boolean): GetTokenResult
protected abstract suspend fun getToken(provider: T, forceRefresh: Boolean): R

/**
* Initializes this object.
Expand Down Expand Up @@ -274,7 +275,7 @@ internal sealed class DataConnectCredentialsTokenManager<T : Any>(
invocationId: String,
provider: T,
forceRefresh: Boolean
): State.Active<T> {
): State.Active<T, R> {
val coroutineName =
CoroutineName(
"$instanceId 535gmcvv5a $invocationId getToken(" +
Expand All @@ -296,14 +297,14 @@ internal sealed class DataConnectCredentialsTokenManager<T : Any>(
* @throws DataConnectException if [close] has been called or is called while the operation is in
* progress.
*/
suspend fun getToken(requestId: String): String? {
suspend fun getToken(requestId: String): R? {
val invocationId = "gat" + Random.nextAlphanumericString(length = 8)
logger.debug { "$invocationId getToken(requestId=$requestId)" }
while (true) {
val attemptSequenceNumber = nextSequenceNumber()
val oldState = state.value

val newState: State.Active<T> =
val newState: State.Active<T, R> =
when (oldState) {
is State.New ->
throw IllegalStateException("initialize() must be called before getToken()")
Expand Down Expand Up @@ -381,11 +382,12 @@ internal sealed class DataConnectCredentialsTokenManager<T : Any>(
}
}

val accessToken = sequencedResult!!.ref.getOrThrow().token
val tokenResult: R = sequencedResult!!.ref.getOrThrow()
logger.debug {
"$invocationId getToken() returns retrieved token: ${accessToken?.toScrubbedAccessToken()}"
"$invocationId getToken() returns retrieved token: " +
tokenResult.token?.toScrubbedAccessToken()
}
return accessToken
return tokenResult
}
}

Expand Down Expand Up @@ -440,16 +442,17 @@ internal sealed class DataConnectCredentialsTokenManager<T : Any>(
* strong reference to the [DataConnectCredentialsTokenManager] instance indefinitely, in the case
* that the callback never occurs.
*/
private class DeferredProviderHandlerImpl<T : Any>(
private val weakCredentialsTokenManagerRef: WeakReference<DataConnectCredentialsTokenManager<T>>
private class DeferredProviderHandlerImpl<T : Any, R : GetTokenResult>(
private val weakCredentialsTokenManagerRef:
WeakReference<DataConnectCredentialsTokenManager<T, R>>
) : DeferredHandler<T> {
override fun handle(provider: Provider<T>) {
weakCredentialsTokenManagerRef.get()?.onProviderAvailable(provider.get())
}
}

private class CredentialsTokenManagerClosedException(
tokenProvider: DataConnectCredentialsTokenManager<*>
tokenProvider: DataConnectCredentialsTokenManager<*, *>
) :
DataConnectException(
"DataConnectCredentialsTokenManager ${tokenProvider.instanceId} was closed (code cqrbq4zfvy)"
Expand All @@ -458,7 +461,9 @@ internal sealed class DataConnectCredentialsTokenManager<T : Any>(
private class GetTokenCancelledException(cause: Throwable) :
DataConnectException("getToken() was cancelled, likely by close() (code rqdd4jam9d)", cause)

protected data class GetTokenResult(val token: String?)
interface GetTokenResult {
val token: String?
}

private companion object {

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -85,12 +85,8 @@ internal class DataConnectGrpcMetadata(
if (appId.isNotBlank()) {
it.put(gmpAppIdHeader, appId)
}
if (authToken !== null) {
it.put(firebaseAuthTokenHeader, authToken)
}
if (appCheckToken !== null) {
it.put(firebaseAppCheckTokenHeader, appCheckToken)
}
authToken?.token?.let { token -> it.put(firebaseAuthTokenHeader, token) }
appCheckToken?.token?.let { token -> it.put(firebaseAppCheckTokenHeader, token) }
}
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -49,6 +49,7 @@ import io.kotest.assertions.withClue
import io.kotest.matchers.collections.shouldContain
import io.kotest.matchers.collections.shouldContainExactly
import io.kotest.matchers.nulls.shouldBeNull
import io.kotest.matchers.nulls.shouldNotBeNull
import io.kotest.matchers.shouldBe
import io.kotest.matchers.types.shouldBeSameInstanceAs
import io.kotest.property.Arb
Expand Down Expand Up @@ -302,7 +303,7 @@ class DataConnectAuthUnitTest {

val result = dataConnectAuth.getToken(requestId)

withClue("result=$result") { result shouldBe accessToken }
withClue("result=$result") { result.shouldNotBeNull().token shouldBe accessToken }
mockLogger.shouldHaveLoggedExactlyOneMessageContaining(requestId)
mockLogger.shouldHaveLoggedExactlyOneMessageContaining(
"returns retrieved token: ${accessToken.toScrubbedAccessToken()}"
Expand Down Expand Up @@ -363,7 +364,7 @@ class DataConnectAuthUnitTest {
dataConnectAuth.forceRefresh()
val result = dataConnectAuth.getToken(requestId)

withClue("result=$result") { result shouldBe accessToken }
withClue("result=$result") { result.shouldNotBeNull().token shouldBe accessToken }
verify(exactly = 1) { mockInternalAuthProvider.getAccessToken(true) }
verify(exactly = 0) { mockInternalAuthProvider.getAccessToken(false) }
mockLogger.shouldHaveLoggedExactlyOneMessageContaining(requestId)
Expand Down Expand Up @@ -419,7 +420,7 @@ class DataConnectAuthUnitTest {
taskForToken(accessTokenGenerator.next().also { tokens.add(it) })
}

val results = List(5) { dataConnectAuth.getToken(requestId) }
val results = List(5) { dataConnectAuth.getToken(requestId)?.token }

results shouldContainExactly tokens
}
Expand Down Expand Up @@ -447,7 +448,7 @@ class DataConnectAuthUnitTest {
}
}

val actualTokens = jobs.map { it.await() }
val actualTokens = jobs.map { it.await()?.token }
actualTokens.forEachIndexed { index, token ->
withClue("actualTokens[$index]") { tokens shouldContain token }
}
Expand Down Expand Up @@ -481,7 +482,7 @@ class DataConnectAuthUnitTest {

val result = dataConnectAuth.getToken(requestId)

withClue("result=$result") { result shouldBe tokens.last() }
withClue("result=$result") { result.shouldNotBeNull().token shouldBe tokens.last() }
verify(exactly = 2) { mockInternalAuthProvider.getAccessToken(true) }
verify(exactly = 1) { mockInternalAuthProvider.getAccessToken(false) }
mockLogger.shouldHaveLoggedAtLeastOneMessageContaining("retrying due to needs token refresh")
Expand All @@ -496,11 +497,7 @@ class DataConnectAuthUnitTest {
advanceUntilIdle()
val invocationCount = AtomicInteger(0)
val tokens = CopyOnWriteArrayList<String>()
val getTokenJob2 =
async(start = CoroutineStart.LAZY) {
val accessToken = dataConnectAuth.getToken(requestId)
accessToken
}
val getTokenJob2 = async(start = CoroutineStart.LAZY) { dataConnectAuth.getToken(requestId) }
coEvery { mockInternalAuthProvider.getAccessToken(any()) } coAnswers
{
if (invocationCount.getAndIncrement() == 0) {
Expand All @@ -509,16 +506,15 @@ class DataConnectAuthUnitTest {
getTokenJob2.start()
advanceUntilIdle()
}
val rv = taskForToken(accessTokenGenerator.next().also { tokens.add(it) })
rv
taskForToken(accessTokenGenerator.next().also { tokens.add(it) })
}

val result1 = dataConnectAuth.getToken(requestId)
withClue("getTokenJob2.isActive") { getTokenJob2.isActive shouldBe true }
val result2 = getTokenJob2.await()

withClue("result1=$result1") { result1 shouldBe tokens[0] }
withClue("result2=$result2") { result2 shouldBe tokens[1] }
withClue("result1=$result1") { result1.shouldNotBeNull().token shouldBe tokens[0] }
withClue("result2=$result2") { result2.shouldNotBeNull().token shouldBe tokens[1] }
verify(exactly = 2) { mockInternalAuthProvider.getAccessToken(false) }
verify(exactly = 0) { mockInternalAuthProvider.getAccessToken(true) }
mockLogger.shouldHaveLoggedExactlyOneMessageContaining("got an old result; retrying")
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,8 @@ import androidx.test.ext.junit.runners.AndroidJUnit4
import com.google.firebase.dataconnect.BuildConfig
import com.google.firebase.dataconnect.FirebaseDataConnect.CallerSdkType
import com.google.firebase.dataconnect.testutil.FirebaseAppUnitTestingRule
import com.google.firebase.dataconnect.testutil.property.arbitrary.appCheckTokenResult
import com.google.firebase.dataconnect.testutil.property.arbitrary.authTokenResult
import com.google.firebase.dataconnect.testutil.property.arbitrary.dataConnect
import com.google.firebase.dataconnect.testutil.property.arbitrary.dataConnectGrpcMetadata
import io.grpc.Metadata
Expand Down Expand Up @@ -175,8 +177,8 @@ class DataConnectGrpcMetadataUnitTest {
@Test
fun `should include x-firebase-auth-token when the auth token is not null`() = runTest {
val dataConnectAuth: DataConnectAuth = mockk()
val accessToken = Arb.dataConnect.accessToken().next()
coEvery { dataConnectAuth.getToken(any()) } returns accessToken
val authTokenResult = Arb.dataConnect.authTokenResult().next()
coEvery { dataConnectAuth.getToken(any()) } returns authTokenResult
val dataConnectGrpcMetadata =
Arb.dataConnect
.dataConnectGrpcMetadata(dataConnectAuth = Arb.constant(dataConnectAuth))
Expand All @@ -189,7 +191,7 @@ class DataConnectGrpcMetadataUnitTest {
metadata.asClue {
it.keys() shouldContain "x-firebase-auth-token"
val metadataKey = Metadata.Key.of("x-firebase-auth-token", Metadata.ASCII_STRING_MARSHALLER)
it.get(metadataKey) shouldBe accessToken
it.get(metadataKey) shouldBe authTokenResult.token
}
}

Expand All @@ -212,9 +214,9 @@ class DataConnectGrpcMetadataUnitTest {

@Test
fun `should include x-firebase-appcheck when the AppCheck token is not null`() = runTest {
val accessToken = Arb.dataConnect.accessToken().next()
val appCheckTokenResult = Arb.dataConnect.appCheckTokenResult().next()
val dataConnectAppCheck: DataConnectAppCheck = mockk {
coEvery { getToken(any()) } returns accessToken
coEvery { getToken(any()) } returns appCheckTokenResult
}
val dataConnectGrpcMetadata =
Arb.dataConnect
Expand All @@ -228,7 +230,7 @@ class DataConnectGrpcMetadataUnitTest {
metadata.asClue {
it.keys() shouldContain "x-firebase-appcheck"
val metadataKey = Metadata.Key.of("x-firebase-appcheck", Metadata.ASCII_STRING_MARSHALLER)
it.get(metadataKey) shouldBe accessToken
it.get(metadataKey) shouldBe appCheckTokenResult.token
}
}

Expand Down
Loading