diff --git a/core/config/load.ts b/core/config/load.ts index 2f4344b23..ab1b7df87 100644 --- a/core/config/load.ts +++ b/core/config/load.ts @@ -1,3 +1,7 @@ +/** + * 2024-02 Modified by Lukas Prediger, Copyright (c) 2023 CSC - IT Center for Science Ltd. + */ + import { Config, ContextProviderWithParams, @@ -115,4 +119,18 @@ async function intermediateToFinalConfig( }; } -export { intermediateToFinalConfig, serializedToIntermediateConfig }; +function injectExtensionModelsToFinalConfig(config: ContinueConfig, extensionModels: readonly CustomLLM[]) { + var configWithExtensionModels = {...config}; + configWithExtensionModels.extensionModels = [...extensionModels]; + + var models = [...config.models]; + for (var modelDescription of extensionModels) { + const model = new CustomLLMClass(modelDescription); + models.push(model); + } + configWithExtensionModels.models = models; + + return configWithExtensionModels; +} + +export { injectExtensionModelsToFinalConfig, intermediateToFinalConfig, serializedToIntermediateConfig }; diff --git a/core/index.d.ts b/core/index.d.ts index 20763fc93..279a83ac3 100644 --- a/core/index.d.ts +++ b/core/index.d.ts @@ -1,3 +1,7 @@ +/** + * 2024-02 Modified by Lukas Prediger, Copyright (c) 2023 CSC - IT Center for Science Ltd. + */ + export interface ChunkWithoutID { content: string; startLine: number; @@ -586,6 +590,7 @@ export interface Config { export interface ContinueConfig { allowAnonymousTelemetry?: boolean; models: ILLM[]; + extensionModels?: CustomLLM[]; systemMessage?: string; completionOptions?: BaseCompletionOptions; slashCommands?: SlashCommand[]; diff --git a/extensions/vscode/src/activation/activate.ts b/extensions/vscode/src/activation/activate.ts index b4c267452..31a737752 100644 --- a/extensions/vscode/src/activation/activate.ts +++ b/extensions/vscode/src/activation/activate.ts @@ -1,3 +1,7 @@ +/** + * 2024-02 Modified by Lukas Prediger, Copyright (c) 2023 CSC - IT Center for Science Ltd. + */ + import { getTsConfigPath } from "core/util/paths"; import * as fs from "fs"; import path from "path"; @@ -86,7 +90,8 @@ export async function activateExtension(context: vscode.ExtensionContext) { setupInlineTips(context); showRefactorMigrationMessage(); - ideProtocolClient = new IdeProtocolClient(context); + ideProtocolClient = new IdeProtocolClient(); + context.subscriptions.push(ideProtocolClient); // Register Continue GUI as sidebar webview, and beginning a new session const provider = new ContinueGUIWebviewViewProvider(); @@ -157,4 +162,10 @@ export async function activateExtension(context: vscode.ExtensionContext) { } catch (e) { console.log("Error adding .continueignore file icon: ", e); } + + const extensionApi = { + "addExtensionModel": ideProtocolClient.addExtensionModel, + "removeExtensionModel": ideProtocolClient.removeExtensionModel, + }; + return extensionApi; } diff --git a/extensions/vscode/src/continueIdeClient.ts b/extensions/vscode/src/continueIdeClient.ts index 38cb966cc..e084bd93d 100644 --- a/extensions/vscode/src/continueIdeClient.ts +++ b/extensions/vscode/src/continueIdeClient.ts @@ -1,4 +1,8 @@ -import { FileEdit, RangeInFile } from "core"; +/** + * 2024-02 Modified by Lukas Prediger, Copyright (c) 2023 CSC - IT Center for Science Ltd. + */ + +import { CustomLLM, FileEdit, RangeInFile } from "core"; import { getConfigJsonPath, getDevDataFilePath } from "core/util/paths"; import { readFileSync, writeFileSync } from "fs"; import * as path from "path"; @@ -20,22 +24,26 @@ import { openEditorAndRevealRange, uriFromFilePath, } from "./util/vscode"; +import { VsCodeIde } from "./ideProtocol"; const util = require("util"); const exec = util.promisify(require("child_process").exec); const continueVirtualDocumentScheme = "continue"; -class IdeProtocolClient { +class IdeProtocolClient implements vscode.Disposable { private static PREVIOUS_BRANCH_FOR_WORKSPACE_DIR: { [dir: string]: string } = {}; - private readonly context: vscode.ExtensionContext; + private disposable: vscode.Disposable; - constructor(context: vscode.ExtensionContext) { - this.context = context; + constructor() { + const extensionModelsChangeSubscription = configHandler.onConfigChanged( + e => { this.configUpdate(e.config); }, + this + ); // Listen for file saving - vscode.workspace.onDidSaveTextDocument((event) => { + const configFileChangeSubscription = vscode.workspace.onDidSaveTextDocument((event) => { const filepath = event.uri.fsPath; if ( @@ -45,10 +53,7 @@ class IdeProtocolClient { filepath.endsWith(".continue\\config.ts") || filepath.endsWith(".continuerc.json") ) { - const config = readFileSync(getConfigJsonPath(), "utf8"); - const configJson = JSON.parse(config); - this.configUpdate(configJson); - configHandler.reloadConfig(); + configHandler.reloadConfig(new VsCodeIde()); } else if ( filepath.endsWith(".continueignore") || filepath.endsWith(".gitignore") @@ -96,7 +101,10 @@ class IdeProtocolClient { return uri.query; } })(); - context.subscriptions.push( + + this.disposable = vscode.Disposable.from( + extensionModelsChangeSubscription, + configFileChangeSubscription, vscode.workspace.registerTextDocumentContentProvider( continueVirtualDocumentScheme, documentContentProvider @@ -104,6 +112,10 @@ class IdeProtocolClient { ); } + dispose() { + this.disposable.dispose(); + } + visibleMessages: Set = new Set(); configUpdate(config: any) { @@ -113,6 +125,26 @@ class IdeProtocolClient { }); } + addExtensionModel(customLLM: CustomLLM, modelAddedCallback?: () => void, modelRemovedCallback?: () => void): vscode.Disposable { + var eventSubscription = configHandler.onExtensionModelsChange(e => { + if (modelAddedCallback && e.added?.find(addedModelTitle => addedModelTitle === customLLM.options?.title)) { + modelAddedCallback(); + } + + if (modelRemovedCallback && e.removed?.find(removedModelTitle => removedModelTitle === customLLM.options?.title)) { + modelRemovedCallback(); + } + }); + + configHandler.addExtensionModel(customLLM); + + return eventSubscription; + } + + removeExtensionModel(customLLMTitle: string) { + configHandler.removeExtensionModel(customLLMTitle); + } + async gotoDefinition( filepath: string, position: vscode.Position diff --git a/extensions/vscode/src/debugPanel.ts b/extensions/vscode/src/debugPanel.ts index 8e813d1b1..dcbe75e19 100644 --- a/extensions/vscode/src/debugPanel.ts +++ b/extensions/vscode/src/debugPanel.ts @@ -1,3 +1,7 @@ +/** + * 2024-02 Modified by Lukas Prediger, Copyright (c) 2023 CSC - IT Center for Science Ltd. + */ + import { ContextItemId, DiffLine, FileEdit, ModelDescription } from "core"; import { indexDocs } from "core/indexing/docs"; import TransformersJsEmbeddingsProvider from "core/indexing/embeddings/TransformersJsEmbeddingsProvider"; @@ -448,6 +452,10 @@ export function getSidebarContent( ideProtocolClient.logDevData(data.tableName, data.data); break; } + case "onLoad": + const config = await configHandler.loadConfig(ide); + ideProtocolClient.configUpdate(config); + break; case "addModel": { const model = data.model; const config = readFileSync(getConfigJsonPath(), "utf8"); @@ -461,7 +469,8 @@ export function getSidebarContent( 2 ); writeFileSync(getConfigJsonPath(), newConfigString); - ideProtocolClient.configUpdate(configJson); + // ideProtocolClient.configUpdate(configJson); + await configHandler.reloadConfig(ide); ideProtocolClient.openFile(getConfigJsonPath()); @@ -502,13 +511,17 @@ export function getSidebarContent( break; } case "deleteModel": { - const configJson = editConfigJson((config) => { - config.models = config.models.filter( - (m: any) => m.title !== data.title - ); - return config; - }); - ideProtocolClient.configUpdate(configJson); + // if the model is an extension model, we let configHandler handle the removal + if (!configHandler.removeExtensionModel(data.title)) { + // otherwise, we need to remove it from the config JSON and reload the config manually + const configJson = editConfigJson((config) => { + config.models = config.models.filter( + (m: any) => m.title !== data.title + ); + return config; + }); + await configHandler.reloadConfig(ide); + } break; } case "addOpenAIKey": { @@ -522,7 +535,7 @@ export function getSidebarContent( }); return config; }); - ideProtocolClient.configUpdate(configJson); + await configHandler.reloadConfig(ide); break; } case "llmStreamComplete": { diff --git a/extensions/vscode/src/extension.ts b/extensions/vscode/src/extension.ts index 725d09480..3782e88f5 100644 --- a/extensions/vscode/src/extension.ts +++ b/extensions/vscode/src/extension.ts @@ -1,5 +1,7 @@ /** * This is the entry point for the extension. + * + * 2024-02 Modified by Lukas Prediger, Copyright (c) 2023 CSC - IT Center for Science Ltd. */ import * as vscode from "vscode"; @@ -31,7 +33,8 @@ async function dynamicImportAndActivate(context: vscode.ExtensionContext) { const { activateExtension } = await import("./activation/activate"); try { - await activateExtension(context); + const ownApi = await activateExtension(context); + return ownApi; } catch (e) { console.log("Error activating extension: ", e); vscode.window @@ -52,7 +55,7 @@ async function dynamicImportAndActivate(context: vscode.ExtensionContext) { } export function activate(context: vscode.ExtensionContext) { - dynamicImportAndActivate(context); + return dynamicImportAndActivate(context); } export function deactivate() { diff --git a/extensions/vscode/src/loadConfig.ts b/extensions/vscode/src/loadConfig.ts index 3f208eee6..7a75a682b 100644 --- a/extensions/vscode/src/loadConfig.ts +++ b/extensions/vscode/src/loadConfig.ts @@ -1,29 +1,117 @@ -import { ContinueConfig, IDE, ILLM } from "core"; +/** + * 2024-02 Modified by Lukas Prediger, Copyright (c) 2023 CSC - IT Center for Science Ltd. + */ + +import { ContinueConfig, CustomLLM, IDE, ILLM } from "core"; import * as fs from "fs"; import { Agent, ProxyAgent, fetch } from "undici"; import * as vscode from "vscode"; import { webviewRequest } from "./debugPanel"; import { VsCodeIde, loadFullConfigNode } from "./ideProtocol"; +import { EventEmitter } from "vscode"; +import { injectExtensionModelsToFinalConfig } from "core/config/load"; const tls = require("tls"); const outputChannel = vscode.window.createOutputChannel( "Continue - LLM Prompt/Completion" ); +export interface ExtensionModelsChangeEvent { + readonly added?: readonly string[]; + readonly removed?: readonly string[]; +} + +export interface ConfigChangeEvent { + readonly config: ContinueConfig; +} + class VsCodeConfigHandler { savedConfig: ContinueConfig | undefined; - reloadConfig() { + savedConfigWithExtensionModels: ContinueConfig | undefined; + private extensionModels: CustomLLM[]; + private extensionModelsChangeEventEmitter: EventEmitter; + private configChangeEventEmitter: EventEmitter; + + constructor() { + this.extensionModels = []; + this.extensionModelsChangeEventEmitter = new EventEmitter(); + this.configChangeEventEmitter = new EventEmitter(); + } + + addExtensionModel(customLLM: CustomLLM) { + if (!customLLM.options?.title) { + throw new Error("The custom model must define a unique title.") + } + if (this.savedConfig?.models.some(knownModel => knownModel.title === customLLM.options?.title)) { + throw new Error("The title is already taken by another model.") + } + + this.extensionModels.push(customLLM); + this.extensionModelsChangeEventEmitter.fire({ + added: [customLLM.options.title], + }) + + if (this.savedConfig) { + this.savedConfigWithExtensionModels = injectExtensionModelsToFinalConfig(this.savedConfig, this.extensionModels); + this.fireConfigChanged(); + } + } + + removeExtensionModel(title: string): boolean { + const customLLM = this.extensionModels.find(knownModel => knownModel.options?.title === title); + if (customLLM) { + this.extensionModels = this.extensionModels.filter(knownModel => knownModel.options?.title !== title); + this.extensionModelsChangeEventEmitter.fire({ + removed: [title], + }); + if (this.savedConfig) { + this.savedConfigWithExtensionModels = injectExtensionModelsToFinalConfig(this.savedConfig, this.extensionModels); + this.fireConfigChanged(); + } + return true; + } + return false; + } + + get onExtensionModelsChange() { + return this.extensionModelsChangeEventEmitter.event; + } + + get onConfigChanged() { + return this.configChangeEventEmitter.event; + } + + private fireConfigChanged() { + if (this.savedConfigWithExtensionModels) { + this.configChangeEventEmitter.fire({ + config: this.savedConfigWithExtensionModels + }); + } else if (this.savedConfig) { + this.configChangeEventEmitter.fire({ + config: this.savedConfig + }); + } + } + + async reloadConfig(ide: IDE) { this.savedConfig = undefined; + this.savedConfigWithExtensionModels = undefined; + await this.loadConfig(ide); } async loadConfig(ide: IDE): Promise { - if (this.savedConfig) { - return this.savedConfig; + if (this.savedConfigWithExtensionModels) { + return this.savedConfigWithExtensionModels; + } + if (!this.savedConfig) { + this.savedConfig = await loadFullConfigNode(ide); } - this.savedConfig = await loadFullConfigNode(ide); - return this.savedConfig; + this.savedConfigWithExtensionModels = injectExtensionModelsToFinalConfig(this.savedConfig, this.extensionModels); + this.fireConfigChanged(); + return this.savedConfigWithExtensionModels; } + } export const configHandler = new VsCodeConfigHandler(); @@ -31,7 +119,8 @@ export const configHandler = new VsCodeConfigHandler(); const TIMEOUT = 7200; // 7200 seconds = 2 hours export async function llmFromTitle(title?: string): Promise { - let config = await configHandler.loadConfig(new VsCodeIde()); + const ide = new VsCodeIde(); + let config = await configHandler.loadConfig(ide); if (title === undefined) { const resp = await webviewRequest("getDefaultModelTitle"); @@ -45,8 +134,8 @@ export async function llmFromTitle(title?: string): Promise { : config.models[0]; if (!llm) { // Try to reload config - configHandler.reloadConfig(); - config = await configHandler.loadConfig(new VsCodeIde()); + await configHandler.reloadConfig(ide); + config = await configHandler.loadConfig(ide); llm = config.models.find((llm) => llm.title === title); if (!llm) { throw new Error(`Unknown model ${title}`); diff --git a/gui/src/hooks/useSetup.ts b/gui/src/hooks/useSetup.ts index 5d108be84..f12f8a762 100644 --- a/gui/src/hooks/useSetup.ts +++ b/gui/src/hooks/useSetup.ts @@ -1,9 +1,15 @@ +/** + * 2024-02 Modified by Lukas Prediger, Copyright (c) 2023 CSC - IT Center for Science Ltd. + */ + import { Dispatch } from "@reduxjs/toolkit"; import { useEffect } from "react"; import { setServerStatusMessage } from "../redux/slices/miscSlice"; import { errorPopup, isJetBrains, postToIde } from "../util/ide"; +import { ContinueConfig, CustomLLM } from "core"; import { + injectExtensionModelsToFinalConfig, intermediateToFinalConfig, serializedToIntermediateConfig, } from "core/config/load"; @@ -16,7 +22,7 @@ import { RootStore } from "../redux/store"; import useChatHandler from "./useChatHandler"; function useSetup(dispatch: Dispatch) { - const loadConfig = async () => { + const loadConfig = async (previousConfig?: ContinueConfig) => { try { const ide = new ExtensionIde(); let serialized = await ide.getSerializedConfig(); @@ -42,11 +48,15 @@ function useSetup(dispatch: Dispatch) { intermediate, async (filepath) => { return new ExtensionIde().readFile(filepath); - } + }, + ); + const finalConfigWithExtensionModels = injectExtensionModelsToFinalConfig( + finalConfig, + (previousConfig?.extensionModels as (CustomLLM[] | undefined)) || [] ); // Fall back to config.json - dispatch(setConfig(finalConfig)); + dispatch(setConfig(finalConfigWithExtensionModels)); } catch (e) { console.log("Error loading config.json: ", e); errorPopup(e.message); @@ -93,7 +103,8 @@ function useSetup(dispatch: Dispatch) { dispatch(setInactive()); break; case "configUpdate": - loadConfig(); + const config = event.data.config as ContinueConfig; + loadConfig(config); break; case "submitMessage": streamResponse(event.data.message); diff --git a/gui/src/redux/slices/stateSlice.ts b/gui/src/redux/slices/stateSlice.ts index b7ea70808..3619633a8 100644 --- a/gui/src/redux/slices/stateSlice.ts +++ b/gui/src/redux/slices/stateSlice.ts @@ -1,3 +1,7 @@ +/** + * 2024-02 Modified by Lukas Prediger, Copyright (c) 2023 CSC - IT Center for Science Ltd. + */ + import { createSlice } from "@reduxjs/toolkit"; import { JSONContent } from "@tiptap/react"; import { @@ -99,6 +103,7 @@ const initialState: RootStore["state"] = { title: "GPT-3.5-Turbo (Free Trial)", }), ], + extensionModels: [], slashCommands: [ EditSlashCommand, CommentSlashCommand,