From f02446ff5062d0c1a65cd90397c6b0f7c53446a7 Mon Sep 17 00:00:00 2001 From: Tanzim Hossain Date: Thu, 13 Jun 2024 19:39:37 +0000 Subject: [PATCH] VertexAI: add test cases for countTokens() --- .../vertexai/src/methods/count-tokens.test.ts | 93 +++++++++++++++++++ .../src/models/generative-model.test.ts | 20 ++++ .../count-tokens-failure-model-not-found.json | 13 +++ ...tokens-success-no-billable-characters.json | 3 + .../count-tokens-success-total-tokens.json | 4 + 5 files changed, 133 insertions(+) create mode 100644 packages/vertexai/src/methods/count-tokens.test.ts create mode 100644 packages/vertexai/test-utils/mock-responses/count-tokens-failure-model-not-found.json create mode 100644 packages/vertexai/test-utils/mock-responses/count-tokens-success-no-billable-characters.json create mode 100644 packages/vertexai/test-utils/mock-responses/count-tokens-success-total-tokens.json diff --git a/packages/vertexai/src/methods/count-tokens.test.ts b/packages/vertexai/src/methods/count-tokens.test.ts new file mode 100644 index 00000000000..33b4de795cc --- /dev/null +++ b/packages/vertexai/src/methods/count-tokens.test.ts @@ -0,0 +1,93 @@ +import { expect, use } from 'chai'; +import { match, restore, stub } from 'sinon'; +import sinonChai from 'sinon-chai'; +import chaiAsPromised from 'chai-as-promised'; +import { getMockResponse } from '../../test-utils/mock-response'; +import * as request from '../requests/request'; +import { countTokens } from './count-tokens'; +import { CountTokensRequest } from '../types'; +import { ApiSettings } from '../types/internal'; +import { Task } from '../requests/request'; + +use(sinonChai); +use(chaiAsPromised); + +const fakeApiSettings: ApiSettings = { + apiKey: 'key', + project: 'my-project', + location: 'us-central1' +}; + +const fakeRequestParams: CountTokensRequest = { + contents: [{ parts: [{ text: 'hello' }], role: 'user' }] +}; + +describe('countTokens()', () => { + afterEach(() => { + restore(); + }); + it('total tokens', async () => { + const mockResponse = getMockResponse( + 'count-tokens-success-total-tokens.json' + ); + const makeRequestStub = stub(request, 'makeRequest').resolves( + mockResponse as Response + ); + const result = await countTokens( + fakeApiSettings, + 'model', + fakeRequestParams + ); + expect(result.totalTokens).to.equal(6); + expect(result.totalBillableCharacters).to.equal(16); + expect(makeRequestStub).to.be.calledWith( + 'model', + Task.COUNT_TOKENS, + fakeApiSettings, + false, + match((value: string) => { + return value.includes('contents'); + }), + undefined + ); + }); + it('total tokens no billable characters', async () => { + const mockResponse = getMockResponse( + 'count-tokens-success-no-billable-characters.json' + ); + const makeRequestStub = stub(request, 'makeRequest').resolves( + mockResponse as Response + ); + const result = await countTokens( + fakeApiSettings, + 'model', + fakeRequestParams + ); + expect(result.totalTokens).to.equal(258); + expect(result).to.not.have.property('totalBillableCharacters'); + expect(makeRequestStub).to.be.calledWith( + 'model', + Task.COUNT_TOKENS, + fakeApiSettings, + false, + match((value: string) => { + return value.includes('contents'); + }), + undefined + ); + }); + it('model not found', async () => { + const mockResponse = getMockResponse( + 'count-tokens-failure-model-not-found.json' + ); + const mockFetch = stub(globalThis, 'fetch').resolves({ + ok: false, + status: 404, + json: mockResponse.json + } as Response); + await expect( + countTokens(fakeApiSettings, 'model', fakeRequestParams) + ).to.be.rejectedWith(/404.*not found/); + expect(mockFetch).to.be.called; + }); +}); diff --git a/packages/vertexai/src/models/generative-model.test.ts b/packages/vertexai/src/models/generative-model.test.ts index 7b0287492da..e036525bace 100644 --- a/packages/vertexai/src/models/generative-model.test.ts +++ b/packages/vertexai/src/models/generative-model.test.ts @@ -262,4 +262,24 @@ describe('GenerativeModel', () => { ); restore(); }); + it('calls countTokens', async () => { + const genModel = new GenerativeModel(fakeVertexAI, { model: 'my-model' }); + const mockResponse = getMockResponse( + 'count-tokens-success-total-tokens.json' + ); + const makeRequestStub = stub(request, 'makeRequest').resolves( + mockResponse as Response + ); + await genModel.countTokens('hello'); + expect(makeRequestStub).to.be.calledWith( + 'publishers/google/models/my-model', + request.Task.COUNT_TOKENS, + match.any, + false, + match((value: string) => { + return value.includes('hello'); + }) + ); + restore(); + }); }); diff --git a/packages/vertexai/test-utils/mock-responses/count-tokens-failure-model-not-found.json b/packages/vertexai/test-utils/mock-responses/count-tokens-failure-model-not-found.json new file mode 100644 index 00000000000..50fcb725667 --- /dev/null +++ b/packages/vertexai/test-utils/mock-responses/count-tokens-failure-model-not-found.json @@ -0,0 +1,13 @@ +{ + "error": { + "code": 404, + "message": "models/test-model-name is not found for API version v1beta, or is not supported for countTokens. Call ListModels to see the list of available models and their supported methods.", + "status": "NOT_FOUND", + "details": [ + { + "@type": "type.googleapis.com/google.rpc.DebugInfo", + "detail": "[ORIGINAL ERROR] generic::not_found: models/test-model-name is not found for API version v1beta, or is not supported for countTokens. Call ListModels to see the list of available models and their supported methods. [google.rpc.error_details_ext] { message: \"models/test-model-name is not found for API version v1beta, or is not supported for countTokens. Call ListModels to see the list of available models and their supported methods.\" }" + } + ] + } +} diff --git a/packages/vertexai/test-utils/mock-responses/count-tokens-success-no-billable-characters.json b/packages/vertexai/test-utils/mock-responses/count-tokens-success-no-billable-characters.json new file mode 100644 index 00000000000..03425d51db0 --- /dev/null +++ b/packages/vertexai/test-utils/mock-responses/count-tokens-success-no-billable-characters.json @@ -0,0 +1,3 @@ +{ + "totalTokens": 258 +} diff --git a/packages/vertexai/test-utils/mock-responses/count-tokens-success-total-tokens.json b/packages/vertexai/test-utils/mock-responses/count-tokens-success-total-tokens.json new file mode 100644 index 00000000000..d2ad6e4ff30 --- /dev/null +++ b/packages/vertexai/test-utils/mock-responses/count-tokens-success-total-tokens.json @@ -0,0 +1,4 @@ +{ + "totalTokens": 6, + "totalBillableCharacters": 16 +}