Skip to content

Commit

Permalink
is_enabled field support; pytest for unit tests (#121)
Browse files Browse the repository at this point in the history
* is_enabled field support; pytest for unit tests

* fix tests job
  • Loading branch information
pushforce committed Mar 10, 2023
1 parent 7d247d2 commit ddd0a2e
Show file tree
Hide file tree
Showing 8 changed files with 247 additions and 114 deletions.
25 changes: 22 additions & 3 deletions .github/workflows/dp-agent.yml
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@ jobs:
pip install -r dev_requirements.txt
- name: Check code formatting
run: |
black deeppavlov_agent/ --diff --check
make check_format
lint:
needs: format
Expand All @@ -45,7 +45,7 @@ jobs:
pip install -r dev_requirements.txt
- name: Lint with flake8
run: |
flake8 deeppavlov_agent/ --count --show-source --statistics
make lint
type_check:
needs: lint
Expand All @@ -64,4 +64,23 @@ jobs:
pip install -r dev_requirements.txt
- name: Check types with mypy
run: |
mypy deeppavlov_agent/
make type_check
test:
runs-on: ubuntu-latest

steps:
- name: Checkout code
uses: actions/checkout@v3
- name: Set up Python 3.7
uses: actions/setup-python@v3
with:
python-version: "3.7"
- name: Install dependencies
run: |
python -m pip install --upgrade pip
pip install -r requirements.txt
pip install -r dev_requirements.txt
- name: Run unit tests
run: |
make unit_test
18 changes: 18 additions & 0 deletions Makefile
Original file line number Diff line number Diff line change
@@ -0,0 +1,18 @@
.PHONY: check_format format lint type_check check_code unit_test

check_format:
black deeppavlov_agent/ --diff --check

format:
black deeppavlov_agent/

lint:
flake8 deeppavlov_agent/ --count --show-source --statistics

type_check:
mypy deeppavlov_agent/

check_code: check_format lint type_check

unit_test:
pytest deeppavlov_agent/tests
5 changes: 5 additions & 0 deletions deeppavlov_agent/core/state_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -217,3 +217,8 @@ async def add_annotation_and_reset_human_attributes_for_first_turn(
dialog.human.attributes = {
"disliked_skills": dialog.human.attributes.get("disliked_skills", [])
}


class FakeStateManager:
async def add_annotation(self, dialog: Dialog, payload: Dict, label: str, **kwargs):
pass
4 changes: 2 additions & 2 deletions deeppavlov_agent/core/workflow_manager.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
from collections import defaultdict
from uuid import uuid4
from typing import Optional, Dict, Tuple
from typing import Optional, Dict, Tuple, Any
from time import time

from .state_schema import Dialog
Expand Down Expand Up @@ -40,7 +40,7 @@ def get_dialog_by_id(self, dialog_id: str) -> Optional[Dialog]:
return None

def add_task(
self, dialog_id: str, service: Service, payload: Dict, ind: int
self, dialog_id: str, service: Service, payload: Any, ind: int
) -> Optional[str]:
workflow_record = self.workflow_records.get(dialog_id, None)
if not workflow_record:
Expand Down
16 changes: 9 additions & 7 deletions deeppavlov_agent/parse_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -56,7 +56,7 @@ def __init__(self, state_manager: StateManager, config: Dict):
self.formatters_module = import_module(formatters_module_name)

self.fill_connectors()
self.fill_services()
self.fill_services(None, self.config["services"])

def setup_module_from_config(self, name_var):
module = None
Expand Down Expand Up @@ -157,7 +157,7 @@ def make_connector(self, name: str, data: Dict):
self.workers.extend(workers)
self.connectors[name] = connector

def make_service(self, group: str, name: str, data: Dict):
def make_service(self, group: Optional[str], name: str, data: Dict):
logger.debug(f"Create service: '{name}' config={data}")

def check_ext_module(class_name):
Expand Down Expand Up @@ -287,10 +287,12 @@ def fill_connectors(self):
self.services_names[k].add(service_name)
self.services_names[service_name].add(service_name)

def fill_services(self):
for k, v in self.config["services"].items():
def fill_services(self, group: Optional[str], services: Dict[str, dict]):
for k, v in services.items():
if "is_enabled" in v and v["is_enabled"] is False:
continue

if "connector" in v: # single service
self.make_service(None, k, v)
self.make_service(group, k, v)
else: # grouped services
for sk, sv in v.items():
self.make_service(k, sk, sv)
self.fill_services(k, v)
78 changes: 78 additions & 0 deletions deeppavlov_agent/tests/config_parser_test.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,78 @@
from ..parse_config import PipelineConfigParser
from ..core.state_manager import FakeStateManager


def test_service_is_enabled():
config = {
"services": {
"annotator": {
"connector": {
"protocol": "python",
"class_name": "PredefinedOutputConnector",
"output": {"body": "here are my annotations"},
},
"state_manager_method": "add_annotation",
"is_enabled": True,
},
}
}

parsed_config = PipelineConfigParser(FakeStateManager(), config)

service = list(filter(lambda x: x.label == "annotator", parsed_config.services))[0]

assert service is not None


def test_service_is_disabled():
config = {
"services": {
"annotator": {
"connector": {
"protocol": "python",
"class_name": "PredefinedOutputConnector",
"output": {"body": "here are my annotations"},
},
"state_manager_method": "add_annotation",
"is_enabled": False,
},
}
}

parsed_config = PipelineConfigParser(FakeStateManager(), config)

assert len(parsed_config.services) == 0


def test_service_in_group_is_disabled():
config = {
"services": {
"annotators": {
"annotator1": {
"connector": {
"protocol": "python",
"class_name": "PredefinedOutputConnector",
"output": {"body": "annotations1"},
},
"state_manager_method": "add_annotation",
},
"annotator2": {
"connector": {
"protocol": "python",
"class_name": "PredefinedOutputConnector",
"output": {"body": "annotations2"},
},
"state_manager_method": "add_annotation",
"is_enabled": False,
},
}
}
}

parsed_config = PipelineConfigParser(FakeStateManager(), config)
filtered_services = list(
filter(lambda x: x.label == "annotator2", parsed_config.services)
)

assert len(filtered_services) == 0
assert len(parsed_config.services) == 1

0 comments on commit ddd0a2e

Please sign in to comment.