diff --git a/documentation/build/doctrees/environment.pickle b/documentation/build/doctrees/environment.pickle index b3797a50..7b5e4feb 100644 Binary files a/documentation/build/doctrees/environment.pickle and b/documentation/build/doctrees/environment.pickle differ diff --git a/documentation/source/requirements.txt b/documentation/source/requirements.txt index e69de29b..1fe98152 100644 --- a/documentation/source/requirements.txt +++ b/documentation/source/requirements.txt @@ -0,0 +1,2 @@ +sphinx==6.1.3 +sphinx-rtd-theme==1.2.0 \ No newline at end of file diff --git a/tests/config.json b/tests/config.json index 2dfa3b8c..4bb4fa90 100644 --- a/tests/config.json +++ b/tests/config.json @@ -1,48 +1,26 @@ { - "allow_interruptions": true, "waking_up_word": "computer", "waking_up_sound": true, "deactivate_sound": true, - "improvise_tasks": false, "rules": "rules.yaml", "functions": "functions.py", "llm_model": { - "model_is_local": false, - "local_model": "mistralai/Mistral-7B-Instruct-v0.1", - "remote_model": { - "model_host": "localhost", - "model_port": 8080 - } + "model_host": "localhost", + "model_port": 8080 }, "listener_model": { - "model_is_local": false, - "local_model": "fractalego/personal-whisper-distilled-model", - "remote_model": { - "model_host": "localhost", - "model_port": 8080 - }, + "model_host": "localhost", + "model_port": 8080, "listener_hotword_logp": -8, "listener_volume_threshold": 0.6, "listener_silence_timeout": 0.7 }, "speaker_model": { - "model_is_local": false, - "local_model": "facebook/fastspeech2-en-ljspeech", - "remote_model": { - "model_host": "localhost", - "model_port": 8080 - } - }, - "entailment_model": { - "model_is_local": false, - "local_model": "MoritzLaurer/DeBERTa-v3-base-mnli-fever-anli" + "model_host": "localhost", + "model_port": 8080 }, "text_embedding_model": { - "model_is_local": false, - "local_model": "TaylorAI/gte-tiny", - "remote_model": { - "model_host": "localhost", - "model_port": 8080 - } + "model_host": "localhost", + "model_port": 8080 } } diff --git a/tests/test_connection.py b/tests/test_connection.py index dd2df36e..d81daf20 100644 --- a/tests/test_connection.py +++ b/tests/test_connection.py @@ -1,16 +1,10 @@ import asyncio import os -import wave from unittest import TestCase - -import numpy as np - from wafl.config import Configuration from wafl.connectors.bridges.llm_chitchat_answer_bridge import LLMChitChatAnswerBridge -from wafl.connectors.local.local_llm_connector import LocalLLMConnector from wafl.connectors.remote.remote_llm_connector import RemoteLLMConnector -from wafl.listener.whisper_listener import WhisperListener from wafl.speaker.fairseq_speaker import FairSeqSpeaker _path = os.path.dirname(__file__) @@ -52,33 +46,3 @@ def test__connection_to_generative_model_can_generate_a_python_list(self): prediction = asyncio.run(connector.predict(prompt)) print(prediction) assert len(prediction) > 0 - - def test__local_llm_connector_can_generate_a_python_list(self): - config = Configuration.load_from_filename("local_config.json") - connector = LocalLLMConnector(config.get_value("llm_model")) - connector._num_prediction_tokens = 200 - prompt = "Generate a list of 4 chapters names for a space opera book. The output needs to be a python list of strings: " - prediction = asyncio.run(connector.predict(prompt)) - print(prediction) - assert len(prediction) > 0 - - def test__chit_chat_bridge_can_run_locally(self): - config = Configuration.load_from_filename("local_config.json") - dialogue_bridge = LLMChitChatAnswerBridge(config) - answer = asyncio.run(dialogue_bridge.get_answer("", "", "bot: hello")) - assert len(answer) > 0 - - def test__listener_local_connector(self): - config = Configuration.load_from_filename("local_config.json") - listener = WhisperListener(config) - f = wave.open(os.path.join(_path, "data/1002.wav"), "rb") - waveform = np.frombuffer(f.readframes(f.getnframes()), dtype=np.int16) / 32768 - result = asyncio.run(listener.input_waveform(waveform)) - expected = "DELETE BATTERIES FROM THE GROCERY LIST" - assert expected.lower() in result - - def test__speaker_local_connector(self): - config = Configuration.load_from_filename("local_config.json") - speaker = FairSeqSpeaker(config) - text = "Hello world" - asyncio.run(speaker.speak(text)) diff --git a/tests/test_voice.py b/tests/test_voice.py index a4abbd12..6a95cba8 100644 --- a/tests/test_voice.py +++ b/tests/test_voice.py @@ -68,14 +68,6 @@ def test__random_sounds_are_excluded(self): expected = "[unclear]" assert result == expected - def test__voice_interface_receives_config(self): - config = Configuration.load_local_config() - interface = VoiceInterface(config) - assert ( - interface.listener_model_name - == config.get_value("listener_model")["local_model"] - ) - def test__hotword_listener_activated_using_recording_of_hotword(self): f = wave.open(os.path.join(_path, "data/computer.wav"), "rb") waveform = np.frombuffer(f.readframes(f.getnframes()), dtype=np.int16) / 32768 diff --git a/wafl/answerer/dialogue_answerer.py b/wafl/answerer/dialogue_answerer.py index 53dce11b..998ab0ab 100644 --- a/wafl/answerer/dialogue_answerer.py +++ b/wafl/answerer/dialogue_answerer.py @@ -18,6 +18,7 @@ class DialogueAnswerer(BaseAnswerer): def __init__(self, config, knowledge, interface, code_path, logger): + self._delete_current_rule = "" self._bridge = LLMChitChatAnswerBridge(config) self._knowledge = knowledge self._logger = logger @@ -27,6 +28,7 @@ def __init__(self, config, knowledge, interface, code_path, logger): self._max_num_past_utterances_for_rules = 0 self._prior_facts_with_timestamp = [] self._init_python_module(code_path.replace(".py", "")) + self._prior_rule_with_timestamp = None self._max_predictions = 3 async def answer(self, query_text): @@ -48,6 +50,16 @@ async def answer(self, query_text): dialogue_items = dialogue dialogue_items = sorted(dialogue_items, key=lambda x: x[0]) + if rules_texts: + last_timestamp = dialogue_items[-1][0] + self._prior_rule_with_timestamp = (last_timestamp, rules_texts) + dialogue_items = self._insert_rule_into_dialogue_items(rules_texts, last_timestamp, dialogue_items) + + elif self._prior_rule_with_timestamp: + last_timestamp = self._prior_rule_with_timestamp[0] + rules_texts = self._prior_rule_with_timestamp[1] + dialogue_items = self._insert_rule_into_dialogue_items(rules_texts, last_timestamp, dialogue_items) + last_bot_utterances = get_last_bot_utterances(dialogue_items, num_utterances=3) last_user_utterance = get_last_user_utterance(dialogue_items) dialogue_items = [item[1] for item in dialogue_items if item[0] >= start_time] @@ -76,6 +88,11 @@ async def answer(self, query_text): dialogue_items = last_user_utterance continue + if self._delete_current_rule in answer_text: + self._prior_rule_with_timestamp = None + dialogue_items += f"\n{original_answer_text}" + continue + if not memories: break @@ -131,6 +148,8 @@ async def _get_relevant_rules(self, query, max_num_rules=1): for cause_index, causes in enumerate(rule.causes): rules_text += f" {cause_index + 1}) {causes.text}\n" + rules_text += f' {len(rule.causes) + 1}) After you completed all the steps output "{self._delete_current_rule}"\n' + rules_texts.append(rules_text) await self._interface.add_fact(f"The bot remembers the rule:\n{rules_text}") @@ -196,3 +215,15 @@ async def _run_code(self, to_execute): result = f'\n"""python\n{to_execute}\n"""' return result + + def _insert_rule_into_dialogue_items(self, rules_texts, rule_timestamp, dialogue_items): + new_dialogue_items = [] + already_inserted = False + for timestamp, utterance in dialogue_items: + if not already_inserted and utterance.startswith("user:") and rule_timestamp == timestamp: + new_dialogue_items.append((rule_timestamp, f"user: I want you to follow these rules:\n{rules_texts}")) + already_inserted = True + + new_dialogue_items.append((timestamp, utterance)) + + return new_dialogue_items diff --git a/wafl/connectors/bridges/llm_chitchat_answer_bridge.py b/wafl/connectors/bridges/llm_chitchat_answer_bridge.py index 3acb086b..88449c68 100644 --- a/wafl/connectors/bridges/llm_chitchat_answer_bridge.py +++ b/wafl/connectors/bridges/llm_chitchat_answer_bridge.py @@ -15,19 +15,6 @@ async def get_answer(self, text: str, dialogue: str, query: str) -> str: return await self._connector.generate(prompt) async def _get_answer_prompt(self, text, rules_text, dialogue=None): - if rules_text: - rules_to_use = f"I want you to follow these rules:\n{rules_text.strip()}\n" - pattern = "\nuser: " - if pattern in dialogue: - last_user_position = dialogue.rfind(pattern) - before_user_dialogue, after_user_dialogue = ( - dialogue[:last_user_position], - dialogue[last_user_position + len(pattern) :], - ) - dialogue = f"{before_user_dialogue}\nuser: {rules_to_use}\nuser: {after_user_dialogue}" - else: - dialogue = f"user: {rules_to_use}\n{dialogue}" - prompt = f""" The following is a summary of a conversation. All the elements of the conversation are described briefly: diff --git a/wafl/interface/voice_interface.py b/wafl/interface/voice_interface.py index 44caea9a..ce9d149c 100644 --- a/wafl/interface/voice_interface.py +++ b/wafl/interface/voice_interface.py @@ -28,7 +28,6 @@ def __init__(self, config, output_filter=None): self._deactivation_sound_filename = self.__get_deactivation_sound_from_config( config ) - self.listener_model_name = config.get_value("listener_model")["local_model"] self._speaker = FairSeqSpeaker(config) self._listener = WhisperListener(config) self._listener.set_timeout( diff --git a/wafl/knowledge/single_file_knowledge.py b/wafl/knowledge/single_file_knowledge.py index 8d69f1d1..66589d5b 100644 --- a/wafl/knowledge/single_file_knowledge.py +++ b/wafl/knowledge/single_file_knowledge.py @@ -31,9 +31,9 @@ class SingleFileKnowledge(BaseKnowledge): _threshold_for_questions_from_bot = 0.6 _threshold_for_questions_in_rules = 0.49 _threshold_for_facts = 0.4 - _threshold_for_fact_rules = 0.22 - _threshold_for_fact_rules_for_creation = 0.1 - _threshold_for_partial_facts = 0.48 + _threshold_for_fact_rules = 0.6 + _threshold_for_fact_rules_for_creation = 0.6 + _threshold_for_partial_facts = 0.6 _max_rules_per_type = 3 def __init__(self, config, rules_text=None, knowledge_name=None, logger=None): diff --git a/wafl/variables.py b/wafl/variables.py index 0a5126b3..47a4ebed 100644 --- a/wafl/variables.py +++ b/wafl/variables.py @@ -1,4 +1,4 @@ def get_variables(): return { - "version": "0.0.80", + "version": "0.0.81", }