In [1]:
import xarray as xr
import numpy as np
from scipy.interpolate import RegularGridInterpolator
import glob
import pandas as pd
import matplotlib.pyplot as plt

In [2]:
def regrid_dataset_3d(dataset, new_resolution):
    """
    Regrid a 3D dataset (time, lat, lon) to a new resolution using linear interpolation.

    Parameters:
    dataset (xarray.Dataset): Input dataset with dimensions (time, lat, lon).
    new_resolution (float): Desired new resolution (e.g., 1.0 for 1°).

    Returns:
    xarray.Dataset: Regridded dataset.
    """
    # Get original coordinates
    lat = dataset['lat'].values
    lon = dataset['lon'].values
    time = dataset['time'].values

    # Create new latitude and longitude coordinates
    new_lat = np.arange(lat.min(), lat.max() + new_resolution, new_resolution)
    new_lon = np.arange(lon.min(), lon.max() + new_resolution, new_resolution)

    varname_nc = list(dataset.data_vars.keys())[0]
    # Prepare interpolation for each time step
    regridded_data = []
    for t in range(len(time)):
        # Extract 2D data slice for the current time step
        data_slice = dataset[varname_nc].isel(time=t).values

        # Define an interpolation function
        interp_func = RegularGridInterpolator((lat, lon), data_slice, method='linear', bounds_error=False, fill_value=np.nan)
        
        # Create mesh grid for new coordinates
        new_lon_grid, new_lat_grid = np.meshgrid(new_lon, new_lat)

        # Apply the interpolation on the new grid
        interpolated_slice = interp_func((new_lat_grid, new_lon_grid))

        regridded_data.append(interpolated_slice)

    # Stack the regridded data into a 3D array
    regridded_data = np.stack(regridded_data, axis=0)

    # Create a new xarray Dataset
    regridded_dataset = xr.Dataset(
        {
            varname_nc: (['time', 'lat', 'lon'], regridded_data)
        },
        coords={
            'time': time,
            'lat': new_lat,
            'lon': new_lon
        }
    )

    return regridded_dataset

In [3]:
path_weekly_anoms = '/glade/derecho/scratch/jhayron/Data4Predictability/WeeklyAnoms_DetrendedStd_v3/'

In [4]:
list_files = np.sort(glob.glob(f'{path_weekly_anoms}*.nc'))

In [11]:
list_vars = [list_files[i].split('/')[-1].split('.')[0] for i in range(len(list_files))]

In [12]:
list_vars

['IC_SODA',
 'IT_SODA',
 'MLD_SODA',
 'OHC100_SODA',
 'OHC200_SODA',
 'OHC300_SODA',
 'OHC50_SODA',
 'OHC700_SODA',
 'OLR_ERA5',
 'SD_ERA5',
 'SSH_SODA',
 'SST_OISSTv2',
 'SST_SODA',
 'STL_1m_ERA5',
 'STL_28cm_ERA5',
 'STL_7cm_ERA5',
 'STL_full_ERA5',
 'SWVL_1m_ERA5',
 'SWVL_28cm_ERA5',
 'SWVL_7cm_ERA5',
 'SWVL_full_ERA5',
 'U10_ERA5',
 'U200_ERA5',
 'Z500_ERA5']

In [14]:
path_weekly_anoms_1dg = '/glade/derecho/scratch/jhayron/Data4Predictability/WeeklyAnoms_DetrendedStd_v3_2dg/'

In [15]:
for ivar,var in enumerate(list_vars):
    print(ivar,var)
    path_nc_anoms = f'{path_weekly_anoms}{var}.nc'
    anoms = xr.open_dataset(path_nc_anoms)
    anoms = anoms.assign_coords(time=pd.DatetimeIndex(anoms.time).normalize())
    var_name_nc = list(anoms.data_vars.keys())[0]
    resolution = 2 ## 1 degree
    ### PROCESS AND LOAD DATA
    if resolution == 0.5:
        regridded_dataset = anoms
    else:
        regridded_dataset = regrid_dataset_3d(anoms, resolution)
    regridded_dataset.to_netcdf(f'{path_weekly_anoms_1dg}{var}.nc')

0 IC_SODA
1 IT_SODA
2 MLD_SODA
3 OHC100_SODA
4 OHC200_SODA
5 OHC300_SODA
6 OHC50_SODA
7 OHC700_SODA
8 OLR_ERA5
9 SD_ERA5
10 SSH_SODA
11 SST_OISSTv2
12 SST_SODA
13 STL_1m_ERA5
14 STL_28cm_ERA5
15 STL_7cm_ERA5
16 STL_full_ERA5
17 SWVL_1m_ERA5
18 SWVL_28cm_ERA5
19 SWVL_7cm_ERA5
20 SWVL_full_ERA5
21 U10_ERA5
22 U200_ERA5
23 Z500_ERA5
