Skip to content

Commit

Permalink
Replace use of pydantic for ForwardModel with plain dict
Browse files Browse the repository at this point in the history
The pydantic version is causing our GUI to freeze for several seconds on
heavy updates. Using plain dict eliminates most of the freeze duration.
  • Loading branch information
JHolba committed Jun 21, 2024
1 parent d577b9b commit 720be68
Show file tree
Hide file tree
Showing 8 changed files with 165 additions and 164 deletions.
8 changes: 5 additions & 3 deletions src/ert/cli/monitor.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
Snapshot,
SnapshotUpdateEvent,
)
from ert.ensemble_evaluator import identifiers as ids
from ert.ensemble_evaluator.state import (
ALL_REALIZATION_STATES,
COLOR_FAILED,
Expand Down Expand Up @@ -109,9 +110,10 @@ def _print_job_errors(self) -> None:
for snapshot in self._snapshots.values():
for real in snapshot.reals.values():
for job in real.forward_models.values():
if job.status == FORWARD_MODEL_STATE_FAILURE:
result = failed_jobs.get(job.error, 0)
failed_jobs[job.error] = result + 1
if job.get(ids.STATUS) == FORWARD_MODEL_STATE_FAILURE:
err = job.get(ids.ERROR)
result = failed_jobs.get(err, 0)
failed_jobs[err] = result + 1
for error, number_of_jobs in failed_jobs.items():
print(f"{number_of_jobs} jobs failed due to the error: {error}")

Expand Down
3 changes: 2 additions & 1 deletion src/ert/ensemble_evaluator/evaluator.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@

from _ert.async_utils import new_event_loop
from _ert.threading import ErtThread
from ert.ensemble_evaluator import identifiers as ids
from ert.serialization import evaluator_marshaller, evaluator_unmarshaller

from ._builder import Ensemble
Expand Down Expand Up @@ -136,7 +137,7 @@ def _stopped_handler(self, events: List[CloudEvent]) -> None:
with self._snapshot_mutex:
max_memory_usage = -1
for job in self.ensemble.snapshot.get_all_forward_models().values():
memory_usage = job.max_memory_usage or "-1"
memory_usage = job[ids.MAX_MEMORY_USAGE] or "-1"
max_memory_usage = max(int(memory_usage), max_memory_usage)
logger.info(
f"Ensemble ran with maximum memory usage for a single realization job: {max_memory_usage}"
Expand Down
64 changes: 33 additions & 31 deletions src/ert/ensemble_evaluator/identifiers.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
from typing import Final

from ert.event_type_constants import (
EVTYPE_ENSEMBLE_CANCELLED,
EVTYPE_ENSEMBLE_FAILED,
Expand All @@ -12,33 +14,33 @@
EVTYPE_REALIZATION_WAITING,
)

ACTIVE = "active"
CURRENT_MEMORY_USAGE = "current_memory_usage"
DATA = "data"
END_TIME = "end_time"
ERROR = "error"
ERROR_MSG = "error_msg"
ERROR_FILE = "error_file"
INDEX = "index"
JOBS = "jobs"
MAX_MEMORY_USAGE = "max_memory_usage"
METADATA = "metadata"
NAME = "name"
REALS = "reals"
START_TIME = "start_time"
STATUS = "status"
STDERR = "stderr"
STDOUT = "stdout"
STEPS = "steps"
ACTIVE: Final = "active"
CURRENT_MEMORY_USAGE: Final = "current_memory_usage"
DATA: Final = "data"
END_TIME: Final = "end_time"
ERROR: Final = "error"
ERROR_MSG: Final = "error_msg"
ERROR_FILE: Final = "error_file"
INDEX: Final = "index"
JOBS: Final = "jobs"
MAX_MEMORY_USAGE: Final = "max_memory_usage"
METADATA: Final = "metadata"
NAME: Final = "name"
REALS: Final = "reals"
START_TIME: Final = "start_time"
STATUS: Final = "status"
STDERR: Final = "stderr"
STDOUT: Final = "stdout"
STEPS: Final = "steps"

EVTYPE_FORWARD_MODEL_START = "com.equinor.ert.forward_model_job.start"
EVTYPE_FORWARD_MODEL_RUNNING = "com.equinor.ert.forward_model_job.running"
EVTYPE_FORWARD_MODEL_SUCCESS = "com.equinor.ert.forward_model_job.success"
EVTYPE_FORWARD_MODEL_FAILURE = "com.equinor.ert.forward_model_job.failure"
EVTYPE_FORWARD_MODEL_CHECKSUM = "com.equinor.ert.forward_model_job.checksum"
EVTYPE_FORWARD_MODEL_START: Final = "com.equinor.ert.forward_model_job.start"
EVTYPE_FORWARD_MODEL_RUNNING: Final = "com.equinor.ert.forward_model_job.running"
EVTYPE_FORWARD_MODEL_SUCCESS: Final = "com.equinor.ert.forward_model_job.success"
EVTYPE_FORWARD_MODEL_FAILURE: Final = "com.equinor.ert.forward_model_job.failure"
EVTYPE_FORWARD_MODEL_CHECKSUM: Final = "com.equinor.ert.forward_model_job.checksum"


EVGROUP_REALIZATION = {
EVGROUP_REALIZATION: Final = {
EVTYPE_REALIZATION_FAILURE,
EVTYPE_REALIZATION_PENDING,
EVTYPE_REALIZATION_RUNNING,
Expand All @@ -48,7 +50,7 @@
EVTYPE_REALIZATION_TIMEOUT,
}

EVGROUP_FORWARD_MODEL = {
EVGROUP_FORWARD_MODEL: Final = {
EVTYPE_FORWARD_MODEL_START,
EVTYPE_FORWARD_MODEL_RUNNING,
EVTYPE_FORWARD_MODEL_SUCCESS,
Expand All @@ -57,14 +59,14 @@

EVGROUP_FM_ALL = EVGROUP_REALIZATION | EVGROUP_FORWARD_MODEL

EVTYPE_EE_SNAPSHOT = "com.equinor.ert.ee.snapshot"
EVTYPE_EE_SNAPSHOT_UPDATE = "com.equinor.ert.ee.snapshot_update"
EVTYPE_EE_TERMINATED = "com.equinor.ert.ee.terminated"
EVTYPE_EE_USER_CANCEL = "com.equinor.ert.ee.user_cancel"
EVTYPE_EE_USER_DONE = "com.equinor.ert.ee.user_done"
EVTYPE_EE_SNAPSHOT: Final = "com.equinor.ert.ee.snapshot"
EVTYPE_EE_SNAPSHOT_UPDATE: Final = "com.equinor.ert.ee.snapshot_update"
EVTYPE_EE_TERMINATED: Final = "com.equinor.ert.ee.terminated"
EVTYPE_EE_USER_CANCEL: Final = "com.equinor.ert.ee.user_cancel"
EVTYPE_EE_USER_DONE: Final = "com.equinor.ert.ee.user_done"


EVGROUP_ENSEMBLE = {
EVGROUP_ENSEMBLE: Final = {
EVTYPE_ENSEMBLE_STARTED,
EVTYPE_ENSEMBLE_STOPPED,
EVTYPE_ENSEMBLE_CANCELLED,
Expand Down
130 changes: 64 additions & 66 deletions src/ert/ensemble_evaluator/snapshot.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,14 +11,14 @@
Optional,
Sequence,
Tuple,
TypedDict,
Union,
)

from cloudevents.http import CloudEvent
from dateutil.parser import parse
from pydantic import BaseModel, ConfigDict
from pydantic import BaseModel
from qtpy.QtGui import QColor
from typing_extensions import TypedDict

from ert.ensemble_evaluator import identifiers as ids
from ert.ensemble_evaluator import state
Expand Down Expand Up @@ -97,18 +97,18 @@ def _filter_nones(some_dict: Dict[str, Any]) -> Dict[str, Any]:
class PartialSnapshot:
def __init__(self, snapshot: Optional["Snapshot"] = None) -> None:
self._realization_states: Dict[
str, Dict[str, Union[bool, datetime.datetime, str]]
str,
Dict[str, Union[bool, datetime.datetime, str, Dict[str, "ForwardModel"]]],
] = defaultdict(dict)
"""A shallow dictionary of realization states. The key is a string with
realization number, pointing to a dict with keys active (bool),
start_time (datetime), end_time (datetime) and status (str)."""

self._forward_model_states: Dict[
Tuple[str, str], Dict[str, Union[str, datetime.datetime]]
] = defaultdict(dict)
self._forward_model_states: Dict[Tuple[str, str], "ForwardModel"] = defaultdict(
lambda: ForwardModel()
)
"""A shallow dictionary of forward_model states. The key is a tuple of two
strings with realization id and forward_model id, pointing to a dict with
the same members as the ForwardModel."""
strings with realization id and forward_model id, pointing to a ForwardModel."""

self._ensemble_state: Optional[str] = None
# TODO not sure about possible values at this point, as GUI hijacks this one as
Expand Down Expand Up @@ -150,15 +150,11 @@ def update_forward_model(
forward_model_id: str,
forward_model: "ForwardModel",
) -> "PartialSnapshot":
forward_model_update = _filter_nones(forward_model.model_dump())

self._forward_model_states[(real_id, forward_model_id)].update(
forward_model_update
)
self._forward_model_states[(real_id, forward_model_id)].update(forward_model)
if self._snapshot:
self._snapshot._my_partial._forward_model_states[
(real_id, forward_model_id)
].update(forward_model_update)
].update(forward_model)
return self

def get_all_forward_models(
Expand Down Expand Up @@ -213,17 +209,15 @@ def to_dict(self) -> Dict[str, Any]:
if self._realization_states:
_dict["reals"] = self._realization_states

for fm_tuple, fm_values_dict in self._forward_model_states.items():
real_id = fm_tuple[0]
for (real_id, fm_id), fm_values_dict in self._forward_model_states.items():
if "reals" not in _dict:
_dict["reals"] = {}
if real_id not in _dict["reals"]:
_dict["reals"][real_id] = {}
if "forward_models" not in _dict["reals"][real_id]:
_dict["reals"][real_id]["forward_models"] = {}
_dict["reals"][real_id]["forward_models"] = ForwardModel()

forward_model_id = fm_tuple[1]
_dict["reals"][real_id]["forward_models"][forward_model_id] = fm_values_dict
_dict["reals"][real_id]["forward_models"][fm_id] = fm_values_dict

return _dict

Expand Down Expand Up @@ -274,15 +268,18 @@ def from_cloudevent(self, event: CloudEvent) -> "PartialSnapshot":
) in self._snapshot.get_forward_models_for_real(
_get_real_id(e_source)
).items():
if forward_model.status != state.FORWARD_MODEL_STATE_FINISHED:
if (
forward_model.get(ids.STATUS)
!= state.FORWARD_MODEL_STATE_FINISHED
):
real_id = _get_real_id(e_source)
forward_model_idx = (real_id, forward_model_id)
if forward_model_idx not in self._forward_model_states:
self._forward_model_states[forward_model_idx] = {}
self._forward_model_states[forward_model_idx].update(
{
"status": state.FORWARD_MODEL_STATE_FAILURE,
"end_time": end_time, # type: ignore
"end_time": end_time,
"error": "The run is cancelled due to "
"reaching MAX_RUNTIME",
}
Expand All @@ -305,25 +302,24 @@ def from_cloudevent(self, event: CloudEvent) -> "PartialSnapshot":
if event.data is not None:
error = event.data.get(ids.ERROR_MSG)

fm_dict = {
ids.STATUS: status,
ids.START_TIME: start_time,
ids.END_TIME: end_time,
ids.INDEX: _get_forward_model_index(e_source),
ids.ERROR: error,
}
fm = ForwardModel(
status=status,
start_time=start_time,
end_time=end_time,
index=_get_forward_model_index(e_source),
error=error,
)

if e_type == ids.EVTYPE_FORWARD_MODEL_RUNNING:
fm_dict[ids.CURRENT_MEMORY_USAGE] = event.data.get(
ids.CURRENT_MEMORY_USAGE
)
fm_dict[ids.MAX_MEMORY_USAGE] = event.data.get(ids.MAX_MEMORY_USAGE)
fm[ids.CURRENT_MEMORY_USAGE] = event.data.get(ids.CURRENT_MEMORY_USAGE)
fm[ids.MAX_MEMORY_USAGE] = event.data.get(ids.MAX_MEMORY_USAGE)
if e_type == ids.EVTYPE_FORWARD_MODEL_START:
fm_dict[ids.STDOUT] = event.data.get(ids.STDOUT)
fm_dict[ids.STDERR] = event.data.get(ids.STDERR)
fm[ids.STDOUT] = event.data.get(ids.STDOUT)
fm[ids.STDERR] = event.data.get(ids.STDERR)
self.update_forward_model(
_get_real_id(e_source),
_get_forward_model_id(e_source),
ForwardModel(**fm_dict),
fm,
)

elif e_type in ids.EVGROUP_ENSEMBLE:
Expand Down Expand Up @@ -363,17 +359,16 @@ def metadata(self) -> SnapshotMetadata:
def get_all_forward_models(
self,
) -> Mapping[Tuple[str, str], "ForwardModel"]:
return {
idx: ForwardModel(**forward_model_state)
for idx, forward_model_state in self._my_partial._forward_model_states.items()
}
return self._my_partial._forward_model_states.copy()

def get_forward_model_status_for_all_reals(
self,
) -> Mapping[Tuple[str, str], Union[str, datetime.datetime]]:
) -> Mapping[Tuple[str, str], str]:
return {
idx: forward_model_state["status"]
for idx, forward_model_state in self._my_partial._forward_model_states.items()
if "status" in forward_model_state
and forward_model_state["status"] is not None
}

@property
Expand All @@ -382,18 +377,18 @@ def reals(self) -> Mapping[str, "RealizationSnapshot"]:

def get_forward_models_for_real(self, real_id: str) -> Dict[str, "ForwardModel"]:
return {
forward_model_idx[1]: ForwardModel(**forward_model_data)
for forward_model_idx, forward_model_data in self._my_partial._forward_model_states.items()
if forward_model_idx[0] == real_id
fm_idx[1]: forward_model_data.copy()
for fm_idx, forward_model_data in self._my_partial._forward_model_states.items()
if fm_idx[0] == real_id
}

def get_real(self, real_id: str) -> "RealizationSnapshot":
return RealizationSnapshot(**self._my_partial._realization_states[real_id])

def get_job(self, real_id: str, forward_model_id: str) -> "ForwardModel":
return ForwardModel(
**self._my_partial._forward_model_states[(real_id, forward_model_id)]
)
return self._my_partial._forward_model_states[
(real_id, forward_model_id)
].copy()

def get_successful_realizations(self) -> typing.List[int]:
return [
Expand All @@ -415,18 +410,17 @@ def data(self) -> Mapping[str, Any]:
return self._my_partial.to_dict()


class ForwardModel(BaseModel):
model_config = ConfigDict(coerce_numbers_to_str=True)
class ForwardModel(TypedDict, total=False):
status: Optional[str]
start_time: Optional[datetime.datetime] = None
end_time: Optional[datetime.datetime] = None
index: Optional[str] = None
current_memory_usage: Optional[str] = None
max_memory_usage: Optional[str] = None
name: Optional[str] = None
error: Optional[str] = None
stdout: Optional[str] = None
stderr: Optional[str] = None
start_time: Optional[datetime.datetime]
end_time: Optional[datetime.datetime]
index: Optional[str]
current_memory_usage: Optional[str]
max_memory_usage: Optional[str]
name: Optional[str]
error: Optional[str]
stdout: Optional[str]
stderr: Optional[str]


class RealizationSnapshot(BaseModel):
Expand Down Expand Up @@ -479,15 +473,19 @@ def add_forward_model(
stderr: Optional[str] = None,
) -> "SnapshotBuilder":
self.forward_models[forward_model_id] = ForwardModel(
status=status,
index=index,
start_time=start_time,
end_time=end_time,
name=name,
stdout=stdout,
stderr=stderr,
current_memory_usage=current_memory_usage,
max_memory_usage=max_memory_usage,
**_filter_nones( # type: ignore
{
ids.STATUS: status,
ids.INDEX: index,
ids.START_TIME: start_time,
ids.END_TIME: end_time,
ids.NAME: name,
ids.STDOUT: stdout,
ids.STDERR: stderr,
ids.CURRENT_MEMORY_USAGE: current_memory_usage,
ids.MAX_MEMORY_USAGE: max_memory_usage,
}
)
)
return self

Expand Down
Loading

0 comments on commit 720be68

Please sign in to comment.