diff --git a/.cursor/rules/ai/ai-package.md b/.cursor/rules/ai/ai-package.md new file mode 100644 index 0000000000..4358aa3361 --- /dev/null +++ b/.cursor/rules/ai/ai-package.md @@ -0,0 +1,869 @@ +# React Native Firebase AI Package - Porting Rules from Firebase JS SDK + +This document describes the **intentional differences** between the React Native Firebase AI package and the Firebase JS SDK AI package. Use this to distinguish between known architectural differences and actual missing features that need to be ported. + +--- + +## ๐ŸŽฏ Purpose + +When comparing this React Native Firebase AI package with the Firebase JS SDK AI package (from `firebase/firebase-js-sdk`), this file documents all **expected and correct** differences. Any difference NOT listed here is potentially a missing feature that should be evaluated for porting. + +This cursor rules file lives in `.cursor/rules/` and helps AI assistants understand the intentional architectural differences between the web and mobile implementations. + +### ๐Ÿ“ฆ Setup for Comparison +To effectively use this file for feature parity checks: +1. Clone both repositories locally: + - `git clone https://github.com/invertase/react-native-firebase.git` + - `git clone https://github.com/firebase/firebase-js-sdk.git` +2. When asking AI to compare, reference both package directories +3. Use these cursor rules to filter out known differences + +--- + +## ๐Ÿ“ Structural Differences (INTENTIONAL - Do Not Port) + +### 1. Source Code Location +- **Firebase JS SDK**: Uses `src/` folder +- **React Native Firebase**: Uses `lib/` folder +- **Reason**: React Native Firebase convention + +### 2. Entry Point Architecture +- **Firebase JS SDK**: + - `src/index.ts` - Component registration and exports + - `src/api.ts` - Public API exports (getAI, getGenerativeModel, etc.) +- **React Native Firebase**: + - `lib/index.ts` - Combined entry point with all exports +- **Reason**: RN doesn't use Firebase's component registration system + +### 3. Test File Organization +- **Firebase JS SDK**: Tests live alongside source in `src/` (e.g., `src/api.test.ts`, `src/models/ai-model.test.ts`) +- **React Native Firebase**: Tests in `__tests__/` at package root +- **Reason**: React Native Firebase testing convention + +**IMPORTANT**: Tests MUST be ported and kept in sync! +- File names match between packages (e.g., `src/api.test.ts` โ†’ `__tests__/api.test.ts`) +- RN uses different testing tools (Jest with React Native setup vs Karma/Mocha) +- Test logic and coverage should match, even if test utilities differ +- When porting features, ALWAYS port the corresponding tests + +### Testing Tools Differences + +**Firebase JS SDK uses:** +- Karma test runner +- Mocha/Chai for assertions +- Browser-based test environment +- Tests run alongside source code + +**React Native Firebase uses:** +- Jest test framework +- Jest assertions and matchers +- React Native test environment +- Tests in dedicated `__tests__/` directory + +**When porting tests:** +- Convert Mocha `describe/it` to Jest (mostly compatible) +- Replace Chai assertions with Jest matchers +- Remove browser-specific test setup +- Keep test file names identical for traceability +- Maintain same test coverage and logic +- **IMPORTANT**: Follow ESLint requirements to ensure tests pass linting (see below) + +### 4. ESLint Requirements for Tests + +**CRITICAL**: All test files MUST pass ESLint. The project uses Mocha ESLint plugin even though tests use Jest. + +**Required imports:** +```typescript +import { describe, expect, it, jest } from '@jest/globals'; +import { type ReactNativeFirebase } from '@react-native-firebase/app'; +``` + +**Key rules to follow:** + +1. **Import Jest globals explicitly** - Do NOT rely on global types: + ```typescript + // โœ… CORRECT + import { describe, expect, it, jest } from '@jest/globals'; + + // โŒ WRONG - Will cause "Cannot find name 'describe'" errors + // (no import) + ``` + +2. **Use regular functions, NOT arrow functions** (mocha/no-mocha-arrows): + ```typescript + // โœ… CORRECT + describe('MyTest', function () { + it('does something', function () { + // test code + }); + }); + + // โŒ WRONG - Violates mocha/no-mocha-arrows + describe('MyTest', () => { + it('does something', () => { + // test code + }); + }); + ``` + +3. **Type assertions for test objects**: + ```typescript + // โœ… CORRECT - Cast to proper RN types + const fakeAI: AI = { + app: { + name: 'DEFAULT', + options: { apiKey: 'key' } + } as ReactNativeFirebase.FirebaseApp, + backend: new VertexAIBackend('us-central1'), + location: 'us-central1', + }; + + // For complex mocks needing double casting: + const mockService = { + ...fakeAI, + appCheck: { getToken: mockFn } + } as unknown as AIService; + + // โŒ WRONG - Using @ts-ignore + const fakeAI: AI = { + app: { name: 'DEFAULT' }, + // @ts-ignore + } as AI; + ``` + +4. **Mock function types** (for TypeScript inference): + ```typescript + // โœ… CORRECT - Explicit type annotation + const mockFn = jest + .fn<() => Promise<{ token: string }>>() + .mockResolvedValue({ token: 'value' }); + + // Format on multiple lines for readability (Prettier requirement) + ``` + +5. **Common imports needed**: + - `jest` - For jest.fn(), jest.spyOn(), etc. + - `afterEach` - For cleanup + - `beforeEach` - For setup + - `afterAll` / `beforeAll` - For suite setup/teardown + +**Example complete test file structure:** +```typescript +import { describe, expect, it, jest, afterEach } from '@jest/globals'; +import { type ReactNativeFirebase } from '@react-native-firebase/app'; +import { AI } from '../lib/public-types'; +import { VertexAIBackend } from '../lib/backend'; + +const fakeAI: AI = { + app: { + name: 'DEFAULT', + automaticDataCollectionEnabled: true, + options: { apiKey: 'key', projectId: 'proj', appId: 'app' }, + } as ReactNativeFirebase.FirebaseApp, + backend: new VertexAIBackend('us-central1'), + location: 'us-central1', +}; + +describe('MyFeature', function () { + afterEach(function () { + jest.clearAllMocks(); + }); + + it('does something', function () { + expect(true).toBe(true); + }); + + it('handles async', async function () { + const mockFn = jest + .fn<() => Promise>() + .mockResolvedValue('result'); + const result = await mockFn(); + expect(result).toBe('result'); + }); +}); +``` + +### 5. Integration Tests Location +- **Firebase JS SDK**: Has `integration/` folder with integration tests +- **React Native Firebase**: Has `e2e/` folder for end-to-end tests +- **Reason**: Different testing approaches for web vs mobile + +### 6. WebSocket Test Mocking Differences + +**Firebase JS SDK WebSocket Tests:** +```typescript +// Uses DOM types available in browser test environment +class MockWebSocket { + private listeners: Map> = new Map(); + + addEventListener(type: string, listener: EventListener): void { + // ... + } + + triggerMessage(data: unknown): void { + this.dispatchEvent(new MessageEvent('message', { data })); + } +} +``` + +**React Native Firebase WebSocket Tests:** +```typescript +// Avoids DOM types not available in React Native +class MockWebSocket { + private listeners: Map void>> = new Map(); + + addEventListener(type: string, listener: (event: any) => void): void { + // ... + } + + triggerMessage(data: unknown): void { + const event = new Event('message'); + (event as any).data = data; + this.dispatchEvent(event); + } +} +``` + +**Key Differences:** +- **EventListener type**: JS SDK uses `EventListener` (DOM type), RN uses `(event: any) => void` +- **MessageEvent**: JS SDK uses `MessageEvent` constructor, RN creates basic `Event` and attaches data property +- **Reason**: DOM types (`EventListener`, `MessageEvent`) are not available in React Native test environment + +**When Porting WebSocket Tests:** +- โœ… Replace `EventListener` type with `(event: any) => void` +- โœ… Replace `new MessageEvent()` with `new Event()` + manual data property +- โœ… Keep mock behavior identical, just adapt the types + +--- + +## ๐Ÿ”ง Firebase Component System (INTENTIONAL - Do Not Port) + +### Component Registration +**Firebase JS SDK has:** +```typescript +// src/index.ts +_registerComponent( + new Component(AI_TYPE, factory, ComponentType.PUBLIC).setMultipleInstances(true) +); +registerVersion(name, version); +``` + +**React Native Firebase does NOT have:** +- Component registration system +- `factory-browser.ts` or `factory-node.ts` files +- `_registerComponent` calls +- Multiple instance management via components + +**Reason**: React Native Firebase uses a different initialization pattern without the component system. The `getAI()` function directly returns an AI instance instead of using providers. + +### AIService Differences +**Firebase JS SDK AIService:** +- Implements `_FirebaseService` interface +- Has `_delete()` method +- Has `chromeAdapterFactory` constructor parameter +- Has options getter/setter methods +- Uses `Provider` and `Provider` + +**React Native Firebase AIService:** +- Simpler class without `_FirebaseService` interface +- No `_delete()` method +- No `chromeAdapterFactory` +- No options property management +- Direct `FirebaseAuthTypes.Module` and `FirebaseAppCheckTypes.Module` types + +**Reason**: Different dependency injection and lifecycle management patterns. + +--- + +## ๐ŸŒ Browser-Specific Features (INTENTIONAL - Do Not Port) + +### Chrome On-Device AI / Hybrid Mode +**Firebase JS SDK has:** +- `src/methods/chrome-adapter.ts` - Chrome's on-device AI integration +- `src/types/chrome-adapter.ts` - ChromeAdapter interface +- `src/types/language-model.ts` - Chrome Prompt API types +- `HybridParams` type with `mode`, `onDeviceParams`, `inCloudParams` +- `InferenceMode` enum with on-device/in-cloud options +- `getGenerativeModel()` accepts `HybridParams | ModelParams` + +**React Native Firebase does NOT have:** +- Any chrome-adapter related files +- HybridParams type +- On-device AI functionality +- `getGenerativeModel()` only accepts `ModelParams` + +**Reason**: Chrome's on-device AI is browser-specific and not available in React Native. This is a web-only feature. + +### Browser-Specific APIs Not Used +- No `window` object references (except in comments/docs) +- No `document` object usage (except in JSDoc examples) +- No Service Workers +- No localStorage/sessionStorage +- No Web Components +- No DOM manipulation + +**Reason**: React Native doesn't have these browser APIs. + +--- + +## ๐Ÿ”Œ Polyfills (UNIQUE TO REACT NATIVE - Do Not Remove) + +### React Native Requires Polyfills +**React Native Firebase has:** +```typescript +// lib/polyfills.ts +import { polyfillGlobal } from 'react-native/Libraries/Utilities/PolyfillFunctions'; +import { ReadableStream } from 'web-streams-polyfill/dist/ponyfill'; +import { fetch, Headers, Request, Response } from 'react-native-fetch-api'; +import 'text-encoding'; // TextEncoder/TextDecoder + +polyfillGlobal('fetch', ...) // with reactNative: { textStreaming: true } +polyfillGlobal('ReadableStream', ...) +``` + +**Plus:** +- `lib/types/polyfills.d.ts` - Type declarations for polyfills +- Custom `RequestInit` interface extension for `reactNative.textStreaming` + +**Firebase JS SDK does NOT need:** +- Any polyfills (native browser/Node support) +- Special fetch configuration + +**Reason**: React Native doesn't have native fetch streaming, ReadableStream, or TextEncoder. These are critical for AI streaming responses. + +**IMPORTANT**: When porting new features, ensure they use these polyfilled APIs, not browser-native ones. + +--- + +## ๐Ÿ”Œ WebSocket Implementation Differences (INTENTIONAL - Adapt When Porting) + +### React Native WebSocket Limitations + +**Firebase JS SDK WebSocket Handler:** +```typescript +// src/websocket.ts +connect(url: string): Promise { + this.ws = new WebSocket(url); + this.ws.binaryType = 'blob'; // Set binary type to blob + // ... +} + +// Message handler expects Blob +const messageListener = (event: MessageEvent): void => { + const data = await event.data.text(); // Assumes Blob + // ... +}; +``` + +**React Native Firebase WebSocket Handler:** +```typescript +// lib/websocket.ts +connect(url: string): Promise { + this.ws = new WebSocket(url); + // Note: binaryType is not supported in React Native's WebSocket implementation. + // We handle ArrayBuffer, Blob, and string data types in the message listener instead. + // ... +} + +// Message handler detects data type dynamically +const messageListener = async (event: any): Promise => { + let data: string; + + if (event.data instanceof Blob) { + // Browser environment + data = await event.data.text(); + } else if (event.data instanceof ArrayBuffer) { + // React Native environment - binary data comes as ArrayBuffer + const decoder = new TextDecoder('utf-8'); + data = decoder.decode(event.data); + } else if (typeof event.data === 'string') { + // String data in all environments + data = event.data; + } + // ... +}; +``` + +**Key Differences:** +1. **No `binaryType` property**: React Native's WebSocket doesn't support setting `binaryType = 'blob'` +2. **ArrayBuffer in RN**: Binary data arrives as `ArrayBuffer` in React Native, not `Blob` +3. **Runtime type detection**: Must check `event.data` type at runtime instead of configuring upfront +4. **TextDecoder usage**: Need to manually decode ArrayBuffer to string using TextDecoder + +### WebSocket URL Construction + +**Firebase JS SDK:** +```typescript +// Uses standard URL class +const url = new URL(`wss://${domain}/path`); +url.searchParams.set('key', apiKey); +return url.toString(); +``` + +**React Native Firebase:** +```typescript +// lib/requests/request.ts - WebSocketUrl class +toString(): string { + // Manually construct URL to avoid React Native URL API issues + const baseUrl = `wss://${DEFAULT_DOMAIN}`; + const pathname = this.pathname; + const queryString = `key=${encodeURIComponent(this.apiSettings.apiKey)}`; + + return `${baseUrl}${pathname}?${queryString}`; +} +``` + +**Reason**: +- React Native has URL API quirks/limitations, so we manually construct WebSocket URLs +- Manual string concatenation is more reliable than URL class in RN environment + +### When Porting WebSocket Features + +**DO:** +- โœ… Remove or comment out `binaryType` assignments +- โœ… Add runtime type checking for `event.data` (Blob | ArrayBuffer | string) +- โœ… Use TextDecoder for ArrayBuffer conversion +- โœ… Manually construct WebSocket URLs with string concatenation +- โœ… Test on both iOS and Android (they may behave slightly differently) + +**DON'T:** +- โŒ Assume `binaryType` can be set +- โŒ Assume binary data will be Blob +- โŒ Use URL class for WebSocket URL construction +- โŒ Remove ArrayBuffer handling code + +--- + +## ๐Ÿ“ฆ Package Dependencies + +### Firebase JS SDK Dependencies +```json +{ + "@firebase/app": "0.x", + "@firebase/component": "0.7.0", + "@firebase/logger": "0.5.0", + "@firebase/util": "1.13.0", + "@firebase/app-check-interop-types": "0.3.3", + "@firebase/auth-interop-types": "...", + "tslib": "^2.1.0" +} +``` + +### React Native Firebase Dependencies +```json +{ + "@react-native-firebase/app": "23.5.0", + "react-native-fetch-api": "^3.0.0", + "web-streams-polyfill": "^4.2.0", + "text-encoding": "^0.7.0" +} +``` + +**Key Differences:** +- RN uses `@react-native-firebase/app` instead of `@firebase/app` +- RN doesn't use `@firebase/component`, `@firebase/util`, or interop types +- RN has polyfill dependencies +- RN uses `FirebaseAuthTypes` and `FirebaseAppCheckTypes` from their respective RN packages + +--- + +## ๐Ÿ”„ Import Pattern Differences + +### Firebase JS SDK Imports +```typescript +import { FirebaseApp, getApp, _getProvider } from '@firebase/app'; +import { Provider } from '@firebase/component'; +import { getModularInstance } from '@firebase/util'; +import { FirebaseAuthInternal } from '@firebase/auth-interop-types'; +import { FirebaseAppCheckInternal } from '@firebase/app-check-interop-types'; +``` + +### React Native Firebase Imports +```typescript +import { getApp, ReactNativeFirebase } from '@react-native-firebase/app'; +import { FirebaseAuthTypes } from '@react-native-firebase/auth'; +import { FirebaseAppCheckTypes } from '@react-native-firebase/app-check'; +``` + +**When porting code:** +- Replace `@firebase/*` imports with `@react-native-firebase/*` equivalents +- Replace `FirebaseApp` type with `ReactNativeFirebase.FirebaseApp` +- Remove `Provider`, `Component`, `getModularInstance`, interop types +- Use direct RN module types instead of internal/interop types + +--- + +## ๐Ÿ—๏ธ Helper Files + +### Firebase JS SDK Has +- `src/helpers.ts` - `encodeInstanceIdentifier()`, `decodeInstanceIdentifier()` for component system +- These are used with the component provider system + +### React Native Firebase +- Does NOT have `helpers.ts` file +- Doesn't need instance identifier encoding/decoding + +**Reason**: No component provider system in RN. + +--- + +## ๐Ÿ“‹ Public Types & Interface Differences + +### AIOptions Interface +**Firebase JS SDK:** +```typescript +export interface AIOptions { + backend?: Backend; + useLimitedUseAppCheckTokens?: boolean; +} +``` + +**React Native Firebase:** +```typescript +export interface AIOptions { + backend?: Backend; + useLimitedUseAppCheckTokens?: boolean; + appCheck?: FirebaseAppCheckTypes.Module | null; + auth?: FirebaseAuthTypes.Module | null; +} +``` + +**Difference**: RN version includes `appCheck` and `auth` directly in options because it doesn't use providers. + +### AI Interface +**Firebase JS SDK:** +```typescript +export interface AI { + app: FirebaseApp; + backend: Backend; + options?: AIOptions; + location: string; // deprecated +} +``` + +**React Native Firebase:** +```typescript +export interface AI { + app: ReactNativeFirebase.FirebaseApp; + backend: Backend; + options?: Omit; + location: string; + appCheck?: FirebaseAppCheckTypes.Module | null; + auth?: FirebaseAuthTypes.Module | null; +} +``` + +**Differences**: +- Different app type +- RN includes `appCheck` and `auth` directly +- Options type excludes `backend` in RN + +--- + +## ๐Ÿšซ Web-Only Entry Points (Do Not Port) + +### Node.js Specific Files +**Firebase JS SDK has:** +- `src/index.node.ts` - Node.js specific entry point +- Separate build targets for Node vs Browser + +**React Native Firebase:** +- Single entry point (`lib/index.ts`) +- No platform-specific entry points + +**Reason**: React Native is a unified mobile runtime, not split between browser/Node like web SDK. + +--- + +## ๐ŸŽจ Build & Configuration Differences + +### Firebase JS SDK +- Uses Rollup for building +- Has `rollup.config.js` +- Has `api-extractor.json` for API documentation +- Multiple export targets (browser, node, esm, cjs) + +### React Native Firebase +- Uses `react-native-builder-bob` for building +- Configured in `package.json` +- Outputs to `dist/commonjs/` and `dist/module/` +- TypeScript definitions in `dist/typescript/` + +**Reason**: Different build tools and conventions for web vs React Native packages. + +--- + +## โœ… When Porting New Features - Checklist + +Use this checklist when identifying new APIs in Firebase JS SDK to port: + +### 1. **Identify the Feature** + - What files are involved in JS SDK? + - What public APIs are exposed? + - What types are exported? + +### 2. **Check for Browser-Specific Code** + - โŒ Skip if uses Chrome on-device AI (Hybrid mode, ChromeAdapter) + - โŒ Skip if uses `window`, `document`, Service Workers + - โŒ Skip if uses Web-specific APIs unavailable in RN + - โœ… Proceed if uses only polyfilled APIs (fetch, ReadableStream, TextEncoder) + +### 3. **Verify Dependencies** + - Do all dependencies have RN equivalents? + - Can it work with the polyfilled environment? + - Does it need any native modules? + +### 4. **Map the Files** + - `src/` โ†’ `lib/` + - `src/*.test.ts` โ†’ `__tests__/*.test.ts` + - Update imports: `@firebase/*` โ†’ `@react-native-firebase/*` + +### 5. **Adapt the Implementation** + - Remove component provider logic + - Replace Firebase app types with RN types + - Ensure fetch uses `reactNative: { textStreaming: true }` + - Remove `_FirebaseService` interface usage + - Add to `lib/index.ts` exports (not `api.ts`) + +### 6. **Update Types** + - Replace `FirebaseApp` with `ReactNativeFirebase.FirebaseApp` + - Use direct module types (not interop types) + - Add any new types to appropriate files in `lib/types/` + +### 7. **Verify Polyfills** + - Ensure it works with polyfilled fetch + - Ensure it works with polyfilled ReadableStream + - Test streaming functionality + +### 8. **Adapt WebSocket Code** (If applicable) + - Remove `binaryType` property assignments (not supported in RN) + - Add runtime type detection for message data (Blob | ArrayBuffer | string) + - Use TextDecoder for ArrayBuffer to string conversion + - Manually construct WebSocket URLs (avoid URL class) + - Test on both iOS and Android + +### 9. **Testing** (CRITICAL - Always Port Tests) + - Port tests from `src/*.test.ts` to `__tests__/*.test.ts` + - Maintain same file names for easy identification + - Adapt test utilities for React Native testing environment: + - Replace Karma/Mocha patterns with Jest + - Replace browser-specific mocks + - Use React Native Firebase test utilities + - **Follow ESLint requirements** (see ยง 4. ESLint Requirements for Tests): + - Import Jest globals: `import { describe, expect, it, jest } from '@jest/globals'` + - Use regular functions, NOT arrow functions for `describe()`/`it()` + - Cast test objects: `as ReactNativeFirebase.FirebaseApp` + - Type mock functions: `jest.fn<() => Promise>()` + - **Adapt WebSocket test mocks** (see ยง 6. WebSocket Test Mocking Differences): + - Replace `EventListener` type with `(event: any) => void` + - Replace `MessageEvent` with `Event` + manual data property assignment + - Ensure test coverage matches JS SDK + - **Verify tests pass linting**: Check with linter before committing + - Add e2e tests if needed in `e2e/` + - **Never skip porting tests** - they're critical for maintaining quality + +--- + +## ๐Ÿ“ Example: How to Compare Packages + +### Step 1: Compare Exports +```bash +# Check what's exported in JS SDK api.ts +grep "^export" /path/to/firebase-js-sdk/packages/ai/src/api.ts + +# Check what's exported in RN index.ts +grep "^export" /path/to/react-native-firebase/packages/ai/lib/index.ts +``` + +### Step 2: Compare Model Classes +```bash +# List model classes in JS SDK +ls /path/to/firebase-js-sdk/packages/ai/src/models/ + +# List model classes in RN Firebase +ls /path/to/react-native-firebase/packages/ai/lib/models/ +``` + +### Step 3: Filter Known Differences +- **Ignore** chrome-adapter related differences +- **Ignore** factory files, component registration +- **Ignore** hybrid/on-device features +- **Evaluate** new model classes, methods, types + +### Step 4: Identify Real Gaps +Any difference NOT documented in these cursor rules is a potential feature gap. + +--- + +## ๐ŸŽฏ Summary of Known Differences + +| Feature | Firebase JS SDK | React Native Firebase | Reason | +|---------|----------------|----------------------|--------| +| **Source Folder** | `src/` | `lib/` | Convention | +| **Entry Point** | `index.ts` + `api.ts` | `index.ts` only | No component system | +| **Tests Location** | `src/` | `__tests__/` | Convention (tests still ported!) | +| **Component Registration** | โœ… Has | โŒ Doesn't have | Different initialization | +| **Factory Files** | โœ… Has | โŒ Doesn't have | No component system | +| **Chrome Adapter** | โœ… Has | โŒ Doesn't have | Browser-only feature | +| **Hybrid Mode** | โœ… Has | โŒ Doesn't have | Browser-only feature | +| **HybridParams** | โœ… Has | โŒ Doesn't have | Browser-only feature | +| **Polyfills** | โŒ Doesn't need | โœ… Requires | RN environment | +| **Dependencies** | `@firebase/*` | `@react-native-firebase/*` | Different ecosystem | +| **AIService** | Complex (with _FirebaseService) | Simple | Different architecture | +| **Type Files** | chrome-adapter.ts, language-model.ts | polyfills.d.ts | Platform-specific | +| **Node Entry** | โœ… Has index.node.ts | โŒ Doesn't have | Unified runtime | +| **WebSocket binaryType** | โœ… Sets to 'blob' | โŒ Not supported | RN WebSocket limitation | +| **WebSocket Data** | Expects Blob | Runtime detection (Blob/ArrayBuffer/string) | RN sends ArrayBuffer | +| **WebSocket URL** | Uses URL class | Manual string construction | RN URL API issues | +| **WebSocket Test Mocks** | Uses EventListener, MessageEvent | Uses function types, Event + data | DOM types unavailable in RN | + +--- + +## ๐Ÿ”„ Version Tracking + +- **Firebase JS SDK Version**: 2.6.0 (as of this comparison) +- **React Native Firebase Version**: 23.5.0 +- **Last Comparison Date**: 2025-11-19 + +When updating, check Firebase JS SDK changelog for new features and re-evaluate what needs porting. + +--- + +## ๐Ÿ’ก Tips for AI-Assisted Porting + +When asking AI to compare packages: + +**Good prompt:** +> "Compare firebase-js-sdk/packages/ai with react-native-firebase/packages/ai. Ignore differences documented in the cursor rules. What new features in JS SDK need to be ported?" + +**Bad prompt:** +> "What are all the differences between these packages?" +> (This will list all the intentional differences too) + +**Focus areas for comparison:** +- New model classes (e.g., LiveGenerativeModel, TemplateGenerativeModel) +- New methods on existing models +- New types/interfaces for public APIs +- New request/response types +- New exported functions from api.ts + +**Ignore for comparison:** +- Component registration code +- Factory files +- Chrome adapter/hybrid features +- Helper files for providers +- Build configuration differences +- Test file locations (but NOT test content - tests must be ported!) +- WebSocket `binaryType` differences (RN uses runtime detection instead) +- WebSocket URL construction methods (RN uses manual string building) +- WebSocket test mock types (RN uses function types, not DOM EventListener/MessageEvent) + +--- + +## ๐Ÿ”„ Incremental Porting Workflow (FOR AI ASSISTANTS) + +When asked to port features from Firebase JS SDK to React Native Firebase AI, follow this **incremental, commit-per-feature** approach: + +### Core Workflow Rules + +1. **ONE feature at a time** - Never implement multiple features in one session +2. **ALWAYS show the plan first** - Present analysis before any implementation +3. **WAIT for approval** - Don't proceed without explicit permission +4. **PAUSE after implementation** - Show changes, then wait for USER to write commit message and commit +5. **RESPECT known differences** - Skip anything documented in this rules file + +### Feature Identification Priority + +When comparing packages, identify missing features in this order: +1. **High Priority**: Core API functions (getLiveGenerativeModel, etc.) +2. **Medium Priority**: Model classes (LiveGenerativeModel, TemplateGenerativeModel) +3. **Low Priority**: Helper methods, utilities, optimizations + +### Implementation Steps (Per Feature) + +#### Step 1: Analysis & Proposal +``` +Present: +- Feature name and description +- Files involved in JS SDK +- Required adaptations for RN +- Browser-specific checks (skip if found) +- Estimated complexity +- Ask: "Ready to implement?" +``` + +#### Step 2: Implementation (After Approval Only) +``` +Execute: +- Create files in correct locations (src/ โ†’ lib/) +- Adapt imports (@firebase/* โ†’ @react-native-firebase/*) +- Remove browser-specific code +- Use polyfilled APIs correctly +- Update lib/index.ts exports +- Port tests to __tests__/ (MUST DO - keep same filenames) +- Adapt tests for React Native/Jest environment +``` + +#### Step 3: Review (USER Commits) +``` +Show: +- Summary of all changes +- List of files created/modified +- Any notable adaptations made + +Then: WAIT for user to write commit message and commit +Do NOT suggest commit messages +Do NOT proceed to next feature until user says they've committed +``` + +### Example Interaction Pattern + +``` +AI: "Found 3 missing features. Start with LiveGenerativeModel (High Priority)?" +User: "yes" + +AI: "Analysis: [shows scope] Ready to implement?" +User: "yes" + +AI: [implements] "โœ… Complete. Changes: +- Created lib/models/live-generative-model.ts +- Created lib/methods/live-session.ts +- Created lib/websocket.ts +- Updated lib/index.ts exports +- Ported __tests__/live-generative-model.test.ts +- Ported __tests__/live-session.test.ts +Review and commit when ready." + +User: [reviews, writes commit, commits] "committed" + +AI: "Next feature: TemplateGenerativeModel. Proceed?" +``` + +### Error Handling + +If you discover: +- **Browser-specific code**: Stop, explain why, ask for guidance +- **Missing dependencies**: List needed packages, ask for approval +- **Unclear adaptations**: Present options, let user decide + +### DO NOT: +- โŒ Implement multiple features at once +- โŒ Proceed without showing the plan +- โŒ Suggest or write commit messages +- โŒ Continue to next feature before user confirms they've committed +- โŒ Skip browser-specific checks +- โŒ Skip porting tests (tests are MANDATORY) + +--- + +## ๐Ÿ“š Related Documentation + +- [React Native Firebase AI Documentation](https://rnfirebase.io/ai/usage) +- [Firebase JS SDK AI Package](https://github.com/firebase/firebase-js-sdk/tree/main/packages/ai) +- [React Native Firebase Architecture](https://rnfirebase.io/) + +--- + +**Last Updated**: 2025-11-19 +**Maintained By**: React Native Firebase Team + diff --git a/.cursor/rules/ai/porting-workflow.md b/.cursor/rules/ai/porting-workflow.md new file mode 100644 index 0000000000..1132b26962 --- /dev/null +++ b/.cursor/rules/ai/porting-workflow.md @@ -0,0 +1,318 @@ +# Feature Porting Workflow + +Quick reference guide for porting features from Firebase JS SDK to React Native Firebase AI using AI assistance. + +--- + +## ๐Ÿš€ Quick Start + +**To begin porting, use this command:** + +``` +Follow @porting-workflow.md + +Firebase JS SDK location: [YOUR_PATH]/firebase-js-sdk/packages/ai + +Start at Step 1: Discovery +``` + +The workflow will: +1. Compare packages using rules from `ai-package.md` +2. List missing features by priority +3. Guide you through porting one feature at a time + +--- + +## ๐ŸŽฏ Three-Step Process + +### **Step 1: Discovery** - What's Missing? + +**When user provides Firebase JS SDK path, execute the following:** + +Compare Firebase JS SDK AI package with React Native Firebase AI package. + +**Paths:** +- Firebase JS SDK: [user-provided path] +- React Native Firebase: Current workspace `packages/ai/` + +**Instructions:** +1. Read `/.cursor/rules/ai/ai-package.md` to understand known differences +2. Compare the packages and identify missing features +3. Categorize by priority: + - ๐Ÿ”ด **HIGH**: Core API functions (e.g., getLiveGenerativeModel, getTemplateGenerativeModel) + - ๐ŸŸก **MEDIUM**: Model classes (e.g., LiveGenerativeModel, TemplateGenerativeModel) + - ๐ŸŸข **LOW**: Helper methods, utilities, optimizations + - โšช **SKIP**: Browser-specific features (Chrome adapter, Hybrid mode) +4. For each missing feature, specify: + - Feature name + - Portability (portable vs browser-specific) + - Files involved in JS SDK + - Brief description + +**Output format:** +``` +Found X missing features: + +๐Ÿ”ด HIGH PRIORITY: +1. [Feature] - [Description] (portable/browser-specific) + Files: [list] + +๐ŸŸก MEDIUM PRIORITY: +1. [Feature] - [Description] (portable/browser-specific) + Files: [list] + +๐ŸŸข LOW PRIORITY: +... + +โšช SKIP (Browser-only): +1. [Feature] - [Reason] +``` + +Then ask: "Which feature would you like to port first?" + +### **Step 2: Port One Feature** - Incremental Implementation + +**When user says "port [FEATURE_NAME]" or selects a feature, execute:** + +#### **Phase 1 - Analysis** + +Analyze the feature for porting: + +1. **Files involved in JS SDK:** + - List all source files + - List all test files + +2. **Required RN adaptations:** + - Import changes (`@firebase/*` โ†’ `@react-native-firebase/*`) + - Type changes (`FirebaseApp` โ†’ `ReactNativeFirebase.FirebaseApp`) + - Polyfill requirements (fetch, ReadableStream, TextEncoder) + - Component system removal (if applicable) + +3. **Browser-specific checks:** + - Scan for `window`, `document`, DOM APIs + - Check for Service Workers, localStorage + - Identify if Chrome adapter or Hybrid mode related + +4. **Complexity estimate:** + - Low: Simple file port, minimal changes + - Medium: Multiple files, some adaptations needed + - High: Complex dependencies, significant adaptations + +5. **Dependencies:** + - List any new packages needed + - Verify RN compatibility + +**Then ask:** "Ready to implement? (yes/no)" + +#### **Phase 2 - Implementation** (After user approval only) + +Execute these steps: + +1. **Create source files:** + - Port from `src/` to `lib/` + - Apply all adaptations from Phase 1 + - Ensure polyfills are used correctly + +2. **Update exports:** + - Add to `lib/index.ts` + - Add types to `lib/public-types.ts` if needed + +3. **Port tests:** + - Port from `src/*.test.ts` to `__tests__/*.test.ts` + - Keep identical filenames + - Convert Mocha/Chai to Jest + - Remove browser-specific test utilities + - Ensure coverage matches JS SDK + +4. **Show all changes:** + - List every file created/modified + - Note key adaptations made + +#### **Phase 3 - Review** (USER commits) + +1. **Summarize:** + ``` + โœ… Complete. Changes made: + + Source files: + - Created lib/[...] + - Updated lib/index.ts + + Test files: + - Ported __tests__/[...] (adapted for Jest) + + Key adaptations: + - [List any notable changes] + ``` + +2. **Wait for user:** + - Say: "Review and commit when ready." + - Do NOT suggest commit messages + - Do NOT proceed to next feature + - Wait for user to say "committed" or similar + +### **Step 3: Continue** - Next Feature + +**When user says "committed", "next", or similar:** + +1. Reference the original discovery list +2. Ask: "Next feature: [FEATURE_NAME] (Priority: [HIGH/MEDIUM/LOW]). Proceed?" +3. If user says yes, go to Step 2 with that feature +4. If user says no or picks different feature, go to Step 2 with their choice +5. If all features done, congratulate and summarize what was ported + +--- + +## โœ… Quality Checklist (Before Committing) + +- [ ] All imports updated (`@firebase/*` โ†’ `@react-native-firebase/*`) +- [ ] Polyfills used correctly (fetch, ReadableStream, TextEncoder) +- [ ] No browser-specific APIs (`window`, `document`, etc.) +- [ ] Exports added to `lib/index.ts` +- [ ] Types exported from `lib/public-types.ts` if needed +- [ ] Tests ported to `__tests__/` with same filenames +- [ ] Tests adapted for Jest (from Karma/Mocha) +- [ ] Test coverage matches JS SDK +- [ ] Files in correct locations (`src/` โ†’ `lib/`) + +--- + +## ๐Ÿšฆ Example Session + +``` +USER: Follow @porting-workflow.md + Firebase JS SDK location: /Users/me/firebase-js-sdk/packages/ai + Start at Step 1: Discovery + +AI: [Reads ai-package.md rules, compares packages] + + Found 3 missing features: + + ๐Ÿ”ด HIGH PRIORITY: + 1. Live API (getLiveGenerativeModel, LiveGenerativeModel, LiveSession) + - Real-time bidirectional communication via WebSocket + - Files: src/api.ts, src/models/live-generative-model.ts, + src/methods/live-session.ts, src/websocket.ts + - Portable โœ“ + + ๐ŸŸก MEDIUM PRIORITY: + 1. Template Models (getTemplateGenerativeModel, TemplateGenerativeModel) + - Server-side template execution + - Files: src/api.ts, src/models/template-generative-model.ts + - Portable โœ“ + + โšช SKIP (Browser-only): + 1. Hybrid Mode - Chrome on-device AI (not available in RN) + + Which feature would you like to port first? + +USER: Live API + +AI: [Phase 1 - Analysis] + + Live API Analysis: + + Files involved: + - src/api.ts (add getLiveGenerativeModel export) + - src/models/live-generative-model.ts + - src/methods/live-session.ts, live-session-helpers.ts + - src/websocket.ts + - src/types/live-responses.ts + - Tests: 3 test files + + Required adaptations: + - Import: @firebase/app โ†’ @react-native-firebase/app + - WebSocket: Native API available in RN โœ“ + - Polyfills: Uses fetch (already polyfilled) โœ“ + + Browser-specific checks: + - No window/document usage โœ“ + - WebSocket is native in RN โœ“ + + Complexity: Medium-High (multiple files, WebSocket handling) + + Ready to implement? (yes/no) + +USER: yes + +AI: [Phase 2 - Implementation] + [Creates all files...] + + โœ… Complete. Changes made: + + Source files: + - Created lib/models/live-generative-model.ts + - Created lib/methods/live-session.ts + - Created lib/methods/live-session-helpers.ts + - Created lib/websocket.ts + - Created lib/types/live-responses.ts + - Updated lib/index.ts (added getLiveGenerativeModel, LiveGenerativeModel exports) + + Test files: + - Ported __tests__/live-generative-model.test.ts (adapted for Jest) + - Ported __tests__/live-session.test.ts (adapted for Jest) + - Ported __tests__/websocket.test.ts (adapted for Jest) + + Key adaptations: + - Changed imports from @firebase/* to @react-native-firebase/* + - Removed component provider logic + - WebSocket uses native RN WebSocket API + + Review and commit when ready. + +USER: [reviews code, writes commit message, commits] + committed + +AI: Next feature: Template Models (Priority: MEDIUM). Proceed? + +USER: yes + +[Process repeats...] +``` + +--- + +## ๐ŸŽฏ Key Principles + +1. **One feature = One commit** - Clean git history +2. **Show before doing** - Always present plan first +3. **Wait for approval** - User stays in control +4. **Pause between features** - Time to review and commit +5. **Respect known differences** - Don't port browser-specific code + +--- + +## ๐Ÿ“š Related Files + +- `/.cursor/rules/ai/ai-package.md` - Known differences and porting rules +- `/.cursor/rules/ai/porting-workflow.md` - This file +- `README.md` - Package documentation + +--- + +## ๐Ÿ“ Usage Summary + +**Single command to start:** + +``` +Follow @porting-workflow.md + +Firebase JS SDK location: [YOUR_PATH]/firebase-js-sdk/packages/ai + +Start at Step 1: Discovery +``` + +**The AI will:** +- โœ… Read the porting rules automatically +- โœ… Compare packages and list missing features +- โœ… Guide you through porting one feature at a time +- โœ… Wait for your approval at each step +- โœ… Pause after each feature for you to commit + +**You control:** +- Which features to port +- When to proceed with implementation +- Commit messages and timing + +That's it! ๐Ÿš€ + diff --git a/package.json b/package.json index b3adfc991b..714a5ab4d8 100644 --- a/package.json +++ b/package.json @@ -57,6 +57,7 @@ "tests:macos:pod:install": "cd tests && rm -f macos/Podfile.lock && cd macos && pod install", "tests:macos:manual": "cd tests && yarn react-native run-macos", "tests:macos:test-cover": "cd tests && npx jet --target=macos --coverage", + "format:js": "prettier --write \"packages/**/*.{js,ts,tsx}\"", "format:markdown": "prettier --write \"docs/**/*.md\"" }, "devDependencies": { diff --git a/packages/ai/__tests__/count-tokens.test.ts b/packages/ai/__tests__/count-tokens.test.ts index 850e7928f5..96a47ef86a 100644 --- a/packages/ai/__tests__/count-tokens.test.ts +++ b/packages/ai/__tests__/count-tokens.test.ts @@ -18,11 +18,10 @@ import { describe, expect, it, afterEach, jest, beforeEach } from '@jest/globals import { BackendName, getMockResponse } from './test-utils/mock-response'; import * as request from '../lib/requests/request'; import { countTokens } from '../lib/methods/count-tokens'; -import { CountTokensRequest, RequestOptions } from '../lib/types'; +import { CountTokensRequest } from '../lib/types'; import { ApiSettings } from '../lib/types/internal'; import { Task } from '../lib/requests/request'; import { GoogleAIBackend } from '../lib/backend'; -import { SpiedFunction } from 'jest-mock'; import { mapCountTokensRequest } from '../lib/googleai-mappers'; const fakeApiSettings: ApiSettings = { @@ -59,12 +58,14 @@ describe('countTokens()', () => { expect(result.totalTokens).toBe(6); expect(result.totalBillableCharacters).toBe(16); expect(makeRequestStub).toHaveBeenCalledWith( - 'model', - Task.COUNT_TOKENS, - fakeApiSettings, - false, + expect.objectContaining({ + model: 'model', + task: Task.COUNT_TOKENS, + apiSettings: fakeApiSettings, + stream: false, + requestOptions: undefined, + }), expect.stringContaining('contents'), - undefined, ); }); @@ -83,12 +84,14 @@ describe('countTokens()', () => { expect(result.promptTokensDetails?.[0]?.modality).toBe('IMAGE'); expect(result.promptTokensDetails?.[0]?.tokenCount).toBe(1806); expect(makeRequestStub).toHaveBeenCalledWith( - 'model', - Task.COUNT_TOKENS, - fakeApiSettings, - false, + expect.objectContaining({ + model: 'model', + task: Task.COUNT_TOKENS, + apiSettings: fakeApiSettings, + stream: false, + requestOptions: undefined, + }), expect.stringContaining('contents'), - undefined, ); }); @@ -104,12 +107,14 @@ describe('countTokens()', () => { expect(result.totalTokens).toBe(258); expect(result).not.toHaveProperty('totalBillableCharacters'); expect(makeRequestStub).toHaveBeenCalledWith( - 'model', - Task.COUNT_TOKENS, - fakeApiSettings, - false, + expect.objectContaining({ + model: 'model', + task: Task.COUNT_TOKENS, + apiSettings: fakeApiSettings, + stream: false, + requestOptions: undefined, + }), expect.stringContaining('contents'), - undefined, ); }); @@ -130,16 +135,7 @@ describe('countTokens()', () => { }); describe('googleAI', () => { - let makeRequestStub: SpiedFunction< - ( - model: string, - task: Task, - apiSettings: ApiSettings, - stream: boolean, - body: string, - requestOptions?: RequestOptions, - ) => Promise - >; + let makeRequestStub: jest.SpiedFunction; beforeEach(() => { makeRequestStub = jest.spyOn(request, 'makeRequest'); @@ -155,12 +151,14 @@ describe('countTokens()', () => { await countTokens(fakeGoogleAIApiSettings, 'model', fakeRequestParams); expect(makeRequestStub).toHaveBeenCalledWith( - 'model', - Task.COUNT_TOKENS, - fakeGoogleAIApiSettings, - false, + expect.objectContaining({ + model: 'model', + task: Task.COUNT_TOKENS, + apiSettings: fakeGoogleAIApiSettings, + stream: false, + requestOptions: undefined, + }), JSON.stringify(mapCountTokensRequest(fakeRequestParams, 'model')), - undefined, ); }); }); diff --git a/packages/ai/__tests__/generate-content.test.ts b/packages/ai/__tests__/generate-content.test.ts index a47692192e..3f38b6ecf7 100644 --- a/packages/ai/__tests__/generate-content.test.ts +++ b/packages/ai/__tests__/generate-content.test.ts @@ -91,12 +91,14 @@ describe('generateContent()', () => { const result = await generateContent(fakeApiSettings, 'model', fakeRequestParams); expect(await result.response.text()).toContain('Mountain View, California'); expect(makeRequestStub).toHaveBeenCalledWith( - 'model', - Task.GENERATE_CONTENT, - fakeApiSettings, - false, + expect.objectContaining({ + model: 'model', + task: Task.GENERATE_CONTENT, + apiSettings: fakeApiSettings, + stream: false, + requestOptions: undefined, + }), expect.stringContaining('contents'), - undefined, ); }); @@ -112,12 +114,14 @@ describe('generateContent()', () => { expect(result.response.text()).toContain('Use Freshly Ground Coffee'); expect(result.response.text()).toContain('30 minutes of brewing'); expect(makeRequestStub).toHaveBeenCalledWith( - 'model', - Task.GENERATE_CONTENT, - fakeApiSettings, - false, + expect.objectContaining({ + model: 'model', + task: Task.GENERATE_CONTENT, + apiSettings: fakeApiSettings, + stream: false, + requestOptions: undefined, + }), expect.anything(), - undefined, ); }); @@ -137,12 +141,14 @@ describe('generateContent()', () => { expect(result.response.usageMetadata?.candidatesTokensDetails?.[0]?.modality).toEqual('TEXT'); expect(result.response.usageMetadata?.candidatesTokensDetails?.[0]?.tokenCount).toEqual(76); expect(makeRequestStub).toHaveBeenCalledWith( - 'model', - Task.GENERATE_CONTENT, - fakeApiSettings, - false, + expect.objectContaining({ + model: 'model', + task: Task.GENERATE_CONTENT, + apiSettings: fakeApiSettings, + stream: false, + requestOptions: undefined, + }), expect.anything(), - undefined, ); }); @@ -155,12 +161,14 @@ describe('generateContent()', () => { expect(result.response.text()).toContain('Some information cited from an external source'); expect(result.response.candidates?.[0]!.citationMetadata?.citations.length).toBe(3); expect(makeRequestStub).toHaveBeenCalledWith( - 'model', - Task.GENERATE_CONTENT, - fakeApiSettings, - false, + expect.objectContaining({ + model: 'model', + task: Task.GENERATE_CONTENT, + apiSettings: fakeApiSettings, + stream: false, + requestOptions: undefined, + }), expect.anything(), - undefined, ); }); @@ -177,12 +185,14 @@ describe('generateContent()', () => { expect(() => result.response.text()).toThrow('SAFETY'); expect(makeRequestStub).toHaveBeenCalledWith( - 'model', - Task.GENERATE_CONTENT, - fakeApiSettings, - false, + expect.objectContaining({ + model: 'model', + task: Task.GENERATE_CONTENT, + apiSettings: fakeApiSettings, + stream: false, + requestOptions: undefined, + }), expect.anything(), - undefined, ); }); @@ -197,12 +207,14 @@ describe('generateContent()', () => { const result = await generateContent(fakeApiSettings, 'model', fakeRequestParams); expect(() => result.response.text()).toThrow('SAFETY'); expect(makeRequestStub).toHaveBeenCalledWith( - 'model', - Task.GENERATE_CONTENT, - fakeApiSettings, - false, + expect.objectContaining({ + model: 'model', + task: Task.GENERATE_CONTENT, + apiSettings: fakeApiSettings, + stream: false, + requestOptions: undefined, + }), expect.anything(), - undefined, ); }); @@ -214,12 +226,14 @@ describe('generateContent()', () => { const result = await generateContent(fakeApiSettings, 'model', fakeRequestParams); expect(result.response.text()).toBe(''); expect(makeRequestStub).toHaveBeenCalledWith( - 'model', - Task.GENERATE_CONTENT, - fakeApiSettings, - false, + expect.objectContaining({ + model: 'model', + task: Task.GENERATE_CONTENT, + apiSettings: fakeApiSettings, + stream: false, + requestOptions: undefined, + }), expect.anything(), - undefined, ); }); @@ -234,12 +248,14 @@ describe('generateContent()', () => { const result = await generateContent(fakeApiSettings, 'model', fakeRequestParams); expect(result.response.text()).toContain('Some text'); expect(makeRequestStub).toHaveBeenCalledWith( - 'model', - Task.GENERATE_CONTENT, - fakeApiSettings, - false, + expect.objectContaining({ + model: 'model', + task: Task.GENERATE_CONTENT, + apiSettings: fakeApiSettings, + stream: false, + requestOptions: undefined, + }), expect.anything(), - undefined, ); }); @@ -318,12 +334,14 @@ describe('generateContent()', () => { await generateContent(fakeGoogleAIApiSettings, 'model', fakeGoogleAIRequestParams); expect(makeRequestStub).toHaveBeenCalledWith( - 'model', - Task.GENERATE_CONTENT, - fakeGoogleAIApiSettings, - false, + expect.objectContaining({ + model: 'model', + task: Task.GENERATE_CONTENT, + apiSettings: fakeGoogleAIApiSettings, + stream: false, + requestOptions: undefined, + }), JSON.stringify(mapGenerateContentRequest(fakeGoogleAIRequestParams)), - undefined, ); }); }); diff --git a/packages/ai/__tests__/generative-model.test.ts b/packages/ai/__tests__/generative-model.test.ts index 7d29c501c8..7835469e37 100644 --- a/packages/ai/__tests__/generative-model.test.ts +++ b/packages/ai/__tests__/generative-model.test.ts @@ -65,12 +65,14 @@ describe('GenerativeModel', () => { .mockResolvedValue(mockResponse as Response); await genModel.generateContent('hello'); expect(makeRequestStub).toHaveBeenCalledWith( - 'publishers/google/models/my-model', - request.Task.GENERATE_CONTENT, - expect.anything(), - false, + expect.objectContaining({ + model: 'publishers/google/models/my-model', + task: request.Task.GENERATE_CONTENT, + apiSettings: expect.anything(), + stream: false, + requestOptions: {}, + }), expect.stringMatching(new RegExp(`myfunc|be friendly|${FunctionCallingMode.NONE}`)), - {}, ); makeRequestStub.mockRestore(); }); @@ -90,12 +92,14 @@ describe('GenerativeModel', () => { .mockResolvedValue(mockResponse as Response); await genModel.generateContent('hello'); expect(makeRequestStub).toHaveBeenCalledWith( - 'publishers/google/models/my-model', - request.Task.GENERATE_CONTENT, - expect.anything(), - false, + expect.objectContaining({ + model: 'publishers/google/models/my-model', + task: request.Task.GENERATE_CONTENT, + apiSettings: expect.anything(), + stream: false, + requestOptions: {}, + }), expect.stringContaining('be friendly'), - {}, ); makeRequestStub.mockRestore(); }); @@ -137,12 +141,14 @@ describe('GenerativeModel', () => { systemInstruction: { role: 'system', parts: [{ text: 'be formal' }] }, }); expect(makeRequestStub).toHaveBeenCalledWith( - 'publishers/google/models/my-model', - request.Task.GENERATE_CONTENT, - expect.anything(), - false, + expect.objectContaining({ + model: 'publishers/google/models/my-model', + task: request.Task.GENERATE_CONTENT, + apiSettings: expect.anything(), + stream: false, + requestOptions: {}, + }), expect.stringMatching(new RegExp(`be formal|otherfunc|${FunctionCallingMode.AUTO}`)), - {}, ); makeRequestStub.mockRestore(); }); @@ -196,12 +202,14 @@ describe('GenerativeModel', () => { .mockResolvedValue(mockResponse as Response); await genModel.startChat().sendMessage('hello'); expect(makeRequestStub).toHaveBeenCalledWith( - 'publishers/google/models/my-model', - request.Task.GENERATE_CONTENT, - expect.anything(), - false, + expect.objectContaining({ + model: 'publishers/google/models/my-model', + task: request.Task.GENERATE_CONTENT, + apiSettings: expect.anything(), + stream: false, + requestOptions: {}, + }), expect.stringMatching(new RegExp(`myfunc|be friendly|${FunctionCallingMode.NONE}`)), - {}, ); makeRequestStub.mockRestore(); }); @@ -221,12 +229,14 @@ describe('GenerativeModel', () => { .mockResolvedValue(mockResponse as Response); await genModel.startChat().sendMessage('hello'); expect(makeRequestStub).toHaveBeenCalledWith( - 'publishers/google/models/my-model', - request.Task.GENERATE_CONTENT, - expect.anything(), - false, + expect.objectContaining({ + model: 'publishers/google/models/my-model', + task: request.Task.GENERATE_CONTENT, + apiSettings: expect.anything(), + stream: false, + requestOptions: {}, + }), expect.stringContaining('be friendly'), - {}, ); makeRequestStub.mockRestore(); }); @@ -262,12 +272,14 @@ describe('GenerativeModel', () => { }) .sendMessage('hello'); expect(makeRequestStub).toHaveBeenCalledWith( - 'publishers/google/models/my-model', - request.Task.GENERATE_CONTENT, - expect.anything(), - false, + expect.objectContaining({ + model: 'publishers/google/models/my-model', + task: request.Task.GENERATE_CONTENT, + apiSettings: expect.anything(), + stream: false, + requestOptions: {}, + }), expect.stringMatching(new RegExp(`otherfunc|be formal|${FunctionCallingMode.AUTO}`)), - {}, ); makeRequestStub.mockRestore(); }); @@ -280,12 +292,14 @@ describe('GenerativeModel', () => { .mockResolvedValue(mockResponse as Response); await genModel.countTokens('hello'); expect(makeRequestStub).toHaveBeenCalledWith( - 'publishers/google/models/my-model', - request.Task.COUNT_TOKENS, - expect.anything(), - false, + expect.objectContaining({ + model: 'publishers/google/models/my-model', + task: request.Task.COUNT_TOKENS, + apiSettings: expect.anything(), + stream: false, + requestOptions: undefined, + }), expect.stringContaining('hello'), - undefined, ); makeRequestStub.mockRestore(); }); diff --git a/packages/ai/__tests__/imagen-model.test.ts b/packages/ai/__tests__/imagen-model.test.ts index e1194d0397..9198581364 100644 --- a/packages/ai/__tests__/imagen-model.test.ts +++ b/packages/ai/__tests__/imagen-model.test.ts @@ -63,20 +63,24 @@ describe('ImagenModel', () => { const prompt = 'A photorealistic image of a toy boat at sea.'; await imagenModel.generateImages(prompt); expect(makeRequestStub).toHaveBeenCalledWith( - 'publishers/google/models/my-model', - request.Task.PREDICT, - expect.anything(), - false, + expect.objectContaining({ + model: 'publishers/google/models/my-model', + task: request.Task.PREDICT, + apiSettings: expect.anything(), + stream: false, + requestOptions: undefined, + }), expect.stringMatching(new RegExp(`"prompt":"${prompt}"`)), - undefined, ); expect(makeRequestStub).toHaveBeenCalledWith( - 'publishers/google/models/my-model', - request.Task.PREDICT, - expect.anything(), - false, + expect.objectContaining({ + model: 'publishers/google/models/my-model', + task: request.Task.PREDICT, + apiSettings: expect.anything(), + stream: false, + requestOptions: undefined, + }), expect.stringContaining(`"sampleCount":1`), - undefined, ); }); @@ -106,28 +110,34 @@ describe('ImagenModel', () => { const prompt = 'A photorealistic image of a toy boat at sea.'; await imagenModel.generateImages(prompt); expect(makeRequestStub).toHaveBeenCalledWith( - 'publishers/google/models/my-model', - request.Task.PREDICT, - expect.anything(), - false, + expect.objectContaining({ + model: 'publishers/google/models/my-model', + task: request.Task.PREDICT, + apiSettings: expect.anything(), + stream: false, + requestOptions: undefined, + }), expect.stringContaining(`"negativePrompt":"${imagenModel.generationConfig?.negativePrompt}"`), - undefined, ); expect(makeRequestStub).toHaveBeenCalledWith( - 'publishers/google/models/my-model', - request.Task.PREDICT, - expect.anything(), - false, + expect.objectContaining({ + model: 'publishers/google/models/my-model', + task: request.Task.PREDICT, + apiSettings: expect.anything(), + stream: false, + requestOptions: undefined, + }), expect.stringContaining(`"sampleCount":${imagenModel.generationConfig?.numberOfImages}`), - undefined, ); expect(makeRequestStub).toHaveBeenCalledWith( - 'publishers/google/models/my-model', - request.Task.PREDICT, - expect.anything(), - false, + expect.objectContaining({ + model: 'publishers/google/models/my-model', + task: request.Task.PREDICT, + apiSettings: expect.anything(), + stream: false, + requestOptions: undefined, + }), expect.stringContaining(`"aspectRatio":"${imagenModel.generationConfig?.aspectRatio}"`), - undefined, ); }); diff --git a/packages/ai/__tests__/live-generative-model.test.ts b/packages/ai/__tests__/live-generative-model.test.ts new file mode 100644 index 0000000000..8190d8a896 --- /dev/null +++ b/packages/ai/__tests__/live-generative-model.test.ts @@ -0,0 +1,181 @@ +/** + * @license + * Copyright 2025 Google LLC + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +import { describe, expect, it, beforeEach, afterEach, jest } from '@jest/globals'; +import { type ReactNativeFirebase } from '@react-native-firebase/app'; +import { AI } from '../lib/public-types'; +import { LiveSession } from '../lib/methods/live-session'; +import { WebSocketHandler } from '../lib/websocket'; +import { GoogleAIBackend } from '../lib/backend'; +import { LiveGenerativeModel } from '../lib/models/live-generative-model'; +import { AIError } from '../lib/errors'; + +// A controllable mock for the WebSocketHandler interface +class MockWebSocketHandler implements WebSocketHandler { + connect = jest.fn<() => Promise>().mockResolvedValue(undefined); + send = jest.fn<(data: string | ArrayBuffer) => void>(); + close = jest.fn<() => Promise>().mockResolvedValue(undefined); + + private serverMessages: unknown[] = []; + private generatorController: { + resolve: () => void; + promise: Promise; + } | null = null; + + async *listen(): AsyncGenerator { + while (true) { + if (this.serverMessages.length > 0) { + yield this.serverMessages.shift(); + } else { + const promise = new Promise(resolve => { + this.generatorController = { resolve, promise: null! }; + }); + await promise; + } + } + } + + // Test method to simulate a message from the server + simulateServerMessage(message: object): void { + this.serverMessages.push(message); + if (this.generatorController) { + this.generatorController.resolve(); + this.generatorController = null; + } + } +} + +const fakeAI: AI = { + app: { + name: 'DEFAULT', + automaticDataCollectionEnabled: true, + options: { + apiKey: 'key', + projectId: 'my-project', + appId: 'my-appid', + }, + } as ReactNativeFirebase.FirebaseApp, + backend: new GoogleAIBackend(), + location: 'us-central1', +}; + +describe('LiveGenerativeModel', function () { + let mockHandler: MockWebSocketHandler; + + beforeEach(function () { + mockHandler = new MockWebSocketHandler(); + jest.useFakeTimers(); + }); + + afterEach(function () { + jest.restoreAllMocks(); + jest.useRealTimers(); + }); + + it('connect() should call handler.connect and send setup message', async function () { + const model = new LiveGenerativeModel(fakeAI, { model: 'my-model' }, mockHandler); + const connectPromise = model.connect(); + + // Ensure connect was called before simulating server response + expect(mockHandler.connect).toHaveBeenCalledTimes(1); + + // Wait for the setup message to be sent + await jest.runAllTimersAsync(); + + expect(mockHandler.send).toHaveBeenCalledTimes(1); + const setupMessage = JSON.parse((mockHandler.send as jest.Mock).mock.calls[0]![0] as string); + expect(setupMessage.setup.model).toContain('my-model'); + + // Simulate successful handshake and resolve the promise + mockHandler.simulateServerMessage({ setupComplete: true }); + const session = await connectPromise; + expect(session).toBeInstanceOf(LiveSession); + await session.close(); + }); + + it('connect() should throw if handshake fails', async function () { + const model = new LiveGenerativeModel(fakeAI, { model: 'my-model' }, mockHandler); + const connectPromise = model.connect(); + + // Wait for setup message + await jest.runAllTimersAsync(); + + // Simulate a failed handshake + mockHandler.simulateServerMessage({ error: 'handshake failed' }); + await expect(connectPromise).rejects.toThrow(AIError); + await expect(connectPromise).rejects.toThrow(/Server connection handshake failed/); + }); + + it('connect() should pass through connection errors', async function () { + (mockHandler.connect as jest.Mock<() => Promise>).mockRejectedValue( + new Error('Connection refused'), + ); + const model = new LiveGenerativeModel(fakeAI, { model: 'my-model' }, mockHandler); + await expect(model.connect()).rejects.toThrow('Connection refused'); + }); + + it('connect() should pass through setup parameters correctly', async function () { + const model = new LiveGenerativeModel( + fakeAI, + { + model: 'gemini-pro', + generationConfig: { temperature: 0.8 }, + systemInstruction: { role: 'system', parts: [{ text: 'Be a pirate' }] }, + }, + mockHandler, + ); + const connectPromise = model.connect(); + + // Wait for setup message + await jest.runAllTimersAsync(); + + const sentData = JSON.parse((mockHandler.send as jest.Mock).mock.calls[0]![0] as string); + expect(sentData.setup.generationConfig).toEqual({ temperature: 0.8 }); + expect(sentData.setup.systemInstruction.parts[0].text).toBe('Be a pirate'); + mockHandler.simulateServerMessage({ setupComplete: true }); + await connectPromise; + }); + + it('connect() should deconstruct generationConfig to send transcription configs in top level setup', async function () { + const model = new LiveGenerativeModel( + fakeAI, + { + model: 'gemini-pro', + generationConfig: { + temperature: 0.8, + inputAudioTranscription: {}, + outputAudioTranscription: {}, + }, + systemInstruction: { role: 'system', parts: [{ text: 'Be a pirate' }] }, + }, + mockHandler, + ); + const connectPromise = model.connect(); + + // Wait for setup message + await jest.runAllTimersAsync(); + + const sentData = JSON.parse((mockHandler.send as jest.Mock).mock.calls[0]![0] as string); + // inputAudioTranscription and outputAudioTranscription should be at the top-level setup message, + // rather than in the generationConfig. + expect(sentData.setup.generationConfig).toEqual({ temperature: 0.8 }); + expect(sentData.setup.inputAudioTranscription).toEqual({}); + expect(sentData.setup.outputAudioTranscription).toEqual({}); + expect(sentData.setup.systemInstruction.parts[0].text).toBe('Be a pirate'); + mockHandler.simulateServerMessage({ setupComplete: true }); + await connectPromise; + }); +}); diff --git a/packages/ai/__tests__/live-session.test.ts b/packages/ai/__tests__/live-session.test.ts new file mode 100644 index 0000000000..1a638e38a4 --- /dev/null +++ b/packages/ai/__tests__/live-session.test.ts @@ -0,0 +1,355 @@ +/** + * @license + * Copyright 2025 Google LLC + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +import { describe, expect, it, beforeEach, jest } from '@jest/globals'; +import { + FunctionResponse, + LiveResponseType, + LiveServerContent, + LiveServerToolCall, + LiveServerToolCallCancellation, +} from '../lib/types'; +import { LiveSession } from '../lib/methods/live-session'; +import { WebSocketHandler } from '../lib/websocket'; +import { AIError } from '../lib/errors'; +import { logger } from '../lib/logger'; +import { ReadableStream } from 'web-streams-polyfill'; + +class MockWebSocketHandler implements WebSocketHandler { + connect = jest.fn<() => Promise>().mockResolvedValue(undefined); + send = jest.fn<(data: string | ArrayBuffer) => void>(); + close = jest.fn<() => Promise>().mockResolvedValue(undefined); + + private messageQueue: unknown[] = []; + private streamClosed = false; + private listenerPromiseResolver: (() => void) | null = null; + + async *listen(): AsyncGenerator { + while (!this.streamClosed) { + if (this.messageQueue.length > 0) { + yield this.messageQueue.shift(); + } else { + // Wait until a new message is pushed or the stream is ended. + await new Promise(resolve => { + this.listenerPromiseResolver = resolve; + }); + } + } + } + + simulateServerMessage(message: object): void { + this.messageQueue.push(message); + if (this.listenerPromiseResolver) { + // listener is waiting for our message + this.listenerPromiseResolver(); + this.listenerPromiseResolver = null; + } + } + + endStream(): void { + this.streamClosed = true; + if (this.listenerPromiseResolver) { + this.listenerPromiseResolver(); + this.listenerPromiseResolver = null; + } + } +} + +describe('LiveSession', function () { + let mockHandler: MockWebSocketHandler; + let session: LiveSession; + let serverMessagesGenerator: AsyncGenerator; + + beforeEach(function () { + mockHandler = new MockWebSocketHandler(); + serverMessagesGenerator = mockHandler.listen(); + session = new LiveSession(mockHandler, serverMessagesGenerator); + }); + + describe('send()', function () { + it('should format and send a valid text message', async function () { + await session.send('Hello there'); + expect(mockHandler.send).toHaveBeenCalledTimes(1); + const sentData = JSON.parse((mockHandler.send as jest.Mock).mock.calls[0]![0] as string); + expect(sentData).toEqual({ + clientContent: { + turns: [{ role: 'user', parts: [{ text: 'Hello there' }] }], + turnComplete: true, + }, + }); + }); + + it('should format and send a message with an array of Parts', async function () { + const parts = [ + { text: 'Part 1' }, + { inlineData: { mimeType: 'image/png', data: 'base64==' } }, + ]; + await session.send(parts); + expect(mockHandler.send).toHaveBeenCalledTimes(1); + const sentData = JSON.parse((mockHandler.send as jest.Mock).mock.calls[0]![0] as string); + expect(sentData.clientContent.turns[0].parts).toEqual(parts); + }); + }); + + describe('sendTextRealtime()', function () { + it('should send a correctly formatted realtimeInput message', async function () { + const text = 'foo'; + await session.sendTextRealtime(text); + expect(mockHandler.send).toHaveBeenCalledTimes(1); + const sentData = JSON.parse((mockHandler.send as jest.Mock).mock.calls[0]![0] as string); + expect(sentData).toEqual({ + realtimeInput: { text }, + }); + }); + }); + + describe('sendAudioRealtime()', function () { + it('should send a correctly formatted realtimeInput message', async function () { + const blob = { data: 'abcdef', mimeType: 'audio/pcm' }; + await session.sendAudioRealtime(blob); + expect(mockHandler.send).toHaveBeenCalledTimes(1); + const sentData = JSON.parse((mockHandler.send as jest.Mock).mock.calls[0]![0] as string); + expect(sentData).toEqual({ + realtimeInput: { audio: blob }, + }); + }); + }); + + describe('sendVideoRealtime()', function () { + it('should send a correctly formatted realtimeInput message', async function () { + const blob = { data: 'abcdef', mimeType: 'image/jpeg' }; + await session.sendVideoRealtime(blob); + expect(mockHandler.send).toHaveBeenCalledTimes(1); + const sentData = JSON.parse((mockHandler.send as jest.Mock).mock.calls[0]![0] as string); + expect(sentData).toEqual({ + realtimeInput: { video: blob }, + }); + }); + }); + + describe('sendMediaChunks()', function () { + it('should send a correctly formatted realtimeInput message', async function () { + const chunks = [{ data: 'base64', mimeType: 'audio/webm' }]; + await session.sendMediaChunks(chunks); + expect(mockHandler.send).toHaveBeenCalledTimes(1); + const sentData = JSON.parse((mockHandler.send as jest.Mock).mock.calls[0]![0] as string); + expect(sentData).toEqual({ + realtimeInput: { mediaChunks: chunks }, + }); + }); + }); + + describe('sendMediaStream()', function () { + it('should send multiple chunks from a stream', async function () { + const stream = new ReadableStream({ + start(controller) { + controller.enqueue({ data: 'chunk1', mimeType: 'audio/webm' }); + controller.enqueue({ data: 'chunk2', mimeType: 'audio/webm' }); + controller.close(); + }, + }); + + await session.sendMediaStream(stream); + + expect(mockHandler.send).toHaveBeenCalledTimes(2); + const firstCall = JSON.parse((mockHandler.send as jest.Mock).mock.calls[0]![0] as string); + const secondCall = JSON.parse((mockHandler.send as jest.Mock).mock.calls[1]![0] as string); + expect(firstCall.realtimeInput.mediaChunks[0].data).toBe('chunk1'); + expect(secondCall.realtimeInput.mediaChunks[0].data).toBe('chunk2'); + }); + + it('should re-throw an AIError if the stream reader throws', async function () { + const errorStream = new ReadableStream({ + pull(controller) { + controller.error(new Error('Stream failed!')); + }, + }); + const promise = session.sendMediaStream(errorStream); + await expect(promise).rejects.toThrow(AIError); + await expect(promise).rejects.toThrow(/Stream failed!/); + }); + }); + + describe('sendFunctionResponses()', function () { + it('should send all function responses', async function () { + const functionResponses: FunctionResponse[] = [ + { + id: 'function-call-1', + name: 'function-name', + response: { + result: 'foo', + }, + }, + { + id: 'function-call-2', + name: 'function-name-2', + response: { + result: 'bar', + }, + }, + ]; + await session.sendFunctionResponses(functionResponses); + expect(mockHandler.send).toHaveBeenCalledTimes(1); + const sentData = JSON.parse((mockHandler.send as jest.Mock).mock.calls[0]![0] as string); + expect(sentData).toEqual({ + toolResponse: { + functionResponses, + }, + }); + }); + }); + + describe('receive()', function () { + it('should correctly parse and transform all server message types', async function () { + const receivePromise = (async () => { + const responses = []; + for await (const response of session.receive()) { + responses.push(response); + } + return responses; + })(); + + mockHandler.simulateServerMessage({ + serverContent: { modelTurn: { parts: [{ text: 'response 1' }] } }, + }); + mockHandler.simulateServerMessage({ + toolCall: { functionCalls: [{ name: 'test_func' }] }, + }); + mockHandler.simulateServerMessage({ + toolCallCancellation: { functionIds: ['123'] }, + }); + mockHandler.simulateServerMessage({ + serverContent: { turnComplete: true }, + }); + await new Promise(r => setTimeout(() => r(), 10)); // Wait for the listener to process messages + mockHandler.endStream(); + + const responses = await receivePromise; + expect(responses).toHaveLength(4); + expect(responses[0]).toEqual({ + type: LiveResponseType.SERVER_CONTENT, + modelTurn: { parts: [{ text: 'response 1' }] }, + } as LiveServerContent); + expect(responses[1]).toEqual({ + type: LiveResponseType.TOOL_CALL, + functionCalls: [{ name: 'test_func' }], + } as LiveServerToolCall); + expect(responses[2]).toEqual({ + type: LiveResponseType.TOOL_CALL_CANCELLATION, + functionIds: ['123'], + } as LiveServerToolCallCancellation); + }); + + it('should log a warning and skip messages that are not objects', async function () { + const loggerSpy = jest.spyOn(logger, 'warn'); + const receivePromise = (async () => { + const responses = []; + for await (const response of session.receive()) { + responses.push(response); + } + return responses; + })(); + + mockHandler.simulateServerMessage(null as unknown as object); + mockHandler.simulateServerMessage('not an object' as unknown as object); + await new Promise(r => setTimeout(() => r(), 10)); // Wait for the listener to process messages + mockHandler.endStream(); + + const responses = await receivePromise; + expect(responses).toHaveLength(0); + expect(loggerSpy).toHaveBeenCalledTimes(2); + expect(loggerSpy).toHaveBeenCalledWith( + expect.stringContaining('Received an invalid message'), + ); + + loggerSpy.mockRestore(); + }); + + it('should log a warning and skip objects of unknown type', async function () { + const loggerSpy = jest.spyOn(logger, 'warn'); + const receivePromise = (async () => { + const responses = []; + for await (const response of session.receive()) { + responses.push(response); + } + return responses; + })(); + + mockHandler.simulateServerMessage({ unknownType: { data: 'test' } }); + await new Promise(r => setTimeout(() => r(), 10)); // Wait for the listener to process messages + mockHandler.endStream(); + + const responses = await receivePromise; + expect(responses).toHaveLength(0); + expect(loggerSpy).toHaveBeenCalledTimes(1); + expect(loggerSpy).toHaveBeenCalledWith( + expect.stringContaining('Received an unknown message type'), + ); + + loggerSpy.mockRestore(); + }); + }); + + describe('close()', function () { + it('should call the handler, set the isClosed flag, and be idempotent', async function () { + expect(session.isClosed).toBe(false); + await session.close(); + expect(mockHandler.close).toHaveBeenCalledTimes(1); + expect(session.isClosed).toBe(true); + + // Call again to test idempotency + await session.close(); + expect(mockHandler.close).toHaveBeenCalledTimes(1); // Should not be called again + }); + + it('should terminate an active receive() loop', async function () { + const received: unknown[] = []; + const receivePromise = (async () => { + for await (const msg of session.receive()) { + received.push(msg); + } + })(); + + mockHandler.simulateServerMessage({ + serverContent: { modelTurn: { parts: [{ text: 'one' }] } }, + }); + // Allow the first message to be processed + await new Promise(r => setTimeout(r, 10)); + expect(received).toHaveLength(1); + + await session.close(); + mockHandler.endStream(); // End the mock stream + + await receivePromise; // This should now resolve + + // No more messages should have been processed + expect(received).toHaveLength(1); + }); + + it('methods should throw after session is closed', async function () { + await session.close(); + await expect(session.send('test')).rejects.toThrow(AIError); + await expect(session.send('test')).rejects.toThrow(/closed/); + await expect(session.sendMediaChunks([])).rejects.toThrow(AIError); + await expect(session.sendMediaChunks([])).rejects.toThrow(/closed/); + const generator = session.receive(); + const nextPromise = generator.next(); + await expect(nextPromise).rejects.toThrow(AIError); + await expect(nextPromise).rejects.toThrow(/closed/); + }); + }); +}); diff --git a/packages/ai/__tests__/request.test.ts b/packages/ai/__tests__/request.test.ts index 3e5e58e415..d00fbabc12 100644 --- a/packages/ai/__tests__/request.test.ts +++ b/packages/ai/__tests__/request.test.ts @@ -249,10 +249,12 @@ describe('request methods', () => { ok: true, } as Response); const response = await makeRequest( - 'models/model-name', - Task.GENERATE_CONTENT, - fakeApiSettings, - false, + { + model: 'models/model-name', + task: Task.GENERATE_CONTENT, + apiSettings: fakeApiSettings, + stream: false, + }, '', ); expect(fetchMock).toHaveBeenCalledTimes(1); @@ -267,9 +269,18 @@ describe('request methods', () => { } as Response); try { - await makeRequest('models/model-name', Task.GENERATE_CONTENT, fakeApiSettings, false, '', { - timeout: 180000, - }); + await makeRequest( + { + model: 'models/model-name', + task: Task.GENERATE_CONTENT, + apiSettings: fakeApiSettings, + stream: false, + requestOptions: { + timeout: 180000, + }, + }, + '', + ); } catch (e) { expect((e as AIError).code).toBe(AIErrorCode.FETCH_ERROR); expect((e as AIError).customErrorData?.status).toBe(500); @@ -287,7 +298,15 @@ describe('request methods', () => { statusText: 'Server Error', } as Response); try { - await makeRequest('models/model-name', Task.GENERATE_CONTENT, fakeApiSettings, false, ''); + await makeRequest( + { + model: 'models/model-name', + task: Task.GENERATE_CONTENT, + apiSettings: fakeApiSettings, + stream: false, + }, + '', + ); } catch (e) { expect((e as AIError).code).toBe(AIErrorCode.FETCH_ERROR); expect((e as AIError).customErrorData?.status).toBe(500); @@ -305,7 +324,15 @@ describe('request methods', () => { json: () => Promise.resolve({ error: { message: 'extra info' } }), } as Response); try { - await makeRequest('models/model-name', Task.GENERATE_CONTENT, fakeApiSettings, false, ''); + await makeRequest( + { + model: 'models/model-name', + task: Task.GENERATE_CONTENT, + apiSettings: fakeApiSettings, + stream: false, + }, + '', + ); } catch (e) { expect((e as AIError).code).toBe(AIErrorCode.FETCH_ERROR); expect((e as AIError).customErrorData?.status).toBe(500); @@ -336,7 +363,15 @@ describe('request methods', () => { }), } as Response); try { - await makeRequest('models/model-name', Task.GENERATE_CONTENT, fakeApiSettings, false, ''); + await makeRequest( + { + model: 'models/model-name', + task: Task.GENERATE_CONTENT, + apiSettings: fakeApiSettings, + stream: false, + }, + '', + ); } catch (e) { expect((e as AIError).code).toBe(AIErrorCode.FETCH_ERROR); expect((e as AIError).customErrorData?.status).toBe(500); @@ -356,7 +391,15 @@ describe('request methods', () => { ); const fetchMock = jest.spyOn(globalThis, 'fetch').mockResolvedValue(mockResponse as Response); try { - await makeRequest('models/model-name', Task.GENERATE_CONTENT, fakeApiSettings, false, ''); + await makeRequest( + { + model: 'models/model-name', + task: Task.GENERATE_CONTENT, + apiSettings: fakeApiSettings, + stream: false, + }, + '', + ); } catch (e) { expect((e as AIError).code).toBe(AIErrorCode.API_NOT_ENABLED); expect((e as AIError).message).toContain('my-project'); diff --git a/packages/ai/__tests__/template-generative-model.test.ts b/packages/ai/__tests__/template-generative-model.test.ts new file mode 100644 index 0000000000..1f1ddee409 --- /dev/null +++ b/packages/ai/__tests__/template-generative-model.test.ts @@ -0,0 +1,93 @@ +/** + * @license + * Copyright 2025 Google LLC + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +import { describe, expect, it, afterEach, jest } from '@jest/globals'; +import { type ReactNativeFirebase } from '@react-native-firebase/app'; +import { AI } from '../lib/public-types'; +import { VertexAIBackend } from '../lib/backend'; +import { TemplateGenerativeModel } from '../lib/models/template-generative-model'; +import * as generateContentMethods from '../lib/methods/generate-content'; + +const fakeAI: AI = { + app: { + name: 'DEFAULT', + automaticDataCollectionEnabled: true, + options: { + apiKey: 'key', + projectId: 'my-project', + appId: 'my-appid', + }, + } as ReactNativeFirebase.FirebaseApp, + backend: new VertexAIBackend('us-central1'), + location: 'us-central1', +}; + +const TEMPLATE_ID = 'my-template'; +const TEMPLATE_VARS = { a: 1, b: '2' }; + +describe('TemplateGenerativeModel', function () { + afterEach(function () { + jest.restoreAllMocks(); + }); + + describe('constructor', function () { + it('should initialize _apiSettings correctly', function () { + const model = new TemplateGenerativeModel(fakeAI); + expect(model._apiSettings.apiKey).toBe('key'); + expect(model._apiSettings.project).toBe('my-project'); + expect(model._apiSettings.appId).toBe('my-appid'); + }); + }); + + describe('generateContent', function () { + it('should call templateGenerateContent with correct parameters', async function () { + const templateGenerateContentSpy = jest + .spyOn(generateContentMethods, 'templateGenerateContent') + .mockResolvedValue({} as any); + const model = new TemplateGenerativeModel(fakeAI, { timeout: 5000 }); + + await model.generateContent(TEMPLATE_ID, TEMPLATE_VARS); + + expect(templateGenerateContentSpy).toHaveBeenCalledTimes(1); + expect(templateGenerateContentSpy).toHaveBeenCalledWith( + model._apiSettings, + TEMPLATE_ID, + { inputs: TEMPLATE_VARS }, + { timeout: 5000 }, + ); + }); + }); + + describe('generateContentStream', function () { + it('should call templateGenerateContentStream with correct parameters', async function () { + const templateGenerateContentStreamSpy = jest + .spyOn(generateContentMethods, 'templateGenerateContentStream') + .mockResolvedValue({} as any); + const model = new TemplateGenerativeModel(fakeAI, { timeout: 5000 }); + + await model.generateContentStream(TEMPLATE_ID, TEMPLATE_VARS); + + expect(templateGenerateContentStreamSpy).toHaveBeenCalledTimes(1); + expect(templateGenerateContentStreamSpy).toHaveBeenCalledWith( + model._apiSettings, + TEMPLATE_ID, + { inputs: TEMPLATE_VARS }, + { timeout: 5000 }, + ); + }); + }); +}); diff --git a/packages/ai/__tests__/template-imagen-model.test.ts b/packages/ai/__tests__/template-imagen-model.test.ts new file mode 100644 index 0000000000..e4ff6e855b --- /dev/null +++ b/packages/ai/__tests__/template-imagen-model.test.ts @@ -0,0 +1,134 @@ +/** + * @license + * Copyright 2025 Google LLC + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +import { describe, expect, it, afterEach, jest } from '@jest/globals'; +import { type ReactNativeFirebase } from '@react-native-firebase/app'; +import { AI, AIErrorCode } from '../lib/public-types'; +import { VertexAIBackend } from '../lib/backend'; +import { TemplateImagenModel } from '../lib/models/template-imagen-model'; +import { AIError } from '../lib/errors'; +import * as request from '../lib/requests/request'; +import { ServerPromptTemplateTask } from '../lib/requests/request'; + +const fakeAI: AI = { + app: { + name: 'DEFAULT', + automaticDataCollectionEnabled: true, + options: { + apiKey: 'key', + projectId: 'my-project', + appId: 'my-appid', + }, + } as ReactNativeFirebase.FirebaseApp, + backend: new VertexAIBackend('us-central1'), + location: 'us-central1', +}; + +const TEMPLATE_ID = 'my-imagen-template'; +const TEMPLATE_VARS = { a: 1, b: '2' }; + +describe('TemplateImagenModel', function () { + afterEach(function () { + jest.restoreAllMocks(); + }); + + describe('constructor', function () { + it('should initialize _apiSettings correctly', function () { + const model = new TemplateImagenModel(fakeAI); + expect(model._apiSettings.apiKey).toBe('key'); + expect(model._apiSettings.project).toBe('my-project'); + expect(model._apiSettings.appId).toBe('my-appid'); + }); + }); + + describe('generateImages', function () { + it('should call makeRequest with correct parameters', async function () { + const makeRequestSpy = jest.spyOn(request, 'makeRequest').mockResolvedValue({ + json: () => + Promise.resolve({ + predictions: [ + { + bytesBase64Encoded: + 'iVBORw0KGgoAAAANSUhEUgAAAAEAAAABCAYAAAAfFcSJAAAADUlEQVR42mP8z8BQDwAEhQGAhKmMIQAAAABJRU5ErkJggg==', + mimeType: 'image/png', + }, + ], + }), + } as Response); + const model = new TemplateImagenModel(fakeAI, { timeout: 5000 }); + + await model.generateImages(TEMPLATE_ID, TEMPLATE_VARS); + + expect(makeRequestSpy).toHaveBeenCalledTimes(1); + expect(makeRequestSpy).toHaveBeenCalledWith( + { + task: ServerPromptTemplateTask.TEMPLATE_PREDICT, + templateId: TEMPLATE_ID, + apiSettings: model._apiSettings, + stream: false, + requestOptions: { timeout: 5000 }, + }, + JSON.stringify({ inputs: TEMPLATE_VARS }), + ); + }); + + it('should return the result of handlePredictResponse', async function () { + const mockPrediction = { + bytesBase64Encoded: + 'iVBORw0KGgoAAAANSUhEUgAAAAEAAAABCAYAAAAfFcSJAAAADUlEQVR42mP8z8BQDwAEhQGAhKmMIQAAAABJRU5ErkJggg==', + mimeType: 'image/png', + }; + jest.spyOn(request, 'makeRequest').mockResolvedValue({ + json: () => Promise.resolve({ predictions: [mockPrediction] }), + } as Response); + + const model = new TemplateImagenModel(fakeAI); + const result = await model.generateImages(TEMPLATE_ID, TEMPLATE_VARS); + + expect(result.images).toEqual([mockPrediction]); + }); + + it('should throw an AIError if the prompt is blocked', async function () { + const error = new AIError(AIErrorCode.FETCH_ERROR, 'Request failed'); + jest.spyOn(request, 'makeRequest').mockRejectedValue(error); + + const model = new TemplateImagenModel(fakeAI); + await expect(model.generateImages(TEMPLATE_ID, TEMPLATE_VARS)).rejects.toThrow(error); + }); + + it('should handle responses with filtered images', async function () { + const mockPrediction = { + bytesBase64Encoded: 'iVBOR...ggg==', + mimeType: 'image/png', + }; + const filteredReason = 'This image was filtered for safety reasons.'; + jest.spyOn(request, 'makeRequest').mockResolvedValue({ + json: () => + Promise.resolve({ + predictions: [mockPrediction, { raiFilteredReason: filteredReason }], + }), + } as Response); + + const model = new TemplateImagenModel(fakeAI); + const result = await model.generateImages(TEMPLATE_ID, TEMPLATE_VARS); + + expect(result.images).toHaveLength(1); + expect(result.images[0]).toEqual(mockPrediction); + expect(result.filteredReason).toBe(filteredReason); + }); + }); +}); diff --git a/packages/ai/__tests__/utils.test.ts b/packages/ai/__tests__/utils.test.ts new file mode 100644 index 0000000000..10d3485b7a --- /dev/null +++ b/packages/ai/__tests__/utils.test.ts @@ -0,0 +1,162 @@ +/** + * @license + * Copyright 2025 Google LLC + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +import { describe, expect, it, jest } from '@jest/globals'; +import { type ReactNativeFirebase } from '@react-native-firebase/app'; +import { AI, AIErrorCode } from '../lib/public-types'; +import { AIError } from '../lib/errors'; +import { VertexAIBackend } from '../lib/backend'; +import { AIService } from '../lib/service'; +import { initApiSettings } from '../lib/models/utils'; + +const fakeAI: AI = { + app: { + name: 'DEFAULT', + automaticDataCollectionEnabled: true, + options: { + apiKey: 'key', + projectId: 'my-project', + appId: 'my-appid', + }, + } as ReactNativeFirebase.FirebaseApp, + backend: new VertexAIBackend('us-central1'), + location: 'us-central1', +}; + +describe('initApiSettings', function () { + it('calls regular app check token when option is set', async function () { + const getTokenMock = jest + .fn<() => Promise<{ token: string }>>() + .mockResolvedValue({ token: 'mock-token' }); + const getLimitedUseTokenMock = jest + .fn<() => Promise<{ token: string }>>() + .mockResolvedValue({ token: 'mock-limited-token' }); + + const apiSettings = initApiSettings({ + ...fakeAI, + options: { useLimitedUseAppCheckTokens: false }, + appCheck: { + getToken: getTokenMock, + getLimitedUseToken: getLimitedUseTokenMock, + }, + } as unknown as AIService); + + if (apiSettings?.getAppCheckToken) { + await apiSettings.getAppCheckToken(); + } + + expect(getTokenMock).toHaveBeenCalled(); + expect(getLimitedUseTokenMock).not.toHaveBeenCalled(); + }); + + it('calls limited use token when option is set', async function () { + const getTokenMock = jest + .fn<() => Promise<{ token: string }>>() + .mockResolvedValue({ token: 'mock-token' }); + const getLimitedUseTokenMock = jest + .fn<() => Promise<{ token: string }>>() + .mockResolvedValue({ token: 'mock-limited-token' }); + + const apiSettings = initApiSettings({ + ...fakeAI, + options: { useLimitedUseAppCheckTokens: true }, + appCheck: { + getToken: getTokenMock, + getLimitedUseToken: getLimitedUseTokenMock, + }, + } as unknown as AIService); + + if (apiSettings?.getAppCheckToken) { + await apiSettings.getAppCheckToken(); + } + + expect(getTokenMock).not.toHaveBeenCalled(); + expect(getLimitedUseTokenMock).toHaveBeenCalled(); + }); + + it('throws if not passed an api key', function () { + const fakeAI: AI = { + app: { + name: 'DEFAULT', + automaticDataCollectionEnabled: true, + options: { + projectId: 'my-project', + }, + } as ReactNativeFirebase.FirebaseApp, + backend: new VertexAIBackend('us-central1'), + location: 'us-central1', + }; + + expect(() => { + initApiSettings(fakeAI); + }).toThrow(AIError); + + try { + initApiSettings(fakeAI); + } catch (e) { + expect((e as AIError).code).toBe(AIErrorCode.NO_API_KEY); + } + }); + + it('throws if not passed a project ID', function () { + const fakeAI: AI = { + app: { + name: 'DEFAULT', + automaticDataCollectionEnabled: true, + options: { + apiKey: 'key', + }, + } as ReactNativeFirebase.FirebaseApp, + backend: new VertexAIBackend('us-central1'), + location: 'us-central1', + }; + + expect(() => { + initApiSettings(fakeAI); + }).toThrow(AIError); + + try { + initApiSettings(fakeAI); + } catch (e) { + expect((e as AIError).code).toBe(AIErrorCode.NO_PROJECT_ID); + } + }); + + it('throws if not passed an app ID', function () { + const fakeAI: AI = { + app: { + name: 'DEFAULT', + automaticDataCollectionEnabled: true, + options: { + apiKey: 'key', + projectId: 'my-project', + }, + } as ReactNativeFirebase.FirebaseApp, + backend: new VertexAIBackend('us-central1'), + location: 'us-central1', + }; + + expect(() => { + initApiSettings(fakeAI); + }).toThrow(AIError); + + try { + initApiSettings(fakeAI); + } catch (e) { + expect((e as AIError).code).toBe(AIErrorCode.NO_APP_ID); + } + }); +}); diff --git a/packages/ai/__tests__/websocket.test.ts b/packages/ai/__tests__/websocket.test.ts new file mode 100644 index 0000000000..aa871b7d30 --- /dev/null +++ b/packages/ai/__tests__/websocket.test.ts @@ -0,0 +1,254 @@ +/** + * @license + * Copyright 2025 Google LLC + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +import { describe, expect, it, beforeEach, afterEach, jest } from '@jest/globals'; +import { WebSocketHandlerImpl } from '../lib/websocket'; +import { AIError } from '../lib/errors'; + +class MockWebSocket { + static CONNECTING = 0; + static OPEN = 1; + static CLOSING = 2; + static CLOSED = 3; + + readyState: number = MockWebSocket.CONNECTING; + sentMessages: Array = []; + url: string; + binaryType: string = 'blob'; + private listeners: Map void>> = new Map(); + + constructor(url: string) { + this.url = url; + } + + send(data: string | ArrayBuffer): void { + if (this.readyState !== MockWebSocket.OPEN) { + throw new Error('WebSocket is not in OPEN state'); + } + this.sentMessages.push(data); + } + + close(): void { + if (this.readyState === MockWebSocket.CLOSED || this.readyState === MockWebSocket.CLOSING) { + return; + } + this.readyState = MockWebSocket.CLOSING; + setTimeout(() => { + this.readyState = MockWebSocket.CLOSED; + this.dispatchEvent(new Event('close')); + }, 10); + } + + addEventListener(type: string, listener: (event: any) => void): void { + if (!this.listeners.has(type)) { + this.listeners.set(type, new Set()); + } + this.listeners.get(type)!.add(listener); + } + + removeEventListener(type: string, listener: (event: any) => void): void { + this.listeners.get(type)?.delete(listener); + } + + dispatchEvent(event: Event): void { + this.listeners.get(event.type)?.forEach(listener => listener(event)); + } + + triggerOpen(): void { + this.readyState = MockWebSocket.OPEN; + this.dispatchEvent(new Event('open')); + } + + triggerMessage(data: unknown): void { + const event = new Event('message'); + (event as any).data = data; + this.dispatchEvent(event); + } + + triggerError(): void { + this.dispatchEvent(new Event('error')); + } +} + +describe('WebSocketHandlerImpl', function () { + let handler: WebSocketHandlerImpl; + let mockWebSocket: MockWebSocket; + + beforeEach(function () { + // @ts-ignore - Mock WebSocket in global scope + global.WebSocket = jest.fn((url: string) => { + mockWebSocket = new MockWebSocket(url); + return mockWebSocket as unknown as WebSocket; + }) as unknown as typeof WebSocket; + + // Set WebSocket constants on the global mock + // @ts-ignore + global.WebSocket.CONNECTING = 0; + // @ts-ignore + global.WebSocket.OPEN = 1; + // @ts-ignore + global.WebSocket.CLOSING = 2; + // @ts-ignore + global.WebSocket.CLOSED = 3; + + jest.useFakeTimers(); + handler = new WebSocketHandlerImpl(); + }); + + afterEach(function () { + jest.restoreAllMocks(); + jest.useRealTimers(); + }); + + describe('connect()', function () { + it('should resolve on open event', async function () { + const connectPromise = handler.connect('ws://test-url'); + expect(global.WebSocket).toHaveBeenCalledWith('ws://test-url'); + + await jest.advanceTimersByTimeAsync(1); + mockWebSocket.triggerOpen(); + + await expect(connectPromise).resolves.toBeUndefined(); + }); + + it('should reject on error event', async function () { + const connectPromise = handler.connect('ws://test-url'); + await jest.advanceTimersByTimeAsync(1); + mockWebSocket.triggerError(); + + await expect(connectPromise).rejects.toThrow(AIError); + await expect(connectPromise).rejects.toThrow(/Error event raised on WebSocket/); + }); + }); + + describe('listen()', function () { + beforeEach(async function () { + const connectPromise = handler.connect('ws://test'); + mockWebSocket.triggerOpen(); + await connectPromise; + }); + + it('should yield multiple messages as they arrive', async function () { + const generator = handler.listen(); + + const received: unknown[] = []; + const listenPromise = (async () => { + for await (const msg of generator) { + received.push(msg); + } + })(); + + // Use advanceTimersByTimeAsync to allow the consumer to start listening + await jest.advanceTimersByTimeAsync(1); + mockWebSocket.triggerMessage(new Blob([JSON.stringify({ foo: 1 })])); + + await jest.advanceTimersByTimeAsync(10); + mockWebSocket.triggerMessage(new Blob([JSON.stringify({ foo: 2 })])); + + await jest.advanceTimersByTimeAsync(5); + mockWebSocket.close(); + await jest.runAllTimersAsync(); // Let timers finish + + await listenPromise; // Wait for the consumer to finish + + expect(received).toEqual([{ foo: 1 }, { foo: 2 }]); + }); + + it('should buffer messages that arrive before the consumer calls .next()', async function () { + const generator = handler.listen(); + + // Create a promise that will consume the generator in a separate async context + const received: unknown[] = []; + const consumptionPromise = (async () => { + for await (const message of generator) { + received.push(message); + } + })(); + + await jest.advanceTimersByTimeAsync(1); + + mockWebSocket.triggerMessage(new Blob([JSON.stringify({ foo: 1 })])); + mockWebSocket.triggerMessage(new Blob([JSON.stringify({ foo: 2 })])); + + await jest.advanceTimersByTimeAsync(1); + mockWebSocket.close(); + await jest.runAllTimersAsync(); + + await consumptionPromise; + + expect(received).toEqual([{ foo: 1 }, { foo: 2 }]); + }); + }); + + describe('close()', function () { + it('should be idempotent and not throw if called multiple times', async function () { + const connectPromise = handler.connect('ws://test'); + mockWebSocket.triggerOpen(); + await connectPromise; + + const closePromise1 = handler.close(); + await jest.runAllTimersAsync(); + await closePromise1; + + await expect(handler.close()).resolves.toBeUndefined(); + }); + + it('should wait for the onclose event before resolving', async function () { + const connectPromise = handler.connect('ws://test'); + mockWebSocket.triggerOpen(); + await connectPromise; + + let closed = false; + const closePromise = handler.close().then(() => { + closed = true; + }); + + // The promise should not have resolved yet + await jest.advanceTimersByTimeAsync(5); + expect(closed).toBe(false); + + // Now, let the mock's setTimeout for closing run, which triggers onclose + await jest.advanceTimersByTimeAsync(10); + + await expect(closePromise).resolves.toBeUndefined(); + expect(closed).toBe(true); + }); + }); + + describe('Interaction between listen() and close()', function () { + it('should allow close() to take precedence and resolve correctly, while also terminating the listener', async function () { + const connectPromise = handler.connect('ws://test'); + mockWebSocket.triggerOpen(); + await connectPromise; + + const generator = handler.listen(); + const listenPromise = (async () => { + for await (const _ of generator) { + } + })(); + + const closePromise = handler.close(); + + await jest.runAllTimersAsync(); + + await expect(closePromise).resolves.toBeUndefined(); + await expect(listenPromise).resolves.toBeUndefined(); + + expect(mockWebSocket.readyState).toBe(MockWebSocket.CLOSED); + }); + }); +}); diff --git a/packages/ai/lib/backend.ts b/packages/ai/lib/backend.ts index 7209828122..2eaec59448 100644 --- a/packages/ai/lib/backend.ts +++ b/packages/ai/lib/backend.ts @@ -15,7 +15,7 @@ * limitations under the License. */ -import { DEFAULT_LOCATION } from './constants'; +import { DEFAULT_API_VERSION, DEFAULT_LOCATION } from './constants'; import { BackendType } from './public-types'; /** @@ -39,6 +39,16 @@ export abstract class Backend { protected constructor(type: BackendType) { this.backendType = type; } + + /** + * @internal + */ + abstract _getModelPath(project: string, model: string): string; + + /** + * @internal + */ + abstract _getTemplatePath(project: string, templateId: string): string; } /** @@ -56,6 +66,20 @@ export class GoogleAIBackend extends Backend { constructor() { super(BackendType.GOOGLE_AI); } + + /** + * @internal + */ + _getModelPath(project: string, model: string): string { + return `/${DEFAULT_API_VERSION}/projects/${project}/${model}`; + } + + /** + * @internal + */ + _getTemplatePath(project: string, templateId: string): string { + return `/${DEFAULT_API_VERSION}/projects/${project}/templates/${templateId}`; + } } /** @@ -89,4 +113,18 @@ export class VertexAIBackend extends Backend { this.location = location; } } + + /** + * @internal + */ + _getModelPath(project: string, model: string): string { + return `/${DEFAULT_API_VERSION}/projects/${project}/locations/${this.location}/${model}`; + } + + /** + * @internal + */ + _getTemplatePath(project: string, templateId: string): string { + return `/${DEFAULT_API_VERSION}/projects/${project}/locations/${this.location}/templates/${templateId}`; + } } diff --git a/packages/ai/lib/index.ts b/packages/ai/lib/index.ts index 7c8c476de4..6dfce8da33 100644 --- a/packages/ai/lib/index.ts +++ b/packages/ai/lib/index.ts @@ -18,18 +18,34 @@ import './polyfills'; import { getApp, ReactNativeFirebase } from '@react-native-firebase/app'; import { Backend, GoogleAIBackend, VertexAIBackend } from './backend'; -import { AIErrorCode, ModelParams, RequestOptions } from './types'; +import { AIErrorCode, LiveModelParams, ModelParams, RequestOptions } from './types'; import { AI, AIOptions, ImagenModelParams } from './public-types'; import { AIError } from './errors'; import { GenerativeModel } from './models/generative-model'; -import { AIModel, ImagenModel } from './models'; +import { + AIModel, + ImagenModel, + LiveGenerativeModel, + TemplateGenerativeModel, + TemplateImagenModel, +} from './models'; +import { WebSocketHandlerImpl } from './websocket'; export * from './public-types'; export { ChatSession } from './methods/chat-session'; +export { LiveSession } from './methods/live-session'; export * from './requests/schema-builder'; export { ImagenImageFormat } from './requests/imagen-image-format'; export { Backend, GoogleAIBackend, VertexAIBackend } from './backend'; -export { GenerativeModel, AIError, AIModel, ImagenModel }; +export { + GenerativeModel, + AIError, + AIModel, + ImagenModel, + LiveGenerativeModel, + TemplateGenerativeModel, + TemplateImagenModel, +}; /** * Returns the default {@link AI} instance that is associated with the provided @@ -125,3 +141,56 @@ export function getImagenModel( } return new ImagenModel(ai, modelParams, requestOptions); } + +/** + * Returns a {@link LiveGenerativeModel} class for real-time, bidirectional communication. + * + * The Live API is only supported in modern browser windows and Node >= 22. + * + * @param ai - An {@link AI} instance. + * @param modelParams - Parameters to use when setting up a {@link LiveSession}. + * @throws If the `apiKey` or `projectId` fields are missing in your + * Firebase config. + * + * @beta + */ +export function getLiveGenerativeModel(ai: AI, modelParams: LiveModelParams): LiveGenerativeModel { + if (!modelParams.model) { + throw new AIError( + AIErrorCode.NO_MODEL, + `Must provide a model name for getLiveGenerativeModel. Example: getLiveGenerativeModel(ai, { model: 'my-model-name' })`, + ); + } + const webSocketHandler = new WebSocketHandlerImpl(); + return new LiveGenerativeModel(ai, modelParams, webSocketHandler); +} + +/** + * Returns a {@link TemplateGenerativeModel} class for executing server-side Gemini templates. + * + * @param ai - An {@link AI} instance. + * @param requestOptions - Additional options to use when making requests. + * + * @beta + */ +export function getTemplateGenerativeModel( + ai: AI, + requestOptions?: RequestOptions, +): TemplateGenerativeModel { + return new TemplateGenerativeModel(ai, requestOptions); +} + +/** + * Returns a {@link TemplateImagenModel} class for executing server-side Imagen templates. + * + * @param ai - An {@link AI} instance. + * @param requestOptions - Additional options to use when making requests. + * + * @beta + */ +export function getTemplateImagenModel( + ai: AI, + requestOptions?: RequestOptions, +): TemplateImagenModel { + return new TemplateImagenModel(ai, requestOptions); +} diff --git a/packages/ai/lib/methods/count-tokens.ts b/packages/ai/lib/methods/count-tokens.ts index baa2a84589..ebb9e31ee6 100644 --- a/packages/ai/lib/methods/count-tokens.ts +++ b/packages/ai/lib/methods/count-tokens.ts @@ -49,12 +49,14 @@ export async function countTokens( } const response = await makeRequest( - model, - Task.COUNT_TOKENS, - apiSettings, - false, + { + model, + task: Task.COUNT_TOKENS, + apiSettings, + stream: false, + requestOptions, + }, body, - requestOptions, ); return response.json(); } diff --git a/packages/ai/lib/methods/generate-content.ts b/packages/ai/lib/methods/generate-content.ts index 901117c708..9d1f5c8a2f 100644 --- a/packages/ai/lib/methods/generate-content.ts +++ b/packages/ai/lib/methods/generate-content.ts @@ -22,7 +22,7 @@ import { GenerateContentStreamResult, RequestOptions, } from '../types'; -import { Task, makeRequest } from '../requests/request'; +import { Task, makeRequest, ServerPromptTemplateTask } from '../requests/request'; import { createEnhancedContentResponse } from '../requests/response-helpers'; import { processStream } from '../requests/stream-reader'; import { ApiSettings } from '../types/internal'; @@ -48,12 +48,14 @@ export async function generateContentStream( params = GoogleAIMapper.mapGenerateContentRequest(params); } const response = await makeRequest( - model, - Task.STREAM_GENERATE_CONTENT, - apiSettings, - /* stream */ true, + { + model, + task: Task.STREAM_GENERATE_CONTENT, + apiSettings, + stream: true, + requestOptions, + }, JSON.stringify(params), - requestOptions, ); return processStream(response, apiSettings); } @@ -78,12 +80,14 @@ export async function generateContent( params = GoogleAIMapper.mapGenerateContentRequest(params); } const response = await makeRequest( - model, - Task.GENERATE_CONTENT, - apiSettings, - /* stream */ false, + { + model, + task: Task.GENERATE_CONTENT, + apiSettings, + stream: false, + requestOptions, + }, JSON.stringify(params), - requestOptions, ); const generateContentResponse = await processGenerateContentResponse(response, apiSettings); const enhancedResponse = createEnhancedContentResponse(generateContentResponse); @@ -110,3 +114,67 @@ async function processGenerateContentResponse( return responseJson; } } + +/** + * Generates content from a template with the given ID and variables. + * + * @param apiSettings The {@link ApiSettings} to use for the request. + * @param templateId The ID of the server-side template to execute. + * @param templateParams The parameters to populate the template with. + * @param requestOptions The {@link RequestOptions} to use for the request. + * @returns The {@link GenerateContentResult} from the request. + * + * @beta + */ +export async function templateGenerateContent( + apiSettings: ApiSettings, + templateId: string, + templateParams: object, + requestOptions?: RequestOptions, +): Promise { + const response = await makeRequest( + { + task: ServerPromptTemplateTask.TEMPLATE_GENERATE_CONTENT, + templateId, + apiSettings, + stream: false, + requestOptions, + }, + JSON.stringify(templateParams), + ); + const generateContentResponse = await processGenerateContentResponse(response, apiSettings); + const enhancedResponse = createEnhancedContentResponse(generateContentResponse); + return { + response: enhancedResponse, + }; +} + +/** + * Generates a content stream from a template with the given ID and variables. + * + * @param apiSettings The {@link ApiSettings} to use for the request. + * @param templateId The ID of the server-side template to execute. + * @param templateParams The parameters to populate the template with. + * @param requestOptions The {@link RequestOptions} to use for the request. + * @returns The {@link GenerateContentStreamResult} from the request. + * + * @beta + */ +export async function templateGenerateContentStream( + apiSettings: ApiSettings, + templateId: string, + templateParams: object, + requestOptions?: RequestOptions, +): Promise { + const response = await makeRequest( + { + task: ServerPromptTemplateTask.TEMPLATE_STREAM_GENERATE_CONTENT, + templateId, + apiSettings, + stream: true, + requestOptions, + }, + JSON.stringify(templateParams), + ); + return processStream(response, apiSettings); +} diff --git a/packages/ai/lib/methods/live-session.ts b/packages/ai/lib/methods/live-session.ts new file mode 100644 index 0000000000..f3cd2d099d --- /dev/null +++ b/packages/ai/lib/methods/live-session.ts @@ -0,0 +1,344 @@ +/** + * @license + * Copyright 2025 Google LLC + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +import { + AIErrorCode, + FunctionResponse, + GenerativeContentBlob, + LiveResponseType, + LiveServerContent, + LiveServerToolCall, + LiveServerToolCallCancellation, + Part, +} from '../types'; +import { formatNewContent } from '../requests/request-helpers'; +import { AIError } from '../errors'; +import { WebSocketHandler } from '../websocket'; +import { logger } from '../logger'; +import { + _LiveClientContent, + _LiveClientRealtimeInput, + _LiveClientToolResponse, +} from '../types/live-responses'; +import { ReadableStream } from 'web-streams-polyfill'; + +/** + * Represents an active, real-time, bidirectional conversation with the model. + * + * This class should only be instantiated by calling {@link LiveGenerativeModel.connect}. + * + * @beta + */ +export class LiveSession { + /** + * Indicates whether this Live session is closed. + * + * @beta + */ + isClosed = false; + /** + * Indicates whether this Live session is being controlled by an `AudioConversationController`. + * + * @beta + */ + inConversation = false; + + /** + * @internal + */ + constructor( + private webSocketHandler: WebSocketHandler, + private serverMessages: AsyncGenerator, + ) {} + + /** + * Sends content to the server. + * + * @param request - The message to send to the model. + * @param turnComplete - Indicates if the turn is complete. Defaults to false. + * @throws If this session has been closed. + * + * @beta + */ + async send(request: string | Array, turnComplete = true): Promise { + if (this.isClosed) { + throw new AIError( + AIErrorCode.REQUEST_ERROR, + 'This LiveSession has been closed and cannot be used.', + ); + } + + const newContent = formatNewContent(request); + + const message: _LiveClientContent = { + clientContent: { + turns: [newContent], + turnComplete, + }, + }; + this.webSocketHandler.send(JSON.stringify(message)); + } + + /** + * Sends text to the server in realtime. + * + * @example + * ```javascript + * liveSession.sendTextRealtime("Hello, how are you?"); + * ``` + * + * @param text - The text data to send. + * @throws If this session has been closed. + * + * @beta + */ + async sendTextRealtime(text: string): Promise { + if (this.isClosed) { + throw new AIError( + AIErrorCode.REQUEST_ERROR, + 'This LiveSession has been closed and cannot be used.', + ); + } + + const message: _LiveClientRealtimeInput = { + realtimeInput: { + text, + }, + }; + this.webSocketHandler.send(JSON.stringify(message)); + } + + /** + * Sends audio data to the server in realtime. + * + * @remarks The server requires that the audio data is base64-encoded 16-bit PCM at 16kHz + * little-endian. + * + * @example + * ```javascript + * // const pcmData = ... base64-encoded 16-bit PCM at 16kHz little-endian. + * const blob = { mimeType: "audio/pcm", data: pcmData }; + * liveSession.sendAudioRealtime(blob); + * ``` + * + * @param blob - The base64-encoded PCM data to send to the server in realtime. + * @throws If this session has been closed. + * + * @beta + */ + async sendAudioRealtime(blob: GenerativeContentBlob): Promise { + if (this.isClosed) { + throw new AIError( + AIErrorCode.REQUEST_ERROR, + 'This LiveSession has been closed and cannot be used.', + ); + } + + const message: _LiveClientRealtimeInput = { + realtimeInput: { + audio: blob, + }, + }; + this.webSocketHandler.send(JSON.stringify(message)); + } + + /** + * Sends video data to the server in realtime. + * + * @remarks The server requires that the video is sent as individual video frames at 1 FPS. It + * is recommended to set `mimeType` to `image/jpeg`. + * + * @example + * ```javascript + * // const videoFrame = ... base64-encoded JPEG data + * const blob = { mimeType: "image/jpeg", data: videoFrame }; + * liveSession.sendVideoRealtime(blob); + * ``` + * @param blob - The base64-encoded video data to send to the server in realtime. + * @throws If this session has been closed. + * + * @beta + */ + async sendVideoRealtime(blob: GenerativeContentBlob): Promise { + if (this.isClosed) { + throw new AIError( + AIErrorCode.REQUEST_ERROR, + 'This LiveSession has been closed and cannot be used.', + ); + } + + const message: _LiveClientRealtimeInput = { + realtimeInput: { + video: blob, + }, + }; + this.webSocketHandler.send(JSON.stringify(message)); + } + + /** + * Sends function responses to the server. + * + * @param functionResponses - The function responses to send. + * @throws If this session has been closed. + * + * @beta + */ + async sendFunctionResponses(functionResponses: FunctionResponse[]): Promise { + if (this.isClosed) { + throw new AIError( + AIErrorCode.REQUEST_ERROR, + 'This LiveSession has been closed and cannot be used.', + ); + } + + const message: _LiveClientToolResponse = { + toolResponse: { + functionResponses, + }, + }; + this.webSocketHandler.send(JSON.stringify(message)); + } + + /** + * Yields messages received from the server. + * This can only be used by one consumer at a time. + * + * @returns An `AsyncGenerator` that yields server messages as they arrive. + * @throws If the session is already closed, or if we receive a response that we don't support. + * + * @beta + */ + async *receive(): AsyncGenerator< + LiveServerContent | LiveServerToolCall | LiveServerToolCallCancellation + > { + if (this.isClosed) { + throw new AIError( + AIErrorCode.SESSION_CLOSED, + 'Cannot read from a Live session that is closed. Try starting a new Live session.', + ); + } + for await (const message of this.serverMessages) { + if (message && typeof message === 'object') { + if (LiveResponseType.SERVER_CONTENT in message) { + yield { + type: 'serverContent', + ...(message as { serverContent: Omit }).serverContent, + } as LiveServerContent; + } else if (LiveResponseType.TOOL_CALL in message) { + yield { + type: 'toolCall', + ...(message as { toolCall: Omit }).toolCall, + } as LiveServerToolCall; + } else if (LiveResponseType.TOOL_CALL_CANCELLATION in message) { + yield { + type: 'toolCallCancellation', + ...( + message as { + toolCallCancellation: Omit; + } + ).toolCallCancellation, + } as LiveServerToolCallCancellation; + } else { + logger.warn( + `Received an unknown message type from the server: ${JSON.stringify(message)}`, + ); + } + } else { + logger.warn(`Received an invalid message from the server: ${JSON.stringify(message)}`); + } + } + } + + /** + * Closes this session. + * All methods on this session will throw an error once this resolves. + * + * @beta + */ + async close(): Promise { + if (!this.isClosed) { + this.isClosed = true; + await this.webSocketHandler.close(1000, 'Client closed session.'); + } + } + + /** + * Sends realtime input to the server. + * + * @deprecated Use `sendTextRealtime()`, `sendAudioRealtime()`, and `sendVideoRealtime()` instead. + * + * @param mediaChunks - The media chunks to send. + * @throws If this session has been closed. + * + * @beta + */ + async sendMediaChunks(mediaChunks: GenerativeContentBlob[]): Promise { + if (this.isClosed) { + throw new AIError( + AIErrorCode.REQUEST_ERROR, + 'This LiveSession has been closed and cannot be used.', + ); + } + + // The backend does not support sending more than one mediaChunk in one message. + // Work around this limitation by sending mediaChunks in separate messages. + mediaChunks.forEach(mediaChunk => { + const message: _LiveClientRealtimeInput = { + realtimeInput: { mediaChunks: [mediaChunk] }, + }; + this.webSocketHandler.send(JSON.stringify(message)); + }); + } + + /** + * @deprecated Use `sendTextRealtime()`, `sendAudioRealtime()`, and `sendVideoRealtime()` instead. + * + * Sends a stream of {@link GenerativeContentBlob}. + * + * @param mediaChunkStream - The stream of {@link GenerativeContentBlob} to send. + * @throws If this session has been closed. + * + * @beta + */ + async sendMediaStream(mediaChunkStream: ReadableStream): Promise { + if (this.isClosed) { + throw new AIError( + AIErrorCode.REQUEST_ERROR, + 'This LiveSession has been closed and cannot be used.', + ); + } + + const reader = mediaChunkStream.getReader(); + while (true) { + try { + const { done, value } = await reader.read(); + + if (done) { + break; + } else if (!value) { + throw new Error('Missing chunk in reader, but reader is not done.'); + } + + await this.sendMediaChunks([value]); + } catch (e) { + // Re-throw any errors that occur during stream consumption or sending. + const message = e instanceof Error ? e.message : 'Error processing media stream.'; + throw new AIError(AIErrorCode.REQUEST_ERROR, message); + } + } + } +} diff --git a/packages/ai/lib/models/ai-model.ts b/packages/ai/lib/models/ai-model.ts index ca8b99d759..c996131ff2 100644 --- a/packages/ai/lib/models/ai-model.ts +++ b/packages/ai/lib/models/ai-model.ts @@ -16,10 +16,8 @@ */ import { ApiSettings } from '../types/internal'; -import { AIError } from '../errors'; -import { AIErrorCode } from '../types'; import { AI, BackendType } from '../public-types'; -import { AIService } from '../service'; +import { initApiSettings } from './utils'; /** * Base class for Firebase AI model APIs. @@ -59,41 +57,8 @@ export abstract class AIModel { * @internal */ protected constructor(ai: AI, modelName: string) { - if (!ai.app?.options?.apiKey) { - throw new AIError( - AIErrorCode.NO_API_KEY, - `The "apiKey" field is empty in the local Firebase config. Firebase AI requires this field to contain a valid API key.`, - ); - } else if (!ai.app?.options?.projectId) { - throw new AIError( - AIErrorCode.NO_PROJECT_ID, - `The "projectId" field is empty in the local Firebase config. Firebase AI requires this field to contain a valid project ID.`, - ); - } else if (!ai.app?.options?.appId) { - throw new AIError( - AIErrorCode.NO_APP_ID, - `The "appId" field is empty in the local Firebase config. Firebase AI requires this field to contain a valid app ID.`, - ); - } else { - this._apiSettings = { - apiKey: ai.app.options.apiKey, - project: ai.app.options.projectId, - appId: ai.app.options.appId, - automaticDataCollectionEnabled: ai.app.automaticDataCollectionEnabled, - location: ai.location, - backend: ai.backend, - }; - - if ((ai as AIService).appCheck) { - this._apiSettings.getAppCheckToken = () => (ai as AIService).appCheck!.getToken(); - } - - if ((ai as AIService).auth?.currentUser) { - this._apiSettings.getAuthToken = () => (ai as AIService).auth!.currentUser!.getIdToken(); - } - - this.model = AIModel.normalizeModelName(modelName, this._apiSettings.backend.backendType); - } + this._apiSettings = initApiSettings(ai); + this.model = AIModel.normalizeModelName(modelName, this._apiSettings.backend.backendType); } /** diff --git a/packages/ai/lib/models/imagen-model.ts b/packages/ai/lib/models/imagen-model.ts index 7cfc350539..1cd41aeb56 100644 --- a/packages/ai/lib/models/imagen-model.ts +++ b/packages/ai/lib/models/imagen-model.ts @@ -107,12 +107,14 @@ export class ImagenModel extends AIModel { ...this.safetySettings, }); const response = await makeRequest( - this.model, - Task.PREDICT, - this._apiSettings, - /* stream */ false, + { + model: this.model, + task: Task.PREDICT, + apiSettings: this._apiSettings, + stream: false, + requestOptions: this.requestOptions, + }, JSON.stringify(body), - this.requestOptions, ); return handlePredictResponse(response); } @@ -146,12 +148,14 @@ export class ImagenModel extends AIModel { ...this.safetySettings, }); const response = await makeRequest( - this.model, - Task.PREDICT, - this._apiSettings, - /* stream */ false, + { + model: this.model, + task: Task.PREDICT, + apiSettings: this._apiSettings, + stream: false, + requestOptions: this.requestOptions, + }, JSON.stringify(body), - this.requestOptions, ); return handlePredictResponse(response); } diff --git a/packages/ai/lib/models/index.ts b/packages/ai/lib/models/index.ts index cb694a5360..1524ef9721 100644 --- a/packages/ai/lib/models/index.ts +++ b/packages/ai/lib/models/index.ts @@ -18,3 +18,7 @@ export * from './ai-model'; export * from './generative-model'; export * from './imagen-model'; +export * from './live-generative-model'; +export * from './template-generative-model'; +export * from './template-imagen-model'; +export { initApiSettings } from './utils'; diff --git a/packages/ai/lib/models/live-generative-model.ts b/packages/ai/lib/models/live-generative-model.ts new file mode 100644 index 0000000000..f68ea7fb8f --- /dev/null +++ b/packages/ai/lib/models/live-generative-model.ts @@ -0,0 +1,130 @@ +/** + * @license + * Copyright 2025 Google LLC + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +import { AIModel } from './ai-model'; +import { LiveSession } from '../methods/live-session'; +import { AIError } from '../errors'; +import { + AI, + AIErrorCode, + BackendType, + Content, + LiveGenerationConfig, + LiveModelParams, + Tool, + ToolConfig, +} from '../public-types'; +import { WebSocketHandler } from '../websocket'; +import { WebSocketUrl } from '../requests/request'; +import { formatSystemInstruction } from '../requests/request-helpers'; +import { _LiveClientSetup } from '../types/live-responses'; + +/** + * Class for Live generative model APIs. The Live API enables low-latency, two-way multimodal + * interactions with Gemini. + * + * This class should only be instantiated with {@link getLiveGenerativeModel}. + * + * @beta + */ +export class LiveGenerativeModel extends AIModel { + generationConfig: LiveGenerationConfig; + tools?: Tool[]; + toolConfig?: ToolConfig; + systemInstruction?: Content; + + /** + * @internal + */ + constructor( + ai: AI, + modelParams: LiveModelParams, + /** + * @internal + */ + private _webSocketHandler: WebSocketHandler, + ) { + super(ai, modelParams.model); + this.generationConfig = modelParams.generationConfig || {}; + this.tools = modelParams.tools; + this.toolConfig = modelParams.toolConfig; + this.systemInstruction = formatSystemInstruction(modelParams.systemInstruction); + } + + /** + * Starts a {@link LiveSession}. + * + * @returns A {@link LiveSession}. + * @throws If the connection failed to be established with the server. + * + * @beta + */ + async connect(): Promise { + const url = new WebSocketUrl(this._apiSettings); + await this._webSocketHandler.connect(url.toString()); + + let fullModelPath: string; + if (this._apiSettings.backend.backendType === BackendType.GOOGLE_AI) { + fullModelPath = `projects/${this._apiSettings.project}/${this.model}`; + } else { + fullModelPath = `projects/${this._apiSettings.project}/locations/${this._apiSettings.location}/${this.model}`; + } + + // inputAudioTranscription and outputAudioTranscription are on the generation config in the public API, + // but the backend expects them to be in the `setup` message. + const { inputAudioTranscription, outputAudioTranscription, ...generationConfig } = + this.generationConfig; + + const setupMessage: _LiveClientSetup = { + setup: { + model: fullModelPath, + generationConfig, + tools: this.tools, + toolConfig: this.toolConfig, + systemInstruction: this.systemInstruction, + inputAudioTranscription, + outputAudioTranscription, + }, + }; + + try { + // Begin listening for server messages, and begin the handshake by sending the 'setupMessage' + const serverMessages = this._webSocketHandler.listen(); + this._webSocketHandler.send(JSON.stringify(setupMessage)); + + // Verify we received the handshake response 'setupComplete' + const firstMessage = (await serverMessages.next()).value; + if ( + !firstMessage || + !(typeof firstMessage === 'object') || + !('setupComplete' in firstMessage) + ) { + await this._webSocketHandler.close(1011, 'Handshake failure'); + throw new AIError( + AIErrorCode.RESPONSE_ERROR, + 'Server connection handshake failed. The server did not respond with a setupComplete message.', + ); + } + + return new LiveSession(this._webSocketHandler, serverMessages); + } catch (e) { + // Ensure connection is closed on any setup error + await this._webSocketHandler.close(); + throw e; + } + } +} diff --git a/packages/ai/lib/models/template-generative-model.ts b/packages/ai/lib/models/template-generative-model.ts new file mode 100644 index 0000000000..d97dcbcb30 --- /dev/null +++ b/packages/ai/lib/models/template-generative-model.ts @@ -0,0 +1,98 @@ +/** + * @license + * Copyright 2025 Google LLC + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +import { + templateGenerateContent, + templateGenerateContentStream, +} from '../methods/generate-content'; +import { GenerateContentResult, RequestOptions } from '../types'; +import { AI, GenerateContentStreamResult } from '../public-types'; +import { ApiSettings } from '../types/internal'; +import { initApiSettings } from './utils'; + +/** + * {@link GenerativeModel} APIs that execute on a server-side template. + * + * This class should only be instantiated with {@link getTemplateGenerativeModel}. + * + * @beta + */ +export class TemplateGenerativeModel { + /** + * @internal + */ + _apiSettings: ApiSettings; + + /** + * Additional options to use when making requests. + */ + requestOptions?: RequestOptions; + + /** + * @hideconstructor + */ + constructor(ai: AI, requestOptions?: RequestOptions) { + this.requestOptions = requestOptions || {}; + this._apiSettings = initApiSettings(ai); + } + + /** + * Makes a single non-streaming call to the model and returns an object + * containing a single {@link GenerateContentResponse}. + * + * @param templateId - The ID of the server-side template to execute. + * @param templateVariables - A key-value map of variables to populate the + * template with. + * + * @beta + */ + async generateContent( + templateId: string, + templateVariables: object, // anything! + ): Promise { + return templateGenerateContent( + this._apiSettings, + templateId, + { inputs: templateVariables }, + this.requestOptions, + ); + } + + /** + * Makes a single streaming call to the model and returns an object + * containing an iterable stream that iterates over all chunks in the + * streaming response as well as a promise that returns the final aggregated + * response. + * + * @param templateId - The ID of the server-side template to execute. + * @param templateVariables - A key-value map of variables to populate the + * template with. + * + * @beta + */ + async generateContentStream( + templateId: string, + templateVariables: object, + ): Promise { + return templateGenerateContentStream( + this._apiSettings, + templateId, + { inputs: templateVariables }, + this.requestOptions, + ); + } +} diff --git a/packages/ai/lib/models/template-imagen-model.ts b/packages/ai/lib/models/template-imagen-model.ts new file mode 100644 index 0000000000..a788c91180 --- /dev/null +++ b/packages/ai/lib/models/template-imagen-model.ts @@ -0,0 +1,77 @@ +/** + * @license + * Copyright 2025 Google LLC + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +import { RequestOptions } from '../types'; +import { AI, ImagenGenerationResponse, ImagenInlineImage } from '../public-types'; +import { ApiSettings } from '../types/internal'; +import { makeRequest, ServerPromptTemplateTask } from '../requests/request'; +import { handlePredictResponse } from '../requests/response-helpers'; +import { initApiSettings } from './utils'; + +/** + * Class for Imagen model APIs that execute on a server-side template. + * + * This class should only be instantiated with {@link getTemplateImagenModel}. + * + * @beta + */ +export class TemplateImagenModel { + /** + * @internal + */ + _apiSettings: ApiSettings; + + /** + * Additional options to use when making requests. + */ + requestOptions?: RequestOptions; + + /** + * @hideconstructor + */ + constructor(ai: AI, requestOptions?: RequestOptions) { + this.requestOptions = requestOptions || {}; + this._apiSettings = initApiSettings(ai); + } + + /** + * Makes a single call to the model and returns an object containing a single + * {@link ImagenGenerationResponse}. + * + * @param templateId - The ID of the server-side template to execute. + * @param templateVariables - A key-value map of variables to populate the + * template with. + * + * @beta + */ + async generateImages( + templateId: string, + templateVariables: object, + ): Promise> { + const response = await makeRequest( + { + task: ServerPromptTemplateTask.TEMPLATE_PREDICT, + templateId, + apiSettings: this._apiSettings, + stream: false, + requestOptions: this.requestOptions, + }, + JSON.stringify({ inputs: templateVariables }), + ); + return handlePredictResponse(response); + } +} diff --git a/packages/ai/lib/models/utils.ts b/packages/ai/lib/models/utils.ts new file mode 100644 index 0000000000..0f2dba4923 --- /dev/null +++ b/packages/ai/lib/models/utils.ts @@ -0,0 +1,68 @@ +/** + * @license + * Copyright 2025 Google LLC + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +import { AIError } from '../errors'; +import { AI, AIErrorCode } from '../public-types'; +import { AIService } from '../service'; +import { ApiSettings } from '../types/internal'; + +/** + * Initializes an {@link ApiSettings} object from an {@link AI} instance. + * + * @internal + */ +export function initApiSettings(ai: AI): ApiSettings { + if (!ai.app?.options?.apiKey) { + throw new AIError( + AIErrorCode.NO_API_KEY, + `The "apiKey" field is empty in the local Firebase config. Firebase AI requires this field to contain a valid API key.`, + ); + } else if (!ai.app?.options?.projectId) { + throw new AIError( + AIErrorCode.NO_PROJECT_ID, + `The "projectId" field is empty in the local Firebase config. Firebase AI requires this field to contain a valid project ID.`, + ); + } else if (!ai.app?.options?.appId) { + throw new AIError( + AIErrorCode.NO_APP_ID, + `The "appId" field is empty in the local Firebase config. Firebase AI requires this field to contain a valid app ID.`, + ); + } + + const apiSettings: ApiSettings = { + apiKey: ai.app.options.apiKey, + project: ai.app.options.projectId, + appId: ai.app.options.appId, + automaticDataCollectionEnabled: ai.app.automaticDataCollectionEnabled, + location: ai.location, + backend: ai.backend, + }; + + if ((ai as AIService).appCheck) { + if (ai.options?.useLimitedUseAppCheckTokens) { + apiSettings.getAppCheckToken = () => (ai as AIService).appCheck!.getLimitedUseToken(); + } else { + apiSettings.getAppCheckToken = () => (ai as AIService).appCheck!.getToken(); + } + } + + if ((ai as AIService).auth?.currentUser) { + apiSettings.getAuthToken = () => (ai as AIService).auth!.currentUser!.getIdToken(); + } + + return apiSettings; +} diff --git a/packages/ai/lib/requests/request.ts b/packages/ai/lib/requests/request.ts index 0b1c668ece..1a6b978b7d 100644 --- a/packages/ai/lib/requests/request.ts +++ b/packages/ai/lib/requests/request.ts @@ -27,6 +27,7 @@ import { } from '../constants'; import { logger } from '../logger'; import { GoogleAIBackend, VertexAIBackend } from '../backend'; +import { BackendType } from '../public-types'; export enum Task { GENERATE_CONTENT = 'generateContent', @@ -35,6 +36,12 @@ export enum Task { PREDICT = 'predict', } +export const enum ServerPromptTemplateTask { + TEMPLATE_GENERATE_CONTENT = 'templateGenerateContent', + TEMPLATE_STREAM_GENERATE_CONTENT = 'templateStreamGenerateContent', + TEMPLATE_PREDICT = 'templatePredict', +} + export class RequestUrl { constructor( public model: string, @@ -105,6 +112,43 @@ export class RequestUrl { } } +export class TemplateRequestUrl { + constructor( + public templateId: string, + public task: ServerPromptTemplateTask, + public apiSettings: ApiSettings, + public stream: boolean, + public requestOptions?: RequestOptions, + ) {} + + toString(): string { + // Manually construct URL to avoid React Native URL API issues + let baseUrl = this.baseUrl; + // Remove trailing slash if present + if (baseUrl.endsWith('/')) { + baseUrl = baseUrl.slice(0, -1); + } + + const pathname = `${this.apiSettings.backend._getTemplatePath(this.apiSettings.project, this.templateId)}:${this.task}`; + const queryString = this.queryParams; + + return `${baseUrl}${pathname}${queryString ? `?${queryString}` : ''}`; + } + + private get baseUrl(): string { + return this.requestOptions?.baseUrl || `https://${DEFAULT_DOMAIN}`; + } + + private get queryParams(): string { + let params = ''; + if (this.stream) { + params += 'alt=sse'; + } + + return params; + } +} + /** * Log language and "fire/version" to x-goog-api-client */ @@ -146,6 +190,37 @@ export async function getHeaders(url: RequestUrl): Promise { return headers; } +export async function getTemplateHeaders(url: TemplateRequestUrl): Promise { + const headers = new Headers(); + headers.append('Content-Type', 'application/json'); + headers.append('x-goog-api-client', getClientHeaders()); + headers.append('x-goog-api-key', url.apiSettings.apiKey); + if (url.apiSettings.automaticDataCollectionEnabled) { + headers.append('X-Firebase-Appid', url.apiSettings.appId); + } + if (url.apiSettings.getAppCheckToken) { + let appCheckToken; + + try { + appCheckToken = await url.apiSettings.getAppCheckToken(); + } catch (e) { + logger.warn(`Unable to obtain a valid App Check token: ${e}`); + } + if (appCheckToken) { + headers.append('X-Firebase-AppCheck', appCheckToken.token); + } + } + + if (url.apiSettings.getAuthToken) { + const authToken = await url.apiSettings.getAuthToken(); + if (authToken) { + headers.append('Authorization', `Firebase ${authToken}`); + } + } + + return headers; +} + export async function constructRequest( model: string, task: Task, @@ -165,28 +240,113 @@ export async function constructRequest( }; } -export async function makeRequest( - model: string, - task: Task, +export async function constructTemplateRequest( + templateId: string, + task: ServerPromptTemplateTask, apiSettings: ApiSettings, stream: boolean, body: string, requestOptions?: RequestOptions, +): Promise<{ url: string; fetchOptions: RequestInit }> { + const url = new TemplateRequestUrl(templateId, task, apiSettings, stream, requestOptions); + return { + url: url.toString(), + fetchOptions: { + method: 'POST', + headers: await getTemplateHeaders(url), + body, + }, + }; +} + +// Overload for model requests +export async function makeRequest( + params: { + model: string; + task: Task; + apiSettings: ApiSettings; + stream: boolean; + requestOptions?: RequestOptions; + }, + body: string, +): Promise; +// Overload for template requests +export async function makeRequest( + params: { + templateId: string; + task: ServerPromptTemplateTask; + apiSettings: ApiSettings; + stream: boolean; + requestOptions?: RequestOptions; + }, + body: string, +): Promise; +// Implementation +export async function makeRequest( + params: + | { + model: string; + task: Task; + apiSettings: ApiSettings; + stream: boolean; + requestOptions?: RequestOptions; + } + | { + templateId: string; + task: ServerPromptTemplateTask; + apiSettings: ApiSettings; + stream: boolean; + requestOptions?: RequestOptions; + }, + body: string, ): Promise { - const url = new RequestUrl(model, task, apiSettings, stream, requestOptions); + // Determine if this is a template request or model request + const isTemplateRequest = 'templateId' in params; + const url = isTemplateRequest + ? new TemplateRequestUrl( + params.templateId, + params.task, + params.apiSettings, + params.stream, + params.requestOptions, + ) + : new RequestUrl( + (params as { model: string }).model, + params.task as Task, + params.apiSettings, + params.stream, + params.requestOptions, + ); + let response; let fetchTimeoutId: string | number | NodeJS.Timeout | undefined; try { - const request = await constructRequest(model, task, apiSettings, stream, body, requestOptions); + const request = isTemplateRequest + ? await constructTemplateRequest( + params.templateId, + params.task, + params.apiSettings, + params.stream, + body, + params.requestOptions, + ) + : await constructRequest( + (params as { model: string }).model, + params.task as Task, + params.apiSettings, + params.stream, + body, + params.requestOptions, + ); const timeoutMillis = - requestOptions?.timeout != null && requestOptions.timeout >= 0 - ? requestOptions.timeout + params.requestOptions?.timeout != null && params.requestOptions.timeout >= 0 + ? params.requestOptions.timeout : DEFAULT_FETCH_TIMEOUT_MS; const abortController = new AbortController(); fetchTimeoutId = setTimeout(() => abortController.abort(), timeoutMillis); request.fetchOptions.signal = abortController.signal; - const fetchOptions = stream + const fetchOptions = params.stream ? { ...request.fetchOptions, reactNative: { @@ -263,3 +423,23 @@ export async function makeRequest( } return response; } + +export class WebSocketUrl { + constructor(public apiSettings: ApiSettings) {} + toString(): string { + // Manually construct URL to avoid React Native URL API issues + const baseUrl = `wss://${DEFAULT_DOMAIN}`; + const pathname = this.pathname; + const queryString = `key=${encodeURIComponent(this.apiSettings.apiKey)}`; + + return `${baseUrl}${pathname}?${queryString}`; + } + + private get pathname(): string { + if (this.apiSettings.backend.backendType === BackendType.GOOGLE_AI) { + return '/ws/google.firebase.vertexai.v1beta.GenerativeService/BidiGenerateContent'; + } else { + return `/ws/google.firebase.vertexai.v1beta.LlmBidiService/BidiGenerateContent/locations/${this.apiSettings.location}`; + } + } +} diff --git a/packages/ai/lib/types/error.ts b/packages/ai/lib/types/error.ts index 4fcc1ac483..811b0f519d 100644 --- a/packages/ai/lib/types/error.ts +++ b/packages/ai/lib/types/error.ts @@ -101,4 +101,7 @@ export const enum AIErrorCode { /** An error occurred due an attempt to use an unsupported feature. */ UNSUPPORTED = 'unsupported', + + /** An error occurred due to a session being closed. */ + SESSION_CLOSED = 'session-closed', } diff --git a/packages/ai/lib/types/live-responses.ts b/packages/ai/lib/types/live-responses.ts new file mode 100644 index 0000000000..a3f04a1d63 --- /dev/null +++ b/packages/ai/lib/types/live-responses.ts @@ -0,0 +1,89 @@ +/** + * @license + * Copyright 2025 Google LLC + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +import { Content, FunctionResponse, GenerativeContentBlob, Part } from './content'; +import { AudioTranscriptionConfig, LiveGenerationConfig, Tool, ToolConfig } from './requests'; +import { Transcription } from './responses'; + +/** + * User input that is sent to the model. + * + * @internal + */ +export interface _LiveClientContent { + clientContent: { + turns: [Content]; + turnComplete: boolean; + inputTranscription?: Transcription; + outputTranscription?: Transcription; + }; +} + +/** + * User input that is sent to the model in real time. + * + * @internal + */ +export interface _LiveClientRealtimeInput { + realtimeInput: { + text?: string; + audio?: GenerativeContentBlob; + video?: GenerativeContentBlob; + + /** + * @deprecated Use `text`, `audio`, and `video` instead. + */ + mediaChunks?: GenerativeContentBlob[]; + }; +} + +/** + * Function responses that are sent to the model in real time. + */ +export interface _LiveClientToolResponse { + toolResponse: { + functionResponses: FunctionResponse[]; + }; +} + +/** + * The first message in a Live session, used to configure generation options. + * + * @internal + */ +export interface _LiveClientSetup { + setup: { + model: string; + generationConfig?: _LiveGenerationConfig; + tools?: Tool[]; + toolConfig?: ToolConfig; + systemInstruction?: string | Part | Content; + inputAudioTranscription?: AudioTranscriptionConfig; + outputAudioTranscription?: AudioTranscriptionConfig; + }; +} + +/** + * The Live Generation Config. + * + * The public API ({@link LiveGenerationConfig}) has `inputAudioTranscription` and `outputAudioTranscription`, + * but the server expects these fields to be in the top-level `setup` message. This was a conscious API decision. + */ +export type _LiveGenerationConfig = Omit< + LiveGenerationConfig, + 'inputAudioTranscription' | 'outputAudioTranscription' +>; diff --git a/packages/ai/lib/types/requests.ts b/packages/ai/lib/types/requests.ts index 23501ac936..a3e0432ad3 100644 --- a/packages/ai/lib/types/requests.ts +++ b/packages/ai/lib/types/requests.ts @@ -46,6 +46,18 @@ export interface ModelParams extends BaseParams { systemInstruction?: string | Part | Content; } +/** + * Params passed to {@link getLiveGenerativeModel}. + * @beta + */ +export interface LiveModelParams { + model: string; + generationConfig?: LiveGenerationConfig; + tools?: Tool[]; + toolConfig?: ToolConfig; + systemInstruction?: string | Part | Content; +} + /** * Request sent through {@link GenerativeModel.generateContent} * @public @@ -119,6 +131,75 @@ export interface GenerationConfig { thinkingConfig?: ThinkingConfig; } +/** + * Configuration parameters used by {@link LiveGenerativeModel} to control live content generation. + * + * @beta + */ +export interface LiveGenerationConfig { + /** + * Configuration for speech synthesis. + */ + speechConfig?: SpeechConfig; + /** + * Specifies the maximum number of tokens that can be generated in the response. The number of + * tokens per word varies depending on the language outputted. Is unbounded by default. + */ + maxOutputTokens?: number; + /** + * Controls the degree of randomness in token selection. A `temperature` value of 0 means that the highest + * probability tokens are always selected. In this case, responses for a given prompt are mostly + * deterministic, but a small amount of variation is still possible. + */ + temperature?: number; + /** + * Changes how the model selects tokens for output. Tokens are + * selected from the most to least probable until the sum of their probabilities equals the `topP` + * value. For example, if tokens A, B, and C have probabilities of 0.3, 0.2, and 0.1 respectively + * and the `topP` value is 0.5, then the model will select either A or B as the next token by using + * the `temperature` and exclude C as a candidate. Defaults to 0.95 if unset. + */ + topP?: number; + /** + * Changes how the model selects token for output. A `topK` value of 1 means the select token is + * the most probable among all tokens in the model's vocabulary, while a `topK` value 3 means that + * the next token is selected from among the 3 most probably using probabilities sampled. Tokens + * are then further filtered with the highest selected `temperature` sampling. Defaults to 40 + * if unspecified. + */ + topK?: number; + /** + * Positive penalties. + */ + presencePenalty?: number; + /** + * Frequency penalties. + */ + frequencyPenalty?: number; + /** + * The modalities of the response. + */ + responseModalities?: ResponseModality[]; + /** + * Enables transcription of audio input. + * + * When enabled, the model will respond with transcriptions of your audio input in the `inputTranscriptions` property + * in {@link LiveServerContent} messages. Note that the transcriptions are broken up across + * messages, so you may only receive small amounts of text per message. For example, if you ask the model + * "How are you today?", the model may transcribe that input across three messages, broken up as "How a", "re yo", "u today?". + */ + inputAudioTranscription?: AudioTranscriptionConfig; + /** + * Enables transcription of audio input. + * + * When enabled, the model will respond with transcriptions of its audio output in the `outputTranscription` property + * in {@link LiveServerContent} messages. Note that the transcriptions are broken up across + * messages, so you may only receive small amounts of text per message. For example, if the model says + * "How are you today?", the model may transcribe that output across three messages, broken up as "How a", "re yo", "u today?". + */ + outputAudioTranscription?: AudioTranscriptionConfig; +} + /** * Params for {@link GenerativeModel.startChat}. * @public @@ -343,3 +424,11 @@ export interface SpeechConfig { */ voiceConfig?: VoiceConfig; } + +/** + * Configuration for audio transcription in Live sessions. + * + * @beta + */ +// eslint-disable-next-line @typescript-eslint/no-empty-object-type +export interface AudioTranscriptionConfig {} diff --git a/packages/ai/lib/types/responses.ts b/packages/ai/lib/types/responses.ts index 243d220f6e..e49f338cd4 100644 --- a/packages/ai/lib/types/responses.ts +++ b/packages/ai/lib/types/responses.ts @@ -412,3 +412,95 @@ export interface CountTokensResponse { */ promptTokensDetails?: ModalityTokenCount[]; } + +/** + * An incremental content update from the model. + * + * @beta + */ +export interface LiveServerContent { + type: 'serverContent'; + /** + * The content that the model has generated as part of the current conversation with the user. + */ + modelTurn?: Content; + /** + * Indicates whether the turn is complete. This is `undefined` if the turn is not complete. + */ + turnComplete?: boolean; + /** + * Indicates whether the model was interrupted by the client. An interruption occurs when + * the client sends a message before the model finishes it's turn. This is `undefined` if the + * model was not interrupted. + */ + interrupted?: boolean; + /** + * Transcription of the audio that was input to the model. + */ + inputTranscription?: Transcription; + /** + * Transcription of the audio output from the model. + */ + outputTranscription?: Transcription; +} + +/** + * Transcription of audio. This can be returned from a {@link LiveGenerativeModel} if transcription + * is enabled with the `inputAudioTranscription` or `outputAudioTranscription` properties on + * the {@link LiveGenerationConfig}. + * + * @beta + */ + +export interface Transcription { + /** + * The text transcription of the audio. + */ + text?: string; +} + +/** + * A request from the model for the client to execute one or more functions. + * + * @beta + */ +export interface LiveServerToolCall { + type: 'toolCall'; + /** + * An array of function calls to run. + */ + functionCalls: FunctionCall[]; +} + +/** + * Notification to cancel a previous function call triggered by {@link LiveServerToolCall}. + * + * @beta + */ +export interface LiveServerToolCallCancellation { + type: 'toolCallCancellation'; + /** + * IDs of function calls that were cancelled. These refer to the `id` property of a {@link FunctionCall}. + */ + functionIds: string[]; +} + +/** + * The types of responses that can be returned by {@link LiveSession.receive}. + * + * @beta + */ +export const LiveResponseType = { + SERVER_CONTENT: 'serverContent', + TOOL_CALL: 'toolCall', + TOOL_CALL_CANCELLATION: 'toolCallCancellation', +}; + +/** + * The types of responses that can be returned by {@link LiveSession.receive}. + * This is a property on all messages that can be used for type narrowing. This property is not + * returned by the server, it is assigned to a server message object once it's parsed. + * + * @beta + */ +export type LiveResponseType = (typeof LiveResponseType)[keyof typeof LiveResponseType]; diff --git a/packages/ai/lib/websocket.ts b/packages/ai/lib/websocket.ts new file mode 100644 index 0000000000..fddbe06de4 --- /dev/null +++ b/packages/ai/lib/websocket.ts @@ -0,0 +1,245 @@ +/** + * @license + * Copyright 2025 Google LLC + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +import { AIError } from './errors'; +import { logger } from './logger'; +import { AIErrorCode } from './types'; + +/** + * A standardized interface for interacting with a WebSocket connection. + * This abstraction allows the SDK to use the appropriate WebSocket implementation + * for the current JS environment (Browser vs. Node) without + * changing the core logic of the `LiveSession`. + * @internal + */ + +export interface WebSocketHandler { + /** + * Establishes a connection to the given URL. + * + * @param url The WebSocket URL (e.g., wss://...). + * @returns A promise that resolves on successful connection or rejects on failure. + */ + connect(url: string): Promise; + + /** + * Sends data over the WebSocket. + * + * @param data The string or binary data to send. + */ + send(data: string | ArrayBuffer): void; + + /** + * Returns an async generator that yields parsed JSON objects from the server. + * The yielded type is `unknown` because the handler cannot guarantee the shape of the data. + * The consumer is responsible for type validation. + * The generator terminates when the connection is closed. + * + * @returns A generator that allows consumers to pull messages using a `for await...of` loop. + */ + listen(): AsyncGenerator; + + /** + * Closes the WebSocket connection. + * + * @param code - A numeric status code explaining why the connection is closing. + * @param reason - A human-readable string explaining why the connection is closing. + */ + close(code?: number, reason?: string): Promise; +} + +/** + * A wrapper for the native `WebSocket` available in both Browsers and Node >= 22. + * + * @internal + */ +export class WebSocketHandlerImpl implements WebSocketHandler { + private ws?: WebSocket; + + constructor() { + if (typeof WebSocket === 'undefined') { + throw new AIError( + AIErrorCode.UNSUPPORTED, + 'The WebSocket API is not available in this environment. ' + + 'The "Live" feature is not supported here. It is supported in ' + + 'modern browser windows, Web Workers with WebSocket support, and Node >= 22.', + ); + } + } + + connect(url: string): Promise { + return new Promise((resolve, reject) => { + this.ws = new WebSocket(url); + // Note: binaryType is not supported in React Native's WebSocket implementation. + // We handle ArrayBuffer, Blob, and string data types in the message listener instead. + + const openHandler = (): void => { + resolve(); + this.ws?.removeEventListener('open', openHandler); + }; + + const errorHandler = (): void => { + reject(new AIError(AIErrorCode.FETCH_ERROR, `Error event raised on WebSocket`)); + this.ws?.removeEventListener('error', errorHandler); + }; + + this.ws.addEventListener('open', openHandler); + this.ws.addEventListener('error', errorHandler); + + this.ws.addEventListener('close', (event: any) => { + if (event?.reason) { + logger.warn(`WebSocket connection closed by server. Reason: '${event.reason}'`); + } + }); + }); + } + + send(data: string | ArrayBuffer): void { + if (!this.ws || this.ws.readyState !== WebSocket.OPEN) { + throw new AIError(AIErrorCode.REQUEST_ERROR, 'WebSocket is not open.'); + } + this.ws.send(data); + } + + async *listen(): AsyncGenerator { + if (!this.ws) { + throw new AIError(AIErrorCode.REQUEST_ERROR, 'WebSocket is not connected.'); + } + + const messageQueue: unknown[] = []; + const errorQueue: Error[] = []; + let resolvePromise: (() => void) | null = null; + let isClosed = false; + + const messageListener = async (event: any): Promise => { + let data: string; + + // Handle different data types across environments + if (event.data instanceof Blob) { + // Browser environment + data = await event.data.text(); + } else if (event.data instanceof ArrayBuffer) { + // React Native environment - binary data comes as ArrayBuffer + const decoder = new TextDecoder('utf-8'); + data = decoder.decode(event.data); + } else if (typeof event.data === 'string') { + // String data in all environments + data = event.data; + } else { + errorQueue.push( + new AIError( + AIErrorCode.PARSE_FAILED, + `Failed to parse WebSocket response. Expected data to be a Blob, ArrayBuffer, or string, but was ${typeof event.data}.`, + ), + ); + if (resolvePromise) { + resolvePromise(); + resolvePromise = null; + } + return; + } + + try { + const obj = JSON.parse(data) as unknown; + messageQueue.push(obj); + } catch (e) { + const err = e as Error; + errorQueue.push( + new AIError( + AIErrorCode.PARSE_FAILED, + `Error parsing WebSocket message to JSON: ${err.message}`, + ), + ); + } + + if (resolvePromise) { + resolvePromise(); + resolvePromise = null; + } + }; + + const errorListener = (): void => { + errorQueue.push(new AIError(AIErrorCode.FETCH_ERROR, 'WebSocket connection error.')); + if (resolvePromise) { + resolvePromise(); + resolvePromise = null; + } + }; + + const closeListener = (event: any): void => { + if (event?.reason) { + logger.warn(`WebSocket connection closed by the server with reason: ${event.reason}`); + } + isClosed = true; + if (resolvePromise) { + resolvePromise(); + resolvePromise = null; + } + // Clean up listeners to prevent memory leaks + this.ws?.removeEventListener('message', messageListener as any); + this.ws?.removeEventListener('close', closeListener as any); + this.ws?.removeEventListener('error', errorListener as any); + }; + + this.ws.addEventListener('message', messageListener as any); + this.ws.addEventListener('close', closeListener as any); + this.ws.addEventListener('error', errorListener as any); + + while (!isClosed) { + if (errorQueue.length > 0) { + const error = errorQueue.shift()!; + throw error; + } + if (messageQueue.length > 0) { + yield messageQueue.shift()!; + } else { + await new Promise(resolve => { + resolvePromise = resolve; + }); + } + } + + // If the loop terminated because isClosed is true, check for any final errors + if (errorQueue.length > 0) { + const error = errorQueue.shift()!; + throw error; + } + } + + close(code?: number, reason?: string): Promise { + return new Promise(resolve => { + if (!this.ws) { + return resolve(); + } + + const closeHandler = (): void => { + resolve(); + this.ws?.removeEventListener('close', closeHandler as any); + }; + + this.ws.addEventListener('close', closeHandler as any); + // Calling 'close' during these states results in an error. + if (this.ws.readyState === WebSocket.CLOSED || this.ws.readyState === WebSocket.CONNECTING) { + return resolve(); + } + + if (this.ws.readyState !== WebSocket.CLOSING) { + this.ws.close(code, reason); + } + }); + } +} diff --git a/tests/ios/Podfile.lock b/tests/ios/Podfile.lock index 00b3428bba..c58998aff0 100644 --- a/tests/ios/Podfile.lock +++ b/tests/ios/Podfile.lock @@ -2235,85 +2235,85 @@ SPEC CHECKSUMS: nanopb: fad817b59e0457d11a5dfbde799381cd727c1275 PromisesObjC: f5707f49cb48b9636751c5b2e7d227e43fba9f47 PromisesSwift: 9d77319bbe72ebf6d872900551f7eeba9bce2851 - RCT-Folly: e78785aa9ba2ed998ea4151e314036f6c49e6d82 + RCT-Folly: 36fe2295e44b10d831836cc0d1daec5f8abcf809 RCTDeprecation: cf39863b43871c2031050605fb884019b6193910 RCTRequired: 8fdd66f4a97f352b66f38cfef13fc11b12d2c884 RCTTypeSafety: c9c9e64389bc545fc137030615b387ef1654dcee React: 14a80ea4f13387cfdaa4250b46fbfe19754c220c React-callinvoker: fed1dad5d6cf992c7b4b5fdbf1bf67fe2e8fb6c5 - React-Core: 3c803e7f3be6fa68e3dabcac283a5a5f87340a60 - React-CoreModules: 94d556b4055defb79278c3afba95e521998b9b3a - React-cxxreact: 21c826a567199cc2d5797bd4cfe3c0d94bb3a7de + React-Core: f703e7a56fcedc3e959b8b7899995e57fd58539a + React-CoreModules: 6e87c904cc257058c271708eef1719b5b3039131 + React-cxxreact: 4153beeff710944832cd90ccb141e299ee16b7d3 React-debug: a665bbe67eb786b7a7a815ce1b7212b3f9fa962c - React-defaultsnativemodule: 8f0bea8d0d3434aa98957302f2a8741600781edd - React-domnativemodule: 889d261cc9691b91063a1c1d89d408e19903923d - React-Fabric: eb0bde19a858807eee7d4fee07f24f036bdf5e6a - React-FabricComponents: 7756ddcd87ff121d0dedff69c80fd6177ccd31d9 - React-FabricImage: c3096fe8c20a4aec77c54c271f878948efa9d477 + React-defaultsnativemodule: 7e4a2c2b13ec2943f2f3b8adec32130443909de6 + React-domnativemodule: dffaa482180243bd1e2b7fba329fd4adc12a2445 + React-Fabric: bd742f0ddb6073ff14f7e51075bc0380b84e7f7a + React-FabricComponents: 347a74f275f989b493ab5324a2161e6b3de5556e + React-FabricImage: f9da31053da5aae09b7f67afdc46329c6e3a2710 React-featureflags: 55800b546a28b63a8a0f419e1a45871d43523d32 - React-featureflagsnativemodule: b200bca78e00f9e5c7cd5a1c9f2957223fcfa33a - React-graphics: fb41a6a55ecd473159452c75b3ea7b57794903a3 - React-hermes: 3be7f73615e70b310b0457146691d39ea3752e6b - React-idlecallbacksnativemodule: dc10ee2e5ba5ae7ad56aa093aedda582345bce16 - React-ImageManager: 08592583c7737aec2b2c2a12e7c4f0ad763ae5c4 - React-jserrorhandler: eede06f57f67c8d3978ff1a9c78107aea5cbdf45 - React-jsi: 70ca5cce94e7f81e4ff51a71b4720b5fb943eea5 - React-jsiexecutor: 265d9fbb2a88c74892646e0012cec62ebb89edcf - React-jsinspector: 7a7e033c64cc40e6a1d6daf7fad4827bc1e7dd12 - React-jsinspectortracing: 77f4d18502af6e7599e77390b35f04f02814f7ce - React-jsitracing: 0608ea7ee711370d42fdd2b4a474dbced68dd275 - React-logger: 8d00d3d794041a77bd890158353361e2709b04c1 - React-Mapbuffer: 45ca4d30efe99834a8cd8d98f803c94766db534f - React-microtasksnativemodule: b5901a0b15f92ce0666ee5131eb8ab646f1d5a27 - React-NativeModulesApple: 7a9ec626a1852d444d0e699b016dc55a064b7569 - React-perflogger: d06f0fd0727356651a5535f6d717130187aeb667 - React-performancetimeline: c397114f2c025aa73412a9f21e021b08127fe820 + React-featureflagsnativemodule: a0ea334fdd3342a2e4dc05085c3e7653e16839d3 + React-graphics: 7360f5f3611fd5982aa0de772a73987ab401fb02 + React-hermes: a942bebef5e9fcc31f51c6fb814e96c260a2a20d + React-idlecallbacksnativemodule: 74d091304aad1ceb0c5b86e5dec14372fcdc0f34 + React-ImageManager: 6b2a95469d9a126f14410bbe10fb7378d37ed0e0 + React-jserrorhandler: 2643140639cbf047bf45f5c1a3ea95b04d748582 + React-jsi: b2de88284fc2cc69466a34d8b794160216d3bd2c + React-jsiexecutor: e947af1c9e42610affd9f4178cd4b649e8ac889b + React-jsinspector: 6d768dfb189027f7ff2161be31ccd69215426ded + React-jsinspectortracing: a6a70eb5c9d767d99391d0373330a7239fb6f9d0 + React-jsitracing: 69280997c7a80ac0af62b95f01a241d68581fb52 + React-logger: e6c3c1b55c18cc1b945b647ff9ada53e0e710b50 + React-Mapbuffer: 57bea44149376ecf1241dd9f02a6222adab43e57 + React-microtasksnativemodule: 2739fc8a8cb99486820c933ce560791c93be5337 + React-NativeModulesApple: 4849912ee050e6ae346f5789bd631570580d8b84 + React-perflogger: 069d41f741187be92ed46c6ac67add03b99f3166 + React-performancetimeline: eda794c5007eb6834e13bc6026a879b5f6d49c74 React-RCTActionSheet: a078d5008632fed31b0024c420ee02e612b317d5 - React-RCTAnimation: b197cc109a896c6ce23981e02e317cfc055f6fda - React-RCTAppDelegate: f7f1d7362256b7c142d9ab49f980df488101f869 - React-RCTBlob: c12d15d40db01ac3fe57c24d3ef5233ff3333da6 - React-RCTFabric: 8cdcde7157a22aac04dfeb579dfc3a1141446846 - React-RCTFBReactNativeSpec: c3a78cb9f2a98146443f1b732a4f21b2ce736abd - React-RCTImage: 7a3d9d67161c714fa4d9b93820da39a266d0f1ff - React-RCTLinking: f860b917500cd3974235a48d5b199a74a4ed6c26 - React-RCTNetwork: 6a984ab1b5a81d17a2df6cc02c24b249fb055deb - React-RCTSettings: e9a39068d8b60d78a5271dcb68f6ea7f59569cb2 - React-RCTText: 44457242238664a5ad69f06ec7a5f273a6967711 - React-RCTVibration: f448ad875c60b2ddc5fc5b06d3f5e2dfc3f18c09 + React-RCTAnimation: 82e31d191af4175e0c2df5bdac2c8569a5f3ab54 + React-RCTAppDelegate: a5c1ff79f5987462b4f62b27387459ba84012439 + React-RCTBlob: c462b8b7de6ce44ddc56dd96eebe1da0a6e54c77 + React-RCTFabric: 56b946204edb5d563885b3b045bdacbb387b27e7 + React-RCTFBReactNativeSpec: 8392ef66ad156cfa848546859bbff3b5e8a09458 + React-RCTImage: 10fad63f1bb8adbd519c4c2ef6bec3c0d95fdd32 + React-RCTLinking: 3843288a44dc33ec083c843f3ff31dd7d96ece41 + React-RCTNetwork: f237299bda8bbd56c4d01d2825110e40b75c438a + React-RCTSettings: c24ce1ee96c9b001ff5059ddd53412a20b7d5e71 + React-RCTText: d97cfb9c89b06de9530577dd43f178c47ea07853 + React-RCTVibration: 2fcefee071a4f0d416e4368416bb073ea6893451 React-rendererconsistency: c9f31b6d55877c5d49d25d69270b89f9cb208e26 - React-rendererdebug: 939c31f97f3bbf434832b7f73d8a438cf96ee1c4 + React-rendererdebug: 185ba0f801f29565873f7a37e410a812eddaa1ee React-rncore: 90e637179a4ce46643d445a9ef16f53af02a8d25 - React-RuntimeApple: 3df87718b4a8e438b4a02d1d7d47677dfcab35a5 - React-RuntimeCore: e188aa1b0fe0450f3a4e6098f2bb8d4b27c147cf + React-RuntimeApple: 713b7c24b3abed07fa39766b35deaabd679ba48e + React-RuntimeCore: 236d704919077fd3393a26fd0ecbaecc081ec94f React-runtimeexecutor: 2de0d537fc6d5b4a7074587b4459ea331c7e5715 - React-RuntimeHermes: 5aa429111da4c0a62d94b37ad5b5144deb8f49d0 - React-runtimescheduler: c3738ed7f0ba2e51b1612ec8d6cbe104c2b9b59a + React-RuntimeHermes: 3e87ad8c5160c976addacd679774a5e4fdb3c4b4 + React-runtimescheduler: bafaf0af0f68bd761d63ff4de3bf13e391820f79 React-timing: 7ad7dc61dfc93ceb4ec2b3e6d1a6ad9ad3652fe0 - React-utils: d6a3bec920c7fa710e8fb5b7c28d81fe54be8c75 - ReactAppDependencyProvider: 5df090fa3cbfc923c6bd0595b64d5ef6d89f7134 - ReactCodegen: 0c213020a601c6adda74f8826629bff9c6c408d3 - ReactCommon: 7f90ec8358d94ec2a078e11990f6e78542957b11 + React-utils: cf358d29b6802cca3d1bec20a939f2f5d50d40ba + ReactAppDependencyProvider: ad88c80e06f29900f2e6f9ccf1d4cb0bfc3e1bbc + ReactCodegen: 69c7aec61821e1860aaaf959189218ecca40e811 + ReactCommon: ef3e394efce4b78e9214587689c459cf82927368 RecaptchaInterop: 11e0b637842dfb48308d242afc3f448062325aba - RNCAsyncStorage: 6a8127b6987dc9fbce778669b252b14c8355c7ce - RNDeviceInfo: bcce8752b5043a623fe3c26789679b473f705d3c - RNFBAnalytics: 8a9771f292a0cbc9e33e84a857ba53f00a0c0db6 - RNFBApp: 76a17bd5a7ab06e4757e5b84d42d3e30c344f060 - RNFBAppCheck: d2e68ac65991d5ffdbb65e6de87034f467cbe98a - RNFBAppDistribution: 23bf60adbd70f64e8a91d169b55dc38b15ccac6c - RNFBAuth: 712e85c69a1f5b6d9c624eca9f92831976016a5e - RNFBCrashlytics: 7ba5fb1d27ae0019c9c0a46a0111758d9a7d0b0f - RNFBDatabase: 13814bc51ebf8af459d248dde4c0f0216a3bf486 - RNFBFirestore: a7c57bc634607929418b29b744be49897ae0ce1d - RNFBFunctions: 914a518d87227e0c5b3beb91789e0a2325dfb5dd - RNFBInAppMessaging: ca3a4590b39dd196751e3e70d938c7aa37a19409 - RNFBInstallations: c85e437fa09155df6ed4947c37f44abeead104c2 - RNFBMessaging: 0c3d2fa215ba2778ee8bc8610c921cb57aa85abd - RNFBML: 2cbc15122658e03c20f088a5c0b99dc116a417a4 - RNFBPerf: 507c272f6fa844e19863c05aca6268f7a9d6345e - RNFBRemoteConfig: b51f88dc4e62b00312f81dd2d54212d343025d15 - RNFBStorage: 39506275bc197187addcedb93dac722cf339c9a5 + RNCAsyncStorage: 481acf401089f312189e100815088ea5dafc583c + RNDeviceInfo: 8b6fa8379062949dd79a009cf3d6b02a9c03ca59 + RNFBAnalytics: c72330428ef1ee616250c26abcc154bc92b7c63f + RNFBApp: bf6d45c99c3ca4635b945284be0c1455b9d5c8ed + RNFBAppCheck: df199d8e604ff9f5582a8efae43b03413be01839 + RNFBAppDistribution: 64b649053cccb2ed3c7e8e6aac086bb20b67e57c + RNFBAuth: 575f269b83c5f7f14bbcbdf88004927776417f79 + RNFBCrashlytics: 24952f7befb8ecc0be5f75365f48d933d3b16c73 + RNFBDatabase: f91415fa8b123493f4477d31c9027e6b1885a19f + RNFBFirestore: 6263f354cf1299967b2638c3ebdb717d17ad4a1b + RNFBFunctions: 4732360e63f1c7902e0663c571ada0a5809ba3c0 + RNFBInAppMessaging: a9e5e8cca8fb68fe6fbd1ae0574ce3beac1dde3f + RNFBInstallations: 76fb4fc683932cbf4de2d1664fa12cefeb3bb9d1 + RNFBMessaging: 107048d3a46c89bcf71035bdf3a243e58921eb7d + RNFBML: 37c38ac8bff7a8d954b1ac9cc1a3685b9bfa227c + RNFBPerf: c9d6b7acc115b9878e7526316f49796381b67bd6 + RNFBRemoteConfig: e1500065123cd45901491bd4a2a8eee1a56995b4 + RNFBStorage: 8bfd9a3512b116873198bafaad89e0b64d094960 SocketRocket: d4aabe649be1e368d1318fdf28a022d714d65748 - Yoga: 3bb1ee33b5133befbd33872601fa46efdd48e841 + Yoga: 66a9fd80007d5d5fce19d1676ce17b4d5e16e9b1 PODFILE CHECKSUM: 3abe8cfe7b06f24b788e90bea320d8ae6ea6d11a diff --git a/tests/local-tests/ai/ai.tsx b/tests/local-tests/ai/ai.tsx index c65cf2b992..11c20ece1b 100644 --- a/tests/local-tests/ai/ai.tsx +++ b/tests/local-tests/ai/ai.tsx @@ -1,7 +1,6 @@ /* eslint-disable no-console */ import React, { JSX, useState } from 'react'; import { Button, View, Text, Pressable } from 'react-native'; - import { getApp } from '@react-native-firebase/app'; import { getAI, @@ -13,6 +12,16 @@ import { GenerateContentResult, GenerateContentStreamResult, ChatSession, + VertexAIBackend, + getLiveGenerativeModel, + LiveGenerativeModel, + LiveSession, + ResponseModality, + getTemplateGenerativeModel, + getTemplateImagenModel, + TemplateGenerativeModel, + TemplateImagenModel, + GoogleAIBackend, } from '@react-native-firebase/ai'; import { PDF_BASE_64, @@ -98,7 +107,7 @@ export function AITestComponent() { try { const app = getApp(); const ai: AI = getAI(app); - const model: GenerativeModel = getGenerativeModel(ai, { model: 'gemini-1.5-flash' }); + const model: GenerativeModel = getGenerativeModel(ai, { model: 'gemini-2.5-flash' }); const result: GenerateContentResult = await model.generateContent('What is 2 + 2?'); @@ -113,8 +122,8 @@ export function AITestComponent() { onPress={async (): Promise => { try { const app = getApp(); - const ai: AI = getAI(app); - const model: GenerativeModel = getGenerativeModel(ai, { model: 'gemini-1.5-flash' }); + const ai: AI = getAI(app, { backend: new GoogleAIBackend() }); + const model: GenerativeModel = getGenerativeModel(ai, { model: 'gemini-2.5-flash' }); const result: GenerateContentStreamResult = await model.generateContentStream( 'Write me a short, funny rap', @@ -145,7 +154,7 @@ export function AITestComponent() { try { const app = getApp(); const ai: AI = getAI(app); - const model: GenerativeModel = getGenerativeModel(ai, { model: 'gemini-1.5-flash' }); + const model: GenerativeModel = getGenerativeModel(ai, { model: 'gemini-2.5-flash' }); const mediaDetails: MediaDetails | null = getMediaDetails(selectedOption); if (!mediaDetails) return; @@ -190,7 +199,7 @@ export function AITestComponent() { }, }); const model: GenerativeModel = getGenerativeModel(ai, { - model: 'gemini-1.5-flash', + model: 'gemini-2.5-flash', generationConfig: { responseMimeType: 'application/json', responseSchema: jsonSchema, @@ -213,7 +222,7 @@ export function AITestComponent() { try { const app = getApp(); const ai: AI = getAI(app); - const model: GenerativeModel = getGenerativeModel(ai, { model: 'gemini-1.5-flash' }); + const model: GenerativeModel = getGenerativeModel(ai, { model: 'gemini-2.5-flash' }); const chat: ChatSession = model.startChat({ history: [ @@ -252,7 +261,7 @@ export function AITestComponent() { try { const app = getApp(); const ai: AI = getAI(app); - const model: GenerativeModel = getGenerativeModel(ai, { model: 'gemini-1.5-flash' }); + const model: GenerativeModel = getGenerativeModel(ai, { model: 'gemini-2.5-flash' }); const result = await model.countTokens('What is 2 + 2?'); @@ -334,7 +343,7 @@ export function AITestComponent() { const app = getApp(); const ai: AI = getAI(app); const model: GenerativeModel = getGenerativeModel(ai, { - model: 'gemini-1.5-flash', + model: 'gemini-2.5-flash', // @ts-ignore tools: fetchWeatherTool, }); @@ -395,6 +404,127 @@ export function AITestComponent() { } }} /> +