From 608ed89aa8f6e29e7289800d6bd53d0428bbc938 Mon Sep 17 00:00:00 2001 From: sabaimran Date: Sun, 22 Oct 2023 22:03:38 -0700 Subject: [PATCH] Allow user to abort request after offline LLM has started yielding tokens --- src/khoj/interface/web/chat.html | 47 ++++++++++++++++++- .../conversation/gpt4all/chat_model.py | 12 ++++- .../processor/conversation/gpt4all/utils.py | 2 +- src/khoj/routers/api.py | 25 ++++++++++ src/khoj/utils/state.py | 1 + 5 files changed, 84 insertions(+), 3 deletions(-) diff --git a/src/khoj/interface/web/chat.html b/src/khoj/interface/web/chat.html index 2230e901b..d7a9026fa 100644 --- a/src/khoj/interface/web/chat.html +++ b/src/khoj/interface/web/chat.html @@ -72,6 +72,9 @@ function chat() { // Extract required fields for search from form let query = document.getElementById("chat-input").value.trim(); + const abortButton = document.getElementById("abort-button"); + let controller = new AbortController(); + let signal = controller.signal; let resultsCount = localStorage.getItem("khojResultsCount") || 5; console.log(`Query: ${query}`); @@ -111,18 +114,33 @@ chatInput.classList.remove("option-enabled"); // Call specified Khoj API which returns a streamed response of type text/plain - fetch(url) + fetch(url, { signal }) .then(response => { const reader = response.body.getReader(); const decoder = new TextDecoder(); + abortButton.addEventListener("click", () => { + controller.abort(); + console.log("Download aborted"); + if (newResponseText.getElementsByClassName("spinner").length > 0) { + newResponseText.removeChild(loadingSpinner); + } + newResponseText.innerHTML += "Aborted."; + document.getElementById("chat-body").scrollTop = document.getElementById("chat-body").scrollHeight; + document.getElementById("chat-input").removeAttribute("disabled"); + abortButton.style.display = "none"; + }); + function readStream() { reader.read().then(({ done, value }) => { + if (done) { // Evaluate the contents of new_response_text.innerHTML after all the data has been streamed const currentHTML = newResponseText.innerHTML; newResponseText.innerHTML = formatHTMLMessage(currentHTML); + abortButton.style.display = "none"; + return; } @@ -145,6 +163,7 @@ // Display response from Khoj if (newResponseText.getElementsByClassName("spinner").length > 0) { newResponseText.removeChild(loadingSpinner); + abortButton.style.display = "block"; } newResponseText.innerHTML += chunk; @@ -153,10 +172,32 @@ // Scroll to bottom of chat window as chat response is streamed document.getElementById("chat-body").scrollTop = document.getElementById("chat-body").scrollHeight; + }).catch(error => { + if (error.name === 'AbortError') { + abortButton.style.display = "none"; + + const cancelUrl = `/api/chat/cancel?client=web`; + + console.log('Fetch aborted'); + abortController = new AbortController(); + signal = abortController.signal; + fetch(cancelUrl, { method: 'POST'}) + .then(response => { + return response.json(); + }) + .then(data => { + console.log(data); + }); + } else { + console.error('Fetch error:', error); + } }); } readStream(); document.getElementById("chat-input").removeAttribute("disabled"); + }) + .catch(error => { + console.error('Fetch error:', error); }); } @@ -297,6 +338,7 @@ + @@ -389,6 +431,9 @@ border-bottom: 0; transform: rotate(-60deg); } + button#abort-button { + display: none; + } /* color chat bubble by you dark grey */ .chat-message-text.you { color: #f8fafc; diff --git a/src/khoj/processor/conversation/gpt4all/chat_model.py b/src/khoj/processor/conversation/gpt4all/chat_model.py index 04a004f05..5420baf4d 100644 --- a/src/khoj/processor/conversation/gpt4all/chat_model.py +++ b/src/khoj/processor/conversation/gpt4all/chat_model.py @@ -178,6 +178,14 @@ def llm_thread(g, messages: List[ChatMessage], model: Any): raise e assert isinstance(model, GPT4All), "model should be of type GPT4All" + + def callback_fn(token_id, token_string): + if state.should_cancel_chat: + logger.debug(f"Canceling chat as cancel_chat is set to {state.should_cancel_chat}") + state.should_cancel_chat = False + return False + return True + user_message = messages[-1] system_message = messages[0] conversation_history = messages[1:-1] @@ -196,7 +204,9 @@ def llm_thread(g, messages: List[ChatMessage], model: Any): prompted_message = templated_system_message + chat_history + templated_user_message state.chat_lock.acquire() - response_iterator = model.generate(prompted_message, streaming=True, max_tokens=500, n_batch=512) + response_iterator = model.generate( + prompted_message, streaming=True, max_tokens=500, n_batch=512, callback=callback_fn + ) try: for response in response_iterator: if any(stop_word in response.strip() for stop_word in stop_words): diff --git a/src/khoj/processor/conversation/gpt4all/utils.py b/src/khoj/processor/conversation/gpt4all/utils.py index 2bb1fbbc1..f97eaff06 100644 --- a/src/khoj/processor/conversation/gpt4all/utils.py +++ b/src/khoj/processor/conversation/gpt4all/utils.py @@ -13,7 +13,7 @@ def download_model(model_name: str): # Use GPU for Chat Model, if available try: - model = GPT4All(model_name=model_name, device="gpu") + model = GPT4All(model_name=model_name, device="gpu", allow_download=True) logger.debug(f"Loaded {model_name} chat model to GPU.") except ValueError: model = GPT4All(model_name=model_name) diff --git a/src/khoj/routers/api.py b/src/khoj/routers/api.py index 345429e86..849a46018 100644 --- a/src/khoj/routers/api.py +++ b/src/khoj/routers/api.py @@ -648,6 +648,31 @@ def update( return {"status": "ok", "message": "khoj reloaded"} +@api.post("/chat/cancel") +def cancel_chat( + request: Request, + client: Optional[str] = None, + user_agent: Optional[str] = Header(None), + referer: Optional[str] = Header(None), + host: Optional[str] = Header(None), +): + perform_chat_checks() + + state.should_cancel_chat = True + + update_telemetry_state( + request=request, + telemetry_type="api", + api="cancel_chat", + client=client, + user_agent=user_agent, + referer=referer, + host=host, + ) + + return Response(content=json.dumps({"status": "ok"}), media_type="application/json", status_code=200) + + @api.get("/chat/history") def chat_history( request: Request, diff --git a/src/khoj/utils/state.py b/src/khoj/utils/state.py index 5ac8a8383..8ffca9712 100644 --- a/src/khoj/utils/state.py +++ b/src/khoj/utils/state.py @@ -31,6 +31,7 @@ previous_query: str = None demo: bool = False khoj_version: str = None +should_cancel_chat: bool = False if torch.cuda.is_available(): # Use CUDA GPU