# Improving data streaming from netCDF files with TensorFlow datasets

In this Jupyter Notebook, approaches for handling data from multiple netCDF files with TF datasets will be tested. The aim is to come up with a performant approach that allows data reading from netCDF files rather than doing a conversion to TFRecords while keeping a memory-light data handling (to allow handling of datasets that do not fit into the memory of the computing node).

The first approach samples data from multiple netCDF-files when creating individual batches. To speed up the operation, threading with `multiprocessing` is tested.

In [None]:
import os, glob
import re
#from tqdm import tqdm
from timeit import default_timer as timer
import pandas as pd
import numpy as np
import xarray as xr
import tensorflow as tf
import multiprocessing

In [None]:
class StreamMonthlyNetCDF():
    def __init__(self, datadir, patt, workers=4, sample_dim: str = "time"):
        self.data_dir = datadir
        self.file_list = patt
        self.ds = xr.open_mfdataset(list(self.file_list), parallel=True)
        self.sample_dim = sample_dim
        self.times = self.ds[sample_dim].load()
        self.nsamples = self.ds.dims[sample_dim]
        self.file_handles = {}
        self.time_dict_times = {}
        for fnc in self.file_list:
            self.file_handles[fnc] = xr.open_dataset(fnc)
            self.time_dict_times[fnc] = self.file_handles[fnc][sample_dim].load()
            # self.file_handles[fnc] = xr.open_dataset(fnc, decode_cf=False)
            # self.file_handles.append(xr.open_dataset(fnc))
        
        print(f"Number of used workers: {workers:d}")
        self.pool = multiprocessing.pool.ThreadPool(workers)
        
    def __len__(self):
        return self.nsamples

    def __getitem__(self, i):
        data = self.index_to_sample(i)
        return data
    
    def getitems(self, indices):
        print(indices)
        return np.array(self.pool.map(self.__getitem__ ,indices))
    
    @property
    def data_dir(self):
        return self._data_dir
    
    @data_dir.setter 
    def data_dir(self, datadir):
        if not os.path.isdir(datadir):
            raise DirectoryNotFoundError(f"Parsed data directory '{datadir}' does not exist.")
            
        self._data_dir = datadir
        
    @property 
    def file_list(self):
        return self._file_list 
    
    @file_list.setter
    def file_list(self, patt):        
        patt = patt if patt.endswith(".nc") else f"{patt}.nc" 
        files = glob.glob(os.path.join(self.data_dir, patt))
        
        if not files:
            raise FileNotFoundError(f"Could not find any files with pattern '{patt}' under '{self.data_dir}'.")
            
        self._file_list = sorted(files)        
        
    @property
    def sample_dim(self):
        return self._sample_dim 
    
    @sample_dim.setter
    def sample_dim(self, sample_dim):
        if not sample_dim in self.ds.dims:
            raise KeyError(f"Could not find dimension '{sample_dim}' in data.")
            
        self._sample_dim = sample_dim 
        
    def index_to_sample(self, index):
        curr_time = pd.to_datetime(self.times[index].values)
        
        fname = [s for s in self.file_list if curr_time.strftime("%Y-%m") in s]
        if not fname:
            raise FileNotFoundError(f"Could not find a file matching requested date {date_ex}")
        elif len(fname) > 1:
            raise ValueError(f"Files found for requested date {date_ex} is not unique.")
        
        ds = self.file_handles[fname[0]]  #
        return ds.sel({self.sample_dim: curr_time}).to_array()
                                        

To speed up data reading, we stage the files on `CSCRATCH`, the high performance storage tier at JSC.

In [None]:
! jutil env activate -p deepacf
! datadir="${CSCRATCH}/maelstrom/maelstrom_data/ap5_michael/preprocessed_era5_crea6/netcdf_data"
! echo $datadir
! datadir="${CSCRATCH}/maelstrom/maelstrom_data/ap5_michael/preprocessed_era5_crea6/netcdf_data"; for yr in {2006..2018}; do for mm in {01..12}; do /opt/ddn/ime/bin/ime-ctl --prestage ${datadir}/${yr}/${yr}-${mm}/preproc_${yr}-${mm}.nc; done; done

Let's check if the data is really staged:

In [None]:
! /opt/ddn/ime/bin/ime-ctl --frag-stat $CSCRATCH/maelstrom/maelstrom_data/ap5_michael/preprocessed_era5_crea6/netcdf_data/2015/2015-11/preproc_2015-11.nc

Let's run a first test on the Tier-2 dataset of MAELSTROM's downscaling application:

In [None]:
# Path to netCDF-files
datadir = "/p/scratch/deepacf/maelstrom/maelstrom_data/ap5_michael/preprocessed_era5_crea6/netcdf_data/all_files/"
# batch size
batch_size = 32
# number of test sets
test_steps = 500

# get number of available (virtual) CPUs
max_workers = multiprocessing.cpu_count()
workers = min(batch_size, max_workers) 

workers_now = int(workers)
print(f"Number of available (virtual) CPUs: {max_workers:d}.")

In [None]:
# instantiate an example monthly data stream
all_data = StreamMonthlyNetCDF(datadir, "preproc_*.nc", workers=int(workers_now))

Check the handled dataset (should have more than 100K samples):

In [None]:
ds = all_data.ds
print(ds)
# check timestamps and available number of samples
print(f"Available samples in dataset: {len(all_data):d}.")
print(all_data.times)

Wrap everything into a numpy-function for TensorFlow and create a dataset. Note that the indices are shuffled while also ensuring that the buffer size is large enough to enable 'reasonable' sampling (20 K corresponds to 2.5 years of data).

In [None]:
tf_fun2=lambda i: tf.numpy_function(all_data.getitems, [i] , tf.float64 )
inp=tf.Variable(range(10))
# some test
data_test = tf_fun2(inp)

# set-up TF dataset
ds=tf.data.Dataset.range(len(all_data)).shuffle(buffer_size=20000).batch(int(workers_now*4)) \
                  .map(tf_fun2).unbatch().batch(batch_size)

Run the test:

In [None]:
#%%timeit
#from timeit import default_timer as timer
batch_time = [timer()]
# test with half of the workers
for i, x in enumerate(ds):#tqdm(enumerate(ds)):
    if i == 0:
        print(tf.shape(x))
    elif i > test_steps -1:
        break
    print(i)
    batch_time.append(timer())

In [None]:
batch_time = np.asarray(batch_time)
elapsed_times = batch_time[1:] - batch_time[0:-1]

print(f"Average time per batch: {np.mean(elapsed_times[3::]):.2f}s (+/- {np.std(elapsed_times[3::]):.3f}s). \n" +
      f"Total time: {np.sum(elapsed_times[3::]):.1f}s") 

In [None]:
print(elapsed_times)

## Preliminary result
The experiments conducted above show that reading from multiple netCDF-files for sampling constitutes a severe bottleneck. Even with a higher number of threads, it takes 1.5s to 2.5s to sample a mini-batch with size 32. This is much slower than the forward and backward step take, e.g. for U-Net (about 0.1 to 0.2s).

Without shuffling, the sampling is considerably quicker. This is obviously due to the fact that not several, but one netCDF-file is (most often) used for creating the mini-batch. 

The new idea is now to perform a manual shuffling before training. All the data will be read lazily and then shuffled to create netCDF-files from random time steps. By doing so, the netCDF-files can be consumed sequentially while also ensuring randomness in data sampling. For varying the ordering of the data during training, one might try to permute the file list order. However, it's not clear yet how this can be realized, maybe with the help of Keras Callbacks.

### Prepare the shuffled dataset

We start by first writing the data in a shuffled way to new netCDF files. Since coordinates in netCDF-files must be montonically ascending or descending, we need to introduce a sample index. <br>
For later merging (i.e. opening with `xr.open_mfdataset`), the sample index must be unique and thus will be defined globally. Thus, it will run from 1 to `len(dataset)`.

In [None]:
ds = all_data.ds 
times = all_data.times.copy(deep=True)
ntimes = np.shape(times)[0]
ds = ds.rename_dims({"time": "sample_ind"})
ds["sample_ind"] = range(np.shape(times)[0])

(Double-)Check the number of samples:

In [None]:
print(ntimes)

Since any parallelized writing of netCDF-files proved to be terribly slow (e.g. with `xr.save_mfdataset` or using `multiprocessing`), we pursue a sequential netCDF-creation procedure.

In [None]:
def sample2netcdf(id, indices):
    print(f"indices of process {id}: {indices}", flush=True)
    ds_subset = ds.isel({"sample_ind": indices}).load()
    print("Data loaded sucsessfully!", flush=True)
    
    nsamples_now = np.shape(ds_subset["sample_ind"])[0]
    ds_subset["sample_ind"] = range(nsamples_now)
    fname_now = os.path.join(datadir, "test2", f"ds_resampled_{id:0d}_test.nc")
    
    print(f"Write data subset to file '{fname_now}'.", flush=True)
    ds_subset.to_netcdf(fname_now)
    

In [None]:
print(datadir)
#datadir = os.path.join(datadir, "test2")
os.makedirs(datadir, exist_ok=True)
print(datadir)

In [None]:
samples_per_file = int(8640)

inds = np.arange(ntimes)
np.random.shuffle(inds)

# approach with multiprocessing -> slow
# t0 = timer()
# inds_list = [(i, inds[i*samples_per_file: (i+1)*samples_per_file]) for i in range(int(ntimes/samples_per_file))]
# with multiprocessing.pool.ThreadPool(4) as Pool:
#    for _ in Pool.starmap(sample2netcdf, inds_list):
#        print("Done!")
        
# print(f"File creation took {timer()-t0:.2f}s.")

# approach with xr.save_mfdataset -> slow
# fname_list, ds_list = [], []
# for i in range(int(ntimes/samples_per_file)):    
    # fname_list.append(os.path.join(datadir, "test2", f"ds_resampled_{i:0d}.nc"))
    # ds_list.append(ds.isel({"sample_ind": inds_now}))
# print(fname_list)
# t0 = timer()
# xr.save_mfdataset(ds_list, fname_list, mode="w")
# print(f"Saving data took {timer()-t0:.1f}s.")

samples_per_file = int(8640)

inds = np.arange(ntimes)
np.random.shuffle(inds)

t0 = timer()
for i in range(int(ntimes/samples_per_file)):
    
    inds_now = inds[i*samples_per_file: (i+1)*samples_per_file]
    print(f"Load data to memory for {i+1:d}th subset...")
    ds_subset = ds.isel({"sample_ind": inds_now}).load()
    print("Data loaded sucsessfully!")
    
    nsamples_now = np.shape(ds_subset["sample_ind"])[0]
    #ds_subset["sample_ind"] = range(i*nsamples_now, (i+1)*nsamples_now)
    ds_subset["sample_ind"] = range(nsamples_now)
    fname_now = os.path.join(datadir, "test2", f"ds_resampled_{i:0d}.nc")
    
    print(f"Write data subset to file '{fname_now}'.")
    ds_subset.to_netcdf(fname_now)
    
print(f"File creation took {timer()-t0:.2f}s.")

Next, we stage (again) the netCDF files to `CSCRATCH` for quicker data access: 

In [None]:
! datadir="${CSCRATCH}/maelstrom/maelstrom_data/ap5_michael/preprocessed_era5_crea6/netcdf_data/all_files/test"; for i in {0..11}; do /opt/ddn/ime/bin/ime-ctl --prestage ${datadir}/ds_resampled_${i}.nc; done
! datadir="${CSCRATCH}/maelstrom/maelstrom_data/ap5_michael/preprocessed_era5_crea6/netcdf_data/all_files/test"; /opt/ddn/ime/bin/ime-ctl --frag-stat ${datadir}/ds_resampled_11.nc

In [None]:
class StreamMonthlyNetCDF():
    def __init__(self, datadir, patt, workers=4, sample_dim: str = "time", samples_per_file: int = 8640):
        self.data_dir = datadir
        self.file_list = patt
        self.ds = xr.open_mfdataset(list(self.file_list), parallel=True)
        self.sample_dim = sample_dim
        self.times = self.ds[sample_dim].load()
        self.nsamples = self.ds.dims[sample_dim]
        self.samples_per_file = samples_per_file
        self.ds_now = None
        self.loaded_files = []
        
        print(f"Number of used workers: {workers:d}")
        self.pool = multiprocessing.pool.ThreadPool(workers)
    def __len__(self):
        return self.nsamples

    def __getitem__(self, i):
        data = self.index_to_sample(i)
        return data
    
    def getitems(self, indices):
        inds_fname = list(set([int(i/self.samples_per_file) for i in indices]))
        # before getting the data, check if we must load new files
        if self.ds_now is None or not set(self.file_list[inds_fname]) == set(self.loaded_files):
            print(f"Load datafiles {*self.file_list[inds_fname],}")
            self.loaded_files = self.file_list[inds_fname]
            self.ds_now = xr.open_mfdataset(list(self.loaded_files)).load()
        return np.array(self.pool.map(self.__getitem__ , indices))
    
    @property
    def data_dir(self):
        return self._data_dir
    
    @data_dir.setter 
    def data_dir(self, datadir):
        if not os.path.isdir(datadir):
            raise DirectoryNotFoundError(f"Parsed data directory '{datadir}' does not exist.")
            
        self._data_dir = datadir
        
    @property 
    def file_list(self):
        return self._file_list 
    
    @file_list.setter
    def file_list(self, patt):        
        patt = patt if patt.endswith(".nc") else f"{patt}.nc" 
        files = glob.glob(os.path.join(self.data_dir, patt))
        
        if not files:
            raise FileNotFoundError(f"Could not find any files with pattern '{patt}' under '{self.data_dir}'.")
            
        self._file_list = np.asarray(sorted(files, key=lambda s: int(re.search(r'\d+', os.path.basename(s)).group())))
        
    @property
    def sample_dim(self):
        return self._sample_dim 
    
    @sample_dim.setter
    def sample_dim(self, sample_dim):
        if not sample_dim in self.ds.dims:
            raise KeyError(f"Could not find dimension '{sample_dim}' in data.")
            
        self._sample_dim = sample_dim 
        
    def index_to_sample(self, index):    
        try:
            return self.ds_now.sel({self.sample_dim: index}).to_array()
        except Exception as err:
            # interestingly, this proves to work (racing condition?)
            print(self.ds_now)
            print(index)
            print(index in self.ds_now["sample_ind"])
            print({self.sample_dim: index})
            print(self.ds_now.sel({self.sample_dim: index}))
            return self.ds_now.sel({self.sample_dim: index}).to_array()
        #return ds.sel({self.sample_dim: curr_time}).to_array()

Next, we will instantiate the data object, ...

In [None]:
all_data_new = StreamMonthlyNetCDF(os.path.join(datadir, "test"), "ds_resampled_*.nc", workers=int(workers_now), sample_dim = "sample_ind")

In [None]:
print(all_data_new.ds)

built a TF dataset (without shuffling to ensure sequential data loading), ...

In [None]:
tf_fun2=lambda i: tf.numpy_function(all_data_new.getitems, [i] , tf.float64)

# same experiment with all workers
#ds=tf.data.Dataset.range(8578, 8643).batch(int(33)) \
ds=tf.data.Dataset.range(len(all_data)).batch(int(33)) \
                  .map(tf_fun2).unbatch().batch(32)

... and conduct the test:

In [None]:
batch_time = [timer()]
test_steps = 500

print(f"Test for {test_steps}-times")
# test with half of the workers
for i, x in enumerate(ds):#tqdm(enumerate(ds)):
    if i == 0:
        print(tf.shape(x))
    elif i > test_steps -1:
        break
    batch_time.append(timer())
    print(i)
    
batch_time = np.asarray(batch_time)
elapsed_times = batch_time[1:] - batch_time[0:-1]

print(f"Average time per batch: {np.mean(elapsed_times):.2f} (+/- {np.std(elapsed_times):.3f}). \n" +
      f"Total time: {np.sum(elapsed_times):.1f}") 

It is seen that data sampling now is much quicker with an average creation time below 0.2s.
Thus, this approach is a candidate for a real test when training the model. <br>
However, open issues persist. These are:
- [ ] Some missing data when the total number of samples is not a divider of samples_per_file
- [ ] Potential racing condition when self.ds_now has to be updated (see the hacky try-except handling)
- [ ] Fixed ordering of shuffled training samples (How to get variation into it?)

In [None]:
class StreamMonthlyNetCDF():
    def __init__(self, datadir, patt, workers=4, sample_dim: str = "time", samples_per_file: int = 8640):
        self.data_dir = datadir
        self.file_list = patt
        print(self.file_list)
        self.ds = xr.open_mfdataset(list(self.file_list))#, parallel=True)
        self.sample_dim = sample_dim
        self.times = self.ds[sample_dim].load()
        self.nsamples = self.ds.dims[sample_dim]
        self.samples_per_file = samples_per_file
        self.ds_now = None
        self.istart = 0
        self.loaded_files = []
        
        print(f"Number of used workers: {workers:d}")
        self.pool = multiprocessing.pool.ThreadPool(workers)
    def __len__(self):
        return self.nsamples

    def __getitem__(self, i):
        data = self.index_to_sample(i)
        return data
    
    def getitems(self, indices):
        print(indices)
        #inds_fname = list(set([int(i/self.samples_per_file) for i in indices]))
        # before getting the data, check if we must load new files
        #if self.ds_now is None or not set(self.file_list[inds_fname]) == set(self.loaded_files):
        #    print(f"Load datafiles {*self.file_list[inds_fname],}")
        #    self.loaded_files = self.file_list[inds_fname]
        #    self.ds_now = xr.open_mfdataset(list(self.loaded_files)).load()
        return np.array(self.pool.map(self.__getitem__ , indices))
    
    @property
    def data_dir(self):
        return self._data_dir
    
    @data_dir.setter 
    def data_dir(self, datadir):
        if not os.path.isdir(datadir):
            raise DirectoryNotFoundError(f"Parsed data directory '{datadir}' does not exist.")
            
        self._data_dir = datadir
        
    @property 
    def file_list(self):
        return self._file_list 
    
    @file_list.setter
    def file_list(self, patt):        
        patt = patt if patt.endswith(".nc") else f"{patt}.nc" 
        files = glob.glob(os.path.join(self.data_dir, patt))
        
        if not files:
            raise FileNotFoundError(f"Could not find any files with pattern '{patt}' under '{self.data_dir}'.")
            
        self._file_list = np.asarray(sorted(files, key=lambda s: int(re.search(r'\d+', os.path.basename(s)).group())))
        
    @property
    def sample_dim(self):
        return self._sample_dim 
    
    @sample_dim.setter
    def sample_dim(self, sample_dim):
        if not sample_dim in self.ds.dims:
            raise KeyError(f"Could not find dimension '{sample_dim}' in data.")
            
        self._sample_dim = sample_dim 
        
    def read_netcdf(self, fname):
        fname = tf.keras.backend.get_value(fname)
        fname = str(fname).lstrip("b'").rstrip("'")
        print(f"Load data from {fname}...")
        ds_now = xr.open_dataset(str(fname), engine="netcdf4")#.sel(**sel_dict)
        self.ds_now = ds_now.astype("float32", copy=False).load()
        self.istart = self.ds_now["sample_ind"][0].values
        print(self.istart)

        return True        
        
    def index_to_sample(self, index):  
        index_loc = index + self.istart
        try:
            return self.ds_now.sel({self.sample_dim: index_loc}).astype("float32", copy=False).to_array()
        except Exception as err:
            # interestingly, this proves to work (racing condition?)
            print(f"istart: {self.istart:d}")
            print(f"dataset: {self.ds_now}")
            print(f"index: {index}")
            print(f"Bool: {index in self.ds_now['sample_ind']}")
            print({self.sample_dim: index})
            print(self.ds_now.sel({self.sample_dim: index}))
            return self.ds_now.sel({self.sample_dim: index}).to_array()
        return ds.sel({self.sample_dim: curr_time}).to_array()

In [None]:
all_data_new = StreamMonthlyNetCDF(os.path.join(datadir, "test"), "ds_resampled_*.nc", workers=int(workers_now), sample_dim = "sample_ind")

In [None]:
tf.config.run_functions_eagerly(True)

tf_fun1 = lambda fname: tf.py_function(all_data_new.read_netcdf, [fname], tf.bool)
tf_fun2 = lambda i: tf.numpy_function(all_data_new.getitems, [i] , (tf.float32, tf.float32))

data_dir = all_data_new.data_dir
#print(data_dir)


def dataset_reader(fname):
    fname = tf.keras.backend.get_value(fname)
    print(str(fname).lstrip("b'").strip("'"))
    ds = xr.open_dataset(str(fname).lstrip("b'").strip("'"), engine="netcdf4").load()
    da = np.array(ds.to_array().transpose("sample_ind", "variable", ...))
    
    #dataset = tf.data.Dataset.from_tensor_slices(da)
    return da

tf_fun = lambda path: tf.py_function(dataset_reader, [path], Tout=tf.float64)
data_iter = tf.data.Dataset.from_tensor_slices(all_data_new.file_list).map(tf_fun1)#.range(10).batch(5).map(tf_fun2)

data_iter = data_iter.interleave(lambda x: tf.data.Dataset.range(10).batch(5).map(tf_fun2))
#data_iter = tf.data.Dataset.from_tensor_slices(all_data_new.file_list).map(tf_fun)#.from_tensor_slices()

In [None]:
for i, ele in enumerate(data_iter):
    if i < 3:
        print(ele)
    else:
        break

In [None]:
print(os.path.isfile("/p/scratch/deepacf/maelstrom/maelstrom_data/ap5_michael/preprocessed_era5_crea6/netcdf_data/all_files/test/ds_resampled_0.nc"))

In [None]:
#import tensorflow.data.Dataset as Dataset

filenames = ["./test/file1.txt", "./test//file2.txt",
             "./test/file3.txt", "./test/file4.txt"]
dataset = tf.data.Dataset.from_tensor_slices(filenames)
def parse_fn(filename):
  return tf.data.Dataset.range(10)
dataset = dataset.interleave(lambda x:
    tf.data.TextLineDataset(x).map(parse_fn, num_parallel_calls=1),
    cycle_length=4, block_length=16)


#dataset = tf.data.Dataset.range(1, 6)
#dataset = dataset.interleave(
#    lambda x: tf.data.Dataset.from_tensors(x).repeat(6),
#    cycle_length=2, block_length=4)
#list(dataset.as_numpy_iterator())
for ds in dataset:
    for i in ds.as_numpy_iterator():
        print(i)