In [None]:
import asyncio
import json
import pyaudio
import websockets
import ssl
import nest_asyncio
from openai import OpenAI
from collections import deque
import requests
import sounddevice as sd
import soundfile as sf
from io import BytesIO
import time
import wave
import sqlite3
import faiss
import numpy as np
import torch
import warnings
from sentence_transformers import SentenceTransformer, CrossEncoder
import pickle
from typing import List, Dict, Tuple
from tqdm.notebook import tqdm

# 禁用特定类型的警告
warnings.filterwarnings("ignore", category=UserWarning)
warnings.filterwarnings("ignore", category=FutureWarning)

# 允许嵌套事件循环
nest_asyncio.apply()

class MentalHealthRAG:
    def __init__(
        self,
        db_path: str = "mental_health_resources.db",
        embedding_model_name: str = "shibing624/text2vec-base-chinese",
        reranker_model_name: str = "BAAI/bge-reranker-base",
        index_path: str = "faiss_index.bin",
        resources_path: str = "resources.pkl",
        debug: bool = True
    ):
        self.db_path = db_path
        print(f"\n加载向量化模型: {embedding_model_name}")
        self.embedding_model = SentenceTransformer(embedding_model_name)
        print(f"加载重排序模型: {reranker_model_name}")
        self.reranker = CrossEncoder(reranker_model_name)
        self.index_path = index_path
        self.resources_path = resources_path
        self.index = None
        self.resources = []
        self.debug = debug

    def load_resources(self) -> List[Dict]:
        """从数据库加载资源"""
        conn = sqlite3.connect(self.db_path)
        cursor = conn.cursor()
        cursor.execute("SELECT * FROM resources")
        resources = cursor.fetchall()
        conn.close()

        return [
            {
                'id': r[0],
                'title': r[1],
                'summary': r[2],
                'link': r[3],
                'text': f"{r[1]} {r[2]}"
            }
            for r in resources
        ]

    def build_index(self):
        """构建FAISS索引"""
        print("加载资源...")
        self.resources = self.load_resources()
        
        print("生成文本向量...")
        texts = [r['text'] for r in self.resources]
        embeddings = self.embedding_model.encode(
            texts,
            show_progress_bar=True,
            convert_to_numpy=True,
            # 使用 notebook 专用的进度条
            progress_bar_class=tqdm
        )

        print("构建FAISS索引...")
        dimension = embeddings.shape[1]
        self.index = faiss.IndexFlatL2(dimension)
        self.index.add(embeddings.astype('float32'))

        print("保存索引和资源数据...")
        faiss.write_index(self.index, self.index_path)
        with open(self.resources_path, 'wb') as f:
            pickle.dump(self.resources, f)

        print("索引构建完成！")

    def load_index(self):
        """加载已存在的索引和资源"""
        self.index = faiss.read_index(self.index_path)
        with open(self.resources_path, 'rb') as f:
            self.resources = pickle.load(f)

    def search(
        self,
        query: str,
        k: int = 3,
        rerank: bool = True,
        threshold: float = 0.5
    ) -> List[Dict]:
        """搜索相关资源"""
        if self.debug:
            print("\nRAG检索过程:")
            print(f"查询文本: {query}")
        
        # 计算查询向量
        query_vector = self.embedding_model.encode([query])
        
        # FAISS检索
        distances, indices = self.index.search(
            query_vector.astype('float32'), 
            k
        )
        
        if self.debug:
            print("\n初始检索结果:")
        
        results = []
        for idx, distance in zip(indices[0], distances[0]):
            if idx != -1:
                result = self.resources[idx].copy()
                similarity = 1 / (1 + distance)  # 将L2距离转换为相似度
                result['initial_score'] = similarity
                results.append(result)
                if self.debug:
                    print(f"- {result['title']}")
                    print(f"  初始相似度: {similarity:.4f}")
                    print(f"  L2距离: {distance:.4f}")
        
        if rerank and results:
            if self.debug:
                print("\n重排序过程:")
            
            # 准备重排序的文本对
            pairs = [(query, r['text']) for r in results]
            
            # 计算相关性分数
            scores = self.reranker.predict(pairs)
            
            # 重排序并过滤
            reranked = []
            for score, result in zip(scores, results):
                if score >= threshold:
                    result['rerank_score'] = float(score)
                    reranked.append(result)
                    if self.debug:
                        print(f"- {result['title']}")
                        print(f"  重排序分数: {score:.4f}")
                        print(f"  是否通过阈值({threshold}): {'是' if score >= threshold else '否'}")
            
            # 按分数降序排序
            reranked.sort(key=lambda x: x['rerank_score'], reverse=True)
            
            if self.debug:
                print(f"\n最终返回 {len(reranked)} 个结果")
            return reranked
            
        return results

    def get_suggestion(self, conversation_text: str) -> Tuple[bool, List[Dict]]:
        """根据对话内容判断是否需要推荐资源"""
        if self.debug:
            print("\n资源推荐流程开始:")
            
        triggers = {
            '抑郁': 0.6,
            '焦虑': 0.6,
            '睡眠': 0.6,
            '压力': 0.5,
            '药物': 0.7,
            '家属': 0.5,
            '量表': 0.8,
            '检查': 0.6,
            '治疗': 0.7,
            '症状': 0.6
        }
        
        # 检查触发词
        matched_triggers = [
            (word, threshold) 
            for word, threshold in triggers.items() 
            if word in conversation_text
        ]
        
        if self.debug:
            if matched_triggers:
                print("检测到以下触发词:")
                for word, threshold in matched_triggers:
                    print(f"- {word} (阈值: {threshold})")
            else:
                print("未检测到触发词")
        
        if not matched_triggers:
            return False, []
            
        # 使用最高阈值
        max_threshold = max(threshold for _, threshold in matched_triggers)
        if self.debug:
            print(f"使用最高阈值: {max_threshold}")
            
        results = self.search(
            conversation_text,
            k=3,
            rerank=True,
            threshold=max_threshold
        )
        
        return bool(results), results

class MicrophoneClient:
    def __init__(
        self, 
        asr_url="ws://127.0.0.1:10095",  # New parameter for ASR URL
        tts_url="http://127.0.0.1:9880",
        deepseek_api_key=None, 
        db_path="mental_health_resources.db",
        debug=True
    ):
        # Parse ASR URL
        self.asr_url = asr_url
        if asr_url.startswith('wss://'):
            self.use_ssl = True
            self.host = asr_url[6:].split(':')[0]
        elif asr_url.startswith('ws://'):
            self.use_ssl = False
            self.host = asr_url[5:].split(':')[0]
        try:
            self.port = int(asr_url.split(':')[-1])
        except (IndexError, ValueError):
            self.port = 443 if self.use_ssl else 80

        self.chunk_size = [5, 10, 5]
        self.chunk_interval = 10
        self.websocket = None
        self.last_text = ""
        self.tts_url = tts_url
        self.recording = True
        self.stream = None
        self.audio = None
        self.debug = debug
        
        # 初始化DeepSeek客户端
        self.deepseek_client = OpenAI(
            api_key=deepseek_api_key, 
            base_url="https://api.deepseek.com"
        ) if deepseek_api_key else None
        
        # 存储对话历史
        self.conversation_history = deque(maxlen=10)
        
        # 初始化RAG系统
        self.rag = MentalHealthRAG(db_path=db_path, debug=debug)
        try:
            print("加载RAG索引...")
            self.rag.load_index()
        except (FileNotFoundError, Exception) as e:
            print("RAG索引不存在，正在构建新索引...")
            self.rag.build_index()
        
    def play_audio(self, wav_data):
        """播放音频数据"""
        try:
            wav_io = BytesIO(wav_data)
            wav_io.seek(0)
            data, samplerate = sf.read(wav_io)
            sd.play(data, samplerate, blocking=True)
            time.sleep(0.3)
        except Exception as e:
            print(f"音频播放错误: {e}")

    async def text_to_speech(self, text):
        """调用TTS API将文本转换为语音"""
        try:
            self.recording = False
            if self.stream:
                self.stream.stop_stream()
                await asyncio.sleep(1.0)
            
            payload = {
                "text": text,
                "text_language": "zh",
            }
            
            print("\n正在生成语音...")
            response = requests.post(self.tts_url, json=payload)
            
            if response.status_code == 200:
                print("播放中...")
                await asyncio.get_event_loop().run_in_executor(
                    None, 
                    self.play_audio, 
                    response.content
                )
                await asyncio.sleep(1.5)
            else:
                print(f"TTS API调用失败: {response.status_code}")
            
            print("准备继续录音...")
            await asyncio.sleep(1.5)
            if self.stream:
                self.stream.start_stream()
            self.recording = True
            print("录音已恢复")
                
        except Exception as e:
            print(f"TTS转换错误: {e}")
            if self.stream:
                self.stream.start_stream()
            self.recording = True

    async def format_resource_recommendations(self, resources):
        """格式化资源推荐为语音输出友好的格式"""
        if not resources:
            return ""
            
        recommendations = "\n\n根据您的描述，我为您推荐以下参考资源："
        for r in resources:
            recommendations += f"\n\n{r['title']}"
            if 'rerank_score' in r:
                recommendations += f"（相关度：{r['rerank_score']:.2%}）"
            recommendations += f"\n{r['summary']}"
        return recommendations

    async def get_ai_response(self, text):
        """获取AI回复并集成资源推荐"""
        if not self.deepseek_client:
            return "未配置DeepSeek API密钥"
            
        try:
            # 获取基础AI回复
            messages = [
                {"role": "system", "content": (
                    "你是ABC医院王医生，请针对患者或患者家属的提问给出合适的回复。"
                    "回复要言简意赅，控制在100字以内。"
                )}
            ]
            
            for history in self.conversation_history:
                messages.append({"role": "user", "content": history["user"]})
                messages.append({"role": "assistant", "content": history["assistant"]})
            
            messages.append({"role": "user", "content": text})
            
            response = await asyncio.get_event_loop().run_in_executor(
                None,
                lambda: self.deepseek_client.chat.completions.create(
                    model="deepseek-chat",
                    messages=messages,
                    max_tokens=400,
                    stream=False
                )
            )
            ai_reply = response.choices[0].message.content
            
            # 检查是否需要推荐资源
            should_recommend, resources = self.rag.get_suggestion(text + " " + ai_reply)
            
            if should_recommend:
                # 添加资源推荐
                recommendations = await self.format_resource_recommendations(resources)
                ai_reply += recommendations
            
            # 存储当前对话
            self.conversation_history.append({
                "user": text,
                "assistant": ai_reply
            })
            
            return ai_reply
        except Exception as e:
            return f"获取AI回复失败: {str(e)}"
        
    async def start_streaming(self):
        MAX_RETRIES = 3
        RETRY_DELAY = 2
        retry_count = 0

        while retry_count < MAX_RETRIES:
            try:
                uri = f"{'wss' if self.use_ssl else 'ws'}://{self.host}:{self.port}"
                ssl_context = None
                if self.use_ssl:
                    ssl_context = ssl.SSLContext(ssl.PROTOCOL_TLS_CLIENT)
                    ssl_context.check_hostname = False
                    ssl_context.verify_mode = ssl.CERT_NONE
                
                print(f"正在连接服务器 {uri}...")
                async with websockets.connect(
                    uri, 
                    subprotocols=["binary"], 
                    ping_interval=None, 
                    ssl=ssl_context
                ) as self.websocket:
                    print("连接成功！")
                    await asyncio.gather(
                        self._record_microphone(),
                        self._receive_messages()
                    )
                break
            except (websockets.exceptions.WebSocketException, TimeoutError, ConnectionRefusedError) as e:
                retry_count += 1
                if retry_count < MAX_RETRIES:
                    print(f"连接失败: {str(e)}")
                    print(f"将在 {RETRY_DELAY} 秒后进行第 {retry_count + 1} 次重试...")
                    await asyncio.sleep(RETRY_DELAY)
                else:
                    print(f"连接失败，已达到最大重试次数 ({MAX_RETRIES})。")
                    print("请检查:")
                    print("1. 语音识别服务是否已启动")
                    print(f"2. 服务器地址 ({self.host}) 和端口 ({self.port}) 是否正确")
                    print("3. 网络连接是否正常")
                    print(f"详细错误: {str(e)}")
                    raise
    
    async def _record_microphone(self):
        FORMAT = pyaudio.paInt16
        CHANNELS = 1
        RATE = 16000
        chunk_size = 60 * self.chunk_size[1] / self.chunk_interval
        CHUNK = int(RATE / 1000 * chunk_size)
        
        self.audio = pyaudio.PyAudio()
        self.stream = self.audio.open(
            format=FORMAT,
            channels=CHANNELS,
            rate=RATE,
            input=True,
            frames_per_buffer=CHUNK
        )
        
        config_message = json.dumps({
            "mode": "2pass",
            "chunk_size": self.chunk_size,
            "chunk_interval": self.chunk_interval,
            "wav_name": "microphone",
            "is_speaking": True,
            "hotwords": "",
            "itn": True
        })
        await self.websocket.send(config_message)
        
        try:
            while True:
                if self.recording:
                    try:
                        data = self.stream.read(CHUNK)
                        await self.websocket.send(data)
                    except Exception as e:
                        if self.recording:
                            print(f"录音错误: {e}")
                await asyncio.sleep(0.005)
        except Exception as e:
            print(f"录音循环错误: {e}")
        finally:
            if self.stream:
                self.stream.stop_stream()
                self.stream.close()
            if self.audio:
                self.audio.terminate()
    
    async def _receive_messages(self):
        try:
            while True:
                message = await self.websocket.recv()
                try:
                    msg_data = json.loads(message)
                    if "text" in msg_data:
                        new_text = msg_data["text"]
                        mode = msg_data.get("mode", "")
                        
                        if len(new_text.strip()) == 0:
                            continue
                            
                        if mode == "2pass-online":
                            if new_text != self.last_text:
                                print(f"\r实时识别: {new_text}", end="")
                                self.last_text = new_text
                        elif mode == "2pass-offline":
                            print(f"\nVAD结果: {new_text}")
                            # 获取AI回复（包含资源推荐）
                            ai_response = await self.get_ai_response(new_text)
                            print(f"AI回复: {ai_response}\n")
                            # 转换为语音并播放
                            await self.text_to_speech(ai_response)
                            print(f"[对话历史: {len(self.conversation_history)}轮]\n")
                            self.last_text = ""
                            
                except json.JSONDecodeError:
                    continue
                except Exception as e:
                    print(f"处理消息错误: {e}")
        except Exception as e:
            print(f"接收消息错误: {e}")

def start_mic_client(
    deepseek_api_key=None, 
    tts_url="http://127.0.0.1:9880",
    asr_url="wss://127.0.0.1:10095",  # New parameter
    db_path="mental_health_resources.db",
    debug=True
):
    """启动麦克风客户端的便捷函数"""
    client = MicrophoneClient(
        deepseek_api_key=deepseek_api_key, 
        tts_url=tts_url,
        asr_url=asr_url,  # Pass ASR URL to client
        db_path=db_path,
        debug=debug
    )
    loop = asyncio.get_event_loop()
    try:
        loop.run_until_complete(client.start_streaming())
    except KeyboardInterrupt:
        print("\n录音已停止")
    except Exception as e:
        print(f"发生错误: {e}")
        print("\n如需重新启动，请再次运行程序。")
    finally:
        if hasattr(client, 'stream') and client.stream:
            client.stream.stop_stream()
            client.stream.close()
        if hasattr(client, 'audio') and client.audio:
            client.audio.terminate()

# 使用示例
if __name__ == "__main__":
    # 配置参数
    DEEPSEEK_API_KEY = "sk-a753a785ce4f4d418de4caf17e82b629"
    TTS_URL = "http://copilot.mynatapp.cc"
    ASR_URL = "wss://693e14e788532c17.natapp.cc:54321"  # Updated ASR URL configuration
    DB_PATH = "mental_health_resources.db"
    
    print("开始连接语音服务...")
    print(f"语音识别服务器: {ASR_URL}")
    print(f"TTS服务器: {TTS_URL}")
    
    start_mic_client(
        deepseek_api_key=DEEPSEEK_API_KEY,
        tts_url=TTS_URL,
        asr_url=ASR_URL,  # Pass ASR URL
        db_path=DB_PATH,
        debug=True
    )

开始连接语音服务...
语音识别服务器: wss://693e14e788532c17.natapp.cc:54321
TTS服务器: http://copilot.mynatapp.cc

加载向量化模型: shibing624/text2vec-base-chinese
加载重排序模型: BAAI/bge-reranker-base
加载RAG索引...
正在连接服务器 wss://693e14e788532c17.natapp.cc:54321...
连接成功！
实时识别: 法吗疗方较好
VAD结果: 呃呃，我问一下，我压力感觉比较大怎么办？有什么比较好的治疗方法吗

资源推荐流程开始:
检测到以下触发词:
- 睡眠 (阈值: 0.6)
- 压力 (阈值: 0.5)
- 治疗 (阈值: 0.7)
使用最高阈值: 0.7

RAG检索过程:
查询文本: 呃呃，我问一下，我压力感觉比较大怎么办？有什么比较好的治疗方法吗 建议您先尝试一些自我调节方法，比如每天坚持30分钟有氧运动，保证7-8小时睡眠，学习正念冥想。如果持续2周以上没有改善，建议来院做专业评估。

初始检索结果:
- 正念减压练习指南
  初始相似度: 0.0057
  L2距离: 175.6180
- 压力管理技巧
  初始相似度: 0.0052
  L2距离: 192.5255
- 抑郁自评量表(SDS)
  初始相似度: 0.0049
  L2距离: 203.0525

重排序过程:
- 正念减压练习指南
  重排序分数: 0.7882
  是否通过阈值(0.7): 是
- 压力管理技巧
  重排序分数: 0.9503
  是否通过阈值(0.7): 是

最终返回 2 个结果
AI回复: 建议您先尝试一些自我调节方法，比如每天坚持30分钟有氧运动，保证7-8小时睡眠，学习正念冥想。如果持续2周以上没有改善，建议来院做专业评估。

根据您的描述，我为您推荐以下参考资源：

压力管理技巧（相关度：95.03%）
介绍实用的压力管理方法，包括时间管理、放松技巧、问题解决策略等。

正念减压练习指南（相关度：78.82%）
正念减压是一种有效的压力管理方法，本指南包含基础的正念练习方法和日常应用技巧。


正在生成语音...
播放中...
