Skip to content

Commit

Permalink
Merge pull request #85 from fractalego/multi-dialogue-turn-rules
Browse files Browse the repository at this point in the history
Multi dialogue turn rules
  • Loading branch information
fractalego committed Dec 23, 2023
2 parents 897ed13 + 42a3cd2 commit d6953ac
Show file tree
Hide file tree
Showing 10 changed files with 45 additions and 92 deletions.
Binary file modified documentation/build/doctrees/environment.pickle
Binary file not shown.
2 changes: 2 additions & 0 deletions documentation/source/requirements.txt
Original file line number Diff line number Diff line change
@@ -0,0 +1,2 @@
sphinx==6.1.3
sphinx-rtd-theme==1.2.0
38 changes: 8 additions & 30 deletions tests/config.json
Original file line number Diff line number Diff line change
@@ -1,48 +1,26 @@
{
"allow_interruptions": true,
"waking_up_word": "computer",
"waking_up_sound": true,
"deactivate_sound": true,
"improvise_tasks": false,
"rules": "rules.yaml",
"functions": "functions.py",
"llm_model": {
"model_is_local": false,
"local_model": "mistralai/Mistral-7B-Instruct-v0.1",
"remote_model": {
"model_host": "localhost",
"model_port": 8080
}
"model_host": "localhost",
"model_port": 8080
},
"listener_model": {
"model_is_local": false,
"local_model": "fractalego/personal-whisper-distilled-model",
"remote_model": {
"model_host": "localhost",
"model_port": 8080
},
"model_host": "localhost",
"model_port": 8080,
"listener_hotword_logp": -8,
"listener_volume_threshold": 0.6,
"listener_silence_timeout": 0.7
},
"speaker_model": {
"model_is_local": false,
"local_model": "facebook/fastspeech2-en-ljspeech",
"remote_model": {
"model_host": "localhost",
"model_port": 8080
}
},
"entailment_model": {
"model_is_local": false,
"local_model": "MoritzLaurer/DeBERTa-v3-base-mnli-fever-anli"
"model_host": "localhost",
"model_port": 8080
},
"text_embedding_model": {
"model_is_local": false,
"local_model": "TaylorAI/gte-tiny",
"remote_model": {
"model_host": "localhost",
"model_port": 8080
}
"model_host": "localhost",
"model_port": 8080
}
}
36 changes: 0 additions & 36 deletions tests/test_connection.py
Original file line number Diff line number Diff line change
@@ -1,16 +1,10 @@
import asyncio
import os
import wave

from unittest import TestCase

import numpy as np

from wafl.config import Configuration
from wafl.connectors.bridges.llm_chitchat_answer_bridge import LLMChitChatAnswerBridge
from wafl.connectors.local.local_llm_connector import LocalLLMConnector
from wafl.connectors.remote.remote_llm_connector import RemoteLLMConnector
from wafl.listener.whisper_listener import WhisperListener
from wafl.speaker.fairseq_speaker import FairSeqSpeaker

_path = os.path.dirname(__file__)
Expand Down Expand Up @@ -52,33 +46,3 @@ def test__connection_to_generative_model_can_generate_a_python_list(self):
prediction = asyncio.run(connector.predict(prompt))
print(prediction)
assert len(prediction) > 0

def test__local_llm_connector_can_generate_a_python_list(self):
config = Configuration.load_from_filename("local_config.json")
connector = LocalLLMConnector(config.get_value("llm_model"))
connector._num_prediction_tokens = 200
prompt = "Generate a list of 4 chapters names for a space opera book. The output needs to be a python list of strings: "
prediction = asyncio.run(connector.predict(prompt))
print(prediction)
assert len(prediction) > 0

def test__chit_chat_bridge_can_run_locally(self):
config = Configuration.load_from_filename("local_config.json")
dialogue_bridge = LLMChitChatAnswerBridge(config)
answer = asyncio.run(dialogue_bridge.get_answer("", "", "bot: hello"))
assert len(answer) > 0

def test__listener_local_connector(self):
config = Configuration.load_from_filename("local_config.json")
listener = WhisperListener(config)
f = wave.open(os.path.join(_path, "data/1002.wav"), "rb")
waveform = np.frombuffer(f.readframes(f.getnframes()), dtype=np.int16) / 32768
result = asyncio.run(listener.input_waveform(waveform))
expected = "DELETE BATTERIES FROM THE GROCERY LIST"
assert expected.lower() in result

def test__speaker_local_connector(self):
config = Configuration.load_from_filename("local_config.json")
speaker = FairSeqSpeaker(config)
text = "Hello world"
asyncio.run(speaker.speak(text))
8 changes: 0 additions & 8 deletions tests/test_voice.py
Original file line number Diff line number Diff line change
Expand Up @@ -68,14 +68,6 @@ def test__random_sounds_are_excluded(self):
expected = "[unclear]"
assert result == expected

def test__voice_interface_receives_config(self):
config = Configuration.load_local_config()
interface = VoiceInterface(config)
assert (
interface.listener_model_name
== config.get_value("listener_model")["local_model"]
)

def test__hotword_listener_activated_using_recording_of_hotword(self):
f = wave.open(os.path.join(_path, "data/computer.wav"), "rb")
waveform = np.frombuffer(f.readframes(f.getnframes()), dtype=np.int16) / 32768
Expand Down
31 changes: 31 additions & 0 deletions wafl/answerer/dialogue_answerer.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@

class DialogueAnswerer(BaseAnswerer):
def __init__(self, config, knowledge, interface, code_path, logger):
self._delete_current_rule = "<delete_rule>"
self._bridge = LLMChitChatAnswerBridge(config)
self._knowledge = knowledge
self._logger = logger
Expand All @@ -27,6 +28,7 @@ def __init__(self, config, knowledge, interface, code_path, logger):
self._max_num_past_utterances_for_rules = 0
self._prior_facts_with_timestamp = []
self._init_python_module(code_path.replace(".py", ""))
self._prior_rule_with_timestamp = None
self._max_predictions = 3

async def answer(self, query_text):
Expand All @@ -48,6 +50,16 @@ async def answer(self, query_text):

dialogue_items = dialogue
dialogue_items = sorted(dialogue_items, key=lambda x: x[0])
if rules_texts:
last_timestamp = dialogue_items[-1][0]
self._prior_rule_with_timestamp = (last_timestamp, rules_texts)
dialogue_items = self._insert_rule_into_dialogue_items(rules_texts, last_timestamp, dialogue_items)

elif self._prior_rule_with_timestamp:
last_timestamp = self._prior_rule_with_timestamp[0]
rules_texts = self._prior_rule_with_timestamp[1]
dialogue_items = self._insert_rule_into_dialogue_items(rules_texts, last_timestamp, dialogue_items)

last_bot_utterances = get_last_bot_utterances(dialogue_items, num_utterances=3)
last_user_utterance = get_last_user_utterance(dialogue_items)
dialogue_items = [item[1] for item in dialogue_items if item[0] >= start_time]
Expand Down Expand Up @@ -76,6 +88,11 @@ async def answer(self, query_text):
dialogue_items = last_user_utterance
continue

if self._delete_current_rule in answer_text:
self._prior_rule_with_timestamp = None
dialogue_items += f"\n{original_answer_text}"
continue

if not memories:
break

Expand Down Expand Up @@ -131,6 +148,8 @@ async def _get_relevant_rules(self, query, max_num_rules=1):
for cause_index, causes in enumerate(rule.causes):
rules_text += f" {cause_index + 1}) {causes.text}\n"

rules_text += f' {len(rule.causes) + 1}) After you completed all the steps output "{self._delete_current_rule}"\n'

rules_texts.append(rules_text)
await self._interface.add_fact(f"The bot remembers the rule:\n{rules_text}")

Expand Down Expand Up @@ -196,3 +215,15 @@ async def _run_code(self, to_execute):
result = f'\n"""python\n{to_execute}\n"""'

return result

def _insert_rule_into_dialogue_items(self, rules_texts, rule_timestamp, dialogue_items):
new_dialogue_items = []
already_inserted = False
for timestamp, utterance in dialogue_items:
if not already_inserted and utterance.startswith("user:") and rule_timestamp == timestamp:
new_dialogue_items.append((rule_timestamp, f"user: I want you to follow these rules:\n{rules_texts}"))
already_inserted = True

new_dialogue_items.append((timestamp, utterance))

return new_dialogue_items
13 changes: 0 additions & 13 deletions wafl/connectors/bridges/llm_chitchat_answer_bridge.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,19 +15,6 @@ async def get_answer(self, text: str, dialogue: str, query: str) -> str:
return await self._connector.generate(prompt)

async def _get_answer_prompt(self, text, rules_text, dialogue=None):
if rules_text:
rules_to_use = f"I want you to follow these rules:\n{rules_text.strip()}\n"
pattern = "\nuser: "
if pattern in dialogue:
last_user_position = dialogue.rfind(pattern)
before_user_dialogue, after_user_dialogue = (
dialogue[:last_user_position],
dialogue[last_user_position + len(pattern) :],
)
dialogue = f"{before_user_dialogue}\nuser: {rules_to_use}\nuser: {after_user_dialogue}"
else:
dialogue = f"user: {rules_to_use}\n{dialogue}"

prompt = f"""
The following is a summary of a conversation. All the elements of the conversation are described briefly:
<summary>
Expand Down
1 change: 0 additions & 1 deletion wafl/interface/voice_interface.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,6 @@ def __init__(self, config, output_filter=None):
self._deactivation_sound_filename = self.__get_deactivation_sound_from_config(
config
)
self.listener_model_name = config.get_value("listener_model")["local_model"]
self._speaker = FairSeqSpeaker(config)
self._listener = WhisperListener(config)
self._listener.set_timeout(
Expand Down
6 changes: 3 additions & 3 deletions wafl/knowledge/single_file_knowledge.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,9 +31,9 @@ class SingleFileKnowledge(BaseKnowledge):
_threshold_for_questions_from_bot = 0.6
_threshold_for_questions_in_rules = 0.49
_threshold_for_facts = 0.4
_threshold_for_fact_rules = 0.22
_threshold_for_fact_rules_for_creation = 0.1
_threshold_for_partial_facts = 0.48
_threshold_for_fact_rules = 0.6
_threshold_for_fact_rules_for_creation = 0.6
_threshold_for_partial_facts = 0.6
_max_rules_per_type = 3

def __init__(self, config, rules_text=None, knowledge_name=None, logger=None):
Expand Down
2 changes: 1 addition & 1 deletion wafl/variables.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
def get_variables():
return {
"version": "0.0.80",
"version": "0.0.81",
}

0 comments on commit d6953ac

Please sign in to comment.