From 48a02e887a01288c423eb2c583c4959eb2b46a8d Mon Sep 17 00:00:00 2001 From: Nate Sesti <33237525+sestinj@users.noreply.github.com> Date: Sun, 26 May 2024 11:33:46 -0700 Subject: [PATCH] Free Trial Auth (#1367) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit * 🎨 pass gh auth token to free trial * 💄 improved onboarding flow * 🎨 make fewer auth reqs * 🎨 refactor ordering of vscode deps * 🐛 resolve confighandler --- core/config/default.ts | 77 +++---- core/config/handler.ts | 10 +- core/config/load.ts | 12 ++ core/config/onboarding.ts | 55 +++-- core/core.ts | 20 +- core/index.d.ts | 1 + core/llm/llms/FreeTrial.ts | 12 ++ core/protocol/core.ts | 3 +- core/protocol/ide.ts | 2 + core/protocol/ideWebview.ts | 2 + core/util/filesystem.ts | 7 +- core/util/messageIde.ts | 3 + extensions/vscode/src/commands.ts | 3 - extensions/vscode/src/debugPanel.ts | 10 +- .../vscode/src/extension/VsCodeMessenger.ts | 47 +++-- .../vscode/src/extension/vscodeExtension.ts | 69 +++---- extensions/vscode/src/ideProtocol.ts | 24 +++ extensions/vscode/src/webviewProtocol.ts | 22 +- gui/src/App.tsx | 5 +- gui/src/components/Layout.tsx | 21 +- gui/src/components/loaders/ProgressBar.tsx | 13 +- .../quickSetup/QuickModelSetup.tsx | 19 +- gui/src/hooks/useInputHistory.ts | 2 +- gui/src/pages/gui.tsx | 4 +- gui/src/pages/onboarding/apiKeyOnboarding.tsx | 8 +- gui/src/pages/onboarding/onboarding.tsx | 194 ++++++++++++------ 26 files changed, 382 insertions(+), 263 deletions(-) diff --git a/core/config/default.ts b/core/config/default.ts index 25b265228f..da50659098 100644 --- a/core/config/default.ts +++ b/core/config/default.ts @@ -1,35 +1,33 @@ import { ContextProviderWithParams, + ModelDescription, SerializedContinueConfig, } from "../index.js"; +export const FREE_TRIAL_MODELS: ModelDescription[] = [ + { + title: "GPT-4o (Free Trial)", + provider: "free-trial", + model: "gpt-4o", + systemMessage: + "You are an expert software developer. You give helpful and concise responses.", + }, + { + title: "Llama3 70b (Free Trial)", + provider: "free-trial", + model: "llama3-70b", + systemMessage: + "You are an expert software developer. You give helpful and concise responses. Whenever you write a code block you include the language after the opening ticks.", + }, + { + title: "Claude 3 Sonnet (Free Trial)", + provider: "free-trial", + model: "claude-3-sonnet-20240229", + }, +]; + export const defaultConfig: SerializedContinueConfig = { - models: [ - // { - // title: "Codestral (Free Trial)", - // provider: "free-trial", - // model: "codestral", - // }, - { - title: "GPT-4o (Free Trial)", - provider: "free-trial", - model: "gpt-4o", - systemMessage: - "You are an expert software developer. You give helpful and concise responses.", - }, - { - title: "Llama3 70b (Free Trial)", - provider: "free-trial", - model: "llama3-70b", - systemMessage: - "You are an expert software developer. You give helpful and concise responses. Whenever you write a code block you include the language after the opening ticks.", - }, - { - title: "Claude 3 Sonnet (Free Trial)", - provider: "free-trial", - model: "claude-3-sonnet-20240229", - }, - ], + models: FREE_TRIAL_MODELS, customCommands: [ { name: "test", @@ -46,32 +44,7 @@ export const defaultConfig: SerializedContinueConfig = { }; export const defaultConfigJetBrains: SerializedContinueConfig = { - models: [ - // { - // title: "Codestral (Free Trial)", - // provider: "free-trial", - // model: "codestral", - // }, - { - title: "GPT-4o (Free Trial)", - provider: "free-trial", - model: "gpt-4o", - systemMessage: - "You are an expert software developer. You give helpful and concise responses.", - }, - { - title: "Llama3 70b (Free Trial)", - provider: "free-trial", - model: "llama3-70b", - systemMessage: - "You are an expert software developer. You give helpful and concise responses. Whenever you write a code block you include the language after the opening ticks.", - }, - { - title: "Claude 3 Sonnet (Free Trial)", - provider: "free-trial", - model: "claude-3-sonnet-20240229", - }, - ], + models: FREE_TRIAL_MODELS, customCommands: [ { name: "test", diff --git a/core/config/handler.ts b/core/config/handler.ts index cf87c56667..f3ed73e93f 100644 --- a/core/config/handler.ts +++ b/core/config/handler.ts @@ -19,7 +19,6 @@ export class ConfigHandler { private readonly ide: IDE, private ideSettingsPromise: Promise, private readonly writeLog: (text: string) => Promise, - private readonly onConfigUpdate: () => void, ) { this.ide = ide; this.ideSettingsPromise = ideSettingsPromise; @@ -44,11 +43,10 @@ export class ConfigHandler { reloadConfig() { this.savedConfig = undefined; this.savedBrowserConfig = undefined; - this.loadConfig().then(() => { - for (const listener of this.updateListeners) { - listener(); - } - }); + this.loadConfig(); + for (const listener of this.updateListeners) { + listener(); + } } async getSerializedConfig(): Promise { diff --git a/core/config/load.ts b/core/config/load.ts index 76b9417c96..0833c0eff5 100644 --- a/core/config/load.ts +++ b/core/config/load.ts @@ -31,6 +31,7 @@ import TransformersJsEmbeddingsProvider from "../indexing/embeddings/Transformer import { AllEmbeddingsProviders } from "../indexing/embeddings/index.js"; import { BaseLLM } from "../llm/index.js"; import CustomLLMClass from "../llm/llms/CustomLLM.js"; +import FreeTrial from "../llm/llms/FreeTrial.js"; import { llmFromDescription } from "../llm/llms/index.js"; import { IdeSettings } from "../protocol/ideWebview.js"; import { fetchwithRequestOptions } from "../util/fetchWithOptions.js"; @@ -277,6 +278,17 @@ async function intermediateToFinalConfig( }; } + // Obtain auth token (only if free trial being used) + const freeTrialModels = models.filter( + (model) => model.providerName === "free-trial", + ); + if (freeTrialModels.length > 0) { + const ghAuthToken = await ide.getGitHubAuthToken(); + for (const model of freeTrialModels) { + (model as FreeTrial).setupGhAuthToken(ghAuthToken); + } + } + // Tab autocomplete model let autocompleteLlm: BaseLLM | undefined = undefined; if (config.tabAutocompleteModel) { diff --git a/core/config/onboarding.ts b/core/config/onboarding.ts index ce5a46f037..f6fc278bb3 100644 --- a/core/config/onboarding.ts +++ b/core/config/onboarding.ts @@ -1,38 +1,12 @@ import { SerializedContinueConfig } from "../index.js"; +import { FREE_TRIAL_MODELS } from "./default.js"; -export const TRIAL_FIM_MODEL = "codestral-latest"; - -export function setupOptimizedMode( +export function setupApiKeysMode( config: SerializedContinueConfig, ): SerializedContinueConfig { return { ...config, - models: [ - // { - // title: "Codestral (Free Trial)", - // provider: "free-trial", - // model: "codestral", - // }, - { - title: "GPT-4o (Free Trial)", - provider: "free-trial", - model: "gpt-4o", - systemMessage: - "You are an expert software developer. You give helpful and concise responses.", - }, - { - title: "Llama3 70b (Free Trial)", - provider: "free-trial", - model: "llama3-70b", - systemMessage: - "You are an expert software developer. You give helpful and concise responses. Whenever you write a code block you include the language after the opening ticks.", - }, - { - title: "Claude 3 Sonnet (Free Trial)", - provider: "free-trial", - model: "claude-3-sonnet-20240229", - }, - ], + models: config.models.filter((model) => model.provider !== "free-trial"), tabAutocompleteModel: { title: "Tab Autocomplete", provider: "free-trial", @@ -96,6 +70,29 @@ export function setupLocalMode( }; } +export function setupFreeTrialMode( + config: SerializedContinueConfig, +): SerializedContinueConfig { + return { + ...config, + models: [ + ...FREE_TRIAL_MODELS, + ...config.models.filter((model) => model.provider !== "free-trial"), + ], + tabAutocompleteModel: { + title: "Tab Autocomplete", + provider: "free-trial", + model: "starcoder-7b", + }, + embeddingsProvider: { + provider: "free-trial", + }, + reranker: { + name: "free-trial", + }, + }; +} + export function setupLocalAfterFreeTrial( config: SerializedContinueConfig, ): SerializedContinueConfig { diff --git a/core/core.ts b/core/core.ts index aacd5937d9..ae749784b6 100644 --- a/core/core.ts +++ b/core/core.ts @@ -3,10 +3,11 @@ import { ContextItemId, IDE } from "."; import { CompletionProvider } from "./autocomplete/completionProvider"; import { ConfigHandler } from "./config/handler"; import { + setupApiKeysMode, + setupFreeTrialMode, setupLocalAfterFreeTrial, setupLocalMode, setupOptimizedExistingUserMode, - setupOptimizedMode, } from "./config/onboarding"; import { addModel, addOpenAIKey, deleteModel } from "./config/util"; import { ContinueServerClient } from "./continueServer/stubs/client"; @@ -55,12 +56,15 @@ export class Core { constructor( private readonly messenger: IMessenger, private readonly ide: IDE, + private readonly onWrite: (text: string) => Promise = async () => {}, ) { const ideSettingsPromise = messenger.request("getIdeSettings", undefined); this.configHandler = new ConfigHandler( this.ide, ideSettingsPromise, - async (text: string) => {}, + this.onWrite, + ); + this.configHandler.onConfigUpdate( (() => this.messenger.send("configUpdate", undefined)).bind(this), ); @@ -469,11 +473,13 @@ export class Core { editConfigJson( mode === "local" ? setupLocalMode - : mode === "localAfterFreeTrial" - ? setupLocalAfterFreeTrial - : mode === "optimized" - ? setupOptimizedMode - : setupOptimizedExistingUserMode, + : mode === "freeTrial" + ? setupFreeTrialMode + : mode === "localAfterFreeTrial" + ? setupLocalAfterFreeTrial + : mode === "apiKeys" + ? setupApiKeysMode + : setupOptimizedExistingUserMode, ); this.configHandler.reloadConfig(); }); diff --git a/core/index.d.ts b/core/index.d.ts index 969e2f6150..bf2852ab49 100644 --- a/core/index.d.ts +++ b/core/index.d.ts @@ -454,6 +454,7 @@ export interface IDE { getGitRootPath(dir: string): Promise; listDir(dir: string): Promise<[string, FileType][]>; getLastModified(files: string[]): Promise<{ [path: string]: number }>; + getGitHubAuthToken(): Promise; } // Slash Commands diff --git a/core/llm/llms/FreeTrial.ts b/core/llm/llms/FreeTrial.ts index 53c51fe256..da74e2281d 100644 --- a/core/llm/llms/FreeTrial.ts +++ b/core/llm/llms/FreeTrial.ts @@ -9,9 +9,21 @@ import { streamResponse } from "../stream.js"; class FreeTrial extends BaseLLM { static providerName: ModelProvider = "free-trial"; + private ghAuthToken: string | undefined = undefined; + + setupGhAuthToken(ghAuthToken: string | undefined) { + this.ghAuthToken = ghAuthToken; + } + private async _getHeaders() { + if (!this.ghAuthToken) { + throw new Error( + "Please sign in with GitHub in order to use the free trial. If you'd like to use Continue without signing in, you can set up your own local model or API key.", + ); + } return { "Content-Type": "application/json", + Authorization: `Bearer ${this.ghAuthToken}`, ...(await getHeaders()), }; } diff --git a/core/protocol/core.ts b/core/protocol/core.ts index 324f74ac71..83d6106223 100644 --- a/core/protocol/core.ts +++ b/core/protocol/core.ts @@ -125,8 +125,9 @@ export type ToCoreFromIdeOrWebviewProtocol = { { mode: | "local" - | "optimized" + | "apiKeys" | "custom" + | "freeTrial" | "localExistingUser" | "optimizedExistingUser" | "localAfterFreeTrial"; diff --git a/core/protocol/ide.ts b/core/protocol/ide.ts index 31369a5eb7..83048f4ef6 100644 --- a/core/protocol/ide.ts +++ b/core/protocol/ide.ts @@ -69,4 +69,6 @@ export type ToIdeFromWebviewOrCoreProtocol = { getGitRootPath: [{ dir: string }, string | undefined]; listDir: [{ dir: string }, [string, FileType][]]; getLastModified: [{ files: string[] }, { [path: string]: number }]; + + getGitHubAuthToken: [undefined, string | undefined]; }; diff --git a/core/protocol/ideWebview.ts b/core/protocol/ideWebview.ts index db22fd2067..b6041e4d8b 100644 --- a/core/protocol/ideWebview.ts +++ b/core/protocol/ideWebview.ts @@ -32,6 +32,7 @@ export type ToIdeFromWebviewProtocol = ToIdeFromWebviewOrCoreProtocol & { insertAtCursor: [{ text: string }, void]; copyText: [{ text: string }, void]; "jetbrains/editorInsetHeight": [{ height: number }, void]; + setGitHubAuthToken: [{ token: string }, void]; }; export type ToWebviewFromIdeProtocol = ToWebviewFromIdeOrCoreProtocol & { @@ -64,4 +65,5 @@ export type ToWebviewFromIdeProtocol = ToWebviewFromIdeOrCoreProtocol & { addApiKey: [undefined, void]; setupLocalModel: [undefined, void]; incrementFtc: [undefined, void]; + openOnboarding: [undefined, void]; }; diff --git a/core/util/filesystem.ts b/core/util/filesystem.ts index cef3ac4199..a0cc2b5e1a 100644 --- a/core/util/filesystem.ts +++ b/core/util/filesystem.ts @@ -13,6 +13,9 @@ import { import { getContinueGlobalPath } from "./paths.js"; class FileSystemIde implements IDE { + async getGitHubAuthToken(): Promise { + return undefined; + } getLastModified(files: string[]): Promise<{ [path: string]: number }> { return new Promise((resolve) => { resolve({ @@ -31,8 +34,8 @@ class FileSystemIde implements IDE { dirent.isDirectory() ? FileType.Directory : dirent.isSymbolicLink() - ? FileType.SymbolicLink - : FileType.File, + ? FileType.SymbolicLink + : FileType.File, ]); return Promise.resolve(all); } diff --git a/core/util/messageIde.ts b/core/util/messageIde.ts index efbc36c0e6..b810450961 100644 --- a/core/util/messageIde.ts +++ b/core/util/messageIde.ts @@ -17,6 +17,9 @@ export class MessageIde implements IDE { data: ToIdeFromWebviewOrCoreProtocol[T][0], ) => Promise, ) {} + getGitHubAuthToken(): Promise { + return this.request("getGitHubAuthToken", undefined); + } getLastModified(files: string[]): Promise<{ [path: string]: number }> { return this.request("getLastModified", { files }); } diff --git a/extensions/vscode/src/commands.ts b/extensions/vscode/src/commands.ts index cde4f748f8..8c1927c821 100644 --- a/extensions/vscode/src/commands.ts +++ b/extensions/vscode/src/commands.ts @@ -478,9 +478,6 @@ const commandsMap: ( panel.webview.html = sidebar.getSidebarContent( extensionContext, panel, - ide, - configHandler, - verticalDiffManager, undefined, undefined, true, diff --git a/extensions/vscode/src/debugPanel.ts b/extensions/vscode/src/debugPanel.ts index b9cd361a7c..41d2f26ea5 100644 --- a/extensions/vscode/src/debugPanel.ts +++ b/extensions/vscode/src/debugPanel.ts @@ -1,7 +1,6 @@ -import type { FileEdit, IDE } from "core"; +import type { FileEdit } from "core"; import type { ConfigHandler } from "core/config/handler"; import * as vscode from "vscode"; -import type { VerticalPerLineDiffManager } from "./diff/verticalPerLine/manager"; import { getTheme } from "./util/getTheme"; import { getExtensionVersion } from "./util/util"; import { getExtensionUri, getNonce, getUniqueId } from "./util/vscode"; @@ -51,7 +50,12 @@ export class ContinueGUIWebviewViewProvider private readonly windowId: string, private readonly extensionContext: vscode.ExtensionContext, ) { - this.webviewProtocol = new VsCodeWebviewProtocol(); + this.webviewProtocol = new VsCodeWebviewProtocol( + (async () => { + const configHandler = await this.configHandlerPromise; + return configHandler.reloadConfig(); + }).bind(this), + ); } getSidebarContent( diff --git a/extensions/vscode/src/extension/VsCodeMessenger.ts b/extensions/vscode/src/extension/VsCodeMessenger.ts index ccc8f86934..7b6ea40f00 100644 --- a/extensions/vscode/src/extension/VsCodeMessenger.ts +++ b/extensions/vscode/src/extension/VsCodeMessenger.ts @@ -68,7 +68,7 @@ export class VsCodeMessenger { >, private readonly webviewProtocol: VsCodeWebviewProtocol, private readonly ide: VsCodeIde, - private readonly verticalDiffManager: VerticalPerLineDiffManager, + private readonly verticalDiffManagerPromise: Promise, ) { /** WEBVIEW ONLY LISTENERS **/ this.onWebview("showFile", (msg) => { @@ -134,7 +134,7 @@ export class VsCodeMessenger { editor.selection = new vscode.Selection(start, end); } - this.verticalDiffManager.streamEdit( + (await this.verticalDiffManagerPromise).streamEdit( `The following code was suggested as an edit:\n\`\`\`\n${msg.data.text}\n\`\`\`\nPlease apply it to the previous code.`, await this.webviewProtocol.request("getDefaultModelTitle", undefined), ); @@ -201,72 +201,72 @@ export class VsCodeMessenger { return ide.getIdeSettings(); }); this.onWebviewOrCore("getDiff", async (msg) => { - return await ide.getDiff(); + return ide.getDiff(); }); this.onWebviewOrCore("getTerminalContents", async (msg) => { - return await ide.getTerminalContents(); + return ide.getTerminalContents(); }); this.onWebviewOrCore("getDebugLocals", async (msg) => { - return await ide.getDebugLocals(Number(msg.data.threadIndex)); + return ide.getDebugLocals(Number(msg.data.threadIndex)); }); this.onWebviewOrCore("getAvailableThreads", async (msg) => { - return await ide.getAvailableThreads(); + return ide.getAvailableThreads(); }); this.onWebviewOrCore("getTopLevelCallStackSources", async (msg) => { - return await ide.getTopLevelCallStackSources( + return ide.getTopLevelCallStackSources( msg.data.threadIndex, msg.data.stackDepth, ); }); this.onWebviewOrCore("listWorkspaceContents", async (msg) => { - return await ide.listWorkspaceContents(); + return ide.listWorkspaceContents(); }); this.onWebviewOrCore("getWorkspaceDirs", async (msg) => { - return await ide.getWorkspaceDirs(); + return ide.getWorkspaceDirs(); }); this.onWebviewOrCore("listFolders", async (msg) => { - return await ide.listFolders(); + return ide.listFolders(); }); this.onWebviewOrCore("writeFile", async (msg) => { - return await ide.writeFile(msg.data.path, msg.data.contents); + return ide.writeFile(msg.data.path, msg.data.contents); }); this.onWebviewOrCore("showVirtualFile", async (msg) => { - return await ide.showVirtualFile(msg.data.name, msg.data.content); + return ide.showVirtualFile(msg.data.name, msg.data.content); }); this.onWebviewOrCore("getContinueDir", async (msg) => { - return await ide.getContinueDir(); + return ide.getContinueDir(); }); this.onWebviewOrCore("openFile", async (msg) => { - return await ide.openFile(msg.data.path); + return ide.openFile(msg.data.path); }); this.onWebviewOrCore("runCommand", async (msg) => { await ide.runCommand(msg.data.command); }); this.onWebviewOrCore("getSearchResults", async (msg) => { - return await ide.getSearchResults(msg.data.query); + return ide.getSearchResults(msg.data.query); }); this.onWebviewOrCore("subprocess", async (msg) => { - return await ide.subprocess(msg.data.command); + return ide.subprocess(msg.data.command); }); this.onWebviewOrCore("getProblems", async (msg) => { - return await ide.getProblems(msg.data.filepath); + return ide.getProblems(msg.data.filepath); }); this.onWebviewOrCore("getBranch", async (msg) => { const { dir } = msg.data; - return await ide.getBranch(dir); + return ide.getBranch(dir); }); this.onWebviewOrCore("getOpenFiles", async (msg) => { - return await ide.getOpenFiles(); + return ide.getOpenFiles(); }); this.onWebviewOrCore("getCurrentFile", async () => { - return await ide.getCurrentFile(); + return ide.getCurrentFile(); }); this.onWebviewOrCore("getPinnedFiles", async (msg) => { - return await ide.getPinnedFiles(); + return ide.getPinnedFiles(); }); this.onWebviewOrCore("showLines", async (msg) => { const { filepath, startLine, endLine } = msg.data; - return await ide.showLines(filepath, startLine, endLine); + return ide.showLines(filepath, startLine, endLine); }); // Other this.onWebviewOrCore("errorPopup", (msg) => { @@ -278,5 +278,8 @@ export class VsCodeMessenger { } }); }); + this.onWebviewOrCore("getGitHubAuthToken", (msg) => + ide.getGitHubAuthToken(), + ); } } diff --git a/extensions/vscode/src/extension/vscodeExtension.ts b/extensions/vscode/src/extension/vscodeExtension.ts index 4099e98bec..f9e57dd0c1 100644 --- a/extensions/vscode/src/extension/vscodeExtension.ts +++ b/extensions/vscode/src/extension/vscodeExtension.ts @@ -39,37 +39,6 @@ export class VsCodeExtension { constructor(context: vscode.ExtensionContext) { this.diffManager = new DiffManager(context); this.ide = new VsCodeIde(this.diffManager); - - const ideSettings = this.ide.getIdeSettings(); - const { remoteConfigServerUrl } = ideSettings; - - // Config Handler with output channel - const outputChannel = vscode.window.createOutputChannel( - "Continue - LLM Prompt/Completion", - ); - this.configHandler = new ConfigHandler( - this.ide, - Promise.resolve(ideSettings), - async (log: string) => { - outputChannel.appendLine( - "==========================================================================", - ); - outputChannel.appendLine( - "==========================================================================", - ); - outputChannel.append(log); - }, - (() => this.webviewProtocol?.request("configUpdate", undefined)).bind( - this, - ), - ); - - this.configHandler.reloadConfig(); - this.verticalDiffManager = new VerticalPerLineDiffManager( - this.configHandler, - ); - this.diffManager = new DiffManager(context); - this.ide = new VsCodeIde(this.diffManager, this.webviewProtocolPromise); this.extensionContext = context; this.windowId = uuidv4(); @@ -105,9 +74,10 @@ export class VsCodeExtension { ); resolveWebviewProtocol(this.sidebar.webviewProtocol); - // Indexing + pause token - this.diffManager.webviewProtocol = this.webviewProtocol; - + // Config Handler with output channel + const outputChannel = vscode.window.createOutputChannel( + "Continue - LLM Prompt/Completion", + ); const inProcessMessenger = new InProcessMessenger< ToCoreProtocol, FromCoreProtocol @@ -116,9 +86,36 @@ export class VsCodeExtension { inProcessMessenger, this.webviewProtocol, this.ide, - this.verticalDiffManager, + verticalDiffManagerPromise, ); - this.core = new Core(inProcessMessenger, this.ide); + this.core = new Core(inProcessMessenger, this.ide, async (log: string) => { + outputChannel.appendLine( + "==========================================================================", + ); + outputChannel.appendLine( + "==========================================================================", + ); + outputChannel.append(log); + }); + this.configHandler = this.core.configHandler; + resolveConfigHandler?.(this.configHandler); + this.configHandler.onConfigUpdate(() => { + this.webviewProtocol?.request("configUpdate", undefined); + }); + + this.configHandler.reloadConfig(); + this.verticalDiffManager = new VerticalPerLineDiffManager( + this.configHandler, + ); + resolveVerticalDiffManager?.(this.verticalDiffManager); + this.tabAutocompleteModel = new TabAutocompleteModel(this.configHandler); + + setupRemoteConfigSync( + this.configHandler.reloadConfig.bind(this.configHandler), + ); + + // Indexing + pause token + this.diffManager.webviewProtocol = this.webviewProtocol; if ( !( diff --git a/extensions/vscode/src/ideProtocol.ts b/extensions/vscode/src/ideProtocol.ts index 016de4ac8d..ee593e2eb9 100644 --- a/extensions/vscode/src/ideProtocol.ts +++ b/extensions/vscode/src/ideProtocol.ts @@ -37,6 +37,30 @@ class VsCodeIde implements IDE { this.ideUtils = new VsCodeIdeUtils(); } + private authToken: string | undefined; + private askedForAuth = false; + + async getGitHubAuthToken(): Promise { + if (this.authToken) { + return this.authToken; + } + try { + const session = await vscode.authentication.getSession("github", [], { + silent: this.askedForAuth, + createIfNone: !this.askedForAuth, + }); + if (session) { + this.authToken = session.accessToken; + return session.accessToken; + } + } catch (error) { + console.error("Failed to get GitHub authentication session:", error); + } finally { + this.askedForAuth = true; + } + return undefined; + } + async infoPopup(message: string): Promise { vscode.window.showInformationMessage(message); } diff --git a/extensions/vscode/src/webviewProtocol.ts b/extensions/vscode/src/webviewProtocol.ts index 7c21e1040c..9810dd1ce7 100644 --- a/extensions/vscode/src/webviewProtocol.ts +++ b/extensions/vscode/src/webviewProtocol.ts @@ -153,6 +153,26 @@ export class VsCodeWebviewProtocol this.request("setupLocalModel", undefined); } }); + } else if (message.includes("Please sign in with GitHub")) { + vscode.window + .showInformationMessage( + message, + "Sign In", + "Use API key / local model", + ) + .then((selection) => { + if (selection === "Sign In") { + vscode.authentication + .getSession("github", [], { + createIfNone: true, + }) + .then(() => { + this.reloadConfig(); + }); + } else if (selection === "Use API key / local model") { + this.request("openOnboarding", undefined); + } + }); } else { vscode.window .showErrorMessage(message, "Show Logs", "Troubleshooting") @@ -175,7 +195,7 @@ export class VsCodeWebviewProtocol }); } - constructor() {} + constructor(private readonly reloadConfig: () => void) {} invoke( messageType: T, data: ToCoreOrIdeFromWebviewProtocol[T][0], diff --git a/gui/src/App.tsx b/gui/src/App.tsx index 1dafdf7406..c015b54b0a 100644 --- a/gui/src/App.tsx +++ b/gui/src/App.tsx @@ -17,6 +17,7 @@ import useSubmenuContextProviders from "./hooks/useSubmenuContextProviders"; import { useVscTheme } from "./hooks/useVscTheme"; import GUI from "./pages/gui"; import LocalOnboarding from "./pages/localOnboarding"; +import ApiKeyOnboarding from "./pages/onboarding/apiKeyOnboarding"; import ExistingUserOnboarding from "./pages/onboarding/existingUserOnboarding"; import Onboarding from "./pages/onboarding/onboarding"; import Stats from "./pages/stats"; @@ -87,10 +88,6 @@ const router = createMemoryRouter([ path: "/apiKeyOnboarding", element: , }, - { - path: "/apiKeyAutocompleteOnboarding", - element: , - }, ], }, ]); diff --git a/gui/src/components/Layout.tsx b/gui/src/components/Layout.tsx index 03813700ad..e621cc33f0 100644 --- a/gui/src/components/Layout.tsx +++ b/gui/src/components/Layout.tsx @@ -19,7 +19,6 @@ import { defaultModelSelector } from "../redux/selectors/modelSelectors"; import { setBottomMessage, setBottomMessageCloseTimeout, - setDialogMessage, setShowDialog, } from "../redux/slices/uiStateSlice"; import { RootState } from "../redux/store"; @@ -31,7 +30,6 @@ import { ftl } from "./dialogs/FTCDialog"; import IndexingProgressBar from "./loaders/IndexingProgressBar"; import ProgressBar from "./loaders/ProgressBar"; import ModelSelect from "./modelSelection/ModelSelect"; -import QuickModelSetup from "./modelSelection/quickSetup/QuickModelSetup"; // #region Styled Components const FOOTER_HEIGHT = "1.8em"; @@ -93,7 +91,9 @@ const DropdownPortalDiv = styled.div` const HIDE_FOOTER_ON_PAGES = [ "/onboarding", "/existingUserOnboarding", + "/onboarding", "/localOnboarding", + "/apiKeyOnboarding", ]; const Layout = () => { @@ -175,8 +175,15 @@ const Layout = () => { useWebviewListener( "addApiKey", async () => { - dispatch(setShowDialog(true)); - dispatch(setDialogMessage()); + navigate("/apiKeyOnboarding"); + }, + [navigate], + ); + + useWebviewListener( + "openOnboarding", + async () => { + navigate("/onboarding"); }, [navigate], ); @@ -212,11 +219,7 @@ const Layout = () => { !location.pathname.startsWith("/onboarding") && !location.pathname.startsWith("/existingUserOnboarding") ) { - if (getLocalStorage("mainTextEntryCounter")) { - navigate("/existingUserOnboarding"); - } else { - navigate("/onboarding"); - } + navigate("/onboarding"); } }, [location]); diff --git a/gui/src/components/loaders/ProgressBar.tsx b/gui/src/components/loaders/ProgressBar.tsx index 87840062e3..5cddaeb1a2 100644 --- a/gui/src/components/loaders/ProgressBar.tsx +++ b/gui/src/components/loaders/ProgressBar.tsx @@ -1,5 +1,6 @@ import ReactDOM from "react-dom"; import { useDispatch } from "react-redux"; +import { useNavigate } from "react-router-dom"; import styled from "styled-components"; import { StyledTooltip, lightGray, vscForeground } from ".."; import { @@ -52,6 +53,7 @@ interface ProgressBarProps { const ProgressBar = ({ completed, total }: ProgressBarProps) => { const dispatch = useDispatch(); + const navigate = useNavigate(); const fillPercentage = Math.min(100, Math.max(0, (completed / total) * 100)); const tooltipPortalDiv = document.getElementById("tooltip-portal-div"); @@ -62,7 +64,16 @@ const ProgressBar = ({ completed, total }: ProgressBarProps) => { data-tooltip-id="usage_progress_bar" onClick={() => { dispatch(setShowDialog(true)); - dispatch(setDialogMessage()); + dispatch( + setDialogMessage( + { + dispatch(setShowDialog(false)); + navigate("/"); + }} + />, + ), + ); }} > diff --git a/gui/src/components/modelSelection/quickSetup/QuickModelSetup.tsx b/gui/src/components/modelSelection/quickSetup/QuickModelSetup.tsx index 7ae6633243..f002c3dfa9 100644 --- a/gui/src/components/modelSelection/quickSetup/QuickModelSetup.tsx +++ b/gui/src/components/modelSelection/quickSetup/QuickModelSetup.tsx @@ -5,13 +5,14 @@ import { useNavigate } from "react-router-dom"; import { Button, Input, SecondaryButton } from "../.."; import { IdeMessengerContext } from "../../../context/IdeMessenger"; import { setDefaultModel } from "../../../redux/slices/stateSlice"; -import { setShowDialog } from "../../../redux/slices/uiStateSlice"; import { getLocalStorage } from "../../../util/localStorage"; import { PROVIDER_INFO } from "../../../util/modelData"; import { ftl } from "../../dialogs/FTCDialog"; import QuickSetupListBox from "./QuickSetupListBox"; -interface QuickModelSetupProps {} +interface QuickModelSetupProps { + onDone: () => void; +} function QuickModelSetup(props: QuickModelSetupProps) { const [selectedProvider, setSelectedProvider] = useState( @@ -29,14 +30,16 @@ function QuickModelSetup(props: QuickModelSetupProps) { setSelectedModel(selectedProvider.packages[0]); }, [selectedProvider]); + const [hasAddedModel, setHasAddedModel] = useState(false); + return (
-

+ {/*

{getLocalStorage("ftc") > ftl() ? "Set up your own model" : "Add a new model"} -

+ */} {getLocalStorage("ftc") > ftl() && (

@@ -127,18 +130,16 @@ function QuickModelSetup(props: QuickModelSetupProps) { }; ideMessenger.post("config/addModel", { model }); dispatch(setDefaultModel({ title: model.title, force: true })); - navigate("/"); + setHasAddedModel(true); }} className="w-full" > Add Model diff --git a/gui/src/hooks/useInputHistory.ts b/gui/src/hooks/useInputHistory.ts index 21d66ed591..ec3c968805 100644 --- a/gui/src/hooks/useInputHistory.ts +++ b/gui/src/hooks/useInputHistory.ts @@ -12,7 +12,7 @@ const MAX_HISTORY_LENGTH = 100; export function useInputHistory() { const [inputHistory, setInputHistory] = useState( - getLocalStorage("inputHistory").slice(-MAX_HISTORY_LENGTH) ?? [], + getLocalStorage("inputHistory")?.slice(-MAX_HISTORY_LENGTH) ?? [], ); const [pendingInput, setPendingInput] = useState(emptyJsonContent()); diff --git a/gui/src/pages/gui.tsx b/gui/src/pages/gui.tsx index 6f371ba4b7..00aff42ae8 100644 --- a/gui/src/pages/gui.tsx +++ b/gui/src/pages/gui.tsx @@ -32,7 +32,6 @@ import StepContainer from "../components/gui/StepContainer"; import TimelineItem from "../components/gui/TimelineItem"; import ContinueInputBox from "../components/mainInput/ContinueInputBox"; import { defaultInputModifiers } from "../components/mainInput/inputModifiers"; -import QuickModelSetup from "../components/modelSelection/quickSetup/QuickModelSetup"; import { IdeMessengerContext } from "../context/IdeMessenger"; import useChatHandler from "../hooks/useChatHandler"; import useHistory from "../hooks/useHistory"; @@ -248,8 +247,7 @@ function GUI(props: GUIProps) { setLocalStorage("ftc", u + 1); if (u >= ftl()) { - dispatch(setShowDialog(true)); - dispatch(setDialogMessage()); + navigate("/onboarding"); posthog?.capture("ftc_reached"); return; } diff --git a/gui/src/pages/onboarding/apiKeyOnboarding.tsx b/gui/src/pages/onboarding/apiKeyOnboarding.tsx index a086cdc879..a0d6c761cc 100644 --- a/gui/src/pages/onboarding/apiKeyOnboarding.tsx +++ b/gui/src/pages/onboarding/apiKeyOnboarding.tsx @@ -3,7 +3,6 @@ import { useContext } from "react"; import { useNavigate } from "react-router-dom"; import QuickModelSetup from "../../components/modelSelection/quickSetup/QuickModelSetup"; import { IdeMessengerContext } from "../../context/IdeMessenger"; -import { getLocalStorage } from "../../util/localStorage"; function ApiKeyOnboarding() { const ideMessenger = useContext(IdeMessengerContext); @@ -24,12 +23,7 @@ function ApiKeyOnboarding() { { ideMessenger.post("showTutorial", undefined); - - if (getLocalStorage("signedInToGh") === true) { - navigate("/"); - } else { - navigate("/apiKeyAutocompleteOnboarding"); - } + navigate("/"); }} >

diff --git a/gui/src/pages/onboarding/onboarding.tsx b/gui/src/pages/onboarding/onboarding.tsx index 9cd7f67456..d2da11a848 100644 --- a/gui/src/pages/onboarding/onboarding.tsx +++ b/gui/src/pages/onboarding/onboarding.tsx @@ -1,8 +1,16 @@ import { useContext, useState } from "react"; import { useNavigate } from "react-router-dom"; -import { greenButtonColor } from "../../components"; +import { + Button, + Input, + SecondaryButton, + greenButtonColor, + vscForeground, +} from "../../components"; +import { ftl } from "../../components/dialogs/FTCDialog"; import { IdeMessengerContext } from "../../context/IdeMessenger"; -import { setLocalStorage } from "../../util/localStorage"; +import { isJetBrains } from "../../util"; +import { getLocalStorage, setLocalStorage } from "../../util/localStorage"; import { Div, StyledButton } from "./components"; function Onboarding() { @@ -11,16 +19,30 @@ function Onboarding() { const [hovered0, setHovered0] = useState(false); const [hovered1, setHovered1] = useState(false); - const [hovered2, setHovered2] = useState(false); const [selected, setSelected] = useState(-1); + const [jbGhAuthToken, setJbGhAuthToken] = useState(""); + return (
-

Welcome to Continue

-

- Let's find the setup that works best for you -

+ {getLocalStorage("ftc") > ftl() ? ( + <> +

Free trial limit reached

+

+ To keep using Continue, please enter an API key or set up a local + model +

+ + ) : ( + <> +

Welcome to Continue

+

+ Let's find the setup that works best for you +

+ + )} +
setHovered0(true)} onMouseLeave={() => setHovered0(false)} > -

✨ Cloud models

+

✨ Use your API key

- This is the best experience. Continue will use the strongest available - commercial models to index code and answer questions. Code is only - ever stored locally. + Enter an OpenAI or other API key for the best experience. Continue + will use the best available commercial models to index code. Code is + only ever stored locally.

{selected === 0 && (

- Embeddings: Voyage Code 2 + Chat: Whichever model you choose

- Autocomplete: Starcoder 7b via Fireworks AI (free trial) + Embeddings: Voyage Code 2

- Chat: GPT-4, Claude 3, and others (free trial) + Autocomplete: Starcoder 7B via Fireworks AI

)}

@@ -70,60 +92,16 @@ function Onboarding() {
{selected === 1 && (

- Embeddings: Local sentence-transformers model + Chat: Llama 3 with Ollama, LM Studio, etc.

- Autocomplete: Starcoder2 3b (set up with Ollama, LM Studio, - etc.) + Embeddings: Nomic Embed

- Chat: Llama 3 with Ollama, LM Studio, etc. + Autocomplete: Starcoder2 3B

)}

- {/*

- - Read the docs - {" "} - to learn more and fully customize Continue by opening config.json. -

*/} -
setHovered2(true)} - onMouseLeave={() => setHovered2(false)} - onClick={() => { - setSelected(2); - ideMessenger.post("openConfigJson", undefined); - }} - > -

⚙️ Your own models

-

- Continue lets you use your own API key or self-hosted LLMs.{" "} - - Read the docs - {" "} - to learn more about using config.json to customize Continue. This can - always be done later. -

-
- {selected === 2 && ( -

- Use config.json to configure your own{" "} - models,{" "} - - context providers - - ,{" "} - - slash commands - - , and more. -

- )} -
{ ideMessenger.post("completeOnboarding", { - mode: ["optimized", "local", "custom"][selected] as any, + mode: ["apiKeys", "local"][selected] as any, }); setLocalStorage("onboardingComplete", true); @@ -146,15 +124,97 @@ function Onboarding() { } else { // Only needed when we switch from the default (local) embeddings provider ideMessenger.post("index/forceReIndex", undefined); - // Don't show the tutorial above yet because there's another step to complete at /localOnboarding - ideMessenger.post("showTutorial", undefined); - navigate("/"); + navigate("/apiKeyOnboarding"); } }} > Continue
+ + {getLocalStorage("onboardingComplete") || ( + <> +
+ +

+ OR sign in with GitHub to try 25 free requests +

+ {isJetBrains() ? ( +
+
+ { + ideMessenger.post( + "openUrl", + "https://github.com/settings/tokens/new?scopes=user:email&description=Continue%20Free%20Trial%20Token%20", + ); + }} + className="grid grid-flow-col items-center gap-2" + > + + + + Generate Token + +
+ setJbGhAuthToken(e.target.value)} + /> + +
+ ) : ( +
+ { + await ideMessenger.request("getGitHubAuthToken", undefined); + setLocalStorage("onboardingComplete", true); + await ideMessenger.request("completeOnboarding", { + mode: "freeTrial", + }); + navigate("/"); + }} + className="grid grid-flow-col items-center gap-2" + > + + + + Sign in with GitHub + +
+ )} + + )} ); }