-
Notifications
You must be signed in to change notification settings - Fork 1.1k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
test: added tests #250
Merged
Merged
test: added tests #250
Changes from all commits
Commits
Show all changes
27 commits
Select commit
Hold shift + click to select a range
00ab70b
chore: move to folder
cachho 664c60e
test: prompt generation
cachho 03897a8
docs: updated docstring
cachho e2912f5
chore: linting
cachho fecc589
docs: updated docstring
cachho 198307c
chore: ignore from linting
cachho ebb73c0
fix: remove unused imports
cachho 60c5e42
chore: renamed
cachho 2fabe88
test: added test for history
cachho c276774
fix: remove unused
cachho 4751ab2
test: add dry run test
cachho a94d4b5
test: expanded test
cachho 808b1ad
test: add chunker test
cachho b8f2e6c
test: added port test
cachho 54bfcba
test: add tests
cachho 55f6bca
test: added test for whole app
cachho aabdad2
test: include chat in test
cachho 1ebdf3f
test: added chat test
cachho 8703304
Delete test.py
cachho 74793c4
fix: removed function
cachho c11ba69
Merge branch 'test/UpdateTests' of github.com:cachho/embedchain into …
cachho b457f33
fix: break lines
cachho bcd0d52
Merge branch 'main' into test/UpdateTests
cachho be9b40b
chore: add support for metadata
cachho ebae393
Merge branch 'main' into test/UpdateTests
cachho 9531aca
chore: linting, 120 chars per line
cachho 5bf346c
chore: removed unused imports
cachho File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,42 @@ | ||
# ruff: noqa: E501 | ||
|
||
import unittest | ||
|
||
from embedchain.chunkers.text import TextChunker | ||
|
||
|
||
class TestTextChunker(unittest.TestCase): | ||
def test_chunks(self): | ||
""" | ||
Test the chunks generated by TextChunker. | ||
# TODO: Not a very precise test. | ||
""" | ||
chunker_config = { | ||
"chunk_size": 10, | ||
"chunk_overlap": 0, | ||
"length_function": len, | ||
} | ||
chunker = TextChunker(config=chunker_config) | ||
text = "Lorem ipsum dolor sit amet, consectetur adipiscing elit." | ||
|
||
result = chunker.create_chunks(MockLoader(), text) | ||
|
||
documents = result["documents"] | ||
|
||
self.assertGreaterEqual(len(documents), 5) | ||
|
||
# Additional test cases can be added to cover different scenarios | ||
|
||
|
||
class MockLoader: | ||
def load_data(self, src): | ||
""" | ||
Mock loader that returns a list of data dictionaries. | ||
Adjust this method to return different data for testing. | ||
""" | ||
return [ | ||
{ | ||
"content": src, | ||
"meta_data": {"url": "none"}, | ||
} | ||
] |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,26 @@ | ||
import os | ||
import unittest | ||
from unittest.mock import MagicMock, patch | ||
|
||
from embedchain import App | ||
|
||
|
||
class TestApp(unittest.TestCase): | ||
os.environ["OPENAI_API_KEY"] = "test_key" | ||
|
||
def setUp(self): | ||
self.app = App() | ||
|
||
@patch("chromadb.api.models.Collection.Collection.add", MagicMock) | ||
def test_add(self): | ||
""" | ||
This test checks the functionality of the 'add' method in the App class. | ||
It begins by simulating the addition of a web page with a specific URL to the application instance. | ||
The 'add' method is expected to append the input type and URL to the 'user_asks' attribute of the App instance. | ||
By asserting that 'user_asks' is updated correctly after the 'add' method is called, we can confirm that the | ||
method is working as intended. | ||
The Collection.add method from the chromadb library is mocked during this test to isolate the behavior of the | ||
'add' method. | ||
""" | ||
self.app.add("web_page", "https://example.com", {"meta": "meta-data"}) | ||
self.assertEqual(self.app.user_asks, [["web_page", "https://example.com", {"meta": "meta-data"}]]) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,48 @@ | ||
import os | ||
import unittest | ||
from unittest.mock import patch | ||
|
||
from embedchain import App | ||
|
||
|
||
class TestApp(unittest.TestCase): | ||
os.environ["OPENAI_API_KEY"] = "test_key" | ||
|
||
def setUp(self): | ||
self.app = App() | ||
|
||
@patch("embedchain.embedchain.memory", autospec=True) | ||
@patch.object(App, "retrieve_from_database", return_value=["Test context"]) | ||
@patch.object(App, "get_answer_from_llm", return_value="Test answer") | ||
def test_chat_with_memory(self, mock_answer, mock_retrieve, mock_memory): | ||
""" | ||
This test checks the functionality of the 'chat' method in the App class with respect to the chat history | ||
memory. | ||
The 'chat' method is called twice. The first call initializes the chat history memory. | ||
The second call is expected to use the chat history from the first call. | ||
|
||
Key assumptions tested: | ||
- After the first call, 'memory.chat_memory.add_user_message' and 'memory.chat_memory.add_ai_message' are | ||
called with correct arguments, adding the correct chat history. | ||
- During the second call, the 'chat' method uses the chat history from the first call. | ||
|
||
The test isolates the 'chat' method behavior by mocking out 'retrieve_from_database', 'get_answer_from_llm' and | ||
'memory' methods. | ||
""" | ||
mock_memory.load_memory_variables.return_value = {"history": []} | ||
app = App() | ||
|
||
# First call to chat | ||
first_answer = app.chat("Test query 1") | ||
self.assertEqual(first_answer, "Test answer") | ||
mock_memory.chat_memory.add_user_message.assert_called_once_with("Test query 1") | ||
mock_memory.chat_memory.add_ai_message.assert_called_once_with("Test answer") | ||
|
||
mock_memory.chat_memory.add_user_message.reset_mock() | ||
mock_memory.chat_memory.add_ai_message.reset_mock() | ||
|
||
# Second call to chat | ||
second_answer = app.chat("Test query 2") | ||
self.assertEqual(second_answer, "Test answer") | ||
mock_memory.chat_memory.add_user_message.assert_called_once_with("Test query 2") | ||
mock_memory.chat_memory.add_ai_message.assert_called_once_with("Test answer") |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,52 @@ | ||
import os | ||
import unittest | ||
from string import Template | ||
from unittest.mock import patch | ||
|
||
from embedchain import App | ||
from embedchain.embedchain import QueryConfig | ||
|
||
|
||
class TestApp(unittest.TestCase): | ||
os.environ["OPENAI_API_KEY"] = "test_key" | ||
|
||
def setUp(self): | ||
self.app = App() | ||
|
||
@patch("logging.info") | ||
def test_query_logs_same_prompt_as_dry_run(self, mock_logging_info): | ||
""" | ||
Test that the 'query' method logs the same prompt as the 'dry_run' method. | ||
This is the only way I found to test the prompt in query, that's not returned. | ||
""" | ||
with patch.object(self.app, "retrieve_from_database") as mock_retrieve: | ||
mock_retrieve.return_value = ["Test context"] | ||
input_query = "Test query" | ||
config = QueryConfig( | ||
number_documents=3, | ||
template=Template("Question: $query, context: $context, history: $history"), | ||
history=["Past context 1", "Past context 2"], | ||
) | ||
|
||
with patch.object(self.app, "get_answer_from_llm"): | ||
self.app.dry_run(input_query, config) | ||
self.app.query(input_query, config) | ||
|
||
# Access the log messages captured during the execution | ||
logged_messages = [call[0][0] for call in mock_logging_info.call_args_list] | ||
|
||
# Extract the prompts from the log messages | ||
dry_run_prompt = self.extract_prompt(logged_messages[0]) | ||
query_prompt = self.extract_prompt(logged_messages[1]) | ||
|
||
# Perform assertions on the prompts | ||
self.assertEqual(dry_run_prompt, query_prompt) | ||
|
||
def extract_prompt(self, log_message): | ||
""" | ||
Extracts the prompt value from the log message. | ||
Adjust this method based on the log message format in your implementation. | ||
""" | ||
# Modify this logic based on your log message format | ||
prefix = "Prompt: " | ||
return log_message.split(prefix, 1)[1] |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,39 @@ | ||
import os | ||
import unittest | ||
from unittest.mock import patch | ||
|
||
from embedchain import App | ||
from embedchain.config import InitConfig | ||
|
||
|
||
class TestChromaDbHostsLoglevel(unittest.TestCase): | ||
os.environ["OPENAI_API_KEY"] = "test_key" | ||
|
||
@patch("chromadb.api.models.Collection.Collection.add") | ||
@patch("chromadb.api.models.Collection.Collection.get") | ||
@patch("embedchain.embedchain.EmbedChain.retrieve_from_database") | ||
@patch("embedchain.embedchain.EmbedChain.get_answer_from_llm") | ||
@patch("embedchain.embedchain.EmbedChain.get_llm_model_answer") | ||
def test_whole_app( | ||
self, | ||
_mock_get, | ||
_mock_add, | ||
_mock_ec_retrieve_from_database, | ||
_mock_get_answer_from_llm, | ||
mock_ec_get_llm_model_answer, | ||
): | ||
""" | ||
Test if the `App` instance is initialized without a config that does not contain default hosts and ports. | ||
""" | ||
config = InitConfig(log_level="DEBUG") | ||
|
||
app = App(config) | ||
|
||
knowledge = "lorem ipsum dolor sit amet, consectetur adipiscing" | ||
|
||
app.add_local("text", knowledge) | ||
|
||
app.query("What text did I give you?") | ||
app.chat("What text did I give you?") | ||
|
||
self.assertEqual(mock_ec_get_llm_model_answer.call_args[1]["documents"], [knowledge]) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,66 @@ | ||
import unittest | ||
from string import Template | ||
|
||
from embedchain import App | ||
from embedchain.embedchain import QueryConfig | ||
|
||
|
||
class TestGeneratePrompt(unittest.TestCase): | ||
def setUp(self): | ||
self.app = App() | ||
|
||
def test_generate_prompt_with_template(self): | ||
""" | ||
Tests that the generate_prompt method correctly formats the prompt using | ||
a custom template provided in the QueryConfig instance. | ||
|
||
This test sets up a scenario with an input query and a list of contexts, | ||
and a custom template, and then calls generate_prompt. It checks that the | ||
returned prompt correctly incorporates all the contexts and the query into | ||
the format specified by the template. | ||
""" | ||
# Setup | ||
input_query = "Test query" | ||
contexts = ["Context 1", "Context 2", "Context 3"] | ||
template = "You are a bot. Context: ${context} - Query: ${query} - Helpful answer:" | ||
config = QueryConfig(template=Template(template)) | ||
|
||
# Execute | ||
result = self.app.generate_prompt(input_query, contexts, config) | ||
|
||
# Assert | ||
expected_result = ( | ||
"You are a bot. Context: Context 1 | Context 2 | Context 3 - Query: Test query - Helpful answer:" | ||
) | ||
self.assertEqual(result, expected_result) | ||
|
||
def test_generate_prompt_with_contexts_list(self): | ||
""" | ||
Tests that the generate_prompt method correctly handles a list of contexts. | ||
|
||
This test sets up a scenario with an input query and a list of contexts, | ||
and then calls generate_prompt. It checks that the returned prompt | ||
correctly includes all the contexts and the query. | ||
""" | ||
# Setup | ||
input_query = "Test query" | ||
contexts = ["Context 1", "Context 2", "Context 3"] | ||
config = QueryConfig() | ||
|
||
# Execute | ||
result = self.app.generate_prompt(input_query, contexts, config) | ||
|
||
# Assert | ||
expected_result = config.template.substitute(context="Context 1 | Context 2 | Context 3", query=input_query) | ||
self.assertEqual(result, expected_result) | ||
|
||
def test_generate_prompt_with_history(self): | ||
""" | ||
Test the 'generate_prompt' method with QueryConfig containing a history attribute. | ||
""" | ||
config = QueryConfig(history=["Past context 1", "Past context 2"]) | ||
config.template = Template("Context: $context | Query: $query | History: $history") | ||
prompt = self.app.generate_prompt("Test query", ["Test context"], config) | ||
|
||
expected_prompt = "Context: Test context | Query: Test query | History: ['Past context 1', 'Past context 2']" | ||
self.assertEqual(prompt, expected_prompt) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,43 @@ | ||
import os | ||
import unittest | ||
from unittest.mock import MagicMock, patch | ||
|
||
from embedchain import App | ||
from embedchain.embedchain import QueryConfig | ||
|
||
|
||
class TestApp(unittest.TestCase): | ||
os.environ["OPENAI_API_KEY"] = "test_key" | ||
|
||
def setUp(self): | ||
self.app = App() | ||
|
||
@patch("chromadb.api.models.Collection.Collection.add", MagicMock) | ||
def test_query(self): | ||
""" | ||
This test checks the functionality of the 'query' method in the App class. | ||
It simulates a scenario where the 'retrieve_from_database' method returns a context list and | ||
'get_llm_model_answer' returns an expected answer string. | ||
|
||
The 'query' method is expected to call 'retrieve_from_database' and 'get_llm_model_answer' methods | ||
appropriately and return the right answer. | ||
|
||
Key assumptions tested: | ||
- 'retrieve_from_database' method is called exactly once with arguments: "Test query" and an instance of | ||
QueryConfig. | ||
- 'get_llm_model_answer' is called exactly once. The specific arguments are not checked in this test. | ||
- 'query' method returns the value it received from 'get_llm_model_answer'. | ||
|
||
The test isolates the 'query' method behavior by mocking out 'retrieve_from_database' and | ||
'get_llm_model_answer' methods. | ||
""" | ||
with patch.object(self.app, "retrieve_from_database") as mock_retrieve: | ||
mock_retrieve.return_value = ["Test context"] | ||
with patch.object(self.app, "get_llm_model_answer") as mock_answer: | ||
mock_answer.return_value = "Test answer" | ||
answer = self.app.query("Test query") | ||
|
||
self.assertEqual(answer, "Test answer") | ||
self.assertEqual(mock_retrieve.call_args[0][0], "Test query") | ||
self.assertIsInstance(mock_retrieve.call_args[0][1], QueryConfig) | ||
mock_answer.assert_called_once() |
This file was deleted.
Oops, something went wrong.
Oops, something went wrong.
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Any specific reason to add this? :D
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
otherwise it complains about all the doc strings.