Skip to content

Commit

Permalink
Free Trial Auth (#1367)
Browse files Browse the repository at this point in the history
* 🎨 pass gh auth token to free trial

* 💄 improved onboarding flow

* 🎨 make fewer auth reqs

* 🎨 refactor ordering of vscode deps

* 🐛 resolve confighandler
  • Loading branch information
sestinj committed Jun 23, 2024
1 parent d33ec43 commit 48a02e8
Show file tree
Hide file tree
Showing 26 changed files with 382 additions and 263 deletions.
77 changes: 25 additions & 52 deletions core/config/default.ts
Original file line number Diff line number Diff line change
@@ -1,35 +1,33 @@
import {
ContextProviderWithParams,
ModelDescription,
SerializedContinueConfig,
} from "../index.js";

export const FREE_TRIAL_MODELS: ModelDescription[] = [
{
title: "GPT-4o (Free Trial)",
provider: "free-trial",
model: "gpt-4o",
systemMessage:
"You are an expert software developer. You give helpful and concise responses.",
},
{
title: "Llama3 70b (Free Trial)",
provider: "free-trial",
model: "llama3-70b",
systemMessage:
"You are an expert software developer. You give helpful and concise responses. Whenever you write a code block you include the language after the opening ticks.",
},
{
title: "Claude 3 Sonnet (Free Trial)",
provider: "free-trial",
model: "claude-3-sonnet-20240229",
},
];

export const defaultConfig: SerializedContinueConfig = {
models: [
// {
// title: "Codestral (Free Trial)",
// provider: "free-trial",
// model: "codestral",
// },
{
title: "GPT-4o (Free Trial)",
provider: "free-trial",
model: "gpt-4o",
systemMessage:
"You are an expert software developer. You give helpful and concise responses.",
},
{
title: "Llama3 70b (Free Trial)",
provider: "free-trial",
model: "llama3-70b",
systemMessage:
"You are an expert software developer. You give helpful and concise responses. Whenever you write a code block you include the language after the opening ticks.",
},
{
title: "Claude 3 Sonnet (Free Trial)",
provider: "free-trial",
model: "claude-3-sonnet-20240229",
},
],
models: FREE_TRIAL_MODELS,
customCommands: [
{
name: "test",
Expand All @@ -46,32 +44,7 @@ export const defaultConfig: SerializedContinueConfig = {
};

export const defaultConfigJetBrains: SerializedContinueConfig = {
models: [
// {
// title: "Codestral (Free Trial)",
// provider: "free-trial",
// model: "codestral",
// },
{
title: "GPT-4o (Free Trial)",
provider: "free-trial",
model: "gpt-4o",
systemMessage:
"You are an expert software developer. You give helpful and concise responses.",
},
{
title: "Llama3 70b (Free Trial)",
provider: "free-trial",
model: "llama3-70b",
systemMessage:
"You are an expert software developer. You give helpful and concise responses. Whenever you write a code block you include the language after the opening ticks.",
},
{
title: "Claude 3 Sonnet (Free Trial)",
provider: "free-trial",
model: "claude-3-sonnet-20240229",
},
],
models: FREE_TRIAL_MODELS,
customCommands: [
{
name: "test",
Expand Down
10 changes: 4 additions & 6 deletions core/config/handler.ts
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,6 @@ export class ConfigHandler {
private readonly ide: IDE,
private ideSettingsPromise: Promise<IdeSettings>,
private readonly writeLog: (text: string) => Promise<void>,
private readonly onConfigUpdate: () => void,
) {
this.ide = ide;
this.ideSettingsPromise = ideSettingsPromise;
Expand All @@ -44,11 +43,10 @@ export class ConfigHandler {
reloadConfig() {
this.savedConfig = undefined;
this.savedBrowserConfig = undefined;
this.loadConfig().then(() => {
for (const listener of this.updateListeners) {
listener();
}
});
this.loadConfig();
for (const listener of this.updateListeners) {
listener();
}
}

async getSerializedConfig(): Promise<BrowserSerializedContinueConfig> {
Expand Down
12 changes: 12 additions & 0 deletions core/config/load.ts
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,7 @@ import TransformersJsEmbeddingsProvider from "../indexing/embeddings/Transformer
import { AllEmbeddingsProviders } from "../indexing/embeddings/index.js";
import { BaseLLM } from "../llm/index.js";
import CustomLLMClass from "../llm/llms/CustomLLM.js";
import FreeTrial from "../llm/llms/FreeTrial.js";
import { llmFromDescription } from "../llm/llms/index.js";
import { IdeSettings } from "../protocol/ideWebview.js";
import { fetchwithRequestOptions } from "../util/fetchWithOptions.js";
Expand Down Expand Up @@ -277,6 +278,17 @@ async function intermediateToFinalConfig(
};
}

// Obtain auth token (only if free trial being used)
const freeTrialModels = models.filter(
(model) => model.providerName === "free-trial",
);
if (freeTrialModels.length > 0) {
const ghAuthToken = await ide.getGitHubAuthToken();
for (const model of freeTrialModels) {
(model as FreeTrial).setupGhAuthToken(ghAuthToken);
}
}

// Tab autocomplete model
let autocompleteLlm: BaseLLM | undefined = undefined;
if (config.tabAutocompleteModel) {
Expand Down
55 changes: 26 additions & 29 deletions core/config/onboarding.ts
Original file line number Diff line number Diff line change
@@ -1,38 +1,12 @@
import { SerializedContinueConfig } from "../index.js";
import { FREE_TRIAL_MODELS } from "./default.js";

export const TRIAL_FIM_MODEL = "codestral-latest";

export function setupOptimizedMode(
export function setupApiKeysMode(
config: SerializedContinueConfig,
): SerializedContinueConfig {
return {
...config,
models: [
// {
// title: "Codestral (Free Trial)",
// provider: "free-trial",
// model: "codestral",
// },
{
title: "GPT-4o (Free Trial)",
provider: "free-trial",
model: "gpt-4o",
systemMessage:
"You are an expert software developer. You give helpful and concise responses.",
},
{
title: "Llama3 70b (Free Trial)",
provider: "free-trial",
model: "llama3-70b",
systemMessage:
"You are an expert software developer. You give helpful and concise responses. Whenever you write a code block you include the language after the opening ticks.",
},
{
title: "Claude 3 Sonnet (Free Trial)",
provider: "free-trial",
model: "claude-3-sonnet-20240229",
},
],
models: config.models.filter((model) => model.provider !== "free-trial"),
tabAutocompleteModel: {
title: "Tab Autocomplete",
provider: "free-trial",
Expand Down Expand Up @@ -96,6 +70,29 @@ export function setupLocalMode(
};
}

export function setupFreeTrialMode(
config: SerializedContinueConfig,
): SerializedContinueConfig {
return {
...config,
models: [
...FREE_TRIAL_MODELS,
...config.models.filter((model) => model.provider !== "free-trial"),
],
tabAutocompleteModel: {
title: "Tab Autocomplete",
provider: "free-trial",
model: "starcoder-7b",
},
embeddingsProvider: {
provider: "free-trial",
},
reranker: {
name: "free-trial",
},
};
}

export function setupLocalAfterFreeTrial(
config: SerializedContinueConfig,
): SerializedContinueConfig {
Expand Down
20 changes: 13 additions & 7 deletions core/core.ts
Original file line number Diff line number Diff line change
Expand Up @@ -3,10 +3,11 @@ import { ContextItemId, IDE } from ".";
import { CompletionProvider } from "./autocomplete/completionProvider";
import { ConfigHandler } from "./config/handler";
import {
setupApiKeysMode,
setupFreeTrialMode,
setupLocalAfterFreeTrial,
setupLocalMode,
setupOptimizedExistingUserMode,
setupOptimizedMode,
} from "./config/onboarding";
import { addModel, addOpenAIKey, deleteModel } from "./config/util";
import { ContinueServerClient } from "./continueServer/stubs/client";
Expand Down Expand Up @@ -55,12 +56,15 @@ export class Core {
constructor(
private readonly messenger: IMessenger<ToCoreProtocol, FromCoreProtocol>,
private readonly ide: IDE,
private readonly onWrite: (text: string) => Promise<void> = async () => {},
) {
const ideSettingsPromise = messenger.request("getIdeSettings", undefined);
this.configHandler = new ConfigHandler(
this.ide,
ideSettingsPromise,
async (text: string) => {},
this.onWrite,
);
this.configHandler.onConfigUpdate(
(() => this.messenger.send("configUpdate", undefined)).bind(this),
);

Expand Down Expand Up @@ -469,11 +473,13 @@ export class Core {
editConfigJson(
mode === "local"
? setupLocalMode
: mode === "localAfterFreeTrial"
? setupLocalAfterFreeTrial
: mode === "optimized"
? setupOptimizedMode
: setupOptimizedExistingUserMode,
: mode === "freeTrial"
? setupFreeTrialMode
: mode === "localAfterFreeTrial"
? setupLocalAfterFreeTrial
: mode === "apiKeys"
? setupApiKeysMode
: setupOptimizedExistingUserMode,
);
this.configHandler.reloadConfig();
});
Expand Down
1 change: 1 addition & 0 deletions core/index.d.ts
Original file line number Diff line number Diff line change
Expand Up @@ -454,6 +454,7 @@ export interface IDE {
getGitRootPath(dir: string): Promise<string | undefined>;
listDir(dir: string): Promise<[string, FileType][]>;
getLastModified(files: string[]): Promise<{ [path: string]: number }>;
getGitHubAuthToken(): Promise<string | undefined>;
}

// Slash Commands
Expand Down
12 changes: 12 additions & 0 deletions core/llm/llms/FreeTrial.ts
Original file line number Diff line number Diff line change
Expand Up @@ -9,9 +9,21 @@ import { streamResponse } from "../stream.js";
class FreeTrial extends BaseLLM {
static providerName: ModelProvider = "free-trial";

private ghAuthToken: string | undefined = undefined;

setupGhAuthToken(ghAuthToken: string | undefined) {
this.ghAuthToken = ghAuthToken;
}

private async _getHeaders() {
if (!this.ghAuthToken) {
throw new Error(
"Please sign in with GitHub in order to use the free trial. If you'd like to use Continue without signing in, you can set up your own local model or API key.",
);
}
return {
"Content-Type": "application/json",
Authorization: `Bearer ${this.ghAuthToken}`,
...(await getHeaders()),
};
}
Expand Down
3 changes: 2 additions & 1 deletion core/protocol/core.ts
Original file line number Diff line number Diff line change
Expand Up @@ -125,8 +125,9 @@ export type ToCoreFromIdeOrWebviewProtocol = {
{
mode:
| "local"
| "optimized"
| "apiKeys"
| "custom"
| "freeTrial"
| "localExistingUser"
| "optimizedExistingUser"
| "localAfterFreeTrial";
Expand Down
2 changes: 2 additions & 0 deletions core/protocol/ide.ts
Original file line number Diff line number Diff line change
Expand Up @@ -69,4 +69,6 @@ export type ToIdeFromWebviewOrCoreProtocol = {
getGitRootPath: [{ dir: string }, string | undefined];
listDir: [{ dir: string }, [string, FileType][]];
getLastModified: [{ files: string[] }, { [path: string]: number }];

getGitHubAuthToken: [undefined, string | undefined];
};
2 changes: 2 additions & 0 deletions core/protocol/ideWebview.ts
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,7 @@ export type ToIdeFromWebviewProtocol = ToIdeFromWebviewOrCoreProtocol & {
insertAtCursor: [{ text: string }, void];
copyText: [{ text: string }, void];
"jetbrains/editorInsetHeight": [{ height: number }, void];
setGitHubAuthToken: [{ token: string }, void];
};

export type ToWebviewFromIdeProtocol = ToWebviewFromIdeOrCoreProtocol & {
Expand Down Expand Up @@ -64,4 +65,5 @@ export type ToWebviewFromIdeProtocol = ToWebviewFromIdeOrCoreProtocol & {
addApiKey: [undefined, void];
setupLocalModel: [undefined, void];
incrementFtc: [undefined, void];
openOnboarding: [undefined, void];
};
7 changes: 5 additions & 2 deletions core/util/filesystem.ts
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,9 @@ import {
import { getContinueGlobalPath } from "./paths.js";

class FileSystemIde implements IDE {
async getGitHubAuthToken(): Promise<string | undefined> {
return undefined;
}
getLastModified(files: string[]): Promise<{ [path: string]: number }> {
return new Promise((resolve) => {
resolve({
Expand All @@ -31,8 +34,8 @@ class FileSystemIde implements IDE {
dirent.isDirectory()
? FileType.Directory
: dirent.isSymbolicLink()
? FileType.SymbolicLink
: FileType.File,
? FileType.SymbolicLink
: FileType.File,
]);
return Promise.resolve(all);
}
Expand Down
3 changes: 3 additions & 0 deletions core/util/messageIde.ts
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,9 @@ export class MessageIde implements IDE {
data: ToIdeFromWebviewOrCoreProtocol[T][0],
) => Promise<ToIdeFromWebviewOrCoreProtocol[T][1]>,
) {}
getGitHubAuthToken(): Promise<string | undefined> {
return this.request("getGitHubAuthToken", undefined);
}
getLastModified(files: string[]): Promise<{ [path: string]: number }> {
return this.request("getLastModified", { files });
}
Expand Down
3 changes: 0 additions & 3 deletions extensions/vscode/src/commands.ts
Original file line number Diff line number Diff line change
Expand Up @@ -478,9 +478,6 @@ const commandsMap: (
panel.webview.html = sidebar.getSidebarContent(
extensionContext,
panel,
ide,
configHandler,
verticalDiffManager,
undefined,
undefined,
true,
Expand Down
10 changes: 7 additions & 3 deletions extensions/vscode/src/debugPanel.ts
Original file line number Diff line number Diff line change
@@ -1,7 +1,6 @@
import type { FileEdit, IDE } from "core";
import type { FileEdit } from "core";
import type { ConfigHandler } from "core/config/handler";
import * as vscode from "vscode";
import type { VerticalPerLineDiffManager } from "./diff/verticalPerLine/manager";
import { getTheme } from "./util/getTheme";
import { getExtensionVersion } from "./util/util";
import { getExtensionUri, getNonce, getUniqueId } from "./util/vscode";
Expand Down Expand Up @@ -51,7 +50,12 @@ export class ContinueGUIWebviewViewProvider
private readonly windowId: string,
private readonly extensionContext: vscode.ExtensionContext,
) {
this.webviewProtocol = new VsCodeWebviewProtocol();
this.webviewProtocol = new VsCodeWebviewProtocol(
(async () => {
const configHandler = await this.configHandlerPromise;
return configHandler.reloadConfig();
}).bind(this),
);
}

getSidebarContent(
Expand Down
Loading

0 comments on commit 48a02e8

Please sign in to comment.