From 7978c9e94f0652b8df6597cb414c913f4f9624e9 Mon Sep 17 00:00:00 2001 From: Mayne Date: Wed, 14 Jun 2023 08:12:53 +0800 Subject: [PATCH 1/7] feat(ai): openai function call --- app/[database]/ai-chat-message-prisma.tsx | 2 +- app/[database]/ai-chat.tsx | 108 ++++++++++++------ app/settings/ai/ai-form.tsx | 38 +++++-- components/sidebar/create-table.tsx | 15 ++- components/sidebar/helper.ts | 17 ++- components/ui/progress.tsx | 28 +++++ hooks/use-ai.ts | 77 +------------ hooks/use-auto-run-code.ts | 19 +++- hooks/use-sqlite.ts | 18 +++ hooks/use-table.ts | 22 +++- lib/ai/functions/index.ts | 24 ++++ lib/ai/functions/sql-query.ts | 15 +++ lib/ai/openai.ts | 128 ++++++++++++++++++++++ package.json | 4 +- pnpm-lock.yaml | 42 ++++++- 15 files changed, 422 insertions(+), 135 deletions(-) create mode 100644 components/ui/progress.tsx create mode 100644 lib/ai/functions/index.ts create mode 100644 lib/ai/functions/sql-query.ts create mode 100644 lib/ai/openai.ts diff --git a/app/[database]/ai-chat-message-prisma.tsx b/app/[database]/ai-chat-message-prisma.tsx index d150963f..08e3bfd9 100644 --- a/app/[database]/ai-chat-message-prisma.tsx +++ b/app/[database]/ai-chat-message-prisma.tsx @@ -68,7 +68,7 @@ export const AIMessage = ({ } return (
- + {message && }
) diff --git a/app/[database]/ai-chat.tsx b/app/[database]/ai-chat.tsx index 8f3c94af..b3c3ae19 100644 --- a/app/[database]/ai-chat.tsx +++ b/app/[database]/ai-chat.tsx @@ -7,8 +7,10 @@ import { useParams } from "next/navigation" import { useKeyPress, useSize } from "ahooks" import { Bot, Loader2, Paintbrush, User } from "lucide-react" +import { handleOpenAIFunctionCall } from "@/lib/ai/openai" import { useAI } from "@/hooks/use-ai" import { useAutoRunCode } from "@/hooks/use-auto-run-code" +import { useSqliteStore } from "@/hooks/use-sqlite" import { Button } from "@/components/ui/button" import { Textarea } from "@/components/ui/textarea" @@ -16,21 +18,20 @@ import { useConfigStore } from "../settings/store" import { AIMessage } from "./ai-chat-message-prisma" import { useTableChange } from "./hook" import { useDatabaseAppStore } from "./store" -import { useSqliteStore } from "@/hooks/use-sqlite" export const AIChat = () => { - const { currentTableSchema, setCurrentQuery } = useDatabaseAppStore() + const { currentTableSchema } = useDatabaseAppStore() const { askAI } = useAI() const { database, table } = useParams() const { aiConfig } = useConfigStore() const [input, setInput] = useState("") const [loading, setLoading] = useState(false) - const { autoRun: runCode, handleRunCode } = useAutoRunCode() + const { handleFunctionCall, handleRunCode } = useAutoRunCode() const divRef = useRef() const size = useSize(divRef) const [messages, setMessages] = useState< { - role: "user" | "assistant" + role: "user" | "assistant" | "function" content: string }[] >([]) @@ -59,20 +60,46 @@ export const AIChat = () => { allTables, databaseName: database, }) - const newMessages = [ - ..._messages, - { role: "assistant", content: response?.content! }, - ] - const thisMsgIndex = newMessages.length - 1 - setMessages(newMessages) - if (response?.content && aiConfig.autoRunScope) { - setTimeout(() => { - runCode(response.content, { - msgIndex: thisMsgIndex, - width: size?.width ?? 300, - }) - }, 1000) + + if (response?.finish_reason == "function_call") { + if (aiConfig.autoRunScope) { + const res = await handleOpenAIFunctionCall( + response.message!, + handleFunctionCall + ) + if (res) { + const { name, resp } = res + const newMessages = [ + ..._messages, + response.message, + { + role: "function", + name, + content: JSON.stringify(resp), + }, + ] + const newResponse = await askAI(newMessages, { + tableSchema: currentTableSchema, + allTables, + databaseName: database, + }) + const _newMessages = [ + ...newMessages, + { role: "assistant", content: newResponse?.message?.content }, + ] + console.log({ _newMessages }) + setMessages(_newMessages as any) + } + } + } else if (response?.message) { + const newMessages = [ + ..._messages, + { role: "assistant", content: response?.message?.content }, + ] + const thisMsgIndex = newMessages.length - 1 + setMessages(newMessages) } + setLoading(false) } @@ -97,28 +124,37 @@ export const AIChat = () => { first

)} - {messages.map((message, i) => ( -
- {message.role === "assistant" ? ( - <> - - - - ) : ( + {messages + .filter((m) => (m.role === "user" || m.role == "assistant") && m.content) + .map((message, i) => ( +
+ {message.role === "assistant" && ( + <> + + + + )} + {message.role === "user" && ( + <> + +

{message.content}

+ + )} + {/* {message.role === "function" && ( <> -

{message.content}

+

run function

- )} -
- ))} + )} */} +
+ ))}
{loading && }
diff --git a/app/settings/ai/ai-form.tsx b/app/settings/ai/ai-form.tsx index 3803c7cc..742b7a4a 100644 --- a/app/settings/ai/ai-form.tsx +++ b/app/settings/ai/ai-form.tsx @@ -28,7 +28,27 @@ const AIConfigFormSchema = z.object({ }) export type AIConfigFormValues = z.infer -export const AutoRunScopes = ["SQL.SELECT", "SQL.INSERT", "D3.CHART"] +export const AutoRunScopesWithDesc = [ + { + value: "SQL.SELECT", + description: "Selects data from a SQL table.", + }, + { + value: "UI.REFRESH", + description: "Refresh the UI after querying data.", + }, + { + value: "SQL.INSERT", + description: "Inserts data into a SQL table.", + }, + { + value: "D3.CHART", + description: "Creates a D3 chart.", + }, +] + +export const AutoRunScopes = AutoRunScopesWithDesc.map((item) => item.value) + // This can come from your database or API. const defaultValues: Partial = { // name: "Your name", @@ -89,32 +109,34 @@ export function AIConfigForm() { run.
- {AutoRunScopes.map((_item) => ( + {AutoRunScopesWithDesc.map(({ value, description }) => ( { return ( { return checked - ? field.onChange([...field.value, _item]) + ? field.onChange([...field.value, value]) : field.onChange( field.value?.filter( - (value) => value !== _item + (value) => value !== value ) ) }} /> - {_item} + + {description} + ) }} diff --git a/components/sidebar/create-table.tsx b/components/sidebar/create-table.tsx index 7d9f3130..7847d7b4 100644 --- a/components/sidebar/create-table.tsx +++ b/components/sidebar/create-table.tsx @@ -15,22 +15,31 @@ import { } from "@/components/ui/dialog" import { Input } from "@/components/ui/input" import { Label } from "@/components/ui/label" +import { Progress } from "@/components/ui/progress" import { csvFile2Sql } from "./helper" export function CreateTableDialog() { const [open, setOpen] = useState(false) const [tableName, setTableName] = useState("") + const [importing, setImporting] = useState(false) + const [progress, setProgress] = useState(0) const [file, setFile] = useState(null) const params = useParams() const router = useRouter() const { database } = params - const { createTable, createTableWithSql } = useSqlite(database) + const { createTable, createTableWithSqlAndInsertSqls } = useSqlite(database) const handleCreateTable = async () => { if (file) { const res = await csvFile2Sql(file, tableName.trim()) - await createTableWithSql(res.createTableSql, res.insertSql) + setImporting(true) + await createTableWithSqlAndInsertSqls( + res.createTableSql, + res.sqls, + setProgress + ) + setImporting(false) // await createTableWithSql(res.createTableSql, res.insertSql) router.push(`/${database}/${tableName}`) setOpen(false) @@ -49,6 +58,7 @@ export function CreateTableDialog() { + Create Table @@ -86,6 +96,7 @@ export function CreateTableDialog() { /> + {importing && }