From 7978c9e94f0652b8df6597cb414c913f4f9624e9 Mon Sep 17 00:00:00 2001
From: Mayne
Date: Wed, 14 Jun 2023 08:12:53 +0800
Subject: [PATCH 1/7] feat(ai): openai function call
---
app/[database]/ai-chat-message-prisma.tsx | 2 +-
app/[database]/ai-chat.tsx | 108 ++++++++++++------
app/settings/ai/ai-form.tsx | 38 +++++--
components/sidebar/create-table.tsx | 15 ++-
components/sidebar/helper.ts | 17 ++-
components/ui/progress.tsx | 28 +++++
hooks/use-ai.ts | 77 +------------
hooks/use-auto-run-code.ts | 19 +++-
hooks/use-sqlite.ts | 18 +++
hooks/use-table.ts | 22 +++-
lib/ai/functions/index.ts | 24 ++++
lib/ai/functions/sql-query.ts | 15 +++
lib/ai/openai.ts | 128 ++++++++++++++++++++++
package.json | 4 +-
pnpm-lock.yaml | 42 ++++++-
15 files changed, 422 insertions(+), 135 deletions(-)
create mode 100644 components/ui/progress.tsx
create mode 100644 lib/ai/functions/index.ts
create mode 100644 lib/ai/functions/sql-query.ts
create mode 100644 lib/ai/openai.ts
diff --git a/app/[database]/ai-chat-message-prisma.tsx b/app/[database]/ai-chat-message-prisma.tsx
index d150963f..08e3bfd9 100644
--- a/app/[database]/ai-chat-message-prisma.tsx
+++ b/app/[database]/ai-chat-message-prisma.tsx
@@ -68,7 +68,7 @@ export const AIMessage = ({
}
return (
)
diff --git a/app/[database]/ai-chat.tsx b/app/[database]/ai-chat.tsx
index 8f3c94af..b3c3ae19 100644
--- a/app/[database]/ai-chat.tsx
+++ b/app/[database]/ai-chat.tsx
@@ -7,8 +7,10 @@ import { useParams } from "next/navigation"
import { useKeyPress, useSize } from "ahooks"
import { Bot, Loader2, Paintbrush, User } from "lucide-react"
+import { handleOpenAIFunctionCall } from "@/lib/ai/openai"
import { useAI } from "@/hooks/use-ai"
import { useAutoRunCode } from "@/hooks/use-auto-run-code"
+import { useSqliteStore } from "@/hooks/use-sqlite"
import { Button } from "@/components/ui/button"
import { Textarea } from "@/components/ui/textarea"
@@ -16,21 +18,20 @@ import { useConfigStore } from "../settings/store"
import { AIMessage } from "./ai-chat-message-prisma"
import { useTableChange } from "./hook"
import { useDatabaseAppStore } from "./store"
-import { useSqliteStore } from "@/hooks/use-sqlite"
export const AIChat = () => {
- const { currentTableSchema, setCurrentQuery } = useDatabaseAppStore()
+ const { currentTableSchema } = useDatabaseAppStore()
const { askAI } = useAI()
const { database, table } = useParams()
const { aiConfig } = useConfigStore()
const [input, setInput] = useState("")
const [loading, setLoading] = useState(false)
- const { autoRun: runCode, handleRunCode } = useAutoRunCode()
+ const { handleFunctionCall, handleRunCode } = useAutoRunCode()
const divRef = useRef()
const size = useSize(divRef)
const [messages, setMessages] = useState<
{
- role: "user" | "assistant"
+ role: "user" | "assistant" | "function"
content: string
}[]
>([])
@@ -59,20 +60,46 @@ export const AIChat = () => {
allTables,
databaseName: database,
})
- const newMessages = [
- ..._messages,
- { role: "assistant", content: response?.content! },
- ]
- const thisMsgIndex = newMessages.length - 1
- setMessages(newMessages)
- if (response?.content && aiConfig.autoRunScope) {
- setTimeout(() => {
- runCode(response.content, {
- msgIndex: thisMsgIndex,
- width: size?.width ?? 300,
- })
- }, 1000)
+
+ if (response?.finish_reason == "function_call") {
+ if (aiConfig.autoRunScope) {
+ const res = await handleOpenAIFunctionCall(
+ response.message!,
+ handleFunctionCall
+ )
+ if (res) {
+ const { name, resp } = res
+ const newMessages = [
+ ..._messages,
+ response.message,
+ {
+ role: "function",
+ name,
+ content: JSON.stringify(resp),
+ },
+ ]
+ const newResponse = await askAI(newMessages, {
+ tableSchema: currentTableSchema,
+ allTables,
+ databaseName: database,
+ })
+ const _newMessages = [
+ ...newMessages,
+ { role: "assistant", content: newResponse?.message?.content },
+ ]
+ console.log({ _newMessages })
+ setMessages(_newMessages as any)
+ }
+ }
+ } else if (response?.message) {
+ const newMessages = [
+ ..._messages,
+ { role: "assistant", content: response?.message?.content },
+ ]
+ const thisMsgIndex = newMessages.length - 1
+ setMessages(newMessages)
}
+
setLoading(false)
}
@@ -97,28 +124,37 @@ export const AIChat = () => {
first
)}
- {messages.map((message, i) => (
-
- {message.role === "assistant" ? (
- <>
-
-
- >
- ) : (
+ {messages
+ .filter((m) => (m.role === "user" || m.role == "assistant") && m.content)
+ .map((message, i) => (
+
+ {message.role === "assistant" && (
+ <>
+
+
+ >
+ )}
+ {message.role === "user" && (
+ <>
+
+
{message.content}
+ >
+ )}
+ {/* {message.role === "function" && (
<>
-
{message.content}
+
run function
>
- )}
-
- ))}
+ )} */}
+
+ ))}
{loading && }
diff --git a/app/settings/ai/ai-form.tsx b/app/settings/ai/ai-form.tsx
index 3803c7cc..742b7a4a 100644
--- a/app/settings/ai/ai-form.tsx
+++ b/app/settings/ai/ai-form.tsx
@@ -28,7 +28,27 @@ const AIConfigFormSchema = z.object({
})
export type AIConfigFormValues = z.infer
-export const AutoRunScopes = ["SQL.SELECT", "SQL.INSERT", "D3.CHART"]
+export const AutoRunScopesWithDesc = [
+ {
+ value: "SQL.SELECT",
+ description: "Selects data from a SQL table.",
+ },
+ {
+ value: "UI.REFRESH",
+ description: "Refresh the UI after querying data.",
+ },
+ {
+ value: "SQL.INSERT",
+ description: "Inserts data into a SQL table.",
+ },
+ {
+ value: "D3.CHART",
+ description: "Creates a D3 chart.",
+ },
+]
+
+export const AutoRunScopes = AutoRunScopesWithDesc.map((item) => item.value)
+
// This can come from your database or API.
const defaultValues: Partial = {
// name: "Your name",
@@ -89,32 +109,34 @@ export function AIConfigForm() {
run.
- {AutoRunScopes.map((_item) => (
+ {AutoRunScopesWithDesc.map(({ value, description }) => (
{
return (
{
return checked
- ? field.onChange([...field.value, _item])
+ ? field.onChange([...field.value, value])
: field.onChange(
field.value?.filter(
- (value) => value !== _item
+ (value) => value !== value
)
)
}}
/>
- {_item}
+
+ {description}
+
)
}}
diff --git a/components/sidebar/create-table.tsx b/components/sidebar/create-table.tsx
index 7d9f3130..7847d7b4 100644
--- a/components/sidebar/create-table.tsx
+++ b/components/sidebar/create-table.tsx
@@ -15,22 +15,31 @@ import {
} from "@/components/ui/dialog"
import { Input } from "@/components/ui/input"
import { Label } from "@/components/ui/label"
+import { Progress } from "@/components/ui/progress"
import { csvFile2Sql } from "./helper"
export function CreateTableDialog() {
const [open, setOpen] = useState(false)
const [tableName, setTableName] = useState("")
+ const [importing, setImporting] = useState(false)
+ const [progress, setProgress] = useState(0)
const [file, setFile] = useState(null)
const params = useParams()
const router = useRouter()
const { database } = params
- const { createTable, createTableWithSql } = useSqlite(database)
+ const { createTable, createTableWithSqlAndInsertSqls } = useSqlite(database)
const handleCreateTable = async () => {
if (file) {
const res = await csvFile2Sql(file, tableName.trim())
- await createTableWithSql(res.createTableSql, res.insertSql)
+ setImporting(true)
+ await createTableWithSqlAndInsertSqls(
+ res.createTableSql,
+ res.sqls,
+ setProgress
+ )
+ setImporting(false)
// await createTableWithSql(res.createTableSql, res.insertSql)
router.push(`/${database}/${tableName}`)
setOpen(false)
@@ -49,6 +58,7 @@ export function CreateTableDialog() {
+
Create Table
@@ -86,6 +96,7 @@ export function CreateTableDialog() {
/>
+ {importing && }