diff --git a/docs/readme.md b/docs/readme.md index fcf67dfe5..9a3d33620 100644 --- a/docs/readme.md +++ b/docs/readme.md @@ -114,12 +114,14 @@ Release Notes - Analytics: Added `SetDefaultEventParameters()` which allows developers to specify a list of parameters that will be set on every event logged. - Analytics: Added a new `LogEvent()` that take in a IEnumerable of - parameters. + parameters. + - Firebase AI: Added support for using + [Server Prompt Templates](https://firebase.google.com/docs/ai-logic/server-prompt-templates/get-started). ### 13.5.0 - Changes - Firebase AI: Add support for receiving Live API Transcripts. - - Storage: Add support for Firebase Storage emulator via `UseEmulator`. + - Storage: Add support for Firebase Storage emulator via `UseEmulator`. The `UseEmulator` method should be called before invoking any other methods on a new instance of Storage. Default port is 9199. diff --git a/firebaseai/src/FirebaseAI.cs b/firebaseai/src/FirebaseAI.cs index 1f43195af..259e28183 100644 --- a/firebaseai/src/FirebaseAI.cs +++ b/firebaseai/src/FirebaseAI.cs @@ -225,6 +225,28 @@ public ImagenModel GetImagenModel( return new ImagenModel(_firebaseApp, _backend, modelName, generationConfig, safetySettings, requestOptions); } + + /// + /// Initializes a `TemplateGenerativeModel` with the given parameters. + /// + /// Configuration parameters for sending requests to the backend. + /// The initialized `TemplateGenerativeModel` instance. + public TemplateGenerativeModel GetTemplateGenerativeModel( + RequestOptions? requestOptions = null) + { + return new TemplateGenerativeModel(_firebaseApp, _backend, requestOptions); + } + + /// + /// Initializes a `TemplateImagenModel` with the given parameters. + /// + /// Configuration parameters for sending requests to the backend. + /// The initialized `TemplateImagenModel` instance. + public TemplateImagenModel GetTemplateImagenModel( + RequestOptions? requestOptions = null) + { + return new TemplateImagenModel(_firebaseApp, _backend, requestOptions); + } } } diff --git a/firebaseai/src/Imagen/ImagenModel.cs b/firebaseai/src/Imagen/ImagenModel.cs index 54b48f872..3f69f9ed1 100644 --- a/firebaseai/src/Imagen/ImagenModel.cs +++ b/firebaseai/src/Imagen/ImagenModel.cs @@ -45,6 +45,10 @@ public class ImagenModel private readonly HttpClient _httpClient; + /// + /// Intended for internal use only. + /// Use `FirebaseAI.GetImagenModel` instead to ensure proper initialization and configuration of the `ImagenModel`. + /// internal ImagenModel(FirebaseApp firebaseApp, FirebaseAI.Backend backend, string modelName, @@ -157,4 +161,72 @@ private Dictionary MakeGenerateImagenRequestAsDictionary( } } + /// + /// Represents a remote Imagen model with the ability to generate images using server template prompts. + /// + public class TemplateImagenModel + { + private readonly FirebaseApp _firebaseApp; + private readonly FirebaseAI.Backend _backend; + + private readonly HttpClient _httpClient; + + /// + /// Intended for internal use only. + /// Use `FirebaseAI.GetTemplateImagenModel` instead to ensure proper initialization and configuration of the `TemplateImagenModel`. + /// + internal TemplateImagenModel(FirebaseApp firebaseApp, + FirebaseAI.Backend backend, RequestOptions? requestOptions = null) + { + _firebaseApp = firebaseApp; + _backend = backend; + + // Create a HttpClient using the timeout requested, or the default one. + _httpClient = new HttpClient() + { + Timeout = requestOptions?.Timeout ?? RequestOptions.DefaultTimeout + }; + } + + /// + /// Generates images using the Template Imagen model and returns them as inline data. + /// + /// The id of the server prompt template to use. + /// Any input parameters expected by the server prompt template. + /// An optional token to cancel the operation. + /// The generated content response from the model. + /// Thrown when an error occurs during content generation. + public async Task> GenerateImagesAsync( + string templateId, IDictionary inputs, CancellationToken cancellationToken = default) + { + HttpRequestMessage request = new(HttpMethod.Post, + HttpHelpers.GetTemplateURL(_firebaseApp, _backend, templateId) + ":templatePredict"); + + // Set the request headers + await HttpHelpers.SetRequestHeaders(request, _firebaseApp); + + // Set the content + Dictionary jsonDict = new() + { + ["inputs"] = inputs + }; + string bodyJson = Json.Serialize(jsonDict); + request.Content = new StringContent(bodyJson, Encoding.UTF8, "application/json"); + +#if FIREBASE_LOG_REST_CALLS + UnityEngine.Debug.Log("Request:\n" + bodyJson); +#endif + + var response = await _httpClient.SendAsync(request, cancellationToken); + await HttpHelpers.ValidateHttpResponse(response); + + string result = await response.Content.ReadAsStringAsync(); + +#if FIREBASE_LOG_REST_CALLS + UnityEngine.Debug.Log("Response:\n" + result); +#endif + + return ImagenGenerationResponse.FromJson(result); + } + } } diff --git a/firebaseai/src/Internal/HttpHelpers.cs b/firebaseai/src/Internal/HttpHelpers.cs index 4c34252dc..075213293 100644 --- a/firebaseai/src/Internal/HttpHelpers.cs +++ b/firebaseai/src/Internal/HttpHelpers.cs @@ -23,6 +23,8 @@ namespace Firebase.AI.Internal // Helper functions to help handling the Http calls. internal static class HttpHelpers { + internal static readonly string StreamPrefix = "data: "; + // Get the URL to use for the rest calls based on the backend. internal static string GetURL(FirebaseApp firebaseApp, FirebaseAI.Backend backend, string modelName) @@ -46,6 +48,25 @@ internal static string GetURL(FirebaseApp firebaseApp, } } + internal static string GetTemplateURL(FirebaseApp firebaseApp, + FirebaseAI.Backend backend, string templateId) + { + var projectUrl = "https://firebasevertexai.googleapis.com/v1beta" + + $"/projects/{firebaseApp.Options.ProjectId}"; + if (backend.Provider == FirebaseAI.Backend.InternalProvider.VertexAI) + { + return $"{projectUrl}/locations/{backend.Location}/templates/{templateId}"; + } + else if (backend.Provider == FirebaseAI.Backend.InternalProvider.GoogleAI) + { + return $"{projectUrl}/templates/{templateId}"; + } + else + { + throw new NotSupportedException($"Missing support for backend: {backend.Provider}"); + } + } + internal static async Task SetRequestHeaders(HttpRequestMessage request, FirebaseApp firebaseApp) { request.Headers.Add("x-goog-api-key", firebaseApp.Options.ApiKey); diff --git a/firebaseai/src/TemplateGenerativeModel.cs b/firebaseai/src/TemplateGenerativeModel.cs new file mode 100644 index 000000000..e9c7cbb4c --- /dev/null +++ b/firebaseai/src/TemplateGenerativeModel.cs @@ -0,0 +1,177 @@ +/* + * Copyright 2025 Google LLC + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +using System; +using System.Collections.Generic; +using System.Net.Http; +using System.Text; +using System.Threading; +using System.Threading.Tasks; +using Google.MiniJSON; +using Firebase.AI.Internal; +using System.Linq; +using System.Runtime.CompilerServices; +using System.IO; + +namespace Firebase.AI +{ + /// + /// A type that represents a remote multimodal model (like Gemini), with the ability to generate + /// content based on defined server prompt templates. + /// + public class TemplateGenerativeModel + { + private readonly FirebaseApp _firebaseApp; + private readonly FirebaseAI.Backend _backend; + + private readonly HttpClient _httpClient; + + /// + /// Intended for internal use only. + /// Use `FirebaseAI.GetTemplateGenerativeModel` instead to ensure proper + /// initialization and configuration of the `TemplateGenerativeModel`. + /// + internal TemplateGenerativeModel(FirebaseApp firebaseApp, + FirebaseAI.Backend backend, + RequestOptions? requestOptions = null) + { + _firebaseApp = firebaseApp; + _backend = backend; + + // Create a HttpClient using the timeout requested, or the default one. + _httpClient = new HttpClient() + { + Timeout = requestOptions?.Timeout ?? RequestOptions.DefaultTimeout + }; + } + + /// + /// Generates new content by calling into a server prompt template. + /// + /// The id of the server prompt template to use. + /// Any input parameters expected by the server prompt template. + /// An optional token to cancel the operation. + /// The generated content response from the model. + /// Thrown when an error occurs during content generation. + public Task GenerateContentAsync( + string templateId, IDictionary inputs, + CancellationToken cancellationToken = default) + { + return GenerateContentAsyncInternal(templateId, inputs, null, cancellationToken); + } + + /// + /// Generates new content as a stream by calling into a server prompt template. + /// + /// The id of the server prompt template to use. + /// Any input parameters expected by the server prompt template. + /// An optional token to cancel the operation. + /// A stream of generated content responses from the model. + /// Thrown when an error occurs during content generation. + public IAsyncEnumerable GenerateContentStreamAsync( + string templateId, IDictionary inputs, + CancellationToken cancellationToken = default) + { + return GenerateContentStreamAsyncInternal(templateId, inputs, null, cancellationToken); + } + + private string MakeGenerateContentRequest(IDictionary inputs, + IEnumerable chatHistory) + { + var jsonDict = new Dictionary() + { + ["inputs"] = inputs + }; + if (chatHistory != null) + { + jsonDict["history"] = chatHistory.Select(t => t.ToJson()).ToList(); + } + return Json.Serialize(jsonDict); + } + + private async Task GenerateContentAsyncInternal( + string templateId, IDictionary inputs, + IEnumerable chatHistory, + CancellationToken cancellationToken) + { + HttpRequestMessage request = new(HttpMethod.Post, + HttpHelpers.GetTemplateURL(_firebaseApp, _backend, templateId) + ":templateGenerateContent"); + + // Set the request headers + await HttpHelpers.SetRequestHeaders(request, _firebaseApp); + + // Set the content + string bodyJson = MakeGenerateContentRequest(inputs, chatHistory); + request.Content = new StringContent(bodyJson, Encoding.UTF8, "application/json"); + +#if FIREBASE_LOG_REST_CALLS + UnityEngine.Debug.Log("Request:\n" + bodyJson); +#endif + + var response = await _httpClient.SendAsync(request, cancellationToken); + await HttpHelpers.ValidateHttpResponse(response); + + string result = await response.Content.ReadAsStringAsync(); + +#if FIREBASE_LOG_REST_CALLS + UnityEngine.Debug.Log("Response:\n" + result); +#endif + + return GenerateContentResponse.FromJson(result, _backend.Provider); + } + + private async IAsyncEnumerable GenerateContentStreamAsyncInternal( + string templateId, IDictionary inputs, + IEnumerable chatHistory, + [EnumeratorCancellation] CancellationToken cancellationToken) + { + HttpRequestMessage request = new(HttpMethod.Post, + HttpHelpers.GetTemplateURL(_firebaseApp, _backend, templateId) + ":templateStreamGenerateContent?alt=sse"); + + // Set the request headers + await HttpHelpers.SetRequestHeaders(request, _firebaseApp); + + // Set the content + string bodyJson = MakeGenerateContentRequest(inputs, chatHistory); + request.Content = new StringContent(bodyJson, Encoding.UTF8, "application/json"); + +#if FIREBASE_LOG_REST_CALLS + UnityEngine.Debug.Log("Request:\n" + bodyJson); +#endif + + var response = await _httpClient.SendAsync(request, HttpCompletionOption.ResponseHeadersRead, cancellationToken); + await HttpHelpers.ValidateHttpResponse(response); + + // We are expecting a Stream as the response, so handle that. + using var stream = await response.Content.ReadAsStreamAsync(); + using var reader = new StreamReader(stream); + + string line; + while ((line = await reader.ReadLineAsync()) != null) + { + // Only pass along strings that begin with the expected prefix. + if (line.StartsWith(HttpHelpers.StreamPrefix)) + { +#if FIREBASE_LOG_REST_CALLS + UnityEngine.Debug.Log("Streaming Response:\n" + line); +#endif + + yield return GenerateContentResponse.FromJson(line[HttpHelpers.StreamPrefix.Length..], _backend.Provider); + } + } + } + } +} diff --git a/firebaseai/src/TemplateGenerativeModel.cs.meta b/firebaseai/src/TemplateGenerativeModel.cs.meta new file mode 100644 index 000000000..51762f35f --- /dev/null +++ b/firebaseai/src/TemplateGenerativeModel.cs.meta @@ -0,0 +1,11 @@ +fileFormatVersion: 2 +guid: 3a969b36e561242e3bda360d90ef7b68 +MonoImporter: + externalObjects: {} + serializedVersion: 2 + defaultReferences: [] + executionOrder: 0 + icon: {instanceID: 0} + userData: + assetBundleName: + assetBundleVariant: diff --git a/firebaseai/testapp/Assets/Firebase/Sample/FirebaseAI/UIHandlerAutomated.cs b/firebaseai/testapp/Assets/Firebase/Sample/FirebaseAI/UIHandlerAutomated.cs index ba09496a4..edcbdee6c 100644 --- a/firebaseai/testapp/Assets/Firebase/Sample/FirebaseAI/UIHandlerAutomated.cs +++ b/firebaseai/testapp/Assets/Firebase/Sample/FirebaseAI/UIHandlerAutomated.cs @@ -81,6 +81,9 @@ protected override void Start() TestIncludeThoughts, TestCodeExecution, TestUrlContext, + TestTemplateGenerateContent, + TestTemplateGenerateContentStream, + TestTemplateImagenGenerateImage, }; // Set of tests that only run the single time. Func[] singleTests = { @@ -201,12 +204,12 @@ private bool ValidProbability(float value) // The model name to use for the tests. private readonly string TestModelName = "gemini-2.0-flash"; - private FirebaseAI GetFirebaseAI(Backend backend) + private FirebaseAI GetFirebaseAI(Backend backend, string location = "us-central1") { return backend switch { Backend.GoogleAI => FirebaseAI.GetInstance(FirebaseAI.Backend.GoogleAI()), - Backend.VertexAI => FirebaseAI.GetInstance(FirebaseAI.Backend.VertexAI()), + Backend.VertexAI => FirebaseAI.GetInstance(FirebaseAI.Backend.VertexAI(location)), _ => throw new ArgumentOutOfRangeException(nameof(backend), backend, "Unhandled Backend type"), }; @@ -810,7 +813,7 @@ async Task TestGenerateImage(Backend backend) // Test generating an image via Imagen. async Task TestImagenGenerateImage(Backend backend) { - var model = GetFirebaseAI(backend).GetImagenModel("imagen-3.0-generate-002"); + var model = GetFirebaseAI(backend).GetImagenModel("imagen-4.0-generate-001"); var response = await model.GenerateImagesAsync( "Generate an image of a cartoon dog."); @@ -832,7 +835,7 @@ async Task TestImagenGenerateImage(Backend backend) async Task TestImagenGenerateImageOptions(Backend backend) { var model = GetFirebaseAI(backend).GetImagenModel( - modelName: "imagen-3.0-generate-002", + modelName: "imagen-4.0-generate-001", generationConfig: new ImagenGenerationConfig( // negativePrompt and addWatermark are not supported on this version of the model. numberOfImages: 2, @@ -965,6 +968,81 @@ async Task TestUrlContext(Backend backend) } } + async Task TestTemplateGenerateContent(Backend backend) + { + var model = GetFirebaseAI(backend, "global").GetTemplateGenerativeModel(); + + var inputs = new Dictionary() + { + ["customerName"] = "Jane" + }; + var response = await model.GenerateContentAsync("input-system-instructions", inputs); + + string result = response.Text; + Assert("Response text was missing", !string.IsNullOrWhiteSpace(result)); + } + + async Task TestTemplateGenerateContentStream(Backend backend) + { + var model = GetFirebaseAI(backend, "global").GetTemplateGenerativeModel(); + + var inputs = new Dictionary() + { + ["customerName"] = "Jane" + }; + var responseStream = model.GenerateContentStreamAsync("input-system-instructions", inputs); + + // We combine all the text, just in case the keyword got cut between two responses. + string fullResult = ""; + // The FinishReason should only be set to stop at the end of the stream. + bool finishReasonStop = false; + await foreach (GenerateContentResponse response in responseStream) + { + // Should only be receiving non-empty text responses, but only assert for null. + string text = response.Text; + Assert("Received null text from the stream.", text != null); + if (string.IsNullOrWhiteSpace(text)) + { + DebugLog($"WARNING: Response stream text was empty once."); + } + + Assert("Previous FinishReason was stop, but received more", !finishReasonStop); + if (response.Candidates.First().FinishReason == FinishReason.Stop) + { + finishReasonStop = true; + } + + fullResult += text; + } + + Assert("Response text was missing", !string.IsNullOrWhiteSpace(fullResult)); + Assert("Finished without seeing FinishReason.Stop", finishReasonStop); + } + + async Task TestTemplateImagenGenerateImage(Backend backend) + { + var model = GetFirebaseAI(backend).GetTemplateImagenModel(); + + var inputs = new Dictionary() + { + ["prompt"] = "flowers", + }; + var response = await model.GenerateImagesAsync( + "imagen-generation-basic", inputs); + + // We can't easily test if the image is correct, but can check other random data. + AssertEq("FilteredReason", response.FilteredReason, null); + AssertEq("Image Count", response.Images.Count, 1); + + AssertEq($"Image MimeType", response.Images[0].MimeType, "image/png"); + + var texture = response.Images[0].AsTexture2D(); + Assert($"Image as Texture2D", texture != null); + // By default the image should be Square 1x1, so check for that. + Assert($"Image Height > 0", texture.height > 0); + AssertEq($"Image Height = Width", texture.height, texture.width); + } + // Test providing a file from a GCS bucket (Firebase Storage) to the model. async Task TestReadFile() {