Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: model/base-url settings for AI completion, bring out of experimental #1049

Merged
merged 4 commits into from
Apr 1, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
21 changes: 13 additions & 8 deletions docs/guides/ai_completion.md
Original file line number Diff line number Diff line change
Expand Up @@ -26,14 +26,13 @@ This feature is currently experimental and is not enabled by default. To enable
1. You need add the following to your `~/.marimo.toml`:

```toml
[experimental]
ai = true
```

2. Add your OpenAI API key to your environment:

```bash
export OPENAI_API_KEY=your-api-key
[ai.open_ai]
# Get your API key from https://platform.openai.com/account/api-keys
api_key = "sk-..."
# Choose a model, we recommend "gpt-3.5-turbo"
model = "gpt-3.5-turbo"
# Change the base_url if you are using a different OpenAI-compatible API
base_url = "https://api.openai.com"
```

Once enabled, you can use AI completion by pressing `Ctrl/Cmd-Shift-e` in a cell. This will open an input to modify the cell using AI.
Expand All @@ -44,3 +43,9 @@ Once enabled, you can use AI completion by pressing `Ctrl/Cmd-Shift-e` in a cell
<figcaption>Use AI to modify a cell by pressing `Ctrl/Cmd-Shift-e`.</figcaption>
</figure>
</div>

### Using other AI providers

marimo supports OpenAI's GPT-3.5 API by default. If your provider is compatible with OpenAI's API, you can use it by changing the `base_url` in the configuration.

For other providers not compatible with OpenAI's API, please submit a [feature request](https://github.com/marimo-team/marimo/issues) or "thumbs up" an existing one.
59 changes: 59 additions & 0 deletions frontend/src/components/app-config/user-config-form.tsx
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@ import { SettingTitle, SettingDescription, SettingSubtitle } from "./common";
import { THEMES } from "@/theme/useTheme";
import { isPyodide } from "@/core/pyodide/utils";
import { PackageManagerNames } from "../../core/config/config-schema";
import { Kbd } from "../ui/kbd";

export const UserConfigForm: React.FC = () => {
const [config, setConfig] = useUserConfig();
Expand Down Expand Up @@ -335,6 +336,64 @@ export const UserConfigForm: React.FC = () => {
)}
/>
</div>
<div className="flex flex-col gap-3">
<SettingSubtitle>AI Assist</SettingSubtitle>
<p className="text-sm text-muted-secondary">
You will need to store an API key in your{" "}
<Kbd className="inline">~/.marimo.toml</Kbd> file. See the{" "}
<a
className="text-link hover:underline"
href="https://docs.marimo.io/guides/ai_completion.html"
target="_blank"
rel="noreferrer"
>
documentation
</a>{" "}
for more information.
</p>
<FormField
control={form.control}
disabled={isWasm}
name="ai.open_ai.base_url"
render={({ field }) => (
<FormItem className="mb-2">
<FormLabel>Base URL</FormLabel>
<FormControl>
<Input
data-testid="code-editor-font-size-input"
className="m-0 inline-flex"
{...field}
value={field.value}
placeholder="https://api.openai.com"
onChange={(e) => field.onChange(e.target.value)}
/>
</FormControl>
<FormMessage />
</FormItem>
)}
/>
<FormField
control={form.control}
disabled={isWasm}
name="ai.open_ai.model"
render={({ field: { value, onChange, ...field } }) => (
<FormItem className="mb-2">
<FormLabel>Model</FormLabel>
<FormControl>
<Input
data-testid="code-editor-font-size-input"
className="m-0 inline-flex"
{...field}
defaultValue={value}
placeholder="gpt-3.5-turbo"
onBlur={(e) => onChange(e.target.value)}
/>
</FormControl>
<FormMessage />
</FormItem>
)}
/>
</div>
<div className="flex flex-col gap-3">
<SettingSubtitle>GitHub Copilot</SettingSubtitle>
<FormField
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,6 @@ import { saveCellConfig } from "@/core/network/requests";
import { EditorView } from "@codemirror/view";
import { useRunCell } from "../cell/useRunCells";
import { NameCellInput } from "./name-cell-input";
import { getFeatureFlag } from "@/core/config/feature-flag";
import { useSetAtom } from "jotai";
import { aiCompletionCellAtom } from "@/core/ai/state";
import { useImperativeModal } from "@/components/modal/ImperativeModal";
Expand All @@ -45,6 +44,7 @@ import {
} from "@/components/ui/dialog";
import { Label } from "@/components/ui/label";
import { MarkdownIcon, PythonIcon } from "../cell/code/icons";
import { useUserConfig } from "@/core/config/config";

export interface CellActionButtonProps
extends Pick<CellData, "name" | "config"> {
Expand Down Expand Up @@ -72,6 +72,7 @@ export function useCellActionButtons({ cell }: Props) {
const runCell = useRunCell(cell?.cellId);
const { openModal } = useImperativeModal();
const setAiCompletionCell = useSetAtom(aiCompletionCellAtom);
const [userConfig] = useUserConfig();
if (!cell) {
return [];
}
Expand Down Expand Up @@ -159,7 +160,7 @@ export function useCellActionButtons({ cell }: Props) {
{
icon: <SparklesIcon size={13} strokeWidth={1.5} />,
label: "AI completion",
hidden: !getFeatureFlag("ai"),
hidden: !userConfig.ai.open_ai?.api_key,
handle: () => {
setAiCompletionCell((current) =>
current === cellId ? null : cellId,
Expand Down
2 changes: 2 additions & 0 deletions frontend/src/core/config/__tests__/config-schema.test.ts
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@ test("default UserConfig - empty", () => {
const defaultConfig = UserConfigSchema.parse({});
expect(defaultConfig).toMatchInlineSnapshot(`
{
"ai": {},
"completion": {
"activate_on_typing": true,
"copilot": false,
Expand Down Expand Up @@ -58,6 +59,7 @@ test("default UserConfig - one level", () => {
});
expect(defaultConfig).toMatchInlineSnapshot(`
{
"ai": {},
"completion": {
"activate_on_typing": true,
"copilot": false,
Expand Down
13 changes: 11 additions & 2 deletions frontend/src/core/config/config-schema.ts
Original file line number Diff line number Diff line change
Expand Up @@ -67,10 +67,19 @@ export const UserConfigSchema = z
manager: z.enum(PackageManagerNames).default("pip"),
})
.default({ manager: "pip" }),
experimental: z
ai: z
.object({
ai: z.boolean().optional(),
open_ai: z
.object({
api_key: z.string().optional(),
base_url: z.string().optional(),
model: z.string().optional(),
})
.optional(),
})
.default({}),
experimental: z
.object({})
// Pass through so that we don't remove any extra keys that the user has added.
.passthrough()
.default({}),
Expand Down
7 changes: 2 additions & 5 deletions frontend/src/core/config/feature-flag.ts
Original file line number Diff line number Diff line change
Expand Up @@ -7,13 +7,10 @@ import { getUserConfig } from "./config";

// eslint-disable-next-line @typescript-eslint/no-empty-interface
export interface ExperimentalFeatures {
// None yet
ai: boolean;
// Add new feature flags here
}

const defaultValues: ExperimentalFeatures = {
ai: process.env.NODE_ENV === "development",
};
const defaultValues: ExperimentalFeatures = {};

export function getFeatureFlag<T extends keyof ExperimentalFeatures>(
feature: T,
Expand Down
1 change: 1 addition & 0 deletions frontend/src/stories/cell.stories.tsx
Original file line number Diff line number Diff line change
Expand Up @@ -76,6 +76,7 @@ const props: CellProps = {
package_management: {
manager: "pip",
},
ai: {},
experimental: {},
},
};
Expand Down
Loading
Loading