diff --git a/firebase-ai/src/main/kotlin/com/google/firebase/ai/ImagenModel.kt b/firebase-ai/src/main/kotlin/com/google/firebase/ai/ImagenModel.kt index b82113bbf91..bbadb9199a1 100644 --- a/firebase-ai/src/main/kotlin/com/google/firebase/ai/ImagenModel.kt +++ b/firebase-ai/src/main/kotlin/com/google/firebase/ai/ImagenModel.kt @@ -177,6 +177,7 @@ internal constructor( GenerateImageRequest.ImagenParameters( sampleCount = generationConfig?.numberOfImages ?: 1, includeRaiReason = true, + includeSafetyAttributes = true, addWatermark = generationConfig?.addWatermark, personGeneration = safetySettings?.personFilterLevel?.internalVal, negativePrompt = generationConfig?.negativePrompt, @@ -206,6 +207,7 @@ internal constructor( GenerateImageRequest.ImagenParameters( sampleCount = generationConfig?.numberOfImages ?: 1, includeRaiReason = true, + includeSafetyAttributes = true, addWatermark = generationConfig?.addWatermark, personGeneration = safetySettings?.personFilterLevel?.internalVal, negativePrompt = generationConfig?.negativePrompt, diff --git a/firebase-ai/src/main/kotlin/com/google/firebase/ai/common/Request.kt b/firebase-ai/src/main/kotlin/com/google/firebase/ai/common/Request.kt index 3f8e0ae079d..43a6f648aad 100644 --- a/firebase-ai/src/main/kotlin/com/google/firebase/ai/common/Request.kt +++ b/firebase-ai/src/main/kotlin/com/google/firebase/ai/common/Request.kt @@ -93,6 +93,7 @@ internal data class GenerateImageRequest( internal data class ImagenParameters( val sampleCount: Int, val includeRaiReason: Boolean, + val includeSafetyAttributes: Boolean, val storageUri: String?, val negativePrompt: String?, val aspectRatio: String?, diff --git a/firebase-ai/src/main/kotlin/com/google/firebase/ai/type/ImagenGenerationResponse.kt b/firebase-ai/src/main/kotlin/com/google/firebase/ai/type/ImagenGenerationResponse.kt index 67f13cff199..f93854851e6 100644 --- a/firebase-ai/src/main/kotlin/com/google/firebase/ai/type/ImagenGenerationResponse.kt +++ b/firebase-ai/src/main/kotlin/com/google/firebase/ai/type/ImagenGenerationResponse.kt @@ -42,7 +42,7 @@ internal constructor(public val images: List, public val filteredReason: Stri internal fun toPublicInline() = ImagenGenerationResponse( images = predictions.filter { it.mimeType != null }.map { it.toPublicInline() }, - null, + predictions.firstNotNullOfOrNull { it.raiFilteredReason }, ) } @@ -52,10 +52,24 @@ internal constructor(public val images: List, public val filteredReason: Stri val gcsUri: String? = null, val mimeType: String? = null, val raiFilteredReason: String? = null, + val safetyAttributes: ImagenSafetyAttributes? = null, ) { internal fun toPublicInline() = ImagenInlineImage(Base64.decode(bytesBase64Encoded!!, Base64.NO_WRAP), mimeType!!) internal fun toPublicGCS() = ImagenGCSImage(gcsUri!!, mimeType!!) } + + @Serializable + internal data class ImagenSafetyAttributes( + val categories: List? = null, + val scores: List? = null + ) { + internal fun toPublic(): Map { + if (categories == null || scores == null) { + return emptyMap() + } + return categories.zip(scores).toMap() + } + } } diff --git a/firebase-ai/src/main/kotlin/com/google/firebase/ai/type/ImagenInlineImage.kt b/firebase-ai/src/main/kotlin/com/google/firebase/ai/type/ImagenInlineImage.kt index fc6033d5390..a83b29ec135 100644 --- a/firebase-ai/src/main/kotlin/com/google/firebase/ai/type/ImagenInlineImage.kt +++ b/firebase-ai/src/main/kotlin/com/google/firebase/ai/type/ImagenInlineImage.kt @@ -31,7 +31,10 @@ import kotlinx.serialization.Serializable */ @PublicPreviewAPI public class ImagenInlineImage -internal constructor(public val data: ByteArray, public val mimeType: String) { +internal constructor( + public val data: ByteArray, + public val mimeType: String, +) { /** * Returns the image as an Android OS native [Bitmap] so that it can be saved or sent to the UI. diff --git a/firebase-ai/src/test/java/com/google/firebase/ai/VertexAIUnarySnapshotTests.kt b/firebase-ai/src/test/java/com/google/firebase/ai/VertexAIUnarySnapshotTests.kt index e34a1e1db0d..a57f2f14c96 100644 --- a/firebase-ai/src/test/java/com/google/firebase/ai/VertexAIUnarySnapshotTests.kt +++ b/firebase-ai/src/test/java/com/google/firebase/ai/VertexAIUnarySnapshotTests.kt @@ -591,6 +591,15 @@ internal class VertexAIUnarySnapshotTests { } } + @Test + fun `generateImages should contain safety data`() = + goldenVertexUnaryFile("unary-success-generate-images-safety_info.json") { + withTimeout(testTimeout) { + val response = imagenModel.generateImages("prompt") + // There is no public API, but if it parses then success + } + } + @Test fun `google search grounding metadata is parsed correctly`() = goldenVertexUnaryFile("unary-success-google-search-grounding.json") {