Skip to content

Commit

Permalink
Tests for eWaterCycleModel (#378)
Browse files Browse the repository at this point in the history
* Replace test_abstract with base/test_model

* Add more tests for base.model module

* Dont depend on fake models from grpc4bmi

* Make type of eWaterCycleModel.parameters property an ItemsView

To be more inline with pymt

See #365 (comment)

* Use local var instead of prop

* Fix more tests

* More and better tests

* Use public interface where possible

* Update model.py

* Sort imports
  • Loading branch information
sverhoeven committed Oct 5, 2023
1 parent deba2df commit ad99ebf
Show file tree
Hide file tree
Showing 18 changed files with 633 additions and 851 deletions.
1 change: 1 addition & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@ Formatted as described on [https://keepachangelog.com](https://keepachangelog.co
- Forcing ((#365)[https://github.com/eWaterCycle/ewatercycle/pull/365]):
- Instead of modifying an existing recipe now builds a ESMValTool recipe from scratch using a fluent interface
- DefaultForcing has overridable class methods for each step of the forcing generation process (build_recipe, run_recipe, recipe_output_to_forcing_arguments).
- eWaterCycleModel.parameters property type is ItemsView instead of dict.

### Deprecated

Expand Down
32 changes: 17 additions & 15 deletions src/ewatercycle/base/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
import datetime
import inspect
import logging
from collections.abc import ItemsView
from datetime import timezone
from pathlib import Path
from typing import Annotated, Any, Iterable, Optional, Type, cast
Expand All @@ -14,7 +15,6 @@
import yaml
from cftime import num2pydate
from grpc4bmi.bmi_optionaldest import OptionalDestBmi
from grpc4bmi.reserve import reserve_values, reserve_values_at_indices
from pydantic import (
BaseModel,
BeforeValidator,
Expand Down Expand Up @@ -88,9 +88,9 @@ def _make_bmi_instance(self) -> OptionalDestBmi:
# where it is {}.items()
# TODO is this OK?
@property
def parameters(self) -> dict[str, Any]:
def parameters(self) -> ItemsView[str, Any]:
"""Display the model's parameters and their values."""
return {}
return {}.items()

def setup(self, *, cfg_dir: str | None = None, **kwargs) -> tuple[str, str]:
"""Perform model setup.
Expand Down Expand Up @@ -136,9 +136,10 @@ def _make_cfg_dir(self, cfg_dir: Optional[str] = None, **kwargs) -> Path:
def _make_cfg_file(self, **kwargs):
"""Create new config file and return its path."""
cfg_file = self._cfg_dir / "config.yaml"
self.parameters.update(**kwargs)
myparameters = dict(list(self.parameters))
myparameters.update(**kwargs)
with cfg_file.open(mode="w") as file:
yaml.dump({k: v for k, v in self.parameters}, file)
yaml.dump({k: v for k, v in myparameters}, file)

return cfg_file

Expand Down Expand Up @@ -167,7 +168,10 @@ def initialize(self, config_file: str) -> None:
self._bmi.initialize(config_file)

def finalize(self) -> None:
"""Perform tear-down tasks for the model."""
"""Perform tear-down tasks for the model.
After finalization, the model should not be used anymore.
"""
self._bmi.finalize()
del self._bmi

Expand All @@ -181,10 +185,7 @@ def get_value(self, name: str) -> np.ndarray:
Args:
name: Name of variable
"""
if isinstance(self._bmi, OptionalDestBmi):
return self._bmi.get_value(name)
dest = reserve_values(self._bmi, name)
return self._bmi.get_value(name, dest)
return self._bmi.get_value(name)

def get_value_at_coords(
self, name, lat: Iterable[float], lon: Iterable[float]
Expand All @@ -198,10 +199,7 @@ def get_value_at_coords(
"""
indices = self._coords_to_indices(name, lat, lon)
indices = np.array(indices)
if isinstance(self._bmi, OptionalDestBmi):
return self._bmi.get_value_at_indices(name, indices)
dest = reserve_values_at_indices(self._bmi, name, indices)
return self._bmi.get_value_at_indices(name, dest, indices)
return self._bmi.get_value_at_indices(name, indices)

def set_value(self, name: str, value: np.ndarray) -> None:
"""Specify a new value for a model variable.
Expand Down Expand Up @@ -273,9 +271,9 @@ def get_value_as_xarray(self, name: str) -> xr.DataArray:
data=np.reshape(
self.get_value(name),
(
1,
shape[0],
shape[1],
1,
),
),
coords={
Expand Down Expand Up @@ -371,6 +369,10 @@ def end_time_as_datetime(self) -> datetime.datetime:
@property
def time_as_datetime(self) -> datetime.datetime:
"""Current time of the model as a datetime object'."""
# TODO some bmi implementations like Wflow.jl returns 'd'
# which can not be converted to a datetime object
# as nupmy2date expects a
# `<time units> since <reference time>` formatted string
return num2pydate(
self._bmi.get_current_time(),
self._bmi.get_time_units(),
Expand Down
6 changes: 3 additions & 3 deletions src/ewatercycle/plugins/hype/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
import logging
import types
from pathlib import Path
from typing import Any, Iterable, Optional
from typing import Any, ItemsView, Iterable, Optional

import bmipy
import xarray as xr
Expand Down Expand Up @@ -133,7 +133,7 @@ def get_latlon_grid(self, name: str) -> tuple[Any, Any, Any]:
raise NotImplementedError("Hype coordinates cannot be mapped to grid")

@property
def parameters(self) -> dict[str, Any]:
def parameters(self) -> ItemsView[str, Any]:
"""List the parameters for this model.
Exposed Lisflood parameters:
Expand All @@ -157,7 +157,7 @@ def parameters(self) -> dict[str, Any]:
"crit_time": _get_hype_time(
_get_code_in_cfg(self._config, "cdate")
).strftime(ISO_TIMEFMT),
}
}.items()

def _coords_to_indices(
self, name: str, lat: Iterable[float], lon: Iterable[float]
Expand Down
6 changes: 3 additions & 3 deletions src/ewatercycle/plugins/lisflood/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@

import logging
from pathlib import Path
from typing import Any
from typing import Any, ItemsView

from pydantic import PrivateAttr, model_validator

Expand Down Expand Up @@ -137,7 +137,7 @@ def _make_cfg_file(self, **kwargs) -> Path:
return lisflood_file

@property
def parameters(self) -> dict[str, Any]:
def parameters(self) -> ItemsView[str, Any]:
"""List the parameters for this model.
Exposed Lisflood parameters:
Expand All @@ -158,7 +158,7 @@ def parameters(self) -> dict[str, Any]:
"MaskMap": self._get_textvar_value("MaskMap"),
"start_time": get_time(self.forcing.start_time).strftime(ISO_TIMEFMT),
"end_time": get_time(self.forcing.end_time).strftime(ISO_TIMEFMT),
}
}.items()

def finalize(self) -> None:
"""Perform tear-down tasks for the model."""
Expand Down
10 changes: 5 additions & 5 deletions src/ewatercycle/plugins/marrmot/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
import logging
from dataclasses import asdict, dataclass
from pathlib import Path
from typing import Any, Literal
from typing import Any, ItemsView, Literal

import scipy.io as sio
from pydantic import PrivateAttr, model_validator
Expand Down Expand Up @@ -136,7 +136,7 @@ def _make_cfg_file(self, **kwargs) -> Path:
return config_file

@property
def parameters(self) -> dict[str, Any]:
def parameters(self) -> ItemsView[str, Any]:
"""List MarrmotM01's parameters and their values.
Model parameters:
Expand All @@ -157,7 +157,7 @@ def parameters(self) -> dict[str, Any]:
"solver": get_solver(self._config),
"start time": get_marrmot_time(self._config, "start"),
"end time": get_marrmot_time(self._config, "end"),
}
}.items()


M14_PARAMS = (
Expand Down Expand Up @@ -272,7 +272,7 @@ def _make_cfg_file(self, **kwargs) -> Path:
return config_file

@property
def parameters(self) -> dict[str, Any]:
def parameters(self) -> ItemsView[str, Any]:
"""List the parameters for this model.
Exposed Marrmot M14 parameters:
Expand Down Expand Up @@ -304,4 +304,4 @@ def parameters(self) -> dict[str, Any]:
"end time": get_marrmot_time(self._config, "end"),
}
)
return pars
return pars.items()
6 changes: 3 additions & 3 deletions src/ewatercycle/plugins/pcrglobwb/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@

import logging
from os import PathLike
from typing import Any, Optional
from typing import Any, ItemsView, Optional

import bmipy
import numpy as np
Expand Down Expand Up @@ -97,7 +97,7 @@ def _initialize_config(self: "PCRGlobWB") -> "PCRGlobWB":
return self

@property
def parameters(self) -> dict[str, Any]:
def parameters(self) -> ItemsView[str, Any]:
"""List the parameters for this model."""
return {
"start_time": f"{self._config.get('globalOptions', 'startTime')}T00:00:00Z",
Expand All @@ -106,7 +106,7 @@ def parameters(self) -> dict[str, Any]:
"max_spinups_in_years": self._config.get(
"globalOptions", "maxSpinUpsInYears"
),
}
}.items()

def _make_cfg_file(self, **kwargs):
self._update_config(**kwargs)
Expand Down
6 changes: 3 additions & 3 deletions src/ewatercycle/plugins/wflow/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
import logging
import shutil
from pathlib import Path
from typing import Any, Optional
from typing import Any, ItemsView, Optional

import bmipy
import numpy as np
Expand Down Expand Up @@ -142,11 +142,11 @@ def _make_cfg_dir(self, cfg_dir: Optional[str] = None, **kwargs) -> Path:
return self._work_dir

@property
def parameters(self) -> dict[str, Any]:
def parameters(self) -> ItemsView[str, Any]:
return {
"start_time": _wflow_to_iso(self._config.get("run", "starttime")),
"end_time": _wflow_to_iso(self._config.get("run", "endtime")),
}
}.items()


def _wflow_to_iso(time):
Expand Down
134 changes: 134 additions & 0 deletions src/ewatercycle/testing/fake_models.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,6 @@
"""Fake BMI models."""
from unittest.mock import Mock

import numpy as np
from bmipy import Bmi

Expand Down Expand Up @@ -135,3 +137,135 @@ def get_grid_nodes_per_face(

def get_grid_face_edges(self, grid: int, face_edges: np.ndarray) -> np.ndarray:
raise self.exc


class NotImplementedModel(FailingModel):
def __init__(self, exc=NotImplementedError()):
super().__init__(exc=exc)


class WithMocksMixin:
"""Mock the bmi methods that return None and have no getter companion.
Use `instance.mock.<method name>.assert_called_once_with()`
to check if the method is called correctly.
"""

def __init__(self):
self.mock = Mock()

def initialize(self, config_file: str) -> None:
self.mock.initialize(config_file)

def finalize(self):
self.mock.finalize()


class WithDailyMixin:
"""Mock the bmi methods that deal wtih time.
Behaves like a daily model which started since epoch.
"""

def __init__(self) -> None:
self.time = 0.0

def update(self):
self.time = self.time + self.get_time_step()

def get_current_time(self):
return self.time

def get_start_time(self):
return 0.0

def get_end_time(self):
return 100.0

def get_time_step(self):
return 1.0

def get_time_units(self):
return "days since 1970-01-01"


class DummyModelWith2DRectilinearGrid(
WithMocksMixin, WithDailyMixin, NotImplementedModel
):
def __init__(self):
super().__init__()
# not sure why extra call to init is needed,
# but without the self.time is not initialized
WithMocksMixin.__init__(self)
WithDailyMixin.__init__(self)
self.dtype = np.dtype("float32")
self.value = np.array(
[
1.1,
2.2,
3.3,
4.4,
5.5,
6.6,
7.7,
8.8,
9.9,
10.1,
11.1,
12.1,
],
dtype=self.dtype,
)

def get_output_var_names(self) -> tuple[str]:
return ("plate_surface__temperature",)

def get_var_type(self, name):
return str(self.dtype)

def get_var_grid(self, name):
return 0

def get_var_units(self, name):
return "K"

def get_var_itemsize(self, name):
return self.dtype.itemsize

def get_var_nbytes(self, name):
return self.dtype.itemsize * self.value.size

def get_grid_type(self, grid):
return "rectilinear"

def get_grid_size(self, grid):
return 12 # 4 longs * 3 lats

def get_grid_rank(self, grid: int) -> int:
return 2

def get_grid_shape(self, grid: int, shape: np.ndarray) -> np.ndarray:
np.copyto(src=[3, 4], dst=shape)
return shape

def get_value(self, name, dest):
np.copyto(src=self.value, dst=dest)
return dest

def get_value_at_indices(self, name, dest, inds):
np.copyto(src=self.value[inds], dst=dest)
return dest

def set_value(self, name, src):
self.value[:] = src

def set_value_at_indices(self, name, inds, src):
self.value[inds] = src

def get_grid_x(self, grid: int, x: np.ndarray) -> np.ndarray:
np.copyto(src=[0.1, 0.2, 0.3, 0.4], dst=x)
return x

def get_grid_y(self, grid: int, y: np.ndarray) -> np.ndarray:
np.copyto(src=[1.1, 1.2, 1.3], dst=y)
return y

0 comments on commit ad99ebf

Please sign in to comment.