Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
215 changes: 165 additions & 50 deletions src/esnb/core/NotebookDiagnostic.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,8 @@
import json
import logging
import os
import shutil
from pathlib import Path

import fsspec
import xarray as xr
Expand All @@ -9,7 +11,13 @@

from . import html, util
from .RequestedVariable import RequestedVariable
from .util2 import flatten_list, infer_source_data_file_types, read_json
from .util2 import (
flatten_list,
generate_tempdir_path,
infer_source_data_file_types,
read_json,
reset_encoding,
)
from .VirtualDataset import VirtualDataset

# import warnings
Expand Down Expand Up @@ -74,6 +82,7 @@ def __init__(
dimensions=None,
variables=None,
varlist=None,
workdir=None,
**kwargs,
):
"""
Expand All @@ -93,6 +102,8 @@ def __init__(
List of variables for the diagnostic.
varlist : dict, optional
Dictionary of variable definitions.
workdir : str, optional
Path to temporary working directory
**kwargs
Additional keyword arguments for settings and user-defined options.
"""
Expand All @@ -103,6 +114,7 @@ def __init__(
self.dimensions = dimensions
self.variables = variables
self.varlist = varlist
self.workdir = workdir

self.name = self.source if self.name is None else self.name

Expand Down Expand Up @@ -184,6 +196,12 @@ def __init__(
# initialize an empty groups attribute
self.groups = []

# initialize workdir
if self.workdir is None:
self.workdir = generate_tempdir_path(self.name)
else:
logger.info(f"Diagnostic workdir is set to: {self.workdir}")

@property
def metrics(self):
"""
Expand Down Expand Up @@ -302,16 +320,18 @@ def dmget(self, status=False):
else:
gfdl.call_dmget(self.files, status=status)

def load(self, site="gfdl"):
def load(self, site="gfdl", dmget=False, use_cache=False, cache_format="zarr"):
"""
Load all groups by calling their load method.
"""
if hasattr(self.groups[0], "dmget"):
_ = [x.load() for x in self.groups]
else:
self.loader(site=site)
self.loader(
site=site, dmget=dmget, use_cache=use_cache, cache_format=cache_format
)

def loader(self, site="gfdl"):
def loader(self, site="gfdl", dmget=False, use_cache=False, cache_format="zarr"):
diag = self
groups = diag.groups
variables = diag.variables
Expand All @@ -321,7 +341,11 @@ def loader(self, site="gfdl"):
def _open_xr(files, varname=None):
time_coder = xr.coders.CFDatetimeCoder(use_cftime=True)
_ds = xr.open_mfdataset(
files, decode_times=time_coder, decode_timedelta=True, **xr_merge_opts
files,
decode_times=time_coder,
decode_timedelta=True,
chunks={},
**xr_merge_opts,
)
if varname is not None:
ds = xr.Dataset()
Expand Down Expand Up @@ -349,7 +373,7 @@ def _open_gcs(files, varname=None):

return ds

if site == "gfdl":
if site == "gfdl" and dmget:
gfdl.call_dmget(diag.files)

# dictionary of datasets by var then group
Expand All @@ -359,41 +383,79 @@ def _open_gcs(files, varname=None):
for var in variables:
ds_by_var[var] = {}
for group in groups:
concat_dim = getattr(group, "concat_dim", None)
# print(var, group, concat_dim)
ncases = len(group.cases)
if ncases > 1:
assert concat_dim is not None, (
f"Multiple cases discovered in group {group} but no concat_dim found"
)

files = []
for case in group.cases:
# print(f" - {case}")
files.append(
list(case.catalog.search(variable_id=var.varname).df["path"])
)

# TODO implement infer_source_data_file_types()
file_type = infer_source_data_file_types(flatten_list(files))

if file_type == "unix_file":
dsets = [_open_xr(x, var.varname) for x in files]
elif file_type == "google_cloud":
dsets = [_open_gcs(x, var.varname) for x in files]
else:
raise ValueError(
f"There is no rule yet to open file type: {file_type}"
)
workdir = self.workdir
_date_range = str("_").join(list(group.date_range))
cached_file_name = f"{group.name}_{var.varname}_{_date_range}"
cached_file_name = Path(f"{cached_file_name}.{cache_format}")
cached_file_name = workdir / cached_file_name
logger.debug(f"Checking for cached file: {cached_file_name}")

if use_cache and cached_file_name.exists():
logger.info(f"Opening cached dataset: {cached_file_name}")
time_coder = xr.coders.CFDatetimeCoder(use_cftime=True)
if cache_format == "zarr":
ds = xr.open_zarr(
cached_file_name,
decode_times=time_coder,
decode_timedelta=True,
)
else:
raise ValueError(
f"Trying to open unsupported cache type: {cache_format}"
)

if len(dsets) > 1:
ds = xr.concat(dsets, concat_dim)
else:
ds = dsets[0]
concat_dim = getattr(group, "concat_dim", None)
# print(var, group, concat_dim)
ncases = len(group.cases)
if ncases > 1:
assert concat_dim is not None, (
f"Multiple cases discovered in group {group} but no concat_dim found"
)

files = []
for case in group.cases:
files.append(
list(
case.catalog.search(variable_id=var.varname).df["path"]
)
)

file_type = infer_source_data_file_types(flatten_list(files))

if file_type == "unix_file":
logger.info(f"Opening dataset files for group {group.name}")
logger.debug(
f"Opening dataset files for group {group.name}: {files}"
)
dsets = [_open_xr(x, var.varname) for x in files]
elif file_type == "google_cloud":
logger.info(
f"Opening Google Cloud Storage for group {group.name}"
)
logger.debug(
f"Opening Google Cloud Storage for group {group.name}: {files}"
)
dsets = [_open_gcs(x, var.varname) for x in files]
else:
raise ValueError(
f"There is no rule yet to open file type: {file_type}"
)

if len(dsets) > 1:
logger.info(
f"Concatenating multiple cases along dimension {concat_dim}: {group.name}"
)
ds = xr.concat(dsets, concat_dim)
else:
ds = dsets[0]

# Select date range
tcoord = "time"
ds = ds.sel({tcoord: slice(*group.date_range)})
# Select date range
tcoord = "time"
logger.info(
f"Subsetting time range {group.date_range}: {group.name}"
)
ds = ds.sel({tcoord: slice(*group.date_range)})

ds = VirtualDataset(ds)
all_datasets.append(ds)
Expand Down Expand Up @@ -422,8 +484,21 @@ def _open_gcs(files, varname=None):
# set top-level datasets
self._datasets = all_datasets

def open(self, site="gfdl"):
self.load(site=site)
def open(self, site="gfdl", dmget=False, use_cache=False, cache_format="zarr"):
self.load(
site=site, dmget=dmget, use_cache=use_cache, cache_format=cache_format
)

def write_cache(
self, workdir=None, output_format="zarr", overwrite=False, chunks=None
):
write_cached_datasets(
self,
workdir=workdir,
output_format=output_format,
overwrite=overwrite,
chunks=chunks,
)

@property
def datasets(self):
Expand Down Expand Up @@ -486,16 +561,7 @@ def _repr_html_(self):
result += f"<tr><td><strong>variables</strong></td><td>{_vars}</td></tr>"
_grps = str("<br>").join([x.name for x in self.groups])
result += f"<tr><td><strong>groups</strong></td><td>{_grps}</td></tr>"

# result += "<tr><td colspan='2'>"
# result += "<details>"
# result += "<summary>Group Details</summary>"
# result += "<div><table>"
# for grp in self.groups:
# result += f"<tr>{grp._repr_html_(title=False)}</tr>"
# result += "</table></div>"
# result += "</details>"
# result += "</td></tr>"
result += f"<tr><td><strong>workdir</strong></td><td>{self.workdir}</td></tr>"

if len(self.diag_vars) > 0:
result += "<tr><td colspan='2'>"
Expand Down Expand Up @@ -550,3 +616,52 @@ def _repr_html_(self):
result += "</table>"

return result


def write_cached_datasets(
diag, workdir=None, output_format="zarr", overwrite=False, chunks=None
):
if workdir is None:
workdir = diag.workdir

workdir = Path(workdir)
if not workdir.exists():
logger.info(f"workdir does not exist, creating: {workdir}")
os.makedirs(workdir)

for group in diag.groups:
for variable in group.datasets.keys():
_date_range = str("_").join(list(group.date_range))
output_name = f"{group.name}_{variable.varname}_{_date_range}"
ds = group.datasets[variable]

if output_format == "zarr":
output_name = Path(f"{output_name}.{output_format}")
output_name = workdir / output_name

if output_name.exists() and overwrite:
logger.info(f"Found existing zarr and deleting: {output_name}")
shutil.rmtree(output_name)

if not output_name.exists():
dsout = ds
dsout[variable.varname] = reset_encoding(dsout[variable.varname])

if chunks is not None:
logger.info(
f"Resetting chunks and applying new chunks: {chunks}"
)
chunksizes = chunks
else:
logger.info("Using automatic chunks")
chunksizes = "auto"

dsout = dsout.chunk(chunksizes)
chunksizes = dsout.chunksizes
logger.info(f"Output chunksizes are: {dict(chunksizes)}")

logger.info(f"Writing zarr file: {output_name}")
dsout.to_zarr(output_name)

else:
logger.info(f"Found existing zarr -- doing nothing: {output_name}")
58 changes: 58 additions & 0 deletions src/esnb/core/util2.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,14 @@
import datetime
import json
import logging
import random
import re
import string
import tempfile
from pathlib import Path

import xarray as xr

from esnb.core.CaseExperiment2 import CaseExperiment2
from esnb.core.util import is_overlapping, process_time_string

Expand Down Expand Up @@ -49,6 +56,13 @@ def case_time_filter(case, date_range):
return df


def clean_string(input_string):
res = re.sub(r"[^a-zA-Z0-9\s]", "", input_string)
res = res.replace(" ", "_")
res = re.sub(r"_+", "_", res)
return res


def flatten_list(nested_list):
"""
Recursively flattens a nested list into a single list of elements.
Expand Down Expand Up @@ -78,6 +92,16 @@ def flatten_list(nested_list):
return flat_list


def generate_tempdir_path(name=None):
name = "" if name is None else clean_string(name) + "_"
date_str = datetime.datetime.now().strftime("%Y-%m-%d_%H%M%S")
rand_str = "".join(random.choices(string.ascii_lowercase + string.digits, k=8))
custom_name = f"{name}{date_str}_{rand_str}"
base_temp_dir = tempfile.gettempdir()
full_path = Path(base_temp_dir) / Path(custom_name)
return full_path


def infer_source_data_file_types(flist):
flist = [flist] if not isinstance(flist, list) else flist
flist = [Path(x) for x in flist]
Expand Down Expand Up @@ -190,6 +214,40 @@ def read_json(name):
return json.loads(json_str)


def reset_encoding(xobj, attrs=None):
"""Function to reset encoding attributes on an xarray object

Parameters
----------
xobj : xarray.core.dataset.Dataset or xarray.core.dataarray.DataArray
Input xarray object
attrs : list, optional
Attributes to reset, by default None

Returns
-------
xarray.core.dataset.Dataset or xarray.core.dataarray.DataArray
Xarray object without encoding attributes
"""

attrs = ["chunks", "preferred_chunks"] if attrs is None else attrs

if isinstance(xobj, xr.DataArray):
for attr in attrs:
xobj.encoding.pop(attr, None)

elif isinstance(xobj, xr.Dataset):
for attr in attrs:
xobj.encoding.pop(attr, None)
for var in xobj.variables:
xobj[var].encoding.pop(attr, None)

else:
raise ValueError("xobj must be an xarray Dataset or DataArray")

return xobj


def xr_date_range_to_datetime(date_range):
"""
Converts a list of date strings into a processed datetime string.
Expand Down
Loading