Skip to content

Commit

Permalink
Seed categories script unit tests (#1049)
Browse files Browse the repository at this point in the history
* Add tests for test_handler

* Add tests for run_workflow

* Category model tests

* Refactor backend category unit tests to use moto

* Address nit feedback

* Fix install activity test typo

* Address test feedback
  • Loading branch information
codemonkey800 committed Jun 1, 2023
1 parent 3eef4be commit 0ae21ef
Show file tree
Hide file tree
Showing 12 changed files with 561 additions and 148 deletions.
3 changes: 2 additions & 1 deletion backend/api/app.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
from werkzeug import exceptions
from apig_wsgi import make_lambda_handler
from werkzeug.middleware.dispatcher import DispatcherMiddleware
from api.models.category import CategoryModel
from flask import Flask, Response, jsonify, render_template, request
from flask_githubapp.core import GitHubApp

Expand All @@ -12,7 +13,7 @@
from api.model import get_public_plugins, get_index, get_plugin, get_excluded_plugins, update_cache, \
move_artifact_to_s3, get_category_mapping, get_categories_mapping, get_manifest, update_activity_data, \
get_metrics_for_plugin
from api.models.category import CategoryModel
from api.models import category as categories
from api.shield import get_shield
from utils.utils import send_alert, reformat_ssh_key_to_pem_bytes

Expand Down
167 changes: 122 additions & 45 deletions backend/api/models/_tests/test_category.py
Original file line number Diff line number Diff line change
@@ -1,23 +1,96 @@
from unittest.mock import Mock
from typing import List
import pytest

from api.models._tests.conftest import create_dynamo_table
from moto import mock_dynamodb

TEST_BUCKET = "test-bucket"
TEST_STACK_NAME = "None"
TEST_BUCKET_PATH = "test-path"
TEST_CATEGORY_PATH = "category/EDAM-BIOIMAGING/alpha06.json"
TEST_CATEGORY_VERSION = "EDAM-BIOIMAGING:alpha06"


class TestCategory:
def test_get_category_has_result(self, monkeypatch):
mock_category = Mock(
return_value=[
Mock(label="label1", dimension="dimension1", hierarchy=["hierarchy1"]),
Mock(
label="label2",
dimension="dimension2",
hierarchy=["hierarchy1", "hierarchy2"],
),
]
)
@pytest.fixture
def setup_env_variables(self, monkeypatch):
monkeypatch.setenv("BUCKET", TEST_BUCKET)
monkeypatch.setenv("BUCKET_PATH", TEST_BUCKET_PATH)

@pytest.fixture()
def categories_table(self, aws_credentials, setup_env_variables):
from api.models.category import CategoryModel

monkeypatch.setattr(CategoryModel, "query", mock_category)
actual = CategoryModel.get_category("name", "version")
with mock_dynamodb():
yield create_dynamo_table(CategoryModel, "category")

def _get_version_hash(self, hash: str) -> str:
return f"{TEST_CATEGORY_VERSION}:{hash}"

def _put_item(
self,
table,
name: str,
version: str,
version_hash: str,
formatted_name: str,
dimension: str,
label: str,
hierarchy: List[str],
):
item = {
"name": name,
"version": version,
"version_hash": self._get_version_hash(version_hash),
"formatted_name": formatted_name,
"label": label,
"dimension": dimension,
"hierarchy": hierarchy,
}
table.put_item(Item=item)

def _seed_data(self, table):
self._put_item(
table,
name="name1",
version=TEST_CATEGORY_VERSION,
version_hash="hash1",
formatted_name="Name1",
dimension="dimension1",
label="label1",
hierarchy=["hierarchy1"],
)

self._put_item(
table,
name="name1",
version=TEST_CATEGORY_VERSION,
version_hash="hash2",
formatted_name="Name1",
dimension="dimension2",
label="label2",
hierarchy=["hierarchy1", "hierarchy2"],
)

self._put_item(
table,
name="name2",
version=TEST_CATEGORY_VERSION,
version_hash="hash3",
formatted_name="Name2",
dimension="dimension3",
label="label3",
hierarchy=["hierarchy3"],
)

def test_get_category_has_result(
self, aws_credentials, setup_env_variables, categories_table
):
self._seed_data(categories_table)

from api.models.category import get_category

actual = get_category("name1", TEST_CATEGORY_VERSION)
expected = [
{
"label": "label1",
Expand All @@ -33,40 +106,31 @@ def test_get_category_has_result(self, monkeypatch):

assert actual == expected

def test_get_all_categories(self, monkeypatch):
mock_category = Mock(
return_value=[
Mock(
formatted_name="name1",
version="version",
label="label1",
dimension="dimension1",
hierarchy=["hierarchy1"],
),
Mock(
formatted_name="name1",
version="version",
label="label2",
dimension="dimension2",
hierarchy=["hierarchy1", "hierarchy2"],
),
Mock(
formatted_name="name2",
version="version",
label="label3",
dimension="dimension3",
hierarchy=["hierarchy3"],
),
]
)
def test_get_category_has_no_result(
self, aws_credentials, setup_env_variables, categories_table
):
self._seed_data(categories_table)

from api.models.category import CategoryModel
from api.models.category import get_category

actual = get_category("foobar", TEST_CATEGORY_VERSION)
expected = []

assert actual == expected

def test_get_all_categories(
self,
aws_credentials,
setup_env_variables,
categories_table,
):
self._seed_data(categories_table)

monkeypatch.setattr(CategoryModel, "scan", mock_category)
actual = CategoryModel.get_all_categories("version")
from api.models.category import get_all_categories

actual = get_all_categories(TEST_CATEGORY_VERSION)
expected = {
"name1": [
"Name1": [
{
"label": "label1",
"dimension": "dimension1",
Expand All @@ -78,7 +142,7 @@ def test_get_all_categories(self, monkeypatch):
"hierarchy": ["hierarchy1", "hierarchy2"],
},
],
"name2": [
"Name2": [
{
"label": "label3",
"dimension": "dimension3",
Expand All @@ -88,3 +152,16 @@ def test_get_all_categories(self, monkeypatch):
}

assert actual == expected

def test_get_all_categories_empty_table(
self,
aws_credentials,
setup_env_variables,
categories_table,
):
from api.models.category import get_all_categories

actual = get_all_categories(TEST_CATEGORY_VERSION)
expected = {}

assert actual == expected
2 changes: 1 addition & 1 deletion backend/api/models/_tests/test_install_activity.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,7 +44,7 @@ def _put_item(self, table, granularity, timestamp, install_count, is_total=None,
'plugin_name': plugin,
'type_timestamp': self._to_type_timestamp(granularity, timestamp),
'install_count': install_count,
'granularity': granularity,
'type': granularity,
'timestamp': to_millis(timestamp),
}
if is_total:
Expand Down
117 changes: 64 additions & 53 deletions backend/api/models/category.py
Original file line number Diff line number Diff line change
@@ -1,18 +1,18 @@
import os
import time

from slugify import slugify
from api.models.helper import set_ddb_metadata
from collections import defaultdict
from pynamodb.attributes import ListAttribute, NumberAttribute, UnicodeAttribute
from pynamodb.models import Model
from slugify import slugify
from utils.time import get_current_timestamp, print_perf_duration


@set_ddb_metadata("category")
class CategoryModel(Model):
class Meta:
host = os.getenv('LOCAL_DYNAMO_HOST')
region = os.environ.get("AWS_REGION", "us-west-2")
table_name = f"{os.environ.get('STACK_NAME')}-category"
pass

name = UnicodeAttribute(hash_key=True)
version_hash = UnicodeAttribute(range_key=True)
Expand All @@ -23,57 +23,68 @@ class Meta:
label = UnicodeAttribute()
last_updated_timestamp = NumberAttribute(default_for_new=get_current_timestamp)

@classmethod
def _get_category_from_model(cls, category):
return {
"label": category.label,
"dimension": category.dimension,
"hierarchy": category.hierarchy,
}

@classmethod
def get_category(cls, name: str, version: str):
"""
Gets the category data for a particular category and EDAM version.
"""

category = []
start = time.perf_counter()

for item in cls.query(
slugify(name), cls.version_hash.startswith(f"{version}:")
):
category.append(cls._get_category_from_model(item))

print_perf_duration(start, f"CategoryModel.get_category({name})")

return category

@classmethod
def get_all_categories(cls, version: str):
"""
Gets all available category mappings from a particular EDAM version.
"""

start = time.perf_counter()
categories = cls.scan(
cls.version == version,
attributes_to_get=[
"formatted_name",
"version",
"dimension",
"hierarchy",
"label",
],
def __eq__(self, other):
return isinstance(other, CategoryModel) and (
self.name == other.name
and self.version_hash == other.version_hash
and self.version == other.version
and self.formatted_name == other.formatted_name
and self.dimension == other.dimension
and self.hierarchy == other.hierarchy
and self.label == other.label
)

mapped_categories = defaultdict(list)

for category in categories:
mapped_categories[category.formatted_name].append(
cls._get_category_from_model(category)
)
def _get_category_from_model(category):
return {
"label": category.label,
"dimension": category.dimension,
"hierarchy": category.hierarchy,
}


def get_category(name: str, version: str):
"""
Gets the category data for a particular category and EDAM version.
"""

category = []
start = time.perf_counter()

for item in CategoryModel.query(
slugify(name), CategoryModel.version_hash.startswith(f"{version}:")
):
category.append(_get_category_from_model(item))

print_perf_duration(start, f"CategoryModel.get_category({name})")

return category


def get_all_categories(version: str):
"""
Gets all available category mappings from a particular EDAM version.
"""

start = time.perf_counter()
categories = CategoryModel.scan(
CategoryModel.version == version,
attributes_to_get=[
"formatted_name",
"version",
"dimension",
"hierarchy",
"label",
],
)

mapped_categories = defaultdict(list)

for category in categories:
mapped_categories[category.formatted_name].append(
_get_category_from_model(category)
)

print_perf_duration(start, "CategoryModel.get_all_categories()")
print_perf_duration(start, "CategoryModel.get_all_categories()")

return mapped_categories
return mapped_categories
11 changes: 11 additions & 0 deletions data-workflows/categories/category_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,3 +24,14 @@ class Meta:
hierarchy = ListAttribute() # List[str]
label = UnicodeAttribute()
last_updated_timestamp = NumberAttribute(default_for_new=get_current_timestamp)

def __eq__(self, other):
return isinstance(other, CategoryModel) and (
self.name == other.name
and self.version_hash == other.version_hash
and self.version == other.version
and self.formatted_name == other.formatted_name
and self.dimension == other.dimension
and self.hierarchy == other.hierarchy
and self.label == other.label
)

0 comments on commit 0ae21ef

Please sign in to comment.