From 9c1fea35bcbf37f387d630431b8713eef8bd74d7 Mon Sep 17 00:00:00 2001 From: peterdudfield Date: Tue, 22 Mar 2022 18:38:00 +0000 Subject: [PATCH 1/2] add method to read all lcations + tests --- nowcasting_datamodel/read/read.py | 28 ++++++++++++++++++++++++++++ tests/test_read.py | 10 ++++++++++ 2 files changed, 38 insertions(+) diff --git a/nowcasting_datamodel/read/read.py b/nowcasting_datamodel/read/read.py index 79433c0e..79ffbd52 100644 --- a/nowcasting_datamodel/read/read.py +++ b/nowcasting_datamodel/read/read.py @@ -175,6 +175,34 @@ def get_location(session: Session, gsp_id: int) -> LocationSQL: return location + +def get_all_location(session: Session, gsp_ids: List[int] = None) -> List[LocationSQL]: + """ + Get all location object from gsp id + + :param session: database session + :param gsp_ids: list of gsp id of the location + + return: List of GSP locations + + """ + + # start main query + query = session.query(LocationSQL) + query = query.distinct(LocationSQL.gsp_id) + + # filter on gsp_id + if gsp_ids is not None: + query = query.filter(LocationSQL.gsp_id.in_(gsp_ids)) + + query = query.order_by(LocationSQL.gsp_id) + + # get all results + locations = query.all() + + return locations + + def get_model(session: Session, name: str, version: str) -> MLModelSQL: """ Get model object from name and version diff --git a/tests/test_read.py b/tests/test_read.py index d8364fd1..d61057cc 100644 --- a/tests/test_read.py +++ b/tests/test_read.py @@ -13,12 +13,22 @@ get_latest_national_forecast, get_model, get_pv_system, + get_all_location ) from nowcasting_datamodel.save import save_pv_system logger = logging.getLogger(__name__) +def test_get_all_location(db_session): + + db_session.add(LocationSQL(label='GSP_1',gsp_id=1)) + db_session.add(LocationSQL(label='GSP_2', gsp_id=2)) + + locations = get_all_location(session=db_session) + assert len(locations) == 2 + + def test_get_model(db_session): model_read_1 = get_model(session=db_session, name="test_name", version="9.9.9") From 617dccee86f2eb439900a3b3b87a73a220666f9b Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Tue, 22 Mar 2022 18:38:56 +0000 Subject: [PATCH 2/2] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- nowcasting_datamodel/read/read.py | 1 - tests/test_read.py | 6 +++--- 2 files changed, 3 insertions(+), 4 deletions(-) diff --git a/nowcasting_datamodel/read/read.py b/nowcasting_datamodel/read/read.py index 79ffbd52..b613f573 100644 --- a/nowcasting_datamodel/read/read.py +++ b/nowcasting_datamodel/read/read.py @@ -175,7 +175,6 @@ def get_location(session: Session, gsp_id: int) -> LocationSQL: return location - def get_all_location(session: Session, gsp_ids: List[int] = None) -> List[LocationSQL]: """ Get all location object from gsp id diff --git a/tests/test_read.py b/tests/test_read.py index d61057cc..954b9f3b 100644 --- a/tests/test_read.py +++ b/tests/test_read.py @@ -8,12 +8,12 @@ from nowcasting_datamodel.models import Forecast, ForecastValue, LocationSQL, MLModel, PVSystem from nowcasting_datamodel.read.read import ( get_all_gsp_ids_latest_forecast, + get_all_location, get_forecast_values, get_latest_forecast, get_latest_national_forecast, get_model, get_pv_system, - get_all_location ) from nowcasting_datamodel.save import save_pv_system @@ -22,8 +22,8 @@ def test_get_all_location(db_session): - db_session.add(LocationSQL(label='GSP_1',gsp_id=1)) - db_session.add(LocationSQL(label='GSP_2', gsp_id=2)) + db_session.add(LocationSQL(label="GSP_1", gsp_id=1)) + db_session.add(LocationSQL(label="GSP_2", gsp_id=2)) locations = get_all_location(session=db_session) assert len(locations) == 2