Skip to content

Commit

Permalink
Merge pull request #679 from lss233/sourcery/pull-677
Browse files Browse the repository at this point in the history
Feat optimize tts voice manage (Sourcery refactored)
  • Loading branch information
lss233 authored Apr 20, 2023
2 parents e8c3fc9 + cdb6ee3 commit 2e61135
Show file tree
Hide file tree
Showing 5 changed files with 33 additions and 39 deletions.
35 changes: 16 additions & 19 deletions adapter/quora/poe.py
Original file line number Diff line number Diff line change
Expand Up @@ -69,14 +69,13 @@ async def ask(self, msg: str) -> Generator[str, None, None]:
self.poe_client.last_ask_time = time.time()
except Exception as e:
logger.warning(f"Poe connection error {str(e)}")
if self.process_retry <= 3:
new_poe_client = botManager.reset_bot(self.poe_client)
self.poe_client = new_poe_client
self.process_retry += 1
async for resp in self.ask(msg):
yield resp
else:
if self.process_retry > 3:
raise e
new_poe_client = botManager.reset_bot(self.poe_client)
self.poe_client = new_poe_client
self.process_retry += 1
async for resp in self.ask(msg):
yield resp

def check_and_reset_client(self):
current_time = time.time()
Expand All @@ -92,13 +91,12 @@ async def rollback(self):
self.process_retry = 0
except Exception as e:
logger.warning(f"Poe connection error {str(e)}")
if self.process_retry <= 3:
new_poe_client = botManager.reset_bot(self.poe_client)
self.poe_client = new_poe_client
self.process_retry += 1
await self.rollback()
else:
if self.process_retry > 3:
raise e
new_poe_client = botManager.reset_bot(self.poe_client)
self.poe_client = new_poe_client
self.process_retry += 1
await self.rollback()

async def on_reset(self):
"""当会话被重置时,此函数被调用"""
Expand All @@ -107,10 +105,9 @@ async def on_reset(self):
self.process_retry = 0
except Exception as e:
logger.warning(f"Poe connection error {str(e)}")
if self.process_retry <= 3:
new_poe_client = botManager.reset_bot(self.poe_client)
self.poe_client = new_poe_client
self.process_retry += 1
await self.on_reset()
else:
if self.process_retry > 3:
raise e
new_poe_client = botManager.reset_bot(self.poe_client)
self.poe_client = new_poe_client
self.process_retry += 1
await self.on_reset()
4 changes: 1 addition & 3 deletions middlewares/draw_ratelimit.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,9 +15,7 @@ def __init__(self):
def handle_draw_request(self, session_id: str, prompt: str):
_id = session_id.split('-', 1)[1] if '-' in session_id else session_id
rate_usage = manager.check_draw_exceed('好友' if session_id.startswith("friend-") else '群组', _id)
if rate_usage >= 1:
return config.ratelimit.draw_exceed
return "1"
return config.ratelimit.draw_exceed if rate_usage >= 1 else "1"



Expand Down
2 changes: 1 addition & 1 deletion universal.py
Original file line number Diff line number Diff line change
Expand Up @@ -148,7 +148,7 @@ async def request(_session_id, prompt: str, conversation_context, _respond):
if not config.azure.tts_speech_key and config.text_to_speech.engine == "azure":
await respond("未配置 Azure TTS 账户,无法切换语音!")
new_voice = voice_type_search[1].strip()
if new_voice == '关闭' or new_voice == "None":
if new_voice in ['关闭', "None"]:
conversation_context.conversation_voice = None
await respond("已关闭语音,让我们继续聊天吧!")
elif config.text_to_speech.engine == "vits":
Expand Down
5 changes: 3 additions & 2 deletions utils/edge_tts.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,8 +12,9 @@ async def load_edge_tts_voices():
if edge_tts_voices:
return edge_tts_voices
for el in await edge_tts.list_voices():
tts_voice = TtsVoice.parse("edge", el.get('ShortName', ''), el.get('Gender', None))
if tts_voice:
if tts_voice := TtsVoice.parse(
"edge", el.get('ShortName', ''), el.get('Gender', None)
):
edge_tts_voices[tts_voice.alias] = tts_voice
logger.info(f"{len(edge_tts_voices)} edge tts voices loaded.")
return edge_tts_voices
Expand Down
26 changes: 12 additions & 14 deletions utils/text_to_speech.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,7 +43,7 @@ def parse(engine, voice: str, gender=None):
tts_voice.engine = engine
tts_voice.full_name = voice
tts_voice.gender = gender
if engine == "edge" or engine == "azure":
if engine in ["edge", "azure"]:
"""如:zh-CN-liaoning-XiaobeiNeural、uz-UZ-SardorNeural"""
voice_info = voice.split("-")
if len(voice_info) < 3:
Expand All @@ -62,31 +62,29 @@ def parse(engine, voice: str, gender=None):
tts_voice.name = name
tts_voice.alias = alias
tts_voice.sub_region = sub_region
return tts_voice
else:
tts_voice.lang = voice
tts_voice.alias = voice
return tts_voice

return tts_voice


class TtsVoiceManager:
"""tts音色管理"""

@staticmethod
async def parse_tts_voice(tts_engine, voice_name) -> TtsVoice:
if tts_engine == "edge":
from utils.edge_tts import load_edge_tts_voices
if "edge" not in tts_voice_dic:
tts_voice_dic["edge"] = await load_edge_tts_voices()
_voice_dic = tts_voice_dic["edge"]
_voice = TtsVoice.parse(tts_engine, voice_name)
if _voice:
return _voice_dic.get(_voice.alias, None)
if voice_name in _voice_dic:
return _voice_dic[voice_name]
else:
if tts_engine != "edge":
# todo support other engines
return TtsVoice.parse(tts_engine, voice_name)
from utils.edge_tts import load_edge_tts_voices
if "edge" not in tts_voice_dic:
tts_voice_dic["edge"] = await load_edge_tts_voices()
_voice_dic = tts_voice_dic["edge"]
if _voice := TtsVoice.parse(tts_engine, voice_name):
return _voice_dic.get(_voice.alias, None)
if voice_name in _voice_dic:
return _voice_dic[voice_name]

@staticmethod
async def list_tts_voices(tts_engine, voice_prefix):
Expand Down

0 comments on commit 2e61135

Please sign in to comment.