diff --git a/src/vs/workbench/api/browser/mainThreadLanguageModels.ts b/src/vs/workbench/api/browser/mainThreadLanguageModels.ts index 9a886a6713af2..028ef5db32cf2 100644 --- a/src/vs/workbench/api/browser/mainThreadLanguageModels.ts +++ b/src/vs/workbench/api/browser/mainThreadLanguageModels.ts @@ -60,7 +60,10 @@ export class MainThreadLanguageModels implements MainThreadLanguageModelsShape { } finally { this._pendingProgress.delete(requestId); } - } + }, + provideTokenCount: (str, token) => { + return this._proxy.$provideTokenLength(handle, str, token); + }, })); if (metadata.auth) { dipsosables.add(this._registerAuthenticationProvider(metadata.extension, metadata.auth)); @@ -119,6 +122,11 @@ export class MainThreadLanguageModels implements MainThreadLanguageModelsShape { return task; } + + $countTokens(provider: string, value: string, token: CancellationToken): Promise { + return this._chatProviderService.computeTokenLength(provider, value, token); + } + private _registerAuthenticationProvider(extension: ExtensionIdentifier, auth: { providerLabel: string; accountLabel?: string | undefined }): IDisposable { // This needs to be done in both MainThread & ExtHost ChatProvider const authProviderId = INTERNAL_AUTH_PROVIDER_PREFIX + extension.value; diff --git a/src/vs/workbench/api/common/extHost.protocol.ts b/src/vs/workbench/api/common/extHost.protocol.ts index 6f7a2931bcf38..13a592f0e97f2 100644 --- a/src/vs/workbench/api/common/extHost.protocol.ts +++ b/src/vs/workbench/api/common/extHost.protocol.ts @@ -1199,6 +1199,8 @@ export interface MainThreadLanguageModelsShape extends IDisposable { $prepareChatAccess(extension: ExtensionIdentifier, providerId: string, justification?: string): Promise; $fetchResponse(extension: ExtensionIdentifier, provider: string, requestId: number, messages: IChatMessage[], options: {}, token: CancellationToken): Promise; + + $countTokens(provider: string, value: string, token: CancellationToken): Promise; } export interface ExtHostLanguageModelsShape { @@ -1206,6 +1208,7 @@ export interface ExtHostLanguageModelsShape { $updateModelAccesslist(data: { from: ExtensionIdentifier; to: ExtensionIdentifier; enabled: boolean }[]): void; $provideLanguageModelResponse(handle: number, requestId: number, from: ExtensionIdentifier, messages: IChatMessage[], options: { [name: string]: any }, token: CancellationToken): Promise; $handleResponseFragment(requestId: number, chunk: IChatResponseFragment): Promise; + $provideTokenLength(handle: number, value: string, token: CancellationToken): Promise; } export interface IExtensionChatAgentMetadata extends Dto { diff --git a/src/vs/workbench/api/common/extHostLanguageModels.ts b/src/vs/workbench/api/common/extHostLanguageModels.ts index 8221534013c05..623ad2b6e3b85 100644 --- a/src/vs/workbench/api/common/extHostLanguageModels.ts +++ b/src/vs/workbench/api/common/extHostLanguageModels.ts @@ -21,6 +21,7 @@ import { createDecorator } from 'vs/platform/instantiation/common/instantiation' import { IExtHostRpcService } from 'vs/workbench/api/common/extHostRpcService'; import { IExtHostAuthentication } from 'vs/workbench/api/common/extHostAuthentication'; import { ILogService } from 'vs/platform/log/common/log'; +import { Iterable } from 'vs/base/common/iterator'; export interface IExtHostLanguageModels extends ExtHostLanguageModels { } @@ -180,6 +181,18 @@ export class ExtHostLanguageModels implements ExtHostLanguageModelsShape { return data.provider.provideLanguageModelResponse2(messages.map(typeConvert.LanguageModelMessage.to), options, ExtensionIdentifier.toKey(from), progress, token); } + + //#region --- token counting + + $provideTokenLength(handle: number, value: string, token: CancellationToken): Promise { + const data = this._languageModels.get(handle); + if (!data) { + return Promise.resolve(0); + } + return Promise.resolve(data.provider.provideTokenCount(value, token)); + } + + //#region --- making request $updateLanguageModels(data: { added?: ILanguageModelChatMetadata[] | undefined; removed?: string[] | undefined }): void { @@ -378,6 +391,23 @@ export class ExtHostLanguageModels implements ExtHostLanguageModelsShape { return undefined; } return list.has(data.extension); + }, + async computeTokenLength(languageModelId: string, value: string, token?: vscode.CancellationToken): Promise { + + token ??= CancellationToken.None; + + const data = that._allLanguageModelData.get(languageModelId); + if (!data) { + throw LanguageModelError.NotFound(`Language model '${languageModelId}' is unknown.`); + } + + const local = Iterable.find(that._languageModels.values(), candidate => candidate.languageModelId === languageModelId); + if (local) { + // stay inside the EH + return local.provider.provideTokenCount(value, token); + } + + return that._proxy.$countTokens(data.identifier, value, token); } }; } diff --git a/src/vs/workbench/contrib/chat/common/languageModels.ts b/src/vs/workbench/contrib/chat/common/languageModels.ts index 7304d00262887..5c95913f016ae 100644 --- a/src/vs/workbench/contrib/chat/common/languageModels.ts +++ b/src/vs/workbench/contrib/chat/common/languageModels.ts @@ -40,6 +40,7 @@ export interface ILanguageModelChatMetadata { export interface ILanguageModelChat { metadata: ILanguageModelChatMetadata; provideChatResponse(messages: IChatMessage[], from: ExtensionIdentifier, options: { [name: string]: any }, progress: IProgress, token: CancellationToken): Promise; + provideTokenCount(str: string, token: CancellationToken): Promise; } export const ILanguageModelsService = createDecorator('ILanguageModelsService'); @@ -57,6 +58,8 @@ export interface ILanguageModelsService { registerLanguageModelChat(identifier: string, provider: ILanguageModelChat): IDisposable; makeLanguageModelChatRequest(identifier: string, from: ExtensionIdentifier, messages: IChatMessage[], options: { [name: string]: any }, progress: IProgress, token: CancellationToken): Promise; + + computeTokenLength(identifier: string, message: string, token: CancellationToken): Promise; } export class LanguageModelsService implements ILanguageModelsService { @@ -100,4 +103,12 @@ export class LanguageModelsService implements ILanguageModelsService { } return provider.provideChatResponse(messages, from, options, progress, token); } + + computeTokenLength(identifier: string, message: string, token: CancellationToken): Promise { + const provider = this._providers.get(identifier); + if (!provider) { + throw new Error(`Chat response provider with identifier ${identifier} is not registered.`); + } + return provider.provideTokenCount(message, token); + } } diff --git a/src/vscode-dts/vscode.proposed.chatProvider.d.ts b/src/vscode-dts/vscode.proposed.chatProvider.d.ts index b6aa1ffdadfe2..3a4d53afbbf5c 100644 --- a/src/vscode-dts/vscode.proposed.chatProvider.d.ts +++ b/src/vscode-dts/vscode.proposed.chatProvider.d.ts @@ -17,6 +17,8 @@ declare module 'vscode' { */ export interface ChatResponseProvider { provideLanguageModelResponse2(messages: LanguageModelChatMessage[], options: { [name: string]: any }, extensionId: string, progress: Progress, token: CancellationToken): Thenable; + + provideTokenCount(text: string, token: CancellationToken): Thenable; } export interface ChatResponseProviderMetadata { diff --git a/src/vscode-dts/vscode.proposed.languageModels.d.ts b/src/vscode-dts/vscode.proposed.languageModels.d.ts index 98a61ecde4252..4544241b130d1 100644 --- a/src/vscode-dts/vscode.proposed.languageModels.d.ts +++ b/src/vscode-dts/vscode.proposed.languageModels.d.ts @@ -232,6 +232,17 @@ declare module 'vscode' { // TODO@API SYNC or ASYNC? // TODO@API future // retrieveQuota(languageModelId: string): { remaining: number; resets: Date }; + + // TODO@API SHOULD THIS BE in vscode.lm? + // TODO@API should this check for access/permissions? + /** + * + * Compute the token length for the given text + * @param languageModelId + * @param text + * @param token + */ + computeTokenLength(languageModelId: string, text: string, token?: CancellationToken): Thenable; } export interface ExtensionContext {