# Improving data streaming from netCDF files with TensorFlow datasets

In this Jupyter Notebook, approaches for handling data from multiple netCDF files with TF datasets will be tested. The aim is to come up with a performant approach that allows data reading from netCDF files rather than doing a conversion to TFRecords while keeping a memory-light data handling (to allow handling of datasets that do not fit into the memory of the computing node).

The first approach samples data from multiple netCDF-files when creating individual batches. To speed up the operation, threading with `multiprocessing` is tested.

In [1]:
import os, glob
import re
#from tqdm import tqdm
from timeit import default_timer as timer
import pandas as pd
import numpy as np
import xarray as xr
import tensorflow as tf
import multiprocessing

In [2]:
class StreamMonthlyNetCDF():
    def __init__(self, datadir, patt, workers=4, sample_dim: str = "time"):
        self.data_dir = datadir
        self.file_list = patt
        self.ds = xr.open_mfdataset(list(self.file_list), parallel=True)
        self.sample_dim = sample_dim
        self.times = self.ds[sample_dim].load()
        self.nsamples = self.ds.dims[sample_dim]
        self.file_handles = {}
        self.time_dict_times = {}
        for fnc in self.file_list:
            self.file_handles[fnc] = xr.open_dataset(fnc)
            self.time_dict_times[fnc] = self.file_handles[fnc][sample_dim].load()
            # self.file_handles[fnc] = xr.open_dataset(fnc, decode_cf=False)
            # self.file_handles.append(xr.open_dataset(fnc))
        
        print(f"Number of used workers: {workers:d}")
        self.pool = multiprocessing.pool.ThreadPool(workers)
        
    def __len__(self):
        return self.nsamples

    def __getitem__(self, i):
        data = self.index_to_sample(i)
        return data
    
    def getitems(self, indices):
        print(indices)
        return np.array(self.pool.map(self.__getitem__ ,indices))
    
    @property
    def data_dir(self):
        return self._data_dir
    
    @data_dir.setter 
    def data_dir(self, datadir):
        if not os.path.isdir(datadir):
            raise DirectoryNotFoundError(f"Parsed data directory '{datadir}' does not exist.")
            
        self._data_dir = datadir
        
    @property 
    def file_list(self):
        return self._file_list 
    
    @file_list.setter
    def file_list(self, patt):        
        patt = patt if patt.endswith(".nc") else f"{patt}.nc" 
        files = glob.glob(os.path.join(self.data_dir, patt))
        
        if not files:
            raise FileNotFoundError(f"Could not find any files with pattern '{patt}' under '{self.data_dir}'.")
            
        self._file_list = sorted(files)        
        
    @property
    def sample_dim(self):
        return self._sample_dim 
    
    @sample_dim.setter
    def sample_dim(self, sample_dim):
        if not sample_dim in self.ds.dims:
            raise KeyError(f"Could not find dimension '{sample_dim}' in data.")
            
        self._sample_dim = sample_dim 
        
    def index_to_sample(self, index):
        curr_time = pd.to_datetime(self.times[index].values)
        
        fname = [s for s in self.file_list if curr_time.strftime("%Y-%m") in s]
        if not fname:
            raise FileNotFoundError(f"Could not find a file matching requested date {date_ex}")
        elif len(fname) > 1:
            raise ValueError(f"Files found for requested date {date_ex} is not unique.")
        
        ds = self.file_handles[fname[0]]  #
        return ds.sel({self.sample_dim: curr_time}).to_array()
                                        

To speed up data reading, we stage the files on `CSCRATCH`, the high performance storage tier at JSC.

In [None]:
! jutil env activate -p deepacf
! datadir="${CSCRATCH}/maelstrom/maelstrom_data/ap5_michael/preprocessed_era5_crea6/netcdf_data"
! echo $datadir
! datadir="${CSCRATCH}/maelstrom/maelstrom_data/ap5_michael/preprocessed_era5_crea6/netcdf_data"; for yr in {2006..2018}; do for mm in {01..12}; do /opt/ddn/ime/bin/ime-ctl --prestage ${datadir}/${yr}/${yr}-${mm}/preproc_${yr}-${mm}.nc; done; done

Let's check if the data is really staged:

In [None]:
! /opt/ddn/ime/bin/ime-ctl --frag-stat $CSCRATCH/maelstrom/maelstrom_data/ap5_michael/preprocessed_era5_crea6/netcdf_data/2015/2015-11/preproc_2015-11.nc

Let's run a first test on the Tier-2 dataset of MAELSTROM's downscaling application:

In [3]:
# Path to netCDF-files
datadir = "/p/scratch/deepacf/maelstrom/maelstrom_data/ap5_michael/preprocessed_era5_crea6/netcdf_data/all_files/"
# batch size
batch_size = 32
# number of test sets
test_steps = 500

# get number of available (virtual) CPUs
max_workers = multiprocessing.cpu_count()
workers = min(batch_size, max_workers) 

workers_now = int(workers)
print(f"Number of available (virtual) CPUs: {max_workers:d}.")

# instantiate an example monthly data stream
all_data = StreamMonthlyNetCDF(datadir, "preproc_*.nc", workers=int(workers_now))

Number of available (virtual) CPUs: 80.
Number of used workers: 32


Check the handled dataset (should have more than 100K samples):

In [4]:
ds = all_data.ds
print(ds)
# check timestamps and available number of samples
print(f"Available samples in dataset: {len(all_data):d}.")
print(all_data.times)

<xarray.Dataset>
Dimensions:       (time: 111548, rlon: 120, rlat: 120)
Coordinates:
  * time          (time) datetime64[ns] 2006-01-01T13:00:00 ... 2018-12-31T23...
  * rlon          (rlon) float64 -8.273 -8.218 -8.163 ... -1.838 -1.783 -1.728
  * rlat          (rlat) float64 -3.933 -3.878 -3.823 ... 2.502 2.557 2.612
Data variables:
    rotated_pole  (time) int32 1 1 1 1 1 1 1 1 1 1 1 1 ... 1 1 1 1 1 1 1 1 1 1 1
    2t_in         (time, rlat, rlon) float32 dask.array<chunksize=(639, 120, 120), meta=np.ndarray>
    sshf_in       (time, rlat, rlon) float32 dask.array<chunksize=(639, 120, 120), meta=np.ndarray>
    slhf_in       (time, rlat, rlon) float32 dask.array<chunksize=(639, 120, 120), meta=np.ndarray>
    blh_in        (time, rlat, rlon) float32 dask.array<chunksize=(639, 120, 120), meta=np.ndarray>
    10u_in        (time, rlat, rlon) float32 dask.array<chunksize=(639, 120, 120), meta=np.ndarray>
    10v_in        (time, rlat, rlon) float32 dask.array<chunksize=(639, 120, 120),

Wrap everything into a numpy-function for TensorFlow and create a dataset. Note that the indices are shuffled while also ensuring that the buffer size is large enough to enable 'reasonable' sampling (20 K corresponds to 2.5 years of data).

In [5]:
tf_fun2=lambda i: tf.numpy_function(all_data.getitems, [i] , tf.float64 )
inp=tf.Variable(range(10))
# some test
data_test = tf_fun2(inp)

# set-up TF dataset
ds=tf.data.Dataset.range(len(all_data)).shuffle(buffer_size=20000).batch(int(workers_now*4)) \
                  .map(tf_fun2).unbatch().batch(batch_size)

2022-12-21 13:11:14.278689: I tensorflow/core/platform/cpu_feature_guard.cc:142] This TensorFlow binary is optimized with oneAPI Deep Neural Network Library (oneDNN) to use the following CPU instructions in performance-critical operations:  AVX512F
To enable them in other operations, rebuild TensorFlow with the appropriate compiler flags.
2022-12-21 13:11:15.960292: I tensorflow/core/common_runtime/gpu/gpu_device.cc:1510] Created device /job:localhost/replica:0/task:0/device:GPU:0 with 14670 MB memory:  -> device: 0, name: Tesla V100-SXM2-16GB, pci bus id: 0000:60:00.0, compute capability: 7.0
2022-12-21 13:11:15.961518: I tensorflow/core/common_runtime/gpu/gpu_device.cc:1510] Created device /job:localhost/replica:0/task:0/device:GPU:1 with 14670 MB memory:  -> device: 1, name: Tesla V100-SXM2-16GB, pci bus id: 0000:61:00.0, compute capability: 7.0
2022-12-21 13:11:15.962234: I tensorflow/core/common_runtime/gpu/gpu_device.cc:1510] Created device /job:localhost/replica:0/task:0/device:

[0 1 2 3 4 5 6 7 8 9]


Run the test:

In [6]:
#%%timeit
#from timeit import default_timer as timer
batch_time = [timer()]
# test with half of the workers
for i, x in enumerate(ds):#tqdm(enumerate(ds)):
    if i == 0:
        print(tf.shape(x))
    elif i > test_steps -1:
        break
    print(i)
    batch_time.append(timer())

2022-12-21 13:11:17.129922: I tensorflow/compiler/mlir/mlir_graph_optimization_pass.cc:185] None of the MLIR Optimization Passes are enabled (registered 2)


[11099   223 19703 10904  5496 12782 10196 18934  6459 15883 16080  7995
 19303 14895 12261   629  9627  1310  1200 11806 19112  7789  9118 14279
 14764  9187 16604 18733  5029  3565  2921 14619 19796 17904 14293  9077
 11909  5104   402  9552  3132 10684  8041 16984 14013 10841  3820 17018
 12105 19309  4787  8651 19651 14367 17974  9847  8248  6754  7849  2922
  7511  6949 12293  3742 12897 20015 14401  5593  9148 14542  1468 18276
 13484  1758  5071  7459 13823 11517 18534 13866  1421  1649  2234 19711
 18627  2168  5579 10650  6878  5007  9018  2151 10700 11672 12772  2302
  8447  3933  9989 15107 13173  9116  9672 16374 14103  6285 13639 20101
 19203 10532 16038  7467  8612 19396 14554 15655 18091  6286 13211 19615
  6723 16203 13690 12544   489 17362  1145 12655]
tf.Tensor([ 32  12 120 120], shape=(4,), dtype=int32)
0
1
2
3
[16180 16018 13961 12543  1308 19447  4952 11577 12888 17601  7027 10724
 16582  3181  9253  8515 14439 15380  8728  3834  9514  8238  1482  8679
 16615  7981

In [9]:
batch_time = np.asarray(batch_time)
elapsed_times = batch_time[1:] - batch_time[0:-1]

print(f"Average time per batch: {np.mean(elapsed_times[3::]):.2f}s (+/- {np.std(elapsed_times[3::]):.3f}s). \n" +
      f"Total time: {np.sum(elapsed_times[3::]):.1f}s") 

Average time per batch: 1.43s (+/- 2.571s). 
Total time: 711.0s


In [None]:
print(elapsed_times)

## Preliminary result
The experiments conducted above show that reading from multiple netCDF-files for sampling constitutes a severe bottleneck. Even with a higher number of threads, it takes 1.5s to 2.5s to sample a mini-batch with size 32. This is much slower than the forward and backward step take, e.g. for U-Net (about 0.1 to 0.2s).

Without shuffling, the sampling is considerably quicker. This is obviously due to the fact that not several, but one netCDF-file is (most often) used for creating the mini-batch. 

The new idea is now to perform a manual shuffling before training. All the data will be read lazily and then shuffled to create netCDF-files from random time steps. By doing so, the netCDF-files can be consumed sequentially while also ensuring randomness in data sampling. For varying the ordering of the data during training, one might try to permute the file list order. However, it's not clear yet how this can be realized, maybe with the help of Keras Callbacks.

### Prepare the shuffled dataset

We start by first writing the data in a shuffled way to new netCDF files. Since coordinates in netCDF-files must be montonically ascending or descending, we need to introduce a sample index. <br>
For later merging (i.e. opening with `xr.open_mfdataset`), the sample index must be unique and thus will be defined globally. Thus, it will run from 1 to `len(dataset)`.

In [None]:
ds = all_data.ds 
times = all_data.times.copy(deep=True)
ntimes = np.shape(times)[0]
ds = ds.rename_dims({"time": "sample_ind"})
ds["sample_ind"] = range(np.shape(times)[0])

(Double-)Check the number of samples:

In [None]:
print(ntimes)

Since any parallelized writing of netCDF-files proved to be terribly slow (e.g. with `xr.save_mfdataset` or using `multiprocessing`), we pursue a sequential netCDF-creation procedure.

In [None]:
def sample2netcdf(id, indices):
    print(f"indices of process {id}: {indices}", flush=True)
    ds_subset = ds.isel({"sample_ind": indices}).load()
    print("Data loaded sucsessfully!", flush=True)
    
    nsamples_now = np.shape(ds_subset["sample_ind"])[0]
    ds_subset["sample_ind"] = range(nsamples_now)
    fname_now = os.path.join(datadir, "test2", f"ds_resampled_{id:0d}_test.nc")
    
    print(f"Write data subset to file '{fname_now}'.", flush=True)
    ds_subset.to_netcdf(fname_now)
    

In [None]:
samples_per_file = int(8640)

inds = np.arange(ntimes)
np.random.shuffle(inds)

# approach with multiprocessing -> slow
# t0 = timer()
# inds_list = [(i, inds[i*samples_per_file: (i+1)*samples_per_file]) for i in range(int(ntimes/samples_per_file))]
# with multiprocessing.pool.ThreadPool(4) as Pool:
#    for _ in Pool.starmap(sample2netcdf, inds_list):
#        print("Done!")
        
# print(f"File creation took {timer()-t0:.2f}s.")

# approach with xr.save_mfdataset -> slow
# fname_list, ds_list = [], []
# for i in range(int(ntimes/samples_per_file)):    
    # fname_list.append(os.path.join(datadir, "test2", f"ds_resampled_{i:0d}.nc"))
    # ds_list.append(ds.isel({"sample_ind": inds_now}))
# print(fname_list)
# t0 = timer()
# xr.save_mfdataset(ds_list, fname_list, mode="w")
# print(f"Saving data took {timer()-t0:.1f}s.")

samples_per_file = int(8640)

inds = np.arange(ntimes)
np.random.shuffle(inds)

t0 = timer()
for i in range(int(ntimes/samples_per_file)):
    
    inds_now = inds[i*samples_per_file: (i+1)*samples_per_file]
    print(f"Load data to memory for {i+1:d}th subset...")
    ds_subset = ds.isel({"sample_ind": inds_now}).load()
    print("Data loaded sucsessfully!")
    
    nsamples_now = np.shape(ds_subset["sample_ind"])[0]
    ds_subset["sample_ind"] = range(i*nsamples_now, (i+1)*nsamples_now)
    fname_now = os.path.join(datadir, "test", f"ds_resampled_{i:0d}.nc")
    
    print(f"Write data subset to file '{fname_now}'.")
    ds_subset.to_netcdf(fname_now)
    
print(f"File creation took {timer()-t0:.2f}s.")

Next, we stage (again) the netCDF files to `CSCRATCH` for quicker data access: 

In [None]:
! datadir="${CSCRATCH}/maelstrom/maelstrom_data/ap5_michael/preprocessed_era5_crea6/netcdf_data/all_files/test"; for i in {0..11}; do /opt/ddn/ime/bin/ime-ctl --prestage ${datadir}/ds_resampled_${i}.nc; done
! datadir="${CSCRATCH}/maelstrom/maelstrom_data/ap5_michael/preprocessed_era5_crea6/netcdf_data/all_files/test"; /opt/ddn/ime/bin/ime-ctl --frag-stat ${datadir}/ds_resampled_11.nc

In [24]:
class StreamMonthlyNetCDF():
    def __init__(self, datadir, patt, workers=4, sample_dim: str = "time", samples_per_file: int = 8640):
        self.data_dir = datadir
        self.file_list = patt
        self.ds = xr.open_mfdataset(list(self.file_list), parallel=True)
        self.sample_dim = sample_dim
        self.times = self.ds[sample_dim].load()
        self.nsamples = self.ds.dims[sample_dim]
        self.samples_per_file = samples_per_file
        self.ds_now = None
        self.loaded_files = []
        
        print(f"Number of used workers: {workers:d}")
        self.pool = multiprocessing.pool.ThreadPool(workers)
    def __len__(self):
        return self.nsamples

    def __getitem__(self, i):
        data = self.index_to_sample(i)
        return data
    
    def getitems(self, indices):
        inds_fname = list(set([int(i/self.samples_per_file) for i in indices]))
        # before getting the data, check if we must load new files
        if self.ds_now is None or not set(self.file_list[inds_fname]) == set(self.loaded_files):
            print(f"Load datafiles {*self.file_list[inds_fname],}")
            self.loaded_files = self.file_list[inds_fname]
            self.ds_now = xr.open_mfdataset(list(self.loaded_files)).load()
        return np.array(self.pool.map(self.__getitem__ , indices))
    
    @property
    def data_dir(self):
        return self._data_dir
    
    @data_dir.setter 
    def data_dir(self, datadir):
        if not os.path.isdir(datadir):
            raise DirectoryNotFoundError(f"Parsed data directory '{datadir}' does not exist.")
            
        self._data_dir = datadir
        
    @property 
    def file_list(self):
        return self._file_list 
    
    @file_list.setter
    def file_list(self, patt):        
        patt = patt if patt.endswith(".nc") else f"{patt}.nc" 
        files = glob.glob(os.path.join(self.data_dir, patt))
        
        if not files:
            raise FileNotFoundError(f"Could not find any files with pattern '{patt}' under '{self.data_dir}'.")
            
        self._file_list = np.asarray(sorted(files, key=lambda s: int(re.search(r'\d+', os.path.basename(s)).group())))
        
    @property
    def sample_dim(self):
        return self._sample_dim 
    
    @sample_dim.setter
    def sample_dim(self, sample_dim):
        if not sample_dim in self.ds.dims:
            raise KeyError(f"Could not find dimension '{sample_dim}' in data.")
            
        self._sample_dim = sample_dim 
        
    def index_to_sample(self, index):    
        try:
            return self.ds_now.sel({self.sample_dim: index}).to_array()
        except Exception as err:
            # interestingly, this proves to work (racing condition?)
            print(self.ds_now)
            print(index)
            print(index in self.ds_now["sample_ind"])
            print({self.sample_dim: index})
            print(self.ds_now.sel({self.sample_dim: index}))
            return self.ds_now.sel({self.sample_dim: index}).to_array()
        #return ds.sel({self.sample_dim: curr_time}).to_array()

Next, we will instantiate the data object, ...

In [25]:
all_data_new = StreamMonthlyNetCDF(os.path.join(datadir, "test"), "ds_resampled_*.nc", workers=int(workers_now), sample_dim = "sample_ind")

['/p/scratch/deepacf/maelstrom/maelstrom_data/ap5_michael/preprocessed_era5_crea6/netcdf_data/all_files/test/ds_resampled_0.nc'
 '/p/scratch/deepacf/maelstrom/maelstrom_data/ap5_michael/preprocessed_era5_crea6/netcdf_data/all_files/test/ds_resampled_1.nc'
 '/p/scratch/deepacf/maelstrom/maelstrom_data/ap5_michael/preprocessed_era5_crea6/netcdf_data/all_files/test/ds_resampled_2.nc'
 '/p/scratch/deepacf/maelstrom/maelstrom_data/ap5_michael/preprocessed_era5_crea6/netcdf_data/all_files/test/ds_resampled_3.nc'
 '/p/scratch/deepacf/maelstrom/maelstrom_data/ap5_michael/preprocessed_era5_crea6/netcdf_data/all_files/test/ds_resampled_4.nc'
 '/p/scratch/deepacf/maelstrom/maelstrom_data/ap5_michael/preprocessed_era5_crea6/netcdf_data/all_files/test/ds_resampled_5.nc'
 '/p/scratch/deepacf/maelstrom/maelstrom_data/ap5_michael/preprocessed_era5_crea6/netcdf_data/all_files/test/ds_resampled_6.nc'
 '/p/scratch/deepacf/maelstrom/maelstrom_data/ap5_michael/preprocessed_era5_crea6/netcdf_data/all_files/

In [26]:
print(all_data_new.ds)

<xarray.Dataset>
Dimensions:       (sample_ind: 103680, rlon: 120, rlat: 120)
Coordinates:
    time          (sample_ind) datetime64[ns] 2017-01-25T02:00:00 ... 2012-11...
  * rlon          (rlon) float64 -8.273 -8.218 -8.163 ... -1.838 -1.783 -1.728
  * rlat          (rlat) float64 -3.933 -3.878 -3.823 ... 2.502 2.557 2.612
  * sample_ind    (sample_ind) int64 0 1 2 3 4 ... 103676 103677 103678 103679
Data variables:
    rotated_pole  (sample_ind) int32 dask.array<chunksize=(8640,), meta=np.ndarray>
    2t_in         (sample_ind, rlat, rlon) float32 dask.array<chunksize=(8640, 120, 120), meta=np.ndarray>
    sshf_in       (sample_ind, rlat, rlon) float32 dask.array<chunksize=(8640, 120, 120), meta=np.ndarray>
    slhf_in       (sample_ind, rlat, rlon) float32 dask.array<chunksize=(8640, 120, 120), meta=np.ndarray>
    blh_in        (sample_ind, rlat, rlon) float32 dask.array<chunksize=(8640, 120, 120), meta=np.ndarray>
    10u_in        (sample_ind, rlat, rlon) float32 dask.array<chun

built a TF dataset (without shuffling to ensure sequential data loading), ...

In [27]:
tf_fun2=lambda i: tf.numpy_function(all_data_new.getitems, [i] , tf.float64)

# same experiment with all workers
#ds=tf.data.Dataset.range(8578, 8643).batch(int(33)) \
ds=tf.data.Dataset.range(len(all_data)).batch(int(33)) \
                  .map(tf_fun2).unbatch().batch(32)

... and conduct the test:

In [28]:
batch_time = [timer()]
test_steps = 500

print(f"Test for {test_steps}-times")
# test with half of the workers
for i, x in enumerate(ds):#tqdm(enumerate(ds)):
    if i == 0:
        print(tf.shape(x))
    elif i > test_steps -1:
        break
    batch_time.append(timer())
    print(i)
    
batch_time = np.asarray(batch_time)
elapsed_times = batch_time[1:] - batch_time[0:-1]

print(f"Average time per batch: {np.mean(elapsed_times):.2f} (+/- {np.std(elapsed_times):.3f}). \n" +
      f"Total time: {np.sum(elapsed_times):.1f}") 

Test for 500-times
Load datafiles ('/p/scratch/deepacf/maelstrom/maelstrom_data/ap5_michael/preprocessed_era5_crea6/netcdf_data/all_files/test/ds_resampled_0.nc',)
tf.Tensor([ 32  12 120 120], shape=(4,), dtype=int32)
0
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222


It is seen that data sampling now is much quicker with an average creation time below 0.2s.
Thus, this approach is a candidate for a real test when training the model. <br>
However, open issues persist. These are:
- [ ] Some missing data when the total number of samples is not a divider of samples_per_file
- [ ] Potential racing condition when self.ds_now has to be updated (see the hacky try-except handling)
- [ ] Fixed ordering of shuffled training samples (How to get variation into it?)