From 6680a401051b23a940dcc255a679ddcbc2cbd27a Mon Sep 17 00:00:00 2001 From: Christina Holland Date: Fri, 3 May 2024 15:10:46 -0700 Subject: [PATCH 1/3] allow user to set location --- packages/vertexai/src/api.ts | 9 ++++++--- packages/vertexai/src/index.ts | 5 +++-- packages/vertexai/src/public-types.ts | 8 ++++++++ .../vertexai/src/requests/stream-reader.test.ts | 17 +++++++++-------- packages/vertexai/src/service.ts | 8 ++++---- 5 files changed, 30 insertions(+), 17 deletions(-) diff --git a/packages/vertexai/src/api.ts b/packages/vertexai/src/api.ts index 19d250ccb79..3702a5ff17f 100644 --- a/packages/vertexai/src/api.ts +++ b/packages/vertexai/src/api.ts @@ -20,7 +20,7 @@ import { Provider } from '@firebase/component'; import { getModularInstance } from '@firebase/util'; import { DEFAULT_LOCATION, VERTEX_TYPE } from './constants'; import { VertexAIService } from './service'; -import { VertexAI } from './public-types'; +import { VertexAI, VertexAIOptions } from './public-types'; import { ERROR_FACTORY, VertexError } from './errors'; import { ModelParams, RequestOptions } from './types'; import { GenerativeModel } from './models/generative-model'; @@ -42,13 +42,16 @@ declare module '@firebase/component' { * * @param app - The {@link @firebase/app#FirebaseApp} to use. */ -export function getVertexAI(app: FirebaseApp = getApp()): VertexAI { +export function getVertexAI( + app: FirebaseApp = getApp(), + options?: VertexAIOptions +): VertexAI { app = getModularInstance(app); // Dependencies const vertexProvider: Provider<'vertex'> = _getProvider(app, VERTEX_TYPE); return vertexProvider.getImmediate({ - identifier: DEFAULT_LOCATION + identifier: options?.location || DEFAULT_LOCATION }); } diff --git a/packages/vertexai/src/index.ts b/packages/vertexai/src/index.ts index 403d690a370..d110a0744ca 100644 --- a/packages/vertexai/src/index.ts +++ b/packages/vertexai/src/index.ts @@ -37,11 +37,12 @@ function registerVertex(): void { _registerComponent( new Component( VERTEX_TYPE, - container => { + (container, { instanceIdentifier: location }) => { // getImmediate for FirebaseApp will always succeed const app = container.getProvider('app').getImmediate(); const appCheckProvider = container.getProvider('app-check-internal'); - return new VertexAIService(app, appCheckProvider); + console.log(location); + return new VertexAIService(app, appCheckProvider, { location }); }, ComponentType.PUBLIC ).setMultipleInstances(true) diff --git a/packages/vertexai/src/public-types.ts b/packages/vertexai/src/public-types.ts index a662580da86..64eeccba371 100644 --- a/packages/vertexai/src/public-types.ts +++ b/packages/vertexai/src/public-types.ts @@ -30,3 +30,11 @@ export interface VertexAI { app: FirebaseApp; location: string; } + +/** + * Options when initializing the Firebase Vertex AI SDK. + * @public + */ +export interface VertexAIOptions { + location?: string; +} diff --git a/packages/vertexai/src/requests/stream-reader.test.ts b/packages/vertexai/src/requests/stream-reader.test.ts index e942bb9f2b6..ae6d9fd33e4 100644 --- a/packages/vertexai/src/requests/stream-reader.test.ts +++ b/packages/vertexai/src/requests/stream-reader.test.ts @@ -32,7 +32,8 @@ import { FinishReason, GenerateContentResponse, HarmCategory, - HarmProbability + HarmProbability, + SafetyRating } from '../types'; use(sinonChai); @@ -229,7 +230,7 @@ describe('aggregateResponses', () => { { category: HarmCategory.HARM_CATEGORY_DANGEROUS_CONTENT, probability: HarmProbability.LOW - } + } as SafetyRating ] } } @@ -256,7 +257,7 @@ describe('aggregateResponses', () => { { category: HarmCategory.HARM_CATEGORY_HARASSMENT, probability: HarmProbability.NEGLIGIBLE - } + } as SafetyRating ] } ], @@ -266,7 +267,7 @@ describe('aggregateResponses', () => { { category: HarmCategory.HARM_CATEGORY_DANGEROUS_CONTENT, probability: HarmProbability.LOW - } + } as SafetyRating ] } }, @@ -284,7 +285,7 @@ describe('aggregateResponses', () => { { category: HarmCategory.HARM_CATEGORY_HARASSMENT, probability: HarmProbability.NEGLIGIBLE - } + } as SafetyRating ], citationMetadata: { citations: [ @@ -304,7 +305,7 @@ describe('aggregateResponses', () => { { category: HarmCategory.HARM_CATEGORY_HATE_SPEECH, probability: HarmProbability.HIGH - } + } as SafetyRating ] } }, @@ -322,7 +323,7 @@ describe('aggregateResponses', () => { { category: HarmCategory.HARM_CATEGORY_DANGEROUS_CONTENT, probability: HarmProbability.MEDIUM - } + } as SafetyRating ], citationMetadata: { citations: [ @@ -348,7 +349,7 @@ describe('aggregateResponses', () => { { category: HarmCategory.HARM_CATEGORY_HATE_SPEECH, probability: HarmProbability.HIGH - } + } as SafetyRating ] } } diff --git a/packages/vertexai/src/service.ts b/packages/vertexai/src/service.ts index 6925deb2a8c..a061fc4ad65 100644 --- a/packages/vertexai/src/service.ts +++ b/packages/vertexai/src/service.ts @@ -16,7 +16,7 @@ */ import { FirebaseApp, _FirebaseService } from '@firebase/app'; -import { VertexAI } from './public-types'; +import { VertexAI, VertexAIOptions } from './public-types'; import { AppCheckInternalComponentName, FirebaseAppCheckInternal @@ -30,12 +30,12 @@ export class VertexAIService implements VertexAI, _FirebaseService { constructor( public app: FirebaseApp, - appCheckProvider?: Provider + appCheckProvider?: Provider, + public options?: VertexAIOptions ) { const appCheck = appCheckProvider?.getImmediate({ optional: true }); this.appCheck = appCheck || null; - // TODO: add in user-set location option when that feature is available - this.location = DEFAULT_LOCATION; + this.location = this.options?.location || DEFAULT_LOCATION; } _delete(): Promise { From fac2760843c11451b2ebc4529c82677b0b95675e Mon Sep 17 00:00:00 2001 From: Christina Holland Date: Fri, 3 May 2024 15:19:36 -0700 Subject: [PATCH 2/3] test --- packages/vertexai/src/service.test.ts | 43 +++++++++++++++++++++++++++ 1 file changed, 43 insertions(+) create mode 100644 packages/vertexai/src/service.test.ts diff --git a/packages/vertexai/src/service.test.ts b/packages/vertexai/src/service.test.ts new file mode 100644 index 00000000000..7abe6a4019d --- /dev/null +++ b/packages/vertexai/src/service.test.ts @@ -0,0 +1,43 @@ +/** + * @license + * Copyright 2024 Google LLC + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +import { DEFAULT_LOCATION } from './constants'; +import { VertexAIService } from './service'; +import { expect } from 'chai'; + +const fakeApp = { + name: 'DEFAULT', + automaticDataCollectionEnabled: true, + options: { + apiKey: 'key', + projectId: 'my-project' + } +}; + +describe('VertexAIService', () => { + it('uses default location if not specified', () => { + const vertexAI = new VertexAIService(fakeApp); + expect(vertexAI.location).to.equal(DEFAULT_LOCATION); + }); + it('uses custom location if specified', () => { + const vertexAI = new VertexAIService( + fakeApp, + /* appCheckProvider */ undefined, + { location: 'somewhere' } + ); + expect(vertexAI.location).to.equal('somewhere'); + }); +}); From a49dc31404f444744cf97cfd98e0c9d862789bc9 Mon Sep 17 00:00:00 2001 From: Christina Holland Date: Mon, 6 May 2024 10:17:27 -0700 Subject: [PATCH 3/3] remove console.log --- packages/vertexai/src/index.ts | 1 - 1 file changed, 1 deletion(-) diff --git a/packages/vertexai/src/index.ts b/packages/vertexai/src/index.ts index d110a0744ca..9a0c717b9ee 100644 --- a/packages/vertexai/src/index.ts +++ b/packages/vertexai/src/index.ts @@ -41,7 +41,6 @@ function registerVertex(): void { // getImmediate for FirebaseApp will always succeed const app = container.getProvider('app').getImmediate(); const appCheckProvider = container.getProvider('app-check-internal'); - console.log(location); return new VertexAIService(app, appCheckProvider, { location }); }, ComponentType.PUBLIC