Skip to content

Commit

Permalink
Refactor stateful test so adding deletion is easy
Browse files Browse the repository at this point in the history
  • Loading branch information
eivindjahren committed Feb 9, 2024
1 parent 1bc3524 commit 0fb444f
Showing 1 changed file with 90 additions and 76 deletions.
166 changes: 90 additions & 76 deletions tests/unit_tests/storage/test_local_storage.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,5 @@
import shutil
import tempfile
from collections import defaultdict
from dataclasses import dataclass, field
from pathlib import Path
from typing import Any, Dict, List
Expand Down Expand Up @@ -211,9 +210,17 @@ def observation_dicts(draw):
)


@dataclass
class Ensemble:
uuid: UUID
parameter_values: Dict[str, Any] = field(default_factory=dict)
failure_messages: Dict[int, str] = field(default_factory=dict)


@dataclass
class Experiment:
ensembles: Dict[UUID, Dict[str, Any]] = field(default_factory=dict)
uuid: UUID
ensembles: Dict[UUID, Ensemble] = field(default_factory=dict)
parameters: List[ParameterConfig] = field(default_factory=list)
responses: List[ResponseConfig] = field(default_factory=list)
observations: Dict[str, xr.Dataset] = field(default_factory=dict)
Expand Down Expand Up @@ -245,13 +252,11 @@ def __init__(self):
super().__init__()
self.tmpdir = tempfile.mkdtemp()
self.storage = open_storage(self.tmpdir + "/storage/", "w")
self.experiments = defaultdict(Experiment)
self.failure_messages = {}
self.model = {}
assert list(self.storage.ensembles) == []

experiment_ids = Bundle("experiments")
ensemble_ids = Bundle("ensembles")
failures = Bundle("failures")
experiments = Bundle("experiments")
ensembles = Bundle("ensembles")
field_list = Bundle("field_list")
grid = Bundle("grid")

Expand Down Expand Up @@ -284,7 +289,7 @@ def reopen(self):
assert cases == sorted(e.id for e in self.storage.ensembles)

@rule(
target=experiment_ids,
target=experiments,
parameters=st.one_of(parameter_configs, field_list),
responses=response_configs,
obs=observations,
Expand All @@ -298,29 +303,31 @@ def create_experiment(
experiment_id = self.storage.create_experiment(
parameters=parameters, responses=responses, observations=obs.datasets
).id
self.experiments[experiment_id].parameters = parameters
self.experiments[experiment_id].responses = responses
self.experiments[experiment_id].observations = obs.datasets
model_experiment = Experiment(experiment_id)
model_experiment.parameters = parameters
model_experiment.responses = responses
model_experiment.observations = obs.datasets

# Ensure that there is at least one ensemble in the experiment
# to avoid https://github.com/equinor/ert/issues/7040
ensemble = self.storage.create_ensemble(experiment_id, ensemble_size=1)
self.experiments[experiment_id].ensembles[ensemble.id] = {}
model_experiment.ensembles[ensemble.id] = Ensemble(ensemble.id)

return experiment_id
self.model[model_experiment.uuid] = model_experiment

return model_experiment

@rule(
ensemble_id=ensemble_ids,
model_ensemble=ensembles,
field_data=grid.flatmap(lambda g: arrays(np.float32, shape=g[1].shape)),
)
def save_field(self, ensemble_id: UUID, field_data):
ensemble = self.storage.get_ensemble(ensemble_id)
experiment_id = ensemble.experiment_id
parameters = self.experiments[experiment_id].parameters
def save_field(self, model_ensemble: Ensemble, field_data):
storage_ensemble = self.storage.get_ensemble(model_ensemble.uuid)
parameters = model_ensemble.parameter_values.values()
fields = [p for p in parameters if isinstance(p, Field)]
for f in fields:
self.experiments[experiment_id].ensembles[ensemble_id][f.name] = field_data
ensemble.save_parameters(
model_ensemble.parameter_values[f.name] = field_data
storage_ensemble.save_parameters(
f.name,
1,
xr.DataArray(
Expand All @@ -331,111 +338,118 @@ def save_field(self, ensemble_id: UUID, field_data):
)

@rule(
ensemble_id=ensemble_ids,
model_ensemble=ensembles,
)
def get_field(self, ensemble_id: UUID):
ensemble = self.storage.get_ensemble(ensemble_id)
experiment_id = ensemble.experiment_id
field_names = self.experiments[experiment_id].ensembles[ensemble_id].keys()
def get_field(self, model_ensemble: Ensemble):
storage_ensemble = self.storage.get_ensemble(model_ensemble.uuid)
field_names = model_ensemble.parameter_values.keys()
for f in field_names:
field_data = ensemble.load_parameters(f, 1)
field_data = storage_ensemble.load_parameters(f, 1)
np.testing.assert_array_equal(
self.experiments[experiment_id].ensembles[ensemble_id][f],
model_ensemble.parameter_values[f],
field_data["values"],
)

@rule(ensemble_id=ensemble_ids, parameter=words)
def load_unknown_parameter(self, ensemble_id: UUID, parameter: str):
ensemble = self.storage.get_ensemble(ensemble_id)
experiment_id = ensemble.experiment_id
parameter_names = [p.name for p in self.experiments[experiment_id].parameters]
@rule(model_ensemble=ensembles, parameter=words)
def load_unknown_parameter(self, model_ensemble: Ensemble, parameter: str):
storage_ensemble = self.storage.get_ensemble(model_ensemble.uuid)
experiment_id = storage_ensemble.experiment_id
parameter_names = [p.name for p in self.model[experiment_id].parameters]
assume(parameter not in parameter_names)
with pytest.raises(
KeyError, match=f"No dataset '{parameter}' in storage for realization 0"
):
_ = ensemble.load_parameters(parameter, 0)
_ = storage_ensemble.load_parameters(parameter, 0)

@rule(
target=ensemble_ids,
experiment=experiment_ids,
target=ensembles,
model_experiment=experiments,
ensemble_size=ensemble_sizes,
)
def create_ensemble(self, experiment: UUID, ensemble_size: int):
ensemble = self.storage.create_ensemble(experiment, ensemble_size=ensemble_size)
def create_ensemble(self, model_experiment: Experiment, ensemble_size: int):
ensemble = self.storage.create_ensemble(
model_experiment.uuid, ensemble_size=ensemble_size
)
assert ensemble in self.storage.ensembles
self.experiments[experiment].ensembles[ensemble.id] = {}
model_ensemble = Ensemble(ensemble.id)
model_experiment.ensembles[ensemble.id] = model_ensemble

# https://github.com/equinor/ert/issues/7046
# assert (
# ensemble.get_ensemble_state()
# == [RealizationStorageState.UNDEFINED] * ensemble_size
# )

return ensemble.id
return model_ensemble

@rule(
target=ensemble_ids,
prior=ensemble_ids,
target=ensembles,
prior=ensembles,
)
def create_ensemble_from_prior(self, prior: UUID):
prior_ensemble = self.storage.get_ensemble(prior)
experiment = prior_ensemble.experiment_id
def create_ensemble_from_prior(self, prior: Ensemble):
prior_ensemble = self.storage.get_ensemble(prior.uuid)
experiment_id = prior_ensemble.experiment_id
size = prior_ensemble.ensemble_size
ensemble = self.storage.create_ensemble(
experiment, ensemble_size=size, prior_ensemble=prior
experiment_id, ensemble_size=size, prior_ensemble=prior.uuid
)
assert ensemble in self.storage.ensembles
self.experiments[experiment].ensembles[ensemble.id] = {}
model_ensemble = Ensemble(ensemble.id)
self.model[experiment_id].ensembles[ensemble.id] = model_ensemble
# https://github.com/equinor/ert/issues/7046
# assert (
# ensemble.get_ensemble_state()
# == [RealizationStorageState.PARENT_FAILURE] * size
# )

return ensemble.id
return model_ensemble

@rule(id=experiment_ids)
def get_experiment(self, id: UUID):
experiment = self.storage.get_experiment(id)
assert experiment.id == id
assert sorted(self.experiments[id].ensembles) == sorted(
e.id for e in experiment.ensembles
@rule(model_experiment=experiments)
def get_experiment(self, model_experiment: Experiment):
storage_experiment = self.storage.get_experiment(model_experiment.uuid)
assert storage_experiment.id == model_experiment.uuid
assert sorted(model_experiment.ensembles) == sorted(
e.id for e in storage_experiment.ensembles
)
assert (
list(experiment.response_configuration.values())
== self.experiments[id].responses
list(storage_experiment.response_configuration.values())
== model_experiment.responses
)
assert self.experiments[id].observations == pytest.approx(
experiment.observations
assert model_experiment.observations == pytest.approx(
storage_experiment.observations
)

@rule(id=ensemble_ids)
def get_ensemble(self, id: UUID):
ensemble = self.storage.get_ensemble(id)
assert ensemble.id == id
@rule(model_ensemble=ensembles)
def get_ensemble(self, model_ensemble: Ensemble):
storage_ensemble = self.storage.get_ensemble(model_ensemble.uuid)
assert storage_ensemble.id == model_ensemble.uuid

@rule(target=failures, id=ensemble_ids, data=st.data(), message=st.text())
def set_failure(self, id: UUID, data: st.DataObject, message: str):
ensemble = self.storage.get_ensemble(id)
assert ensemble.id == id
@rule(model_ensemble=ensembles, data=st.data(), message=st.text())
def set_failure(self, model_ensemble: Ensemble, data: st.DataObject, message: str):
storage_ensemble = self.storage.get_ensemble(model_ensemble.uuid)
assert storage_ensemble.id == model_ensemble.uuid

realization = data.draw(
st.integers(min_value=0, max_value=ensemble.ensemble_size - 1)
st.integers(min_value=0, max_value=storage_ensemble.ensemble_size - 1)
)

ensemble.set_failure(
storage_ensemble.set_failure(
realization, RealizationStorageState.PARENT_FAILURE, message
)
self.failure_messages[ensemble.id, realization] = message

return (ensemble.id, realization)
model_ensemble.failure_messages[realization] = message

@rule(failure=failures)
def get_failure(self, failure):
(ensemble, realization) = failure
fail = self.storage.get_ensemble(ensemble).get_failure(realization)
assert fail is not None
assert fail.message == self.failure_messages[ensemble, realization]
@rule(model_ensemble=ensembles, data=st.data())
def get_failure(self, model_ensemble: Ensemble, data: st.DataObject):
storage_ensemble = self.storage.get_ensemble(model_ensemble.uuid)
realization = data.draw(
st.integers(min_value=0, max_value=storage_ensemble.ensemble_size - 1)
)
fail = self.storage.get_ensemble(model_ensemble.uuid).get_failure(realization)
if realization in model_ensemble.failure_messages:
assert fail is not None
assert fail.message == model_ensemble.failure_messages[realization]
else:
assert fail is None or "Failure from prior" in fail.message

def teardown(self):
if self.storage is not None:
Expand Down

0 comments on commit 0fb444f

Please sign in to comment.