Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -17,12 +17,8 @@
package com.google.firebase.ai.common.util

import android.media.AudioRecord
import kotlin.time.Duration.Companion.milliseconds
import kotlinx.coroutines.delay
import kotlinx.coroutines.flow.callbackFlow
import kotlinx.coroutines.flow.flow
import kotlinx.coroutines.isActive
import kotlinx.coroutines.yield

/**
* The minimum buffer size for this instance.
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -162,7 +162,10 @@ internal class AudioHelper(
fun build(): AudioHelper {
val playbackTrack =
AudioTrack(
AudioAttributes.Builder().setUsage(AudioAttributes.USAGE_MEDIA).setContentType(AudioAttributes.CONTENT_TYPE_SPEECH).build(),
AudioAttributes.Builder()
.setUsage(AudioAttributes.USAGE_MEDIA)
.setContentType(AudioAttributes.CONTENT_TYPE_SPEECH)
.build(),
AudioFormat.Builder()
.setSampleRate(24000)
.setChannelMask(AudioFormat.CHANNEL_OUT_MONO)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -17,12 +17,17 @@
package com.google.firebase.ai.type

import android.Manifest.permission.RECORD_AUDIO
import android.annotation.SuppressLint
import android.content.pm.PackageManager
import android.media.AudioFormat
import android.media.AudioTrack
import android.os.Process
import android.os.StrictMode
import android.os.StrictMode.ThreadPolicy
import android.util.Log
import androidx.annotation.RequiresPermission
import androidx.core.content.ContextCompat
import com.google.firebase.BuildConfig
import com.google.firebase.FirebaseApp
import com.google.firebase.ai.common.JSON
import com.google.firebase.ai.common.util.CancelledCoroutineScope
Expand All @@ -33,19 +38,23 @@ import io.ktor.client.plugins.websocket.DefaultClientWebSocketSession
import io.ktor.websocket.Frame
import io.ktor.websocket.close
import io.ktor.websocket.readBytes
import kotlinx.coroutines.CoroutineName
import java.util.concurrent.ConcurrentLinkedQueue
import java.util.concurrent.Executors
import java.util.concurrent.ThreadFactory
import java.util.concurrent.atomic.AtomicBoolean
import java.util.concurrent.atomic.AtomicLong
import kotlin.coroutines.CoroutineContext
import kotlinx.coroutines.CoroutineName
import kotlinx.coroutines.CoroutineScope
import kotlinx.coroutines.Dispatchers
import kotlinx.coroutines.asCoroutineDispatcher
import kotlinx.coroutines.cancel
import kotlinx.coroutines.channels.Channel.Factory.UNLIMITED
import kotlinx.coroutines.delay
import kotlinx.coroutines.flow.Flow
import kotlinx.coroutines.flow.buffer
import kotlinx.coroutines.flow.catch
import kotlinx.coroutines.flow.flow
import kotlinx.coroutines.flow.flowOn
import kotlinx.coroutines.flow.launchIn
import kotlinx.coroutines.flow.onCompletion
import kotlinx.coroutines.flow.onEach
Expand All @@ -67,11 +76,21 @@ internal constructor(
private val firebaseApp: FirebaseApp,
) {
/**
* Coroutine scope that we batch data on for [startAudioConversation].
* Coroutine scope that we batch data on for network related behavior.
*
* Makes it easy to stop all the work with [stopAudioConversation] by just cancelling the scope.
*/
private var networkScope = CancelledCoroutineScope

/**
* Coroutine scope that we batch data on for audio recording and playback.
*
* Separate from [networkScope] to ensure interchanging of dispatchers doesn't cause any deadlocks
* or issues.
*
* Makes it easy to stop all the work with [stopAudioConversation] by just cancelling the scope.
*/
private var scope = CancelledCoroutineScope
private var audioScope = CancelledCoroutineScope

/**
* Playback audio data sent from the model.
Expand Down Expand Up @@ -129,16 +148,17 @@ internal constructor(
}

FirebaseAIException.catchAsync {
if (scope.isActive) {
if (networkScope.isActive || audioScope.isActive) {
Log.w(
TAG,
"startAudioConversation called after the recording has already started. " +
"Call stopAudioConversation to close the previous connection."
)
return@catchAsync
}
// TODO: maybe it should be THREAD_PRIORITY_AUDIO anyways for playback and recording (not network though)
scope = CoroutineScope(blockingDispatcher + childJob() + CoroutineName("LiveSession Scope"))
networkScope =
CoroutineScope(blockingDispatcher + childJob() + CoroutineName("LiveSession Network"))
audioScope = CoroutineScope(audioDispatcher + childJob() + CoroutineName("LiveSession Audio"))
audioHelper = AudioHelper.build()

recordUserAudio()
Expand All @@ -158,7 +178,8 @@ internal constructor(
FirebaseAIException.catch {
if (!startedReceiving.getAndSet(false)) return@catch

scope.cancel()
networkScope.cancel()
audioScope.cancel()
playBackQueue.clear()

audioHelper?.release()
Expand Down Expand Up @@ -228,7 +249,8 @@ internal constructor(
FirebaseAIException.catch {
if (!startedReceiving.getAndSet(false)) return@catch

scope.cancel()
networkScope.cancel()
audioScope.cancel()
playBackQueue.clear()

audioHelper?.release()
Expand Down Expand Up @@ -325,21 +347,22 @@ internal constructor(
audioHelper
?.listenToRecording()
?.buffer(UNLIMITED)
?.flowOn(audioDispatcher)
?.accumulateUntil(MIN_BUFFER_SIZE)
?.onEach {
sendMediaStream(listOf(MediaData(it, "audio/pcm")))
delay(0)
}
?.catch { throw FirebaseAIException.from(it) }
?.launchIn(scope)
?.launchIn(networkScope)
}

/**
* Processes responses from the model during an audio conversation.
*
* Audio messages are added to [playBackQueue].
*
* Launched asynchronously on [scope].
* Launched asynchronously on [networkScope].
*
* @param functionCallHandler A callback function that is invoked whenever the server receives a
* function call.
Expand Down Expand Up @@ -393,18 +416,18 @@ internal constructor(
}
}
}
.launchIn(scope)
.launchIn(networkScope)
}

/**
* Listens for playback data from the model and plays the audio.
*
* Polls [playBackQueue] for data, and calls [AudioHelper.playAudio] when data is received.
*
* Launched asynchronously on [scope].
* Launched asynchronously on [networkScope].
*/
private fun listenForModelPlayback(enableInterruptions: Boolean = false) {
scope.launch {
audioScope.launch {
while (isActive) {
val playbackData = playBackQueue.poll()
if (playbackData == null) {
Expand Down Expand Up @@ -490,5 +513,38 @@ internal constructor(
AudioFormat.CHANNEL_OUT_MONO,
AudioFormat.ENCODING_PCM_16BIT
)
@SuppressLint("ThreadPoolCreation")
val audioDispatcher =
Executors.newCachedThreadPool(AudioThreadFactory()).asCoroutineDispatcher()
}
}

internal class AudioThreadFactory : ThreadFactory {
private val threadCount = AtomicLong()
private val policy: ThreadPolicy = audioPolicy()

override fun newThread(task: Runnable?): Thread? {
val thread =
DEFAULT.newThread {
Process.setThreadPriority(Process.THREAD_PRIORITY_AUDIO)
StrictMode.setThreadPolicy(policy)
task?.run()
}
thread.name = "Firebase Audio Thread #${threadCount.andIncrement}"
return thread
}

companion object {
val DEFAULT: ThreadFactory = Executors.defaultThreadFactory()

private fun audioPolicy(): ThreadPolicy {
val builder = ThreadPolicy.Builder().detectNetwork()

if (BuildConfig.DEBUG) {
builder.penaltyDeath()
}

return builder.penaltyLog().build()
}
}
}