Skip to content
Merged
9 changes: 9 additions & 0 deletions conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -72,6 +72,15 @@ def pytest_addoption(parser):
}


try:
import pytest_xdist
except ImportError:
# If pytest-xdist is not available we provide a dummy worker_id fixture.
@pytest.fixture()
def worker_id():
return "master"


@pytest.fixture(params=_BACKENDS)
def backend(pytestconfig: pytest.Config, request: pytest.FixtureRequest):
backends_provided = any(map(pytestconfig.getoption, _BACKENDS))
Expand Down
8 changes: 8 additions & 0 deletions imaspy/backends/netcdf/db_entry_nc.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,12 +33,20 @@ def __init__(self, fname: str, mode: str, factory: IDSFactory) -> None:
"The `netCDF4` python module is not available. Please install this "
"module to read/write IMAS netCDF files with IMASPy."
)
# To support netcdf v1.4 (which has no mode "x") we map it to "w" with
# `clobber=True`.
if mode == "x":
mode = "w"
clobber = False
else:
clobber = True

self._dataset = netCDF4.Dataset(
fname,
mode,
format="NETCDF4",
auto_complex=True,
clobber=clobber,
)
"""NetCDF4 dataset."""
self._factory = factory
Expand Down
15 changes: 14 additions & 1 deletion imaspy/backends/netcdf/ids2nc.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,8 +7,10 @@

import netCDF4
import numpy
from packaging import version

from imaspy.backends.netcdf.nc_metadata import NCMetadata
from imaspy.exception import InvalidNetCDFEntry
from imaspy.ids_base import IDSBase
from imaspy.ids_data_type import IDSDataType
from imaspy.ids_defs import IDS_TIME_MODE_HOMOGENEOUS
Expand Down Expand Up @@ -185,9 +187,20 @@ def create_variables(self) -> None:

else:
dtype = dtypes[metadata.data_type]
if (
version.parse(netCDF4.__version__) < version.parse("1.7.0")
and dtype is dtypes[IDSDataType.CPX]
):
raise InvalidNetCDFEntry(
f"Found complex data in {var_name}, NetCDF 1.7.0 or"
f" later is required for complex data types"
)
kwargs = {}
if dtype is not str: # Enable compression:
kwargs.update(compression="zlib", complevel=1)
if version.parse(netCDF4.__version__) > version.parse("1.4.1"):
kwargs.update(compression="zlib", complevel=1)
else:
kwargs.update(zlib=True, complevel=1)
if dtype is not dtypes[IDSDataType.CPX]: # Set fillvalue
kwargs.update(fill_value=default_fillvals[metadata.data_type])
# Create variable
Expand Down
9 changes: 5 additions & 4 deletions imaspy/backends/netcdf/nc_validate.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,23 +23,24 @@ def validate_netcdf_file(filename: str) -> None:
# additional variables are smuggled inside:
groups = [dataset] + [dataset[group] for group in dataset.groups]
for group in groups:
group_name = group.path.split("/")[-1]
if group.variables or group.dimensions:
raise InvalidNetCDFEntry(
"NetCDF file should not have variables or dimensions in the "
f"{group.name} group."
f"{group_name} group."
)
if group is dataset:
continue
if group.name not in ids_names:
if group_name not in ids_names:
raise InvalidNetCDFEntry(
f"Invalid group name {group.name}: there is no IDS with this name."
f"Invalid group name {group_name}: there is no IDS with this name."
)
for subgroup in group.groups:
try:
int(subgroup)
except ValueError:
raise InvalidNetCDFEntry(
f"Invalid group name {group.name}/{subgroup}: "
f"Invalid group name {group_name}/{subgroup}: "
f"{subgroup} is not a valid occurrence number."
)

Expand Down
5 changes: 4 additions & 1 deletion imaspy/ids_primitive.py
Original file line number Diff line number Diff line change
Expand Up @@ -481,7 +481,10 @@ def _cast_value(self, value):
value = np.asanyarray(value)
if value.dtype != dtype:
logger.info(_CONVERT_MSG, value.dtype, self)
value = np.array(value, dtype=dtype, copy=False)
value = np.asarray(
value,
dtype=dtype,
)
if value.ndim != self.metadata.ndim:
raise ValueError(f"Trying to assign a {value.ndim}D value to {self!r}.")
return value
Expand Down
12 changes: 9 additions & 3 deletions imaspy/test/test_cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
from click.testing import CliRunner
from packaging.version import Version

from imaspy.backends.imas_core.imas_interface import has_imas
from imaspy.backends.imas_core.imas_interface import ll_interface
from imaspy.command.cli import print_version
from imaspy.command.db_analysis import analyze_db, process_db_analysis
Expand All @@ -12,15 +13,20 @@


@pytest.mark.cli
def test_imaspy_version():
def test_imaspy_version(requires_imas):
runner = CliRunner()
result = runner.invoke(print_version)
assert result.exit_code == 0


@pytest.mark.cli
@pytest.mark.skipif(ll_interface._al_version < Version("5.0"), reason="Needs AL >= 5")
def test_db_analysis(tmp_path):
@pytest.mark.skipif(
not has_imas or ll_interface._al_version < Version("5.0"),
reason="Needs AL >= 5 AND Requires IMAS Core.",
)
def test_db_analysis(
tmp_path,
):
# This only tests the happy flow, error handling is not tested
db_path = tmp_path / "test_db_analysis"
with DBEntry(f"imas:hdf5?path={db_path}", "w") as entry:
Expand Down
2 changes: 1 addition & 1 deletion imaspy/test/test_dbentry.py
Original file line number Diff line number Diff line change
Expand Up @@ -82,7 +82,7 @@ def test_dbentry_constructor():
assert get_entry_attrs(entry) == (1, 2, 3, 4, None, 6)


def test_ignore_unknown_dd_version(monkeypatch, worker_id, tmp_path):
def test_ignore_unknown_dd_version(monkeypatch, worker_id, tmp_path, requires_imas):
entry = open_dbentry(imaspy.ids_defs.MEMORY_BACKEND, "w", worker_id, tmp_path)
ids = entry.factory.core_profiles()
ids.ids_properties.homogeneous_time = 0
Expand Down
30 changes: 22 additions & 8 deletions imaspy/test/test_helpers.py
Original file line number Diff line number Diff line change
Expand Up @@ -93,7 +93,9 @@ def fill_with_random_data(structure, max_children=3):
child.value = random_data(child.metadata.data_type, child.metadata.ndim)


def maybe_set_random_value(primitive: IDSPrimitive, leave_empty: float) -> None:
def maybe_set_random_value(
primitive: IDSPrimitive, leave_empty: float, skip_complex: bool
) -> None:
"""Set the value of an IDS primitive with a certain chance.

If the IDSPrimitive has coordinates, then the size of the coordinates is taken into
Expand Down Expand Up @@ -153,7 +155,7 @@ def maybe_set_random_value(primitive: IDSPrimitive, leave_empty: float) -> None:
# Scale chance of not setting a coordinate by our number of dimensions,
# such that overall there is roughly a 50% chance that any coordinate
# remains empty
maybe_set_random_value(coordinate_element, 0.5**ndim)
maybe_set_random_value(coordinate_element, 0.5**ndim, skip_complex)
size = coordinate_element.shape[0 if coordinate.references else dim]

if coordinate.size: # coordinateX = <path> OR 1...1
Expand All @@ -176,13 +178,18 @@ def maybe_set_random_value(primitive: IDSPrimitive, leave_empty: float) -> None:
elif primitive.metadata.data_type is IDSDataType.FLT:
primitive.value = np.random.random_sample(size=shape)
elif primitive.metadata.data_type is IDSDataType.CPX:
if skip_complex:
# If we are skipping complex numbers then leave the value empty.
return
val = np.random.random_sample(shape) + 1j * np.random.random_sample(shape)
primitive.value = val
else:
raise ValueError(f"Invalid IDS data type: {primitive.metadata.data_type}")


def fill_consistent(structure: IDSStructure, leave_empty: float = 0.2):
def fill_consistent(
structure: IDSStructure, leave_empty: float = 0.2, skip_complex: bool = False
):
"""Fill a structure with random data, such that coordinate sizes are consistent.

Sets homogeneous_time to heterogeneous (always).
Expand All @@ -196,6 +203,9 @@ def fill_consistent(structure: IDSStructure, leave_empty: float = 0.2):
exclusive_coordinates: list of IDSPrimitives that have exclusive alternative
coordinates. These are initially not filled, and only at the very end of
filling an IDSToplevel, a choice is made between the exclusive coordinates.
skip_complex: Whether to skip over populating complex numbers. This is
useful for maintaining compatibility with older versions of netCDF4
(<1.7.0) where complex numbers are not supported.
"""
if isinstance(structure, IDSToplevel):
unsupported_ids_name = (
Expand All @@ -218,7 +228,9 @@ def fill_consistent(structure: IDSStructure, leave_empty: float = 0.2):

for child in structure:
if isinstance(child, IDSStructure):
exclusive_coordinates.extend(fill_consistent(child, leave_empty))
exclusive_coordinates.extend(
fill_consistent(child, leave_empty, skip_complex)
)

elif isinstance(child, IDSStructArray):
if child.metadata.coordinates[0].references:
Expand All @@ -230,7 +242,7 @@ def fill_consistent(structure: IDSStructure, leave_empty: float = 0.2):
if isinstance(coor, IDSPrimitive):
# maybe fill with random data:
try:
maybe_set_random_value(coor, leave_empty)
maybe_set_random_value(coor, leave_empty, skip_complex)
except (RuntimeError, ValueError):
pass
child.resize(len(coor))
Expand All @@ -244,7 +256,9 @@ def fill_consistent(structure: IDSStructure, leave_empty: float = 0.2):
else:
child.resize(child.metadata.coordinates[0].size or 1)
for ele in child:
exclusive_coordinates.extend(fill_consistent(ele, leave_empty))
exclusive_coordinates.extend(
fill_consistent(ele, leave_empty, skip_complex)
)

else: # IDSPrimitive
coordinates = child.metadata.coordinates
Expand All @@ -256,7 +270,7 @@ def fill_consistent(structure: IDSStructure, leave_empty: float = 0.2):
exclusive_coordinates.append(child)
else:
try:
maybe_set_random_value(child, leave_empty)
maybe_set_random_value(child, leave_empty, skip_complex)
except (RuntimeError, ValueError):
pass

Expand All @@ -278,7 +292,7 @@ def fill_consistent(structure: IDSStructure, leave_empty: float = 0.2):
coor = filled_refs.pop()
unset_coordinate(coor)

maybe_set_random_value(element, leave_empty)
maybe_set_random_value(element, leave_empty, skip_complex)
else:
return exclusive_coordinates

Expand Down
2 changes: 1 addition & 1 deletion imaspy/test/test_ids_toplevel.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,7 +46,7 @@ def test_pretty_print(ids):
assert pprint.pformat(ids) == "<IDSToplevel (IDS:gyrokinetics)>"


def test_serialize_nondefault_dd_version():
def test_serialize_nondefault_dd_version(requires_imas):
ids = IDSFactory("3.31.0").core_profiles()
fill_with_random_data(ids)
data = ids.serialize()
Expand Down
13 changes: 11 additions & 2 deletions imaspy/test/test_minimal_types.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
# A minimal testcase loading an IDS file and checking that the structure built is ok
from numbers import Complex, Integral, Number, Real
from packaging import version

import numpy as np
import pytest
Expand Down Expand Up @@ -61,7 +62,11 @@ def test_assign_str_1d(minimal, caplog):


# Prevent the expected numpy ComplexWarnings from cluttering pytest output
@pytest.mark.filterwarnings("ignore::numpy.ComplexWarning")
@pytest.mark.filterwarnings(
"ignore::numpy.ComplexWarning"
if version.parse(np.__version__) < version.parse("1.25")
else "ignore::numpy.exceptions.ComplexWarning"
)
@pytest.mark.parametrize("typ, max_dim", [("flt", 6), ("cpx", 6), ("int", 3)])
def test_assign_numeric_types(minimal, caplog, typ, max_dim):
caplog.set_level("INFO", "imaspy")
Expand All @@ -87,7 +92,11 @@ def test_assign_numeric_types(minimal, caplog, typ, max_dim):
len(caplog.records) == 1
elif dim == other_ndim >= 1 and other_typ == "cpx":
# np allows casting of complex to float or int, but warns:
with pytest.warns(np.ComplexWarning):
with pytest.warns(
np.ComplexWarning
if version.parse(np.__version__) < version.parse("1.25")
else np.exceptions.ComplexWarning
):
caplog.clear()
minimal[name].value = value
assert len(caplog.records) == 1
Expand Down
2 changes: 1 addition & 1 deletion imaspy/test/test_nbc_change.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,7 +54,7 @@ def test_nbc_structure_to_aos(caplog):
assert caplog.record_tuples[0][:2] == ("imaspy.ids_convert", logging.WARNING)


def test_nbc_0d_to_1d(caplog):
def test_nbc_0d_to_1d(caplog, requires_imas):
# channel/filter_spectrometer/radiance_calibration in spectrometer visible changed
# from FLT_0D to FLT_1D in DD 3.39.0
ids = IDSFactory("3.32.0").spectrometer_visible()
Expand Down
49 changes: 47 additions & 2 deletions imaspy/test/test_nc_autofill.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,56 @@
from imaspy.db_entry import DBEntry
from imaspy.exception import InvalidNetCDFEntry
from imaspy.test.test_helpers import compare_children, fill_consistent
import re
import pytest
import netCDF4
from packaging import version


def test_nc_latest_dd_autofill_put_get(ids_name, tmp_path):
def test_nc_latest_dd_autofill_put_get_skip_complex(ids_name, tmp_path):
with DBEntry(f"{tmp_path}/test-{ids_name}.nc", "x") as entry:
ids = entry.factory.new(ids_name)
fill_consistent(ids, 0.5)
fill_consistent(ids, leave_empty=0.5, skip_complex=True)

entry.put(ids)
ids2 = entry.get(ids_name)

compare_children(ids, ids2)


@pytest.mark.skipif(
version.parse(netCDF4.__version__) >= version.parse("1.7.0"),
reason="NetCDF4 versions < 1.7.0 do not support complex numbers",
)
def test_nc_latest_dd_autofill_put_get_with_complex_older_netCDF4(
ids_name, tmp_path
):
with DBEntry(f"{tmp_path}/test-{ids_name}.nc", "x") as entry:
ids = entry.factory.new(ids_name)
fill_consistent(ids, leave_empty=0.5, skip_complex=False)
try:
entry.put(ids)
ids2 = entry.get(ids_name)
compare_children(ids, ids2)
except InvalidNetCDFEntry as e:
# This is expected, as these versions of NetCDF4 do not support
# complex numbers.
if not re.search(
r".*NetCDF 1.7.0 or later is required for complex data types", str(e)
):
raise InvalidNetCDFEntry(e) from e


@pytest.mark.skipif(
version.parse(netCDF4.__version__) < version.parse("1.7.0"),
reason="NetCDF4 versions >= 1.7.0 support complex numbers",
)
def test_nc_latest_dd_autofill_put_get_with_complex_newer_netCDF4(
ids_name, tmp_path
):
with DBEntry(f"{tmp_path}/test-{ids_name}.nc", "x") as entry:
ids = entry.factory.new(ids_name)
fill_consistent(ids, leave_empty=0.5, skip_complex=False)

entry.put(ids)
ids2 = entry.get(ids_name)
Expand Down
2 changes: 1 addition & 1 deletion imaspy/test/test_static_ids.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@ def test_ids_valid_type():
assert ids_types in ({IDSType.NONE}, {IDSType.CONSTANT, IDSType.DYNAMIC})


def test_constant_ids(caplog):
def test_constant_ids(caplog, requires_imas):
ids = imaspy.IDSFactory().new("amns_data")
if ids.metadata.type is IDSType.NONE:
pytest.skip("IDS definition has no constant IDSs")
Expand Down
6 changes: 3 additions & 3 deletions imaspy/test/test_util.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,7 +54,7 @@ def test_inspect():
inspect(cp.profiles_1d[1].grid.rho_tor_norm) # IDSPrimitive


def test_inspect_lazy():
def test_inspect_lazy(requires_imas):
with get_training_db_entry() as entry:
cp = entry.get("core_profiles", lazy=True)
inspect(cp)
Expand Down Expand Up @@ -141,7 +141,7 @@ def test_idsdiffgen():
assert diff[0] == ("profiles_1d/time", -1, 0)


def test_idsdiff():
def test_idsdiff(requires_imas):
# Test the diff rendering for two sample IDSs
with get_training_db_entry() as entry:
imaspy.util.idsdiff(entry.get("core_profiles"), entry.get("equilibrium"))
Expand Down Expand Up @@ -179,7 +179,7 @@ def test_get_toplevel():
assert get_toplevel(cp) is cp


def test_is_lazy_loaded():
def test_is_lazy_loaded(requires_imas):
with get_training_db_entry() as entry:
assert is_lazy_loaded(entry.get("core_profiles")) is False
assert is_lazy_loaded(entry.get("core_profiles", lazy=True)) is True
Expand Down
Loading