Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: support refresh chatglm token #1462

Merged
merged 1 commit into from
Apr 21, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
18 changes: 17 additions & 1 deletion src-tauri/src/fetch.rs
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,12 @@ pub(crate) struct StreamChunk {
status: u16,
}

#[derive(Debug, Serialize, Deserialize, Clone)]
pub(crate) struct StreamStatusCode {
id: String,
status: u16,
}

#[derive(Debug, Serialize, Deserialize, Clone)]
pub(crate) struct AbortEventPayload {
id: String,
Expand Down Expand Up @@ -96,9 +102,19 @@ pub async fn fetch_stream(id: String, url: String, options_str: String) -> Resul

let status = resp.status();

let app_handle = APP_HANDLE.get().unwrap();
app_handle
.emit(
"fetch-stream-status-code",
StreamStatusCode {
id: id.clone(),
status: status.as_u16(),
},
)
.unwrap();

let stream = resp.bytes_stream();

let app_handle = APP_HANDLE.get().unwrap();
let (abort_handle, abort_registration) = AbortHandle::new_pair();
let cloned_id = id.clone();
let listen_id = app_handle.listen_any("abort-fetch-stream", move |msg| {
Expand Down
97 changes: 58 additions & 39 deletions src/common/engines/chatglm.ts
Original file line number Diff line number Diff line change
@@ -1,15 +1,9 @@
/* eslint-disable camelcase */
import { getUniversalFetch } from '@/common/universal-fetch'
import { fetchSSE, getSettings, isDesktopApp } from '@/common/utils'
import { fetchSSE, getSettings, isDesktopApp, setSettings } from '@/common/utils'
import { AbstractEngine } from '@/common/engines/abstract-engine'
import { IModel, IMessageRequest } from '@/common/engines/interfaces'
import qs from 'qs'
import { LRUCache } from 'lru-cache'

const cache = new LRUCache<string, string>({
max: 100,
ttl: 1000 * 60 * 60,
})

export const keyChatGLMAccessToken = 'chatglm-access-token'
export const keyChatGLMRefreshToken = 'chatglm-refresh-token'
Expand Down Expand Up @@ -44,52 +38,79 @@ export class ChatGLM extends AbstractEngine {
}
}

async refreshAccessToken(onStatusCode: ((statusCode: number) => void) | undefined) {
const settings = await getSettings()

const headers = await this.getHeaders()

const fetcher = getUniversalFetch()

headers['Authorization'] = `Bearer ${settings.chatglmRefreshToken}`
const refreshResp = await fetcher('https://chatglm.cn/chatglm/backend-api/v1/user/refresh', {
method: 'POST',
headers,
body: JSON.stringify({}),
})
onStatusCode?.(refreshResp.status)
if (refreshResp.status === 200) {
const data = await refreshResp.json()
if (data.message !== 'success') {
throw new Error('Failed to refresh ChatGLM access token: ' + data.message)
}
await setSettings({
...settings,
chatglmAccessToken: data.result.accessToken,
})
} else {
throw new Error('Failed to refresh ChatGLM access token: ' + refreshResp.statusText)
}
}

async sendMessage(req: IMessageRequest): Promise<void> {
const settings = await getSettings()
const refreshToken = settings.chatglmRefreshToken
const fetcher = getUniversalFetch()

const assistantID = '65940acff94777010aa6b796'
const conversationTitle = 'OpenAI Translator'
const conversationIDCacheKey = `chatglm-conversation-id-${assistantID}`
let conversationID = cache.get(conversationIDCacheKey) || ''

if (conversationID) {
console.log('Using cached conversation ID:', conversationID)
}

req.onStatusCode?.(200)

const headers = await this.getHeaders()

if (!conversationID) {
const conversationListResp = await fetcher(
`https://chatglm.cn/chatglm/backend-api/assistant/conversation/list?${qs.stringify({
assistant_id: assistantID,
page: 1,
page_size: 25,
})}`,
{
method: 'GET',
headers,
}
)

req.onStatusCode?.(conversationListResp.status)

if (!conversationListResp.ok) {
const jsn = await conversationListResp.json()
req.onError?.(jsn.message ?? jsn.msg ?? 'unknown error')
return
const conversationListResp = await fetcher(
`https://chatglm.cn/chatglm/backend-api/assistant/conversation/list?${qs.stringify({
assistant_id: assistantID,
page: 1,
page_size: 25,
})}`,
{
method: 'GET',
headers,
}
)

const conversationList = await conversationListResp.json()
req.onStatusCode?.(conversationListResp.status)

const conversation = conversationList.result.conversation_list.find(
(c: { id: string; title: string }) => c.title === conversationTitle
)
if ((conversationListResp.status === 401 || conversationListResp.status === 422) && refreshToken) {
await this.refreshAccessToken(req.onStatusCode)
return await this.sendMessage(req)
}

conversationID = conversation?.id
if (!conversationListResp.ok) {
const jsn = await conversationListResp.json()
req.onError?.(jsn.message ?? jsn.msg ?? 'unknown error')
return
}

const conversationList = await conversationListResp.json()

const conversation = conversationList.result.conversation_list.find(
(c: { id: string; title: string }) => c.title === conversationTitle
)

let conversationID = conversation?.id

if (!conversationID) {
try {
const signalController = new AbortController()
Expand Down Expand Up @@ -141,8 +162,6 @@ export class ChatGLM extends AbstractEngine {
return
}

cache.set(conversationIDCacheKey, conversationID)

let hasError = false
let finished = false
let length = 0
Expand Down
14 changes: 12 additions & 2 deletions src/common/utils.ts
Original file line number Diff line number Diff line change
Expand Up @@ -428,13 +428,23 @@ export async function fetchSSE(input: string, options: FetchSSEOptions) {

if (isTauri()) {
const id = uuidv4()
let unlisten: (() => void) | undefined = undefined
const unlistens: Array<() => void> = []
const unlisten = () => {
unlistens.forEach((unlisten) => unlisten())
}
return await new Promise<void>((resolve, reject) => {
options.signal?.addEventListener('abort', () => {
unlisten?.()
emit('abort-fetch-stream', { id })
resolve()
})
listen('fetch-stream-status-code', (event: Event<{ id: string; status: number }>) => {
if (event.payload.id === id) {
onStatusCode?.(event.payload.status)
}
})
.then((unlisten) => unlistens.push(unlisten))
.catch((e) => reject(e))
listen(
'fetch-stream-chunk',
(event: Event<{ id: string; data: string; done: boolean; status: number }>) => {
Expand Down Expand Up @@ -464,7 +474,7 @@ export async function fetchSSE(input: string, options: FetchSSEOptions) {
}
)
.then((cb) => {
unlisten = cb
unlistens.push(cb)
})
.catch((e) => {
reject(e)
Expand Down
Loading