Skip to content
Original file line number Diff line number Diff line change
Expand Up @@ -177,6 +177,7 @@ internal constructor(
GenerateImageRequest.ImagenParameters(
sampleCount = generationConfig?.numberOfImages ?: 1,
includeRaiReason = true,
includeSafetyAttributes = true,
addWatermark = generationConfig?.addWatermark,
personGeneration = safetySettings?.personFilterLevel?.internalVal,
negativePrompt = generationConfig?.negativePrompt,
Expand Down Expand Up @@ -206,6 +207,7 @@ internal constructor(
GenerateImageRequest.ImagenParameters(
sampleCount = generationConfig?.numberOfImages ?: 1,
includeRaiReason = true,
includeSafetyAttributes = true,
addWatermark = generationConfig?.addWatermark,
personGeneration = safetySettings?.personFilterLevel?.internalVal,
negativePrompt = generationConfig?.negativePrompt,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -93,6 +93,7 @@ internal data class GenerateImageRequest(
internal data class ImagenParameters(
val sampleCount: Int,
val includeRaiReason: Boolean,
val includeSafetyAttributes: Boolean,
val storageUri: String?,
val negativePrompt: String?,
val aspectRatio: String?,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -42,7 +42,7 @@ internal constructor(public val images: List<T>, public val filteredReason: Stri
internal fun toPublicInline() =
ImagenGenerationResponse(
images = predictions.filter { it.mimeType != null }.map { it.toPublicInline() },
null,
predictions.firstNotNullOfOrNull { it.raiFilteredReason },
)
}

Expand All @@ -52,10 +52,24 @@ internal constructor(public val images: List<T>, public val filteredReason: Stri
val gcsUri: String? = null,
val mimeType: String? = null,
val raiFilteredReason: String? = null,
val safetyAttributes: ImagenSafetyAttributes? = null,
) {
internal fun toPublicInline() =
ImagenInlineImage(Base64.decode(bytesBase64Encoded!!, Base64.NO_WRAP), mimeType!!)

internal fun toPublicGCS() = ImagenGCSImage(gcsUri!!, mimeType!!)
}

@Serializable
internal data class ImagenSafetyAttributes(
val categories: List<String>? = null,
val scores: List<Double>? = null
) {
internal fun toPublic(): Map<String, Double> {
if (categories == null || scores == null) {
return emptyMap()
}
return categories.zip(scores).toMap()
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,10 @@ import kotlinx.serialization.Serializable
*/
@PublicPreviewAPI
public class ImagenInlineImage
internal constructor(public val data: ByteArray, public val mimeType: String) {
internal constructor(
public val data: ByteArray,
public val mimeType: String,
) {

/**
* Returns the image as an Android OS native [Bitmap] so that it can be saved or sent to the UI.
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -591,6 +591,15 @@ internal class VertexAIUnarySnapshotTests {
}
}

@Test
fun `generateImages should contain safety data`() =
goldenVertexUnaryFile("unary-success-generate-images-safety_info.json") {
withTimeout(testTimeout) {
val response = imagenModel.generateImages("prompt")
// There is no public API, but if it parses then success
}
}

@Test
fun `google search grounding metadata is parsed correctly`() =
goldenVertexUnaryFile("unary-success-google-search-grounding.json") {
Expand Down