参考https://github.com/THUDM/ChatGLM-6B/blob/main/web_demo.py

In [None]:
from transformers import AutoModel, AutoTokenizer
import gradio as gr
import mdtex2html

tokenizer = AutoTokenizer.from_pretrained("../../model/chatglm2-6b", trust_remote_code=True, local_files_only=True)
model = AutoModel.from_pretrained("../../model/chatglm2-6b", trust_remote_code=True, local_files_only=True).half().quantize(4).cuda()
model = model.eval()

from langchain.embeddings.huggingface import HuggingFaceEmbeddings
from langchain.vectorstores import Milvus

vector_db = Milvus(
    embedding_function = HuggingFaceEmbeddings(model_name="../../model/m3e-base"),
    connection_args={"host": "127.0.0.1", "port": "19530"},
    collection_name="qa1",
)

In [7]:


"""Override Chatbot.postprocess"""

# 重写Chatbot.postprocess方法，用于处理输出结果
def postprocess(self, y):
    if y is None:
        return []
    for i, (message, response) in enumerate(y):
        y[i] = (
            None if message is None else mdtex2html.convert((message)),
            None if response is None else mdtex2html.convert(response),
        )
    return y


gr.Chatbot.postprocess = postprocess



# 定义一个函数，用于解析文本中的代码块
def parse_text(text):
    docs = vector_db.similarity_search(text, k=3)
    q = '''使用以下内容回答最后的问题。如果你不知道答案，就说你不知道，不要试图编造答案。在回答的最后一定要说"感谢您的提问！"
    {context}
    问题：{query}
    有用的回答： 
    '''.format(context=docs, query=text)
    return q


# def parse_text(text):
#     """copy from https://github.com/GaiZhenbiao/ChuanhuChatGPT/"""
#     lines = text.split("\n")
#     lines = [line for line in lines if line != ""]
#     count = 0
#     for i, line in enumerate(lines):
#         if "```" in line:
#             count += 1
#             items = line.split('`')
#             if count % 2 == 1:
#                 lines[i] = f'<pre><code class="language-{items[-1]}">'
#             else:
#                 lines[i] = f'<br></code></pre>'
#         else:
#             if i > 0:
#                 if count % 2 == 1:
#                     line = line.replace("`", "\`")
#                     line = line.replace("<", "&lt;")
#                     line = line.replace(">", "&gt;")
#                     line = line.replace(" ", "&nbsp;")
#                     line = line.replace("*", "&ast;")
#                     line = line.replace("_", "&lowbar;")
#                     line = line.replace("-", "&#45;")
#                     line = line.replace(".", "&#46;")
#                     line = line.replace("!", "&#33;")
#                     line = line.replace("(", "&#40;")
#                     line = line.replace(")", "&#41;")
#                     line = line.replace("$", "&#36;")
#                 lines[i] = "<br>"+line
#     text = "".join(lines)
#     return text

# 定义一个函数，用于预测输入并返回聊天记录和历史对话
def predict(input, chatbot, max_length, top_p, temperature, history):
    chatbot.append((parse_text(input), ""))
    for response, history in model.stream_chat(tokenizer, parse_text(input), history, max_length=max_length, top_p=top_p,
                                               temperature=temperature):
        chatbot[-1] = (input, response)
        yield chatbot, history

# 定义一个函数，用于清除用户输入
def reset_user_input():
    return gr.update(value='')

# 定义一个函数，用于清空历史对话
def reset_state():
    return [], []

# 使用gradio创建一个界面
with gr.Blocks() as demo:
    gr.HTML("""<h1 align="center">ChatGLM</h1>""")

    # 创建一个聊天框
    chatbot = gr.Chatbot()
    with gr.Row():
        with gr.Column(scale=4):
            with gr.Column(scale=12):
                user_input = gr.Textbox(show_label=False, placeholder="Input...", lines=10).style(
                    container=False)
            with gr.Column(min_width=32, scale=1):
                submitBtn = gr.Button("Submit", variant="primary")
        with gr.Column(scale=1):
            emptyBtn = gr.Button("Clear History")
            max_length = gr.Slider(0, 4096, value=2048, step=1.0, label="Maximum length", interactive=True)
            top_p = gr.Slider(0, 1, value=0.7, step=0.01, label="Top P", interactive=True)
            temperature = gr.Slider(0, 1, value=0.95, step=0.01, label="Temperature", interactive=True)

    # 创建一个保存历史对话的状态对象
    history = gr.State([])

    # 给提交按钮添加点击事件
    submitBtn.click(predict, [user_input, chatbot, max_length, top_p, temperature, history], [chatbot, history],
                    show_progress=True)
    submitBtn.click(reset_user_input, [], [user_input])

    # 给清空按钮添加点击事件
    emptyBtn.click(reset_state, outputs=[chatbot, history], show_progress=True)

# 启动应用
demo.queue().launch(share=False, inbrowser=True)


  user_input = gr.Textbox(show_label=False, placeholder="Input...", lines=10).style(


Running on local URL:  http://127.0.0.1:7860

To create a public link, set `share=True` in `launch()`.




In [6]:
demo.close()

Closing server running on port: 7860
