Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

test: added tests #250

Merged
merged 27 commits into from
Jul 16, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
27 commits
Select commit Hold shift + click to select a range
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
42 changes: 42 additions & 0 deletions tests/chunkers/test_text.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,42 @@
# ruff: noqa: E501
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Any specific reason to add this? :D

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

otherwise it complains about all the doc strings.


import unittest

from embedchain.chunkers.text import TextChunker


class TestTextChunker(unittest.TestCase):
def test_chunks(self):
"""
Test the chunks generated by TextChunker.
# TODO: Not a very precise test.
"""
chunker_config = {
"chunk_size": 10,
"chunk_overlap": 0,
"length_function": len,
}
chunker = TextChunker(config=chunker_config)
text = "Lorem ipsum dolor sit amet, consectetur adipiscing elit."

result = chunker.create_chunks(MockLoader(), text)

documents = result["documents"]

self.assertGreaterEqual(len(documents), 5)

# Additional test cases can be added to cover different scenarios


class MockLoader:
def load_data(self, src):
"""
Mock loader that returns a list of data dictionaries.
Adjust this method to return different data for testing.
"""
return [
{
"content": src,
"meta_data": {"url": "none"},
}
]
26 changes: 26 additions & 0 deletions tests/embedchain/test_add.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,26 @@
import os
import unittest
from unittest.mock import MagicMock, patch

from embedchain import App


class TestApp(unittest.TestCase):
os.environ["OPENAI_API_KEY"] = "test_key"

def setUp(self):
self.app = App()

@patch("chromadb.api.models.Collection.Collection.add", MagicMock)
def test_add(self):
"""
This test checks the functionality of the 'add' method in the App class.
It begins by simulating the addition of a web page with a specific URL to the application instance.
The 'add' method is expected to append the input type and URL to the 'user_asks' attribute of the App instance.
By asserting that 'user_asks' is updated correctly after the 'add' method is called, we can confirm that the
method is working as intended.
The Collection.add method from the chromadb library is mocked during this test to isolate the behavior of the
'add' method.
"""
self.app.add("web_page", "https://example.com", {"meta": "meta-data"})
self.assertEqual(self.app.user_asks, [["web_page", "https://example.com", {"meta": "meta-data"}]])
48 changes: 48 additions & 0 deletions tests/embedchain/test_chat.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,48 @@
import os
import unittest
from unittest.mock import patch

from embedchain import App


class TestApp(unittest.TestCase):
os.environ["OPENAI_API_KEY"] = "test_key"

def setUp(self):
self.app = App()

@patch("embedchain.embedchain.memory", autospec=True)
@patch.object(App, "retrieve_from_database", return_value=["Test context"])
@patch.object(App, "get_answer_from_llm", return_value="Test answer")
def test_chat_with_memory(self, mock_answer, mock_retrieve, mock_memory):
"""
This test checks the functionality of the 'chat' method in the App class with respect to the chat history
memory.
The 'chat' method is called twice. The first call initializes the chat history memory.
The second call is expected to use the chat history from the first call.

Key assumptions tested:
- After the first call, 'memory.chat_memory.add_user_message' and 'memory.chat_memory.add_ai_message' are
called with correct arguments, adding the correct chat history.
- During the second call, the 'chat' method uses the chat history from the first call.

The test isolates the 'chat' method behavior by mocking out 'retrieve_from_database', 'get_answer_from_llm' and
'memory' methods.
"""
mock_memory.load_memory_variables.return_value = {"history": []}
app = App()

# First call to chat
first_answer = app.chat("Test query 1")
self.assertEqual(first_answer, "Test answer")
mock_memory.chat_memory.add_user_message.assert_called_once_with("Test query 1")
mock_memory.chat_memory.add_ai_message.assert_called_once_with("Test answer")

mock_memory.chat_memory.add_user_message.reset_mock()
mock_memory.chat_memory.add_ai_message.reset_mock()

# Second call to chat
second_answer = app.chat("Test query 2")
self.assertEqual(second_answer, "Test answer")
mock_memory.chat_memory.add_user_message.assert_called_once_with("Test query 2")
mock_memory.chat_memory.add_ai_message.assert_called_once_with("Test answer")
52 changes: 52 additions & 0 deletions tests/embedchain/test_dryrun.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,52 @@
import os
import unittest
from string import Template
from unittest.mock import patch

from embedchain import App
from embedchain.embedchain import QueryConfig


class TestApp(unittest.TestCase):
os.environ["OPENAI_API_KEY"] = "test_key"

def setUp(self):
self.app = App()

@patch("logging.info")
def test_query_logs_same_prompt_as_dry_run(self, mock_logging_info):
"""
Test that the 'query' method logs the same prompt as the 'dry_run' method.
This is the only way I found to test the prompt in query, that's not returned.
"""
with patch.object(self.app, "retrieve_from_database") as mock_retrieve:
mock_retrieve.return_value = ["Test context"]
input_query = "Test query"
config = QueryConfig(
number_documents=3,
template=Template("Question: $query, context: $context, history: $history"),
history=["Past context 1", "Past context 2"],
)

with patch.object(self.app, "get_answer_from_llm"):
self.app.dry_run(input_query, config)
self.app.query(input_query, config)

# Access the log messages captured during the execution
logged_messages = [call[0][0] for call in mock_logging_info.call_args_list]

# Extract the prompts from the log messages
dry_run_prompt = self.extract_prompt(logged_messages[0])
query_prompt = self.extract_prompt(logged_messages[1])

# Perform assertions on the prompts
self.assertEqual(dry_run_prompt, query_prompt)

def extract_prompt(self, log_message):
"""
Extracts the prompt value from the log message.
Adjust this method based on the log message format in your implementation.
"""
# Modify this logic based on your log message format
prefix = "Prompt: "
return log_message.split(prefix, 1)[1]
39 changes: 39 additions & 0 deletions tests/embedchain/test_embedchain.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,39 @@
import os
import unittest
from unittest.mock import patch

from embedchain import App
from embedchain.config import InitConfig


class TestChromaDbHostsLoglevel(unittest.TestCase):
os.environ["OPENAI_API_KEY"] = "test_key"

@patch("chromadb.api.models.Collection.Collection.add")
@patch("chromadb.api.models.Collection.Collection.get")
@patch("embedchain.embedchain.EmbedChain.retrieve_from_database")
@patch("embedchain.embedchain.EmbedChain.get_answer_from_llm")
@patch("embedchain.embedchain.EmbedChain.get_llm_model_answer")
def test_whole_app(
self,
_mock_get,
_mock_add,
_mock_ec_retrieve_from_database,
_mock_get_answer_from_llm,
mock_ec_get_llm_model_answer,
):
"""
Test if the `App` instance is initialized without a config that does not contain default hosts and ports.
"""
config = InitConfig(log_level="DEBUG")

app = App(config)

knowledge = "lorem ipsum dolor sit amet, consectetur adipiscing"

app.add_local("text", knowledge)

app.query("What text did I give you?")
app.chat("What text did I give you?")

self.assertEqual(mock_ec_get_llm_model_answer.call_args[1]["documents"], [knowledge])
66 changes: 66 additions & 0 deletions tests/embedchain/test_generate_prompt.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,66 @@
import unittest
from string import Template

from embedchain import App
from embedchain.embedchain import QueryConfig


class TestGeneratePrompt(unittest.TestCase):
def setUp(self):
self.app = App()

def test_generate_prompt_with_template(self):
"""
Tests that the generate_prompt method correctly formats the prompt using
a custom template provided in the QueryConfig instance.

This test sets up a scenario with an input query and a list of contexts,
and a custom template, and then calls generate_prompt. It checks that the
returned prompt correctly incorporates all the contexts and the query into
the format specified by the template.
"""
# Setup
input_query = "Test query"
contexts = ["Context 1", "Context 2", "Context 3"]
template = "You are a bot. Context: ${context} - Query: ${query} - Helpful answer:"
config = QueryConfig(template=Template(template))

# Execute
result = self.app.generate_prompt(input_query, contexts, config)

# Assert
expected_result = (
"You are a bot. Context: Context 1 | Context 2 | Context 3 - Query: Test query - Helpful answer:"
)
self.assertEqual(result, expected_result)

def test_generate_prompt_with_contexts_list(self):
"""
Tests that the generate_prompt method correctly handles a list of contexts.

This test sets up a scenario with an input query and a list of contexts,
and then calls generate_prompt. It checks that the returned prompt
correctly includes all the contexts and the query.
"""
# Setup
input_query = "Test query"
contexts = ["Context 1", "Context 2", "Context 3"]
config = QueryConfig()

# Execute
result = self.app.generate_prompt(input_query, contexts, config)

# Assert
expected_result = config.template.substitute(context="Context 1 | Context 2 | Context 3", query=input_query)
self.assertEqual(result, expected_result)

def test_generate_prompt_with_history(self):
"""
Test the 'generate_prompt' method with QueryConfig containing a history attribute.
"""
config = QueryConfig(history=["Past context 1", "Past context 2"])
config.template = Template("Context: $context | Query: $query | History: $history")
prompt = self.app.generate_prompt("Test query", ["Test context"], config)

expected_prompt = "Context: Test context | Query: Test query | History: ['Past context 1', 'Past context 2']"
self.assertEqual(prompt, expected_prompt)
43 changes: 43 additions & 0 deletions tests/embedchain/test_query.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,43 @@
import os
import unittest
from unittest.mock import MagicMock, patch

from embedchain import App
from embedchain.embedchain import QueryConfig


class TestApp(unittest.TestCase):
os.environ["OPENAI_API_KEY"] = "test_key"

def setUp(self):
self.app = App()

@patch("chromadb.api.models.Collection.Collection.add", MagicMock)
def test_query(self):
"""
This test checks the functionality of the 'query' method in the App class.
It simulates a scenario where the 'retrieve_from_database' method returns a context list and
'get_llm_model_answer' returns an expected answer string.

The 'query' method is expected to call 'retrieve_from_database' and 'get_llm_model_answer' methods
appropriately and return the right answer.

Key assumptions tested:
- 'retrieve_from_database' method is called exactly once with arguments: "Test query" and an instance of
QueryConfig.
- 'get_llm_model_answer' is called exactly once. The specific arguments are not checked in this test.
- 'query' method returns the value it received from 'get_llm_model_answer'.

The test isolates the 'query' method behavior by mocking out 'retrieve_from_database' and
'get_llm_model_answer' methods.
"""
with patch.object(self.app, "retrieve_from_database") as mock_retrieve:
mock_retrieve.return_value = ["Test context"]
with patch.object(self.app, "get_llm_model_answer") as mock_answer:
mock_answer.return_value = "Test answer"
answer = self.app.query("Test query")

self.assertEqual(answer, "Test answer")
self.assertEqual(mock_retrieve.call_args[0][0], "Test query")
self.assertIsInstance(mock_retrieve.call_args[0][1], QueryConfig)
mock_answer.assert_called_once()
29 changes: 0 additions & 29 deletions tests/test_embedchain.py

This file was deleted.

Loading