diff --git a/.changeset/rare-hats-know.md b/.changeset/rare-hats-know.md new file mode 100644 index 00000000000..68fafc7849d --- /dev/null +++ b/.changeset/rare-hats-know.md @@ -0,0 +1,5 @@ +--- +'@firebase/ai': patch +--- + +Fix logic for merging default `onDeviceParams` with user-provided `onDeviceParams`. diff --git a/packages/ai/src/methods/chrome-adapter-browser.test.ts b/packages/ai/src/methods/chrome-adapter-browser.test.ts index 5d5b2344ab6..e37a08bf1a9 100644 --- a/packages/ai/src/methods/chrome-adapter-browser.test.ts +++ b/packages/ai/src/methods/chrome-adapter-browser.test.ts @@ -78,6 +78,63 @@ describe('ChromeAdapter', () => { expectedInputs: [{ type: 'image' }] }); }); + it('sets image as expected input type by default even if other onDeviceParams params are set', async () => { + const languageModelProvider = { + availability: () => Promise.resolve(Availability.AVAILABLE) + } as LanguageModel; + const availabilityStub = stub( + languageModelProvider, + 'availability' + ).resolves(Availability.AVAILABLE); + const adapter = new ChromeAdapterImpl( + languageModelProvider, + InferenceMode.PREFER_ON_DEVICE, + { + promptOptions: {} + } + ); + await adapter.isAvailable({ + contents: [ + { + role: 'user', + parts: [{ text: 'hi' }] + } + ] + }); + expect(availabilityStub).to.have.been.calledWith({ + expectedInputs: [{ type: 'image' }] + }); + }); + it('sets image as expected input type by default even if other createOptions params are set', async () => { + const languageModelProvider = { + availability: () => Promise.resolve(Availability.AVAILABLE) + } as LanguageModel; + const availabilityStub = stub( + languageModelProvider, + 'availability' + ).resolves(Availability.AVAILABLE); + const adapter = new ChromeAdapterImpl( + languageModelProvider, + InferenceMode.PREFER_ON_DEVICE, + { + createOptions: { + topK: 22 + } + } + ); + await adapter.isAvailable({ + contents: [ + { + role: 'user', + parts: [{ text: 'hi' }] + } + ] + }); + expect(availabilityStub).to.have.been.calledWith({ + topK: 22, + expectedInputs: [{ type: 'image' }] + }); + }); it('honors explicitly set expected inputs', async () => { const languageModelProvider = { availability: () => Promise.resolve(Availability.AVAILABLE) diff --git a/packages/ai/src/methods/chrome-adapter.ts b/packages/ai/src/methods/chrome-adapter.ts index a0ab509e335..839276814bb 100644 --- a/packages/ai/src/methods/chrome-adapter.ts +++ b/packages/ai/src/methods/chrome-adapter.ts @@ -31,11 +31,15 @@ import { ChromeAdapter } from '../types/chrome-adapter'; import { Availability, LanguageModel, + LanguageModelExpected, LanguageModelMessage, LanguageModelMessageContent, LanguageModelMessageRole } from '../types/language-model'; +// Defaults to support image inputs for convenience. +const defaultExpectedInputs: LanguageModelExpected[] = [{ type: 'image' }]; + /** * Defines an inference "backend" that uses Chrome's on-device model, * and encapsulates logic for detecting when on-device inference is @@ -47,16 +51,28 @@ export class ChromeAdapterImpl implements ChromeAdapter { private isDownloading = false; private downloadPromise: Promise | undefined; private oldSession: LanguageModel | undefined; + onDeviceParams: OnDeviceParams = { + createOptions: { + expectedInputs: defaultExpectedInputs + } + }; constructor( public languageModelProvider: LanguageModel, public mode: InferenceMode, - public onDeviceParams: OnDeviceParams = { - createOptions: { - // Defaults to support image inputs for convenience. - expectedInputs: [{ type: 'image' }] + onDeviceParams?: OnDeviceParams + ) { + if (onDeviceParams) { + this.onDeviceParams = onDeviceParams; + if (!this.onDeviceParams.createOptions) { + this.onDeviceParams.createOptions = { + expectedInputs: defaultExpectedInputs + }; + } else if (!this.onDeviceParams.createOptions.expectedInputs) { + this.onDeviceParams.createOptions.expectedInputs = + defaultExpectedInputs; } } - ) {} + } /** * Checks if a given request can be made on-device.