In [1]:
import os
import sys
sys.path.append("../utils/")
sys.path.append("../models/")
sys.path.append("../handle_data/")
import glob
import argparse
from datetime import datetime as dt
print("Start with importing packages at {0}".format(dt.strftime(dt.now(), "%Y-%m-%d %H:%M:%S")))
import gc
import json as js
from timeit import default_timer as timer
import numpy as np
import xarray as xr
import multiprocessing
import tensorflow as tf
import tensorflow.keras as keras
from model_utils import ModelEngine, handle_opt_utils
from handle_data_class import HandleDataClass, get_dataset_filename
from all_normalizations import ZScore
from timeit import default_timer as timer

Start with importing packages at 2023-01-31 15:35:10


In [2]:
import re
from typing import List
from operator import itemgetter
import dask

class StreamMonthlyNetCDF(object):
    # TO-DO:
    # - get samples_per_file from the data rather than predefining it (varying samples per file for monthly data files!)

    def __init__(self, datadir, patt, workers=4, sample_dim: str = "time", selected_predictors: List = None,
                 selected_predictands: List = None, var_tar2in: str = None, norm_dims: List = None, norm_obj = None):
        self.data_dir = datadir
        self.file_list = patt
        self.ds = xr.open_mfdataset(list(self.file_list), combine="nested", concat_dim=sample_dim)  # , parallel=True)
        # check if norm_dims or norm_obj should be used
        assert norm_obj or norm_dims, f"Neither norm_obj nor norm_dims has been provided."
        if norm_obj and norm_dims:
            print("WARNING: norm_obj and norm_dims have been passed. norm_dims will be ignored.")
            norm_dims = None
        if norm_obj:
            #assert isinstance(norm_obj, Normalize), "norm_obj is not an instance of the Normalize-class."
            self.data_norm = norm_obj
        else:
            self.data_norm = ZScore(norm_dims)      # TO-DO: Allow for arbitrary normalization
        self.sample_dim = sample_dim
        self.data_dim = self.get_data_dim()
        self.norm_params = self.data_norm.get_required_stats(self.ds.to_array(dim="variables"))
        self.nsamples = self.ds.dims[sample_dim]
        self.variables = list(self.ds.variables)
        self.samples_per_file = 28175                # TO-DO avoid hard-coding
        self.predictor_list = selected_predictors
        self.predictand_list = selected_predictands
        self.n_predictands = len(self.predictand_list)
        self.var_tar2in = var_tar2in
        self.data = None

        print(f"Number of used workers: {workers:d}")
        self.pool = multiprocessing.pool.ThreadPool(workers)

    def __len__(self):
        return self.nsamples

    def __getitem__(self, i):
        data = self.index_to_sample(i)
        return data

    def getitems(self, indices):
        return np.array(self.pool.map(self.__getitem__, indices))

    def get_data_dim(self):
        """
        Retrieve the dimensionality of the data to be handled, i.e. without sample_dim which will be batched in a
        data stream.
        :return: tuple of data dimensions
        """
        # get existing dimension names and remove sample_dim
        dimnames = list(self.ds.coords)
        dimnames.remove(self.sample_dim)

        # get the dimensionality of the data of interest
        all_dims = dict(self.ds.dims)
        data_dim = itemgetter(*dimnames)(all_dims)

        return data_dim

    @property
    def data_dir(self):
        return self._data_dir

    @data_dir.setter
    def data_dir(self, datadir):
        if not os.path.isdir(datadir):
            raise NotADirectoryError(f"Parsed data directory '{datadir}' does not exist.")

        self._data_dir = datadir

    @property
    def file_list(self):
        return self._file_list

    @file_list.setter
    def file_list(self, patt):
        patt = patt if patt.endswith(".nc") else f"{patt}.nc"
        files = glob.glob(os.path.join(self.data_dir, patt))

        if not files:
            raise FileNotFoundError(f"Could not find any files with pattern '{patt}' under '{self.data_dir}'.")

        self._file_list = np.asarray(sorted(files, key=lambda s: int(re.search(r'\d+', os.path.basename(s)).group())))

    @property
    def sample_dim(self):
        return self._sample_dim

    @sample_dim.setter
    def sample_dim(self, sample_dim):
        if not sample_dim in self.ds.dims:
            raise KeyError(f"Could not find dimension '{sample_dim}' in data.")

        self._sample_dim = sample_dim

    @property
    def predictor_list(self):
        return self._predictor_list

    @predictor_list.setter
    def predictor_list(self, selected_predictors: List):
        """
        Initalizes predictor list. In case that selected_predictors is set to None, all variables with suffix `_in` in their names are selected.
        In case that a list of selected_predictors is parsed, their availability is checked.
        :param selected_predictors: list of predictor variables or None
        """
        self._predictor_list = self.check_and_choose_vars(selected_predictors, "_in")
        
    @property
    def predictand_list(self):
        return self._predictand_list
    
    @predictand_list.setter
    def predictand_list(self, selected_predictands: List):
        self._predictand_list = self.check_and_choose_vars(selected_predictands, "_tar")
        
    def check_and_choose_vars(self, var_list, suffix: str = "*"):
        """
        Checks list of variables for availability or retrieves all variables named with a given suffix (for var_list = None)
        :param var_list: list of predictor variables or None
        :param suffix: optional suffix of variables to selected. Only effective if var_list is None.
        :return selected_vars: list of selected variables
        """
        if var_list is None:
            selected_vars = [var for var in self.variables if var.endswith(suffix)]
        else:
            stat_list = [var in self.variables for var in var_list]         
            if all(stat_list):
                selected_vars = var_list
            else:
                miss_inds = [i for i, x in enumerate(stat_list) if x]
                miss_vars = [var_list[i] for i in miss_inds]
                raise ValueError(f"Could not find the following variables in the dataset: {*miss_vars,}")
        
        return selected_vars

    def read_netcdf(self, fname):
        fname = tf.keras.backend.get_value(fname)
        fname = str(fname).lstrip("b'").rstrip("'")
        print(f"Load data from {fname}...")
        ds_now = xr.open_dataset(str(fname), engine="netcdf4")
        var_list = list(self.predictor_list) + list(self.predictand_list) 
        ds_now = ds_now[var_list]
        da_now = ds_now.to_array("variables").transpose(..., "variables").astype("float32", copy=False)
        # da_now = self.data_norm.normalize(da_now)
        # da_now = xr.concat([da_now.sel({"variables": "hsurf_tar"}), da_now], dim="variables")
        
        da_now = dask.compute(da_now)[0]
        # da_now = self.data_norm.normalize(da_now)
        # self.data = xr.concat([da_now.sel({"variables": "hsurf_tar"}), da_now], dim="variables")
        self.data = da_now
        print(self.data)
        
        self.nsamples = len(ds_now[self.sample_dim])

        return True
    
    def index_to_sample(self, index):
        data = self.data.isel({self.sample_dim: index})
        data = self.data_norm.normalize(data)
        data = xr.concat([data.sel({"variables": "hsurf_tar"}), data], dim="variables")

        return data
        #return self.data.isel({self.sample_dim: index})        

        
    def normalize_batch(self, batch):
        return self.data_norm.normalize(batch)


In [3]:
js_model = "../HPC_batch_scripts/config_wgan.json"
js_ds = "../HPC_batch_scripts/config_ds_tier2.json"

datadir = "/p/scratch/deepacf/maelstrom/maelstrom_data/ap5_michael/preprocessed_tier2/monthly_files/tmp/"

In [4]:
model_instance = ModelEngine("wgan")

# read configuration files for model and dataset
with open(js_ds, "r") as dsf:
    ds_dict = js.load(dsf)

print(ds_dict)
with open(js_model, "r") as mdf:
    hparams_dict = js.load(mdf)

named_targets = hparams_dict.get("named_targets", False)

bs_train = ds_dict["batch_size"] * hparams_dict["d_steps"] + 1 if "d_steps" in hparams_dict else ds_dict["batch_size"]

{'norm_dims': ['time', 'rlat', 'rlon'], 'batch_size': 32, 'var_tar2in': 'hsurf_tar'}


In [5]:
nworkers = multiprocessing.cpu_count()

norm_dims = ds_dict["norm_dims"]
norm_obj = ZScore(norm_dims)
norm_obj.read_norm_from_file(os.path.join(datadir, "norm.json"))

In [6]:
import inspect
#print(inspect.getmembers(norm_obj))
mu, sigma = norm_obj.norm_stats["mu"].copy(), norm_obj.norm_stats["sigma"].copy()
mu, sigma = mu.astype("float32"), sigma.astype("float32")
norm_obj.norm_stats["mu"], norm_obj.norm_stats["sigma"] = mu, sigma

In [7]:
ds_obj = StreamMonthlyNetCDF(os.path.join(datadir), "ds_resampled*.nc", workers=1,
                             var_tar2in="hsurf_tar", norm_obj=norm_obj, 
                             selected_predictands=["t_2m_tar", "hsurf_tar"])

Mu and sigma are parsed for (de-)normalization.
Number of used workers: 1


In [None]:
tf_read_nc = lambda fname: tf.py_function(ds_obj.read_netcdf, [fname], tf.bool)
tf_getdata = lambda i: tf.numpy_function(ds_obj.getitems, [i], tf.float32)
#tf_norm = lambda arr: tf.numpy_function(ds_obj.normalize_batch, [arr], tf.float32)
tf_split = lambda arr: (arr[..., 0:-ds_obj.n_predictands], arr[...,-ds_obj.n_predictands:])

tfds = tf.data.Dataset.from_tensor_slices(ds_obj.file_list).map(tf_read_nc)
#ds_obj.samples_per_file = 1500
tfds = tfds.flat_map(lambda x: tf.data.Dataset.range(ds_obj.samples_per_file).shuffle(ds_obj.samples_per_file)\
                       .batch(bs_train, drop_remainder=True).map(tf_getdata))#.map(tf_norm))
tfds = tfds.map(tf_split).repeat()

In [7]:
def run_test(tfds, iterations=250, test_case="default"):
    t0 = timer()
    for i, x in enumerate(tfds):
        data_in, data_tar = x
        print(i)
        if i > iterations:
            break
            
    time_tot = timer()-t0
    print(f"Iteration for {test_case}-test took {time_tot:.2f}s.")
    
    return time_tot

In [None]:
exp_time = run_test(tfds, test_case="worker1_load")

In [None]:
%xdel tfds
%xdel ds_obj

In [None]:
ds_obj = StreamMonthlyNetCDF(os.path.join(datadir), "ds_resampled*.nc", workers=1,
                             var_tar2in="hsurf_tar", norm_obj=norm_obj)

tf_read_nc = lambda fname: tf.py_function(ds_obj.read_netcdf, [fname], tf.bool)
tf_getdata = lambda i: tf.numpy_function(ds_obj.getitems, [i], tf.float32)
tf_split = lambda arr: (arr[..., 0:-2], arr[...,-2:])

tfds = tf.data.Dataset.from_tensor_slices(ds_obj.file_list).map(tf_read_nc)
#ds_obj.samples_per_file = 1500
tfds = tfds.flat_map(lambda x: tf.data.Dataset.range(ds_obj.samples_per_file).shuffle(ds_obj.samples_per_file)\
                       .batch(bs_train, drop_remainder=True).map(tf_getdata, num_parallel_calls=tf.data.AUTOTUNE))
tfds = tfds.map(tf_split).repeat()

exp_time = run_test(tfds, test_case="worker1_parallel")

In [None]:
%xdel tfds
%xdel ds_obj

In [None]:
import re
from typing import List
from operator import itemgetter
import dask

class StreamMonthlyNetCDF(object):
    # TO-DO:
    # - get samples_per_file from the data rather than predefining it (varying samples per file for monthly data files!)

    def __init__(self, datadir, patt, workers=4, sample_dim: str = "time", selected_predictors: List = None,
                 selected_predictands: List = None, var_tar2in: str = None, norm_dims: List = None, norm_obj = None):
        self.data_dir = datadir
        self.file_list = patt
        print(self.file_list)
        self.ds = xr.open_mfdataset(list(self.file_list), combine="nested", concat_dim=sample_dim)  # , parallel=True)
        # check if norm_dims or norm_obj should be used
        assert norm_obj or norm_dims, f"Neither norm_obj nor norm_dims has been provided."
        if norm_obj and norm_dims:
            print("WARNING: norm_obj and norm_dims have been passed. norm_dims will be ignored.")
            norm_dims = None
        if norm_obj:
            #assert isinstance(norm_obj, Normalize), "norm_obj is not an instance of the Normalize-class."
            self.data_norm = norm_obj
        else:
            self.data_norm = ZScore(norm_dims)      # TO-DO: Allow for arbitrary normalization
        self.sample_dim = sample_dim
        self.data_dim = self.get_data_dim()
        self.norm_params = self.data_norm.get_required_stats(self.ds.to_array(dim="variables"))
        self.nsamples = self.ds.dims[sample_dim]
        self.variables = list(self.ds.variables)
        self.samples_per_file = 28175                # TO-DO avoid hard-coding
        self.predictor_list = selected_predictors
        self.predictand_list = selected_predictands
        self.n_predictands = len(self.predictand_list)
        print(self.predictand_list)
        print(self.predictor_list)   
        self.var_tar2in = var_tar2in
        self.data = None

        print(f"Number of used workers: {workers:d}")
        self.pool = multiprocessing.pool.ThreadPool(workers)

    def __len__(self):
        return self.nsamples

    def __getitem__(self, i):
        data = self.index_to_sample(i)
        return data

    def getitems(self, indices):
        return np.array(self.pool.map(self.__getitem__, indices))

    def get_data_dim(self):
        """
        Retrieve the dimensionality of the data to be handled, i.e. without sample_dim which will be batched in a
        data stream.
        :return: tuple of data dimensions
        """
        # get existing dimension names and remove sample_dim
        dimnames = list(self.ds.coords)
        dimnames.remove(self.sample_dim)

        # get the dimensionality of the data of interest
        all_dims = dict(self.ds.dims)
        data_dim = itemgetter(*dimnames)(all_dims)

        return data_dim

    @property
    def data_dir(self):
        return self._data_dir

    @data_dir.setter
    def data_dir(self, datadir):
        if not os.path.isdir(datadir):
            raise NotADirectoryError(f"Parsed data directory '{datadir}' does not exist.")

        self._data_dir = datadir

    @property
    def file_list(self):
        return self._file_list

    @file_list.setter
    def file_list(self, patt):
        patt = patt if patt.endswith(".nc") else f"{patt}.nc"
        files = glob.glob(os.path.join(self.data_dir, patt))

        if not files:
            raise FileNotFoundError(f"Could not find any files with pattern '{patt}' under '{self.data_dir}'.")

        self._file_list = np.asarray(sorted(files, key=lambda s: int(re.search(r'\d+', os.path.basename(s)).group())))

    @property
    def sample_dim(self):
        return self._sample_dim

    @sample_dim.setter
    def sample_dim(self, sample_dim):
        if not sample_dim in self.ds.dims:
            raise KeyError(f"Could not find dimension '{sample_dim}' in data.")

        self._sample_dim = sample_dim

    @property
    def predictor_list(self):
        return self._predictor_list

    @predictor_list.setter
    def predictor_list(self, selected_predictors: List):
        """
        Initalizes predictor list. In case that selected_predictors is set to None, all variables with suffix `_in` in their names are selected.
        In case that a list of selected_predictors is parsed, their availability is checked.
        :param selected_predictors: list of predictor variables or None
        """
        self._predictor_list = self.check_and_choose_vars(selected_predictors, "_in")
        
    @property
    def predictand_list(self):
        return self._predictand_list
    
    @predictand_list.setter
    def predictand_list(self, selected_predictands: List):
        self._predictand_list = self.check_and_choose_vars(selected_predictands, "_tar")
        
    def check_and_choose_vars(self, var_list, suffix: str = "*"):
        """
        Checks list of variables for availability or retrieves all variables named with a given suffix (for var_list = None)
        :param var_list: list of predictor variables or None
        :param suffix: optional suffix of variables to selected. Only effective if var_list is None.
        :return selected_vars: list of selected variables
        """
        if var_list is None:
            selected_vars = [var for var in self.variables if var.endswith(suffix)]
        else:
            stat_list = [var in self.variables for var in var_list]         
            if all(stat_list):
                selected_vars = var_list
            else:
                miss_inds = [i for i, x in enumerate(stat_list) if x]
                miss_vars = [var_list[i] for i in miss_inds]
                raise ValueError(f"Could not find the following variables in the dataset: {*miss_vars,}")
        
        return selected_vars

    def read_netcdf(self, fname):
        fname = tf.keras.backend.get_value(fname)
        fname = str(fname).lstrip("b'").rstrip("'")
        print(f"Load data from {fname}...")
        ds_now = xr.open_dataset(str(fname), engine="netcdf4")
        var_list = list(self.predictor_list) + list(self.predictand_list) 
        ds_now = ds_now[var_list]
        da_now = ds_now.to_array("variables").transpose(..., "variables").astype("float32", copy=False)
        # da_now = self.data_norm.normalize(da_now)
        # da_now = xr.concat([da_now.sel({"variables": "hsurf_tar"}), da_now], dim="variables")
        
        da_now = dask.compute(da_now)[0]
        #da_now = self.data_norm.normalize(da_now)
        self.data = xr.concat([da_now.sel({"variables": "hsurf_tar"}), da_now], dim="variables")
        #self.data = da_now
        print(self.data)
        
        self.nsamples = len(ds_now[self.sample_dim])

        return True
    
    def index_to_sample(self, index):
        #data = self.data.isel({self.sample_dim: index})
        #data = self.data_norm.normalize(data)
        #data = xr.concat([data.sel({"variables": "hsurf_tar"}), data], dim="variables")
        return self.data.isel({self.sample_dim: index})
        
        #return data


In [None]:
ds_obj = StreamMonthlyNetCDF(os.path.join(datadir), "ds_resampled*.nc", workers=1,
                             var_tar2in="hsurf_tar", norm_obj=norm_obj)

tf_read_nc = lambda fname: tf.py_function(ds_obj.read_netcdf, [fname], tf.bool)
tf_getdata = lambda i: tf.numpy_function(ds_obj.getitems, [i], tf.float32)
tf_split = lambda arr: (arr[..., 0:-2], arr[...,-2:])

tfds = tf.data.Dataset.from_tensor_slices(ds_obj.file_list).map(tf_read_nc)
tfds = tfds.flat_map(lambda x: tf.data.Dataset.range(ds_obj.samples_per_file).shuffle(ds_obj.samples_per_file)\
                       .batch(bs_train, drop_remainder=True).map(tf_getdata))
tfds = tfds.map(tf_split).repeat()

exp_time = run_test(tfds, test_case=f"prenorm_noparallel")

%xdel tfds
%xdel ds_obj

In [None]:
ds_obj = StreamMonthlyNetCDF(os.path.join(datadir), "ds_resampled*.nc", workers=1,
                             var_tar2in="hsurf_tar", norm_obj=norm_obj)

tf_read_nc = lambda fname: tf.py_function(ds_obj.read_netcdf, [fname], tf.bool)
tf_getdata = lambda i: tf.numpy_function(ds_obj.getitems, [i], tf.float32)
tf_split = lambda arr: (arr[..., 0:-2], arr[...,-2:])

tfds = tf.data.Dataset.from_tensor_slices(ds_obj.file_list).map(tf_read_nc)
#ds_obj.samples_per_file = 1500
tfds = tfds.flat_map(lambda x: tf.data.Dataset.range(ds_obj.samples_per_file).shuffle(ds_obj.samples_per_file)\
                       .batch(bs_train, drop_remainder=True).map(tf_getdata, num_parallel_calls=tf.data.AUTOTUNE))
tfds = tfds.map(tf_split).repeat()

exp_time = run_test(tfds, test_case=f"prenorm_autotune")

%xdel tfds
%xdel ds_obj

## Experiment with vectorized indexing

Here, we check if vectorized indexing (cf. `getitems`-method) actually performs better than (potentially) parallelized mapping of the indices.

In [30]:
class StreamMonthlyNetCDF(object):
    # TO-DO:
    # - get samples_per_file from the data rather than predefining it (varying samples per file for monthly data files!)

    def __init__(self, datadir, patt, workers=4, sample_dim: str = "time", selected_predictors: List = None,
                 selected_predictands: List = None, var_tar2in: str = None, norm_dims: List = None, norm_obj = None):
        self.data_dir = datadir
        self.file_list = patt
        print(self.file_list)
        self.ds = xr.open_mfdataset(list(self.file_list), combine="nested", concat_dim=sample_dim)  # , parallel=True)
        self.sample_dim = sample_dim
        self.data_dim = self.get_data_dim()
        # check if norm_dims or norm_obj should be used
        assert norm_obj or norm_dims, f"Neither norm_obj nor norm_dims has been provided."
        if norm_obj and norm_dims:
            print("WARNING: norm_obj and norm_dims have been passed. norm_dims will be ignored.")
            norm_dims = None
        if norm_obj:
            #assert isinstance(norm_obj, Normalize), "norm_obj is not an instance of the Normalize-class."
            self.data_norm = norm_obj
        else:
            self.data_norm = ZScore(norm_dims)      # TO-DO: Allow for arbitrary normalization
        self.sample_dim = sample_dim
        self.data_dim = self.get_data_dim()
        self.norm_params = self.data_norm.get_required_stats(self.ds.to_array(dim="variables"))
        self.nsamples = self.ds.dims[sample_dim]
        self.variables = list(self.ds.variables)
        self.samples_per_file = 28175                # TO-DO avoid hard-coding
        self.predictor_list = selected_predictors
        self.predictand_list = selected_predictands
        self.n_predictands = len(self.predictand_list)
        self.var_tar2in = var_tar2in
        self.data = None

    def __len__(self):
        return self.nsamples

    def getitems(self, indices):
        return self.data.isel({self.sample_dim: indices}).transpose(..., "variables")

    def get_data_dim(self):
        """
        Retrieve the dimensionality of the data to be handled, i.e. without sample_dim which will be batched in a
        data stream.
        :return: tuple of data dimensions
        """
        # get existing dimension names and remove sample_dim
        dimnames = list(self.ds.coords)
        dimnames.remove(self.sample_dim)

        # get the dimensionality of the data of interest
        all_dims = dict(self.ds.dims)
        data_dim = itemgetter(*dimnames)(all_dims)

        return data_dim

    @property
    def data_dir(self):
        return self._data_dir

    @data_dir.setter
    def data_dir(self, datadir):
        if not os.path.isdir(datadir):
            raise NotADirectoryError(f"Parsed data directory '{datadir}' does not exist.")

        self._data_dir = datadir

    @property
    def file_list(self):
        return self._file_list

    @file_list.setter
    def file_list(self, patt):
        patt = patt if patt.endswith(".nc") else f"{patt}.nc"
        files = glob.glob(os.path.join(self.data_dir, patt))

        if not files:
            raise FileNotFoundError(f"Could not find any files with pattern '{patt}' under '{self.data_dir}'.")

        self._file_list = np.asarray(sorted(files, key=lambda s: int(re.search(r'\d+', os.path.basename(s)).group())))

    @property
    def sample_dim(self):
        return self._sample_dim

    @sample_dim.setter
    def sample_dim(self, sample_dim):
        if not sample_dim in self.ds.dims:
            raise KeyError(f"Could not find dimension '{sample_dim}' in data.")

        self._sample_dim = sample_dim

    @property
    def predictor_list(self):
        return self._predictor_list

    @predictor_list.setter
    def predictor_list(self, selected_predictors: List):
        """
        Initalizes predictor list. In case that selected_predictors is set to None, all variables with suffix `_in` in their names are selected.
        In case that a list of selected_predictors is parsed, their availability is checked.
        :param selected_predictors: list of predictor variables or None
        """
        self._predictor_list = self.check_and_choose_vars(selected_predictors, "_in")
        
    @property
    def predictand_list(self):
        return self._predictand_list
    
    @predictand_list.setter
    def predictand_list(self, selected_predictands: List):
        self._predictand_list = self.check_and_choose_vars(selected_predictands, "_tar")
        
    def check_and_choose_vars(self, var_list, suffix: str = "*"):
        """
        Checks list of variables for availability or retrieves all variables named with a given suffix (for var_list = None)
        :param var_list: list of predictor variables or None
        :param suffix: optional suffix of variables to selected. Only effective if var_list is None.
        :return selected_vars: list of selected variables
        """
        if var_list is None:
            selected_vars = [var for var in self.variables if var.endswith(suffix)]
        else:
            stat_list = [var in self.variables for var in var_list]         
            if all(stat_list):
                selected_vars = var_list
            else:
                miss_inds = [i for i, x in enumerate(stat_list) if x]
                miss_vars = [var_list[i] for i in miss_inds]
                raise ValueError(f"Could not find the following variables in the dataset: {*miss_vars,}")
        
        return selected_vars

    def read_netcdf(self, fname):
        fname = tf.keras.backend.get_value(fname)
        fname = str(fname).lstrip("b'").rstrip("'")
        print(f"Load data from {fname}...")
        ds_now = xr.open_dataset(str(fname), engine="netcdf4")
        var_list = list(self.predictor_list) + list(self.predictand_list) 
        ds_now = ds_now[var_list]
        da_now = ds_now.to_array("variables").astype("float32", copy=False)
        self.data = xr.concat([da_now, da_now.sel({"variables": "hsurf_tar"})], dim="variables")

        return True

In [31]:
ds_obj = StreamMonthlyNetCDF(os.path.join(datadir), "ds_resampled*.nc", workers=1,
                             var_tar2in="hsurf_tar", norm_obj=norm_obj, 
                             selected_predictands=["t_2m_tar", "hsurf_tar"])

['/p/scratch/deepacf/maelstrom/maelstrom_data/ap5_michael/preprocessed_tier2/monthly_files/tmp/ds_resampled_0.nc'
 '/p/scratch/deepacf/maelstrom/maelstrom_data/ap5_michael/preprocessed_tier2/monthly_files/tmp/ds_resampled_1.nc'
 '/p/scratch/deepacf/maelstrom/maelstrom_data/ap5_michael/preprocessed_tier2/monthly_files/tmp/ds_resampled_2.nc']


In [32]:
tf_read_nc = lambda fname: tf.py_function(ds_obj.read_netcdf, [fname], tf.bool)
tf_getdata = lambda i: tf.numpy_function(ds_obj.getitems, [i], tf.float32)
tf_split = lambda arr: (arr[..., 0:-ds_obj.n_predictands], arr[...,-ds_obj.n_predictands:])

tfds = tf.data.Dataset.from_tensor_slices(ds_obj.file_list).map(tf_read_nc)
tfds = tfds.flat_map(lambda x: tf.data.Dataset.range(ds_obj.samples_per_file).shuffle(ds_obj.samples_per_file)\
                       .batch(bs_train, drop_remainder=True).map(tf_getdata))#.map(tf_norm))
tfds = tfds.map(tf_split).repeat()

In [33]:
exp_time = run_test(tfds, test_case="worker1_vectorized")

Load data from /p/scratch/deepacf/maelstrom/maelstrom_data/ap5_michael/preprocessed_tier2/monthly_files/tmp/ds_resampled_0.nc...
0
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
Load data from /p/scratch/deepacf/maelstrom/maelstrom_data/ap5_michael/preprocessed_tier2/monthly_files/tmp/ds_resampled_1.nc...
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212


Vectorized indexing yields a test time slightly above 70s which is 15-20s quicker than the prevouis experiments.

## Data streaming with `xarray.open_mfdataset`

Finally, we check an approach where no data has to be written on disk beforehand for streaming, but where streaming is realized by loading the data on-the-fly into memory.
This is realized by creating sublists of all files to be processed which are then merged while loading with `xarray.open_mfdataset`. <br>
Te advantage of this approach is that normalization can be performed exploiting the `preprocess`-argument of `xarray.open_mfdataset`. 
However, we must first revise the Normalization-class to enable processing on datasets as well as since currently only `xarray.DataArray`s are supported.

In [11]:
"""
Abstract class to perform normalization on data
"""

__email__ = "m.langguth@fz-juelich.de"
__author__ = "Michael Langguth"
__date__ = "2022-11-24"

from abc import ABC, abstractmethod
from typing import Union, List
import os
import json as js
import numpy as np
import xarray as xr
import dask

da_or_ds = Union[xr.DataArray, xr.Dataset]


class Normalize(ABC):
    """
    Abstract class for normalizing data.
    """
    def __init__(self, method: str, norm_dims: List):
        self.method = method
        self.norm_dims = norm_dims
        self.norm_stats = None

    def normalize(self, data: xr.DataArray, **stats):
        """
        Normalize data.
        :param data: The DataArray to be normalized.
        :param **stats: Parameters to perform normalization. Must fit to normalization type!
        :return: DataArray with normalized data.
        """
        # sanity checks
        #if not isinstance(data, xr.DataArray):
        #    raise TypeError(f"Passed data must be a xarray.DataArray, but is of type {str(type(data))}.")

        _ = self._check_norm_dims(data)
        # do the computation
        norm_stats = self.get_required_stats(data, **stats)
        data_norm = self.normalize_data(data, *norm_stats)

        return data_norm

    def denormalize(self, data: da_or_ds, **stats):
        """
        Denormalize data.
        :param data: The DataArray to be denormalized.
        :param **stats: Parameters to perform denormalization. Must fit to normalization type!
        :return: DataArray with denormalized data.
        """
        # sanity checks
        #if not isinstance(data, xr.DataArray):
        #    raise TypeError(f"Passed data must be a xarray.DataArray, but is of type {str(type(data))}.")

        _ = self._check_norm_dims(data)
        # do the computation
        norm_stats = self.get_required_stats(data, **stats)
        data_denorm = self.denormalize_data(data, *norm_stats)

        return data_denorm

    @property
    def norm_dims(self):
        return self._norm_dims

    @norm_dims.setter
    def norm_dims(self, norm_dims):
        if norm_dims is None:
            raise AttributeError("norm_dims must not be None. Please parse a list of dimensions" +
                                 "over which normalization should be applied.")

        self._norm_dims = list(norm_dims)

    def _check_norm_dims(self, data):
        """
        Check if dimension for normalization reside in dimensions of data.
        :param data: the data (xr.DataArray) to be normalized
        :return True: in case of passed check, a ValueError is risen else
        """
        data_dims = list(data.dims)
        norm_dims_check = [norm_dim in data_dims for norm_dim in self.norm_dims]
        if not all(norm_dims_check):
            imiss = np.where(~np.array(norm_dims_check))[0]
            miss_dims = list(np.array(self.norm_dims)[imiss])
            raise ValueError("The following dimensions do not reside in the data: " +
                             f"{', '.join(miss_dims)}")

        return True

    def save_norm_to_file(self, js_file, missdir_ok: bool = True):
        """
        Write normalization parameters to file.
        :param js_file: Path to JSON-file to be created.
        :param missdir_ok: If True, base-directory of JSON-file can be missing and will be created then.
        :return: -
        """
        if self.norm_stats is None:
            raise AttributeError("norm_stats is still None. Please run (de-)normalization to get parameters.")

        if any([stat is None for stat in self.norm_stats.values()]):
            raise AttributeError("Some parameters of norm_stats are None.")

        norm_serialized = {key: da.to_dict() for key, da in self.norm_stats.items()}

        # serialization and (later) deserialization depends from data type.
        # Thus, we have to save it to the dictionary
        d0 = list(data_norm.norm_stats.values())[0]
        if isinstance(d0, xr.DataArray):
            norm_serialized["data_type"]= "data_array"
        elif isinstance(d0, xr.Dataset):
            norm_serialized["data_type"]= "data_set"
        
        if missdir_ok: os.makedirs(os.path.dirname(js_file), exist_ok=True)

        with open(js_file, "w") as jsf:
            js.dump(norm_serialized, jsf)

    def read_norm_from_file(self, js_file):
        """
        Read normalization parameters from file. Inverse function to write_norm_from_file.
        :param js_file: Path to JSON-file to be read.
        :return: Parameters set to self.norm_stats
        """
        with open(js_file, "r") as jsf:
            norm_data = js.load(jsf)

        data_type = norm_data.pop('data_type', None)
        
        if data_type == "data_array":
            xr_obj = xr.DataArray
        elif data_type == "data_set":
            xr_obj = xr.Dataset
        else:
            raise ValueError(f"Unknown data_type {data_type} in {js_file}. Only 'data_array' or 'data_set' are allowed.")

        norm_data.pop('data_type', None)
            
        norm_dict_restored = {key: xr_obj.from_dict(da_dict) for key, da_dict in norm_data.items()}

        self.norm_stats = norm_dict_restored

    @abstractmethod
    def get_required_stats(self, data, *stats):
        """
        Function to retrieve either normalization parameters from data or from keyword arguments
        """
        pass

    @staticmethod
    @abstractmethod
    def normalize_data(data, *norm_param):
        """
        Function to normalize data.
        """
        pass

    @staticmethod
    @abstractmethod
    def denormalize_data(data, *norm_param):
        """
        Function to denormalize data.
        """
        pass
    
class ZScore(Normalize):
    def __init__(self, norm_dims: List):
        super().__init__("z_score", norm_dims)
        self.norm_stats = {"mu": None, "sigma": None}

    def get_required_stats(self, data: xr.DataArray, **stats):
        """
        Get required parameters for z-score normalization. They are either computed from the data
        or can be parsed as keyword arguments.
        :param data: the data to be (de-)normalized
        :param stats: keyword arguments for mean (mu) and standard deviation (std) used for normalization
        :return (mu, sigma): Parameters for normalization
        """
        mu, std = stats.get("mu", self.norm_stats["mu"]), stats.get("sigma", self.norm_stats["sigma"])

        if mu is None or std is None:
            print("Retrieve mu and sigma from data...")
            mu, std = data.mean(self.norm_dims), data.std(self.norm_dims)
            # the following ensure that both parameters are computed in one graph! 
            # This significantly reduces memory footprint as we don't end up having data duplicates 
            # in memory due to multiple graphs (and also seem to enfore usage of data chunks as well)
            mu, std = dask.compute(mu, std)
            self.norm_stats = {"mu": mu, "sigma": std}
        # else:
        #    print("Mu and sigma are parsed for (de-)normalization.")
        
        return mu, std

    @staticmethod
    def normalize_data(data, mu, std):
        """
        Perform z-score normalization on data
        :param data: Data array of interest
        :param mu: mean of data for normalization
        :param std: standard deviation of data for normalization
        :return data_norm: normalized data
        """
        data_norm = (data - mu) / std

        return data_norm

    @staticmethod
    def denormalize_data(data, mu, std):
        """
        Perform z-score denormalization on data.
        :param data: Data array of interest
        :param mu: mean of data for denormalization
        :param std: standard deviation of data for denormalization
        :return data_norm: denormalized data
        """
        data_denorm = data * std + mu

        return data_denorm

### New auxiliary function
Besides, the following should be added to to be added to `other_utils.py`

In [12]:
def find_closest_divisor(n1, div):
    
    def getDivisors(n, res=None) : 
        res = res or []
        i = 1
        while i <= n : 
            if (n % i==0) : 
                res.append(i), 
            i = i + 1
        return res
    
    all_divs = getDivisors(n1)
    
    if div in all_divs:
        return div
    else:
        i = np.argmin(np.abs(np.array(all_divs) - div))
        print(a)
        return all_divs[a]

### Running the experiment

The corresponding StreamMonthlyNetCDF-class is set-up in the following.

In [21]:
data_norm = ZScore(["time", "rlat", "rlon"])
data_norm.read_norm_from_file("./norm_test.json")

In [41]:
import random
from other_utils import to_list
from functools import partial

class StreamMonthlyNetCDF(object):
    # TO-DO:
    # - get samples_per_file from the data rather than predefining it (varying samples per file for monthly data files!)

    def __init__(self, datadir, patt, nfiles_merge: int, sample_dim: str = "time", selected_predictors: List = None,
                 selected_predictands: List = None, var_tar2in: str = None, norm_dims: List = None, norm_obj = None):
        self.data_dir = datadir
        self.file_list = patt
        self.nfiles = len(self.file_list)
        self.file_list_random = random.sample(self.file_list, self.nfiles)
        self.nfiles2merge = nfiles_merge
        self.nfiles_merged = int(self.nfiles/self.nfiles2merge)
        self.samples_merged = self.get_samples_per_merged_file()
        self.predictor_list = selected_predictors
        self.predictand_list = selected_predictands
        self.n_predictands, self.n_predictors = len(self.predictand_list), len(self.predictor_list)
        self.all_vars = self.predictor_list + self.predictand_list
        self.ds_all = xr.open_mfdataset(list(self.file_list), decode_cf=False, data_vars=self.all_vars)  # , parallel=True)
        self.var_tar2in = var_tar2in
        if self.var_tar2in is not None:
            print(self.n_predictors)
            print("Add var_tar2in...")
            self.n_predictors += len(to_list(self.var_tar2in))
            print(self.n_predictors)
            print(self.predictor_list)
        self.sample_dim = sample_dim
        self.nsamples = self.ds_all.dims[sample_dim]
        self.data_dim = self.get_data_dim()
        print("Start computing normalization parameters.")
        t0 = timer()
        if norm_obj is None:                        # TO-DO: remove usgae of norm_obj
            self.data_norm = ZScore(norm_dims)      # TO-DO: Allow for arbitrary normalization
            self.norm_params = self.data_norm.get_required_stats(self.ds_all)
        else:
            self.data_norm = norm_obj
        self.normalization_time = timer() - t0

    def __len__(self):
        return self.nsamples

    def getitems(self, indices):
        da_now = self.data.isel({self.sample_dim: indices}).to_array("variables")
        if self.var_tar2in is not None:
            da_now = xr.concat([da_now, da_now.sel({"variables": self.var_tar2in})], dim="variables")
        return da_now.transpose(..., "variables")

    def get_data_dim(self):
        """
        Retrieve the dimensionality of the data to be handled, i.e. without sample_dim which will be batched in a
        data stream.
        :return: tuple of data dimensions
        """
        # get existing dimension names and remove sample_dim
        dimnames = list(self.ds_all.coords)
        dimnames.remove(self.sample_dim)

        # get the dimensionality of the data of interest
        all_dims = dict(self.ds_all.dims)
        data_dim = itemgetter(*dimnames)(all_dims)

        return data_dim
    
    def get_samples_per_merged_file(self):
        nsamples_merged = []
        
        for i in range(self.nfiles_merged):
            file_list_now = self.file_list_random[i*self.nfiles2merge : (i+1)*self.nfiles2merge]
            ds_now = xr.open_mfdataset(list(file_list_now), decode_cf=False)
            print(ds_now)
            nsamples_merged.append(ds_now.dims["time"])                         # To-Do avoid hard-coding
            
        print(nsamples_merged)
        print(max(nsamples_merged))
        return max(nsamples_merged)
        
            
    @property
    def data_dir(self):
        return self._data_dir

    @data_dir.setter
    def data_dir(self, datadir):
        if not os.path.isdir(datadir):
            raise NotADirectoryError(f"Parsed data directory '{datadir}' does not exist.")

        self._data_dir = datadir

    @property
    def file_list(self):
        return self._file_list

    @file_list.setter
    def file_list(self, patt):
        patt = patt if patt.endswith(".nc") else f"{patt}.nc"
        files = glob.glob(os.path.join(self.data_dir, patt))

        if not files:
            raise FileNotFoundError(f"Could not find any files with pattern '{patt}' under '{self.data_dir}'.")

        self._file_list = list(np.asarray(sorted(files, key=lambda s: int(re.search(r'\d+', os.path.basename(s)).group()))))

    @property
    def nfiles2merge(self):
        return self._nfiles2merge
    
    @nfiles2merge.setter
    def nfiles2merge(self, n2merge):
        n = find_closest_divisor(self.nfiles, n2merge)
        if n != n2merge:
            print(f"{n2merge} is not a divisor of the total number of files. Value is changed to {n}")
        
        self._nfiles2merge = n
    
    @property
    def sample_dim(self):
        return self._sample_dim

    @sample_dim.setter
    def sample_dim(self, sample_dim):
        if not sample_dim in self.ds_all.dims:
            raise KeyError(f"Could not find dimension '{sample_dim}' in data.")

        self._sample_dim = sample_dim

    @property
    def predictor_list(self):
        return self._predictor_list

    @predictor_list.setter
    def predictor_list(self, selected_predictors: List):
        """
        Initalizes predictor list. In case that selected_predictors is set to None, all variables with suffix `_in` in their names are selected.
        In case that a list of selected_predictors is parsed, their availability is checked.
        :param selected_predictors: list of predictor variables or None
        """
        self._predictor_list = self.check_and_choose_vars(selected_predictors, "_in")
        
    @property
    def predictand_list(self):
        return self._predictand_list
    
    @predictand_list.setter
    def predictand_list(self, selected_predictands: List):
        self._predictand_list = self.check_and_choose_vars(selected_predictands, "_tar")
        
    def check_and_choose_vars(self, var_list: List, suffix: str = "*"):
        """
        Checks list of variables for availability or retrieves all variables named with a given suffix (for var_list = None)
        :param var_list: list of predictor variables or None
        :param suffix: optional suffix of variables to selected. Only effective if var_list is None.
        :return selected_vars: list of selected variables
        """
        ds_test = xr.open_dataset(self.file_list[0])
        all_vars = list(ds_test.variables)
        
        if var_list is None:
            selected_vars = [var for var in all_vars if var.endswith(suffix)]
        else:
            stat_list = [var in all_vars for var in var_list]         
            if all(stat_list):
                selected_vars = var_list
            else:
                miss_inds = [i for i, x in enumerate(stat_list) if x]
                miss_vars = [var_list[i] for i in miss_inds]
                raise ValueError(f"Could not find the following variables in the dataset: {*miss_vars,}")
        
        return selected_vars
    
    #@staticmethod
    def _preprocess_ds(self, ds, data_norm):
        ds = data_norm.normalize(ds[self.all_vars])
        return ds.astype("float32")

    def read_netcdf(self, ind):
        ind = tf.keras.backend.get_value(ind)
        ind = int(str(ind).lstrip("b'").rstrip("'"))        
        print(f"Load data from {ind}th set of files...")
        file_list_now = self.file_list[ind*self.nfiles2merge:(ind+1)*self.nfiles2merge]
        # read the normalized data into memory
        ds_now = xr.open_mfdataset(list(file_list_now), decode_cf=False, data_vars=self.all_vars, 
                                   preprocess=partial(self._preprocess_ds, data_norm=ds_obj.data_norm), parallel=True).load()
        print("Data loaded successfully from.")
        #da_now = ds_now.to_array("variables").astype("float32", copy=False)
        nsamples = ds_now.dims[self.sample_dim]
        if nsamples < self.samples_merged:
            t0 = timer()
            add_samples = self.samples_merged - nsamples
            print(f"Add {add_samples:d} samples to dataset.")
            add_inds = random.sample(range(nsamples), add_samples)
            ds_add = ds_now.isel({self.sample_dim: add_inds})
            ds_add[self.sample_dim] = ds_add[self.sample_dim] + 1.
            ds_now = xr.concat([ds_now, ds_add], dim=self.sample_dim)
            print(ds_now)
            print(f"Appending data took {timer()-t0:.2f}s.")
            
        self.data = ds_now#xr.concat([da_now, da_now.sel({"variables": "hsurf_tar"})], dim="variables")

        return True

#datadir2 = "/p/cscratch/fs/deepacf//maelstrom/maelstrom_data/ap5_michael/preprocessed_tier2/monthly_files_copy/"
ds_obj = StreamMonthlyNetCDF(os.path.join(datadir, ".."), "downscaling_tier2_train*.nc", 30, norm_dims = ["time", "rlat", "rlon"], norm_obj=data_norm, var_tar2in="hsurf_tar")

<xarray.Dataset>
Dimensions:       (time: 21094, rlon: 120, rlat: 96)
Coordinates:
  * time          (time) float64 1.411e+03 1.412e+03 ... 8.397e+04 8.397e+04
  * rlon          (rlon) float64 -8.273 -8.218 -8.163 ... -1.838 -1.783 -1.728
  * rlat          (rlat) float64 -3.933 -3.878 -3.823 ... 1.182 1.237 1.292
Data variables:
    rotated_pole  (time) int32 1 1 1 1 1 1 1 1 1 1 1 1 ... 1 1 1 1 1 1 1 1 1 1 1
    2t_in         (time, rlat, rlon) float32 dask.array<chunksize=(679, 96, 120), meta=np.ndarray>
    sshf_in       (time, rlat, rlon) float32 dask.array<chunksize=(679, 96, 120), meta=np.ndarray>
    slhf_in       (time, rlat, rlon) float32 dask.array<chunksize=(679, 96, 120), meta=np.ndarray>
    blh_in        (time, rlat, rlon) float32 dask.array<chunksize=(679, 96, 120), meta=np.ndarray>
    10u_in        (time, rlat, rlon) float32 dask.array<chunksize=(679, 96, 120), meta=np.ndarray>
    10v_in        (time, rlat, rlon) float32 dask.array<chunksize=(679, 96, 120), meta=np.nda

In [35]:
if "ds_obj" not in locals():
    data_norm = ZScore(["time", "rlat", "rlon"])
    data_norm.read_norm_from_file("./norm_test.json")
else:
    data_norm = ds_obj.data_norm
    print(ds_obj.normalization_time)
print(data_norm.norm_stats)
#data_norm.save_norm_to_file("./norm_test.json")

8.399947546422482e-07
{'mu': <xarray.Dataset>
Dimensions:       ()
Data variables:
    rotated_pole  float64 1.0
    2t_in         float64 281.5
    sshf_in       float64 -4.993e+04
    slhf_in       float64 -1.643e+05
    blh_in        float64 512.4
    10u_in        float64 0.7001
    10v_in        float64 0.3209
    z_in          float64 5.686e+03
    t850_in       float64 277.5
    t925_in       float64 281.4
    hsurf_tar     float64 571.6
    t_2m_tar      float64 282.0, 'sigma': <xarray.Dataset>
Dimensions:       ()
Data variables:
    rotated_pole  float64 0.0
    2t_in         float64 8.401
    sshf_in       float64 1.934e+05
    slhf_in       float64 2.493e+05
    blh_in        float64 476.3
    10u_in        float64 2.356
    10v_in        float64 1.918
    z_in          float64 4.485e+03
    t850_in       float64 6.883
    t925_in       float64 7.484
    hsurf_tar     float64 497.5
    t_2m_tar      float64 8.482}


Next, we run the experiment:

In [None]:
! jutil env activate -p deepacf
#! export datadir1="${CSCRATCH}/maelstrom/maelstrom_data/ap5_michael/preprocessed_tier2/monthly_files_copy/"; echo $datadir1; /opt/ddn/ime/bin/ime-ctl --prestage ${datadir1}/downscaling_tier2_train*.nc
#! datadir1="${SCRATCH}/maelstrom/maelstrom_data/ap5_michael/preprocessed_tier2/monthly_files_copy/"; ls ${datadir1}
! jutil env activate -p deepacf; /opt/ddn/ime/bin/ime-ctl --frag-stat ${CSCRATCH}/maelstrom/maelstrom_data/ap5_michael/preprocessed_tier2/monthly_files_copy/downscaling_tier2_train_9.nc
! echo $HOSTNAME

In [39]:
tf_read_nc = lambda ind_set: tf.py_function(ds_obj.read_netcdf, [ind_set], tf.bool)
tf_getdata = lambda i: tf.numpy_function(ds_obj.getitems, [i], tf.float32)
tf_split = lambda arr: (arr[..., 0:-ds_obj.n_predictands], arr[...,-ds_obj.n_predictands:])

tfds = tf.data.Dataset.range(int(ds_obj.nfiles_merged)).map(tf_read_nc)
tfds = tfds.flat_map(lambda x: tf.data.Dataset.range(ds_obj.samples_merged).shuffle(ds_obj.samples_merged)\
                       .batch(bs_train, drop_remainder=True).map(tf_getdata))#.map(tf_norm))
tfds = tfds.map(tf_split).repeat()

In [40]:
for i in tfds:
    print(i)
    break

Load data from 0th set of files...
Data loaded successfully from.
Add 46 samples to dataset.
<xarray.Dataset>
Dimensions:    (time: 21394, rlat: 96, rlon: 120)
Coordinates:
  * time       (time) float64 2.155e+03 2.156e+03 ... 2.137e+04 1.137e+04
  * rlon       (rlon) float64 -8.273 -8.218 -8.163 ... -1.838 -1.783 -1.728
  * rlat       (rlat) float64 -3.933 -3.878 -3.823 -3.768 ... 1.182 1.237 1.292
Data variables:
    2t_in      (time, rlat, rlon) float32 -0.3611 -0.3556 ... -0.04588 -0.07664
    sshf_in    (time, rlat, rlon) float32 0.3705 0.3714 0.3732 ... 0.492 0.4727
    slhf_in    (time, rlat, rlon) float32 0.6547 0.6546 0.6546 ... 0.5076 0.5063
    blh_in     (time, rlat, rlon) float32 -1.036 -1.036 -1.037 ... 2.068 2.036
    10u_in     (time, rlat, rlon) float32 -0.1098 -0.1155 ... 2.001 2.043
    10v_in     (time, rlat, rlon) float32 -0.04733 -0.03653 ... -2.074 -2.075
    z_in       (time, rlat, rlon) float32 0.53 0.5259 0.5286 ... -1.09 -1.091
    t850_in    (time, rlat, rlo

In [22]:
exp_time = run_test(tfds, test_case="xarray_mfdataset")

Load data from 0th set of files...
Data loaded successfully from.
Add 121 samples to dataset.
Appending data took 3.83s.
0
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
Load data from 1th set of files...
132
Data loaded successfully from.
Add 120 samples to dataset.
Appending data took 3.84s.
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216


... and find that data streaming is reasonably quick. It takes only about 80s for 250 mini-batches. This is only a bit slower than the test with vectorized indexing, see above.

