Skip to content

Commit

Permalink
Allow user to abort request after offline LLM has started yielding to…
Browse files Browse the repository at this point in the history
…kens
  • Loading branch information
sabaimran committed Oct 23, 2023
1 parent 5bb14a0 commit 608ed89
Show file tree
Hide file tree
Showing 5 changed files with 84 additions and 3 deletions.
47 changes: 46 additions & 1 deletion src/khoj/interface/web/chat.html
Original file line number Diff line number Diff line change
Expand Up @@ -72,6 +72,9 @@
function chat() {
// Extract required fields for search from form
let query = document.getElementById("chat-input").value.trim();
const abortButton = document.getElementById("abort-button");
let controller = new AbortController();
let signal = controller.signal;
let resultsCount = localStorage.getItem("khojResultsCount") || 5;
console.log(`Query: ${query}`);

Expand Down Expand Up @@ -111,18 +114,33 @@
chatInput.classList.remove("option-enabled");

// Call specified Khoj API which returns a streamed response of type text/plain
fetch(url)
fetch(url, { signal })
.then(response => {
const reader = response.body.getReader();
const decoder = new TextDecoder();

abortButton.addEventListener("click", () => {
controller.abort();
console.log("Download aborted");
if (newResponseText.getElementsByClassName("spinner").length > 0) {
newResponseText.removeChild(loadingSpinner);
}
newResponseText.innerHTML += "<b>Aborted.</b>";
document.getElementById("chat-body").scrollTop = document.getElementById("chat-body").scrollHeight;
document.getElementById("chat-input").removeAttribute("disabled");
abortButton.style.display = "none";
});

function readStream() {
reader.read().then(({ done, value }) => {

if (done) {
// Evaluate the contents of new_response_text.innerHTML after all the data has been streamed
const currentHTML = newResponseText.innerHTML;
newResponseText.innerHTML = formatHTMLMessage(currentHTML);

abortButton.style.display = "none";

return;
}

Expand All @@ -145,6 +163,7 @@
// Display response from Khoj
if (newResponseText.getElementsByClassName("spinner").length > 0) {
newResponseText.removeChild(loadingSpinner);
abortButton.style.display = "block";
}

newResponseText.innerHTML += chunk;
Expand All @@ -153,10 +172,32 @@

// Scroll to bottom of chat window as chat response is streamed
document.getElementById("chat-body").scrollTop = document.getElementById("chat-body").scrollHeight;
}).catch(error => {
if (error.name === 'AbortError') {
abortButton.style.display = "none";

const cancelUrl = `/api/chat/cancel?client=web`;

console.log('Fetch aborted');
abortController = new AbortController();
signal = abortController.signal;
fetch(cancelUrl, { method: 'POST'})
.then(response => {
return response.json();
})
.then(data => {
console.log(data);
});
} else {
console.error('Fetch error:', error);
}
});
}
readStream();
document.getElementById("chat-input").removeAttribute("disabled");
})
.catch(error => {
console.error('Fetch error:', error);
});
}

Expand Down Expand Up @@ -297,6 +338,7 @@
<div id="chat-tooltip" style="display: none;"></div>
<textarea id="chat-input" class="option" oninput="onChatInput()" onkeyup=incrementalChat(event) autofocus="autofocus" placeholder="Type / to see a list of commands, or just type your questions and hit enter.">
</textarea>
<button id="abort-button" class="option">Abort</button>
</div>
</body>

Expand Down Expand Up @@ -389,6 +431,9 @@
border-bottom: 0;
transform: rotate(-60deg);
}
button#abort-button {
display: none;
}
/* color chat bubble by you dark grey */
.chat-message-text.you {
color: #f8fafc;
Expand Down
12 changes: 11 additions & 1 deletion src/khoj/processor/conversation/gpt4all/chat_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -178,6 +178,14 @@ def llm_thread(g, messages: List[ChatMessage], model: Any):
raise e

assert isinstance(model, GPT4All), "model should be of type GPT4All"

def callback_fn(token_id, token_string):
if state.should_cancel_chat:
logger.debug(f"Canceling chat as cancel_chat is set to {state.should_cancel_chat}")
state.should_cancel_chat = False
return False
return True

user_message = messages[-1]
system_message = messages[0]
conversation_history = messages[1:-1]
Expand All @@ -196,7 +204,9 @@ def llm_thread(g, messages: List[ChatMessage], model: Any):
prompted_message = templated_system_message + chat_history + templated_user_message

state.chat_lock.acquire()
response_iterator = model.generate(prompted_message, streaming=True, max_tokens=500, n_batch=512)
response_iterator = model.generate(
prompted_message, streaming=True, max_tokens=500, n_batch=512, callback=callback_fn
)
try:
for response in response_iterator:
if any(stop_word in response.strip() for stop_word in stop_words):
Expand Down
2 changes: 1 addition & 1 deletion src/khoj/processor/conversation/gpt4all/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@ def download_model(model_name: str):

# Use GPU for Chat Model, if available
try:
model = GPT4All(model_name=model_name, device="gpu")
model = GPT4All(model_name=model_name, device="gpu", allow_download=True)
logger.debug(f"Loaded {model_name} chat model to GPU.")
except ValueError:
model = GPT4All(model_name=model_name)
Expand Down
25 changes: 25 additions & 0 deletions src/khoj/routers/api.py
Original file line number Diff line number Diff line change
Expand Up @@ -648,6 +648,31 @@ def update(
return {"status": "ok", "message": "khoj reloaded"}


@api.post("/chat/cancel")
def cancel_chat(
request: Request,
client: Optional[str] = None,
user_agent: Optional[str] = Header(None),
referer: Optional[str] = Header(None),
host: Optional[str] = Header(None),
):
perform_chat_checks()

state.should_cancel_chat = True

update_telemetry_state(
request=request,
telemetry_type="api",
api="cancel_chat",
client=client,
user_agent=user_agent,
referer=referer,
host=host,
)

return Response(content=json.dumps({"status": "ok"}), media_type="application/json", status_code=200)


@api.get("/chat/history")
def chat_history(
request: Request,
Expand Down
1 change: 1 addition & 0 deletions src/khoj/utils/state.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,7 @@
previous_query: str = None
demo: bool = False
khoj_version: str = None
should_cancel_chat: bool = False

if torch.cuda.is_available():
# Use CUDA GPU
Expand Down

0 comments on commit 608ed89

Please sign in to comment.