diff --git a/firebase-ai/CHANGELOG.md b/firebase-ai/CHANGELOG.md index 0d4a2333f7d..4432555a470 100644 --- a/firebase-ai/CHANGELOG.md +++ b/firebase-ai/CHANGELOG.md @@ -1,6 +1,7 @@ # Unreleased - [changed] **Breaking Change**: Removed the `candidateCount` option from `LiveGenerationConfig` +- [changed] Added better error messages to `ServiceConnectionHandshakeFailedException` # 17.3.0 diff --git a/firebase-ai/src/main/kotlin/com/google/firebase/ai/LiveGenerativeModel.kt b/firebase-ai/src/main/kotlin/com/google/firebase/ai/LiveGenerativeModel.kt index a696ddd5f73..d5afca6b960 100644 --- a/firebase-ai/src/main/kotlin/com/google/firebase/ai/LiveGenerativeModel.kt +++ b/firebase-ai/src/main/kotlin/com/google/firebase/ai/LiveGenerativeModel.kt @@ -32,6 +32,7 @@ import com.google.firebase.ai.type.Tool import com.google.firebase.annotations.concurrent.Blocking import com.google.firebase.appcheck.interop.InteropAppCheckTokenProvider import com.google.firebase.auth.internal.InternalAuthProvider +import io.ktor.client.plugins.websocket.DefaultClientWebSocketSession import io.ktor.websocket.Frame import io.ktor.websocket.close import io.ktor.websocket.readBytes @@ -114,8 +115,9 @@ internal constructor( ) .toInternal() val data: String = Json.encodeToString(clientMessage) + var webSession: DefaultClientWebSocketSession? = null try { - val webSession = controller.getWebSocketSession(location) + webSession = controller.getWebSocketSession(location) webSession.send(Frame.Text(data)) val receivedJsonStr = webSession.incoming.receive().readBytes().toString(Charsets.UTF_8) val receivedJson = JSON.parseToJsonElement(receivedJsonStr) @@ -131,7 +133,10 @@ internal constructor( throw ServiceConnectionHandshakeFailedException("Unable to connect to the server") } } catch (e: ClosedReceiveChannelException) { - throw ServiceConnectionHandshakeFailedException("Channel was closed by the server", e) + val reason = webSession?.closeReason?.await() + val message = + "Channel was closed by the server.${if(reason!=null) " Details: ${reason.message}" else "" }" + throw ServiceConnectionHandshakeFailedException(message, e) } } diff --git a/firebase-ai/src/main/kotlin/com/google/firebase/ai/common/APIController.kt b/firebase-ai/src/main/kotlin/com/google/firebase/ai/common/APIController.kt index 720b2c50a63..b199698aa7b 100644 --- a/firebase-ai/src/main/kotlin/com/google/firebase/ai/common/APIController.kt +++ b/firebase-ai/src/main/kotlin/com/google/firebase/ai/common/APIController.kt @@ -38,7 +38,7 @@ import io.ktor.client.engine.HttpClientEngine import io.ktor.client.engine.okhttp.OkHttp import io.ktor.client.plugins.HttpTimeout import io.ktor.client.plugins.contentnegotiation.ContentNegotiation -import io.ktor.client.plugins.websocket.ClientWebSocketSession +import io.ktor.client.plugins.websocket.DefaultClientWebSocketSession import io.ktor.client.plugins.websocket.WebSockets import io.ktor.client.plugins.websocket.webSocketSession import io.ktor.client.request.HttpRequestBuilder @@ -174,7 +174,7 @@ internal constructor( "wss://firebasevertexai.googleapis.com/ws/google.firebase.vertexai.v1beta.GenerativeService/BidiGenerateContent?key=$key" } - suspend fun getWebSocketSession(location: String): ClientWebSocketSession = + suspend fun getWebSocketSession(location: String): DefaultClientWebSocketSession = client.webSocketSession(getBidiEndpoint(location)) { applyCommonHeaders() } fun generateContentStream( diff --git a/firebase-ai/src/main/kotlin/com/google/firebase/ai/type/LiveSession.kt b/firebase-ai/src/main/kotlin/com/google/firebase/ai/type/LiveSession.kt index a91d7e4aedf..ccdc3e7fe95 100644 --- a/firebase-ai/src/main/kotlin/com/google/firebase/ai/type/LiveSession.kt +++ b/firebase-ai/src/main/kotlin/com/google/firebase/ai/type/LiveSession.kt @@ -29,7 +29,7 @@ import com.google.firebase.ai.common.util.CancelledCoroutineScope import com.google.firebase.ai.common.util.accumulateUntil import com.google.firebase.ai.common.util.childJob import com.google.firebase.annotations.concurrent.Blocking -import io.ktor.client.plugins.websocket.ClientWebSocketSession +import io.ktor.client.plugins.websocket.DefaultClientWebSocketSession import io.ktor.websocket.Frame import io.ktor.websocket.close import io.ktor.websocket.readBytes @@ -59,7 +59,7 @@ import kotlinx.serialization.json.Json @OptIn(ExperimentalSerializationApi::class) public class LiveSession internal constructor( - private val session: ClientWebSocketSession, + private val session: DefaultClientWebSocketSession, @Blocking private val blockingDispatcher: CoroutineContext, private var audioHelper: AudioHelper? = null, private val firebaseApp: FirebaseApp,