diff --git a/backend/api/app.py b/backend/api/app.py index 833f529e5..5623c6f49 100644 --- a/backend/api/app.py +++ b/backend/api/app.py @@ -4,6 +4,7 @@ from werkzeug import exceptions from apig_wsgi import make_lambda_handler from werkzeug.middleware.dispatcher import DispatcherMiddleware +from api.models.category import CategoryModel from flask import Flask, Response, jsonify, render_template, request from flask_githubapp.core import GitHubApp @@ -12,7 +13,7 @@ from api.model import get_public_plugins, get_index, get_plugin, get_excluded_plugins, update_cache, \ move_artifact_to_s3, get_category_mapping, get_categories_mapping, get_manifest, update_activity_data, \ get_metrics_for_plugin -from api.models.category import CategoryModel +from api.models import category as categories from api.shield import get_shield from utils.utils import send_alert, reformat_ssh_key_to_pem_bytes diff --git a/backend/api/models/_tests/test_category.py b/backend/api/models/_tests/test_category.py index 824127bea..35edc23db 100644 --- a/backend/api/models/_tests/test_category.py +++ b/backend/api/models/_tests/test_category.py @@ -1,23 +1,96 @@ -from unittest.mock import Mock +from typing import List +import pytest + +from api.models._tests.conftest import create_dynamo_table +from moto import mock_dynamodb + +TEST_BUCKET = "test-bucket" +TEST_STACK_NAME = "None" +TEST_BUCKET_PATH = "test-path" +TEST_CATEGORY_PATH = "category/EDAM-BIOIMAGING/alpha06.json" +TEST_CATEGORY_VERSION = "EDAM-BIOIMAGING:alpha06" class TestCategory: - def test_get_category_has_result(self, monkeypatch): - mock_category = Mock( - return_value=[ - Mock(label="label1", dimension="dimension1", hierarchy=["hierarchy1"]), - Mock( - label="label2", - dimension="dimension2", - hierarchy=["hierarchy1", "hierarchy2"], - ), - ] - ) + @pytest.fixture + def setup_env_variables(self, monkeypatch): + monkeypatch.setenv("BUCKET", TEST_BUCKET) + monkeypatch.setenv("BUCKET_PATH", TEST_BUCKET_PATH) + @pytest.fixture() + def categories_table(self, aws_credentials, setup_env_variables): from api.models.category import CategoryModel - monkeypatch.setattr(CategoryModel, "query", mock_category) - actual = CategoryModel.get_category("name", "version") + with mock_dynamodb(): + yield create_dynamo_table(CategoryModel, "category") + + def _get_version_hash(self, hash: str) -> str: + return f"{TEST_CATEGORY_VERSION}:{hash}" + + def _put_item( + self, + table, + name: str, + version: str, + version_hash: str, + formatted_name: str, + dimension: str, + label: str, + hierarchy: List[str], + ): + item = { + "name": name, + "version": version, + "version_hash": self._get_version_hash(version_hash), + "formatted_name": formatted_name, + "label": label, + "dimension": dimension, + "hierarchy": hierarchy, + } + table.put_item(Item=item) + + def _seed_data(self, table): + self._put_item( + table, + name="name1", + version=TEST_CATEGORY_VERSION, + version_hash="hash1", + formatted_name="Name1", + dimension="dimension1", + label="label1", + hierarchy=["hierarchy1"], + ) + + self._put_item( + table, + name="name1", + version=TEST_CATEGORY_VERSION, + version_hash="hash2", + formatted_name="Name1", + dimension="dimension2", + label="label2", + hierarchy=["hierarchy1", "hierarchy2"], + ) + + self._put_item( + table, + name="name2", + version=TEST_CATEGORY_VERSION, + version_hash="hash3", + formatted_name="Name2", + dimension="dimension3", + label="label3", + hierarchy=["hierarchy3"], + ) + + def test_get_category_has_result( + self, aws_credentials, setup_env_variables, categories_table + ): + self._seed_data(categories_table) + + from api.models.category import get_category + + actual = get_category("name1", TEST_CATEGORY_VERSION) expected = [ { "label": "label1", @@ -33,40 +106,31 @@ def test_get_category_has_result(self, monkeypatch): assert actual == expected - def test_get_all_categories(self, monkeypatch): - mock_category = Mock( - return_value=[ - Mock( - formatted_name="name1", - version="version", - label="label1", - dimension="dimension1", - hierarchy=["hierarchy1"], - ), - Mock( - formatted_name="name1", - version="version", - label="label2", - dimension="dimension2", - hierarchy=["hierarchy1", "hierarchy2"], - ), - Mock( - formatted_name="name2", - version="version", - label="label3", - dimension="dimension3", - hierarchy=["hierarchy3"], - ), - ] - ) + def test_get_category_has_no_result( + self, aws_credentials, setup_env_variables, categories_table + ): + self._seed_data(categories_table) - from api.models.category import CategoryModel + from api.models.category import get_category + + actual = get_category("foobar", TEST_CATEGORY_VERSION) + expected = [] + + assert actual == expected + + def test_get_all_categories( + self, + aws_credentials, + setup_env_variables, + categories_table, + ): + self._seed_data(categories_table) - monkeypatch.setattr(CategoryModel, "scan", mock_category) - actual = CategoryModel.get_all_categories("version") + from api.models.category import get_all_categories + actual = get_all_categories(TEST_CATEGORY_VERSION) expected = { - "name1": [ + "Name1": [ { "label": "label1", "dimension": "dimension1", @@ -78,7 +142,7 @@ def test_get_all_categories(self, monkeypatch): "hierarchy": ["hierarchy1", "hierarchy2"], }, ], - "name2": [ + "Name2": [ { "label": "label3", "dimension": "dimension3", @@ -88,3 +152,16 @@ def test_get_all_categories(self, monkeypatch): } assert actual == expected + + def test_get_all_categories_empty_table( + self, + aws_credentials, + setup_env_variables, + categories_table, + ): + from api.models.category import get_all_categories + + actual = get_all_categories(TEST_CATEGORY_VERSION) + expected = {} + + assert actual == expected diff --git a/backend/api/models/_tests/test_install_activity.py b/backend/api/models/_tests/test_install_activity.py index 0f17287ad..52a7cbf6d 100644 --- a/backend/api/models/_tests/test_install_activity.py +++ b/backend/api/models/_tests/test_install_activity.py @@ -44,7 +44,7 @@ def _put_item(self, table, granularity, timestamp, install_count, is_total=None, 'plugin_name': plugin, 'type_timestamp': self._to_type_timestamp(granularity, timestamp), 'install_count': install_count, - 'granularity': granularity, + 'type': granularity, 'timestamp': to_millis(timestamp), } if is_total: diff --git a/backend/api/models/category.py b/backend/api/models/category.py index c4efc03e3..cfa84b9cc 100644 --- a/backend/api/models/category.py +++ b/backend/api/models/category.py @@ -1,18 +1,18 @@ import os import time -from slugify import slugify +from api.models.helper import set_ddb_metadata from collections import defaultdict from pynamodb.attributes import ListAttribute, NumberAttribute, UnicodeAttribute from pynamodb.models import Model +from slugify import slugify from utils.time import get_current_timestamp, print_perf_duration +@set_ddb_metadata("category") class CategoryModel(Model): class Meta: - host = os.getenv('LOCAL_DYNAMO_HOST') - region = os.environ.get("AWS_REGION", "us-west-2") - table_name = f"{os.environ.get('STACK_NAME')}-category" + pass name = UnicodeAttribute(hash_key=True) version_hash = UnicodeAttribute(range_key=True) @@ -23,57 +23,68 @@ class Meta: label = UnicodeAttribute() last_updated_timestamp = NumberAttribute(default_for_new=get_current_timestamp) - @classmethod - def _get_category_from_model(cls, category): - return { - "label": category.label, - "dimension": category.dimension, - "hierarchy": category.hierarchy, - } - - @classmethod - def get_category(cls, name: str, version: str): - """ - Gets the category data for a particular category and EDAM version. - """ - - category = [] - start = time.perf_counter() - - for item in cls.query( - slugify(name), cls.version_hash.startswith(f"{version}:") - ): - category.append(cls._get_category_from_model(item)) - - print_perf_duration(start, f"CategoryModel.get_category({name})") - - return category - - @classmethod - def get_all_categories(cls, version: str): - """ - Gets all available category mappings from a particular EDAM version. - """ - - start = time.perf_counter() - categories = cls.scan( - cls.version == version, - attributes_to_get=[ - "formatted_name", - "version", - "dimension", - "hierarchy", - "label", - ], + def __eq__(self, other): + return isinstance(other, CategoryModel) and ( + self.name == other.name + and self.version_hash == other.version_hash + and self.version == other.version + and self.formatted_name == other.formatted_name + and self.dimension == other.dimension + and self.hierarchy == other.hierarchy + and self.label == other.label ) - mapped_categories = defaultdict(list) - for category in categories: - mapped_categories[category.formatted_name].append( - cls._get_category_from_model(category) - ) +def _get_category_from_model(category): + return { + "label": category.label, + "dimension": category.dimension, + "hierarchy": category.hierarchy, + } + + +def get_category(name: str, version: str): + """ + Gets the category data for a particular category and EDAM version. + """ + + category = [] + start = time.perf_counter() + + for item in CategoryModel.query( + slugify(name), CategoryModel.version_hash.startswith(f"{version}:") + ): + category.append(_get_category_from_model(item)) + + print_perf_duration(start, f"CategoryModel.get_category({name})") + + return category + + +def get_all_categories(version: str): + """ + Gets all available category mappings from a particular EDAM version. + """ + + start = time.perf_counter() + categories = CategoryModel.scan( + CategoryModel.version == version, + attributes_to_get=[ + "formatted_name", + "version", + "dimension", + "hierarchy", + "label", + ], + ) + + mapped_categories = defaultdict(list) + + for category in categories: + mapped_categories[category.formatted_name].append( + _get_category_from_model(category) + ) - print_perf_duration(start, "CategoryModel.get_all_categories()") + print_perf_duration(start, "CategoryModel.get_all_categories()") - return mapped_categories + return mapped_categories diff --git a/data-workflows/categories/category_model.py b/data-workflows/categories/category_model.py index 60d7d26a5..7c66a686d 100644 --- a/data-workflows/categories/category_model.py +++ b/data-workflows/categories/category_model.py @@ -24,3 +24,14 @@ class Meta: hierarchy = ListAttribute() # List[str] label = UnicodeAttribute() last_updated_timestamp = NumberAttribute(default_for_new=get_current_timestamp) + + def __eq__(self, other): + return isinstance(other, CategoryModel) and ( + self.name == other.name + and self.version_hash == other.version_hash + and self.version == other.version + and self.formatted_name == other.formatted_name + and self.dimension == other.dimension + and self.hierarchy == other.hierarchy + and self.label == other.label + ) diff --git a/data-workflows/categories/processor.py b/data-workflows/categories/processor.py index addc865ff..db5859783 100644 --- a/data-workflows/categories/processor.py +++ b/data-workflows/categories/processor.py @@ -1,11 +1,10 @@ import os import time -import hashlib import logging from categories.category_model import CategoryModel +from categories.utils import hash_category from slugify import slugify -from typing import Dict from utils.env import get_required_env from utils.s3 import S3Client @@ -14,27 +13,6 @@ LOGGER = logging.getLogger() -def _hash_category(category: Dict[str, str]) -> str: - """ - Hashes a category object using the MD5 hash algorithm. This works by - creating a hash from the string and string array fields in the category - object. - """ - - label = category.get("label", "") - dimension = category.get("dimension") - hierarchy = category.get("hierarchy", []) - - category_hash = hashlib.new("md5") - category_hash.update(label.encode("utf-8")) - category_hash.update(dimension.encode("utf-8")) - - for value in hierarchy: - category_hash.update(value.encode("utf-8")) - - return category_hash.hexdigest() - - def seed_s3_categories_workflow(version: str, categories_path: str): """ Runs data workflow for populating the category dynamo table from an S3 @@ -75,7 +53,7 @@ def seed_s3_categories_workflow(version: str, categories_path: str): for category in categories: item = CategoryModel( name=slugify(name), - version_hash=f"{version}:{_hash_category(category)}", + version_hash=f"{version}:{hash_category(category)}", version=version, formatted_name=name, dimension=category.get("dimension", ""), diff --git a/data-workflows/categories/tests/__init__.py b/data-workflows/categories/tests/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/data-workflows/categories/tests/test_processor.py b/data-workflows/categories/tests/test_processor.py new file mode 100644 index 000000000..4e6ba261a --- /dev/null +++ b/data-workflows/categories/tests/test_processor.py @@ -0,0 +1,193 @@ +from typing import Any, Dict +import boto3 +import json +import pytest + +from categories.category_model import CategoryModel +from categories.utils import hash_category +from conftest import create_dynamo_table +from moto import mock_dynamodb, mock_s3 +from unittest.mock import Mock + + +TEST_BUCKET = "test-bucket" +TEST_STACK_NAME = "None" +TEST_BUCKET_PATH = "test-path" +TEST_CATEGORY_PATH = "category/EDAM-BIOIMAGING/alpha06.json" +TEST_CATEGORY_VERSION = "EDAM-BIOIMAGING:alpha06" +TEST_CATEGORY_DATA = json.dumps( + { + "Foo": [ + { + "dimension": "dimension1", + "label": "label", + "hierarchy": ["1", "2", "3"], + }, + { + "dimension": "dimension2", + "label": "label", + "hierarchy": ["1", "2"], + }, + ], + "Foo Bar": [ + { + "dimension": "dimension1", + "label": "label", + "hierarchy": ["1", "2"], + }, + ], + } +) + + +def _get_version_hash(category: Dict[str, Any]): + hash = hash_category(category) + return f"{TEST_CATEGORY_VERSION}:{hash}" + + +class BatchWriteMock: + def __init__(self, commit=Mock(), save=Mock()): + self.commit = commit + self.save = save + + +@mock_s3 +class TestPluginManifest: + @pytest.fixture + def setup_env_variables(self, monkeypatch): + monkeypatch.setenv("BUCKET", TEST_BUCKET) + monkeypatch.setenv("BUCKET_PATH", TEST_BUCKET_PATH) + monkeypatch.setenv("STACK_NAME", TEST_STACK_NAME) + + @pytest.fixture() + def categories_table(self, aws_credentials, setup_env_variables): + from categories.category_model import CategoryModel + + with mock_dynamodb(): + yield create_dynamo_table(CategoryModel, "category") + + def _set_up_mock_batch_write(self, monkeypatch, commit=Mock(), save=Mock()): + self._mock_batch_write = BatchWriteMock(commit=commit, save=save) + monkeypatch.setattr( + CategoryModel, "batch_write", lambda: self._mock_batch_write + ) + + def _set_up_s3(self, bucket_name=TEST_BUCKET): + self._s3 = boto3.resource("s3") + bucket = self._s3.Bucket(bucket_name) + bucket.create() + + def _seed_data(self): + complete_path = f"{TEST_BUCKET_PATH}/{TEST_CATEGORY_PATH}" + self._s3.Object(TEST_BUCKET, complete_path).put( + Body=bytes(TEST_CATEGORY_DATA, "utf-8") + ) + + def test_write_category_data( + self, + aws_credentials, + setup_env_variables, + categories_table, + ): + self._set_up_s3() + self._seed_data() + + import categories.processor + from categories.category_model import CategoryModel + + categories.processor.seed_s3_categories_workflow( + TEST_CATEGORY_VERSION, TEST_CATEGORY_PATH + ) + + data = list(CategoryModel.scan()) + assert data == [ + CategoryModel( + name="foo", + version_hash=_get_version_hash( + { + "dimension": "dimension1", + "label": "label", + "hierarchy": ["1", "2", "3"], + } + ), + version=TEST_CATEGORY_VERSION, + formatted_name="Foo", + dimension="dimension1", + label="label", + hierarchy=["1", "2", "3"], + ), + CategoryModel( + name="foo", + version_hash=_get_version_hash( + { + "dimension": "dimension2", + "label": "label", + "hierarchy": ["1", "2"], + } + ), + version=TEST_CATEGORY_VERSION, + formatted_name="Foo", + dimension="dimension2", + label="label", + hierarchy=["1", "2"], + ), + CategoryModel( + name="foo-bar", + version=TEST_CATEGORY_VERSION, + version_hash=_get_version_hash( + { + "dimension": "dimension1", + "label": "label", + "hierarchy": ["1", "2"], + } + ), + formatted_name="Foo Bar", + dimension="dimension1", + label="label", + hierarchy=["1", "2"], + ), + ] + + def test_write_category_data_missing_params(self): + import categories.processor + + with pytest.raises(ValueError): + categories.processor.seed_s3_categories_workflow("", TEST_CATEGORY_PATH) + + with pytest.raises(ValueError): + categories.processor.seed_s3_categories_workflow(TEST_CATEGORY_VERSION, "") + + def test_write_category_data_missing_required_env(self): + import categories.processor + + with pytest.raises(ValueError): + categories.processor.seed_s3_categories_workflow( + TEST_CATEGORY_VERSION, TEST_CATEGORY_PATH + ) + + def test_write_category_data_s3_load_error( + self, aws_credentials, setup_env_variables, categories_table, monkeypatch + ): + self._set_up_s3() + self._set_up_mock_batch_write(monkeypatch) + + import categories.processor + + categories.processor.seed_s3_categories_workflow( + TEST_CATEGORY_VERSION, TEST_CATEGORY_PATH + ) + + self._mock_batch_write.save.assert_not_called() + + def test_write_category_data_batch_write_error( + self, aws_credentials, setup_env_variables, categories_table, monkeypatch + ): + self._set_up_s3() + self._set_up_mock_batch_write(monkeypatch, commit=Mock(side_effect=Exception())) + + import categories.processor + + with pytest.raises(Exception): + categories.processor.seed_s3_categories_workflow( + TEST_CATEGORY_VERSION, TEST_CATEGORY_PATH + ) diff --git a/data-workflows/categories/utils.py b/data-workflows/categories/utils.py new file mode 100644 index 000000000..30810d91a --- /dev/null +++ b/data-workflows/categories/utils.py @@ -0,0 +1,24 @@ +import hashlib + +from typing import Dict + + +def hash_category(category: Dict[str, str]) -> str: + """ + Hashes a category object using the MD5 hash algorithm. This works by + creating a hash from the string and string array fields in the category + object. + """ + + label = category.get("label", "") + dimension = category.get("dimension") + hierarchy = category.get("hierarchy", []) + + category_hash = hashlib.new("md5") + category_hash.update(label.encode("utf-8")) + category_hash.update(dimension.encode("utf-8")) + + for value in hierarchy: + category_hash.update(value.encode("utf-8")) + + return category_hash.hexdigest() diff --git a/data-workflows/conftest.py b/data-workflows/conftest.py index 031d562a3..5684dd279 100644 --- a/data-workflows/conftest.py +++ b/data-workflows/conftest.py @@ -1,9 +1,11 @@ +import boto3 import os - import pytest +from pynamodb.models import Model + -@pytest.fixture(scope='module') +@pytest.fixture(scope="module") def aws_credentials(): """Mocked AWS Credentials for moto.""" os.environ["AWS_ACCESS_KEY_ID"] = "testing" @@ -11,3 +13,10 @@ def aws_credentials(): os.environ["AWS_SECURITY_TOKEN"] = "testing" os.environ["AWS_SESSION_TOKEN"] = "testing" os.environ["AWS_DEFAULT_REGION"] = "us-east-1" + + +def create_dynamo_table(pynamo_ddb_model: Model, table_name: str): + pynamo_ddb_model.create_table() + return boto3.resource("dynamodb", region_name="us-west-2").Table( + f"None-{table_name}" + ) diff --git a/data-workflows/tests/test_handler.py b/data-workflows/tests/test_handler.py index c4ba60e99..a1bd00c96 100644 --- a/data-workflows/tests/test_handler.py +++ b/data-workflows/tests/test_handler.py @@ -5,47 +5,90 @@ import pytest import activity.processor +import categories.processor class TestHandle: - @pytest.fixture(autouse=True) def setup(self, monkeypatch): self._update_activity = Mock(spec=activity.processor.update_activity) - monkeypatch.setattr(activity.processor, 'update_activity', self._update_activity) + self._seed_s3_categories_workflow = Mock( + spec=categories.processor.seed_s3_categories_workflow + ) + monkeypatch.setattr( + activity.processor, "update_activity", self._update_activity + ) + monkeypatch.setattr( + categories.processor, + "seed_s3_categories_workflow", + self._seed_s3_categories_workflow, + ) - def _verify(self, activity_call_count: int = 0): + def _verify(self, activity_call_count: int = 0, s3_seed_call_count: int = 0): assert self._update_activity.call_count == activity_call_count + assert self._seed_s3_categories_workflow.call_count == s3_seed_call_count - @pytest.mark.parametrize('event_type', ["Activity", "AcTiviTy", "ACTIVITY"]) - def test_handle_event_type_in_different_case(self, event_type: str): + @pytest.mark.parametrize( + "event_type,activity_call_count,s3_seed_call_count", + [ + ("Activity", 1, 0), + ("AcTiviTy", 1, 0), + ("ACTIVITY", 1, 0), + ("seed-s3-categories", 0, 1), + ("SeEd-S3-cAtEgorIes", 0, 1), + ("SEED-S3-CATEGORIES", 0, 1), + ], + ) + def test_handle_event_type_in_different_case( + self, + event_type: str, + activity_call_count: int, + s3_seed_call_count: int, + ): from handler import handle - handle({'Records': [{'body': '{"type":"activity"}'}]}, None) - self._verify(activity_call_count=1) + handle({"Records": [{"body": '{"type":"' + event_type + '"}'}]}, None) + self._verify( + activity_call_count=activity_call_count, + s3_seed_call_count=s3_seed_call_count, + ) - def test_handle_activity_event_type(self): + def test_handle_event_type(self): from handler import handle - handle({'Records': [ - {'body': '{"type":"activity"}'}, - {'body': '{"type":"bar"}'} - ]}, None) - self._verify(activity_call_count=1) + + handle( + { + "Records": [ + {"body": '{"type":"activity"}'}, + {"body": '{"type":"seed-s3-categories"}'}, + {"body": '{"type":"bar"}'}, + ] + }, + None, + ) + self._verify(activity_call_count=1, s3_seed_call_count=1) def test_handle_invalid_json(self): with pytest.raises(JSONDecodeError): from handler import handle - handle({'Records': [{'body': '{"type:"activity"}'}]}, None) + + handle({"Records": [{"body": '{"type:"activity"}'}]}, None) + self._verify() self._verify() - @pytest.mark.parametrize('event', [ - ({'Records': [{'body': '{"type":"foo"}'}, {'body': '{"type":"bar"}'}]}), - ({'Records': [{'body': '{"type":"foo"}'}]}), - ({'Records': [{'foo': 'bar'}]}), - ({'Records': []}), - ({}), - ]) + @pytest.mark.parametrize( + "event", + [ + ({"Records": [{"body": '{"type":"foo"}'}, {"body": '{"type":"bar"}'}]}), + ({"Records": [{"body": '{"type":"foo"}'}]}), + ({"Records": [{"foo": "bar"}]}), + ({"Records": []}), + ({}), + ], + ) def test_handle_invalid_event(self, event: Dict): from handler import handle + handle(event, None) - self._verify() \ No newline at end of file + self._verify() + self._verify() diff --git a/data-workflows/tests/test_run_workflow.py b/data-workflows/tests/test_run_workflow.py new file mode 100644 index 000000000..c5c7ee272 --- /dev/null +++ b/data-workflows/tests/test_run_workflow.py @@ -0,0 +1,66 @@ +from json import JSONDecodeError +from typing import Dict +from unittest.mock import Mock + +import pytest + +import activity.processor +import categories.processor + + +class TestHandle: + @pytest.fixture(autouse=True) + def setup(self, monkeypatch): + self._update_activity = Mock(spec=activity.processor.update_activity) + self._seed_s3_categories_workflow = Mock( + spec=categories.processor.seed_s3_categories_workflow + ) + monkeypatch.setattr( + activity.processor, "update_activity", self._update_activity + ) + monkeypatch.setattr( + categories.processor, + "seed_s3_categories_workflow", + self._seed_s3_categories_workflow, + ) + + def _verify_update_activity(self, activity_call_count: int = 0): + assert self._update_activity.call_count == activity_call_count + + def _verify_s3_seed(self, s3_seed_call_count: int = 0): + assert self._seed_s3_categories_workflow.call_count == s3_seed_call_count + + @pytest.mark.parametrize("event_type", ["Activity", "AcTiviTy", "ACTIVITY"]) + def test_handle_event_type_in_different_case(self, event_type: str): + from run_workflow import run_workflow + + run_workflow( + { + "type": event_type, + } + ) + + self._verify_update_activity(activity_call_count=1) + + def test_handle_activity_event_type(self): + from run_workflow import run_workflow + + run_workflow({"type": "activity"}) + self._verify_update_activity(activity_call_count=1) + + def test_handle_seed_s3_categories_event_type(self): + from run_workflow import run_workflow + + run_workflow({"type": "seed-s3-categories"}) + self._verify_s3_seed(s3_seed_call_count=1) + + @pytest.mark.parametrize( + "event", + [{"type": "foo"}, {"type": "bar"}, {}], + ) + def test_handle_invalid_event(self, event: Dict): + from run_workflow import run_workflow + + run_workflow(event) + self._verify_update_activity() + self._verify_s3_seed()