Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
20 changes: 19 additions & 1 deletion src/esnb/core/NotebookDiagnostic.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,8 +9,10 @@
from esnb.sites import gfdl

from . import html, util
from .CaseGroup2 import CaseGroup2
from .RequestedVariable import RequestedVariable
from .util2 import flatten_list, generate_tempdir_path, read_json, reset_encoding
from .util2 import (flatten_list, generate_tempdir_path, read_json,
reset_encoding)
from .VirtualDataset import VirtualDataset

# import warnings
Expand Down Expand Up @@ -442,6 +444,22 @@ def resolve(self, groups=None):
groups : list or None, optional
List of groups to resolve. If None, uses an empty list.
"""
esnb_case_data = os.environ.get("ESNB_CASE_DATA", None)
esnb_case_file = os.environ.get("ESNB_CASE_FILE", None)

logger.debug(f"Case override settings: ESNB_CASE_DATA={esnb_case_data}")
logger.debug(f"Case override settings: ESNB_CASE_FILE={esnb_case_file}")

if esnb_case_data is not None:
logger.info("Converting case override data to dict")
logger.info("This feature is not fully implemented; falling back to original groups")
groups = groups
elif esnb_case_file is not None:
logger.info(f"Reading case override settings from file: {esnb_case_file}")
if not os.path.exists(esnb_case_file):
raise FileNotFoundError(f"File does not exist: {esnb_case_file}")
groups = [CaseGroup2(esnb_case_file)]

groups = [] if groups is None else groups
groups = [groups] if not isinstance(groups, list) else groups
self.groups = groups
Expand Down
16 changes: 4 additions & 12 deletions src/esnb/core/__init__.py
Original file line number Diff line number Diff line change
@@ -1,15 +1,6 @@
from . import (
CaseExperiment,
CaseExperiment2,
CaseGroup,
CaseGroup2,
NotebookDiagnostic,
RequestedVariable,
util,
util2,
util_catalog,
util_xr,
)
from . import (CaseExperiment, CaseExperiment2, CaseGroup, CaseGroup2,
NotebookDiagnostic, RequestedVariable, util, util2,
util_catalog, util_mdtf, util_xr)

__all__ = [
"CaseExperiment",
Expand All @@ -22,6 +13,7 @@
"util2",
"util_case",
"util_catalog",
"util_mdtf",
"util_xr",
]

Expand Down
49 changes: 49 additions & 0 deletions src/esnb/core/util_mdtf.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,49 @@
"""Module with mdtf helper functions"""


def mdtf_settings_template_dict(**kwargs):
settings = {}
settings["pod_list"] = []
settings["case_list"] = {
"case_name": "case_name",
"model": "CMIP",
"convention": "CMIP",
"startdate": "00010101000000",
"enddate": "99990101000000",
}
settings["DATA_CATALOG"] = ""
settings["OBS_DATA_ROOT"] = ""
settings["WORK_DIR"] = ""
settings["OUTPUT_DIR"] = ""
settings["conda_root"] = ""
settings["conda_env_root"] = ""
settings["micromamba_exe"] = ""
settings["large_file"] = False
settings["save_ps"] = False
settings["save_pp_data"] = True
settings["translate_data"] = True
settings["make_variab_tar"] = False
settings["overwrite"] = True
settings["make_multicase_figure_html"] = False
settings["run_pp"] = True
settings["user_pp_scripts"] = {}

required_keys = list(settings.keys())
for key in required_keys:
if isinstance(settings[key], dict):
subkeys = list(settings[key].keys())
for subkey in subkeys:
result = kwargs.pop(subkey,"null result")
if result != "null result":
settings[key][subkey] = result
else:
result = kwargs.pop(key,"null result")
if result != "null result":
settings[key] = result

leftover_keys = list(kwargs.keys())
if len(leftover_keys) > 0:
for key in leftover_keys:
settings[key] = kwargs.pop(key, "")

return settings
14 changes: 12 additions & 2 deletions tests/test_mdtf.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
import pytest

import esnb
import pytest
from esnb.core import mdtf
from esnb.core.util_mdtf import mdtf_settings_template_dict


def test_MDTFCaseSettings():
Expand All @@ -14,3 +14,13 @@ def test_MDTFCaseSettings_invalid_file():
with pytest.raises(FileNotFoundError):
x = mdtf.MDTFCaseSettings
x = x.load_mdtf_settings_file(x, "non_existent_file.yml")


def test_mdtf_settings_template_dict_1():
result = mdtf_settings_template_dict()
assert len(result) == 18


def test_mdtf_settings_template_dict_2():
result = mdtf_settings_template_dict(foo="bar", startdate="18501231000000")
assert len(result) == 19