Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions firebase-ai/CHANGELOG.md
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
# Unreleased

- [fixed] Fixed an issue causing streaming chat interactions to drop thought signatures. (#7562)
- [feature] Added support for server templates via `TemplateGenerativeModel` and
`TemplateImagenModel`. (#7503)

Expand Down
24 changes: 15 additions & 9 deletions firebase-ai/src/main/kotlin/com/google/firebase/ai/Chat.kt
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@ import com.google.firebase.ai.type.GenerateContentResponse
import com.google.firebase.ai.type.ImagePart
import com.google.firebase.ai.type.InlineDataPart
import com.google.firebase.ai.type.InvalidStateException
import com.google.firebase.ai.type.Part
import com.google.firebase.ai.type.TextPart
import com.google.firebase.ai.type.content
import java.util.LinkedList
Expand Down Expand Up @@ -133,6 +134,7 @@ public class Chat(
val bitmaps = LinkedList<Bitmap>()
val inlineDataParts = LinkedList<InlineDataPart>()
val text = StringBuilder()
val parts = mutableListOf<Part>()

/**
* TODO: revisit when images and inline data are returned. This will cause issues with how
Expand All @@ -147,22 +149,17 @@ public class Chat(
is ImagePart -> bitmaps.add(part.image)
is InlineDataPart -> inlineDataParts.add(part)
}
parts.add(part)
}
}
.onCompletion {
lock.release()
if (it == null) {
val content =
content("model") {
for (bitmap in bitmaps) {
image(bitmap)
}
for (inlineDataPart in inlineDataParts) {
inlineData(inlineDataPart.inlineData, inlineDataPart.mimeType)
}
if (text.isNotBlank()) {
text(text.toString())
}
setParts(
parts.filterNot { part -> part is TextPart && !part.hasContent() }.toMutableList()
)
}

history.add(prompt)
Expand Down Expand Up @@ -224,3 +221,12 @@ public class Chat(
}
}
}

/**
* Returns true if the [TextPart] contains any content, either in its [TextPart.text] property or
* its [TextPart.thoughtSignature] property.
*/
private fun TextPart.hasContent(): Boolean {
if (text.isNotEmpty()) return true
return !thoughtSignature.isNullOrBlank()
}
Original file line number Diff line number Diff line change
Expand Up @@ -18,18 +18,29 @@ package com.google.firebase.ai

import com.google.firebase.ai.type.BlockReason
import com.google.firebase.ai.type.FinishReason
import com.google.firebase.ai.type.FunctionCallPart
import com.google.firebase.ai.type.PromptBlockedException
import com.google.firebase.ai.type.ResponseStoppedException
import com.google.firebase.ai.type.ServerException
import com.google.firebase.ai.type.content
import com.google.firebase.ai.util.goldenDevAPIStreamingFile
import io.kotest.assertions.throwables.shouldThrow
import io.kotest.matchers.collections.shouldBeEmpty
import io.kotest.matchers.collections.shouldNotBeEmpty
import io.kotest.matchers.nulls.shouldNotBeNull
import io.kotest.matchers.shouldBe
import io.kotest.matchers.string.shouldStartWith
import io.ktor.client.engine.mock.toByteArray
import io.ktor.client.request.HttpRequestData
import io.ktor.http.HttpStatusCode
import kotlin.time.Duration.Companion.seconds
import kotlinx.coroutines.flow.collect
import kotlinx.coroutines.flow.toList
import kotlinx.coroutines.withTimeout
import kotlinx.serialization.json.Json
import kotlinx.serialization.json.jsonArray
import kotlinx.serialization.json.jsonObject
import kotlinx.serialization.json.jsonPrimitive
import org.junit.Test
import org.junit.runner.RunWith
import org.robolectric.RobolectricTestRunner
Expand Down Expand Up @@ -85,6 +96,86 @@ internal class DevAPIStreamingSnapshotTests {
}
}

@Test
fun `success call with thought summary and signature`() =
goldenDevAPIStreamingFile(
"streaming-success-thinking-function-call-thought-summary-signature.txt"
) {
val responses = model.generateContentStream("prompt")

withTimeout(testTimeout) {
val responseList = responses.toList()
responseList.isEmpty() shouldBe false
val functionCallResponse = responseList.find { it.functionCalls.isNotEmpty() }
functionCallResponse.shouldNotBeNull()
functionCallResponse.functionCalls.first().let {
it.thoughtSignature.shouldNotBeNull()
it.thoughtSignature.shouldStartWith("CiIBVKhc7vB")
}
}
}

@Test
fun `chat call with history including thought summary and signature`() {
var capturedRequest: HttpRequestData? = null
goldenDevAPIStreamingFile(
"streaming-success-thinking-function-call-thought-summary-signature.txt",
requestHandler = { capturedRequest = it }
) {
val chat = model.startChat()
val firstPrompt = content { text("first prompt") }
val secondPrompt = content { text("second prompt") }
val responses = chat.sendMessageStream(firstPrompt)

withTimeout(testTimeout) {
val responseList = responses.toList()
responseList.shouldNotBeEmpty()

chat.history.let { history ->
history.contains(firstPrompt)
val functionCallPart =
history.flatMap { it.parts }.first { it is FunctionCallPart } as FunctionCallPart
functionCallPart.let {
it.thoughtSignature.shouldNotBeNull()
it.thoughtSignature.shouldStartWith("CiIBVKhc7vB")
}
}

// Reset the request so we can be sure we capture the latest version
capturedRequest = null

// We don't care about the response, only the request
val unused = chat.sendMessageStream(secondPrompt).toList()

// Make sure the history contains all prompts seen so far
chat.history.contains(firstPrompt)
chat.history.contains(secondPrompt)

// Put the captured request into a `val` to enable smart casting
val request = capturedRequest
request.shouldNotBeNull()
val bodyAsString = request.body.toByteArray().decodeToString()
bodyAsString.shouldNotBeNull()

val rootElement = Json.parseToJsonElement(bodyAsString).jsonObject

// Traverse the tree: contents -> parts -> thoughtSignature
val contents = rootElement["contents"]?.jsonArray

val signature =
contents?.firstNotNullOfOrNull { content ->
content.jsonObject["parts"]?.jsonArray?.firstNotNullOfOrNull { part ->
// resulting value is a JsonPrimitive, so we use .content to get the string
part.jsonObject["thoughtSignature"]?.jsonPrimitive?.content
}
}

signature.shouldNotBeNull()
signature.shouldStartWith("CiIBVKhc7vB")
}
}
}

@Test
fun `prompt blocked for safety`() =
goldenDevAPIStreamingFile("streaming-failure-prompt-blocked-safety.txt") {
Expand Down
26 changes: 22 additions & 4 deletions firebase-ai/src/test/java/com/google/firebase/ai/util/tests.kt
Original file line number Diff line number Diff line change
Expand Up @@ -29,11 +29,11 @@ import io.kotest.matchers.collections.shouldNotBeEmpty
import io.kotest.matchers.nulls.shouldNotBeNull
import io.ktor.client.engine.mock.MockEngine
import io.ktor.client.engine.mock.respond
import io.ktor.client.request.HttpRequestData
import io.ktor.http.HttpHeaders
import io.ktor.http.HttpStatusCode
import io.ktor.http.headersOf
import io.ktor.utils.io.ByteChannel
import io.ktor.utils.io.close
import io.ktor.utils.io.writeFully
import java.io.File
import kotlinx.coroutines.launch
Expand Down Expand Up @@ -103,6 +103,7 @@ internal fun commonTest(
status: HttpStatusCode = HttpStatusCode.OK,
requestOptions: RequestOptions = RequestOptions(),
backend: GenerativeBackend = GenerativeBackend.vertexAI(),
requestHandler: (HttpRequestData) -> Unit = {},
block: CommonTest,
) = doBlocking {
val channel = ByteChannel(autoFlush = true)
Expand All @@ -115,6 +116,7 @@ internal fun commonTest(
"gemini-pro",
requestOptions,
MockEngine {
requestHandler(it)
respond(channel, status, headersOf(HttpHeaders.ContentType, "application/json"))
},
TEST_CLIENT_ID,
Expand Down Expand Up @@ -144,12 +146,13 @@ internal fun goldenStreamingFile(
name: String,
httpStatusCode: HttpStatusCode = HttpStatusCode.OK,
backend: GenerativeBackend = GenerativeBackend.vertexAI(),
requestHandler: (HttpRequestData) -> Unit,
block: CommonTest,
) = doBlocking {
val goldenFile = loadGoldenFile(name)
val messages = goldenFile.readLines().filter { it.isNotBlank() }

commonTest(httpStatusCode, backend = backend) {
commonTest(httpStatusCode, backend = backend, requestHandler = requestHandler) {
launch {
for (message in messages) {
channel.writeFully("$message$SSE_SEPARATOR".toByteArray())
Expand All @@ -175,8 +178,15 @@ internal fun goldenStreamingFile(
internal fun goldenVertexStreamingFile(
name: String,
httpStatusCode: HttpStatusCode = HttpStatusCode.OK,
requestHandler: (HttpRequestData) -> Unit = {},
block: CommonTest,
) = goldenStreamingFile("vertexai/$name", httpStatusCode, block = block)
) =
goldenStreamingFile(
"vertexai/$name",
httpStatusCode,
requestHandler = requestHandler,
block = block
)

/**
* A variant of [goldenStreamingFile] for testing the developer api
Expand All @@ -192,8 +202,16 @@ internal fun goldenVertexStreamingFile(
internal fun goldenDevAPIStreamingFile(
name: String,
httpStatusCode: HttpStatusCode = HttpStatusCode.OK,
requestHandler: (HttpRequestData) -> Unit = {},
block: CommonTest,
) = goldenStreamingFile("googleai/$name", httpStatusCode, GenerativeBackend.googleAI(), block)
) =
goldenStreamingFile(
"googleai/$name",
httpStatusCode,
GenerativeBackend.googleAI(),
requestHandler,
block
)

/**
* A variant of [commonTest] for performing snapshot tests.
Expand Down
Loading