Skip to content

Commit

Permalink
Move report writing to client
Browse files Browse the repository at this point in the history
  • Loading branch information
oyvindeide committed Jun 11, 2024
1 parent 0092588 commit 7440057
Show file tree
Hide file tree
Showing 30 changed files with 1,320 additions and 1,593 deletions.
1 change: 1 addition & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -78,6 +78,7 @@ dependencies=[
"sortedcontainers",
"tables; python_version >= '3.9'",
"tables<3.9;python_version == '3.8'",
"tabulate",
"tqdm>=4.62.0",
"uvicorn >= 0.17.0",
"websockets",
Expand Down
125 changes: 38 additions & 87 deletions src/ert/analysis/_es_update.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,14 +3,10 @@
import functools
import logging
import time
from collections import defaultdict
from datetime import datetime
from fnmatch import fnmatch
from pathlib import Path
from typing import (
TYPE_CHECKING,
Callable,
DefaultDict,
Generic,
Iterable,
List,
Expand All @@ -36,6 +32,7 @@
from ..config.analysis_module import ESSettings, IESSettings
from . import misfit_preprocessor
from .event import (
AnalysisCompleteEvent,
AnalysisDataEvent,
AnalysisErrorEvent,
AnalysisEvent,
Expand All @@ -45,14 +42,13 @@
)
from .snapshots import (
ObservationAndResponseSnapshot,
ObservationStatus,
SmootherSnapshot,
)

if TYPE_CHECKING:
import numpy.typing as npt

from ert.storage import Ensemble, Experiment
from ert.storage import Ensemble

logger = logging.getLogger(__name__)

Expand Down Expand Up @@ -685,67 +681,6 @@ def analysis_IES(
return sies_smoother


def _write_update_report(
path: Path, snapshot: SmootherSnapshot, run_id: str, experiment: Experiment
) -> None:
update_step = snapshot.update_step_snapshots

(experiment._path / f"update_log_{run_id}.json").write_text(
snapshot.model_dump_json()
)

fname = path / f"{run_id}.txt"
fname.parent.mkdir(parents=True, exist_ok=True)
update_step = snapshot.update_step_snapshots

obs_info: DefaultDict[ObservationStatus, int] = defaultdict(lambda: 0)
for update in update_step:
obs_info[update.status] += 1

with open(fname, "w", encoding="utf-8") as fout:
fout.write("=" * 150 + "\n")
timestamp = datetime.now().strftime("%Y.%m.%d %H:%M:%S")
fout.write(f"Time: {timestamp}\n")
fout.write(f"Parent ensemble: {snapshot.source_ensemble_name}\n")
fout.write(f"Target ensemble: {snapshot.target_ensemble_name}\n")
fout.write(f"Alpha: {snapshot.alpha}\n")
fout.write(f"Global scaling: {snapshot.global_scaling}\n")
fout.write(f"Standard cutoff: {snapshot.std_cutoff}\n")
fout.write(f"Run id: {run_id}\n")
fout.write(f"Active observations: {obs_info[ObservationStatus.ACTIVE]}\n")
fout.write(
f"Deactivated observations - missing respons(es): {obs_info[ObservationStatus.MISSING_RESPONSE]}\n"
)
fout.write(
f"Deactivated observations - ensemble_std > STD_CUTOFF: {obs_info[ObservationStatus.STD_CUTOFF]}\n"
)
fout.write(
f"Deactivated observations - outlier: {obs_info[ObservationStatus.OUTLIER]}\n"
)
fout.write("-" * 150 + "\n")
fout.write(
"Observed history".rjust(56)
+ "|".rjust(17)
+ "Simulated data".rjust(32)
+ "|".rjust(13)
+ "Status".rjust(12)
+ "\n"
)
fout.write("-" * 150 + "\n")
for nr, step in enumerate(update_step):
obs_std = (
f"{step.obs_std:.3f}"
if step.obs_scaling == 1
else f"{step.obs_std * step.obs_scaling:.3f} ({step.obs_std:<.3f} * {step.obs_scaling:.3f})"
)
fout.write(
f"{nr+1:^6}: {step.obs_name:20} {step.obs_val:>16.3f} +/- "
f"{obs_std:<21} | {step.response_mean:>21.3f} +/- "
f"{step.response_std:<16.3f} {'|':<6} "
f"{step.get_status().capitalize()}\n"
)


def _assert_has_enough_realizations(
ens_mask: npt.NDArray[np.bool_], min_required_realizations: int
) -> None:
Expand Down Expand Up @@ -776,15 +711,13 @@ def _create_smoother_snapshot(
def smoother_update(
prior_storage: Ensemble,
posterior_storage: Ensemble,
run_id: str,
observations: Iterable[str],
parameters: Iterable[str],
analysis_config: Optional[UpdateSettings] = None,
es_settings: Optional[ESSettings] = None,
rng: Optional[np.random.Generator] = None,
progress_callback: Optional[Callable[[AnalysisEvent], None]] = None,
global_scaling: float = 1.0,
log_path: Optional[Path] = None,
) -> SmootherSnapshot:
if not progress_callback:
progress_callback = noop_progress_callback
Expand Down Expand Up @@ -819,24 +752,33 @@ def smoother_update(
analysis_config.auto_scale_observations,
)
except Exception as e:
progress_callback(
AnalysisErrorEvent(
error_msg=str(e),
data=DataSection(
header=smoother_snapshot.header,
data=smoother_snapshot.csv,
extra=smoother_snapshot.extra,
),
)
)
raise e
finally:
if log_path is not None:
_write_update_report(
log_path,
smoother_snapshot,
run_id,
prior_storage.experiment,
progress_callback(
AnalysisCompleteEvent(
data=DataSection(
header=smoother_snapshot.header,
data=smoother_snapshot.csv,
extra=smoother_snapshot.extra,
)

)
)
return smoother_snapshot


def iterative_smoother_update(
prior_storage: Ensemble,
posterior_storage: Ensemble,
sies_smoother: Optional[ies.SIES],
run_id: str,
parameters: Iterable[str],
observations: Iterable[str],
update_settings: UpdateSettings,
Expand All @@ -845,7 +787,6 @@ def iterative_smoother_update(
initial_mask: npt.NDArray[np.bool_],
rng: Optional[np.random.Generator] = None,
progress_callback: Optional[Callable[[AnalysisEvent], None]] = None,
log_path: Optional[Path] = None,
global_scaling: float = 1.0,
) -> Tuple[SmootherSnapshot, ies.SIES]:
if not progress_callback:
Expand Down Expand Up @@ -882,14 +823,24 @@ def iterative_smoother_update(
initial_mask=initial_mask,
)
except Exception as e:
progress_callback(
AnalysisErrorEvent(
error_msg=str(e),
data=DataSection(
header=smoother_snapshot.header,
data=smoother_snapshot.csv,
extra=smoother_snapshot.extra,
),
)
)
raise e
finally:
if log_path is not None:
_write_update_report(
log_path,
smoother_snapshot,
run_id,
prior_storage.experiment,
progress_callback(
AnalysisCompleteEvent(
data=DataSection(
header=smoother_snapshot.header,
data=smoother_snapshot.csv,
extra=smoother_snapshot.extra,
)

)
)
return smoother_snapshot, sies_smoother
5 changes: 5 additions & 0 deletions src/ert/analysis/event.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,3 +50,8 @@ class AnalysisDataEvent(AnalysisEvent):
class AnalysisErrorEvent(AnalysisEvent):
error_msg: str
data: Optional[DataSection] = None


@dataclass
class AnalysisCompleteEvent(AnalysisEvent):
data: DataSection
2 changes: 1 addition & 1 deletion src/ert/cli/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -129,7 +129,7 @@ def run_cli(args: Namespace, plugin_manager: Optional[ErtPluginManager] = None)
monitor = Monitor(out=out, color_always=args.color_always)
thread.start()
try:
monitor.monitor(status_queue)
monitor.monitor(status_queue, ert_config.analysis_config.log_path)
except (SystemExit, KeyboardInterrupt, OSError):
print("\nKilling simulations...")
model.cancel()
Expand Down
21 changes: 21 additions & 0 deletions src/ert/cli/monitor.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@

import sys
from datetime import datetime, timedelta
from pathlib import Path
from queue import SimpleQueue
from typing import Dict, Optional, TextIO, Tuple

Expand All @@ -22,6 +23,12 @@
REAL_STATE_TO_COLOR,
)
from ert.run_models.base_run_model import StatusEvents
from ert.run_models.event import (
RunModelDataEvent,
RunModelErrorEvent,
RunModelUpdateEndEvent,
)
from ert.shared.exporter import csv_event_to_report
from ert.shared.status.utils import format_running_time

Color = Tuple[int, int, int]
Expand Down Expand Up @@ -66,6 +73,7 @@ def __init__(self, out: TextIO = sys.stdout, color_always: bool = False) -> None
def monitor(
self,
event_queue: SimpleQueue[StatusEvents],
output_path: Optional[Path] = None,
) -> None:
self._start_time = datetime.now()
while True:
Expand All @@ -83,6 +91,19 @@ def monitor(
self._print_job_errors()
return

if (
isinstance(
event,
(RunModelDataEvent, RunModelUpdateEndEvent, RunModelErrorEvent),
)
and output_path
):
name = event.name if hasattr(event, "name") else "Report"
if event.data:
csv_event_to_report(
name, event.data, output_path / str(event.run_id)
)

def _print_job_errors(self) -> None:
failed_jobs: Dict[Optional[str], int] = {}
for snapshot in self._snapshots.values():
Expand Down
7 changes: 6 additions & 1 deletion src/ert/gui/simulation/experiment_panel.py
Original file line number Diff line number Diff line change
Expand Up @@ -282,7 +282,12 @@ def run_experiment(self) -> None:

if not abort:
dialog = RunDialog(
self._config_file, model, event_queue, self._notifier, self.parent()
self._config_file,
model,
event_queue,
self._notifier,
self.parent(),
output_path=self.ert.ert_config.analysis_config.log_path,
)
self.run_button.setEnabled(False)
self.run_button.setText(EXPERIMENT_IS_RUNNING_BUTTON_MESSAGE)
Expand Down
16 changes: 16 additions & 0 deletions src/ert/gui/simulation/run_dialog.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
import logging
from pathlib import Path
from queue import SimpleQueue
from typing import Optional

Expand Down Expand Up @@ -53,6 +54,7 @@
from ert.run_models.event import RunModelDataEvent, RunModelErrorEvent
from ert.shared.status.utils import byte_with_unit, format_running_time

from ...shared.exporter import csv_event_to_report
from ..find_ert_info import find_ert_info
from ..model.node import NodeType
from .queue_emitter import QueueEmitter
Expand All @@ -73,8 +75,10 @@ def __init__(
event_queue: SimpleQueue,
notifier: ErtNotifier,
parent=None,
output_path: Optional[Path] = None,
):
QDialog.__init__(self, parent)
self.output_path = output_path
self.setAttribute(Qt.WA_DeleteOnClose)
self.setWindowFlags(Qt.Window)
self.setWindowFlags(self.windowFlags() & ~Qt.WindowContextHelpButtonHint)
Expand Down Expand Up @@ -437,6 +441,18 @@ def _on_event(self, event: object):
if (widget := self._get_update_widget(event.iteration)) is not None:
widget.error(event)

if (
isinstance(
event, (RunModelDataEvent, RunModelUpdateEndEvent, RunModelErrorEvent)
)
and self.output_path
):
name = event.name if hasattr(event, "name") else "Report"
if event.data:
csv_event_to_report(
name, event.data, self.output_path / str(event.run_id)
)

def _get_update_widget(self, iteration: int) -> Optional[UpdateWidget]:
for i in range(0, self._tab_widget.count()):
widget = self._tab_widget.widget(i)
Expand Down
16 changes: 11 additions & 5 deletions src/ert/gui/tools/run_analysis/run_analysis_tool.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import functools
import uuid
from contextlib import contextmanager
from typing import Iterator, Optional
Expand Down Expand Up @@ -48,18 +49,20 @@ def run(self) -> None:
config = self._ert.ert_config
rng = np.random.default_rng(_seed_sequence(config.random_seed))
update_settings = config.analysis_config.observation_settings
update_id = uuid.uuid4()
try:
smoother_update(
self._source_ensemble,
self._target_ensemble,
str(uuid.uuid4()),
self._source_ensemble.experiment.observation_keys,
self._source_ensemble.experiment.update_parameters,
update_settings,
config.analysis_config.es_module,
rng,
self.send_smoother_event,
log_path=config.analysis_config.log_path,
functools.partial(
self.send_smoother_event,
update_id,
),
)
except ErtAnalysisError as e:
error = str(e)
Expand All @@ -68,13 +71,16 @@ def run(self) -> None:

self.finished.emit(error, self._source_ensemble.name)

def send_smoother_event(self, event: AnalysisEvent) -> None:
def send_smoother_event(self, run_id: uuid.UUID, event: AnalysisEvent) -> None:
if isinstance(event, AnalysisStatusEvent):
self.progress_update.emit(RunModelStatusEvent(iteration=0, msg=event.msg))
self.progress_update.emit(
RunModelStatusEvent(iteration=0, run_id=run_id, msg=event.msg)
)
elif isinstance(event, AnalysisTimeEvent):
self.progress_update.emit(
RunModelTimeEvent(
iteration=0,
run_id=run_id,
elapsed_time=event.elapsed_time,
remaining_time=event.remaining_time,
)
Expand Down
Loading

0 comments on commit 7440057

Please sign in to comment.