From 305fd5ee3f3cee5662f8f78a7a67009b85ac8cde Mon Sep 17 00:00:00 2001 From: Myles Scolnick Date: Mon, 1 Apr 2024 15:05:29 -0600 Subject: [PATCH 1/4] feat: model/base-url settings for AI completion, bring out of experimental --- docs/guides/ai_completion.md | 21 ++- .../app-config/user-config-form.tsx | 59 +++++++ .../editor/actions/useCellActionButton.tsx | 5 +- .../config/__tests__/config-schema.test.ts | 2 + frontend/src/core/config/config-schema.ts | 13 +- frontend/src/core/config/feature-flag.ts | 7 +- marimo/_config/config.py | 100 +++++++++-- marimo/_config/manager.py | 20 ++- marimo/_config/utils.py | 4 +- marimo/_server/api/endpoints/ai.py | 37 +++- tests/_config/test_config.py | 79 ++++++++- tests/_config/test_manager.py | 76 ++++++++ tests/_server/api/endpoints/test_ai.py | 163 ++++++++++++++++-- tests/_server/conftest.py | 4 + 14 files changed, 526 insertions(+), 64 deletions(-) create mode 100644 tests/_config/test_manager.py diff --git a/docs/guides/ai_completion.md b/docs/guides/ai_completion.md index befc23f33b..4c222f959f 100644 --- a/docs/guides/ai_completion.md +++ b/docs/guides/ai_completion.md @@ -26,14 +26,13 @@ This feature is currently experimental and is not enabled by default. To enable 1. You need add the following to your `~/.marimo.toml`: ```toml -[experimental] -ai = true -``` - -2. Add your OpenAI API key to your environment: - -```bash -export OPENAI_API_KEY=your-api-key +[ai.open_ai] +# Get your API key from https://platform.openai.com/account/api-keys +api_key = "sk-..." +# Choose a model, we recommend "gpt-3.5-turbo" +model = "gpt-3.5-turbo" +# Change the base_url if you are using a different OpenAI-compatible API +base_url = "https://api.openai.com" ``` Once enabled, you can use AI completion by pressing `Ctrl/Cmd-Shift-e` in a cell. This will open an input to modify the cell using AI. @@ -44,3 +43,9 @@ Once enabled, you can use AI completion by pressing `Ctrl/Cmd-Shift-e` in a cell
Use AI to modify a cell by pressing `Ctrl/Cmd-Shift-e`.
+ +### Using other AI providers + +marimo supports OpenAI's GPT-3.5 API by default. If your provider is compatible with OpenAI's API, you can use it by changing the `base_url` in the configuration. + +For other providers not compatible with OpenAI's API, please submit a [feature request](https://github.com/marimo-team/marimo/issues) or "thumbs up" an existing one. diff --git a/frontend/src/components/app-config/user-config-form.tsx b/frontend/src/components/app-config/user-config-form.tsx index 0be0f6e412..2f68f73449 100644 --- a/frontend/src/components/app-config/user-config-form.tsx +++ b/frontend/src/components/app-config/user-config-form.tsx @@ -23,6 +23,7 @@ import { SettingTitle, SettingDescription, SettingSubtitle } from "./common"; import { THEMES } from "@/theme/useTheme"; import { isPyodide } from "@/core/pyodide/utils"; import { PackageManagerNames } from "../../core/config/config-schema"; +import { Kbd } from "../ui/kbd"; export const UserConfigForm: React.FC = () => { const [config, setConfig] = useUserConfig(); @@ -335,6 +336,64 @@ export const UserConfigForm: React.FC = () => { )} /> +
+ AI Assist +

+ You will need to store an API key in your{" "} + ~/.marimo.toml file. See the{" "} + + documentation + {" "} + for more information. +

+ ( + + Base URL + + field.onChange(e.target.value)} + /> + + + + )} + /> + ( + + Model + + onChange(e.target.value)} + /> + + + + )} + /> +
GitHub Copilot { @@ -72,6 +72,7 @@ export function useCellActionButtons({ cell }: Props) { const runCell = useRunCell(cell?.cellId); const { openModal } = useImperativeModal(); const setAiCompletionCell = useSetAtom(aiCompletionCellAtom); + const [userConfig] = useUserConfig(); if (!cell) { return []; } @@ -159,7 +160,7 @@ export function useCellActionButtons({ cell }: Props) { { icon: , label: "AI completion", - hidden: !getFeatureFlag("ai"), + hidden: !userConfig.ai.open_ai?.api_key, handle: () => { setAiCompletionCell((current) => current === cellId ? null : cellId, diff --git a/frontend/src/core/config/__tests__/config-schema.test.ts b/frontend/src/core/config/__tests__/config-schema.test.ts index 606e5f47c6..a93320fb91 100644 --- a/frontend/src/core/config/__tests__/config-schema.test.ts +++ b/frontend/src/core/config/__tests__/config-schema.test.ts @@ -15,6 +15,7 @@ test("default UserConfig - empty", () => { const defaultConfig = UserConfigSchema.parse({}); expect(defaultConfig).toMatchInlineSnapshot(` { + "ai": {}, "completion": { "activate_on_typing": true, "copilot": false, @@ -58,6 +59,7 @@ test("default UserConfig - one level", () => { }); expect(defaultConfig).toMatchInlineSnapshot(` { + "ai": {}, "completion": { "activate_on_typing": true, "copilot": false, diff --git a/frontend/src/core/config/config-schema.ts b/frontend/src/core/config/config-schema.ts index 6a0d3d5f62..db0c136fd7 100644 --- a/frontend/src/core/config/config-schema.ts +++ b/frontend/src/core/config/config-schema.ts @@ -67,10 +67,19 @@ export const UserConfigSchema = z manager: z.enum(PackageManagerNames).default("pip"), }) .default({ manager: "pip" }), - experimental: z + ai: z .object({ - ai: z.boolean().optional(), + open_ai: z + .object({ + api_key: z.string().optional(), + base_url: z.string().optional(), + model: z.string().optional(), + }) + .optional(), }) + .default({}), + experimental: z + .object({}) // Pass through so that we don't remove any extra keys that the user has added. .passthrough() .default({}), diff --git a/frontend/src/core/config/feature-flag.ts b/frontend/src/core/config/feature-flag.ts index a787d3f9da..d180a46788 100644 --- a/frontend/src/core/config/feature-flag.ts +++ b/frontend/src/core/config/feature-flag.ts @@ -7,13 +7,10 @@ import { getUserConfig } from "./config"; // eslint-disable-next-line @typescript-eslint/no-empty-interface export interface ExperimentalFeatures { - // None yet - ai: boolean; + // Add new feature flags here } -const defaultValues: ExperimentalFeatures = { - ai: process.env.NODE_ENV === "development", -}; +const defaultValues: ExperimentalFeatures = {}; export function getFeatureFlag( feature: T, diff --git a/marimo/_config/config.py b/marimo/_config/config.py index de38f63aa5..1eef52a294 100644 --- a/marimo/_config/config.py +++ b/marimo/_config/config.py @@ -1,7 +1,7 @@ # Copyright 2024 Marimo. All rights reserved. from __future__ import annotations -from typing import Any, Dict, Literal, TypedDict, Union, cast +from typing import Any, Dict, Literal, Optional, TypedDict, Union, cast from marimo._output.rich_help import mddoc from marimo._utils.deep_merge import deep_merge @@ -120,6 +120,32 @@ class PackageManagementConfig(TypedDict, total=False): manager: Literal["pip", "rye", "uv", "poetry", "pixi"] +class AiConfig(TypedDict, total=False): + """Configuration options for AI. + + **Keys.** + + - `open_ai`: the OpenAI config + """ + + open_ai: Optional[OpenAiConfig] + + +class OpenAiConfig(TypedDict, total=False): + """Configuration options for OpenAI or OpenAI-compatible services. + + **Keys.** + + - `api_key`: the OpenAI API key + - `model`: the model to use + - `base_url`: the base URL for the API + """ + + api_key: Optional[str] + model: Optional[str] + base_url: Optional[str] + + @mddoc class MarimoConfig(TypedDict, total=False): """Configuration for the marimo editor. @@ -127,17 +153,6 @@ class MarimoConfig(TypedDict, total=False): A marimo configuration is a Python `dict`. Configurations can be partially specified, with just a subset of possible keys. Partial configs will be augmented with default options. - - Use with `configure` to configure the editor. See `configure` - documentation for details on how to register the configuration. - - **Example.** - - ```python3 - config: mo.config.MarimoConfig = { - "completion": {"activate_on_typing": True}, - } - ``` """ completion: CompletionConfig @@ -148,6 +163,7 @@ class MarimoConfig(TypedDict, total=False): save: SaveConfig server: ServerConfig package_management: PackageManagementConfig + ai: AiConfig experimental: Dict[str, Any] @@ -171,11 +187,67 @@ class MarimoConfig(TypedDict, total=False): } -def merge_config(config: MarimoConfig) -> MarimoConfig: +def merge_default_config(config: MarimoConfig) -> MarimoConfig: """Merge a user configuration with the default configuration.""" + return merge_config(DEFAULT_CONFIG, config) + + +def merge_config( + config: MarimoConfig, new_config: MarimoConfig +) -> MarimoConfig: + """Merge a user configuration with a new configuration.""" return cast( MarimoConfig, deep_merge( - cast(Dict[Any, Any], DEFAULT_CONFIG), cast(Dict[Any, Any], config) + cast(Dict[Any, Any], config), cast(Dict[Any, Any], new_config) ), ) + + +def _deep_copy(obj: Any) -> Any: + if isinstance(obj, dict): + return {k: _deep_copy(v) for k, v in obj.items()} # type: ignore + if isinstance(obj, list): + return [_deep_copy(v) for v in obj] # type: ignore + return obj + + +SECRET_PLACEHOLDER = "********" + + +def mask_secrets(config: MarimoConfig) -> MarimoConfig: + def deep_remove_from_path(path: list[str], obj: Dict[str, Any]) -> None: + key = path[0] + if key not in obj: + return + if len(path) == 1: + if obj[key]: + obj[key] = SECRET_PLACEHOLDER + else: + deep_remove_from_path(path[1:], cast(Dict[str, Any], obj[key])) + + secrets = [["ai", "open_ai", "api_key"]] + + new_config = _deep_copy(config) + for secret in secrets: + deep_remove_from_path(secret, cast(Dict[str, Any], new_config)) + + return new_config # type: ignore + + +def remove_secret_placeholders(config: MarimoConfig) -> MarimoConfig: + def deep_remove(obj: Any) -> Any: + if isinstance(obj, dict): + # Filter all keys with value SECRET_PLACEHOLDER + return { + k: deep_remove(v) + for k, v in obj.items() + if v != SECRET_PLACEHOLDER + } # type: ignore + if isinstance(obj, list): + return [deep_remove(v) for v in obj] # type: ignore + if obj == SECRET_PLACEHOLDER: + return None + return obj + + return deep_remove(_deep_copy(config)) # type: ignore diff --git a/marimo/_config/manager.py b/marimo/_config/manager.py index 6936caa8e4..1b8dae3432 100644 --- a/marimo/_config/manager.py +++ b/marimo/_config/manager.py @@ -4,7 +4,13 @@ import tomlkit from marimo import _loggers -from marimo._config.config import MarimoConfig, merge_config +from marimo._config.config import ( + MarimoConfig, + mask_secrets, + merge_config, + merge_default_config, + remove_secret_placeholders, +) from marimo._config.utils import CONFIG_FILENAME, get_config_path, load_config LOGGER = _loggers.marimo_logger() @@ -17,13 +23,19 @@ def __init__(self) -> None: def save_config(self, config: MarimoConfig) -> MarimoConfig: config_path = self._get_config_path() LOGGER.debug("Saving user configuration to %s", config_path) + # Remove the secret placeholders from the incoming config + config = remove_secret_placeholders(config) + # Merge the current config with the new config + merged = merge_config(self.config, config) with open(config_path, "w", encoding="utf-8") as f: - tomlkit.dump(config, f) + tomlkit.dump(merged, f) - self.config = merge_config(config) + self.config = merge_default_config(merged) return self.config - def get_config(self) -> MarimoConfig: + def get_config(self, hide_secrets: bool = True) -> MarimoConfig: + if hide_secrets: + return mask_secrets(self.config) return self.config def _get_config_path(self) -> str: diff --git a/marimo/_config/utils.py b/marimo/_config/utils.py index 173283f177..991e3d6adf 100644 --- a/marimo/_config/utils.py +++ b/marimo/_config/utils.py @@ -8,7 +8,7 @@ from marimo._config.config import ( DEFAULT_CONFIG, MarimoConfig, - merge_config, + merge_default_config, ) LOGGER = _loggers.marimo_logger() @@ -103,7 +103,7 @@ def load_config() -> MarimoConfig: LOGGER.error("Failed to read user config at %s", path) LOGGER.error(str(e)) return DEFAULT_CONFIG - return merge_config(cast(MarimoConfig, user_config)) + return merge_default_config(cast(MarimoConfig, user_config)) else: LOGGER.debug("No config found; loading default settings.") return DEFAULT_CONFIG diff --git a/marimo/_server/api/endpoints/ai.py b/marimo/_server/api/endpoints/ai.py index 894a4c1c4b..cb8581484a 100644 --- a/marimo/_server/api/endpoints/ai.py +++ b/marimo/_server/api/endpoints/ai.py @@ -1,8 +1,8 @@ # Copyright 2024 Marimo. All rights reserved. from __future__ import annotations -import os -from typing import Generator +from ast import Dict +from typing import Generator, Optional from starlette.authentication import requires from starlette.exceptions import HTTPException @@ -37,16 +37,41 @@ async def ai_completion( app_state = AppState(request) app_state.require_current_session() + config = app_state.config_manager.get_config(hide_secrets=False) body = await parse_request(request, cls=AiCompletionRequest) - key = os.environ.get("OPENAI_API_KEY") + if "ai" not in config: + raise HTTPException( + status_code=HTTPStatus.BAD_REQUEST, detail="OpenAI not configured" + ) + if "open_ai" not in config["ai"]: + raise HTTPException( + status_code=HTTPStatus.BAD_REQUEST, detail="OpenAI not configured" + ) + if "api_key" not in config["ai"]["open_ai"]: + raise HTTPException( + status_code=HTTPStatus.BAD_REQUEST, + detail="OpenAI API key not configured", + ) + + key: str = config["ai"]["open_ai"]["api_key"] if not key: raise HTTPException( status_code=HTTPStatus.BAD_REQUEST, - detail="OpenAI API key not found in environment", + detail="OpenAI API key not configured", ) + base_url: Optional[str] = ( + config.get("ai", {}).get("open_ai", {}).get("base_url", None) + ) + if not base_url: + base_url = None + model: str = ( + config.get("ai", {}).get("open_ai", {}).get("model", "gpt-3.5-turbo") + ) + if not model: + model = "gpt-3.5-turbo" - client = OpenAI(api_key=key) + client = OpenAI(api_key=key, base_url=base_url) system_prompt = ( "You are a helpful assistant that can answer questions " @@ -64,7 +89,7 @@ async def ai_completion( prompt = f"{prompt}\n\nCurrent code:\n{body.code}" response = client.chat.completions.create( - model="gpt-3.5-turbo", + model=model, messages=[ { "role": "system", diff --git a/tests/_config/test_config.py b/tests/_config/test_config.py index 7fafefa1f2..2b0edf8570 100644 --- a/tests/_config/test_config.py +++ b/tests/_config/test_config.py @@ -1,9 +1,16 @@ # Copyright 2024 Marimo. All rights reserved. -from marimo._config.config import DEFAULT_CONFIG, MarimoConfig, merge_config +from marimo._config.config import ( + DEFAULT_CONFIG, + MarimoConfig, + mask_secrets, + merge_config, + merge_default_config, + remove_secret_placeholders, +) def assert_config(override: MarimoConfig) -> None: - user_config = merge_config(override) + user_config = merge_default_config(override) assert user_config == {**DEFAULT_CONFIG, **override} @@ -28,3 +35,71 @@ def test_configure_full() -> None: def test_configure_unknown() -> None: assert_config({"super cool future config key": {"secret": "value"}}) # type: ignore[typeddict-unknown-key] # noqa: E501 + + +def test_merge_config() -> None: + prev_config = merge_default_config( + MarimoConfig( + ai={ + "open_ai": { + "api_key": "super_secret", + } + }, + ) + ) + assert prev_config["ai"]["open_ai"]["api_key"] == "super_secret" + + new_config = merge_config( + prev_config, + MarimoConfig( + ai={ + "open_ai": { + "model": "davinci", + } + }, + ), + ) + + assert new_config["ai"]["open_ai"]["api_key"] == "super_secret" + assert new_config["ai"]["open_ai"]["model"] == "davinci" + + +def test_mask_secrets() -> None: + config = MarimoConfig(ai={"open_ai": {"api_key": "super_secret"}}) + assert config["ai"]["open_ai"]["api_key"] == "super_secret" + + new_config = mask_secrets(config) + assert new_config["ai"]["open_ai"]["api_key"] == "********" + + # Ensure the original config is not modified + assert config["ai"]["open_ai"]["api_key"] == "super_secret" + + +def test_mask_secrets_empty() -> None: + config = MarimoConfig(ai={"open_ai": {"model": "davinci"}}) + assert config["ai"]["open_ai"]["model"] == "davinci" + + new_config = mask_secrets(config) + assert new_config["ai"]["open_ai"]["model"] == "davinci" + # Not added until the key is present + assert "api_key" not in new_config["ai"]["open_ai"] + + # Ensure the original config is not modified + assert config["ai"]["open_ai"]["model"] == "davinci" + + # Not added when key is "" + config["ai"]["open_ai"]["api_key"] = "" + new_config = mask_secrets(config) + assert new_config["ai"]["open_ai"]["api_key"] == "" + assert config["ai"]["open_ai"]["api_key"] == "" + + +def test_remove_secret_placeholders() -> None: + config = MarimoConfig(ai={"open_ai": {"api_key": "********"}}) + assert config["ai"]["open_ai"]["api_key"] == "********" + + new_config = remove_secret_placeholders(config) + assert new_config["ai"]["open_ai"]["api_key"] is None + + # Ensure the original config is not modified + assert config["ai"]["open_ai"]["api_key"] == "********" diff --git a/tests/_config/test_manager.py b/tests/_config/test_manager.py new file mode 100644 index 0000000000..2108be6b3c --- /dev/null +++ b/tests/_config/test_manager.py @@ -0,0 +1,76 @@ +import unittest +from typing import Any +from unittest.mock import patch + +from marimo._config.config import merge_default_config +from marimo._config.manager import MarimoConfig, UserConfigManager + + +class TestUserConfigManager(unittest.TestCase): + @patch("tomlkit.dump") + @patch("marimo._config.manager.load_config") + def test_save_config(self, mock_load: Any, mock_dump: Any) -> None: + mock_config = merge_default_config(MarimoConfig()) + mock_load.return_value = mock_config + manager = UserConfigManager() + + result = manager.save_config(mock_config) + + mock_load.assert_called_once() + assert result == manager.config + + assert mock_dump.mock_calls[0][1][0] == result + + @patch("tomlkit.dump") + @patch("marimo._config.manager.load_config") + def test_can_save_secrets(self, mock_load: Any, mock_dump: Any) -> None: + mock_config = merge_default_config(MarimoConfig()) + mock_load.return_value = mock_config + manager = UserConfigManager() + + manager.save_config( + merge_default_config( + MarimoConfig(ai={"open_ai": {"api_key": "super_secret"}}) + ) + ) + + assert ( + mock_dump.mock_calls[0][1][0]["ai"]["open_ai"]["api_key"] + == "super_secret" + ) + + # Do not overwrite secrets + manager.save_config( + merge_default_config( + MarimoConfig(ai={"open_ai": {"api_key": "********"}}) + ) + ) + assert ( + mock_dump.mock_calls[1][1][0]["ai"]["open_ai"]["api_key"] + == "super_secret" + ) + + @patch("marimo._config.manager.load_config") + def test_can_read_secrets(self, mock_load: Any) -> None: + mock_config = merge_default_config( + MarimoConfig(ai={"open_ai": {"api_key": "super_secret"}}) + ) + mock_load.return_value = mock_config + manager = UserConfigManager() + + assert manager.get_config()["ai"]["open_ai"]["api_key"] == "********" + assert ( + manager.get_config(hide_secrets=False)["ai"]["open_ai"]["api_key"] + == "super_secret" + ) + + @patch("marimo._config.manager.load_config") + def test_get_config(self, mock_load: Any) -> None: + mock_config = merge_default_config(MarimoConfig()) + mock_load.return_value = mock_config + manager = UserConfigManager() + + result = manager.get_config() + + mock_load.assert_called_once() + assert result == manager.config diff --git a/tests/_server/api/endpoints/test_ai.py b/tests/_server/api/endpoints/test_ai.py index 145b087dec..fd4b3acd8b 100644 --- a/tests/_server/api/endpoints/test_ai.py +++ b/tests/_server/api/endpoints/test_ai.py @@ -1,5 +1,4 @@ # Copyright 2024 Marimo. All rights reserved. -import os from contextlib import contextmanager from dataclasses import dataclass from typing import Any, List @@ -8,8 +7,9 @@ import pytest from starlette.testclient import TestClient +from marimo._config.manager import UserConfigManager from marimo._dependencies.dependencies import DependencyManager -from tests._server.conftest import get_session_manager +from tests._server.conftest import get_session_manager, get_user_config_manager from tests._server.mocks import with_session SESSION_ID = "session-123" @@ -47,20 +47,20 @@ def test_completion_without_token( del openai_mock filename = get_session_manager(client).filename assert filename + user_config_manager = get_user_config_manager(client) - response = client.post( - "/api/ai/completion", - headers=HEADERS, - json={ - "prompt": "Help me create a dataframe", - "include_other_code": "", - "code": "", - }, - ) + with no_openai_config(user_config_manager): + response = client.post( + "/api/ai/completion", + headers=HEADERS, + json={ + "prompt": "Help me create a dataframe", + "include_other_code": "", + "code": "", + }, + ) assert response.status_code == 400, response.text - assert response.json() == { - "detail": "OpenAI API key not found in environment" - } + assert response.json() == {"detail": "OpenAI API key not configured"} @staticmethod @with_session(SESSION_ID) @@ -73,6 +73,7 @@ def test_completion_without_code( ) -> None: filename = get_session_manager(client).filename assert filename + user_config_manager = get_user_config_manager(client) oaiclient = MagicMock() openai_mock.return_value = oaiclient @@ -83,7 +84,7 @@ def test_completion_without_code( ) ] - with fake_openai_env(): + with openai_config(user_config_manager): response = client.post( "/api/ai/completion", headers=HEADERS, @@ -111,6 +112,7 @@ def test_completion_with_code( ) -> None: filename = get_session_manager(client).filename assert filename + user_config_manager = get_user_config_manager(client) oaiclient = MagicMock() openai_mock.return_value = oaiclient @@ -121,7 +123,7 @@ def test_completion_with_code( ) ] - with fake_openai_env(): + with openai_config(user_config_manager): response = client.post( "/api/ai/completion", headers=HEADERS, @@ -140,11 +142,134 @@ def test_completion_with_code( "Help me create a dataframe\n\nCurrent code:\nimport pandas as pd" # noqa: E501 ) + @staticmethod + @with_session(SESSION_ID) + @pytest.mark.skipif( + not HAS_DEPS, reason="optional dependencies not installed" + ) + @patch("openai.OpenAI") + def test_completion_with_custom_model( + client: TestClient, openai_mock: Any + ) -> None: + filename = get_session_manager(client).filename + assert filename + user_config_manager = get_user_config_manager(client) + + oaiclient = MagicMock() + openai_mock.return_value = oaiclient + + oaiclient.chat.completions.create.return_value = [ + FakeChoices( + choices=[Choice(delta=Delta(content="import pandas as pd"))] + ) + ] + + with openai_config_custom_model(user_config_manager): + response = client.post( + "/api/ai/completion", + headers=HEADERS, + json={ + "prompt": "Help me create a dataframe", + "code": "import pandas as pd", + "include_other_code": "", + }, + ) + assert response.status_code == 200, response.text + # Assert the model it was called with + model = oaiclient.chat.completions.create.call_args.kwargs["model"] + assert model == "gpt-marimo" + + @staticmethod + @with_session(SESSION_ID) + @pytest.mark.skipif( + not HAS_DEPS, reason="optional dependencies not installed" + ) + @patch("openai.OpenAI") + def test_completion_with_custom_base_url( + client: TestClient, openai_mock: Any + ) -> None: + filename = get_session_manager(client).filename + assert filename + user_config_manager = get_user_config_manager(client) + + oaiclient = MagicMock() + openai_mock.return_value = oaiclient + + oaiclient.chat.completions.create.return_value = [ + FakeChoices( + choices=[Choice(delta=Delta(content="import pandas as pd"))] + ) + ] + + with openai_config_custom_base_url(user_config_manager): + response = client.post( + "/api/ai/completion", + headers=HEADERS, + json={ + "prompt": "Help me create a dataframe", + "code": "import pandas as pd", + "include_other_code": "", + }, + ) + assert response.status_code == 200, response.text + # Assert the base_url it was called with + base_url = openai_mock.call_args.kwargs["base_url"] + assert base_url == "https://my-openai-instance.com" + + +@contextmanager +def openai_config(config: UserConfigManager): + prev_config = config.get_config() + try: + config.save_config({"ai": {"open_ai": {"api_key": "fake-api"}}}) + yield + finally: + config.save_config(prev_config) + + +@contextmanager +def openai_config_custom_model(config: UserConfigManager): + prev_config = config.get_config() + try: + config.save_config( + { + "ai": { + "open_ai": { + "api_key": "fake-api", + "model": "gpt-marimo", + } + } + } + ) + yield + finally: + config.save_config(prev_config) + + +@contextmanager +def openai_config_custom_base_url(config: UserConfigManager): + prev_config = config.get_config() + try: + config.save_config( + { + "ai": { + "open_ai": { + "api_key": "fake-api", + "base_url": "https://my-openai-instance.com", + } + } + } + ) + yield + finally: + config.save_config(prev_config) + @contextmanager -def fake_openai_env(): +def no_openai_config(config: UserConfigManager): + prev_config = config.get_config() try: - os.environ["OPENAI_API_KEY"] = "fake-key" + config.save_config({"ai": {"open_ai": {"api_key": ""}}}) yield finally: - del os.environ["OPENAI_API_KEY"] + config.save_config(prev_config) diff --git a/tests/_server/conftest.py b/tests/_server/conftest.py index dffd64ccd0..b6b1ed2609 100644 --- a/tests/_server/conftest.py +++ b/tests/_server/conftest.py @@ -47,3 +47,7 @@ def client() -> Iterator[TestClient]: def get_session_manager(client: TestClient) -> SessionManager: return client.app.state.session_manager # type: ignore + + +def get_user_config_manager(client: TestClient) -> UserConfigManager: + return client.app.state.config_manager # type: ignore From 1631a5e70bf9131fdc043c0cfc9c47f942998bf5 Mon Sep 17 00:00:00 2001 From: Myles Scolnick Date: Mon, 1 Apr 2024 15:09:39 -0600 Subject: [PATCH 2/4] fix --- frontend/src/stories/cell.stories.tsx | 1 + 1 file changed, 1 insertion(+) diff --git a/frontend/src/stories/cell.stories.tsx b/frontend/src/stories/cell.stories.tsx index 0c23c5a508..082a53165b 100644 --- a/frontend/src/stories/cell.stories.tsx +++ b/frontend/src/stories/cell.stories.tsx @@ -76,6 +76,7 @@ const props: CellProps = { package_management: { manager: "pip", }, + ai: {}, experimental: {}, }, }; From 3690c6a8ff02b91f7444015aeafcfa53092ca005 Mon Sep 17 00:00:00 2001 From: Akshay Agrawal Date: Mon, 1 Apr 2024 14:32:24 -0700 Subject: [PATCH 3/4] improve typings for MarimoConfig - switch to TypedDict, total=True - use Typing.NotRequired to denote keys that don't need to be included --- marimo/_config/config.py | 49 +++++++++++++++--------------- marimo/_server/api/endpoints/ai.py | 1 - pyproject.toml | 2 +- 3 files changed, 26 insertions(+), 26 deletions(-) diff --git a/marimo/_config/config.py b/marimo/_config/config.py index 1eef52a294..998a5574a4 100644 --- a/marimo/_config/config.py +++ b/marimo/_config/config.py @@ -1,14 +1,21 @@ # Copyright 2024 Marimo. All rights reserved. from __future__ import annotations -from typing import Any, Dict, Literal, Optional, TypedDict, Union, cast +import sys + +if sys.version_info < (3, 11): + from typing_extensions import NotRequired +else: + from typing import NotRequired + +from typing import Any, Dict, Literal, TypedDict, Union, cast from marimo._output.rich_help import mddoc from marimo._utils.deep_merge import deep_merge @mddoc -class CompletionConfig(TypedDict, total=False): +class CompletionConfig(TypedDict): """Configuration for code completion. A dict with key/value pairs configuring code completion in the marimo @@ -22,12 +29,11 @@ class CompletionConfig(TypedDict, total=False): """ activate_on_typing: bool - copilot: bool @mddoc -class SaveConfig(TypedDict, total=False): +class SaveConfig(TypedDict): """Configuration for saving. **Keys.** @@ -55,7 +61,7 @@ class KeymapConfig(TypedDict, total=False): @mddoc -class RuntimeConfig(TypedDict, total=False): +class RuntimeConfig(TypedDict): """Configuration for runtime. **Keys.** @@ -70,7 +76,7 @@ class RuntimeConfig(TypedDict, total=False): @mddoc -class DisplayConfig(TypedDict, total=False): +class DisplayConfig(TypedDict): """Configuration for display. **Keys.** @@ -86,7 +92,7 @@ class DisplayConfig(TypedDict, total=False): @mddoc -class FormattingConfig(TypedDict, total=False): +class FormattingConfig(TypedDict): """Configuration for code formatting. **Keys.** @@ -97,7 +103,7 @@ class FormattingConfig(TypedDict, total=False): line_length: int -class ServerConfig(TypedDict, total=False): +class ServerConfig(TypedDict): """Configuration for the server. **Keys.** @@ -109,7 +115,7 @@ class ServerConfig(TypedDict, total=False): browser: Union[Literal["default"], str] -class PackageManagementConfig(TypedDict, total=False): +class PackageManagementConfig(TypedDict): """Configuration options for package management. **Keys.** @@ -120,7 +126,7 @@ class PackageManagementConfig(TypedDict, total=False): manager: Literal["pip", "rye", "uv", "poetry", "pixi"] -class AiConfig(TypedDict, total=False): +class AiConfig(TypedDict): """Configuration options for AI. **Keys.** @@ -128,10 +134,10 @@ class AiConfig(TypedDict, total=False): - `open_ai`: the OpenAI config """ - open_ai: Optional[OpenAiConfig] + open_ai: OpenAiConfig -class OpenAiConfig(TypedDict, total=False): +class OpenAiConfig(TypedDict): """Configuration options for OpenAI or OpenAI-compatible services. **Keys.** @@ -141,19 +147,14 @@ class OpenAiConfig(TypedDict, total=False): - `base_url`: the base URL for the API """ - api_key: Optional[str] - model: Optional[str] - base_url: Optional[str] + api_key: str + model: NotRequired[str] + base_url: NotRequired[str] @mddoc -class MarimoConfig(TypedDict, total=False): - """Configuration for the marimo editor. - - A marimo configuration is a Python `dict`. Configurations - can be partially specified, with just a subset of possible keys. - Partial configs will be augmented with default options. - """ +class MarimoConfig(TypedDict): + """Configuration for the marimo editor""" completion: CompletionConfig display: DisplayConfig @@ -163,8 +164,8 @@ class MarimoConfig(TypedDict, total=False): save: SaveConfig server: ServerConfig package_management: PackageManagementConfig - ai: AiConfig - experimental: Dict[str, Any] + ai: NotRequired[AiConfig] + experimental: NotRequired[Dict[str, Any]] DEFAULT_CONFIG: MarimoConfig = { diff --git a/marimo/_server/api/endpoints/ai.py b/marimo/_server/api/endpoints/ai.py index cb8581484a..ea2a27cc0b 100644 --- a/marimo/_server/api/endpoints/ai.py +++ b/marimo/_server/api/endpoints/ai.py @@ -1,7 +1,6 @@ # Copyright 2024 Marimo. All rights reserved. from __future__ import annotations -from ast import Dict from typing import Generator, Optional from starlette.authentication import requires diff --git a/pyproject.toml b/pyproject.toml index 34b29d4457..e91c19b183 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -34,7 +34,7 @@ dependencies = [ # websockets for use with starlette "websockets >= 10.0.0,<13.0.0", # python <=3.10 compatibility - "typing_extensions>=4.4.0; python_version < \"3.10\"", + "typing_extensions>=4.4.0; python_version < \"3.11\"", # for rst parsing "docutils>=0.17.0", # for cell formatting; if user version is not compatible, no-op From ffbd19b7b77fb9990e774e45801f91ad06253318 Mon Sep 17 00:00:00 2001 From: Myles Scolnick Date: Mon, 1 Apr 2024 15:51:20 -0600 Subject: [PATCH 4/4] fix test --- tests/_config/test_config.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/_config/test_config.py b/tests/_config/test_config.py index 2b0edf8570..33d6b7bb37 100644 --- a/tests/_config/test_config.py +++ b/tests/_config/test_config.py @@ -99,7 +99,7 @@ def test_remove_secret_placeholders() -> None: assert config["ai"]["open_ai"]["api_key"] == "********" new_config = remove_secret_placeholders(config) - assert new_config["ai"]["open_ai"]["api_key"] is None + assert "api_key" not in new_config["ai"]["open_ai"] # Ensure the original config is not modified assert config["ai"]["open_ai"]["api_key"] == "********"