In [None]:
# -*- coding: utf-8 -*-

import os
import shutil
import configparser
import gradio as gr
from loguru import logger
from claude_api import Client
import threading
import time
from utils import xlsx_to_csv,markdown_to_csv

proxies = {
    "http": "127.0.0.1:10908",
    "https": "127.0.0.1:10908"
}

class ChatPDFWebUI:
    def __init__(self, pwd_path,server_name,server_port):
        self.pwd_path = pwd_path
        self.server_name=server_name
        self.server_port=server_port
        self.CONTENT_DIR = os.path.join(pwd_path, str(server_port)+"content")
        self.USER_SETTINGS_FILE = os.path.join(self.CONTENT_DIR, "usersettings.ini")
        logger.info(f"CONTENT_DIR: {self.CONTENT_DIR}")
        self.load_user_settings()
        self.attachments = []
        self.uploaded_list = []

    def get_local_files(self):
        if not os.path.exists(self.CONTENT_DIR):
            return []
        return [f for f in os.listdir(self.CONTENT_DIR) if
            f.endswith(".csv") or f.endswith(".txt") or f.endswith(".pdf") or f.endswith(".docx") or f.endswith(".md")]

    def save_user_settings(self):
        config = configparser.ConfigParser()
        config["USER_SETTINGS"] = {
            "session_key": self.session_key,
            "conversation_id": self.conversation_id,
            "last_query": self.last_query
        }
        with open(self.USER_SETTINGS_FILE, "w") as configfile:
            config.write(configfile)

    def load_user_settings(self):
        if os.path.exists(self.USER_SETTINGS_FILE):
            config = configparser.ConfigParser()
            config.read(self.USER_SETTINGS_FILE)
            self.session_key = config.get("USER_SETTINGS", "session_key", fallback="")
            self.conversation_id = config.get("USER_SETTINGS", "conversation_id", fallback="")
            self.last_query = config.get("USER_SETTINGS", "last_query", fallback="")
    
    def launch_web_ui(self):
        block_css = """.importantButton {
            background: linear-gradient(45deg, #7e0570,#5d1c99, #6e00ff) !important;
            border: none !important;
        }
        .importantButton:hover {
            background: linear-gradient(45deg, #ff00e0,#8500ff, #6e00ff) !important;
            border: none !important;
        }"""
        webui_title = """
        # 🎉智能文档🎉
        """
        init_message = """欢迎使用，可以直接提问或上传文件后提问 """
        def chatbot_print(message,chatbot):
            ret = "已成功"+message+"文件"
            for attachments in self.attachments:
                file_size_kb = attachments['file_size'] / 1024
                ret += "<br>文件名称：" + attachments['file_name'] + '<br>' + f'文件大小 {file_size_kb:.2f} KB'
            chatbot.append([None, ret])
            return chatbot

        def init_claude():
            chatbot = []
            self.claude = Client(self.session_key,proxies)
            if '' == self.conversation_id:
                self.conversation_id = self.claude.create_new_chat()['uuid']
                self.save_user_settings()
            else:
                #导出历史聊天记录
                history = self.claude.chat_conversation_history(self.conversation_id)
                text_values = [message['text'] for message in history['chat_messages']]
                chatbot = [text_values[i:i + 2] for i in range(0, len(text_values), 2)]
            return chatbot

        def on_session_key_change(session_key_input):
            self.session_key = session_key_input
            self.conversation_id = ''
            chatbot = init_claude()
            return chatbot

        def process_query(history,message):
            if message is not None and message.strip() != "":
                response = self.claude.send_message(message, self.conversation_id,attachments=self.attachments,timeout=3000)
                history.append([message, response])
            
        def on_query_submit(history,message):
            process_query(history,message)
            self.last_query = message
            self.save_user_settings()
            return history

        def on_example_query(query):
            return query
        
        def clear_chatbot(chatbot):
            #开启新对话
            if hasattr(self, 'claude'):
                self.conversation_id = self.claude.create_new_chat()['uuid']
                self.save_user_settings()
            return gr.Chatbot.update(value = [])
            
        def upload_file(files,chatbot):
            if not os.path.exists(self.CONTENT_DIR):
                os.mkdir(self.CONTENT_DIR)
            for file in files:
                filename = os.path.basename(file.name)
                file_path = os.path.join(self.CONTENT_DIR, filename)
                self.uploaded_list.append(file_path)
                if False == os.path.exists(file_path):
                    shutil.copy(file.name, file_path)
                #如果是xlsx文件，先转化成csv文件再上传
                if file_path.endswith('.xlsx'):
                    csv_list = xlsx_to_csv(file_path)
                    for path in csv_list:
                        self.attachments.append(self.claude.upload_attachment(path))
                elif os.path.exists(file_path):
                    self.attachments.append(self.claude.upload_attachment(file_path))
            chatbot = chatbot_print("上传",chatbot)
            return self.uploaded_list,gr.CheckboxGroup.update(choices=self.get_local_files()),chatbot

        def delete_files(local_files):
            for file in local_files:
                file_path = os.path.join(self.CONTENT_DIR, file)
                if os.path.exists(file_path):
                    os.remove(file_path)
            return gr.CheckboxGroup.update(choices=self.get_local_files())
            
        def clear_attachments(chatbot):
            self.attachments = []
            self.uploaded_list = []
            chatbot = chatbot_print("清空",chatbot)
            return chatbot

        def upload_history_file(local_files,chatbot):
            for file in local_files:
                file_path = os.path.join(self.CONTENT_DIR, file)
                self.attachments.append(self.claude.upload_attachment(file_path))
            chatbot = chatbot_print("上传",chatbot)
            return chatbot

        #将对话保存为csv
        def chatbot_select(evt: gr.SelectData):
            if evt.value:
                markdown_to_csv(evt.value,os.path.join(self.CONTENT_DIR, "markdown_table.csv"))

        with gr.Blocks(css=block_css) as demo:
            gr.Markdown(webui_title)
            with gr.Row():
                with gr.Column(scale=1):
                    session_key_input = gr.Textbox(label="cookie", value=self.session_key)
                    with gr.Tab("上传文件"):
                        update_files = gr.File(label="已上传文件列表：")
                        upload_btn = gr.UploadButton("点击上传", 
                        file_types=['.xlsx','.csv''', '.txt', '.md', '.docx', '.pdf'], 
                        file_count="multiple")
                        #点击清除附件列表
                        clear_btn = gr.ClearButton(value="清空上传文件", components=update_files)
                    with gr.Tab("本地文件"):
                        local_files = gr.CheckboxGroup(
                        label="请选择你将要会话的文件：",
                        choices=self.get_local_files(),
                        type="value",
                        default=[],
                        css_class="select-files-checkbox-group")  
                        upload_history_btn = gr.Button("上传选中历史文件")
                        delete_history_btn = gr.Button("删除选中历史文件")
                with gr.Column(scale=2):
                    history = init_claude()
                    chatbot = gr.Chatbot(history,
                                         elem_id="会话框",
                                         label="conversation ID:"+self.conversation_id,
                                         max_messages=None).style(height=600)
                    query = gr.Textbox(label="prompt",
                                       placeholder="请输入提问内容，按回车进行提交",
                                       value=self.last_query,
                                       lines=1)
                    clear_chatbot_btn = gr.Button('🔄清空并创建新对话', elem_id='clear')
                     # Define example queries
                    example_queries = gr.Examples(examples=[["Help me summarize a table information of 6 columns based on the attachments' content, columns including: manufacturer, bacterial filter product, BFE(Bacterial Filter efficiency),VFE(Viral Filter efficiency),resistance to flow and dead space, rows including: subject device and predicate device in each attachment, when filling out information, it is necessary to cover the answers in different situations, If there is no relevant information, please fill in \"NA\" in the relevant column of the table."], 
                                                            ["good"]],
                                                  fn=on_example_query,
                                                  inputs=[query],
                                                  outputs=[query],label="示例prompt",
                                                  examples_per_page = 2)
            #点击表格转csv文件
            chatbot.select(fn=chatbot_select)
            upload_history_btn.click(fn=upload_history_file, inputs=[local_files,chatbot], outputs=[chatbot])
            delete_history_btn.click(fn=delete_files, inputs=[local_files], outputs=[local_files])
            upload_btn.upload(fn=upload_file, inputs=[upload_btn,chatbot], outputs=[update_files,local_files,chatbot])
            clear_btn.click(fn = clear_attachments,inputs=[chatbot], outputs=[chatbot])
            query.submit(fn=on_query_submit, inputs=[chatbot, query], outputs=[chatbot])
            clear_chatbot_btn.click(fn=clear_chatbot, inputs=[chatbot], outputs=[chatbot])
            session_key_input.change(fn=on_session_key_change,inputs=[session_key_input],outputs=[chatbot])
        demo.launch(server_name=self.server_name, server_port=self.server_port, share=False, inbrowser=False)

def thread_run(pwd_path,server_name,port):
    chat_pdf_web_ui = ChatPDFWebUI(pwd_path,server_name,port)
    chat_pdf_web_ui.launch_web_ui()
if __name__ == "__main__":
    pwd_path = "E:\89_DataMining\lancet"
    server_name = '10.10.70.78'
    port_list = [36552,36546,38931,38539,38410,36670,38484,38996,38930,38322,37936,38656]
    for port in port_list:
        th = threading.Thread(target=thread_run,args=(pwd_path,server_name,port))
        th.start()
        time.sleep(5)
        