Skip to content

Commit

Permalink
feat: add stream output and memory mode
Browse files Browse the repository at this point in the history
  • Loading branch information
orangelckc committed Mar 16, 2023
1 parent fdcb490 commit 06969ec
Show file tree
Hide file tree
Showing 9 changed files with 130 additions and 53 deletions.
3 changes: 2 additions & 1 deletion package.json
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
]
},
"dependencies": {
"@microsoft/fetch-event-source": "^2.0.1",
"@multiavatar/multiavatar": "^1.0.7",
"@tauri-apps/api": "^1.2.0",
"pinia": "^2.0.33",
Expand Down Expand Up @@ -54,4 +55,4 @@
"vite": "^4.0.0",
"vue-tsc": "^1.0.11"
}
}
}
55 changes: 53 additions & 2 deletions src/api/index.ts
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,11 @@ import { fetch, Body } from '@tauri-apps/api/http'
import { dialogErrorMessage } from '@/utils'
import type { FetchOptions } from '@tauri-apps/api/http'
import type { RecordData } from '@/types'
import { useSettingsStore } from '@/stores'
import { useSessionStore, useSettingsStore } from '@/stores'
import {
fetchEventSource,
type EventSourceMessage
} from '@microsoft/fetch-event-source'

/**
* 请求总入口
Expand Down Expand Up @@ -52,10 +56,57 @@ export const getOpenAIResultApi = async (messages: RecordData[]) => {
body: Body.json({
model: 'gpt-3.5-turbo-0301',
messages,
stream: true
temperature: 0.6,
stream: false
}),
headers: {
Authorization: `Bearer ${apiKey || import.meta.env.VITE_OPEN_AI_API_KEY}`
}
})
}

/**
* 获取 openai 对话消息(流)
* @param messages 消息列表
*/
export const getOpenAIResultStream = async (messages: RecordData[]) => {
if (!messages.length) return

const { apiKey } = useSettingsStore()
const { addSessionData } = useSessionStore()
const { streamReply } = storeToRefs(useSessionStore())
streamReply.value = ''

await fetchEventSource(import.meta.env.VITE_OPEN_AI_URL, {
method: 'POST',
body: JSON.stringify({
model: 'gpt-3.5-turbo-0301',
messages,
temperature: 0.6,
stream: true
}),
headers: {
Authorization: `Bearer ${apiKey || import.meta.env.VITE_OPEN_AI_API_KEY}`,
'Content-Type': 'application/json',
Accept: 'application/json'
},
onmessage(msg: EventSourceMessage) {
if (msg.data !== '[DONE]') {
const { choices } = JSON.parse(msg.data)

if (!choices[0].delta.content) return
streamReply.value += choices[0].delta.content
}
},
onclose() {
const res: RecordData = {
role: 'assistant',
content: streamReply.value
}
addSessionData(false, '', res)
},
onerror(err: any) {
throw new Error('流输出出错:', err)
}
})
}
4 changes: 2 additions & 2 deletions src/components/Function/components/SettingsModal.vue
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@ import ShortcutKey from './ShortcutKey.vue'
defineProps<{ visible: boolean; setVisible: () => void }>()
const { apiKey, autoStart } = storeToRefs(useSettingsStore())
const { apiKey, autoStart, isMemory } = storeToRefs(useSettingsStore())
</script>

<template>
Expand All @@ -29,7 +29,7 @@ const { apiKey, autoStart } = storeToRefs(useSettingsStore())
<ShortcutKey />

<div class="flex items-end gap-2">
<a-checkbox>记忆对话</a-checkbox>
<a-checkbox v-model="isMemory">记忆对话</a-checkbox>
<span class="text-3 text-[var(--color-text-3)]">
开启连续对话,将加倍消耗 Token
</span>
Expand Down
2 changes: 0 additions & 2 deletions src/components/Input/components/RoleList.vue
Original file line number Diff line number Diff line change
Expand Up @@ -36,8 +36,6 @@ const handleAdd = () => {
}
const handleHide = (value: boolean) => {
console.log('value', value)
// 判断是否有正在编辑的角色
const isEdit = showList.value.some((item) => item.isEdit)
if (isEdit) return
Expand Down
10 changes: 8 additions & 2 deletions src/components/Session/index.vue
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,9 @@
import { useSettingsStore, useSessionStore, useRoleStore } from '@/stores'
const { uuid } = storeToRefs(useSettingsStore())
const { currentSession, sessionDataList } = storeToRefs(useSessionStore())
const { currentSession, sessionDataList, streamReply } = storeToRefs(
useSessionStore()
)
const { currentRole } = storeToRefs(useRoleStore())
</script>

Expand All @@ -18,9 +20,13 @@ const { currentRole } = storeToRefs(useRoleStore())
>
<Avatar class="w-14!" :value="item.is_ask ? uuid : currentRole?.name" />
<div>
{{ item.messages }}
{{ item.message }}
</div>
</div>
<div>
<p>正在回答</p>
<p>{{ streamReply }}</p>
</div>
</template>

<!-- 当前无会话 -->
Expand Down
13 changes: 5 additions & 8 deletions src/stores/role.ts
Original file line number Diff line number Diff line change
Expand Up @@ -23,8 +23,6 @@ export const useRoleStore = defineStore(
defaultRole.value = []

roleList.value = result.map((item) => ({ ...item, isEdit: false }))

changeCurrentRole()
}

const getFilterRoleList = (value: string) => {
Expand All @@ -40,14 +38,13 @@ export const useRoleStore = defineStore(
filterList.value.length = 0
}

const changeCurrentRole = () => {
const changeCurrentRole = async () => {
const { currentSession } = useSessionStore()
const findRole = roleList.value.find(
(role) => role.id === currentSession?.role_id
)
console.log('currentSession', currentSession, findRole)

currentRole.value = findRole ?? roleList.value[0]
const sql = `SELECT * FROM role WHERE id = ${currentSession?.role_id};`
const findRole = (await executeSQL(sql)) as RolePayload[]

currentRole.value = findRole[0] ?? roleList.value[0]
}

const addRole = async (payload: RolePayload) => {
Expand Down
23 changes: 17 additions & 6 deletions src/stores/session.ts
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
import { executeSQL } from '@/sqls'
import type { SessionPayload, SessionData, RecordData } from '@/types'
import { useRoleStore } from './role'
import { useSettingsStore } from './settings'

// TODO: 无记忆对话和有记忆对话
// 用来管理当前会话的状态
Expand All @@ -15,6 +16,8 @@ export const useSessionStore = defineStore(
const sessionList = ref<SessionPayload[]>([])
// 请求发送状态
const isThinking = ref(false)
// 流式回复
const streamReply = ref('')

const getSessionList = async () => {
const sql =
Expand Down Expand Up @@ -57,26 +60,30 @@ export const useSessionStore = defineStore(
return result.length > 0
}

const { changeCurrentRole, currentRole } = useRoleStore()

// TODO: 是否为记忆对话
// TODO: messageType从 types 中取到
const addSessionData = async (
isAsk: boolean,
messageType: string,
data: RecordData[]
data: RecordData
) => {
if (!currentSession.value) return
// 检查会话是否已经存在
const isExist = await checkSessionExist()

const { currentRole } = useRoleStore()

if (!isExist) {
const sql = `INSERT INTO session (id, title, role_id) VALUES ('${currentSession.value.id}', '${data[1].content}', '${currentRole?.id}');`
const sql = `INSERT INTO session (id, title, role_id) VALUES ('${currentSession.value.id}', '${data.content}', '${currentRole?.id}');`
executeSQL(sql)
}

const sql = `INSERT INTO session_data (session_id, is_ask, messages) VALUES (
'${currentSession.value.id}','${isAsk}', '${JSON.stringify(data)}');`
const { isMemory } = useSettingsStore()

const sql = `INSERT INTO session_data (session_id, is_ask, is_memory, message) VALUES (
'${currentSession.value.id}', ${isAsk}, ${isMemory}, '${JSON.stringify(
data
)}');`

executeSQL(sql)
getSessionData()
Expand All @@ -88,6 +95,9 @@ export const useSessionStore = defineStore(
else {
currentSession.value = session
}

const { changeCurrentRole } = useRoleStore()

changeCurrentRole()
getSessionData()
}
Expand All @@ -98,6 +108,7 @@ export const useSessionStore = defineStore(
currentSession,
sessionDataList,
isThinking,
streamReply,
sessionList,
addSessionData,
switchSession,
Expand Down
2 changes: 1 addition & 1 deletion src/types/sql.d.ts
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@ export interface SessionData {
is_ask: boolean
is_memory: boolean
message_type?: 'text' | 'image' | 'voice'
messages: RecordData[]
message: RecordData
time?: string
}

Expand Down
71 changes: 42 additions & 29 deletions src/utils/openai.ts
Original file line number Diff line number Diff line change
@@ -1,55 +1,68 @@
import { getOpenAIResultApi } from '@/api'
import { getOpenAIResultStream } from '@/api'
import { executeSQL } from '@/sqls'
import { useSessionStore, useRoleStore } from '@/stores'
import { RecordData } from '@/types'
import { useSessionStore, useRoleStore, useSettingsStore } from '@/stores'
import { RecordData, SessionData } from '@/types'

export const getAiMessage = async (value?: string) => {
const { currentRole } = useRoleStore()
if (!currentRole) return

let messages: RecordData[]
const messages: RecordData[] = []

const { sessionDataList, currentSession } = useSessionStore()
const { isMemory } = useSettingsStore()

const lastQuestion = sessionDataList.filter((item) => item.is_ask).at(-1)

// 记忆模式,或者是第一次对话,都要生成角色描述
if (sessionDataList.length < 3 || isMemory)
messages.push({
role: 'system',
content: currentRole.description
})

// 获取记忆(限制5条),往前推直到出现is_momery为false的
// TODO 应该进行限流,防止出现过多的记忆,导致token超出
const addMemory = async () => {
if (isMemory) {
const sql = `SELECT * FROM session_data WHERE session_id = '${currentSession?.id}' ORDER BY id DESC LIMIT 5;`
const memoryList = (await executeSQL(sql)) as SessionData[]

let count = 0
const arr = []
while (count < memoryList.length) {
if (!memoryList[count].is_memory) break
arr.push(JSON.parse(memoryList[count++].message as any))
}
messages.push(...arr.reverse())
}
}

// 再次生成上一次问题
if (!value) {
const { sessionDataList } = useSessionStore()
const lastQuestion = sessionDataList.filter((item) => item.is_ask).at(-1)

if (!lastQuestion) return

// 为了保证统一,这之后的内容全部删掉
const deleteSql = `DELETE FROM session_data WHERE session_id = '${lastQuestion?.session_id}' AND id >= ${lastQuestion?.id};`

await executeSQL(deleteSql)

messages = JSON.parse(lastQuestion?.messages as any) || []
await addMemory()
messages.push(JSON.parse(lastQuestion?.message as any))
} else {
// TODO 这里可以优化,如何携带上一次的对话内容
messages = [
{
role: 'system',
content: currentRole.description
},
{
role: 'user',
content: value
}
]
await addMemory()
messages.push({
role: 'user',
content: value
})
}

const { isThinking } = storeToRefs(useSessionStore())
const { addSessionData } = useSessionStore()

isThinking.value = true

addSessionData(true, '', messages)
const result = await getOpenAIResultApi(messages)
addSessionData(true, '', messages.at(-1)!)
await getOpenAIResultStream(messages)

isThinking.value = false

console.log('result', result)

if (!result) return

// TODO 处理流式输出的结果
addSessionData(false, '', result.message)
}

0 comments on commit 06969ec

Please sign in to comment.