In [4]:
import xarray as xr
import numpy as np
import zarr

In [5]:
data_dir = '../data'
zarr_group_path = f'{data_dir}/test'
chunk_sizes = {'time_periods': 1, 'parameters': 1, 'depth': 8, 'lat': 90, 'lon': 360}


In [None]:
# Simulate creating and writing the initial dataset to Zarr
ds_initial = xr.Dataset(
    {
        'mn': (('time_periods', 'parameters', 'depth', 'lat', 'lon'), np.random.rand(1, 1, 57, 720, 1440)),
        'an': (('time_periods', 'parameters', 'depth', 'lat', 'lon'), np.random.rand(1, 1, 57, 720, 1440))
    },
    coords={
        'time_periods': ['1'],
        'parameters': ['temperature'],
        'depth': np.linspace(0, 1500, 57),
        'lat': np.linspace(-90, 90, 720),
        'lon': np.linspace(-180, 180, 1440)
    }
)
ds_initial = ds_initial.chunk(chunk_sizes)
ds_initial.to_zarr(zarr_group_path, mode='w')

# Verify the initial dataset
dt1 = xr.open_zarr(zarr_group_path, consolidated=False)
print(dt1)

In [8]:
# Function to reindex and combine new data
def reindex_and_combine_existing_with_new(existing_ds, new_ds, param_name, period_key):
    if param_name not in existing_ds.parameters.values:
        new_params = np.append(existing_ds.parameters.values, param_name)
        existing_ds = existing_ds.reindex(parameters=new_params, fill_value=np.nan)

    if period_key not in existing_ds.time_periods.values:
        new_periods = np.append(existing_ds.time_periods.values, period_key)
        existing_ds = existing_ds.reindex(time_periods=new_periods, fill_value=np.nan)

    for var in new_ds.data_vars:
        existing_ds[var] = existing_ds[var].combine_first(new_ds[var])

    return existing_ds

# Simulate loading new data and combining it with existing data
param_name = 'temperature'
period_key = '2'
ds_new = xr.Dataset(
    {
        'mn': (('time_periods', 'parameters', 'depth', 'lat', 'lon'), np.random.rand(1, 1, 57, 720, 1440)),
        'an': (('time_periods', 'parameters', 'depth', 'lat', 'lon'), np.random.rand(1, 1, 57, 720, 1440))
    },
    coords={
        'time_periods': [period_key],
        'parameters': [param_name],
        'depth': np.linspace(0, 1500, 57),
        'lat': np.linspace(-90, 90, 720),
        'lon': np.linspace(-180, 180, 1440)
    }
)
ds_new = ds_new.chunk(chunk_sizes)

# Open the existing Zarr store and combine with new data
ds_existing = xr.open_zarr(zarr_group_path, consolidated=True)
ds_combined = reindex_and_combine_existing_with_new(ds_existing, ds_new, param_name, period_key)

# Verify combined data
dt1_sel = ds_combined.sel(parameters='temperature', time_periods='1')
t_mn1 = dt1_sel['mn'].values
if np.any(np.isfinite(t_mn1)):
    print("First run has non-na values")
else:
    print("First run all na values")

dt2_sel = ds_combined.sel(parameters='temperature', time_periods='2')
t_mn2 = dt2_sel['mn'].values
if np.any(np.isfinite(t_mn2)):
    print("Second run has non-na values")
else:
    print("Second run all na values")

# Save the updated combined dataset back to the Zarr store
ds_combined.to_zarr(zarr_group_path, mode='w')


First run has non-na values
Second run has non-na values


<xarray.backends.zarr.ZarrStore at 0x7fc9b1cb8940>

In [30]:
# Verify the saved dataset
dt1 = xr.open_zarr(zarr_group_path, consolidated=True)
print(dt1)


<xarray.Dataset> Size: 946MB
Dimensions:       (time_periods: 1, parameters: 1, depth: 57, lat: 720,
                   lon: 1440)
Coordinates:
  * depth         (depth) float64 456B 0.0 26.79 53.57 ... 1.473e+03 1.5e+03
  * lat           (lat) float64 6kB -90.0 -89.75 -89.5 ... 89.5 89.75 90.0
  * lon           (lon) float64 12kB -180.0 -179.7 -179.5 ... 179.5 179.7 180.0
  * parameters    (parameters) <U11 44B 'temperature'
  * time_periods  (time_periods) <U1 4B '1'
Data variables:
    an            (time_periods, parameters, depth, lat, lon) float64 473MB dask.array<chunksize=(1, 1, 8, 90, 360), meta=np.ndarray>
    mn            (time_periods, parameters, depth, lat, lon) float64 473MB dask.array<chunksize=(1, 1, 8, 90, 360), meta=np.ndarray>


In [10]:
if True:   
   dt1_sel = dt1.sel(parameters='temperature', time_periods='1')
   t_mn1 = dt1_sel['mn'].values
   if np.any(np.isfinite(t_mn1)):
      print("has non-na values\n")
   else:
      print("All na values\n")
    
   dt2_sel = dt1.sel(parameters='temperature', time_periods='2')
   t_mn2 = dt2_sel['mn'].values
   if np.any(np.isfinite(t_mn2)):
      print("has non-na values\n")
   else:
      print("All na values\n")
 
   

All na values

has non-na values



In [32]:
# Region method
# Function to reindex and combine new data using the `region` argument
def reindex_and_combine_using_region(existing_ds, new_ds, param_name, period_key, zarr_store):
    if param_name not in existing_ds.parameters.values:
        new_params = np.append(existing_ds.parameters.values, param_name)
        existing_ds = existing_ds.reindex(parameters=new_params, fill_value=np.nan)

    if period_key not in existing_ds.time_periods.values:
        new_periods = np.append(existing_ds.time_periods.values, period_key)
        existing_ds = existing_ds.reindex(time_periods=new_periods, fill_value=np.nan)

    for var in new_ds.data_vars:
        # Define the region to update
        param_idx = np.where(existing_ds.parameters.values == param_name)[0][0]
        period_idx = np.where(existing_ds.time_periods.values == period_key)[0][0]
        region = {
            'parameters': slice(param_idx, param_idx + 1),
            'time_periods': slice(period_idx, period_idx + 1),
            'depth': slice(None),
            'lat': slice(None),
            'lon': slice(None)
        }

        # Extract the new data for the specific region
        new_data = new_ds[var].data[0, 0, :, :, :]

        # Create a temporary DataArray with the correct dimensions
        temp_da = xr.DataArray(
            new_data,
            dims=['depth', 'lat', 'lon'],
            coords={
                'depth': existing_ds.depth,
                'lat': existing_ds.lat,
                'lon': existing_ds.lon
            }
        )

        # Update the existing dataset with the new data in the specified region
        print(existing_ds)
        print(existing_ds[var])
        existing_ds[var].loc[dict(parameters=param_name, time_periods=period_key)] = temp_da

        # Write the updated data to Zarr store without re-writing the coordinates
        existing_ds[var].to_zarr(zarr_store, region=region)

    return existing_ds


In [None]:
# Simulate loading new data and combining it with existing data
param_name = 'temperature'
period_key = '2'
ds_new = xr.Dataset(
    {
        'mn': (('time_periods', 'parameters', 'depth', 'lat', 'lon'), np.random.rand(1, 1, 57, 720, 1440)),
        'an': (('time_periods', 'parameters', 'depth', 'lat', 'lon'), np.random.rand(1, 1, 57, 720, 1440))
    },
    coords={
        'time_periods': [period_key],
        'parameters': [param_name],
        'depth': np.linspace(0, 1500, 57),
        'lat': np.linspace(-90, 90, 720),
        'lon': np.linspace(-180, 180, 1440)
    }
)
ds_new = ds_new.chunk(chunk_sizes)

# Open the existing Zarr store and combine with new data
ds_existing = xr.open_zarr(zarr_group_path, consolidated=True)
ds_combined = reindex_and_combine_using_region(ds_existing, ds_new, param_name, period_key, zarr_group_path)

# Verify combined data
dt1_sel = ds_combined.sel(parameters='temperature', time_periods='1')
t_mn1 = dt1_sel['mn'].values
if np.any(np.isfinite(t_mn1)):
    print("First run has non-na values")
else:
    print("First run all na values")

dt2_sel = ds_combined.sel(parameters='temperature', time_periods='2')
t_mn2 = dt2_sel['mn'].values
if np.any(np.isfinite(t_mn2)):
    print("Second run has non-na values")
else:
    print("Second run all na values")

# Save the updated combined dataset back to the Zarr store
ds_combined.to_zarr(zarr_group_path, mode='w')

# Verify the saved dataset
dt1 = xr.open_zarr(zarr_group_path, consolidated=True)
print(dt1)

# Test the data after saving
if True:   
   dt1_sel = dt1.sel(parameters='temperature', time_periods='1')
   t_mn1 = dt1_sel['mn'].values
   if np.any(np.isfinite(t_mn1)):
      print("has non-na values\n")
   else:
      print("All na values\n")
    
   dt2_sel = dt1.sel(parameters='temperature', time_periods='2')
   t_mn2 = dt2_sel['mn'].values
   if np.any(np.isfinite(t_mn2)):
      print("has non-na values\n")
   else:
      print("All na values\n")

In [6]:
data_dir = '../data'
zarr_group_path = f'{data_dir}/test'
# Define paths and parameters
grid_resolutions = {'01': '1.00', '04': '0.25'}
grid_dir = {'01': '1_degree', '04': '0.25_degree'}
parameters = {
    't': 'temperature',
    's': 'salinity',
    'o': 'DOXY',
    'O': 'O2S',
    'A': 'AOU',
    'i': 'silicate',
    'p': 'phosphate',
    'n': 'nitrate'
}
time_periods = {
    '0': 'annual',
    '1': 'january',
    '2': 'february',
    '3': 'march',
    '4': 'april',
    '5': 'may',
    '6': 'june',
    '7': 'july',
    '8': 'august',
    '9': 'september',
    '10': 'october',
    '11': 'november',
    '12': 'december',
    '13': 'winter',
    '14': 'spring',
    '15': 'summer',
    '16': 'autumn'
}
data_variables = ['an', 'mn', 'dd', 'ma', 'sd', 'se', 'oa', 'gp', 'sdo', 'sea']
chunk_sizes = {'time_periods': 1, 'parameters': 1, 'depth': 8, 'lat': 90, 'lon': 360}

# Function to initialize Zarr store with required dimensions incrementally
def initialize_zarr_store(zarr_group_path, ds, chunk_sizes):
    # Initialize the store with only the coordinates
    empty_ds = xr.Dataset(coords=ds.coords)
    empty_ds.to_zarr(zarr_group_path, mode='w')

    # Incrementally add each variable to avoid memory exhaustion
    for var in data_variables:
        data_shape = (
            len(ds.coords['time_periods']),
            len(ds.coords['parameters']),
            len(ds.coords['depth']),
            len(ds.coords['lat']),
            len(ds.coords['lon'])
        )
        temp_data = np.empty(data_shape, dtype=np.float32)
        temp_data.fill(np.nan)  # Fill with NaNs to indicate empty data
        temp_ds = xr.Dataset({var: (('time_periods', 'parameters', 'depth', 'lat', 'lon'), temp_data)}, coords=ds.coords)
        temp_ds = temp_ds.chunk(chunk_sizes)
        temp_ds.to_zarr(zarr_group_path, mode='a')
        del temp_data, temp_ds  # Clear variables to free up memory


In [5]:
ds_initial = xr.Dataset(
    coords={
        'lon': np.linspace(-180, 180, 1440),
        'lat': np.linspace(-90, 90, 720),
        'depth': np.linspace(0, 1500, 57),
        'parameters': list(parameters.values()),
        'time_periods': list(time_periods.keys())  # Use numeric keys for time_periods
    }
)

print(ds_initial)

<xarray.Dataset> Size: 18kB
Dimensions:       (lon: 1440, lat: 720, depth: 57, parameters: 8,
                   time_periods: 17)
Coordinates:
  * lon           (lon) float64 12kB -180.0 -179.7 -179.5 ... 179.5 179.7 180.0
  * lat           (lat) float64 6kB -90.0 -89.75 -89.5 ... 89.5 89.75 90.0
  * depth         (depth) float64 456B 0.0 26.79 53.57 ... 1.473e+03 1.5e+03
  * parameters    (parameters) <U11 352B 'temperature' 'salinity' ... 'nitrate'
  * time_periods  (time_periods) <U2 136B '0' '1' '2' '3' ... '14' '15' '16'
Data variables:
    *empty*


In [None]:
initialize_zarr_store(zarr_group_path, ds_initial, chunk_sizes)


In [3]:
# Verify the initialized Zarr store
dt1 = xr.open_zarr(zarr_group_path, consolidated=False)
print(dt1)

<xarray.Dataset> Size: 80GB
Dimensions:       (time_periods: 17, parameters: 2, depth: 57, lat: 720,
                   lon: 1440)
Coordinates:
  * depth         (depth) float64 456B 0.0 26.79 53.57 ... 1.473e+03 1.5e+03
  * lat           (lat) float64 6kB -90.0 -89.75 -89.5 ... 89.5 89.75 90.0
  * lon           (lon) float64 12kB -180.0 -179.7 -179.5 ... 179.5 179.7 180.0
  * parameters    (parameters) <U11 88B 'temperature' 'salinity'
  * time_periods  (time_periods) <U2 136B '0' '1' '2' '3' ... '14' '15' '16'
Data variables:
    an            (time_periods, parameters, depth, lat, lon) float32 8GB dask.array<chunksize=(1, 1, 8, 90, 360), meta=np.ndarray>
    dd            (time_periods, parameters, depth, lat, lon) float32 8GB dask.array<chunksize=(1, 1, 8, 90, 360), meta=np.ndarray>
    gp            (time_periods, parameters, depth, lat, lon) float32 8GB dask.array<chunksize=(1, 1, 8, 90, 360), meta=np.ndarray>
    ma            (time_periods, parameters, depth, lat, lon) float32 

In [46]:
# Function to append a new dataset to the existing Zarr store
def append_to_zarr_store(zarr_group_path, nc_file, param_key, period_key):
    # Open the existing Zarr store
    ds_existing = xr.open_zarr(zarr_group_path, consolidated=False)
    
    # Load the new dataset from NetCDF with decode_times=False
    ds_new = xr.open_dataset(nc_file, decode_times=False)

    # Drop unused variables if they exist
    ds_new = ds_new.drop_vars(['crs', 'lat_bnds', 'lon_bnds', 'depth_bnds', 'climatology_bounds'], errors='ignore')
    
    # Rename variables according to data_variables
    rename_vars = {f'{param_key}_{var}': var for var in data_variables}
    ds_new = ds_new.rename(rename_vars)
    
    # Remove the time dimension if it exists
    for var in data_variables:
        if 'time' in ds_new[var].dims:
            ds_new[var] = ds_new[var].squeeze('time', drop=True)
    ds_new = ds_new.drop_vars('time', errors='ignore')

    # Extract the parameter name
    param_name = parameters[param_key]

    # Ensure the new dataset has the correct coordinates and dimensions
    ds_new = ds_new.expand_dims({'parameters': [param_name], 'time_periods': [period_key]})

    for var in ds_new.data_vars:
        # Define the region to update
        # param_idx = list(ds_existing.parameters.values).index(param_name)
        # period_idx = list(ds_existing.time_periods.values).index(period_key)   
        #region = {
        #    'parameters': slice(param_idx, param_idx + 1),
        #    'time_periods': slice(period_idx, period_idx + 1),
        #    'depth': slice(None),
        #    'lat': slice(None),
        #    'lon': slice(None)
        #}

        # Extract the relevant data for the region update
        new_data = ds_new[var].data.squeeze()
        
        # Create a temporary DataArray with the correct dimensions
        temp_da = xr.DataArray(
            new_data,
            dims=['depth', 'lat', 'lon'],
            coords={
                'depth': ds_existing.depth,
                'lat': ds_existing.lat,
                'lon': ds_existing.lon
            }
        )
        
        # Update the existing dataset with the new data in the specified region
        ds_existing[var].loc[dict(parameters=param_name, time_periods=period_key)] = temp_da

    # Write the updated data to the Zarr store without re-writing the coordinates
    ds_existing.chunk(chunk_sizes).to_zarr(zarr_group_path, mode='a')


In [37]:
nc_file = '../tmp_data/woa23_decav_t01_04.nc'
#ds = xr.open_dataset(nc_file, decode_times=False)
#print(ds)
append_to_zarr_store(zarr_group_path, nc_file, 't', '1')

In [7]:
# Verify the initialized Zarr store
dt1 = xr.open_zarr(zarr_group_path, consolidated=False)
print(dt1)

<xarray.Dataset> Size: 80GB
Dimensions:       (time_periods: 17, parameters: 2, depth: 57, lat: 720,
                   lon: 1440)
Coordinates:
  * depth         (depth) float64 456B 0.0 26.79 53.57 ... 1.473e+03 1.5e+03
  * lat           (lat) float64 6kB -90.0 -89.75 -89.5 ... 89.5 89.75 90.0
  * lon           (lon) float64 12kB -180.0 -179.7 -179.5 ... 179.5 179.7 180.0
  * parameters    (parameters) <U11 88B 'temperature' 'salinity'
  * time_periods  (time_periods) <U2 136B '0' '1' '2' '3' ... '14' '15' '16'
Data variables:
    an            (time_periods, parameters, depth, lat, lon) float32 8GB dask.array<chunksize=(1, 1, 8, 90, 360), meta=np.ndarray>
    dd            (time_periods, parameters, depth, lat, lon) float32 8GB dask.array<chunksize=(1, 1, 8, 90, 360), meta=np.ndarray>
    gp            (time_periods, parameters, depth, lat, lon) float32 8GB dask.array<chunksize=(1, 1, 8, 90, 360), meta=np.ndarray>
    ma            (time_periods, parameters, depth, lat, lon) float32 

In [40]:
if True:   
   dt1_sel = dt1.sel(parameters='temperature', time_periods='1')
   t_mn1 = dt1_sel['mn'].values
   if np.any(np.isfinite(t_mn1)):
      print("has non-na values\n")
   else:
      print("All na values\n")

has non-na values



In [41]:
nc_file = '../tmp_data/woa23_decav_t02_04.nc'
append_to_zarr_store(zarr_group_path, nc_file, 't', '2')

In [42]:
if True:   
   dt1_sel = dt1.sel(parameters='temperature', time_periods='1')
   t_mn1 = dt1_sel['mn'].values
   if np.any(np.isfinite(t_mn1)):
      print("has non-na values\n")
   else:
      print("All na values\n")
    
   dt2_sel = dt1.sel(parameters='temperature', time_periods='2')
   t_mn2 = dt2_sel['mn'].values
   if np.any(np.isfinite(t_mn2)):
      print("has non-na values\n")
   else:
      print("All na values\n")

has non-na values

has non-na values



In [43]:
nc_file = '../tmp_data/woa23_decav_s01_04.nc'
append_to_zarr_store(zarr_group_path, nc_file, 's', '1')

In [None]:
nc_file = '../tmp_data/woa23_decav_s02_04.nc'
append_to_zarr_store(zarr_group_path, nc_file, 's', '2')

In [8]:
if True:   
   dt1_sel = dt1.sel(parameters='temperature', time_periods='1')
   t_mn1 = dt1_sel['mn'].values
   if np.any(np.isfinite(t_mn1)):
      print("has non-na values\n")
   else:
      print("All na values\n")
    
   dt2_sel = dt1.sel(parameters='temperature', time_periods='2')
   t_mn2 = dt2_sel['mn'].values
   if np.any(np.isfinite(t_mn2)):
      print("has non-na values\n")
   else:
      print("All na values\n")

   dt3_sel = dt1.sel(parameters='salinity', time_periods='1')
   s_mn1 = dt3_sel['mn'].values
   if np.any(np.isfinite(s_mn1)):
      print("has non-na values\n")
   else:
      print("All na values\n")

   dt4_sel = dt1.sel(parameters='salinity', time_periods='1')
   s_mn1 = dt3_sel['mn'].values
   if np.any(np.isfinite(s_mn1)):
      print("has non-na values\n")
   else:
      print("All na values\n")                  

has non-na values

has non-na values

has non-na values

has non-na values

