Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
16 changes: 2 additions & 14 deletions example-apps/chatbot-rag-app/api/app.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,7 @@
from flask import Flask, jsonify, request, Response
from flask_cors import CORS
from queue import Queue
from uuid import uuid4
from chat import ask_question, parse_stream_message
import threading
from chat import ask_question
import os
import sys

Expand All @@ -23,18 +21,8 @@ def api_chat():
if question is None:
return jsonify({"msg": "Missing question from request JSON"}), 400

stream_queue = Queue()
session_id = request.args.get("session_id", str(uuid4()))

print("Chat session ID: ", session_id)

threading.Thread(
target=ask_question, args=(question, stream_queue, session_id)
).start()

return Response(
parse_stream_message(session_id, stream_queue), mimetype="text/event-stream"
)
return Response(ask_question(question, session_id), mimetype="text/event-stream")


@app.cli.command()
Expand Down
149 changes: 28 additions & 121 deletions example-apps/chatbot-rag-app/api/chat.py
Original file line number Diff line number Diff line change
@@ -1,18 +1,10 @@
from langchain.callbacks.base import BaseCallbackHandler
from langchain.chains import ConversationalRetrievalChain
from langchain.prompts.chat import (
HumanMessagePromptTemplate,
SystemMessagePromptTemplate,
ChatPromptTemplate,
)
from langchain.prompts.prompt import PromptTemplate
from langchain.vectorstores import ElasticsearchStore
from queue import Queue
from llm_integrations import get_llm
from elasticsearch_client import (
elasticsearch_client,
get_elasticsearch_chat_message_history,
)
from flask import render_template, stream_with_context, current_app
import json
import os

Expand All @@ -21,135 +13,50 @@
"ES_INDEX_CHAT_HISTORY", "workplace-app-docs-chat-history"
)
ELSER_MODEL = os.getenv("ELSER_MODEL", ".elser_model_2")
POISON_MESSAGE = "~~~END~~~"
SESSION_ID_TAG = "[SESSION_ID]"
SOURCE_TAG = "[SOURCE]"
DONE_TAG = "[DONE]"


class QueueCallbackHandler(BaseCallbackHandler):
def __init__(
self,
queue: Queue,
):
self.queue = queue
self.in_human_prompt = True

def on_retriever_end(self, documents, *, run_id, parent_run_id=None, **kwargs):
if len(documents) > 0:
for doc in documents:
source = {
"name": doc.metadata["name"],
"page_content": doc.page_content,
"url": doc.metadata["url"],
"icon": doc.metadata["category"],
"updated_at": doc.metadata.get("updated_at", None),
}
self.queue.put(f"{SOURCE_TAG} {json.dumps(source)}")

def on_llm_new_token(self, token, **kwargs):
if not self.in_human_prompt:
self.queue.put(token)

def on_llm_start(
self,
serialized,
prompts,
*,
run_id,
parent_run_id=None,
tags=None,
metadata=None,
**kwargs,
):
self.in_human_prompt = prompts[0].startswith("Human:")

def on_llm_end(self, response, *, run_id, parent_run_id=None, **kwargs):
if not self.in_human_prompt:
self.queue.put(POISON_MESSAGE)


store = ElasticsearchStore(
es_connection=elasticsearch_client,
index_name=INDEX,
strategy=ElasticsearchStore.SparseVectorRetrievalStrategy(model_id=ELSER_MODEL),
)

general_system_template = """
Human: Use the following passages to answer the user's question.
Each passage has a SOURCE which is the title of the document. When answering, give the source name of the passages you are answering from, put them in a comma seperated list, prefixed at the start with SOURCES: $sources then print an empty line.

Example:

Question: What is the meaning of life?
Response:
The meaning of life is 42. \n

SOURCES: Hitchhiker's Guide to the Galaxy \n

If you don't know the answer, just say that you don't know, don't try to make up an answer.

----
{context}
----
@stream_with_context
def ask_question(question, session_id):
yield f"data: {SESSION_ID_TAG} {session_id}\n\n"
current_app.logger.debug("Chat session ID: %s", session_id)

"""
general_user_template = "Question: {question}"
qa_prompt = ChatPromptTemplate.from_messages(
[
SystemMessagePromptTemplate.from_template(general_system_template),
HumanMessagePromptTemplate.from_template(general_user_template),
]
)
chat_history = get_elasticsearch_chat_message_history(
INDEX_CHAT_HISTORY, session_id
)

document_prompt = PromptTemplate(
input_variables=["page_content", "name"],
template="""
---
NAME: "{name}"
PASSAGE:
{page_content}
---
""",
)
if len(chat_history.messages) > 0:
# create a condensed question
condense_question_prompt = render_template(
'condense_question_prompt.txt', question=question,
chat_history=chat_history.messages)
question = get_llm().invoke(condense_question_prompt).content

retriever = store.as_retriever()
llm = get_llm()
chat = ConversationalRetrievalChain.from_llm(
llm=llm,
retriever=store.as_retriever(),
return_source_documents=True,
combine_docs_chain_kwargs={"prompt": qa_prompt, "document_prompt": document_prompt},
verbose=True,
)
current_app.logger.debug('Question: %s', question)

docs = store.as_retriever().invoke(question)
for doc in docs:
doc_source = {**doc.metadata, 'page_content': doc.page_content}
current_app.logger.debug('Retrieved document passage from: %s', doc.metadata['name'])
yield f'data: {SOURCE_TAG} {json.dumps(doc_source)}\n\n'

def parse_stream_message(session_id, queue: Queue):
yield f"data: {SESSION_ID_TAG} {session_id}\n\n"
qa_prompt = render_template('rag_prompt.txt', question=question, docs=docs)

message = None
break_out_flag = False
while True:
message = queue.get()
for line in message.splitlines():
if line == POISON_MESSAGE:
break_out_flag = True
break
yield f"data: {line}\n\n"
if break_out_flag:
break
answer = ''
for chunk in get_llm().stream(qa_prompt):
yield f'data: {chunk.content}\n\n'
answer += chunk.content

yield f"data: {DONE_TAG}\n\n"
current_app.logger.debug('Answer: %s', answer)


def ask_question(question, queue, session_id):
chat_history = get_elasticsearch_chat_message_history(
INDEX_CHAT_HISTORY, session_id
)
result = chat(
{"question": question, "chat_history": chat_history.messages},
callbacks=[QueueCallbackHandler(queue)],
)

chat_history.add_user_message(result["question"])
chat_history.add_ai_message(result["answer"])
chat_history.add_user_message(question)
chat_history.add_ai_message(answer)
20 changes: 10 additions & 10 deletions example-apps/chatbot-rag-app/api/llm_integrations.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,15 +5,15 @@

LLM_TYPE = os.getenv("LLM_TYPE", "openai")

def init_openai_chat():
def init_openai_chat(temperature):
OPENAI_API_KEY = os.getenv("OPENAI_API_KEY")
return ChatOpenAI(openai_api_key=OPENAI_API_KEY, streaming=True, temperature=0.2)
def init_vertex_chat():
return ChatOpenAI(openai_api_key=OPENAI_API_KEY, streaming=True, temperature=temperature)
def init_vertex_chat(temperature):
VERTEX_PROJECT_ID = os.getenv("VERTEX_PROJECT_ID")
VERTEX_REGION = os.getenv("VERTEX_REGION", "us-central1")
vertexai.init(project=VERTEX_PROJECT_ID, location=VERTEX_REGION)
return ChatVertexAI(streaming=True, temperature=0.2)
def init_azure_chat():
return ChatVertexAI(streaming=True, temperature=temperature)
def init_azure_chat(temperature):
OPENAI_VERSION=os.getenv("OPENAI_VERSION", "2023-05-15")
BASE_URL=os.getenv("OPENAI_BASE_URL")
OPENAI_API_KEY=os.getenv("OPENAI_API_KEY")
Expand All @@ -24,8 +24,8 @@ def init_azure_chat():
openai_api_version=OPENAI_VERSION,
openai_api_key=OPENAI_API_KEY,
streaming=True,
temperature=0.2)
def init_bedrock():
temperature=temperature)
def init_bedrock(temperature):
AWS_ACCESS_KEY=os.getenv("AWS_ACCESS_KEY")
AWS_SECRET_KEY=os.getenv("AWS_SECRET_KEY")
AWS_REGION=os.getenv("AWS_REGION")
Expand All @@ -35,7 +35,7 @@ def init_bedrock():
client=BEDROCK_CLIENT,
model_id=AWS_MODEL_ID,
streaming=True,
model_kwargs={"temperature":0.2})
model_kwargs={"temperature":temperature})

MAP_LLM_TYPE_TO_CHAT_MODEL = {
"azure": init_azure_chat,
Expand All @@ -44,8 +44,8 @@ def init_bedrock():
"vertex": init_vertex_chat,
}

def get_llm():
def get_llm(temperature=0.2):
if not LLM_TYPE in MAP_LLM_TYPE_TO_CHAT_MODEL:
raise Exception("LLM type not found. Please set LLM_TYPE to one of: " + ", ".join(MAP_LLM_TYPE_TO_CHAT_MODEL.keys()) + ".")

return MAP_LLM_TYPE_TO_CHAT_MODEL[LLM_TYPE]()
return MAP_LLM_TYPE_TO_CHAT_MODEL[LLM_TYPE](temperature=temperature)
Original file line number Diff line number Diff line change
@@ -0,0 +1,8 @@
Given the following conversation and a follow up question, rephrase the follow up question to be a standalone question, in its original language.

Chat history:
{% for dialogue_turn in chat_history -%}
{% if dialogue_turn.type == 'human' %}Question: {{ dialogue_turn.content }}{% elif dialogue_turn.type == 'ai' %}Response: {{ dialogue_turn.content }}{% endif %}
{% endfor -%}
Follow Up Question: {{ question }}
Standalone question:
26 changes: 26 additions & 0 deletions example-apps/chatbot-rag-app/api/templates/rag_prompt.txt
Original file line number Diff line number Diff line change
@@ -0,0 +1,26 @@
Use the following passages to answer the user's question.
Each passage has a NAME which is the title of the document. When answering, give the source name of the passages you are answering from at the end. Put them in a comma separated list, prefixed with SOURCES:.

Example:

Question: What is the meaning of life?
Response:
The meaning of life is 42.

SOURCES: Hitchhiker's Guide to the Galaxy

If you don't know the answer, just say that you don't know, don't try to make up an answer.

----

{% for doc in docs -%}
---
NAME: {{ doc.metadata.name }}
PASSAGE:
{{ doc.page_content }}
---

{% endfor -%}
----
Question: {{ question }}
Response: