diff --git a/docs/guides/ai_completion.md b/docs/guides/ai_completion.md
index befc23f33b..4c222f959f 100644
--- a/docs/guides/ai_completion.md
+++ b/docs/guides/ai_completion.md
@@ -26,14 +26,13 @@ This feature is currently experimental and is not enabled by default. To enable
1. You need add the following to your `~/.marimo.toml`:
```toml
-[experimental]
-ai = true
-```
-
-2. Add your OpenAI API key to your environment:
-
-```bash
-export OPENAI_API_KEY=your-api-key
+[ai.open_ai]
+# Get your API key from https://platform.openai.com/account/api-keys
+api_key = "sk-..."
+# Choose a model, we recommend "gpt-3.5-turbo"
+model = "gpt-3.5-turbo"
+# Change the base_url if you are using a different OpenAI-compatible API
+base_url = "https://api.openai.com"
```
Once enabled, you can use AI completion by pressing `Ctrl/Cmd-Shift-e` in a cell. This will open an input to modify the cell using AI.
@@ -44,3 +43,9 @@ Once enabled, you can use AI completion by pressing `Ctrl/Cmd-Shift-e` in a cell
Use AI to modify a cell by pressing `Ctrl/Cmd-Shift-e`.
+
+### Using other AI providers
+
+marimo supports OpenAI's GPT-3.5 API by default. If your provider is compatible with OpenAI's API, you can use it by changing the `base_url` in the configuration.
+
+For other providers not compatible with OpenAI's API, please submit a [feature request](https://github.com/marimo-team/marimo/issues) or "thumbs up" an existing one.
diff --git a/frontend/src/components/app-config/user-config-form.tsx b/frontend/src/components/app-config/user-config-form.tsx
index 0be0f6e412..2f68f73449 100644
--- a/frontend/src/components/app-config/user-config-form.tsx
+++ b/frontend/src/components/app-config/user-config-form.tsx
@@ -23,6 +23,7 @@ import { SettingTitle, SettingDescription, SettingSubtitle } from "./common";
import { THEMES } from "@/theme/useTheme";
import { isPyodide } from "@/core/pyodide/utils";
import { PackageManagerNames } from "../../core/config/config-schema";
+import { Kbd } from "../ui/kbd";
export const UserConfigForm: React.FC = () => {
const [config, setConfig] = useUserConfig();
@@ -335,6 +336,64 @@ export const UserConfigForm: React.FC = () => {
)}
/>
+
GitHub Copilot
{
@@ -72,6 +72,7 @@ export function useCellActionButtons({ cell }: Props) {
const runCell = useRunCell(cell?.cellId);
const { openModal } = useImperativeModal();
const setAiCompletionCell = useSetAtom(aiCompletionCellAtom);
+ const [userConfig] = useUserConfig();
if (!cell) {
return [];
}
@@ -159,7 +160,7 @@ export function useCellActionButtons({ cell }: Props) {
{
icon: ,
label: "AI completion",
- hidden: !getFeatureFlag("ai"),
+ hidden: !userConfig.ai.open_ai?.api_key,
handle: () => {
setAiCompletionCell((current) =>
current === cellId ? null : cellId,
diff --git a/frontend/src/core/config/__tests__/config-schema.test.ts b/frontend/src/core/config/__tests__/config-schema.test.ts
index 606e5f47c6..a93320fb91 100644
--- a/frontend/src/core/config/__tests__/config-schema.test.ts
+++ b/frontend/src/core/config/__tests__/config-schema.test.ts
@@ -15,6 +15,7 @@ test("default UserConfig - empty", () => {
const defaultConfig = UserConfigSchema.parse({});
expect(defaultConfig).toMatchInlineSnapshot(`
{
+ "ai": {},
"completion": {
"activate_on_typing": true,
"copilot": false,
@@ -58,6 +59,7 @@ test("default UserConfig - one level", () => {
});
expect(defaultConfig).toMatchInlineSnapshot(`
{
+ "ai": {},
"completion": {
"activate_on_typing": true,
"copilot": false,
diff --git a/frontend/src/core/config/config-schema.ts b/frontend/src/core/config/config-schema.ts
index 6a0d3d5f62..db0c136fd7 100644
--- a/frontend/src/core/config/config-schema.ts
+++ b/frontend/src/core/config/config-schema.ts
@@ -67,10 +67,19 @@ export const UserConfigSchema = z
manager: z.enum(PackageManagerNames).default("pip"),
})
.default({ manager: "pip" }),
- experimental: z
+ ai: z
.object({
- ai: z.boolean().optional(),
+ open_ai: z
+ .object({
+ api_key: z.string().optional(),
+ base_url: z.string().optional(),
+ model: z.string().optional(),
+ })
+ .optional(),
})
+ .default({}),
+ experimental: z
+ .object({})
// Pass through so that we don't remove any extra keys that the user has added.
.passthrough()
.default({}),
diff --git a/frontend/src/core/config/feature-flag.ts b/frontend/src/core/config/feature-flag.ts
index a787d3f9da..d180a46788 100644
--- a/frontend/src/core/config/feature-flag.ts
+++ b/frontend/src/core/config/feature-flag.ts
@@ -7,13 +7,10 @@ import { getUserConfig } from "./config";
// eslint-disable-next-line @typescript-eslint/no-empty-interface
export interface ExperimentalFeatures {
- // None yet
- ai: boolean;
+ // Add new feature flags here
}
-const defaultValues: ExperimentalFeatures = {
- ai: process.env.NODE_ENV === "development",
-};
+const defaultValues: ExperimentalFeatures = {};
export function getFeatureFlag(
feature: T,
diff --git a/frontend/src/stories/cell.stories.tsx b/frontend/src/stories/cell.stories.tsx
index 0c23c5a508..082a53165b 100644
--- a/frontend/src/stories/cell.stories.tsx
+++ b/frontend/src/stories/cell.stories.tsx
@@ -76,6 +76,7 @@ const props: CellProps = {
package_management: {
manager: "pip",
},
+ ai: {},
experimental: {},
},
};
diff --git a/marimo/_config/config.py b/marimo/_config/config.py
index de38f63aa5..998a5574a4 100644
--- a/marimo/_config/config.py
+++ b/marimo/_config/config.py
@@ -1,6 +1,13 @@
# Copyright 2024 Marimo. All rights reserved.
from __future__ import annotations
+import sys
+
+if sys.version_info < (3, 11):
+ from typing_extensions import NotRequired
+else:
+ from typing import NotRequired
+
from typing import Any, Dict, Literal, TypedDict, Union, cast
from marimo._output.rich_help import mddoc
@@ -8,7 +15,7 @@
@mddoc
-class CompletionConfig(TypedDict, total=False):
+class CompletionConfig(TypedDict):
"""Configuration for code completion.
A dict with key/value pairs configuring code completion in the marimo
@@ -22,12 +29,11 @@ class CompletionConfig(TypedDict, total=False):
"""
activate_on_typing: bool
-
copilot: bool
@mddoc
-class SaveConfig(TypedDict, total=False):
+class SaveConfig(TypedDict):
"""Configuration for saving.
**Keys.**
@@ -55,7 +61,7 @@ class KeymapConfig(TypedDict, total=False):
@mddoc
-class RuntimeConfig(TypedDict, total=False):
+class RuntimeConfig(TypedDict):
"""Configuration for runtime.
**Keys.**
@@ -70,7 +76,7 @@ class RuntimeConfig(TypedDict, total=False):
@mddoc
-class DisplayConfig(TypedDict, total=False):
+class DisplayConfig(TypedDict):
"""Configuration for display.
**Keys.**
@@ -86,7 +92,7 @@ class DisplayConfig(TypedDict, total=False):
@mddoc
-class FormattingConfig(TypedDict, total=False):
+class FormattingConfig(TypedDict):
"""Configuration for code formatting.
**Keys.**
@@ -97,7 +103,7 @@ class FormattingConfig(TypedDict, total=False):
line_length: int
-class ServerConfig(TypedDict, total=False):
+class ServerConfig(TypedDict):
"""Configuration for the server.
**Keys.**
@@ -109,7 +115,7 @@ class ServerConfig(TypedDict, total=False):
browser: Union[Literal["default"], str]
-class PackageManagementConfig(TypedDict, total=False):
+class PackageManagementConfig(TypedDict):
"""Configuration options for package management.
**Keys.**
@@ -120,26 +126,36 @@ class PackageManagementConfig(TypedDict, total=False):
manager: Literal["pip", "rye", "uv", "poetry", "pixi"]
-@mddoc
-class MarimoConfig(TypedDict, total=False):
- """Configuration for the marimo editor.
+class AiConfig(TypedDict):
+ """Configuration options for AI.
+
+ **Keys.**
+
+ - `open_ai`: the OpenAI config
+ """
+
+ open_ai: OpenAiConfig
- A marimo configuration is a Python `dict`. Configurations
- can be partially specified, with just a subset of possible keys.
- Partial configs will be augmented with default options.
- Use with `configure` to configure the editor. See `configure`
- documentation for details on how to register the configuration.
+class OpenAiConfig(TypedDict):
+ """Configuration options for OpenAI or OpenAI-compatible services.
- **Example.**
+ **Keys.**
- ```python3
- config: mo.config.MarimoConfig = {
- "completion": {"activate_on_typing": True},
- }
- ```
+ - `api_key`: the OpenAI API key
+ - `model`: the model to use
+ - `base_url`: the base URL for the API
"""
+ api_key: str
+ model: NotRequired[str]
+ base_url: NotRequired[str]
+
+
+@mddoc
+class MarimoConfig(TypedDict):
+ """Configuration for the marimo editor"""
+
completion: CompletionConfig
display: DisplayConfig
formatting: FormattingConfig
@@ -148,7 +164,8 @@ class MarimoConfig(TypedDict, total=False):
save: SaveConfig
server: ServerConfig
package_management: PackageManagementConfig
- experimental: Dict[str, Any]
+ ai: NotRequired[AiConfig]
+ experimental: NotRequired[Dict[str, Any]]
DEFAULT_CONFIG: MarimoConfig = {
@@ -171,11 +188,67 @@ class MarimoConfig(TypedDict, total=False):
}
-def merge_config(config: MarimoConfig) -> MarimoConfig:
+def merge_default_config(config: MarimoConfig) -> MarimoConfig:
"""Merge a user configuration with the default configuration."""
+ return merge_config(DEFAULT_CONFIG, config)
+
+
+def merge_config(
+ config: MarimoConfig, new_config: MarimoConfig
+) -> MarimoConfig:
+ """Merge a user configuration with a new configuration."""
return cast(
MarimoConfig,
deep_merge(
- cast(Dict[Any, Any], DEFAULT_CONFIG), cast(Dict[Any, Any], config)
+ cast(Dict[Any, Any], config), cast(Dict[Any, Any], new_config)
),
)
+
+
+def _deep_copy(obj: Any) -> Any:
+ if isinstance(obj, dict):
+ return {k: _deep_copy(v) for k, v in obj.items()} # type: ignore
+ if isinstance(obj, list):
+ return [_deep_copy(v) for v in obj] # type: ignore
+ return obj
+
+
+SECRET_PLACEHOLDER = "********"
+
+
+def mask_secrets(config: MarimoConfig) -> MarimoConfig:
+ def deep_remove_from_path(path: list[str], obj: Dict[str, Any]) -> None:
+ key = path[0]
+ if key not in obj:
+ return
+ if len(path) == 1:
+ if obj[key]:
+ obj[key] = SECRET_PLACEHOLDER
+ else:
+ deep_remove_from_path(path[1:], cast(Dict[str, Any], obj[key]))
+
+ secrets = [["ai", "open_ai", "api_key"]]
+
+ new_config = _deep_copy(config)
+ for secret in secrets:
+ deep_remove_from_path(secret, cast(Dict[str, Any], new_config))
+
+ return new_config # type: ignore
+
+
+def remove_secret_placeholders(config: MarimoConfig) -> MarimoConfig:
+ def deep_remove(obj: Any) -> Any:
+ if isinstance(obj, dict):
+ # Filter all keys with value SECRET_PLACEHOLDER
+ return {
+ k: deep_remove(v)
+ for k, v in obj.items()
+ if v != SECRET_PLACEHOLDER
+ } # type: ignore
+ if isinstance(obj, list):
+ return [deep_remove(v) for v in obj] # type: ignore
+ if obj == SECRET_PLACEHOLDER:
+ return None
+ return obj
+
+ return deep_remove(_deep_copy(config)) # type: ignore
diff --git a/marimo/_config/manager.py b/marimo/_config/manager.py
index 6936caa8e4..1b8dae3432 100644
--- a/marimo/_config/manager.py
+++ b/marimo/_config/manager.py
@@ -4,7 +4,13 @@
import tomlkit
from marimo import _loggers
-from marimo._config.config import MarimoConfig, merge_config
+from marimo._config.config import (
+ MarimoConfig,
+ mask_secrets,
+ merge_config,
+ merge_default_config,
+ remove_secret_placeholders,
+)
from marimo._config.utils import CONFIG_FILENAME, get_config_path, load_config
LOGGER = _loggers.marimo_logger()
@@ -17,13 +23,19 @@ def __init__(self) -> None:
def save_config(self, config: MarimoConfig) -> MarimoConfig:
config_path = self._get_config_path()
LOGGER.debug("Saving user configuration to %s", config_path)
+ # Remove the secret placeholders from the incoming config
+ config = remove_secret_placeholders(config)
+ # Merge the current config with the new config
+ merged = merge_config(self.config, config)
with open(config_path, "w", encoding="utf-8") as f:
- tomlkit.dump(config, f)
+ tomlkit.dump(merged, f)
- self.config = merge_config(config)
+ self.config = merge_default_config(merged)
return self.config
- def get_config(self) -> MarimoConfig:
+ def get_config(self, hide_secrets: bool = True) -> MarimoConfig:
+ if hide_secrets:
+ return mask_secrets(self.config)
return self.config
def _get_config_path(self) -> str:
diff --git a/marimo/_config/utils.py b/marimo/_config/utils.py
index 173283f177..991e3d6adf 100644
--- a/marimo/_config/utils.py
+++ b/marimo/_config/utils.py
@@ -8,7 +8,7 @@
from marimo._config.config import (
DEFAULT_CONFIG,
MarimoConfig,
- merge_config,
+ merge_default_config,
)
LOGGER = _loggers.marimo_logger()
@@ -103,7 +103,7 @@ def load_config() -> MarimoConfig:
LOGGER.error("Failed to read user config at %s", path)
LOGGER.error(str(e))
return DEFAULT_CONFIG
- return merge_config(cast(MarimoConfig, user_config))
+ return merge_default_config(cast(MarimoConfig, user_config))
else:
LOGGER.debug("No config found; loading default settings.")
return DEFAULT_CONFIG
diff --git a/marimo/_server/api/endpoints/ai.py b/marimo/_server/api/endpoints/ai.py
index 894a4c1c4b..ea2a27cc0b 100644
--- a/marimo/_server/api/endpoints/ai.py
+++ b/marimo/_server/api/endpoints/ai.py
@@ -1,8 +1,7 @@
# Copyright 2024 Marimo. All rights reserved.
from __future__ import annotations
-import os
-from typing import Generator
+from typing import Generator, Optional
from starlette.authentication import requires
from starlette.exceptions import HTTPException
@@ -37,16 +36,41 @@ async def ai_completion(
app_state = AppState(request)
app_state.require_current_session()
+ config = app_state.config_manager.get_config(hide_secrets=False)
body = await parse_request(request, cls=AiCompletionRequest)
- key = os.environ.get("OPENAI_API_KEY")
+ if "ai" not in config:
+ raise HTTPException(
+ status_code=HTTPStatus.BAD_REQUEST, detail="OpenAI not configured"
+ )
+ if "open_ai" not in config["ai"]:
+ raise HTTPException(
+ status_code=HTTPStatus.BAD_REQUEST, detail="OpenAI not configured"
+ )
+ if "api_key" not in config["ai"]["open_ai"]:
+ raise HTTPException(
+ status_code=HTTPStatus.BAD_REQUEST,
+ detail="OpenAI API key not configured",
+ )
+
+ key: str = config["ai"]["open_ai"]["api_key"]
if not key:
raise HTTPException(
status_code=HTTPStatus.BAD_REQUEST,
- detail="OpenAI API key not found in environment",
+ detail="OpenAI API key not configured",
)
+ base_url: Optional[str] = (
+ config.get("ai", {}).get("open_ai", {}).get("base_url", None)
+ )
+ if not base_url:
+ base_url = None
+ model: str = (
+ config.get("ai", {}).get("open_ai", {}).get("model", "gpt-3.5-turbo")
+ )
+ if not model:
+ model = "gpt-3.5-turbo"
- client = OpenAI(api_key=key)
+ client = OpenAI(api_key=key, base_url=base_url)
system_prompt = (
"You are a helpful assistant that can answer questions "
@@ -64,7 +88,7 @@ async def ai_completion(
prompt = f"{prompt}\n\nCurrent code:\n{body.code}"
response = client.chat.completions.create(
- model="gpt-3.5-turbo",
+ model=model,
messages=[
{
"role": "system",
diff --git a/pyproject.toml b/pyproject.toml
index 34b29d4457..e91c19b183 100644
--- a/pyproject.toml
+++ b/pyproject.toml
@@ -34,7 +34,7 @@ dependencies = [
# websockets for use with starlette
"websockets >= 10.0.0,<13.0.0",
# python <=3.10 compatibility
- "typing_extensions>=4.4.0; python_version < \"3.10\"",
+ "typing_extensions>=4.4.0; python_version < \"3.11\"",
# for rst parsing
"docutils>=0.17.0",
# for cell formatting; if user version is not compatible, no-op
diff --git a/tests/_config/test_config.py b/tests/_config/test_config.py
index 7fafefa1f2..33d6b7bb37 100644
--- a/tests/_config/test_config.py
+++ b/tests/_config/test_config.py
@@ -1,9 +1,16 @@
# Copyright 2024 Marimo. All rights reserved.
-from marimo._config.config import DEFAULT_CONFIG, MarimoConfig, merge_config
+from marimo._config.config import (
+ DEFAULT_CONFIG,
+ MarimoConfig,
+ mask_secrets,
+ merge_config,
+ merge_default_config,
+ remove_secret_placeholders,
+)
def assert_config(override: MarimoConfig) -> None:
- user_config = merge_config(override)
+ user_config = merge_default_config(override)
assert user_config == {**DEFAULT_CONFIG, **override}
@@ -28,3 +35,71 @@ def test_configure_full() -> None:
def test_configure_unknown() -> None:
assert_config({"super cool future config key": {"secret": "value"}}) # type: ignore[typeddict-unknown-key] # noqa: E501
+
+
+def test_merge_config() -> None:
+ prev_config = merge_default_config(
+ MarimoConfig(
+ ai={
+ "open_ai": {
+ "api_key": "super_secret",
+ }
+ },
+ )
+ )
+ assert prev_config["ai"]["open_ai"]["api_key"] == "super_secret"
+
+ new_config = merge_config(
+ prev_config,
+ MarimoConfig(
+ ai={
+ "open_ai": {
+ "model": "davinci",
+ }
+ },
+ ),
+ )
+
+ assert new_config["ai"]["open_ai"]["api_key"] == "super_secret"
+ assert new_config["ai"]["open_ai"]["model"] == "davinci"
+
+
+def test_mask_secrets() -> None:
+ config = MarimoConfig(ai={"open_ai": {"api_key": "super_secret"}})
+ assert config["ai"]["open_ai"]["api_key"] == "super_secret"
+
+ new_config = mask_secrets(config)
+ assert new_config["ai"]["open_ai"]["api_key"] == "********"
+
+ # Ensure the original config is not modified
+ assert config["ai"]["open_ai"]["api_key"] == "super_secret"
+
+
+def test_mask_secrets_empty() -> None:
+ config = MarimoConfig(ai={"open_ai": {"model": "davinci"}})
+ assert config["ai"]["open_ai"]["model"] == "davinci"
+
+ new_config = mask_secrets(config)
+ assert new_config["ai"]["open_ai"]["model"] == "davinci"
+ # Not added until the key is present
+ assert "api_key" not in new_config["ai"]["open_ai"]
+
+ # Ensure the original config is not modified
+ assert config["ai"]["open_ai"]["model"] == "davinci"
+
+ # Not added when key is ""
+ config["ai"]["open_ai"]["api_key"] = ""
+ new_config = mask_secrets(config)
+ assert new_config["ai"]["open_ai"]["api_key"] == ""
+ assert config["ai"]["open_ai"]["api_key"] == ""
+
+
+def test_remove_secret_placeholders() -> None:
+ config = MarimoConfig(ai={"open_ai": {"api_key": "********"}})
+ assert config["ai"]["open_ai"]["api_key"] == "********"
+
+ new_config = remove_secret_placeholders(config)
+ assert "api_key" not in new_config["ai"]["open_ai"]
+
+ # Ensure the original config is not modified
+ assert config["ai"]["open_ai"]["api_key"] == "********"
diff --git a/tests/_config/test_manager.py b/tests/_config/test_manager.py
new file mode 100644
index 0000000000..2108be6b3c
--- /dev/null
+++ b/tests/_config/test_manager.py
@@ -0,0 +1,76 @@
+import unittest
+from typing import Any
+from unittest.mock import patch
+
+from marimo._config.config import merge_default_config
+from marimo._config.manager import MarimoConfig, UserConfigManager
+
+
+class TestUserConfigManager(unittest.TestCase):
+ @patch("tomlkit.dump")
+ @patch("marimo._config.manager.load_config")
+ def test_save_config(self, mock_load: Any, mock_dump: Any) -> None:
+ mock_config = merge_default_config(MarimoConfig())
+ mock_load.return_value = mock_config
+ manager = UserConfigManager()
+
+ result = manager.save_config(mock_config)
+
+ mock_load.assert_called_once()
+ assert result == manager.config
+
+ assert mock_dump.mock_calls[0][1][0] == result
+
+ @patch("tomlkit.dump")
+ @patch("marimo._config.manager.load_config")
+ def test_can_save_secrets(self, mock_load: Any, mock_dump: Any) -> None:
+ mock_config = merge_default_config(MarimoConfig())
+ mock_load.return_value = mock_config
+ manager = UserConfigManager()
+
+ manager.save_config(
+ merge_default_config(
+ MarimoConfig(ai={"open_ai": {"api_key": "super_secret"}})
+ )
+ )
+
+ assert (
+ mock_dump.mock_calls[0][1][0]["ai"]["open_ai"]["api_key"]
+ == "super_secret"
+ )
+
+ # Do not overwrite secrets
+ manager.save_config(
+ merge_default_config(
+ MarimoConfig(ai={"open_ai": {"api_key": "********"}})
+ )
+ )
+ assert (
+ mock_dump.mock_calls[1][1][0]["ai"]["open_ai"]["api_key"]
+ == "super_secret"
+ )
+
+ @patch("marimo._config.manager.load_config")
+ def test_can_read_secrets(self, mock_load: Any) -> None:
+ mock_config = merge_default_config(
+ MarimoConfig(ai={"open_ai": {"api_key": "super_secret"}})
+ )
+ mock_load.return_value = mock_config
+ manager = UserConfigManager()
+
+ assert manager.get_config()["ai"]["open_ai"]["api_key"] == "********"
+ assert (
+ manager.get_config(hide_secrets=False)["ai"]["open_ai"]["api_key"]
+ == "super_secret"
+ )
+
+ @patch("marimo._config.manager.load_config")
+ def test_get_config(self, mock_load: Any) -> None:
+ mock_config = merge_default_config(MarimoConfig())
+ mock_load.return_value = mock_config
+ manager = UserConfigManager()
+
+ result = manager.get_config()
+
+ mock_load.assert_called_once()
+ assert result == manager.config
diff --git a/tests/_server/api/endpoints/test_ai.py b/tests/_server/api/endpoints/test_ai.py
index 145b087dec..fd4b3acd8b 100644
--- a/tests/_server/api/endpoints/test_ai.py
+++ b/tests/_server/api/endpoints/test_ai.py
@@ -1,5 +1,4 @@
# Copyright 2024 Marimo. All rights reserved.
-import os
from contextlib import contextmanager
from dataclasses import dataclass
from typing import Any, List
@@ -8,8 +7,9 @@
import pytest
from starlette.testclient import TestClient
+from marimo._config.manager import UserConfigManager
from marimo._dependencies.dependencies import DependencyManager
-from tests._server.conftest import get_session_manager
+from tests._server.conftest import get_session_manager, get_user_config_manager
from tests._server.mocks import with_session
SESSION_ID = "session-123"
@@ -47,20 +47,20 @@ def test_completion_without_token(
del openai_mock
filename = get_session_manager(client).filename
assert filename
+ user_config_manager = get_user_config_manager(client)
- response = client.post(
- "/api/ai/completion",
- headers=HEADERS,
- json={
- "prompt": "Help me create a dataframe",
- "include_other_code": "",
- "code": "",
- },
- )
+ with no_openai_config(user_config_manager):
+ response = client.post(
+ "/api/ai/completion",
+ headers=HEADERS,
+ json={
+ "prompt": "Help me create a dataframe",
+ "include_other_code": "",
+ "code": "",
+ },
+ )
assert response.status_code == 400, response.text
- assert response.json() == {
- "detail": "OpenAI API key not found in environment"
- }
+ assert response.json() == {"detail": "OpenAI API key not configured"}
@staticmethod
@with_session(SESSION_ID)
@@ -73,6 +73,7 @@ def test_completion_without_code(
) -> None:
filename = get_session_manager(client).filename
assert filename
+ user_config_manager = get_user_config_manager(client)
oaiclient = MagicMock()
openai_mock.return_value = oaiclient
@@ -83,7 +84,7 @@ def test_completion_without_code(
)
]
- with fake_openai_env():
+ with openai_config(user_config_manager):
response = client.post(
"/api/ai/completion",
headers=HEADERS,
@@ -111,6 +112,7 @@ def test_completion_with_code(
) -> None:
filename = get_session_manager(client).filename
assert filename
+ user_config_manager = get_user_config_manager(client)
oaiclient = MagicMock()
openai_mock.return_value = oaiclient
@@ -121,7 +123,7 @@ def test_completion_with_code(
)
]
- with fake_openai_env():
+ with openai_config(user_config_manager):
response = client.post(
"/api/ai/completion",
headers=HEADERS,
@@ -140,11 +142,134 @@ def test_completion_with_code(
"Help me create a dataframe\n\nCurrent code:\nimport pandas as pd" # noqa: E501
)
+ @staticmethod
+ @with_session(SESSION_ID)
+ @pytest.mark.skipif(
+ not HAS_DEPS, reason="optional dependencies not installed"
+ )
+ @patch("openai.OpenAI")
+ def test_completion_with_custom_model(
+ client: TestClient, openai_mock: Any
+ ) -> None:
+ filename = get_session_manager(client).filename
+ assert filename
+ user_config_manager = get_user_config_manager(client)
+
+ oaiclient = MagicMock()
+ openai_mock.return_value = oaiclient
+
+ oaiclient.chat.completions.create.return_value = [
+ FakeChoices(
+ choices=[Choice(delta=Delta(content="import pandas as pd"))]
+ )
+ ]
+
+ with openai_config_custom_model(user_config_manager):
+ response = client.post(
+ "/api/ai/completion",
+ headers=HEADERS,
+ json={
+ "prompt": "Help me create a dataframe",
+ "code": "import pandas as pd",
+ "include_other_code": "",
+ },
+ )
+ assert response.status_code == 200, response.text
+ # Assert the model it was called with
+ model = oaiclient.chat.completions.create.call_args.kwargs["model"]
+ assert model == "gpt-marimo"
+
+ @staticmethod
+ @with_session(SESSION_ID)
+ @pytest.mark.skipif(
+ not HAS_DEPS, reason="optional dependencies not installed"
+ )
+ @patch("openai.OpenAI")
+ def test_completion_with_custom_base_url(
+ client: TestClient, openai_mock: Any
+ ) -> None:
+ filename = get_session_manager(client).filename
+ assert filename
+ user_config_manager = get_user_config_manager(client)
+
+ oaiclient = MagicMock()
+ openai_mock.return_value = oaiclient
+
+ oaiclient.chat.completions.create.return_value = [
+ FakeChoices(
+ choices=[Choice(delta=Delta(content="import pandas as pd"))]
+ )
+ ]
+
+ with openai_config_custom_base_url(user_config_manager):
+ response = client.post(
+ "/api/ai/completion",
+ headers=HEADERS,
+ json={
+ "prompt": "Help me create a dataframe",
+ "code": "import pandas as pd",
+ "include_other_code": "",
+ },
+ )
+ assert response.status_code == 200, response.text
+ # Assert the base_url it was called with
+ base_url = openai_mock.call_args.kwargs["base_url"]
+ assert base_url == "https://my-openai-instance.com"
+
+
+@contextmanager
+def openai_config(config: UserConfigManager):
+ prev_config = config.get_config()
+ try:
+ config.save_config({"ai": {"open_ai": {"api_key": "fake-api"}}})
+ yield
+ finally:
+ config.save_config(prev_config)
+
+
+@contextmanager
+def openai_config_custom_model(config: UserConfigManager):
+ prev_config = config.get_config()
+ try:
+ config.save_config(
+ {
+ "ai": {
+ "open_ai": {
+ "api_key": "fake-api",
+ "model": "gpt-marimo",
+ }
+ }
+ }
+ )
+ yield
+ finally:
+ config.save_config(prev_config)
+
+
+@contextmanager
+def openai_config_custom_base_url(config: UserConfigManager):
+ prev_config = config.get_config()
+ try:
+ config.save_config(
+ {
+ "ai": {
+ "open_ai": {
+ "api_key": "fake-api",
+ "base_url": "https://my-openai-instance.com",
+ }
+ }
+ }
+ )
+ yield
+ finally:
+ config.save_config(prev_config)
+
@contextmanager
-def fake_openai_env():
+def no_openai_config(config: UserConfigManager):
+ prev_config = config.get_config()
try:
- os.environ["OPENAI_API_KEY"] = "fake-key"
+ config.save_config({"ai": {"open_ai": {"api_key": ""}}})
yield
finally:
- del os.environ["OPENAI_API_KEY"]
+ config.save_config(prev_config)
diff --git a/tests/_server/conftest.py b/tests/_server/conftest.py
index dffd64ccd0..b6b1ed2609 100644
--- a/tests/_server/conftest.py
+++ b/tests/_server/conftest.py
@@ -47,3 +47,7 @@ def client() -> Iterator[TestClient]:
def get_session_manager(client: TestClient) -> SessionManager:
return client.app.state.session_manager # type: ignore
+
+
+def get_user_config_manager(client: TestClient) -> UserConfigManager:
+ return client.app.state.config_manager # type: ignore