Skip to content

Commit

Permalink
VertexAI: add test cases for countTokens()
Browse files Browse the repository at this point in the history
  • Loading branch information
tanzimfh committed Jun 13, 2024
1 parent a90255a commit f02446f
Show file tree
Hide file tree
Showing 5 changed files with 133 additions and 0 deletions.
93 changes: 93 additions & 0 deletions packages/vertexai/src/methods/count-tokens.test.ts
Original file line number Diff line number Diff line change
@@ -0,0 +1,93 @@
import { expect, use } from 'chai';
import { match, restore, stub } from 'sinon';
import sinonChai from 'sinon-chai';
import chaiAsPromised from 'chai-as-promised';
import { getMockResponse } from '../../test-utils/mock-response';
import * as request from '../requests/request';
import { countTokens } from './count-tokens';
import { CountTokensRequest } from '../types';
import { ApiSettings } from '../types/internal';
import { Task } from '../requests/request';

use(sinonChai);
use(chaiAsPromised);

const fakeApiSettings: ApiSettings = {
apiKey: 'key',
project: 'my-project',
location: 'us-central1'
};

const fakeRequestParams: CountTokensRequest = {
contents: [{ parts: [{ text: 'hello' }], role: 'user' }]
};

describe('countTokens()', () => {
afterEach(() => {
restore();
});
it('total tokens', async () => {
const mockResponse = getMockResponse(
'count-tokens-success-total-tokens.json'
);
const makeRequestStub = stub(request, 'makeRequest').resolves(
mockResponse as Response
);
const result = await countTokens(
fakeApiSettings,
'model',
fakeRequestParams
);
expect(result.totalTokens).to.equal(6);
expect(result.totalBillableCharacters).to.equal(16);
expect(makeRequestStub).to.be.calledWith(
'model',
Task.COUNT_TOKENS,
fakeApiSettings,
false,
match((value: string) => {
return value.includes('contents');
}),
undefined
);
});
it('total tokens no billable characters', async () => {
const mockResponse = getMockResponse(
'count-tokens-success-no-billable-characters.json'
);
const makeRequestStub = stub(request, 'makeRequest').resolves(
mockResponse as Response
);
const result = await countTokens(
fakeApiSettings,
'model',
fakeRequestParams
);
expect(result.totalTokens).to.equal(258);
expect(result).to.not.have.property('totalBillableCharacters');
expect(makeRequestStub).to.be.calledWith(
'model',
Task.COUNT_TOKENS,
fakeApiSettings,
false,
match((value: string) => {
return value.includes('contents');
}),
undefined
);
});
it('model not found', async () => {
const mockResponse = getMockResponse(
'count-tokens-failure-model-not-found.json'
);
const mockFetch = stub(globalThis, 'fetch').resolves({
ok: false,
status: 404,
json: mockResponse.json
} as Response);
await expect(
countTokens(fakeApiSettings, 'model', fakeRequestParams)
).to.be.rejectedWith(/404.*not found/);
expect(mockFetch).to.be.called;
});
});
20 changes: 20 additions & 0 deletions packages/vertexai/src/models/generative-model.test.ts
Original file line number Diff line number Diff line change
Expand Up @@ -262,4 +262,24 @@ describe('GenerativeModel', () => {
);
restore();
});
it('calls countTokens', async () => {
const genModel = new GenerativeModel(fakeVertexAI, { model: 'my-model' });
const mockResponse = getMockResponse(
'count-tokens-success-total-tokens.json'
);
const makeRequestStub = stub(request, 'makeRequest').resolves(
mockResponse as Response
);
await genModel.countTokens('hello');
expect(makeRequestStub).to.be.calledWith(
'publishers/google/models/my-model',
request.Task.COUNT_TOKENS,
match.any,
false,
match((value: string) => {
return value.includes('hello');
})
);
restore();
});
});
Original file line number Diff line number Diff line change
@@ -0,0 +1,13 @@
{
"error": {
"code": 404,
"message": "models/test-model-name is not found for API version v1beta, or is not supported for countTokens. Call ListModels to see the list of available models and their supported methods.",
"status": "NOT_FOUND",
"details": [
{
"@type": "type.googleapis.com/google.rpc.DebugInfo",
"detail": "[ORIGINAL ERROR] generic::not_found: models/test-model-name is not found for API version v1beta, or is not supported for countTokens. Call ListModels to see the list of available models and their supported methods. [google.rpc.error_details_ext] { message: \"models/test-model-name is not found for API version v1beta, or is not supported for countTokens. Call ListModels to see the list of available models and their supported methods.\" }"
}
]
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
{
"totalTokens": 258
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,4 @@
{
"totalTokens": 6,
"totalBillableCharacters": 16
}

0 comments on commit f02446f

Please sign in to comment.