In [None]:
import numpy as np
import os
import xarray as xr
from config import input_dir, output_dir

import sys
sys.path.append('ECCOv4-py/ECCOv4-py')
import ecco_v4_py as ecco

In [None]:
# Convert time indexes to year.decimal
def convert_time(ds, start_year):
    if not ds['time'][0] == start_year:
        ds = ds.assign_coords({
            'time': start_year + (ds['time'] / 12)  # Convert time to "year.decimal"
        })
    # round from float64 to float32 to get rid of rounding errors.
    ds = ds.assign_coords({'time': ds['time'].astype('float32')})
    return ds

In [None]:
def to_netcdf_ecco(ecco_data, file_name):

    # Load ecco grid into an array with dimensions: longitudes (XC), latitudes (YC), 13 tiles 
    eccor5_grid = ecco.load_ecco_grid_nc('/glade/u/home/mengnanz/p2375_bp_seasonal_cycle/input_dir/ECCOllc90/r5_nctiles_grid', 'ECCO-GRID.nc')

    # Extract coordinates from the grid
    # (dim_0: 336, dim_1: 13, dim_2: 90, dim_3: 90)>
    i = eccor5_grid['i']  # Shape: (90,)
    j = eccor5_grid['j']  # Shape: (90,)
    tile = eccor5_grid['tile']  # Shape: (13,)
    if isinstance(ecco_data, np.ndarray):
        time=(['time'],range(0,len(ecco_data[:,0,0,0])))
    elif 'time' in ecco_data.dims:
        time = ecco_data['time']
    else:
        time = [0]

    if '__xarray_dataarray_variable__' in ecco_data:
        ecco_data.rename({'__xarray_dataarray_variable__': 'pb'})

    if isinstance(ecco_data, np.ndarray):
        values = ecco_data
    else:
        values = ecco_data.values

    # Create an xarray DataArray from pb_r5_with_ext_mon
    pb_ecco = xr.DataArray(
        data=values,  # numpy array
        dims=['time', 'tile', 'j', 'i'],  # Dimensions
        coords={
            'time': time,
            'tile': tile,
            'j': j,
            'i': i
        },
        name='pb',  # Name of the DataArray
        attrs={'units': 'cm/month'} 
    )

    pb_ecco.to_netcdf(os.path.join(output_dir, file_name))

In [None]:

def to_netcdf(array, lats, lons, time, name, units, description):
    
    data_vars = dict(
            pb=(["time","j", "i"], array)
        )
    
    if units is not None:
        data_vars = dict(
            pb=(["time","j", "i"], array, {"units": units}),
        )
    
    ds = xr.Dataset(
        
        data_vars=data_vars,
        coords=dict(
            lon=(["i"], lats),
            lat=(["j"], lons),
            time=(["time"], time)
        ),
        attrs=dict(description=description),
    )
    save_to = os.path.join(output_dir, f'{name}.nc')
    print(f'save to ds {save_to}')
    ds.to_netcdf(save_to)
    

In [None]:
def to_netcdf_grace_grid(data, file_name, lon_grace=None, lat_grace=None, time=None):

    if lon_grace is None:
        try:
            lat_grace = data.lat.values
            lon_grace = data.lon.values
        except Exception as e:
            lat_grace = data.y.values
            lon_grace = data.x.values
        time = data.time.values

    if 'pb' in data:
        data = data['pb']

    if isinstance(data, np.ndarray):
        data_values = data
    else:
        data_values = data.values

    ds = xr.Dataset(
        data_vars=dict(
            pb=(["time","y", "x"], data_values, {"units": "cm"}),
        ),
        coords=dict(
            lon=(["x"], lon_grace),
            lat=(["y"], lat_grace),
            time=(["time"],time)
        ),
        attrs=dict(description="GRACE pb, cm"),
    )
    print(f'save data to {os.path.join(output_dir, file_name)}')
    ds.to_netcdf(os.path.join(output_dir, file_name))

In [None]:
def to_netcdf_grace(data, file_name, lon_grace, lat_grace, time):
    '''
    For GRACE data on an ECCO grid

    '''

    ds = xr.Dataset(
        data_vars=dict(
            pb=(["time","j", "i"], data, {"units": "cm"}),
        ),
        coords=dict(
            lon=(["i"], lon_grace),
            lat=(["j"], lat_grace),
            time=(["time"],time)
        ),
        attrs=dict(description="GRACE pb, cm"),
    )
    ds.to_netcdf(os.path.join(output_dir, file_name))

In [None]:
import sys
sys.path.append('ECCOv4-py/ECCOv4-py')
import ecco_v4_py as ecco
ecco_grid = ecco.load_ecco_grid_nc('/glade/u/home/mengnanz/p2375_bp_seasonal_cycle/input_dir/ECCOllc90/r5_nctiles_grid', 'ECCO-GRID.nc')
def to_netcdf_tiles(array, name, units, description):

    # Extract coordinates from the grid
    # (dim_0: 336, dim_1: 13, dim_2: 90, dim_3: 90)>
    i = ecco_grid['i']  # Shape: (90,)
    j = ecco_grid['j']  # Shape: (90,)
    tile = ecco_grid['tile']  # Shape: (13,)
    time=(['time'],range(0,len(array[:,0,0,0])))

    data_array = xr.DataArray(
        data=array,  # Your numpy array
        dims=['time', 'tile', 'j', 'i'],  # Dimensions
        coords={
            'time': time,
            'tile': tile,
            'j': j,
            'i': i
        },
        name='pb',  # Name of the DataArray
        attrs={'units': 'mm'} 
    )
    
    # Convert time indexes to year.decimal
    array = data_array.assign_coords({
        'time': 1992.0 + (data_array['time'] / 12)  # Convert time to "year.decimal"
    })

    save_to = os.path.join(output_dir, f'{name}.nc')
    print(f'save to ds {save_to}')
    array.to_netcdf(save_to)

In [None]:

def to_netcdf_msc(array, name, units, description):
    
    data_vars = dict(
            pb=(["tile","j", "i", "month"], array.data)
        )
    
    if units is not None:
        data_vars = dict(
            pb=(["tile","j", "i", "month"], array, {"units": units}),
        )
    
    ds = xr.Dataset(
        
        data_vars=data_vars,
        coords=dict(
            i=(["i"], array["i"].data),
            j=(["j"], array["j"].data),
            tile=(["tile"],array["tile"].data),
            months=(["month"],array["month"].data)
        ),
        attrs=dict(description=description),
    )
    save_to = os.path.join(output_dir, f'{name}.nc')
    print(f'save to ds {save_to}')
    ds.to_netcdf(save_to)

In [None]:

# maybe i need to pass in another dataarray hat has meta data attached?
def to_netcdf_msc_data(array, data_array, name, units, description):
    '''
    array: only used to get dimensions
    data_array: get data from this
    
    '''
    data_vars = dict(
            pb=(["time", "tile", "j", "i"], data_array.data)
        )
    
    if units is not None:
        data_vars = dict(
            pb=(["time", "tile","j", "i"], data_array, {"units": units}),
        )
    
    ds = xr.Dataset(
        
        data_vars=data_vars,
        coords=dict(
            i=(["i"], array["i"].data),
            j=(["j"], array["j"].data),
            tile=(["tile"],array["tile"].data),
            months=(["month"],[1,2,3,4,5,6,7,8,9,10,11,12])
        ),
        attrs=dict(description=description),
    )
    save_to = os.path.join(output_dir, f'{name}.nc')
    print(f'save to ds {save_to}')
    ds.to_netcdf(save_to)

In [None]:
def get_aligned_data(ds1, ds2):
    '''
    Return ds1, ds2, and ds1-ds2 containing only the times that are in both.
    
    '''
    # round time coords to make sure they match exactly
    rounded_coords = ds1.coords["time"].astype(np.float32).round(2)
    ds1 = ds1.assign_coords(time=rounded_coords)
    rounded_coords = ds2.coords["time"].astype(np.float32).round(2)
    ds2 = ds2.assign_coords(time=rounded_coords)

    # Align the datasets based on the intersection of time coordinates
    ds1_aligned, ds2_aligned = xr.align(ds1, ds2, join="inner")

    # Make sure data has NaNs, not zeroes.
    ds1_aligned = ds1_aligned.where(ds1_aligned != 0)
    ds2_aligned = ds2_aligned.where(ds2_aligned != 0)

    # Subtract the aligned datasets
    diff = ds1_aligned - ds2_aligned
    # diff = diff.where(diff != 0)

    return[ds1_aligned, ds2_aligned, diff]