Skip to content

Commit

Permalink
Refactor snapshot metadata to be typeddict
Browse files Browse the repository at this point in the history
  • Loading branch information
jonathan-eq committed Jun 11, 2024
1 parent 48ab658 commit ded8131
Show file tree
Hide file tree
Showing 2 changed files with 55 additions and 32 deletions.
45 changes: 36 additions & 9 deletions src/ert/ensemble_evaluator/snapshot.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,11 +2,23 @@
import re
import typing
from collections import defaultdict
from typing import Any, Dict, Mapping, Optional, Sequence, Tuple, Union
from typing import (
Any,
DefaultDict,
Dict,
List,
Mapping,
Optional,
Sequence,
Tuple,
TypedDict,
Union,
)

from cloudevents.http import CloudEvent
from dateutil.parser import parse
from pydantic import BaseModel, ConfigDict
from qtpy.QtGui import QColor

from ert.ensemble_evaluator import identifiers as ids
from ert.ensemble_evaluator import state
Expand Down Expand Up @@ -67,7 +79,18 @@ def convert_iso8601_to_datetime(
return parse(timestamp)


def _filter_nones(some_dict: Mapping[str, Any]) -> Dict[str, Any]:
RealId = str
FmStepId = str


class SnapshotMetadata(TypedDict, total=False):
aggr_job_status_colors: DefaultDict[RealId, Dict[FmStepId, QColor]]
real_status_colors: Dict[RealId, QColor]
sorted_real_ids: List[RealId]
sorted_forward_model_ids: DefaultDict[RealId, List[FmStepId]]


def _filter_nones(some_dict: Dict[str, Any]) -> Dict[str, Any]:
return {key: value for key, value in some_dict.items() if value is not None}


Expand All @@ -90,18 +113,22 @@ def __init__(self, snapshot: Optional["Snapshot"] = None) -> None:
self._ensemble_state: Optional[str] = None
# TODO not sure about possible values at this point, as GUI hijacks this one as
# well
self._metadata: Dict[str, Any] = defaultdict(dict)

self._metadata = SnapshotMetadata(
aggr_job_status_colors=defaultdict(dict),
real_status_colors=dict(),
sorted_real_ids=list(),
sorted_forward_model_ids=defaultdict(list),
)
self._snapshot = snapshot

@property
def status(self) -> Optional[str]:
return self._ensemble_state

def update_metadata(self, metadata: Dict[str, Any]) -> None:
def update_metadata(self, metadata: SnapshotMetadata) -> None:
"""only used in gui snapshot model, which only cares about the partial
snapshot's metadata"""
self._metadata.update(_filter_nones(metadata))
self._metadata.update(metadata)

def update_realization(
self,
Expand Down Expand Up @@ -169,7 +196,7 @@ def get_real_ids(self) -> Sequence[str]:
return sorted(real_ids, key=int)

@property
def metadata(self) -> Mapping[str, Any]:
def metadata(self) -> SnapshotMetadata:
return self._metadata

def get_real(self, real_id: str) -> "RealizationSnapshot":
Expand Down Expand Up @@ -319,7 +346,7 @@ def merge_event(self, event: PartialSnapshot) -> None:
def merge(self, update_as_nested_dict: Mapping[str, Any]) -> None:
self._my_partial._merge(_from_nested_dict(update_as_nested_dict))

def merge_metadata(self, metadata: Dict[str, Any]) -> None:
def merge_metadata(self, metadata: SnapshotMetadata) -> None:
self._my_partial._metadata.update(metadata)

def to_dict(self) -> Dict[str, Any]:
Expand All @@ -330,7 +357,7 @@ def status(self) -> Optional[str]:
return self._my_partial._ensemble_state

@property
def metadata(self) -> Mapping[str, Any]:
def metadata(self) -> SnapshotMetadata:
return self._my_partial.metadata

def get_all_forward_models(
Expand Down
42 changes: 19 additions & 23 deletions src/ert/gui/model/snapshot.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,14 +2,15 @@
import logging
from collections import defaultdict
from contextlib import ExitStack
from typing import Any, Dict, Final, List, Mapping, Optional, Sequence, Union
from typing import Dict, Final, List, Mapping, Optional, Sequence, Union

from dateutil import tz
from qtpy.QtCore import QAbstractItemModel, QModelIndex, QObject, QSize, Qt, QVariant
from qtpy.QtGui import QColor, QFont

from ert.ensemble_evaluator import PartialSnapshot, Snapshot, state
from ert.ensemble_evaluator import identifiers as ids
from ert.ensemble_evaluator.snapshot import SnapshotMetadata
from ert.gui.model.node import (
ForwardModelStepNode,
IterNode,
Expand Down Expand Up @@ -42,11 +43,6 @@

DURATION = "Duration"

SORTED_REALIZATION_IDS = "_sorted_real_ids"
SORTED_JOB_IDS = "_sorted_forward_model_ids"
REAL_JOB_STATUS_AGGREGATED = "_aggr_job_status_colors"
REAL_STATUS_COLOR = "_real_status_colors"

COLUMNS: Dict[NodeType, Sequence[str]] = {
NodeType.ROOT: ["Name", "Status"],
NodeType.ITER: ["Name", "Status", "Active"],
Expand Down Expand Up @@ -108,24 +104,24 @@ def prerender(
if not reals and not forward_model_states:
return None

metadata: Dict[str, Any] = {
metadata = SnapshotMetadata(
# A mapping from real to job to that job's QColor status representation
REAL_JOB_STATUS_AGGREGATED: defaultdict(dict),
aggr_job_status_colors=defaultdict(dict),
# A mapping from real to that real's QColor status representation
REAL_STATUS_COLOR: defaultdict(dict),
}
real_status_colors=defaultdict(dict),
)

for real_id, real in reals.items():
if real.status:
metadata[REAL_STATUS_COLOR][real_id] = _QCOLORS[
metadata["real_status_colors"][real_id] = _QCOLORS[
state.REAL_STATE_TO_COLOR[real.status]
]

isSnapshot = False
if isinstance(snapshot, Snapshot):
isSnapshot = True
metadata[SORTED_REALIZATION_IDS] = sorted(snapshot.reals.keys(), key=int)
metadata[SORTED_JOB_IDS] = defaultdict(list)
metadata["sorted_real_ids"] = sorted(snapshot.reals.keys(), key=int)
metadata["sorted_forward_model_ids"] = defaultdict(list)

running_forward_model_id: Dict[str, int] = {}
for (
Expand All @@ -140,7 +136,7 @@ def prerender(
forward_model_id,
), forward_model_status in forward_model_states.items():
if isSnapshot:
metadata[SORTED_JOB_IDS][real_id].append(forward_model_id)
metadata["sorted_forward_model_ids"][real_id].append(forward_model_id)
if (
real_id in running_forward_model_id
and int(forward_model_id) > running_forward_model_id[real_id]
Expand All @@ -153,7 +149,7 @@ def prerender(
color = _QCOLORS[
state.FORWARD_MODEL_STATE_TO_COLOR[forward_model_status]
]
metadata[REAL_JOB_STATUS_AGGREGATED][real_id][forward_model_id] = color
metadata["aggr_job_status_colors"][real_id][forward_model_id] = color

if isSnapshot:
snapshot.merge_metadata(metadata)
Expand Down Expand Up @@ -195,13 +191,13 @@ def _add_partial_snapshot(self, partial: PartialSnapshot, iter_: int) -> None:
if real and real.status:
real_node.data.status = real.status
for real_forward_model_id, color in (
metadata[REAL_JOB_STATUS_AGGREGATED].get(real_id, {}).items()
metadata["aggr_job_status_colors"].get(real_id, {}).items()
):
real_node.data.forward_model_step_status_color_by_id[
real_forward_model_id
] = color
if real_id in metadata[REAL_STATUS_COLOR]:
real_node.data.real_status_color = metadata[REAL_STATUS_COLOR][
if real_id in metadata["real_status_colors"]:
real_node.data.real_status_color = metadata["real_status_colors"][
real_id
]
reals_changed.append(real_node.row())
Expand Down Expand Up @@ -278,9 +274,9 @@ def _add_snapshot(self, snapshot: Snapshot, iter_: int) -> None:
iter_,
data=IterNodeData(
status=snapshot.status,
sorted_realization_ids=metadata[SORTED_REALIZATION_IDS],
sorted_realization_ids=metadata["sorted_real_ids"],
sorted_forward_model_step_ids_by_realization_id=metadata[
SORTED_JOB_IDS
"sorted_forward_model_ids"
],
),
)
Expand All @@ -292,14 +288,14 @@ def _add_snapshot(self, snapshot: Snapshot, iter_: int) -> None:
status=real.status,
active=real.active,
forward_model_step_status_color_by_id=metadata[
REAL_JOB_STATUS_AGGREGATED
"aggr_job_status_colors"
][real_id],
real_status_color=metadata[REAL_STATUS_COLOR][real_id],
real_status_color=metadata["real_status_colors"][real_id],
),
)
snapshot_tree.add_child(real_node)

for forward_model_id in metadata[SORTED_JOB_IDS][real_id]:
for forward_model_id in metadata["sorted_forward_model_ids"][real_id]:
job = snapshot.get_job(real_id, forward_model_id)
job_node = ForwardModelStepNode(
id_=forward_model_id, data=job, parent=real_node
Expand Down

0 comments on commit ded8131

Please sign in to comment.