diff --git a/src/services/apis/claude-web.mjs b/src/services/apis/claude-web.mjs index f8ec51f9..ab9ac4ee 100644 --- a/src/services/apis/claude-web.mjs +++ b/src/services/apis/claude-web.mjs @@ -1,4 +1,4 @@ -import { pushRecord } from './shared.mjs' +import { pushRecord, setAbortController } from './shared.mjs' import Claude from '../clients/claude' import { Models } from '../../config/index.mjs' @@ -18,6 +18,7 @@ export async function generateAnswersWithClaudeWebApi( ) { const bot = new Claude({ sessionKey }) await bot.init() + const { controller, cleanController } = setAbortController(port) let answer = '' const progressFunc = ({ completion }) => { @@ -31,22 +32,34 @@ export async function generateAnswersWithClaudeWebApi( port.postMessage({ answer: answer, done: true, session: session }) } + const params = { + progress: progressFunc, + done: doneFunc, + model: Models[modelName].value, + signal: controller.signal, + } + if (!session.claude_conversation) await bot - .startConversation(question, { - progress: progressFunc, - done: doneFunc, - model: Models[modelName].value, - }) + .startConversation(question, params) .then((conversation) => { session.claude_conversation = conversation port.postMessage({ answer: answer, done: true, session: session }) + cleanController() + }) + .catch((err) => { + cleanController() + throw err }) else - await bot.sendMessage(question, { - conversation: session.claude_conversation, - progress: progressFunc, - done: doneFunc, - model: Models[modelName].value, - }) + await bot + .sendMessage(question, { + conversation: session.claude_conversation, + ...params, + }) + .then(cleanController) + .catch((err) => { + cleanController() + throw err + }) } diff --git a/src/services/apis/shared.mjs b/src/services/apis/shared.mjs index 8a915f44..a263876b 100644 --- a/src/services/apis/shared.mjs +++ b/src/services/apis/shared.mjs @@ -36,7 +36,12 @@ export function setAbortController(port, onStop, onDisconnect) { } port.onDisconnect.addListener(disconnectListener) - return { controller, messageListener, disconnectListener } + const cleanController = () => { + port.onMessage.removeListener(messageListener) + port.onDisconnect.removeListener(disconnectListener) + } + + return { controller, cleanController, messageListener, disconnectListener } } export function pushRecord(session, question, answer) { diff --git a/src/services/clients/claude/index.mjs b/src/services/clients/claude/index.mjs index 134dd2ef..eafb0153 100644 --- a/src/services/clients/claude/index.mjs +++ b/src/services/clients/claude/index.mjs @@ -1,6 +1,9 @@ // https://github.com/Explosion-Scratch/claude-unofficial-api /* eslint-disable */ +import { fetchSSE } from '../../../utils/index.mjs' +import { isEmpty } from 'lodash-es' + /** * The main Claude API client class. * @typedef Claude @@ -267,6 +270,7 @@ export class Claude { cookie: `sessionKey=${this.sessionKey}`, }, method: 'POST', + signal: params.signal, body: JSON.stringify({ name: '', uuid: uuid(), @@ -555,6 +559,7 @@ export class Conversation { done = () => {}, progress = () => {}, rawResponse = () => {}, + signal = null, } = {}, ) { if (model === 'default') { @@ -565,10 +570,16 @@ export class Conversation { attachments, timezone, } - const response = await this.request( - `/api/organizations/${this.claude.organizationId}/chat_conversations/${this.conversationId}/${ - retry ? 'retry_completion' : 'completion' - }`, + let resolve, reject + let returnPromise = new Promise((r, j) => { + resolve = r + reject = j + }) + let fullResponse = '' + await fetchSSE( + `https://claude.ai/api/organizations/${this.claude.organizationId}/chat_conversations/${ + this.conversationId + }/${retry ? 'retry_completion' : 'completion'}`, { method: 'POST', headers: { @@ -576,56 +587,49 @@ export class Conversation { 'content-type': 'application/json', cookie: `sessionKey=${this.claude.sessionKey}`, }, + signal: signal, body: JSON.stringify(body), + onMessage(message) { + console.debug('sse message', message) + let parsed + try { + parsed = JSON.parse(message) + } catch (error) { + console.debug('json error', error) + return + } + if (parsed.completion) fullResponse += parsed.completion + const PROGRESS_OBJECT = { + ...parsed, + completion: fullResponse, + delta: parsed.completion || '', + } + progress(PROGRESS_OBJECT) + if (parsed.stop_reason === 'stop_sequence') { + done(PROGRESS_OBJECT) + resolve(PROGRESS_OBJECT) + } + }, + async onStart() {}, + async onEnd() { + resolve({ + completion: fullResponse, + }) + }, + async onError(resp) { + if (resp instanceof Error) { + reject(resp) + return + } + const error = await resp.json().catch(() => ({})) + reject( + new Error( + !isEmpty(error) ? JSON.stringify(error) : `${resp.status} ${resp.statusText}`, + ), + ) + }, }, ) - let resolve - let returnPromise = new Promise((r) => (resolve = r)) - let parsed - readStream(response, (a, fullResponse) => { - rawResponse(a, fullResponse) - if (!a.toString().startsWith('data:')) { - return - } - try { - parsed = JSON.parse( - a - .toString() - .replace(/^data\:/, '') - .split('\n\ndata:')[0] - ?.trim() || '{}', - ) - } catch (e) { - return - } - const PROGRESS_OBJECT = { - ...parsed, - completion: fullResponse - .split('\n\n') - .filter((i) => i.startsWith('data:')) - .map((i) => { - try { - return JSON.parse( - i - .toString() - .replace(/^data\: */, '') - .split('\n\ndata:')[0] - ?.trim() || '{}', - ) - } catch (e) { - return {} - } - }) - .map((i) => i.completion) - .join(''), - delta: parsed.completion, - } - progress(PROGRESS_OBJECT) - if (parsed.stop_reason === 'stop_sequence') { - done(PROGRESS_OBJECT) - resolve(PROGRESS_OBJECT) - } - }) return returnPromise } /** @@ -815,7 +819,7 @@ function errorHandle(msg) { return (e) => { console.error(`Error at: ${msg}`) console.error(e) - process.exit(0) + // process.exit(0) } }