Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Initial implementation of TF-IDF for similar commands #193418

Merged
merged 1 commit into from
Sep 19, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
35 changes: 35 additions & 0 deletions src/vs/base/common/tfIdf.ts
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,17 @@ export interface TfIdfDocument {

export interface TfIdfScore {
readonly key: string;
/**
* An unbounded number.
*/
readonly score: number;
}

export interface NormalizedTfIdfScore {
readonly key: string;
/**
* A number between 0 and 1.
*/
readonly score: number;
}

Expand Down Expand Up @@ -204,3 +215,27 @@ export class TfIdfCalculator {
return embedding;
}
}

/**
* Normalize the scores to be between 0 and 1 and sort them decending.
* @param scores array of scores from {@link TfIdfCalculator.calculateScores}
* @returns normalized scores
*/
export function normalizeTfIdfScores(scores: TfIdfScore[]): NormalizedTfIdfScore[] {

// copy of scores
const result = scores.slice(0) as { score: number }[];

// sort descending
result.sort((a, b) => b.score - a.score);

// normalize
const max = result[0]?.score ?? 0;
if (max > 0) {
for (const score of result) {
score.score /= max;
}
}

return result as TfIdfScore[];
}
64 changes: 61 additions & 3 deletions src/vs/platform/quickinput/browser/commandsQuickAccess.ts
Original file line number Diff line number Diff line change
Expand Up @@ -8,8 +8,10 @@ import { CancellationToken } from 'vs/base/common/cancellation';
import { toErrorMessage } from 'vs/base/common/errorMessage';
import { isCancellationError } from 'vs/base/common/errors';
import { matchesContiguousSubString, matchesPrefix, matchesWords, or } from 'vs/base/common/filters';
import { once } from 'vs/base/common/functional';
import { Disposable, DisposableStore, IDisposable } from 'vs/base/common/lifecycle';
import { LRUCache } from 'vs/base/common/map';
import { TfIdfCalculator, normalizeTfIdfScores } from 'vs/base/common/tfIdf';
import { localize } from 'vs/nls';
import { ICommandService } from 'vs/platform/commands/common/commands';
import { IConfigurationChangeEvent, IConfigurationService } from 'vs/platform/configuration/common/configuration';
Expand All @@ -25,6 +27,7 @@ import { ITelemetryService } from 'vs/platform/telemetry/common/telemetry';
export interface ICommandQuickPick extends IPickerQuickAccessItem {
readonly commandId: string;
readonly commandAlias?: string;
tfIdfScore?: number;
readonly args?: any[];
}

Expand All @@ -37,6 +40,9 @@ export abstract class AbstractCommandsQuickAccessProvider extends PickerQuickAcc

static PREFIX = '>';

private static readonly TFIDF_THRESHOLD = 0.5;
private static readonly TFIDF_MAX_RESULTS = 5;

private static WORD_FILTER = or(matchesPrefix, matchesWords, matchesContiguousSubString);

private readonly commandsHistory = this._register(this.instantiationService.createInstance(CommandsHistory));
Expand Down Expand Up @@ -65,6 +71,19 @@ export abstract class AbstractCommandsQuickAccessProvider extends PickerQuickAcc
return [];
}

const runTfidf = once(() => {
const tfidf = new TfIdfCalculator();
tfidf.updateDocuments(allCommandPicks.map(commandPick => ({
key: commandPick.commandId,
textChunks: [commandPick.label + (commandPick.commandAlias ? ` ${commandPick.commandAlias}` : '')]
})));
const result = tfidf.calculateScores(filter, token);

return normalizeTfIdfScores(result)
.filter(score => score.score > AbstractCommandsQuickAccessProvider.TFIDF_THRESHOLD)
.slice(0, AbstractCommandsQuickAccessProvider.TFIDF_MAX_RESULTS);
});

// Filter
const filteredCommandPicks: ICommandQuickPick[] = [];
for (const commandPick of allCommandPicks) {
Expand All @@ -85,6 +104,21 @@ export abstract class AbstractCommandsQuickAccessProvider extends PickerQuickAcc
else if (filter === commandPick.commandId) {
filteredCommandPicks.push(commandPick);
}

// Handle tf-idf scoring for the rest if there's a filter
else if (filter.length >= 3) {
const tfidf = runTfidf();
if (token.isCancellationRequested) {
return [];
}

// Add if we have a tf-idf score
const tfidfScore = tfidf.find(score => score.key === commandPick.commandId);
if (tfidfScore) {
commandPick.tfIdfScore = tfidfScore.score;
filteredCommandPicks.push(commandPick);
}
}
}

// Add description to commands that have duplicate labels
Expand All @@ -101,6 +135,18 @@ export abstract class AbstractCommandsQuickAccessProvider extends PickerQuickAcc

// Sort by MRU order and fallback to name otherwise
filteredCommandPicks.sort((commandPickA, commandPickB) => {
// If a result came from tf-idf, we want to put that towards the bottom
if (commandPickA.tfIdfScore && commandPickB.tfIdfScore) {
if (commandPickA.tfIdfScore === commandPickB.tfIdfScore) {
return commandPickA.label.localeCompare(commandPickB.label); // prefer lexicographically smaller command
}
return commandPickB.tfIdfScore - commandPickA.tfIdfScore; // prefer higher tf-idf score
} else if (commandPickA.tfIdfScore) {
return 1; // first command has a score but other doesn't so other wins
} else if (commandPickB.tfIdfScore) {
return -1; // other command has a score but first doesn't so first wins
}

const commandACounter = this.commandsHistory.peek(commandPickA.commandId);
const commandBCounter = this.commandsHistory.peek(commandPickB.commandId);

Expand Down Expand Up @@ -139,6 +185,7 @@ export abstract class AbstractCommandsQuickAccessProvider extends PickerQuickAcc
const commandPicks: Array<ICommandQuickPick | IQuickPickSeparator> = [];

let addOtherSeparator = false;
let addSuggestedSeparator = true;
let addCommonlyUsedSeparator = !!this.options.suggestedCommandIds;
for (let i = 0; i < filteredCommandPicks.length; i++) {
const commandPick = filteredCommandPicks[i];
Expand All @@ -149,15 +196,20 @@ export abstract class AbstractCommandsQuickAccessProvider extends PickerQuickAcc
addOtherSeparator = true;
}

if (addSuggestedSeparator && commandPick.tfIdfScore !== undefined) {
commandPicks.push({ type: 'separator', label: localize('suggested', "similar commands") });
addSuggestedSeparator = false;
}

// Separator: commonly used
if (addCommonlyUsedSeparator && !this.commandsHistory.peek(commandPick.commandId) && this.options.suggestedCommandIds?.has(commandPick.commandId)) {
if (addCommonlyUsedSeparator && commandPick.tfIdfScore === undefined && !this.commandsHistory.peek(commandPick.commandId) && this.options.suggestedCommandIds?.has(commandPick.commandId)) {
commandPicks.push({ type: 'separator', label: localize('commonlyUsed', "commonly used") });
addOtherSeparator = true;
addCommonlyUsedSeparator = false;
}

// Separator: other commands
if (addOtherSeparator && !this.commandsHistory.peek(commandPick.commandId) && !this.options.suggestedCommandIds?.has(commandPick.commandId)) {
if (addOtherSeparator && commandPick.tfIdfScore === undefined && !this.commandsHistory.peek(commandPick.commandId) && !this.options.suggestedCommandIds?.has(commandPick.commandId)) {
commandPicks.push({ type: 'separator', label: localize('morecCommands', "other commands") });
addOtherSeparator = false;
}
Expand All @@ -178,7 +230,13 @@ export abstract class AbstractCommandsQuickAccessProvider extends PickerQuickAcc
return [];
}

return additionalCommandPicks.map(commandPick => this.toCommandPick(commandPick, runOptions));
const commandPicks: Array<ICommandQuickPick | IQuickPickSeparator> = additionalCommandPicks.map(commandPick => this.toCommandPick(commandPick, runOptions));
// Basically, if we haven't already added a separator, we add one before the additional picks so long
// as one hasn't been added to the start of the array.
if (addSuggestedSeparator && commandPicks[0]?.type !== 'separator') {
commandPicks.unshift({ type: 'separator', label: localize('suggested', "similar commands") });
}
return commandPicks;
})()
};
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,7 @@ import { CHAT_OPEN_ACTION_ID } from 'vs/workbench/contrib/chat/browser/actions/c

export class CommandsQuickAccessProvider extends AbstractEditorCommandsQuickAccessProvider {

private static AI_RELATED_INFORMATION_MAX_PICKS = 3;
private static AI_RELATED_INFORMATION_MAX_PICKS = 5;
private static AI_RELATED_INFORMATION_THRESHOLD = 0.8;
private static AI_RELATED_INFORMATION_DEBOUNCE = 200;

Expand Down Expand Up @@ -165,13 +165,6 @@ export class CommandsQuickAccessProvider extends AbstractEditorCommandsQuickAcce
return [];
}

if (additionalPicks.length) {
additionalPicks.unshift({
type: 'separator',
label: localize('similarCommands', "similar commands")
});
}

if (picksSoFar.length || additionalPicks.length) {
additionalPicks.push({
type: 'separator'
Expand Down