Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

introduce a SemanticSimilarity provider model #179640

Merged
merged 4 commits into from Apr 10, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
19 changes: 11 additions & 8 deletions src/vs/base/common/async.ts
Expand Up @@ -94,18 +94,21 @@ export function raceCancellationError<T>(promise: Promise<T>, token: Cancellatio
}

/**
* Returns as soon as one of the promises is resolved and cancels remaining promises
* Returns as soon as one of the promises resolves or rejects and cancels remaining promises
*/
export async function raceCancellablePromises<T>(cancellablePromises: CancelablePromise<T>[]): Promise<T> {
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think this is an improvement, thanks, but it is a behavior change so if you haven't already, make sure you take a look at the other places it's used and try to check whether that might be an issue

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yeah it's only used in one other place and I think based on the code it seems ok. @sandy081 FYI

let resolvedPromiseIndex = -1;
const promises = cancellablePromises.map((promise, index) => promise.then(result => { resolvedPromiseIndex = index; return result; }));
const result = await Promise.race(promises);
cancellablePromises.forEach((cancellablePromise, index) => {
if (index !== resolvedPromiseIndex) {
cancellablePromise.cancel();
}
});
return result;
try {
const result = await Promise.race(promises);
return result;
} finally {
cancellablePromises.forEach((cancellablePromise, index) => {
if (index !== resolvedPromiseIndex) {
cancellablePromise.cancel();
}
});
}
}

export function raceTimeout<T>(promise: Promise<T>, timeout: number, onTimeout?: () => void): Promise<T | undefined> {
Expand Down
1 change: 1 addition & 0 deletions src/vs/workbench/api/browser/extensionHost.contribution.ts
Expand Up @@ -81,6 +81,7 @@ import './mainThreadTimeline';
import './mainThreadTesting';
import './mainThreadSecretState';
import './mainThreadProfilContentHandlers';
import './mainThreadSemanticSimilarity';

export class ExtensionPoints implements IWorkbenchContribution {

Expand Down
36 changes: 36 additions & 0 deletions src/vs/workbench/api/browser/mainThreadSemanticSimilarity.ts
@@ -0,0 +1,36 @@
/*---------------------------------------------------------------------------------------------
* Copyright (c) Microsoft Corporation. All rights reserved.
* Licensed under the MIT License. See License.txt in the project root for license information.
*--------------------------------------------------------------------------------------------*/

import { Disposable, DisposableMap } from 'vs/base/common/lifecycle';
import { ExtHostContext, ExtHostSemanticSimilarityShape, MainContext, MainThreadSemanticSimilarityShape } from 'vs/workbench/api/common/extHost.protocol';
import { IExtHostContext, extHostNamedCustomer } from 'vs/workbench/services/extensions/common/extHostCustomers';
import { ISemanticSimilarityProvider, ISemanticSimilarityService } from 'vs/workbench/services/semanticSimilarity/common/semanticSimilarityService';

@extHostNamedCustomer(MainContext.MainThreadSemanticSimilarity)
export class MainThreadSemanticSimilarity extends Disposable implements MainThreadSemanticSimilarityShape {
private readonly _proxy: ExtHostSemanticSimilarityShape;
private readonly _registrations = this._register(new DisposableMap<number>());

constructor(
context: IExtHostContext,
@ISemanticSimilarityService private readonly _semanticSimilarityService: ISemanticSimilarityService
) {
super();
this._proxy = context.getProxy(ExtHostContext.ExtHostSemanticSimilarity);
}

$registerSemanticSimilarityProvider(handle: number): void {
const provider: ISemanticSimilarityProvider = {
provideSimilarityScore: (string1, comparisons, token) => {
return this._proxy.$provideSimilarityScore(handle, string1, comparisons, token);
},
};
this._registrations.set(handle, this._semanticSimilarityService.registerSemanticSimilarityProvider(provider));
}

$unregisterSemanticSimilarityProvider(handle: number): void {
this._registrations.deleteAndDispose(handle);
}
}
11 changes: 11 additions & 0 deletions src/vs/workbench/api/common/extHost.api.impl.ts
Expand Up @@ -100,6 +100,7 @@ import { ExtHostQuickDiff } from 'vs/workbench/api/common/extHostQuickDiff';
import { ExtHostInteractiveSession } from 'vs/workbench/api/common/extHostInteractiveSession';
import { ExtHostInteractiveEditor } from 'vs/workbench/api/common/extHostInteractiveEditor';
import { ExtHostNotebookDocumentSaveParticipant } from 'vs/workbench/api/common/extHostNotebookDocumentSaveParticipant';
import { ExtHostSemanticSimilarity } from 'vs/workbench/api/common/extHostSemanticSimilarity';

export interface IExtensionRegistries {
mine: ExtensionDescriptionRegistry;
Expand Down Expand Up @@ -197,6 +198,7 @@ export function createApiFactoryAndRegisterActors(accessor: ServicesAccessor): I
rpcProtocol.set(ExtHostContext.ExtHostInteractive, new ExtHostInteractive(rpcProtocol, extHostNotebook, extHostDocumentsAndEditors, extHostCommands, extHostLogService));
const extHostInteractiveEditor = rpcProtocol.set(ExtHostContext.ExtHostInteractiveEditor, new ExtHostInteractiveEditor(rpcProtocol, extHostDocuments, extHostLogService));
const extHostInteractiveSession = rpcProtocol.set(ExtHostContext.ExtHostInteractiveSession, new ExtHostInteractiveSession(rpcProtocol, extHostLogService));
const extHostSemanticSimilarity = rpcProtocol.set(ExtHostContext.ExtHostSemanticSimilarity, new ExtHostSemanticSimilarity(rpcProtocol));

// Check that no named customers are missing
const expected = Object.values<ProxyIdentifier<any>>(ExtHostContext);
Expand Down Expand Up @@ -1248,9 +1250,18 @@ export function createApiFactoryAndRegisterActors(accessor: ServicesAccessor): I
}
};

// namespace: ai
const ai: typeof vscode.ai = {
registerSemanticSimilarityProvider(provider: vscode.SemanticSimilarityProvider) {
checkProposedApiEnabled(extension, 'semanticSimilarity');
return extHostSemanticSimilarity.registerSemanticSimilarityProvider(extension, provider);
}
};

return <typeof vscode>{
version: initData.version,
// namespaces
ai,
authentication,
commands,
comments,
Expand Down
13 changes: 12 additions & 1 deletion src/vs/workbench/api/common/extHost.protocol.ts
Expand Up @@ -1555,6 +1555,15 @@ export interface ExtHostAuthenticationShape {
$setProviders(providers: AuthenticationProviderInformation[]): Promise<void>;
}

export interface ExtHostSemanticSimilarityShape {
$provideSimilarityScore(handle: number, string1: string, comparisons: string[], token: CancellationToken): Promise<number[]>;
}

export interface MainThreadSemanticSimilarityShape extends IDisposable {
$registerSemanticSimilarityProvider(handle: number): void;
$unregisterSemanticSimilarityProvider(handle: number): void;
}

export interface ExtHostSecretStateShape {
$onDidChangePassword(e: { extensionId: string; key: string }): Promise<void>;
}
Expand Down Expand Up @@ -2452,7 +2461,8 @@ export const MainContext = {
MainThreadTunnelService: createProxyIdentifier<MainThreadTunnelServiceShape>('MainThreadTunnelService'),
MainThreadTimeline: createProxyIdentifier<MainThreadTimelineShape>('MainThreadTimeline'),
MainThreadTesting: createProxyIdentifier<MainThreadTestingShape>('MainThreadTesting'),
MainThreadLocalization: createProxyIdentifier<MainThreadLocalizationShape>('MainThreadLocalizationShape')
MainThreadLocalization: createProxyIdentifier<MainThreadLocalizationShape>('MainThreadLocalizationShape'),
MainThreadSemanticSimilarity: createProxyIdentifier<MainThreadSemanticSimilarityShape>('MainThreadSemanticSimilarity')
};

export const ExtHostContext = {
Expand Down Expand Up @@ -2506,6 +2516,7 @@ export const ExtHostContext = {
ExtHostInteractive: createProxyIdentifier<ExtHostInteractiveShape>('ExtHostInteractive'),
ExtHostInteractiveEditor: createProxyIdentifier<ExtHostInteractiveEditorShape>('ExtHostInteractiveEditor'),
ExtHostInteractiveSession: createProxyIdentifier<ExtHostInteractiveSessionShape>('ExtHostInteractiveSession'),
ExtHostSemanticSimilarity: createProxyIdentifier<ExtHostSemanticSimilarityShape>('ExtHostSemanticSimilarity'),
ExtHostTheming: createProxyIdentifier<ExtHostThemingShape>('ExtHostTheming'),
ExtHostTunnelService: createProxyIdentifier<ExtHostTunnelServiceShape>('ExtHostTunnelService'),
ExtHostAuthentication: createProxyIdentifier<ExtHostAuthenticationShape>('ExtHostAuthentication'),
Expand Down
47 changes: 47 additions & 0 deletions src/vs/workbench/api/common/extHostSemanticSimilarity.ts
@@ -0,0 +1,47 @@
/*---------------------------------------------------------------------------------------------
* Copyright (c) Microsoft Corporation. All rights reserved.
* Licensed under the MIT License. See License.txt in the project root for license information.
*--------------------------------------------------------------------------------------------*/

import { IExtensionDescription } from 'vs/platform/extensions/common/extensions';
import { ExtHostSemanticSimilarityShape, IMainContext, MainContext, MainThreadSemanticSimilarityShape } from 'vs/workbench/api/common/extHost.protocol';
import type { CancellationToken, SemanticSimilarityProvider } from 'vscode';
import { Disposable } from 'vs/workbench/api/common/extHostTypes';

export class ExtHostSemanticSimilarity implements ExtHostSemanticSimilarityShape {
private _semanticSimilarityProviders: Map<number, SemanticSimilarityProvider> = new Map();
private _nextHandle = 0;

private readonly _proxy: MainThreadSemanticSimilarityShape;

constructor(
mainContext: IMainContext
) {
this._proxy = mainContext.getProxy(MainContext.MainThreadSemanticSimilarity);
}

async $provideSimilarityScore(handle: number, string1: string, comparisons: string[], token: CancellationToken): Promise<number[]> {
if (this._semanticSimilarityProviders.size === 0) {
throw new Error('No semantic similarity providers registered');
}

const provider = this._semanticSimilarityProviders.get(handle);
if (!provider) {
throw new Error('Semantic similarity provider not found');
}

const result = await provider.provideSimilarityScore(string1, comparisons, token);
return result;
}

registerSemanticSimilarityProvider(extension: IExtensionDescription, provider: SemanticSimilarityProvider): Disposable {
const handle = this._nextHandle;
this._nextHandle++;
this._semanticSimilarityProviders.set(handle, provider);
this._proxy.$registerSemanticSimilarityProvider(handle);
return new Disposable(() => {
this._proxy.$unregisterSemanticSimilarityProvider(handle);
this._semanticSimilarityProviders.delete(handle);
});
}
}
Expand Up @@ -33,7 +33,7 @@ import { IPreferencesService } from 'vs/workbench/services/preferences/common/pr
import { stripIcons } from 'vs/base/common/iconLabels';
import { isFirefox } from 'vs/base/browser/browser';
import { IProductService } from 'vs/platform/product/common/productService';
import { ISemanticSimilarityService } from 'vs/workbench/contrib/quickaccess/browser/semanticSimilarityService';
import { ISemanticSimilarityService } from 'vs/workbench/services/semanticSimilarity/common/semanticSimilarityService';
import { timeout } from 'vs/base/common/async';

export class CommandsQuickAccessProvider extends AbstractEditorCommandsQuickAccessProvider {
Expand Down
Expand Up @@ -62,6 +62,7 @@ export const allApiProposals = Object.freeze({
scmActionButton: 'https://raw.githubusercontent.com/microsoft/vscode/main/src/vscode-dts/vscode.proposed.scmActionButton.d.ts',
scmSelectedProvider: 'https://raw.githubusercontent.com/microsoft/vscode/main/src/vscode-dts/vscode.proposed.scmSelectedProvider.d.ts',
scmValidation: 'https://raw.githubusercontent.com/microsoft/vscode/main/src/vscode-dts/vscode.proposed.scmValidation.d.ts',
semanticSimilarity: 'https://raw.githubusercontent.com/microsoft/vscode/main/src/vscode-dts/vscode.proposed.semanticSimilarity.d.ts',
showLocal: 'https://raw.githubusercontent.com/microsoft/vscode/main/src/vscode-dts/vscode.proposed.showLocal.d.ts',
tabInputTextMerge: 'https://raw.githubusercontent.com/microsoft/vscode/main/src/vscode-dts/vscode.proposed.tabInputTextMerge.d.ts',
taskPresentationGroup: 'https://raw.githubusercontent.com/microsoft/vscode/main/src/vscode-dts/vscode.proposed.taskPresentationGroup.d.ts',
Expand Down
Expand Up @@ -5,25 +5,110 @@

import { CommandsRegistry, ICommandService } from 'vs/platform/commands/common/commands';
import { IFileService } from 'vs/platform/files/common/files';
import { InstantiationType, registerSingleton } from 'vs/platform/instantiation/common/extensions';
import { createDecorator } from 'vs/platform/instantiation/common/instantiation';
import { URI } from 'vs/base/common/uri';
import { CancellationToken } from 'vs/base/common/cancellation';
import { raceCancellation } from 'vs/base/common/async';
import { CancelablePromise, createCancelablePromise, raceCancellablePromises, raceCancellation, timeout } from 'vs/base/common/async';
import { IDisposable } from 'vs/base/common/lifecycle';
import { InstantiationType, registerSingleton } from 'vs/platform/instantiation/common/extensions';

export const ISemanticSimilarityService = createDecorator<ISemanticSimilarityService>('IEmbeddingsService');
export const ISemanticSimilarityService = createDecorator<ISemanticSimilarityService>('ISemanticSimilarityService');

export interface ISemanticSimilarityService {
readonly _serviceBrand: undefined;

isEnabled(): boolean;
getSimilarityScore(string1: string, comparisons: string[], token: CancellationToken): Promise<number[]>;
registerSemanticSimilarityProvider(provider: ISemanticSimilarityProvider): IDisposable;
}

export interface ISemanticSimilarityProvider {
provideSimilarityScore(string1: string, comparisons: string[], token: CancellationToken): Promise<number[]>;
}

export class SemanticSimilarityService implements ISemanticSimilarityService {
readonly _serviceBrand: undefined;

static readonly DEFAULT_TIMEOUT = 1000 * 10; // 30 seconds
TylerLeonhardt marked this conversation as resolved.
Show resolved Hide resolved

private readonly _providers: ISemanticSimilarityProvider[] = [];
// remove when we move over to API
private readonly oldService: OldSemanticSimilarityService;

constructor(
// for the old service
@ICommandService commandService: ICommandService,
@IFileService fileService: IFileService
) {
this.oldService = new OldSemanticSimilarityService(commandService, fileService);
}

isEnabled(): boolean {
return this._providers.length > 0;
}

registerSemanticSimilarityProvider(provider: ISemanticSimilarityProvider): IDisposable {
this._providers.push(provider);
return {
dispose: () => {
const index = this._providers.indexOf(provider);
if (index >= 0) {
this._providers.splice(index, 1);
}
}
};
}

async getSimilarityScore(string1: string, comparisons: string[], token: CancellationToken): Promise<number[]> {
if (this._providers.length === 0) {
// Remove when we have a provider shipping in extensions
if (this.oldService.isEnabled()) {
return this.oldService.getSimilarityScore(string1, comparisons, token);
}
throw new Error('No semantic similarity providers registered');
}

const cancellablePromises: Array<CancelablePromise<number[]>> = [];

const timer = timeout(SemanticSimilarityService.DEFAULT_TIMEOUT);
const disposible = token.onCancellationRequested(() => {
disposible.dispose();
timer.cancel();
});

for (const provider of this._providers) {
cancellablePromises.push(createCancelablePromise(async t => {
try {
return await provider.provideSimilarityScore(string1, comparisons, t);
} catch (e) {
// logged in extension host
}
await timer;
TylerLeonhardt marked this conversation as resolved.
Show resolved Hide resolved
throw new Error('Semantic similarity provider timed out');
}));
}

cancellablePromises.push(createCancelablePromise(async (t) => {
const disposible = t.onCancellationRequested(() => {
timer.cancel();
disposible.dispose();
});
await timer;
throw new Error('Semantic similarity provider timed out');
}));

const result = await raceCancellablePromises(cancellablePromises);
return result;
}
}

// TODO: remove this when the extensions are updated

interface ICommandsEmbeddingsCache {
[commandId: string]: { embedding: number[] };
}

// TODO: use proper API for this instead of commands
export class SemanticSimilarityService implements ISemanticSimilarityService {
class OldSemanticSimilarityService {
declare _serviceBrand: undefined;

static readonly CALCULATE_EMBEDDING_COMMAND_ID = '_vscode.ai.calculateEmbedding';
Expand All @@ -39,7 +124,7 @@ export class SemanticSimilarityService implements ISemanticSimilarityService {
}

private async loadCache(): Promise<ICommandsEmbeddingsCache> {
const path = await this.commandService.executeCommand<string>(SemanticSimilarityService.COMMAND_EMBEDDING_CACHE_COMMAND_ID);
const path = await this.commandService.executeCommand<string>(OldSemanticSimilarityService.COMMAND_EMBEDDING_CACHE_COMMAND_ID);
if (!path) {
return {};
}
Expand All @@ -48,11 +133,10 @@ export class SemanticSimilarityService implements ISemanticSimilarityService {
}

isEnabled(): boolean {
return !!CommandsRegistry.getCommand(SemanticSimilarityService.CALCULATE_EMBEDDING_COMMAND_ID);
return !!CommandsRegistry.getCommand(OldSemanticSimilarityService.CALCULATE_EMBEDDING_COMMAND_ID);
}

async getSimilarityScore(str: string, comparisons: string[], token: CancellationToken): Promise<number[]> {

const embedding1 = await this.computeEmbedding(str, token);
const scores: number[] = [];
for (const comparison of comparisons) {
Expand All @@ -74,7 +158,7 @@ export class SemanticSimilarityService implements ISemanticSimilarityService {
if (!this.isEnabled()) {
throw new Error('Embeddings are not enabled');
}
const result = await raceCancellation(this.commandService.executeCommand<number[][]>(SemanticSimilarityService.CALCULATE_EMBEDDING_COMMAND_ID, text), token);
const result = await raceCancellation(this.commandService.executeCommand<number[][]>(OldSemanticSimilarityService.CALCULATE_EMBEDDING_COMMAND_ID, text), token);
if (!result) {
throw new Error('No result');
}
Expand Down
22 changes: 22 additions & 0 deletions src/vscode-dts/vscode.proposed.semanticSimilarity.d.ts
@@ -0,0 +1,22 @@
/*---------------------------------------------------------------------------------------------
* Copyright (c) Microsoft Corporation. All rights reserved.
* Licensed under the MIT License. See License.txt in the project root for license information.
*--------------------------------------------------------------------------------------------*/

declare module 'vscode' {
export interface SemanticSimilarityProvider {
/**
* Computes the semantic similarity score between two strings.
* @param string1 The string to compare to all other strings.
* @param comparisons An array of strings to compare string1 to. An array allows you to batch multiple comparisons in one call.
* @param token A cancellation token.
* @return A promise that resolves to the semantic similarity scores between string1 and each string in comparisons.
* The score should be a number between 0 and 1, where 0 means no similarity and 1 means
* perfect similarity.
*/
provideSimilarityScore(string1: string, comparisons: string[], token: CancellationToken): Thenable<number[]>;
}
export namespace ai {
export function registerSemanticSimilarityProvider(provider: SemanticSimilarityProvider): Disposable;
}
}