Skip to content

Commit

Permalink
feat: support refresh chatglm token
Browse files Browse the repository at this point in the history
  • Loading branch information
yetone committed Apr 21, 2024
1 parent c694a51 commit 47fe3d7
Show file tree
Hide file tree
Showing 3 changed files with 87 additions and 42 deletions.
18 changes: 17 additions & 1 deletion src-tauri/src/fetch.rs
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,12 @@ pub(crate) struct StreamChunk {
status: u16,
}

#[derive(Debug, Serialize, Deserialize, Clone)]
pub(crate) struct StreamStatusCode {
id: String,
status: u16,
}

#[derive(Debug, Serialize, Deserialize, Clone)]
pub(crate) struct AbortEventPayload {
id: String,
Expand Down Expand Up @@ -96,9 +102,19 @@ pub async fn fetch_stream(id: String, url: String, options_str: String) -> Resul

let status = resp.status();

let app_handle = APP_HANDLE.get().unwrap();
app_handle
.emit(
"fetch-stream-status-code",
StreamStatusCode {
id: id.clone(),
status: status.as_u16(),
},
)
.unwrap();

let stream = resp.bytes_stream();

let app_handle = APP_HANDLE.get().unwrap();
let (abort_handle, abort_registration) = AbortHandle::new_pair();
let cloned_id = id.clone();
let listen_id = app_handle.listen_any("abort-fetch-stream", move |msg| {
Expand Down
97 changes: 58 additions & 39 deletions src/common/engines/chatglm.ts
Original file line number Diff line number Diff line change
@@ -1,15 +1,9 @@
/* eslint-disable camelcase */
import { getUniversalFetch } from '@/common/universal-fetch'
import { fetchSSE, getSettings, isDesktopApp } from '@/common/utils'
import { fetchSSE, getSettings, isDesktopApp, setSettings } from '@/common/utils'
import { AbstractEngine } from '@/common/engines/abstract-engine'
import { IModel, IMessageRequest } from '@/common/engines/interfaces'
import qs from 'qs'
import { LRUCache } from 'lru-cache'

const cache = new LRUCache<string, string>({
max: 100,
ttl: 1000 * 60 * 60,
})

export const keyChatGLMAccessToken = 'chatglm-access-token'
export const keyChatGLMRefreshToken = 'chatglm-refresh-token'
Expand Down Expand Up @@ -44,52 +38,79 @@ export class ChatGLM extends AbstractEngine {
}
}

async refreshAccessToken(onStatusCode: ((statusCode: number) => void) | undefined) {
const settings = await getSettings()

const headers = await this.getHeaders()

const fetcher = getUniversalFetch()

headers['Authorization'] = `Bearer ${settings.chatglmRefreshToken}`
const refreshResp = await fetcher('https://chatglm.cn/chatglm/backend-api/v1/user/refresh', {
method: 'POST',
headers,
body: JSON.stringify({}),
})
onStatusCode?.(refreshResp.status)
if (refreshResp.status === 200) {
const data = await refreshResp.json()
if (data.message !== 'success') {
throw new Error('Failed to refresh ChatGLM access token: ' + data.message)
}
await setSettings({
...settings,
chatglmAccessToken: data.result.accessToken,
})
} else {
throw new Error('Failed to refresh ChatGLM access token: ' + refreshResp.statusText)
}
}

async sendMessage(req: IMessageRequest): Promise<void> {
const settings = await getSettings()
const refreshToken = settings.chatglmRefreshToken
const fetcher = getUniversalFetch()

const assistantID = '65940acff94777010aa6b796'
const conversationTitle = 'OpenAI Translator'
const conversationIDCacheKey = `chatglm-conversation-id-${assistantID}`
let conversationID = cache.get(conversationIDCacheKey) || ''

if (conversationID) {
console.log('Using cached conversation ID:', conversationID)
}

req.onStatusCode?.(200)

const headers = await this.getHeaders()

if (!conversationID) {
const conversationListResp = await fetcher(
`https://chatglm.cn/chatglm/backend-api/assistant/conversation/list?${qs.stringify({
assistant_id: assistantID,
page: 1,
page_size: 25,
})}`,
{
method: 'GET',
headers,
}
)

req.onStatusCode?.(conversationListResp.status)

if (!conversationListResp.ok) {
const jsn = await conversationListResp.json()
req.onError?.(jsn.message ?? jsn.msg ?? 'unknown error')
return
const conversationListResp = await fetcher(
`https://chatglm.cn/chatglm/backend-api/assistant/conversation/list?${qs.stringify({
assistant_id: assistantID,
page: 1,
page_size: 25,
})}`,
{
method: 'GET',
headers,
}
)

const conversationList = await conversationListResp.json()
req.onStatusCode?.(conversationListResp.status)

const conversation = conversationList.result.conversation_list.find(
(c: { id: string; title: string }) => c.title === conversationTitle
)
if ((conversationListResp.status === 401 || conversationListResp.status === 422) && refreshToken) {
await this.refreshAccessToken(req.onStatusCode)
return await this.sendMessage(req)
}

conversationID = conversation?.id
if (!conversationListResp.ok) {
const jsn = await conversationListResp.json()
req.onError?.(jsn.message ?? jsn.msg ?? 'unknown error')
return
}

const conversationList = await conversationListResp.json()

const conversation = conversationList.result.conversation_list.find(
(c: { id: string; title: string }) => c.title === conversationTitle
)

let conversationID = conversation?.id

if (!conversationID) {
try {
const signalController = new AbortController()
Expand Down Expand Up @@ -141,8 +162,6 @@ export class ChatGLM extends AbstractEngine {
return
}

cache.set(conversationIDCacheKey, conversationID)

let hasError = false
let finished = false
let length = 0
Expand Down
14 changes: 12 additions & 2 deletions src/common/utils.ts
Original file line number Diff line number Diff line change
Expand Up @@ -428,13 +428,23 @@ export async function fetchSSE(input: string, options: FetchSSEOptions) {

if (isTauri()) {
const id = uuidv4()
let unlisten: (() => void) | undefined = undefined
const unlistens: Array<() => void> = []
const unlisten = () => {
unlistens.forEach((unlisten) => unlisten())
}
return await new Promise<void>((resolve, reject) => {
options.signal?.addEventListener('abort', () => {
unlisten?.()
emit('abort-fetch-stream', { id })
resolve()
})
listen('fetch-stream-status-code', (event: Event<{ id: string; status: number }>) => {
if (event.payload.id === id) {
onStatusCode?.(event.payload.status)
}
})
.then((unlisten) => unlistens.push(unlisten))
.catch((e) => reject(e))
listen(
'fetch-stream-chunk',
(event: Event<{ id: string; data: string; done: boolean; status: number }>) => {
Expand Down Expand Up @@ -464,7 +474,7 @@ export async function fetchSSE(input: string, options: FetchSSEOptions) {
}
)
.then((cb) => {
unlisten = cb
unlistens.push(cb)
})
.catch((e) => {
reject(e)
Expand Down

0 comments on commit 47fe3d7

Please sign in to comment.