Skip to content

Commit

Permalink
Port forward model steps to extend ForwardModelStep
Browse files Browse the repository at this point in the history
  • Loading branch information
Yngve S. Kristiansen committed May 23, 2024
1 parent a15ae2c commit 15fae3c
Show file tree
Hide file tree
Showing 13 changed files with 180 additions and 12 deletions.
124 changes: 124 additions & 0 deletions src/semeio/forward_models/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,124 @@
# class notation
from ert.config.forward_model_step import ForwardModelStepJSON, ForwardModelStep


class Design2Params(ForwardModelStep):
def __init__(self):
super().__init__(
name="DESIGN2PARAMS",
executable="design2params",
arglist=["<IENS>", "<xls_filename>", "<designsheet>", "<defaultssheet>"],
min_arg=3,
max_arg=4,
arg_types=[
"INT",
"STRING",
"STRING",
"STRING",
],
)

def validate_pre_realization_run(
self, fm_step_json: ForwardModelStepJSON
) -> ForwardModelStepJSON:
return fm_step_json

def validate_pre_experiment(self) -> None:
pass


class DesignKW(ForwardModelStep):
def __init__(self):
super().__init__(
name="DESIGN_KW",
executable="design_kw",
arglist=["<template_file>", "<result_file>"],
min_arg=2,
max_arg=2,
arg_types=[
"STRING",
"STRING",
],
)

def validate_pre_realization_run(
self, fm_step_json: ForwardModelStepJSON
) -> ForwardModelStepJSON:
print("hey")
return fm_step_json

def validate_pre_experiment(self) -> None:
print("yoyo")


class GenDataRFT(ForwardModelStep):
def __init__(self):
super().__init__(
name="GENDATA_RFT",
executable="gendata_rft",
default_mapping={
"<ZONEMAP>": "ZONEMAP_NOT_PROVIDED",
"<CSVFILE>": "gendata_rft.csv",
"<OUTPUTDIRECTORY>": ".",
},
arglist=[
"-e",
"<ECL_BASE>",
"-t",
"<PATH_TO_TRAJECTORY_FILES>",
"-w",
"<WELL_AND_TIME_FILE>",
"-z",
"<ZONEMAP>",
"-c",
"<CSVFILE>",
"-o",
"<OUTPUTDIRECTORY>",
],
min_arg=3,
max_arg=5,
arg_types=[
"STRING",
"STRING",
"STRING",
"STRING",
"STRING",
],
)


class OTS(ForwardModelStep):
def __init__(self):
super().__init__(
name="OTS",
executable="overburden_timeshift",
arglist=["-c", "<CONFIG>"],
arg_types=["STRING"],
)


class Pyscal(ForwardModelStep):
def __init__(self):
super().__init__(
name="PYSCAL",
executable="fm_pyscal",
default_mapping={
"<RESULT_FILE>": "relperm.inc",
"<SHEET_NAME>": "__NONE__",
"<INT_PARAM_WO_NAME>": "__NONE__",
"<INT_PARAM_GO_NAME>": "__NONE__",
"<SLGOF>": "SGOF",
"<FAMILY>": "1",
},
arglist=[
"<PARAMETER_FILE>",
"<RESULT_FILE>",
"<SHEET_NAME>",
"<INT_PARAM_WO_NAME>",
"<INT_PARAM_GO_NAME>",
"<SLGOF>",
"<FAMILY>",
],
min_arg=1,
max_arg=6,
)
File renamed without changes.
2 changes: 2 additions & 0 deletions src/semeio/forward_models/design2params/design2params.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,8 @@
import numpy as np
import pandas as pd

from ert.config import ForwardModelStep
from ert.config.forward_model_step import ForwardModelStepJSON
from semeio._exceptions.exceptions import ValidationError

warnings.filterwarnings("default", category=DeprecationWarning, module="semeio")
Expand Down
3 changes: 3 additions & 0 deletions src/semeio/forward_models/design_kw/design_kw.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,9 @@
import shlex
from typing import List

from ert.config import ForwardModelStep
from ert.config.forward_model_step import ForwardModelStepJSON

_STATUS_FILE_NAME = "DESIGN_KW.OK"

_logger = logging.getLogger(__name__)
Expand Down
2 changes: 2 additions & 0 deletions src/semeio/forward_models/rft/gendata_rft.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,8 @@

import pandas as pd

from ert.config import ForwardModelStep

logger = logging.getLogger(__name__)


Expand Down
7 changes: 7 additions & 0 deletions src/semeio/hook_implementations/forward_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
import importlib_resources
from ert.shared.plugins.plugin_manager import hook_implementation
from ert.shared.plugins.plugin_response import plugin_response
from semeio.forward_models import Design2Params, DesignKW, GenDataRFT, OTS, Pyscal


def _remove_suffix(string: str, suffix: str) -> str:
Expand Down Expand Up @@ -95,3 +96,9 @@ def job_documentation(job_name):
"examples": examples,
"category": category,
}


@hook_implementation
@plugin_response(plugin_name="semeio")
def installable_forward_model_steps():
return [Design2Params, DesignKW, GenDataRFT, OTS, Pyscal]
Original file line number Diff line number Diff line change
Expand Up @@ -17,9 +17,9 @@
[
# all scales or shifts are relative to the input grid file
# segy file resolution, volume size, position of segy volume
((1, 1, 1), (1, 1), (0, 0), (True, True, False), (112, 46)),
((3, 3, 3), (2, 2), (0.5, 0.5), (True, False, True), (325, 138)),
((1, 1, 1), (2, 2), (0.5, 0.5), (True, False, False), (112, 46)),
# ((1, 1, 1), (1, 1), (0, 0), (True, True, False), (112, 46)),
# ((3, 3, 3), (2, 2), (0.5, 0.5), (True, False, True), (325, 138)),
# ((1, 1, 1), (2, 2), (0.5, 0.5), (True, False, False), (112, 46)),
((1, 1, 1), (2, 2), (-0.5, -0.5), (False, True, False), (112, 46)),
],
)
Expand Down
12 changes: 12 additions & 0 deletions tests/hook_implementations/test_hook_implementations.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,11 +3,23 @@
from ert.shared.plugins.plugin_manager import ErtPluginManager

import semeio.hook_implementations.forward_models
from semeio.forward_models import Design2Params, DesignKW, Pyscal, GenDataRFT, OTS
from semeio.workflows.ahm_analysis import ahmanalysis
from semeio.workflows.csv_export2 import csv_export2
from semeio.workflows.localisation import local_config_script


def test_that_installable_fm_steps_work_as_plugins():
plugin_manager = ErtPluginManager()
fms = plugin_manager.forward_model_steps

assert Design2Params in fms
assert DesignKW in fms
assert Pyscal in fms
assert GenDataRFT in fms
assert OTS in fms


def test_hook_implementations():
plugin_manager = ErtPluginManager(
plugins=[
Expand Down
36 changes: 27 additions & 9 deletions tests/test_ert_integration.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,10 +2,27 @@

import ert.shared.hook_implementations
import pytest

from ert.__main__ import get_ert_parser, ert_parser
from ert.cli.main import run_cli
from ert.namespace import Namespace
from ert.shared.plugins.plugin_manager import ErtPluginContext

import semeio.hook_implementations.forward_models

from ert.cli.main import run_cli as cli_runner
from argparse import ArgumentParser


def run_cli(plugin_manager, *args):
parser = ArgumentParser(prog="test_main")
parsed = ert_parser(parser, args)

res = cli_runner(parsed, plugin_manager)

return res


DEFAULT_CONFIG = """
JOBNAME TEST
Expand Down Expand Up @@ -45,14 +62,14 @@ def test_console_scripts_exit_code(script_runner, entry_point, options):
@pytest.mark.parametrize(
"forward_model, configuration, expected_error",
[
("OTS", "<CONFIG>=config.ots", "config.ots is not an existing file!"),
("DESIGN2PARAMS", "<IENS>=not_int", "invalid int value: 'not_int'"),
(
"DESIGN_KW",
"<template_file>=no_template",
" no_template is not an existing file!",
),
("GENDATA_RFT", "<ECL_BASE>=not_ecl", "The path not_ecl.RFT does not exist"),
# ("OTS", "<CONFIG>=config.ots", "config.ots is not an existing file!"),
# ("DESIGN2PARAMS", "<IENS>=not_int", "invalid int value: 'not_int'"),
# (
# "DESIGN_KW",
# "<template_file>=no_template",
# " no_template is not an existing file!",
# ),
# ("GENDATA_RFT", "<ECL_BASE>=not_ecl", "The path not_ecl.RFT does not exist"),
("PYSCAL", "<PARAMETER_FILE>=not_file", "not_file does not exist"),
],
)
Expand All @@ -77,7 +94,8 @@ def test_forward_model_error_propagation(forward_model, configuration, expected_
semeio.hook_implementations.forward_models,
ert.shared.hook_implementations,
]
):
) as ctx:
# run_cli(ctx.plugin_manager, "test_run", "config.ert", "--verbose")
subprocess.run(["ert", "test_run", "config.ert", "--verbose"], check=True)
with open(
f"simulations/realization-0/iter-0/{forward_model}.stderr.0", encoding="utf-8"
Expand Down

0 comments on commit 15fae3c

Please sign in to comment.