Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
36 changes: 19 additions & 17 deletions tests/unit/utils/test_checks.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,13 +5,15 @@
from types import ModuleType
from unittest.mock import patch

from typing import Any

import pytest

from utils import checks


@pytest.fixture(name="input_file")
def input_file_fixture(tmp_path):
def input_file_fixture(tmp_path: Path) -> str:
"""Create file manually using the tmp_path fixture."""
filename = os.path.join(tmp_path, "mydoc.csv")
with open(filename, "wt", encoding="utf-8") as fout:
Expand All @@ -20,17 +22,17 @@ def input_file_fixture(tmp_path):


@pytest.fixture(name="input_directory")
def input_directory_fixture(tmp_path):
def input_directory_fixture(tmp_path: Path) -> str:
"""Create directory manually using the tmp_path fixture."""
dirname = os.path.join(tmp_path, "mydir")
os.mkdir(dirname)
return dirname


def test_get_attribute_from_file_no_record():
def test_get_attribute_from_file_no_record() -> None:
"""Test the get_attribute_from_file function when record is not in dictionary."""
# no data
d = {}
d: dict[str, Any] = {}

# non-existing key
key = ""
Expand All @@ -43,7 +45,7 @@ def test_get_attribute_from_file_no_record():
assert value is None


def test_get_attribute_from_file_proper_record(input_file):
def test_get_attribute_from_file_proper_record(input_file: str) -> None:
"""Test the get_attribute_from_file function when record is present in dictionary."""
# existing key
key = "my_file"
Expand All @@ -57,7 +59,7 @@ def test_get_attribute_from_file_proper_record(input_file):
assert value == "some content!"


def test_get_attribute_from_file_improper_filename():
def test_get_attribute_from_file_improper_filename() -> None:
"""Test the get_attribute_from_file when the file does not exist."""
# existing key
key = "my_file"
Expand All @@ -70,26 +72,26 @@ def test_get_attribute_from_file_improper_filename():
checks.get_attribute_from_file(d, "my_file")


def test_file_check_existing_file(input_file):
def test_file_check_existing_file(input_file: str) -> None:
"""Test the function file_check for existing file."""
# just call the function, it should not raise an exception
checks.file_check(input_file, "description")


def test_file_check_non_existing_file():
def test_file_check_non_existing_file() -> None:
"""Test the function file_check for non existing file."""
with pytest.raises(checks.InvalidConfigurationError):
checks.file_check(Path("does-not-exists"), "description")


def test_file_check_not_readable_file(input_file):
def test_file_check_not_readable_file(input_file: str) -> None:
"""Test the function file_check for not readable file."""
with patch("os.access", return_value=False):
with pytest.raises(checks.InvalidConfigurationError):
checks.file_check(input_file, "description")


def test_directory_check_non_existing_directory():
def test_directory_check_non_existing_directory() -> None:
"""Test the function directory_check skips non-existing directory."""
# just call the function, it should not raise an exception
checks.directory_check(
Expand All @@ -101,15 +103,15 @@ def test_directory_check_non_existing_directory():
)


def test_directory_check_existing_writable_directory(input_directory):
def test_directory_check_existing_writable_directory(input_directory: str) -> None:
"""Test the function directory_check checks directory."""
# just call the function, it should not raise an exception
checks.directory_check(
input_directory, must_exists=True, must_be_writable=True, desc="foobar"
)


def test_directory_check_non_a_directory(input_file):
def test_directory_check_non_a_directory(input_file: str) -> None:
"""Test the function directory_check checks directory."""
# pass a filename not a directory name
with pytest.raises(checks.InvalidConfigurationError):
Expand All @@ -118,7 +120,7 @@ def test_directory_check_non_a_directory(input_file):
)


def test_directory_check_existing_non_writable_directory(input_directory):
def test_directory_check_existing_non_writable_directory(input_directory: str) -> None:
"""Test the function directory_check checks directory."""
with patch("os.access", return_value=False):
with pytest.raises(checks.InvalidConfigurationError):
Expand All @@ -127,7 +129,7 @@ def test_directory_check_existing_non_writable_directory(input_directory):
)


def test_import_python_module_success():
def test_import_python_module_success() -> None:
"""Test importing a Python module."""
module_path = "tests/profiles/test/profile.py"
module_name = "profile"
Expand All @@ -136,7 +138,7 @@ def test_import_python_module_success():
assert isinstance(result, ModuleType)


def test_import_python_module_error():
def test_import_python_module_error() -> None:
"""Test importing a Python module that is a .txt file."""
module_path = "tests/profiles/test_two/test.txt"
module_name = "profile"
Expand All @@ -145,7 +147,7 @@ def test_import_python_module_error():
assert result is None


def test_is_valid_profile():
def test_is_valid_profile() -> None:
"""Test if an imported profile is valid."""
module_path = "tests/profiles/test/profile.py"
module_name = "profile"
Expand All @@ -157,7 +159,7 @@ def test_is_valid_profile():
assert result is True


def test_invalid_profile():
def test_invalid_profile() -> None:
"""Test if an imported profile is valid (expect invalid)"""
module_path = "tests/profiles/test_three/profile.py"
module_name = "profile"
Expand Down
14 changes: 8 additions & 6 deletions tests/unit/utils/test_types.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,30 +8,30 @@
class TestGraniteToolParser:
"""Unit tests for functions defined in utils/types.py."""

def test_get_tool_parser_when_model_is_is_not_granite(self):
def test_get_tool_parser_when_model_is_is_not_granite(self) -> None:
"""Test that the tool_parser is None when model_id is not a granite model."""
assert (
GraniteToolParser.get_parser("ollama3.3") is None
), "tool_parser should be None"

def test_get_tool_parser_when_model_id_does_not_start_with_granite(self):
def test_get_tool_parser_when_model_id_does_not_start_with_granite(self) -> None:
"""Test that the tool_parser is None when model_id does not start with granite."""
assert (
GraniteToolParser.get_parser("a-fine-trained-granite-model") is None
), "tool_parser should be None"

def test_get_tool_parser_when_model_id_starts_with_granite(self):
def test_get_tool_parser_when_model_id_starts_with_granite(self) -> None:
"""Test that the tool_parser is not None when model_id starts with granite."""
tool_parser = GraniteToolParser.get_parser("granite-3.3-8b-instruct")
assert tool_parser is not None, "tool_parser should not be None"

def test_get_tool_calls_from_completion_message_when_none(self):
def test_get_tool_calls_from_completion_message_when_none(self) -> None:
"""Test that get_tool_calls returns an empty array when CompletionMessage is None."""
tool_parser = GraniteToolParser.get_parser("granite-3.3-8b-instruct")
assert tool_parser is not None, "tool parser was not returned"
assert tool_parser.get_tool_calls(None) == [], "get_tool_calls should return []"

def test_get_tool_calls_from_completion_message_when_not_none(self):
def test_get_tool_calls_from_completion_message_when_not_none(self) -> None:
"""Test that get_tool_calls returns an empty array when CompletionMessage has no tool_calls.""" # pylint: disable=line-too-long
tool_parser = GraniteToolParser.get_parser("granite-3.3-8b-instruct")
assert tool_parser is not None, "tool parser was not returned"
Expand All @@ -41,7 +41,9 @@ def test_get_tool_calls_from_completion_message_when_not_none(self):
completion_message
), "get_tool_calls should return []"

def test_get_tool_calls_from_completion_message_when_message_has_tool_calls(self):
def test_get_tool_calls_from_completion_message_when_message_has_tool_calls(
self,
) -> None:
"""Test that get_tool_calls returns the tool_calls when CompletionMessage has tool_calls."""
tool_parser = GraniteToolParser.get_parser("granite-3.3-8b-instruct")
assert tool_parser is not None, "tool parser was not returned"
Expand Down
Loading