Skip to content

Commit

Permalink
first version of token counting API (#210177)
Browse files Browse the repository at this point in the history
  • Loading branch information
jrieken committed Apr 11, 2024
1 parent 004a2c4 commit a808279
Show file tree
Hide file tree
Showing 6 changed files with 66 additions and 1 deletion.
10 changes: 9 additions & 1 deletion src/vs/workbench/api/browser/mainThreadLanguageModels.ts
Expand Up @@ -60,7 +60,10 @@ export class MainThreadLanguageModels implements MainThreadLanguageModelsShape {
} finally {
this._pendingProgress.delete(requestId);
}
}
},
provideTokenCount: (str, token) => {
return this._proxy.$provideTokenLength(handle, str, token);
},
}));
if (metadata.auth) {
dipsosables.add(this._registerAuthenticationProvider(metadata.extension, metadata.auth));
Expand Down Expand Up @@ -119,6 +122,11 @@ export class MainThreadLanguageModels implements MainThreadLanguageModelsShape {
return task;
}


$countTokens(provider: string, value: string, token: CancellationToken): Promise<number> {
return this._chatProviderService.computeTokenLength(provider, value, token);
}

private _registerAuthenticationProvider(extension: ExtensionIdentifier, auth: { providerLabel: string; accountLabel?: string | undefined }): IDisposable {
// This needs to be done in both MainThread & ExtHost ChatProvider
const authProviderId = INTERNAL_AUTH_PROVIDER_PREFIX + extension.value;
Expand Down
3 changes: 3 additions & 0 deletions src/vs/workbench/api/common/extHost.protocol.ts
Expand Up @@ -1199,13 +1199,16 @@ export interface MainThreadLanguageModelsShape extends IDisposable {

$prepareChatAccess(extension: ExtensionIdentifier, providerId: string, justification?: string): Promise<ILanguageModelChatMetadata | undefined>;
$fetchResponse(extension: ExtensionIdentifier, provider: string, requestId: number, messages: IChatMessage[], options: {}, token: CancellationToken): Promise<any>;

$countTokens(provider: string, value: string, token: CancellationToken): Promise<number>;
}

export interface ExtHostLanguageModelsShape {
$updateLanguageModels(data: { added?: ILanguageModelChatMetadata[]; removed?: string[] }): void;
$updateModelAccesslist(data: { from: ExtensionIdentifier; to: ExtensionIdentifier; enabled: boolean }[]): void;
$provideLanguageModelResponse(handle: number, requestId: number, from: ExtensionIdentifier, messages: IChatMessage[], options: { [name: string]: any }, token: CancellationToken): Promise<any>;
$handleResponseFragment(requestId: number, chunk: IChatResponseFragment): Promise<void>;
$provideTokenLength(handle: number, value: string, token: CancellationToken): Promise<number>;
}

export interface IExtensionChatAgentMetadata extends Dto<IChatAgentMetadata> {
Expand Down
30 changes: 30 additions & 0 deletions src/vs/workbench/api/common/extHostLanguageModels.ts
Expand Up @@ -21,6 +21,7 @@ import { createDecorator } from 'vs/platform/instantiation/common/instantiation'
import { IExtHostRpcService } from 'vs/workbench/api/common/extHostRpcService';
import { IExtHostAuthentication } from 'vs/workbench/api/common/extHostAuthentication';
import { ILogService } from 'vs/platform/log/common/log';
import { Iterable } from 'vs/base/common/iterator';

export interface IExtHostLanguageModels extends ExtHostLanguageModels { }

Expand Down Expand Up @@ -180,6 +181,18 @@ export class ExtHostLanguageModels implements ExtHostLanguageModelsShape {
return data.provider.provideLanguageModelResponse2(messages.map(typeConvert.LanguageModelMessage.to), options, ExtensionIdentifier.toKey(from), progress, token);
}


//#region --- token counting

$provideTokenLength(handle: number, value: string, token: CancellationToken): Promise<number> {
const data = this._languageModels.get(handle);
if (!data) {
return Promise.resolve(0);
}
return Promise.resolve(data.provider.provideTokenCount(value, token));
}


//#region --- making request

$updateLanguageModels(data: { added?: ILanguageModelChatMetadata[] | undefined; removed?: string[] | undefined }): void {
Expand Down Expand Up @@ -378,6 +391,23 @@ export class ExtHostLanguageModels implements ExtHostLanguageModelsShape {
return undefined;
}
return list.has(data.extension);
},
async computeTokenLength(languageModelId: string, value: string, token?: vscode.CancellationToken): Promise<number> {

token ??= CancellationToken.None;

const data = that._allLanguageModelData.get(languageModelId);
if (!data) {
throw LanguageModelError.NotFound(`Language model '${languageModelId}' is unknown.`);
}

const local = Iterable.find(that._languageModels.values(), candidate => candidate.languageModelId === languageModelId);
if (local) {
// stay inside the EH
return local.provider.provideTokenCount(value, token);
}

return that._proxy.$countTokens(data.identifier, value, token);
}
};
}
Expand Down
11 changes: 11 additions & 0 deletions src/vs/workbench/contrib/chat/common/languageModels.ts
Expand Up @@ -40,6 +40,7 @@ export interface ILanguageModelChatMetadata {
export interface ILanguageModelChat {
metadata: ILanguageModelChatMetadata;
provideChatResponse(messages: IChatMessage[], from: ExtensionIdentifier, options: { [name: string]: any }, progress: IProgress<IChatResponseFragment>, token: CancellationToken): Promise<any>;
provideTokenCount(str: string, token: CancellationToken): Promise<number>;
}

export const ILanguageModelsService = createDecorator<ILanguageModelsService>('ILanguageModelsService');
Expand All @@ -57,6 +58,8 @@ export interface ILanguageModelsService {
registerLanguageModelChat(identifier: string, provider: ILanguageModelChat): IDisposable;

makeLanguageModelChatRequest(identifier: string, from: ExtensionIdentifier, messages: IChatMessage[], options: { [name: string]: any }, progress: IProgress<IChatResponseFragment>, token: CancellationToken): Promise<any>;

computeTokenLength(identifier: string, message: string, token: CancellationToken): Promise<number>;
}

export class LanguageModelsService implements ILanguageModelsService {
Expand Down Expand Up @@ -100,4 +103,12 @@ export class LanguageModelsService implements ILanguageModelsService {
}
return provider.provideChatResponse(messages, from, options, progress, token);
}

computeTokenLength(identifier: string, message: string, token: CancellationToken): Promise<number> {
const provider = this._providers.get(identifier);
if (!provider) {
throw new Error(`Chat response provider with identifier ${identifier} is not registered.`);
}
return provider.provideTokenCount(message, token);
}
}
2 changes: 2 additions & 0 deletions src/vscode-dts/vscode.proposed.chatProvider.d.ts
Expand Up @@ -17,6 +17,8 @@ declare module 'vscode' {
*/
export interface ChatResponseProvider {
provideLanguageModelResponse2(messages: LanguageModelChatMessage[], options: { [name: string]: any }, extensionId: string, progress: Progress<ChatResponseFragment>, token: CancellationToken): Thenable<any>;

provideTokenCount(text: string, token: CancellationToken): Thenable<number>;
}

export interface ChatResponseProviderMetadata {
Expand Down
11 changes: 11 additions & 0 deletions src/vscode-dts/vscode.proposed.languageModels.d.ts
Expand Up @@ -232,6 +232,17 @@ declare module 'vscode' {
// TODO@API SYNC or ASYNC?
// TODO@API future
// retrieveQuota(languageModelId: string): { remaining: number; resets: Date };

// TODO@API SHOULD THIS BE in vscode.lm?
// TODO@API should this check for access/permissions?
/**
*
* Compute the token length for the given text
* @param languageModelId
* @param text
* @param token
*/
computeTokenLength(languageModelId: string, text: string, token?: CancellationToken): Thenable<number>;
}

export interface ExtensionContext {
Expand Down

0 comments on commit a808279

Please sign in to comment.