Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Make loading from netcdf lazy using load_by_guid etc. #5711

Merged
merged 5 commits into from
Feb 1, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
1 change: 1 addition & 0 deletions docs/changes/newsfragments/5711.improved
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
As an extension to the feature added in #5627 datasets are also no longer converted into QCoDeS format when loaded from netcdf using ``load_by_guid``, ``load_by_id``, ``load_by_run_spec``, ``load_by_counter``
37 changes: 13 additions & 24 deletions src/qcodes/dataset/data_set_in_memory.py
Original file line number Diff line number Diff line change
Expand Up @@ -332,34 +332,23 @@ def _load_from_db(cls, conn: ConnectionPlus, guid: str) -> DataSetInMem:
export_info=export_info,
snapshot=run_attributes["snapshot"],
)
xr_path = export_info.export_paths.get("nc")
xr_path_temp = export_info.export_paths.get("nc")
xr_path = Path(xr_path_temp) if xr_path_temp is not None else None

cls._set_cache_from_netcdf(ds, xr_path)
return ds

@classmethod
def _set_cache_from_netcdf(cls, ds: DataSetInMem, xr_path: str | None) -> bool:
import cf_xarray as cfxr
import xarray as xr
def _set_cache_from_netcdf(cls, ds: DataSetInMem, xr_path: Path | None) -> bool:

success = True
if xr_path is not None:
try:
loaded_data = xr.load_dataset(xr_path, engine="h5netcdf")
loaded_data = cfxr.coding.decode_compress_to_multi_index(loaded_data)
ds._cache = DataSetCacheInMem(ds)
ds._cache._data = cls._from_xarray_dataset_to_qcodes_raw_data(
loaded_data
)
except (
FileNotFoundError,
OSError,
): # older versions of h5py may throw a OSError here
success = False
warnings.warn(
"Could not load raw data for dataset with guid :"
f"{ds.guid} from location {xr_path}"
)
if xr_path is not None and xr_path.is_file():
ds._cache = DataSetCacheDeferred(ds, xr_path)
elif xr_path is not None and not xr_path.is_file():
success = False
warnings.warn(
jenshnielsen marked this conversation as resolved.
Show resolved Hide resolved
"Could not load raw data for dataset with guid : {ds.guid} from location {xr_path}"
)
else:
warnings.warn(f"No raw data stored for dataset with guid : {ds.guid}")
success = False
Expand All @@ -375,12 +364,12 @@ def set_netcdf_location(self, path: str | Path) -> None:
be able to use this method to update the metadata in the database to refer to
the new location.
"""
if isinstance(path, Path):
path = str(path)
if isinstance(path, str):
path = Path(path)
data_loaded = self._set_cache_from_netcdf(self, path)
if data_loaded:
export_info = self.export_info
export_info.export_paths["nc"] = path
export_info.export_paths["nc"] = str(path)
self._set_export_info(export_info)
else:
raise FileNotFoundError(f"Could not load a netcdf file from {path}")
Expand Down
81 changes: 81 additions & 0 deletions tests/dataset/test_dataset_export.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@
DataSetType,
Measurement,
get_data_export_path,
load_by_guid,
load_by_id,
load_from_netcdf,
new_data_set,
Expand Down Expand Up @@ -131,6 +132,21 @@ def _make_mock_dataset_grid(experiment) -> DataSet:
return dataset


@pytest.fixture(name="mock_dataset_in_mem_grid")
def _make_mock_dataset_in_mem_grid(experiment) -> DataSetProtocol:
meas = Measurement(exp=experiment, name="in_mem_ds")
meas.register_custom_parameter("x", paramtype="numeric")
meas.register_custom_parameter("y", paramtype="numeric")
meas.register_custom_parameter("z", paramtype="numeric", setpoints=("x", "y"))

with meas.run(dataset_class=DataSetType.DataSetInMem) as datasaver:
for x in range(10):
for y in range(20, 25):
results: list[tuple[str, int]] = [("x", x), ("y", y), ("z", x + y)]
datasaver.add_result(*results)
return datasaver.dataset


@pytest.fixture(name="mock_dataset_grid_with_shapes")
def _make_mock_dataset_grid_with_shapes(experiment) -> DataSet:
dataset = new_data_set("dataset")
Expand Down Expand Up @@ -1408,3 +1424,68 @@ def test_export_lazy_load(
getattr(ds, function_name)()

assert ds.cache._data != {}


@given(
function_name=hst.sampled_from(
[
"to_xarray_dataarray_dict",
"to_pandas_dataframe",
"to_pandas_dataframe_dict",
"get_parameter_data",
]
)
)
@settings(suppress_health_check=(HealthCheck.function_scoped_fixture,), deadline=None)
def test_export_lazy_load_in_mem_dataset(
tmp_path_factory: TempPathFactory,
mock_dataset_in_mem_grid: DataSet,
function_name: str,
) -> None:
tmp_path = tmp_path_factory.mktemp("export_netcdf")
path = str(tmp_path)
mock_dataset_in_mem_grid.export(
export_type="netcdf", path=tmp_path, prefix="qcodes_"
)

xr_ds = mock_dataset_in_mem_grid.to_xarray_dataset()
assert xr_ds["z"].dims == ("x", "y")

expected_path = f"qcodes_{mock_dataset_in_mem_grid.captured_run_id}_{mock_dataset_in_mem_grid.guid}.nc"
assert os.listdir(path) == [expected_path]
file_path = os.path.join(path, expected_path)
ds = load_from_netcdf(file_path)

# loading the dataset should not load the actual data into cache
assert ds.cache._data == {}
# loading directly into xarray should not round
# trip to qcodes format and therefor not fill the cache
xr_ds_reimported = ds.to_xarray_dataset()
assert ds.cache._data == {}

assert xr_ds_reimported["z"].dims == ("x", "y")
assert xr_ds.identical(xr_ds_reimported)

# but loading with any of these functions
# will currently fill the cache
getattr(ds, function_name)()

assert ds.cache._data != {}

dataset_loaded_by_guid = load_by_guid(mock_dataset_in_mem_grid.guid)

# loading the dataset should not load the actual data into cache
assert dataset_loaded_by_guid.cache._data == {}
# loading directly into xarray should not round
# trip to qcodes format and therefor not fill the cache
xr_ds_reimported = dataset_loaded_by_guid.to_xarray_dataset()
assert dataset_loaded_by_guid.cache._data == {}

assert xr_ds_reimported["z"].dims == ("x", "y")
assert xr_ds.identical(xr_ds_reimported)

# but loading with any of these functions
# will currently fill the cache
getattr(dataset_loaded_by_guid, function_name)()

assert dataset_loaded_by_guid.cache._data != {}