In [None]:
# %%writefile ui_launcher_v.py
import gradio as gr
import random
import time

import sagemaker
import boto3

sess = sagemaker.Session()
role = sagemaker.get_execution_role()
default_bucket = sess.default_bucket()
region = sess.boto_session.region_name

from sagemaker_endpoint_llm import SageMakerLLM

LLM_ENDPOINT_NAME = 'xxxx-xx-xx-xx-xx' # Paste here SageMaker Endpoint Name created in Step2

llm = SageMakerLLM(
            SageMakerEndpointName = LLM_ENDPOINT_NAME,
            AWSRegion = region,
            LLMType = 'GLM-6b',
            LLMArgs={'top_p': 0.45, 'temperature': 0.45}
        )

from kendra_retrieval_qa_chain import KendraLLMRetrieverQAChain

KENDRA_INDEX_ID = 'xxxx-xx-xx-xx-xx'    # Paste here Kendra index id created in Step1
KENDRA_LANG_CODE = 'zh' # full lang code list in https://docs.aws.amazon.com/kendra/latest/dg/in-adding-languages.html

kqa = KendraLLMRetrieverQAChain(
            KendraIndexId = KENDRA_INDEX_ID,
            KendraLanguageCode = KENDRA_LANG_CODE,
            AWSRegion = region,
            Llm = llm.get_llm()
        )

from custom_search_helper import WebSearcher

RAPID_API_KEY = '---51-digit-key---' # use rapid api, can be applied in https://rapidapi.com/microsoft-azure-org-microsoft-cognitive-services/api/bing-web-search1/

web_sercher = WebSearcher(
                apikey = RAPID_API_KEY,
                result_count = 5,
                lang_code = 'zh',
                search_engine = 'bing' # google
            )


from custom_agent_helper import CustomExecutor

flat_agent = CustomExecutor(llm = llm.sm_llm,
                        retriever = kqa.kendra_retriever,
                        websearcher = web_sercher,
                        kendra_chain = None, # retriever will be overwritten if assigned
                        verbose=True
                       )

hier_agent = CustomExecutor(llm = llm.sm_llm,
                        retriever = kqa.kendra_retriever,
                        websearcher = web_sercher,
                        kendra_chain = kqa, # retriever will be overwritten if assigned
                        verbose=True
                       )

web_agent = CustomExecutor(llm = llm.sm_llm,
                        retriever = None,
                        websearcher = web_sercher,
                        kendra_chain = None, # retriever will be overwritten if assigned
                        verbose=True
                       )

theme=gr.themes.Default(spacing_size=gr.themes.sizes.spacing_sm, text_size=gr.themes.sizes.text_sm)
with gr.Blocks(theme) as demo:
    chatbot = gr.Chatbot()
    msg = gr.Textbox(label='User Input: ')
    
    with gr.Row():
        clear = gr.ClearButton([msg, chatbot])
        int_btn = gr.Button("InternalQA", variant="primary")
        hyb_btn = gr.Button("HybridQA\n(Beta..)")

    def respond(message, chat_history):
        bot_resp = hier_agent.query(message)
        chat_history.append((message, bot_resp))
        
        return "", chat_history
    
    def respond_internal(message, chat_history):
        bot_resp = kqa.kendra_chain_qa(message)
        chat_history.append((message, bot_resp['result']))
        
        return "", chat_history
        
    msg.submit(respond_internal, [msg, chatbot], [msg, chatbot])
    
    int_btn.click(respond_internal, [msg, chatbot], [msg, chatbot])
    hyb_btn.click(respond, [msg, chatbot], [msg, chatbot])

demo.launch()

## Unit tests

In [None]:
# model test
llm.sm_llm('你最新的知识截止到什么时间')

In [None]:
# retriever test
kqa.kendra_retriever.get_relevant_documents('AWS clean room支持的数据源有哪些')

In [None]:
resp = kqa.kendra_chain_qa('AWS clean room支持的数据源有哪些')
resp['query'], resp['result']

In [None]:
# web search API test
web_sercher.search('moss大模型')

In [None]:
# hierachical chain test
hier_agent.query('AWS clean room支持哪些数据源')