In [None]:
import xesmf as xe
import numpy as np
import xarray as xr
import glob as glob

In [None]:
import xesmf as xe

class xESMF_REGRIDDER:
    def __init__(self,ds_out,method='conservative'):
        self.ds_out = ds_out
        self.method = method

    def add_lat_lon_bounds(self, ds, lat_name='lat', lon_name='lon'):
        """
        Function to add latitude and longitude bounds to a dataset
        for conservative regridding with xESMF.

        Parameters:
        - ds: xarray Dataset or DataArray containing latitude and longitude coordinates
        - lat_name: Name of the latitude coordinate in the dataset (default is 'lat')
        - lon_name: Name of the longitude coordinate in the dataset (default is 'lon')

        Returns:
        - ds: Dataset with added 'lat_bnds' and 'lon_bnds' variables
        """

        # Get latitude and longitude coordinates
        lat = ds[lat_name]
        lon = ds[lon_name]

        # Calculate latitude bounds
        lat_diff = np.diff(lat) / 2.0
        lat_bnds = np.empty((lat.size, 2), dtype=np.float64)
        lat_bnds[:, 0] = lat - np.concatenate(([lat_diff[0]], lat_diff))
        lat_bnds[:, 1] = lat + np.concatenate((lat_diff, [lat_diff[-1]]))

        # Calculate longitude bounds
        lon_diff = np.diff(lon) / 2.0
        lon_bnds = np.empty((lon.size, 2), dtype=np.float64)
        lon_bnds[:, 0] = lon - np.concatenate(([lon_diff[0]], lon_diff))
        lon_bnds[:, 1] = lon + np.concatenate((lon_diff, [lon_diff[-1]]))

        # Add latitude and longitude bounds to dataset
        ds.coords['lat_bnds'] = (('lat', 'bnds'), lat_bnds)
        ds.coords['lon_bnds'] = (('lon', 'bnds'), lon_bnds)

        return ds

    def regrid(self,ds):
        """
        Function to regrid xarray dataset or dataarray using xESMF.

        Parameters:
        - self.ds_out: xarray Dataset with the latitudes and longitudes you want to 
                       regrid to.
        - self.method: The regridding approach you want to use (i.e., bilinear (fastest),
                       conservative (optimal usually), nearest_n2s...)
        - ds: xarray Dataset or DataArray containing latitude and longitude coordinates

        Returns:
        - ds: Dataset regridded to the grid specified as ds_out.   

        """
        if self.method == 'conservative':
            #ds = ds.expand_dims('bnds')
            #ds['bnds'] = ('bnds', [ds.attrs.get('bnds', [1,2])])
            if 'bnds' not in list(ds.coords.keys()):
                 ds.coords['bnds'] = [1.0,2.0]
            if 'lat_bnds' not in list(ds.coords.keys()):
                ds = self.add_lat_lon_bounds(ds)
            if len(ds['lat_bnds'].dims) > 2:
                ds = self.add_lat_lon_bounds(ds)

        print(ds)
        regridder = xe.Regridder(ds, self.ds_out, self.method)
        dr_out = regridder(ds, keep_attrs=True)

        return dr_out

if __name__ == "__main__":
    # Example usage:

    # New equals a dataset of the resolution you want to regrid to, for example NeuralGCM 2.8 degree
    new = xr.open_dataset("/scratch/midway3/krucker01/ai-models/ngcm/climate_2.8_csp_pe/tmp_monthly_1981-2023_csp_pe.nc")
    
    
    dims = list(new.dims)
    dim1 = [d for d in dims if 'lat' in d][0]
    dim2 = [d for d in dims if 'lon' in d][0]
    resolution = f'{len(new[dim1])}x{len(new[dim2])}'

    # Create xr.Dataset with output grid
    ds_out = xr.Dataset(
        {
            "lat": (["lat"], new[dim1].data, {"units": "degrees_north"}),
            "lon": (["lon"], new[dim2].data, {"units": "degrees_east"}),
        }
    )


In [None]:
def reassign(ds):
    ds['time'] = ds['valid_time'].squeeze()
    return ds

In [None]:
# Load files you want to interpolate
files = glob.glob('/glade/derecho/scratch/katyr/AMIP for NGCM/ace/ACE2-ERA5/*/monthly_mean_predictions.nc')
ds = xr.open_mfdataset(files, combine = 'nested', concat_dim = 'member_id', preprocess = reassign)



['/project/tas1/itbaxter/for-tiffany/amip-piForcing/180x360/ta/ta_CMIP6_HadGEM3-GC31-LL_Amon_amip-piForcing_1870-2015.nc',
 '/project/tas1/itbaxter/for-tiffany/amip-piForcing/180x360/ta/ta_CMIP6_MRI-ESM2-0_Amon_amip-piForcing_1870-2015.nc',
 '/project/tas1/itbaxter/for-tiffany/amip-piForcing/180x360/ta/ta_CMIP6_TaiESM1_Amon_amip-piForcing_1850-2015.nc',
 '/project/tas1/itbaxter/for-tiffany/amip-piForcing/180x360/ta/ta_CMIP6_CanESM5_Amon_amip-piForcing_1870-2015.nc',
 '/project/tas1/itbaxter/for-tiffany/amip-piForcing/180x360/ta/ta_CMIP6_MIROC6_Amon_amip-piForcing_1870-2015.nc',
 '/project/tas1/itbaxter/for-tiffany/amip-piForcing/180x360/ta/ta_CMIP6_CESM2_Amon_amip-piForcing_1870-2016.nc']

In [12]:
ds_regridded = xESMF_REGRIDDER(ds_out,method='bilinear').regrid(ds)
ds_regridded

<xarray.DataArray 'ta' (year: 35, member_id: 8, plev: 19, latitude: 180,
                        longitude: 360)> Size: 1GB
dask.array<getitem, shape=(35, 8, 19, 180, 360), dtype=float32, chunksize=(1, 3, 19, 180, 360), chunktype=numpy.ndarray>
Coordinates:
  * plev       (plev) float64 152B 1e+05 9.25e+04 8.5e+04 ... 1e+03 500.0 100.0
  * latitude   (latitude) float32 720B -89.24 -88.25 -87.25 ... 88.25 89.24
  * longitude  (longitude) float32 1kB 0.5 1.5 2.5 3.5 ... 357.5 358.5 359.5
  * year       (year) int64 280B 1980 1981 1982 1983 ... 2011 2012 2013 2014
  * member_id  (member_id) <U24 768B 'HadGEM3-GC31-LL_r1i1p1f3' ... 'CESM2_r1...
Attributes:
    standard_name:  air_temperature
    long_name:      Air Temperature
    comment:        Air Temperature
    units:          K
    original_name:  mo: (stash: m01s30i294, blev: [1000.0, 925.0, 850.0, 700....
    cell_methods:   time: mean
    cell_measures:  area: areacella
    history:        2019-11-20T15:28:22Z altered by CMOR: rep

Unnamed: 0,Array,Chunk
Bytes,166.25 MiB,1.78 MiB
Shape,"(35, 8, 19, 64, 128)","(1, 3, 19, 64, 128)"
Dask graph,210 chunks in 3638 graph layers,210 chunks in 3638 graph layers
Data type,float32 numpy.ndarray,float32 numpy.ndarray
"Array Chunk Bytes 166.25 MiB 1.78 MiB Shape (35, 8, 19, 64, 128) (1, 3, 19, 64, 128) Dask graph 210 chunks in 3638 graph layers Data type float32 numpy.ndarray",8  35  128  64  19,

Unnamed: 0,Array,Chunk
Bytes,166.25 MiB,1.78 MiB
Shape,"(35, 8, 19, 64, 128)","(1, 3, 19, 64, 128)"
Dask graph,210 chunks in 3638 graph layers,210 chunks in 3638 graph layers
Data type,float32 numpy.ndarray,float32 numpy.ndarray


In [None]:
# Save the dataset
ds_regridded.to_netcdf('***')