Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Implement simple filters #64

Merged
merged 9 commits into from
Mar 14, 2024
21 changes: 20 additions & 1 deletion ixmp4/db/filters.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
from types import UnionType
from typing import Any, ClassVar, Optional, Union, get_args, get_origin

from pydantic import BaseModel, ConfigDict, Field, ValidationError
from pydantic import BaseModel, ConfigDict, Field, ValidationError, model_validator
from pydantic.fields import FieldInfo

from ixmp4 import db
Expand Down Expand Up @@ -255,6 +255,11 @@ class BaseFilter(BaseModel, metaclass=FilterMeta):
)
sqla_model: ClassVar[type | None] = None

@model_validator(mode="before")
@classmethod
def expand_simple_filters(cls, v):
return expand_simple_filter(v)

def __init__(self, **data: Any) -> None:
try:
super().__init__(**data)
Expand Down Expand Up @@ -299,3 +304,17 @@ def apply(self, exc: db.sql.Select, model, session) -> db.sql.Select:
column = getattr(model, sqla_column, None)
exc = filter_func(exc, column, value, session=session)
return exc.distinct()


def expand_simple_filter(value):
if isinstance(value, str):
if "*" in value:
return dict(name__like=value)
else:
return dict(name=value)
elif isinstance(value, list):
if any(["*" in v for v in value]):
raise NotImplementedError("Filter by list with wildcard is not implemented")
return dict(name__in=value)

return value
19 changes: 15 additions & 4 deletions tests/core/test_iamc.py
Original file line number Diff line number Diff line change
Expand Up @@ -76,17 +76,28 @@ def test_unit_as_string_dimensionless_raises(test_mp, test_data_annual, request)


@all_platforms
def test_run_tabulate_with_filter_raw(test_mp, test_data_annual, request):
@pytest.mark.parametrize(
"filters",
(
dict(variable={"name": "Primary Energy"}),
dict(variable={"name": "Primary Energy"}, unit={"name": "EJ/yr"}),
dict(variable={"name__like": "* Energy"}, unit={"name": "EJ/yr"}),
dict(variable={"name__in": ["Primary Energy", "Some Other Variable"]}),
dict(variable="Primary Energy"),
dict(variable="Primary Energy", unit="EJ/yr"),
dict(variable="* Energy", unit="EJ/yr"),
dict(variable=["Primary Energy", "Some Other Variable"]),
),
)
def test_run_tabulate_with_filter_raw(test_mp, test_data_annual, request, filters):
test_mp = request.getfixturevalue(test_mp)
# Filter run directly
add_regions(test_mp, test_data_annual["region"].unique())
add_units(test_mp, test_data_annual["unit"].unique())

run = test_mp.runs.create("Model", "Scenario")
run.iamc.add(test_data_annual, type=DataPoint.Type.ANNUAL)
obs = run.iamc.tabulate(
raw=True, variable={"name": "Primary Energy"}, unit={"name": "EJ/yr"}
).drop(["id", "type"], axis=1)
obs = run.iamc.tabulate(raw=True, **filters).drop(["id", "type"], axis=1)
exp = test_data_annual[test_data_annual.variable == "Primary Energy"]
assert_unordered_equality(obs, exp, check_like=True)

Expand Down
1 change: 0 additions & 1 deletion tests/data/test_iamc_datapoint.py
Original file line number Diff line number Diff line change
Expand Up @@ -71,7 +71,6 @@ def test_filtering(test_mp, filter, exp_filter, request):
@pytest.mark.parametrize(
"filter",
[
{"unit": "test"},
{"dne": {"dne": "test"}},
{"region": {"dne": "test"}},
{"region": {"name__in": False}},
Expand Down
Loading