Skip to content

Commit

Permalink
Implement simple filters (#64)
Browse files Browse the repository at this point in the history
  • Loading branch information
danielhuppmann committed Mar 14, 2024
1 parent b33a700 commit 3f45bba
Show file tree
Hide file tree
Showing 3 changed files with 35 additions and 6 deletions.
21 changes: 20 additions & 1 deletion ixmp4/db/filters.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
from types import UnionType
from typing import Any, ClassVar, Optional, Union, get_args, get_origin

from pydantic import BaseModel, ConfigDict, Field, ValidationError
from pydantic import BaseModel, ConfigDict, Field, ValidationError, model_validator
from pydantic.fields import FieldInfo

from ixmp4 import db
Expand Down Expand Up @@ -255,6 +255,11 @@ class BaseFilter(BaseModel, metaclass=FilterMeta):
)
sqla_model: ClassVar[type | None] = None

@model_validator(mode="before")
@classmethod
def expand_simple_filters(cls, v):
return expand_simple_filter(v)

def __init__(self, **data: Any) -> None:
try:
super().__init__(**data)
Expand Down Expand Up @@ -299,3 +304,17 @@ def apply(self, exc: db.sql.Select, model, session) -> db.sql.Select:
column = getattr(model, sqla_column, None)
exc = filter_func(exc, column, value, session=session)
return exc.distinct()


def expand_simple_filter(value):
if isinstance(value, str):
if "*" in value:
return dict(name__like=value)
else:
return dict(name=value)
elif isinstance(value, list):
if any(["*" in v for v in value]):
raise NotImplementedError("Filter by list with wildcard is not implemented")
return dict(name__in=value)

return value
19 changes: 15 additions & 4 deletions tests/core/test_iamc.py
Original file line number Diff line number Diff line change
Expand Up @@ -76,17 +76,28 @@ def test_unit_as_string_dimensionless_raises(test_mp, test_data_annual, request)


@all_platforms
def test_run_tabulate_with_filter_raw(test_mp, test_data_annual, request):
@pytest.mark.parametrize(
"filters",
(
dict(variable={"name": "Primary Energy"}),
dict(variable={"name": "Primary Energy"}, unit={"name": "EJ/yr"}),
dict(variable={"name__like": "* Energy"}, unit={"name": "EJ/yr"}),
dict(variable={"name__in": ["Primary Energy", "Some Other Variable"]}),
dict(variable="Primary Energy"),
dict(variable="Primary Energy", unit="EJ/yr"),
dict(variable="* Energy", unit="EJ/yr"),
dict(variable=["Primary Energy", "Some Other Variable"]),
),
)
def test_run_tabulate_with_filter_raw(test_mp, test_data_annual, request, filters):
test_mp = request.getfixturevalue(test_mp)
# Filter run directly
add_regions(test_mp, test_data_annual["region"].unique())
add_units(test_mp, test_data_annual["unit"].unique())

run = test_mp.runs.create("Model", "Scenario")
run.iamc.add(test_data_annual, type=DataPoint.Type.ANNUAL)
obs = run.iamc.tabulate(
raw=True, variable={"name": "Primary Energy"}, unit={"name": "EJ/yr"}
).drop(["id", "type"], axis=1)
obs = run.iamc.tabulate(raw=True, **filters).drop(["id", "type"], axis=1)
exp = test_data_annual[test_data_annual.variable == "Primary Energy"]
assert_unordered_equality(obs, exp, check_like=True)

Expand Down
1 change: 0 additions & 1 deletion tests/data/test_iamc_datapoint.py
Original file line number Diff line number Diff line change
Expand Up @@ -71,7 +71,6 @@ def test_filtering(test_mp, filter, exp_filter, request):
@pytest.mark.parametrize(
"filter",
[
{"unit": "test"},
{"dne": {"dne": "test"}},
{"region": {"dne": "test"}},
{"region": {"name__in": False}},
Expand Down

0 comments on commit 3f45bba

Please sign in to comment.