Skip to content

Commit

Permalink
Encapsulates logic within AI providers better
Browse files Browse the repository at this point in the history
  • Loading branch information
eamodio committed Apr 29, 2024
1 parent 0bdd084 commit 02f44f5
Show file tree
Hide file tree
Showing 5 changed files with 491 additions and 262 deletions.
155 changes: 73 additions & 82 deletions src/ai/aiProviderService.ts
Original file line number Diff line number Diff line change
Expand Up @@ -10,18 +10,54 @@ import type { Repository } from '../git/models/repository';
import { isRepository } from '../git/models/repository';
import { showAIModelPicker } from '../quickpicks/aiModelPicker';
import { configuration } from '../system/configuration';
import { getSettledValue } from '../system/promise';
import type { Storage } from '../system/storage';
import { supportedInVSCodeVersion } from '../system/utils';
import { AnthropicProvider } from './anthropicProvider';
import { GeminiProvider } from './geminiProvider';
import { OpenAIProvider } from './openaiProvider';

export interface AIModel<
Provider extends AIProviders = AIProviders,
Model extends AIModels<Provider> = AIModels<Provider>,
> {
readonly id: Model;
readonly name: string;
readonly maxTokens: number;
readonly provider: {
id: Provider;
name: string;
};

readonly default?: boolean;
readonly hidden?: boolean;
}

interface AIProviderConstructor<Provider extends AIProviders = AIProviders> {
new (container: Container): AIProvider<Provider>;
}

const _supportedProviderTypes = new Map<AIProviders, AIProviderConstructor>([
['openai', OpenAIProvider],
['anthropic', AnthropicProvider],
['gemini', GeminiProvider],
]);

export interface AIProvider<Provider extends AIProviders = AIProviders> extends Disposable {
readonly id: Provider;
readonly name: string;

generateCommitMessage(diff: string, options?: { context?: string }): Promise<string | undefined>;
explainChanges(message: string, diff: string): Promise<string | undefined>;
getModels(): Promise<readonly AIModel<Provider, AIModels<Provider>>[]>;

explainChanges(
message: string,
diff: string,
options?: { cancellation?: CancellationToken },
): Promise<string | undefined>;
generateCommitMessage(
diff: string,
options?: { cancellation?: CancellationToken; context?: string },
): Promise<string | undefined>;
}

export class AIProviderService implements Disposable {
Expand All @@ -37,19 +73,25 @@ export class AIProviderService implements Disposable {
return this._provider?.id;
}

public async generateCommitMessage(
async getModels(): Promise<readonly AIModel[]> {
const providers = [..._supportedProviderTypes.values()].map(p => new p(this.container));
const models = await Promise.allSettled(providers.map(p => p.getModels()));
return models.flatMap(m => getSettledValue(m, []));
}

async generateCommitMessage(
changes: string[],
options?: { cancellation?: CancellationToken; context?: string; progress?: ProgressOptions },
): Promise<string | undefined>;
public async generateCommitMessage(
async generateCommitMessage(
repoPath: Uri,
options?: { cancellation?: CancellationToken; context?: string; progress?: ProgressOptions },
): Promise<string | undefined>;
public async generateCommitMessage(
async generateCommitMessage(
repository: Repository,
options?: { cancellation?: CancellationToken; context?: string; progress?: ProgressOptions },
): Promise<string | undefined>;
public async generateCommitMessage(
async generateCommitMessage(
changesOrRepoOrPath: string[] | Repository | Uri,
options?: { cancellation?: CancellationToken; context?: string; progress?: ProgressOptions },
): Promise<string | undefined> {
Expand Down Expand Up @@ -81,25 +123,31 @@ export class AIProviderService implements Disposable {

if (options?.progress != null) {
return window.withProgress(options.progress, async () =>
provider.generateCommitMessage(changes, { context: options?.context }),
provider.generateCommitMessage(changes, {
cancellation: options?.cancellation,
context: options?.context,
}),
);
}
return provider.generateCommitMessage(changes, { context: options?.context });
return provider.generateCommitMessage(changes, {
cancellation: options?.cancellation,
context: options?.context,
});
}

async explainCommit(
repoPath: string | Uri,
sha: string,
options?: { progress?: ProgressOptions },
options?: { cancellation?: CancellationToken; progress?: ProgressOptions },
): Promise<string | undefined>;
async explainCommit(
commit: GitRevisionReference | GitCommit,
options?: { progress?: ProgressOptions },
options?: { cancellation?: CancellationToken; progress?: ProgressOptions },
): Promise<string | undefined>;
async explainCommit(
commitOrRepoPath: string | Uri | GitRevisionReference | GitCommit,
shaOrOptions?: string | { progress?: ProgressOptions },
options?: { progress?: ProgressOptions },
options?: { cancellation?: CancellationToken; progress?: ProgressOptions },
): Promise<string | undefined> {
let commit: GitCommit | undefined;
if (typeof commitOrRepoPath === 'string' || commitOrRepoPath instanceof Uri) {
Expand Down Expand Up @@ -132,10 +180,10 @@ export class AIProviderService implements Disposable {

if (options?.progress != null) {
return window.withProgress(options.progress, async () =>
provider.explainChanges(commit.message, diff.contents),
provider.explainChanges(commit.message, diff.contents, { cancellation: options?.cancellation }),
);
}
return provider.explainChanges(commit.message, diff.contents);
return provider.explainChanges(commit.message, diff.contents, { cancellation: options?.cancellation });
}

reset() {
Expand All @@ -149,7 +197,7 @@ export class AIProviderService implements Disposable {
}

supports(provider: AIProviders | string) {
return provider === 'anthropic' || provider === 'gemini' || provider === 'openai';
return _supportedProviderTypes.has(provider as AIProviders);
}

async switchProvider() {
Expand All @@ -159,7 +207,7 @@ export class AIProviderService implements Disposable {
private async getOrChooseProvider(force?: boolean): Promise<AIProvider | undefined> {
let providerId = !force ? configuration.get('ai.experimental.provider') || undefined : undefined;
if (providerId == null || !this.supports(providerId)) {
const pick = await showAIModelPicker();
const pick = await showAIModelPicker(this.container);
if (pick == null) return undefined;

providerId = pick.provider;
Expand All @@ -170,22 +218,13 @@ export class AIProviderService implements Disposable {
if (providerId !== this._provider?.id) {
this._provider?.dispose();

switch (providerId) {
case 'anthropic':
this._provider = new AnthropicProvider(this.container);
break;

case 'gemini':
this._provider = new GeminiProvider(this.container);
break;

case 'openai':
this._provider = new OpenAIProvider(this.container);
break;

default:
this._provider = new OpenAIProvider(this.container);
await configuration.updateEffective('ai.experimental.provider', 'openai');
let type = _supportedProviderTypes.get(providerId);
if (type == null && providerId !== 'openai') {
type = _supportedProviderTypes.get('openai');
await configuration.updateEffective('ai.experimental.provider', 'openai');
}
if (type != null) {
this._provider = new type(this.container);
}
}

Expand All @@ -204,7 +243,7 @@ async function confirmAIProviderToS(provider: AIProvider, storage: Storage): Pro
const acceptAlways: MessageItem = { title: 'Always' };
const decline: MessageItem = { title: 'No', isCloseAffordance: true };
const result = await window.showInformationMessage(
`This GitLens experimental feature requires sending a diff of the code changes to ${provider.name}. This may contain sensitive information.\n\nDo you want to continue?`,
`GitLens experimental AI features require sending a diff of the code changes to ${provider.name} for analysis. This may contain sensitive information.\n\nDo you want to continue?`,
{ modal: true },
accept,
acceptWorkspace,
Expand All @@ -227,57 +266,9 @@ async function confirmAIProviderToS(provider: AIProvider, storage: Storage): Pro
return false;
}

export function getMaxCharacters(model: AIModels, outputLength: number): number {
export function getMaxCharacters(model: AIModel, outputLength: number): number {
const tokensPerCharacter = 3.1;

let tokens;
switch (model) {
case 'gpt-4-turbo': // 128,000 tokens (4,096 max output tokens)
case 'gpt-4-turbo-2024-04-09':
case 'gpt-4-turbo-preview':
case 'gpt-4-0125-preview':
case 'gpt-4-1106-preview':
tokens = 128000;
break;
case 'gpt-4': // 8,192 tokens
case 'gpt-4-0613':
tokens = 8192;
break;
case 'gpt-4-32k': // 32,768 tokens
case 'gpt-4-32k-0613':
tokens = 32768;
break;
case 'gpt-3.5-turbo': // 16,385 tokens (4,096 max output tokens)
case 'gpt-3.5-turbo-0125':
case 'gpt-3.5-turbo-1106':
case 'gpt-3.5-turbo-16k': // (Legacy)
tokens = 16385;
break;

case 'claude-3-opus-20240229': // 200,000 tokens
case 'claude-3-sonnet-20240229':
case 'claude-3-haiku-20240307':
case 'claude-2.1':
tokens = 200000;
break;
case 'claude-2': // 100,000 tokens
case 'claude-instant-1':
tokens = 100000;
break;

case 'gemini-1.0-pro': // 30,720 tokens
tokens = 30720;
break;
case 'gemini-1.5-pro-latest': // 1,048,576 tokens
tokens = 1048576;
break;

default: // 4,096 tokens
tokens = 4096;
break;
}

const max = tokens * tokensPerCharacter - outputLength / tokensPerCharacter;
const max = model.maxTokens * tokensPerCharacter - outputLength / tokensPerCharacter;
return Math.floor(max - max * 0.1);
}

Expand Down
Loading

0 comments on commit 02f44f5

Please sign in to comment.