Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Assist last used: remember STT and record proactively (before connected) #3755

Merged
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -74,6 +74,7 @@ class AssistActivity : BaseActivity() {

if (savedInstanceState == null) {
viewModel.onCreate(
hasPermission = hasRecordingPermission(),
serverId = if (intent.hasExtra(EXTRA_SERVER)) {
intent.getIntExtra(EXTRA_SERVER, ServerManager.SERVER_ID_ACTIVE)
} else {
Expand Down Expand Up @@ -137,9 +138,7 @@ class AssistActivity : BaseActivity() {

override fun onResume() {
super.onResume()
viewModel.setPermissionInfo(
ContextCompat.checkSelfPermission(this, Manifest.permission.RECORD_AUDIO) == PackageManager.PERMISSION_GRANTED
) { requestPermission.launch(Manifest.permission.RECORD_AUDIO) }
viewModel.setPermissionInfo(hasRecordingPermission()) { requestPermission.launch(Manifest.permission.RECORD_AUDIO) }
}

override fun onPause() {
Expand All @@ -152,4 +151,7 @@ class AssistActivity : BaseActivity() {
this.intent = intent
viewModel.onNewIntent(intent)
}

private fun hasRecordingPermission() =
ContextCompat.checkSelfPermission(this, Manifest.permission.RECORD_AUDIO) == PackageManager.PERMISSION_GRANTED
}
Original file line number Diff line number Diff line change
Expand Up @@ -53,22 +53,37 @@ class AssistViewModel @Inject constructor(
var inputMode by mutableStateOf<AssistInputMode?>(null)
private set

fun onCreate(serverId: Int?, pipelineId: String?, startListening: Boolean?) {
fun onCreate(hasPermission: Boolean, serverId: Int?, pipelineId: String?, startListening: Boolean?) {
viewModelScope.launch {
this@AssistViewModel.hasPermission = hasPermission
serverId?.let {
filteredServerId = serverId
selectedServerId = serverId
}
startListening?.let { recorderAutoStart = it }

val supported = checkSupport()
if (!serverManager.isRegistered()) {
inputMode = AssistInputMode.BLOCKED
_conversation.clear()
_conversation.add(
AssistMessage(app.getString(commonR.string.not_registered), isInput = false)
)
} else if (supported == null) { // Couldn't get config
return@launch
}

if (
pipelineId == PIPELINE_LAST_USED && recorderAutoStart &&
hasPermission && hasMicrophone &&
serverManager.getServer(selectedServerId) != null &&
serverManager.integrationRepository(selectedServerId).getLastUsedPipelineSttSupport()
) {
// Start microphone recording to prevent missing voice input while doing network checks
onMicrophoneInput(proactive = true)
}

val supported = checkSupport()
if (supported != true) stopRecording()
if (supported == null) { // Couldn't get config
inputMode = AssistInputMode.BLOCKED
_conversation.clear()
_conversation.add(
Expand All @@ -86,7 +101,7 @@ class AssistViewModel @Inject constructor(
} else {
setPipeline(
when {
pipelineId == PIPELINE_LAST_USED -> serverManager.integrationRepository(selectedServerId).getLastUsedPipeline()
pipelineId == PIPELINE_LAST_USED -> serverManager.integrationRepository(selectedServerId).getLastUsedPipelineId()
pipelineId == PIPELINE_PREFERRED -> null
pipelineId?.isNotBlank() == true -> pipelineId
else -> null
Expand Down Expand Up @@ -169,15 +184,15 @@ class AssistViewModel @Inject constructor(
id = it.id,
name = it.name
)
serverManager.integrationRepository(selectedServerId).setLastUsedPipeline(it.id)
serverManager.integrationRepository(selectedServerId).setLastUsedPipeline(it.id, it.sttEngine != null)

_conversation.clear()
_conversation.add(startMessage)
clearPipelineData()
if (hasMicrophone && it.sttEngine != null) {
if (recorderAutoStart && (hasPermission || requestSilently)) {
inputMode = AssistInputMode.VOICE_INACTIVE
onMicrophoneInput()
onMicrophoneInput(proactive = null)
} else { // already requested permission once and was denied
inputMode = AssistInputMode.TEXT
}
Expand Down Expand Up @@ -219,31 +234,37 @@ class AssistViewModel @Inject constructor(

fun onTextInput(input: String) = runAssistPipeline(input)

fun onMicrophoneInput() {
/**
* Start/stop microphone input for Assist, depending on the current state.
* @param proactive true if proactive, null if not important, false if not
*/
fun onMicrophoneInput(proactive: Boolean? = false) {
if (!hasPermission) {
requestPermission?.let { it() }
return
}

if (inputMode == AssistInputMode.VOICE_ACTIVE) {
if (inputMode == AssistInputMode.VOICE_ACTIVE && proactive == false) {
stopRecording()
return
}

val recording = try {
audioRecorder.startRecording()
recorderProactive || audioRecorder.startRecording()
} catch (e: Exception) {
Log.e(TAG, "Exception while starting recording", e)
false
}

if (recording) {
setupRecorderQueue()
if (!recorderProactive) setupRecorderQueue()
inputMode = AssistInputMode.VOICE_ACTIVE
runAssistPipeline(null)
if (proactive == true) _conversation.add(AssistMessage("…", isInput = true))
if (proactive != true) runAssistPipeline(null)
} else {
_conversation.add(AssistMessage(app.getString(commonR.string.assist_error), isInput = false, isError = true))
}
recorderProactive = recording && proactive == true
}

private fun runAssistPipeline(text: String?) {
Expand All @@ -269,6 +290,9 @@ class AssistViewModel @Inject constructor(
_conversation.add(haMessage)
message = haMessage
}
if (isError && inputMode == AssistInputMode.VOICE_ACTIVE) {
stopRecording()
}
}
}
}
Expand All @@ -280,15 +304,16 @@ class AssistViewModel @Inject constructor(

fun onPermissionResult(granted: Boolean) {
hasPermission = granted
val proactive = currentPipeline == null
if (granted) {
inputMode = AssistInputMode.VOICE_INACTIVE
onMicrophoneInput()
} else if (requestSilently) { // Don't notify the user if they haven't explicitly requested
onMicrophoneInput(proactive = proactive)
} else if (requestSilently && !proactive) { // Don't notify the user if they haven't explicitly requested
inputMode = AssistInputMode.TEXT
} else {
} else if (!requestSilently) {
_conversation.add(AssistMessage(app.getString(commonR.string.assist_permission), isInput = false))
}
requestSilently = false
if (!proactive) requestSilently = false
}

fun onPause() {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,7 @@ abstract class AssistViewModelBase(

protected var selectedServerId = ServerManager.SERVER_ID_ACTIVE

protected var recorderProactive = false
private var recorderJob: Job? = null
private var recorderQueue: MutableList<ByteArray>? = null
protected val hasMicrophone = app.packageManager.hasSystemFeature(PackageManager.FEATURE_MICROPHONE)
Expand Down Expand Up @@ -99,8 +100,11 @@ abstract class AssistViewModelBase(
}
AssistPipelineEventType.STT_START -> {
viewModelScope.launch {
recorderQueue?.forEach { item ->
sendVoiceData(item)
binaryHandlerId?.let { id ->
// Manually loop here to avoid the queue being reset too soon
recorderQueue?.forEach { data ->
serverManager.webSocketRepository(selectedServerId).sendVoiceData(id, data)
}
}
recorderQueue = null
}
Expand Down Expand Up @@ -156,7 +160,7 @@ abstract class AssistViewModelBase(
binaryHandlerId?.let {
viewModelScope.launch {
// Launch to prevent blocking the output flow if the network is slow
serverManager.webSocketRepository().sendVoiceData(it, data)
serverManager.webSocketRepository(selectedServerId).sendVoiceData(it, data)
}
}
}
Expand Down Expand Up @@ -186,8 +190,9 @@ abstract class AssistViewModelBase(
recorderQueue = null
}
if (getInput() == AssistInputMode.VOICE_ACTIVE) {
setInput(AssistInputMode.VOICE_INACTIVE)
setInput(if (recorderProactive) AssistInputMode.BLOCKED else AssistInputMode.VOICE_INACTIVE)
}
recorderProactive = false
}

protected fun stopPlayback() = audioUrlPlayer.stop()
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -63,9 +63,11 @@ interface IntegrationRepository {
conversationId: String? = null
): Flow<AssistPipelineEvent>?

suspend fun getLastUsedPipeline(): String?
suspend fun getLastUsedPipelineId(): String?

suspend fun setLastUsedPipeline(pipelineId: String)
suspend fun getLastUsedPipelineSttSupport(): Boolean

suspend fun setLastUsedPipeline(pipelineId: String, supportsStt: Boolean)
}

@AssistedFactory
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -62,7 +62,8 @@ class IntegrationRepositoryImpl @AssistedInject constructor(
private const val PREF_SESSION_EXPIRE = "session_expire"
private const val PREF_TRUSTED = "trusted"
private const val PREF_SEC_WARNING_NEXT = "sec_warning_last"
private const val PREF_LAST_USED_PIPELINE = "last_used_pipeline"
private const val PREF_LAST_USED_PIPELINE_ID = "last_used_pipeline"
private const val PREF_LAST_USED_PIPELINE_STT = "last_used_pipeline_stt"
private const val TAG = "IntegrationRepository"
private const val RATE_LIMIT_URL = BuildConfig.RATE_LIMIT_URL

Expand Down Expand Up @@ -166,7 +167,8 @@ class IntegrationRepositoryImpl @AssistedInject constructor(
localStorage.remove("${serverId}_$PREF_SESSION_EXPIRE")
localStorage.remove("${serverId}_$PREF_TRUSTED")
localStorage.remove("${serverId}_$PREF_SEC_WARNING_NEXT")
localStorage.remove("${serverId}_$PREF_LAST_USED_PIPELINE")
localStorage.remove("${serverId}_$PREF_LAST_USED_PIPELINE_ID")
localStorage.remove("${serverId}_$PREF_LAST_USED_PIPELINE_STT")
// app version and push token is device-specific
}

Expand Down Expand Up @@ -552,11 +554,16 @@ class IntegrationRepositoryImpl @AssistedInject constructor(
}
}

override suspend fun getLastUsedPipeline(): String? =
localStorage.getString("${serverId}_$PREF_LAST_USED_PIPELINE")
override suspend fun getLastUsedPipelineId(): String? =
localStorage.getString("${serverId}_$PREF_LAST_USED_PIPELINE_ID")

override suspend fun setLastUsedPipeline(pipelineId: String) =
localStorage.putString("${serverId}_$PREF_LAST_USED_PIPELINE", pipelineId)
override suspend fun getLastUsedPipelineSttSupport(): Boolean =
localStorage.getBoolean("${serverId}_$PREF_LAST_USED_PIPELINE_STT")

override suspend fun setLastUsedPipeline(pipelineId: String, supportsStt: Boolean) {
localStorage.putString("${serverId}_$PREF_LAST_USED_PIPELINE_ID", pipelineId)
localStorage.putBoolean("${serverId}_$PREF_LAST_USED_PIPELINE_STT", supportsStt)
}

override suspend fun getEntities(): List<Entity<Any>>? {
val response = webSocketRepository.getStates()
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -48,7 +48,7 @@ class ConversationActivity : ComponentActivity() {
super.onCreate(savedInstanceState)

lifecycleScope.launch {
val launchIntent = conversationViewModel.onCreate()
val launchIntent = conversationViewModel.onCreate(hasRecordingPermission())
if (launchIntent) {
launchVoiceInputIntent()
}
Expand All @@ -64,9 +64,7 @@ class ConversationActivity : ComponentActivity() {

override fun onResume() {
super.onResume()
conversationViewModel.setPermissionInfo(
ContextCompat.checkSelfPermission(this, Manifest.permission.RECORD_AUDIO) == PackageManager.PERMISSION_GRANTED
) { requestPermission.launch(Manifest.permission.RECORD_AUDIO) }
conversationViewModel.setPermissionInfo(hasRecordingPermission()) { requestPermission.launch(Manifest.permission.RECORD_AUDIO) }
}

override fun onPause() {
Expand All @@ -88,6 +86,9 @@ class ConversationActivity : ComponentActivity() {
}
}

private fun hasRecordingPermission() =
ContextCompat.checkSelfPermission(this, Manifest.permission.RECORD_AUDIO) == PackageManager.PERMISSION_GRANTED

private fun launchVoiceInputIntent() {
val searchIntent = Intent(RecognizerIntent.ACTION_RECOGNIZE_SPEECH).apply {
putExtra(
Expand Down