Skip to content

Commit

Permalink
Merge pull request #5391 from jenshnielsen/export_large_dataset
Browse files Browse the repository at this point in the history
Use Dask delayed to export large datasets to NetCDF
  • Loading branch information
jenshnielsen committed Oct 4, 2023
2 parents 30659dd + 9f43f65 commit 4b58493
Show file tree
Hide file tree
Showing 6 changed files with 245 additions and 35 deletions.
3 changes: 3 additions & 0 deletions docs/changes/newsfragments/5391.new
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
Large datasets are now exported to NetCDF4 using Dask delayed writer.
This avoids allocating a large amount of memory to process the whole dataset at the same time.
Size threshold at the moment is set to approximately 1 GB.
3 changes: 2 additions & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -40,14 +40,15 @@ dependencies = [
"ruamel.yaml>=0.16.0,!=0.16.6",
"tabulate>=0.8.0",
"typing_extensions>=4.1.1",
"tqdm>=4.32.2",
"tqdm>=4.59.0",
"uncertainties>=3.1.4",
"versioningit>=2.0.1",
"websockets>=9.1",
"wrapt>=1.13.2",
"xarray>=2022.06.0",
"cf_xarray>=0.8.4",
"opentelemetry-api>=1.15.0",
"dask>=2022.1.0", # we are making use of xarray features that requires dask implicitly
# transitive dependencies. We list these explicitly to",
# ensure that we always use versions that do not have",
# known security vulnerabilities",
Expand Down
104 changes: 87 additions & 17 deletions qcodes/dataset/data_set.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,15 +3,19 @@
import importlib
import json
import logging
import sys
import tempfile
import time
import uuid
from collections.abc import Callable, Mapping, Sequence
from dataclasses import dataclass
from pathlib import Path
from queue import Queue
from threading import Thread
from typing import TYPE_CHECKING, Any

import numpy
from tqdm.auto import trange

import qcodes
from qcodes.dataset.data_set_protocol import (
Expand Down Expand Up @@ -76,7 +80,6 @@
)
from qcodes.utils import (
NumpyJSONEncoder,
QCoDeSDeprecationWarning,
deprecate,
issue_deprecation_warning,
)
Expand All @@ -93,6 +96,7 @@
from .exporters.export_to_xarray import (
load_to_xarray_dataarray_dict,
load_to_xarray_dataset,
xarray_to_h5netcdf_with_complex_numbers,
)
from .subscriber import _Subscriber

Expand Down Expand Up @@ -244,6 +248,7 @@ def __init__(
self._cache: DataSetCacheWithDBBackend = DataSetCacheWithDBBackend(self)
self._results: list[dict[str, VALUE]] = []
self._in_memory_cache = in_memory_cache
self._export_limit = 1000

if run_id is not None:
if not run_exists(self.conn, run_id):
Expand Down Expand Up @@ -859,7 +864,6 @@ def to_pandas_dataframe_dict(
a column and a indexed by a :py:class:`pandas.MultiIndex` formed
by the dependencies.
"""
self._warn_if_set(*params, start=start, end=end)
datadict = self.get_parameter_data(*params,
start=start,
end=end)
Expand Down Expand Up @@ -958,7 +962,6 @@ def to_pandas_dataframe(
Return a pandas DataFrame with
df = ds.to_pandas_dataframe()
"""
self._warn_if_set(*params, start=start, end=end)
datadict = self.get_parameter_data(*params,
start=start,
end=end)
Expand Down Expand Up @@ -1010,7 +1013,6 @@ def to_xarray_dataarray_dict(
dataarray_dict = ds.to_xarray_dataarray_dict()
"""
self._warn_if_set(*params, start=start, end=end)
data = self.get_parameter_data(*params,
start=start,
end=end)
Expand Down Expand Up @@ -1061,7 +1063,6 @@ def to_xarray_dataset(
xds = ds.to_xarray_dataset()
"""
self._warn_if_set(*params, start=start, end=end)
data = self.get_parameter_data(*params,
start=start,
end=end)
Expand Down Expand Up @@ -1457,20 +1458,89 @@ def _set_export_info(self, export_info: ExportInfo) -> None:

self._export_info = export_info

@staticmethod
def _warn_if_set(
*params: str | ParamSpec | ParameterBase,
start: int | None = None,
end: int | None,
) -> None:
if len(params) > 0 or start is not None or end is not None:
QCoDeSDeprecationWarning(
"Passing params, start or stop to to_xarray_... and "
"to_pandas_... methods is deprecated "
"This will be an error in the future. "
"If you need to sub-select use `dataset.get_parameter_data`"
def _export_as_netcdf(self, path: Path, file_name: str) -> Path:
"""Export data as netcdf to a given path with file prefix"""
import xarray as xr

file_path = path / file_name
if self._estimate_ds_size() > self._export_limit:
log.info(
"Dataset is expected to be larger that threshold. Using distributed export.",
extra={
"file_name": file_path,
"qcodes_guid": self.guid,
"ds_name": self.name,
"exp_name": self.exp_name,
"_export_limit": self._export_limit,
"_estimated_ds_size": self._estimate_ds_size(),
},
)
print(
"Large dataset detected. Will write to multiple files first and combine after, to reduce memory overhead."
)
with tempfile.TemporaryDirectory() as temp_dir:
temp_path = Path(temp_dir)
log.info(
"Writing individual files to temp dir.",
extra={
"file_name": file_path,
"qcodes_guid": self.guid,
"ds_name": self.name,
"exp_name": self.exp_name,
"temp_dir": temp_dir,
},
)
num_files = len(self)
num_digits = len(str(num_files))
file_name_template = f"ds_{{:0{num_digits}d}}.nc"
for i in trange(num_files, desc="Writing individual files"):
xarray_to_h5netcdf_with_complex_numbers(
self.to_xarray_dataset(start=i + 1, end=i + 1),
temp_path / file_name_template.format(i),
)
files = tuple(temp_path.glob("*.nc"))
data = xr.open_mfdataset(files)
try:
log.info(
"Combining temp files into one file.",
extra={
"file_name": file_path,
"qcodes_guid": self.guid,
"ds_name": self.name,
"exp_name": self.exp_name,
"temp_dir": temp_dir,
},
)
xarray_to_h5netcdf_with_complex_numbers(
data, file_path, compute=False
)
finally:
data.close()
else:
log.info(
"Writing netcdf file directly.",
extra={"file_name": file_path},
)

file_path = super()._export_as_netcdf(path=path, file_name=file_name)
return file_path

def _estimate_ds_size(self) -> float:
"""
Give an estimated size of the dataset as the size of a single row
times the len of the dataset. Result is returned in Mega Bytes.
Note that this does not take overhead into account so it is more accurate
if the row size is "large"
"""
sample_data = self.get_parameter_data(start=1, end=1)
row_size = 0.0

for param_data in sample_data.values():
for array in param_data.values():
row_size += sys.getsizeof(array)
return row_size * len(self) / 1024 / 1024


# public api
def load_by_run_spec(
Expand Down
18 changes: 9 additions & 9 deletions qcodes/dataset/data_set_protocol.py
Original file line number Diff line number Diff line change
@@ -1,16 +1,8 @@
from __future__ import annotations

import sys

if sys.version_info >= (3, 10):
# new entrypoints api was added in 3.10
from importlib.metadata import entry_points
else:
# 3.9 and earlier
from importlib_metadata import entry_points

import logging
import os
import sys
import warnings
from collections.abc import Mapping, Sequence
from enum import Enum
Expand Down Expand Up @@ -39,6 +31,13 @@
from .exporters.export_to_xarray import xarray_to_h5netcdf_with_complex_numbers
from .sqlite.queries import raw_time_to_str_time

if sys.version_info >= (3, 10):
# new entrypoints api was added in 3.10
from importlib.metadata import entry_points
else:
# 3.9 and earlier
from importlib_metadata import entry_points

if TYPE_CHECKING:
import pandas as pd
import xarray as xr
Expand Down Expand Up @@ -242,6 +241,7 @@ def get_parameter_data(
) -> ParameterData:
...


def get_parameters(self) -> SPECS:
# used by plottr
...
Expand Down
31 changes: 23 additions & 8 deletions qcodes/dataset/exporters/export_to_xarray.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,14 @@
from __future__ import annotations

import logging
import warnings
from collections.abc import Hashable, Mapping
from math import prod
from pathlib import Path
from typing import TYPE_CHECKING, cast

import numpy as np
from tqdm.dask import TqdmCallback

from qcodes.dataset.linked_datasets.links import links_to_str

Expand All @@ -23,6 +25,8 @@

from qcodes.dataset.data_set_protocol import DataSetProtocol, ParameterData

_LOG = logging.getLogger(__name__)


def _calculate_index_shape(idx: pd.Index | pd.MultiIndex) -> dict[Hashable, int]:
# heavily inspired by xarray.core.dataset.from_dataframe
Expand Down Expand Up @@ -207,7 +211,7 @@ def _paramspec_dict_with_extras(


def xarray_to_h5netcdf_with_complex_numbers(
xarray_dataset: xr.Dataset, file_path: str | Path
xarray_dataset: xr.Dataset, file_path: str | Path, compute: bool = True
) -> None:
import cf_xarray as cfxr
from pandas import MultiIndex
Expand All @@ -230,18 +234,29 @@ def xarray_to_h5netcdf_with_complex_numbers(
internal_ds.data_vars[data_var].dtype.kind for data_var in internal_ds.data_vars
]
coord_kinds = [internal_ds.coords[coord].dtype.kind for coord in internal_ds.coords]
if "c" in data_var_kinds or "c" in coord_kinds:
allow_invalid_netcdf = "c" in data_var_kinds or "c" in coord_kinds

with warnings.catch_warnings():
# see http://xarray.pydata.org/en/stable/howdoi.html
# for how to export complex numbers
with warnings.catch_warnings():
if allow_invalid_netcdf:
warnings.filterwarnings(
"ignore",
module="h5netcdf",
message="You are writing invalid netcdf features",
category=UserWarning,
)
internal_ds.to_netcdf(
path=file_path, engine="h5netcdf", invalid_netcdf=True
)
else:
internal_ds.to_netcdf(path=file_path, engine="h5netcdf")
maybe_write_job = internal_ds.to_netcdf(
path=file_path,
engine="h5netcdf",
invalid_netcdf=allow_invalid_netcdf,
compute=compute, # pyright: ignore
)
# https://github.com/microsoft/pyright/issues/6069
if not compute and maybe_write_job is not None:
with TqdmCallback(desc="Combining files"):
_LOG.info(
"Writing netcdf file using Dask delayed writer.",
extra={"file_name": file_path},
)
maybe_write_job.compute()

0 comments on commit 4b58493

Please sign in to comment.