# Compare correlation of $\nabla SST$ and $div(\vec{U})$ for different env conds

# Packages

In [1]:
import intake
from easygems import healpix as egh
import cartopy.crs as ccrs
import cartopy.feature as cf
import cmocean
import healpy as hp
import matplotlib.pyplot as plt
import numpy as np
import xarray as xr
import pandas as pd
import warnings
import math
from scipy.stats import sem
from scipy.special import gammaincc
from pycoare import coare_35
import sys
sys.path.append('/home/b/b383497/hk25-teams/hk25-ShallowCirc/src/')
from toolbox import attach_coords, compute_hder, compute_conv, nest2ring_index
sys.path.append('/home/b/b383497/hk25-ASintTrops/Scripts/lfdavoli/')
import geometry as gm
from concurrent.futures import ProcessPoolExecutor

warnings.filterwarnings("ignore", category=FutureWarning) # don't warn us about future package conflicts

# Functions

In [2]:
def cells_of_region(ds, region):
    if ('lon' not in list(ds.coords)) | ('lat' not in list(ds.coords)):
        raise NameError('Missing coordinates')
        
    return np.where(
        (ds.lon>np.mod(regions[region]['boundaries'][0],360)) & 
        (ds.lon<np.mod(regions[region]['boundaries'][1],360)) &
        (ds.lat>regions[region]['boundaries'][2]) & 
        (ds.lat<regions[region]['boundaries'][3])
    )[0]

def get_region(ds,region):
    if 'value' in list(ds.dims):
        return ds.isel(value=cells_of_region(ds, region))        
    elif 'cell' in list(ds.dims):
        return ds.isel(cell=cells_of_region(ds, region))
    

def worldmap(var, extent=None,title=None,**kwargs):
    projection = ccrs.Robinson(central_longitude=-135.5808361)
    fig, ax = plt.subplots(
        figsize=(8, 4), subplot_kw={"projection": projection}, constrained_layout=True
    )
    ax.set_global()
    if extent is not None:
        ax.set_extent(extent, crs=ccrs.PlateCarree())
    egh.healpix_show(var, ax=ax, **kwargs)
    if title is not None:
        ax.set_title(title)
    ax.add_feature(cf.COASTLINE, linewidth=0.8)
    ax.add_feature(cf.BORDERS, linewidth=0.4)
    
def get_nn_lon_lat_index(nside, lons, lats):
    lons2, lats2 = np.meshgrid(lons, lats)
    return xr.DataArray(
        hp.ang2pix(nside, lons2, lats2, nest=True, lonlat=True),
        coords=[("lat", lats), ("lon", lons)],
    )

def get_nside(ds):
    return ds.crs.healpix_nside

def get_nn_data(var, nx=1000, ny=1000, ax=None):
    """
    var: variable (array-like)
    nx: image resolution in x-direction
    ny: image resolution in y-direction
    ax: axis to plot on
    returns: values on the points in the plot grid.
    """
    lonlat = get_lonlat_for_plot_grid(nx, ny, ax)
    try:
        return get_healpix_nn_data(var, lonlat)
    except ValueError:
        pass
    if set(var.dims) == {"lat", "lon"}:
        return get_lonlat_meshgrid_nn_data(var, lonlat)
    else:
        return get_lonlat_nn_data(var, lonlat)


def get_healpix_nn_data(var, lonlat):
    """
    var: variable on healpix coordinates (array-like)
    lonlat: coordinates at which to get the data
    returns: values on the points in the plot grid.
    """
    valid = np.all(np.isfinite(lonlat), axis=-1)
    points = lonlat[valid].T  # .T reverts index order
    pix = hp.ang2pix(
        hp.npix2nside(len(var)), theta=points[0], phi=points[1], nest=True, lonlat=True
    )
    res = np.full(lonlat.shape[:-1], np.nan, dtype=var.dtype)
    res[valid] = var[pix]
    return res


def get_lonlat_nn_data(var, lonlat):
    """
    var: variable with lon and lat attributes (2d slice)
    lonlat: coordinates at which to get the data
    returns: values on the points in the plot grid.
    """
    var_xyz = lonlat_to_xyz(lon=var.lon.values.flatten(), lat=var.lat.values.flatten())
    tree = KDTree(var_xyz)

    valid = np.all(np.isfinite(lonlat), axis=-1)
    ll_valid = lonlat[valid].T
    plot_xyz = lonlat_to_xyz(lon=ll_valid[0], lat=ll_valid[1])

    distances, inds = tree.query(plot_xyz)
    res = np.full(lonlat.shape[:-1], np.nan, dtype=var.dtype)
    res[valid] = var.values.flatten()[inds]
    return res


def get_lonlat_meshgrid_nn_data(var, lonlat):
    """
    var: variable with lon and lat attributes (2d slice)
    lonlat: coordinates at which to get the data
    returns: values on the points in the plot grid.
    """
    return get_lonlat_nn_data(var.stack(cell=("lon", "lat")), lonlat)


def get_lonlat_for_plot_grid(nx, ny, ax=None):
    """
    nx: image resolution in x-direction
    ny: image resolution in y-direction
    ax: axis to plot on
    returns: coordinates of the points in the plot grid.
    """

    if ax is None:
        ax = plt.gca()

    xlims = ax.get_xlim()
    ylims = ax.get_ylim()
    xvals = np.linspace(xlims[0], xlims[1], nx)
    yvals = np.linspace(ylims[0], ylims[1], ny)
    xvals2, yvals2 = np.meshgrid(xvals, yvals)
    lonlat = ccrs.PlateCarree().transform_points(
        ax.projection, xvals2, yvals2, np.zeros_like(xvals2)
    )
    return lonlat


def lonlat_to_xyz(lon, lat):
    """
    lon: longitude in degree E
    lat: latitude in degree N
    returns numpy array (3, len (lon)) with coordinates on unit sphere.
    """

    return np.array(
        (
            np.cos(np.deg2rad(lon)) * np.cos(np.deg2rad(lat)),
            np.sin(np.deg2rad(lon)) * np.cos(np.deg2rad(lat)),
            np.sin(np.deg2rad(lat)),
        )
    ).T

def plot_map_diff(var, ref, colorbar_label="", title="", extent=None, **kwargs):
    """
    var: data set
    ref: reference data
    colorbar_label: label for the colorbar
    title: title string
    **kwargs: get passed to imshow
    returns figure, axis objects
    """
    projection = ccrs.Robinson(central_longitude=-135.5808361)
    fig, ax = plt.subplots(
        figsize=(8, 4), subplot_kw={"projection": projection}, constrained_layout=True
    )
    ax.set_global()
    if extent is not None:
        ax.set_extent(extent, crs=ccrs.PlateCarree())

    varmap = get_nn_data(var, ax=ax)
    refmap = get_nn_data(ref, ax=ax)
    imsh = ax.imshow(
        varmap - refmap, extent=ax.get_xlim() + ax.get_ylim(), origin="lower", **kwargs
    )

    # Add coastlines and borders
    ax.add_feature(cf.COASTLINE, linewidth=0.8)
    ax.add_feature(cf.BORDERS, linewidth=0.4)

    # ✅ Add gridlines
    gl = ax.gridlines(draw_labels=True, linewidth=0.5, color='gray', alpha=0.5, linestyle='--')
    gl.top_labels = False  # Disable labels on top
    gl.right_labels = False  # Disable labels on right
    gl.xlabel_style = {'size': 8}
    gl.ylabel_style = {'size': 8}

    # Colorbar and title
    fig.colorbar(imsh, label=colorbar_label)
    plt.title(title)
    return (fig, ax)


def compute_SST_wind_derivative_fields(str_a,str_b,sigma,llon,llat,sst,u_interp,v_interp):
    # Get the background wind field.
    smooth_u = gm.nan_gaussian_filter(u_interp,sigma)
    smooth_v = gm.nan_gaussian_filter(v_interp,sigma)
    smooth_ws = np.sqrt(smooth_u**2+smooth_v**2)
    smooth_sst = gm.nan_gaussian_filter(sst,sigma)

    # Large-scale winddir
    cosphi = smooth_u/smooth_ws
    sinphi = smooth_v/smooth_ws
    
    # Get the anomalies with respect to the background wind field.
    u_prime = u_interp-smooth_u
    v_prime = v_interp-smooth_v
    sst_prime = sst-smooth_sst

    dsst_dx, dsst_dy = gm.grad_sphere(sst_prime,llon,llat)
    if str_a=='gamma':
        a_prime = u_interp*dsst_dx + v_interp*dsst_dy
    elif str_a=='dsst_dr':
        a_prime = dsst_dx*cosphi + dsst_dy*sinphi
    elif str_a=='lapl_sst':
        a_prime = gm.div_sphere(dsst_dx,dsst_dy,llon,llat)
    elif str_a=='d2sst_ds2':
        dsst_ds = -dsst_dx*sinphi + dsst_dy*cosphi
        ddsst_ds_dx, ddsst_ds_dy = gm.grad_sphere(dsst_ds,llon,llat)
        a_prime = -ddsst_ds_dx*sinphi + ddsst_ds_dy*cosphi
    elif str_a=='sst_prime':
        smooth_sst = gm.nan_gaussian_filter(l3_sst,sigma)
        a_prime = l3_sst-smooth_sst
		
    if str_b=='wind_div':
        b_prime = gm.div_sphere(u_interp,v_interp,llon,llat)
    elif str_b=='dr_dot_prime_dr':
        r_dot_prime = u_prime*cosphi + v_prime*sinphi
        dr_dot_prime_dx, dr_dot_prime_dy = gm.grad_sphere(r_dot_prime,llon,llat)
        b_prime = dr_dot_prime_dx*cosphi + dr_dot_prime_dy*sinphi 
    elif str_b=='ds_dot_prime_ds':
        s_dot_prime = -u_prime*sinphi + v_prime*cosphi
        ds_dot_prime_dx, ds_dot_prime_dy = gm.grad_sphere(s_dot_prime,llon,llat)
        b_prime = -ds_dot_prime_dx*sinphi + ds_dot_prime_dy*cosphi
    elif str_b=='ws_prime':
        b_prime = np.sqrt(u_interp**2+v_interp**2)-smooth_ws


    # Remove the NaNs, from the variables to be concatenated (with no subsampling).
    #a_to_be_concat = a_prime[(~np.isnan(a_prime))&(~np.isnan(b_prime))&(~np.isnan(smooth_ws))]
    #b_to_be_concat = b_prime[(~np.isnan(a_prime))&(~np.isnan(b_prime))&(~np.isnan(smooth_ws))]    
    #U_to_be_concat = smooth_ws[(~np.isnan(a_prime))&(~np.isnan(b_prime))&(~np.isnan(smooth_ws))]    

    return a_prime, b_prime, smooth_ws #a_to_be_concat, b_to_be_concat, U_to_be_concat


def compute_SST_var_derivative_fields(str_a,str_b,sigma,llon,llat,sst,u_interp,v_interp,var):
    # Get the background wind field.
    smooth_u = gm.nan_gaussian_filter(u_interp,sigma)
    smooth_v = gm.nan_gaussian_filter(v_interp,sigma)
    smooth_ws = np.sqrt(smooth_u**2+smooth_v**2)
    smooth_sst = gm.nan_gaussian_filter(sst,sigma)
    smooth_var = gm.nan_gaussian_filter(var,sigma)

    # Large-scale winddir
    cosphi = smooth_u/smooth_ws
    sinphi = smooth_v/smooth_ws
    
    # Get the anomalies with respect to the background wind field.
    u_prime = u_interp-smooth_u
    v_prime = v_interp-smooth_v
    sst_prime = sst-smooth_sst
    var_prime = var-smooth_var

    dsst_dx, dsst_dy = gm.grad_sphere(sst_prime,llon,llat)
    if str_a=='gamma':
        a_prime = u_interp*dsst_dx + v_interp*dsst_dy
    elif str_a=='dsst_dr':
        a_prime = dsst_dx*cosphi + dsst_dy*sinphi
    elif str_a=='lapl_sst':
        a_prime = gm.div_sphere(dsst_dx,dsst_dy,llon,llat)
    elif str_a=='d2sst_ds2':
        dsst_ds = -dsst_dx*sinphi + dsst_dy*cosphi
        ddsst_ds_dx, ddsst_ds_dy = gm.grad_sphere(dsst_ds,llon,llat)
        a_prime = -ddsst_ds_dx*sinphi + ddsst_ds_dy*cosphi
    elif str_a=='sst_prime':
        smooth_sst = gm.nan_gaussian_filter(l3_sst,sigma)
        a_prime = l3_sst-smooth_sst

    dvar_dx, dvar_dy = gm.grad_sphere(var_prime,llon,llat)
    if str_b=='wind_div':
        b_prime = gm.div_sphere(u_interp,v_interp,llon,llat)
    elif str_b=='dr_dot_prime_dr':
        b_prime = dvar_dx*cosphi + dvar_dy*sinphi
    elif str_b=='ds_dot_prime_ds':
        b_prime = -dvar_dx*sinphi + dvar_dy*cosphi
    elif str_b=='ws_prime':
        b_prime = np.sqrt(u_interp**2+v_interp**2)-smooth_ws


    # Remove the NaNs, from the variables to be concatenated (with no subsampling).
    #a_to_be_concat = a_prime[(~np.isnan(a_prime))&(~np.isnan(b_prime))&(~np.isnan(smooth_ws))]
    #b_to_be_concat = b_prime[(~np.isnan(a_prime))&(~np.isnan(b_prime))&(~np.isnan(smooth_ws))]    
    #U_to_be_concat = smooth_ws[(~np.isnan(a_prime))&(~np.isnan(b_prime))&(~np.isnan(smooth_ws))]    

    return a_prime, b_prime, smooth_ws #a_to_be_concat, b_to_be_concat, U_to_be_concat



def binned_mean_plot(x, y, bins=10, xlabel='x', ylabel='Mean y',title=None):
    """
    Bin x into equally spaced intervals and compute mean y in each bin.
    
    Parameters:
    - x: array-like, independent variable
    - y: array-like, dependent variable
    - bins: int, number of bins
    - xlabel: str, label for x-axis
    - ylabel: str, label for y-axis
    """
    x = np.asarray(x)
    y = np.asarray(y)

    # Compute bin edges and indices
    bin_edges = np.linspace(np.nanmin(x), np.nanmax(x), bins + 1)
    bin_indices = np.digitize(x, bin_edges) - 1  # bins are 0-indexed

    bin_centers = 0.5 * (bin_edges[:-1] + bin_edges[1:])
    mean_y = np.array([
        np.nanmean(y[bin_indices == i]) if np.any(bin_indices == i) else np.nan
        for i in range(bins)
    ])
    std_y = np.array([
        np.nanstd(y[bin_indices == i]) if np.any(bin_indices == i) else np.nan
        for i in range(bins)
    ])

    lin_regr = binned_weighted_lin_regr(a_prime, b_prime, bins=10)
    
    # Plot
    plt.figure(figsize=(8, 5))
    plt.errorbar(x=bin_centers,y=mean_y,yerr=std_y,marker='o', linestyle='-', color='black')
    plt.plot(bin_centers,lin_regr['intercept'][0]+bin_centers*lin_regr['slope'][0],marker='', linestyle='-', color='red')
    plt.xlabel(xlabel)
    plt.ylabel(ylabel)
    plt.title(title)
    plt.grid(True)
    plt.show()

def weighted_linear_regression_with_pvalues(x, y, sigma_y):
    '''         
        Compute a weighted linear regressio of (x,y+/-sigma_y),
        where the uncertainty on y observations is taken into 
        account.
        Returns a dictionary in the form:
                "intercept" : (intercept, sigma_b),
                "slope" : (slope, sigma_m),
                "chi2" : chi square value,
                "q" : goodness of fit

        From Numerical Recipe 15.2
        If q is larger than, say, 0.1,then the goodness-of-fit is believable. 
        If it is larger than, say, 0.001, then the fit may be acceptable if the 
        errors are nonnormal or have been moderately underestimated. If q is 
        less than 0.001, then the model and/or estimation procedure can rightly
        be called into question. In this latter case, turn to 15.7 to proceed 
        further.
    '''

    # Numerical recipe code 15.2 (https://numerical.recipes/book.html)
    # Translated to python by copilot and checked.
    ndata = len(x)
    x = np.array(x)
    y = np.array(y)
    sigma_y = np.array(sigma_y)
    
    ss = 0.0
    sx = 0.0
    sy = 0.0
    st2 = 0.0
    b = 0.0
    chi2 = 0.0
    q = 1.0 # Estimator for goodness of fit.
    
    '''
        From Numerical Recipe 15.2
        If Q is larger than, say, 0.1,then the goodness-of-fit is believable. 
        If it is larger than, say, 0.001, then the fit may be acceptable if the 
        errors are nonnormal or have been moderately underestimated. If Q is 
        less than 0.001, then the model and/or estimation procedure can rightly
        be called into question. In this latter case, turn to 15.7 to proceed 
        further.
    '''

    for i in range(ndata):
        wt = 1.0 / (sigma_y[i] ** 2)
        ss += wt
        sx += x[i] * wt
        sy += y[i] * wt

    sxoss = sx / ss

    for i in range(ndata):
        t = (x[i] - sxoss) / sigma_y[i]
        st2 += t * t
        b += t * y[i] / sigma_y[i]

    b /= st2
    a = (sy - sx * b) / ss
    sigma_a = np.sqrt((1.0 + sx * sx / (ss * st2)) / ss)
    sigma_b = np.sqrt(1.0 / st2)

    for i in range(ndata):
        chi2 += ((y[i] - a - b * x[i]) / sigma_y[i]) ** 2

    if ndata > 2:
        q = gammaincc(0.5 * (ndata - 2), 0.5 * chi2)

    return {
        "intercept": (a, sigma_a),
        "slope": (b, sigma_b),
        "chi2" : chi2,
        "q" : q,
    }

def binned_weighted_lin_regr(x, y, bins=10):
    """
    Bin x into equally spaced intervals and compute the lin regr y(x) considering std(y) for each bin.
    
    Parameters:
    - x: array-like, independent variable
    - y: array-like, dependent variable
    - bins: int, number of bins
    """
    x = np.asarray(x)
    y = np.asarray(y)

    # Compute bin edges and indices
    bin_edges = np.linspace(np.nanmin(x), np.nanmax(x), bins + 1)
    bin_indices = np.digitize(x, bin_edges) - 1  # bins are 0-indexed

    bin_centers = 0.5 * (bin_edges[:-1] + bin_edges[1:])
    mean_y = np.array([
        np.nanmean(y[bin_indices == i]) if np.any(bin_indices == i) else np.nan
        for i in range(bins)
    ])
    std_y = np.array([
        np.nanstd(y[bin_indices == i]) if np.any(bin_indices == i) else np.nan
        for i in range(bins)
    ])

    bin_centers,mean_y,std_y = bin_centers[np.isfinite(mean_y) & np.isfinite(std_y) & (std_y != 0)],mean_y[np.isfinite(mean_y) & np.isfinite(std_y) & (std_y != 0)],std_y[np.isfinite(mean_y) & np.isfinite(std_y) & (std_y != 0)]

    return weighted_linear_regression_with_pvalues(bin_centers, mean_y, std_y)
    

def plot_map_with_contours(data,llat,llon,background=None,title=None,cmap=None,vmin=None,vmax=None,label=None, contour_levels=10):
    """
    Plot a 2D data array on a map with optional background contours.

    Parameters:
    - data: 2D array (lat x lon)
    - llat, llon: 2D arrays matching shape of data
    - background: Optional 2D array (same shape) to use as contour overlay
    - title: Plot title
    - cmap: Colormap for main data
    - contour_levels: Number or list of contour levels
    """
    fig = plt.figure(figsize=(10, 6))
    ax = plt.axes(projection=ccrs.PlateCarree())

    # Main data plot
    mesh = ax.pcolormesh(llon, llat, data, transform=ccrs.PlateCarree(), cmap=cmap,vmin=vmin,vmax=vmax)

    # Add optional contours
    if background is not None:
        contours = ax.contour(llon, llat, background, levels=contour_levels, colors='black',
                              linewidths=0.6, transform=ccrs.PlateCarree())
        ax.clabel(contours, inline=True, fontsize=8)

    ax.coastlines()
    ax.add_feature(cf.BORDERS, linewidth=0.5)
    ax.gridlines(draw_labels=True, linewidth=0.5, linestyle='--')
    plt.colorbar(mesh, ax=ax, orientation='vertical', label=label,shrink=0.5)
    plt.title(title)
    plt.tight_layout()
    plt.show()


def runCOARE(ds, uas, vas, psl, sst, tas, hur, pr, rsds, rlds, lat_name='lat', lon_name='lon'):
    ''' COARE requires the following input atmospheric variables:
        source: https://github.com/pyCOARE/coare/blob/main/pycoare/coare_35.py

        u: ocean surface wind speed (m/s) at height zu
        t: bulk air temperature (degC) at height zt
        rh: relative humidity (%) at height zq
        ts: sea water temperature (degC) (also see jcool)
        p: surface air pressure (mb)
        zi: planetary boundary layer height (m)
        rs: downward shortwave radiation (W/m^2)
        rl: downward longwave radiation (W/m^2)
        rain: rain rate (mm/hr)

        Note: ocean variables and sensor heights ignored and set to COARE default for now.

        Inputs to this wrapping function:
            - IFS dataset with all atmospheric variables
            - Each atmospheric variable as: ds["var"]
            - Names of your lat/lon coordinates (if different from standard)

        Output: Dataset inputted into function with processed atmospheric variables
                & variable of interest from COARE (in this case wind stress).
                See source for other COARE variables (good luck, COARE code is a handful).

    '''
    # define shape of grid
    dimensions = list(ds.dims)
    grid_shape = tas.shape

    # compute wind speed from wind vectors
    ds['wsp'] = ((lat_name, lon_name), np.sqrt(uas**2 + vas**2).data)

    # express pressure as hPa instead of Pa
    ds['psl'] = ds['psl'] * 0.01

    # express sst and temp air as deg C not K
    ds['sst'] = ds['sst']- 273.15
    ds['tas'] = ds['tas'] - 273.15

    # express relative humidity in %
    ds['hur'] = ds['hur'] * 100

    # convert rain rate from kg/m2/s to mm/hr
    ds['pr'] = ds['pr'] * 3600

    # run COARE
    bulk_params = coare_35(u = ds["wsp"].data.flatten(), t = ds["tas"].data.flatten(),
                            rh = ds["hur"].data.flatten(), ts = ds["sst"].data.flatten(),
                            p = ds["psl"].data.flatten(), rs = ds["rsds"].data.flatten(),
                            rl = ds["rlds"].data.flatten(), rain = ds["pr"].data.flatten())

    # retrieve variable of interest (in this case wind stress)
    ds['cwst'] = (dimensions, bulk_params._return_vars('tau').reshape(grid_shape))
    ds['ccdrag'] = (dimensions, bulk_params._return_vars('cd').reshape(grid_shape))
    ds['ccdragn'] = (dimensions, bulk_params._return_vars('cdn_rf').reshape(grid_shape))

    return ds

def save_friday_data(ds,friday):
    print(f'{friday} # Start')
    fine_ds = ds.sel(time=friday)
    fine_ds = fine_ds.drop_vars(['lat','lon'])
    for var in fine_ds.data_vars:
        if np.issubdtype(fine_ds[var].dtype, np.floating):
            fine_ds[var] = fine_ds[var].where(fine_ds[var] != 9999, np.nan)
    # Regrid and select region
    region_idx = get_nn_lon_lat_index(2**fine_zoom, np.arange(regions[region]['boundaries'][0], regions[region]['boundaries'][1], fine_latlon_gridstep/supersampling['lon']), np.arange(regions[region]['boundaries'][2], regions[region]['boundaries'][3], fine_latlon_gridstep/supersampling['lat']))
    
    print(f'{friday} # Compute dask')
    fine_sst = fine_ds.sst.where(fine_ds.sst!=9999).isel(cell=region_idx).coarsen(supersampling).mean().compute()
    fine_tas = fine_ds['tas'].isel(cell=region_idx).coarsen(supersampling).mean().compute()
    fine_10u = fine_ds['uas'].isel(cell=region_idx).coarsen(supersampling).mean().compute()
    fine_10v = fine_ds['vas'].isel(cell=region_idx).coarsen(supersampling).mean().compute()
    fine_blh = fine_ds['blh'].isel(cell=region_idx).coarsen(supersampling).mean().compute()
    # Remove land
    fine_sst = fine_sst.where(np.isfinite(fine_sst))
    fine_10u = fine_10u.where(np.isfinite(fine_sst))
    fine_10v = fine_10v.where(np.isfinite(fine_sst))
    fine_blh = fine_blh.where(np.isfinite(fine_sst))
    
    llon, llat = np.meshgrid(fine_sst.lon,fine_sst.lat)

    # SST grad, wind and blh div along large-scale wind direction
    # Sub-mesoscale
    print(f'{friday} # Compute fields')
    sst_prime, wind_prime, smooth_ws = compute_SST_wind_derivative_fields(sst_deriv,wind_deriv,submeso_pass_lower_sigma,llon,llat,fine_sst,fine_10u,fine_10v)
    _, bhl_prime, _ = compute_SST_var_derivative_fields(sst_deriv,wind_deriv,submeso_pass_lower_sigma,llon,llat,fine_sst,fine_10u,fine_10v,fine_blh)

    smooth_tas = gm.nan_gaussian_filter(fine_tas,submeso_pass_lower_sigma)
    smooth_sst = gm.nan_gaussian_filter(fine_sst,submeso_pass_lower_sigma)

    output = {
        sst_deriv : sst_prime.flatten(),
        wind_deriv : wind_prime.flatten(),
        blh_deriv : bhl_prime.flatten(),
        'smooth_ws' : smooth_ws.flatten(),
        'smooth_tas' : smooth_tas.flatten(),
        'smooth_sst' : smooth_sst.flatten(),
    }
    
    output_pd = pd.DataFrame.from_dict(output)
    
    if enable_save_file:
        output_pd.to_csv(f'{WORKDIR}/fields/{str_mech}_zoom_{fine_zoom}_timeres_{time_res}_latlon_res_{fine_latlon_gridstep*100}km_high_pass_{submeso_pass_lower_sigma}_{region}_{friday}.csv')
    print(f'{friday} # DONE')

def parallel_apply(my_function,arr):
    with ProcessPoolExecutor() as executor:
        results = list(executor.map(my_function, arr))
    return results

# Data extraction

## Config

In [3]:
# Define configuration
model='ifs_tco3999-ng5_rcbmf_cf'
time_res = 'PT1H'
fine_zoom = 11
fine_latlon_gridstep = 0.02 # [°]
supersampling = {"lon": 4, "lat": 4}
submeso_pass_lower_sigma = 5 # [gridstep], for ifs_zoom = 11 => 0.01° latlon resampling => 10km gaussian smoothing => ~20km filter 

#region = 'eurec4a_extended'
region = 'gulf_stream'
#region = 'EAC'
regions = {
    'gulf_stream' : {
        'long_name' : 'Gulf Stream',
        'short_name' : 'gulf_stream',
        'boundaries' : [-83,-30,30,55], # [minlon,maxlon,minlat,maxlat]. lon -> [-180,180], lat -> [-90,90]  
    },    
    'gulf_stream_detail' : {
        'long_name' : 'Gulf Stream - detail',
        'short_name' : 'gulf_stream_detail',
        'boundaries' : [-73,-50,30,45], # [minlon,maxlon,minlat,maxlat]. lon -> [-180,180], lat -> [-90,90]  
    },    
    'tropical_atlantic' : {
        'long_name' : 'Tropical Atlantic',
        'short_name' : 'tropical_atlantic',
        'boundaries' : [-62,15,-20,20], # [minlon,maxlon,minlat,maxlat]. lon -> [-180,180], lat -> [-90,90]  
    },
    'tropical_atlantic_detail' : {
        'long_name' : 'Tropical Atlantic - detail',
        'short_name' : 'tropical_atlantic_detail',
        'boundaries' : [-40,-25,-5,0], # [minlon,maxlon,minlat,maxlat]. lon -> [-180,180], lat -> [-90,90]  
    },
    'eurec4a' : {
        'long_name' : '$EUREC^{4}A$',
        'short_name' : 'eurec4a',
        'boundaries' : [-62,-48,4,16], # [minlon,maxlon,minlat,maxlat]. lon -> [-180,180], lat -> [-90,90]  
    },
    'eurec4a_extended' : {
        'long_name' : r'extended $EUREC^{4}A$',
        'short_name' : 'eurec4a_extended',
        'boundaries' : [-62,-20,0,20], # [minlon,maxlon,minlat,maxlat]. lon -> [-180,180], lat -> [-90,90]  
    },
    'EAC' : {
        'long_name' : r'East Australian Current',
        'short_name' : 'EAC',
        'boundaries' : [150,165,-45,-25], # [minlon,maxlon,minlat,maxlat]. lon -> [-180,180], lat -> [-90,90]  
    },
}

# Select here the fields to be analysed.
str_mech = 'DMM'
#str_mech = 'PA'
if str_mech == 'DMM':
    sst_deriv = 'dsst_dr' # Choose between: 'dsst_dr', 'lapl_sst', 'd2sst_ds2', 'sst_prime'
    wind_deriv = 'dr_dot_prime_dr' # Choose between: 'wind_div', 'dr_dot_prime_dr', 'ds_dot_prime_ds', 'ws_prime'
    blh_deriv = 'dblh_prime_dr'
elif str_mech == 'PA':
    sst_deriv = 'd2sst_ds2'
    wind_deriv = 'ds_dot_prime_ds'
    wind_deriv = 'dblh_prime_ds'
else: 
    raise NameError('Mechanism not recognised')

WORKDIR = '/work/bb1153/b383497' 
enable_save_file = True

## Catalog

In [4]:
cat = intake.open_catalog("https://digital-earths-global-hackathon.github.io/catalog/catalog.yaml")["EU"]
print(pd.DataFrame(cat[model].describe()["user_parameters"]))
list(cat)

   name                     description type  allowed default
0  time  time resolution of the dataset  str   [PT1H]    PT1H
1  zoom       zoom level of the dataset  int  [7, 11]       7


['CERES_EBAF',
 'ERA5',
 'IR_IMERG',
 'JRA3Q',
 'MERRA2',
 'arp-gem-1p3km',
 'arp-gem-2p6km',
 'casesm2_10km_nocumulus',
 'icon_d3hp003',
 'icon_d3hp003aug',
 'icon_d3hp003feb',
 'icon_ngc4008',
 'ifs_tco2559_rcbmf',
 'ifs_tco3999-ng5_deepoff',
 'ifs_tco3999-ng5_rcbmf',
 'ifs_tco3999-ng5_rcbmf_cf',
 'ifs_tco3999_rcbmf',
 'nicam_220m_test',
 'nicam_gl11',
 'scream-dkrz',
 'tracking-d3hp003',
 'um_Africa_km4p4_RAL3P3_n1280_GAL9_nest',
 'um_CTC_km4p4_RAL3P3_n1280_GAL9_nest',
 'um_SAmer_km4p4_RAL3P3_n1280_GAL9_nest',
 'um_SEA_km4p4_RAL3P3_n1280_GAL9_nest',
 'um_glm_n1280_CoMA9_TBv1p2',
 'um_glm_n1280_GAL9',
 'um_glm_n2560_RAL3p3']

## Generate dataset

In [5]:
# Load dataset
with cat[model](zoom=fine_zoom,time=time_res).to_dask() as ds:
    # Choose only the best day for each week (ref. Giaccio et. al, PREPRINT)
    times = pd.to_datetime(ds.time.data)
    fridays = times[(times.weekday == 4) & (times>pd.to_datetime('2020-02-07T13:00:00')) & (times<pd.to_datetime('2020-03-01'))] 
    # Save the fields
    for friday in fridays:
        save_friday_data(ds,friday)

  self._ds = xr.open_dataset(self.urlpath, **kw)


2020-02-07 13:00:48 # Start
2020-02-07 13:00:48 # Compute dask
2020-02-07 13:00:48 # Compute fields




2020-02-07 13:00:48 # DONE
2020-02-07 14:00:32 # Start
2020-02-07 14:00:32 # Compute dask
2020-02-07 14:00:32 # Compute fields




2020-02-07 14:00:32 # DONE
2020-02-07 15:00:16 # Start
2020-02-07 15:00:16 # Compute dask
2020-02-07 15:00:16 # Compute fields




2020-02-07 15:00:16 # DONE
2020-02-07 16:00:00 # Start
2020-02-07 16:00:00 # Compute dask
2020-02-07 16:00:00 # Compute fields




2020-02-07 16:00:00 # DONE
2020-02-07 16:59:44 # Start
2020-02-07 16:59:44 # Compute dask
2020-02-07 16:59:44 # Compute fields




2020-02-07 16:59:44 # DONE
2020-02-07 17:59:28 # Start
2020-02-07 17:59:28 # Compute dask
2020-02-07 17:59:28 # Compute fields




2020-02-07 17:59:28 # DONE
2020-02-07 18:59:12 # Start
2020-02-07 18:59:12 # Compute dask
2020-02-07 18:59:12 # Compute fields




2020-02-07 18:59:12 # DONE
2020-02-07 20:01:04 # Start
2020-02-07 20:01:04 # Compute dask
2020-02-07 20:01:04 # Compute fields




2020-02-07 20:01:04 # DONE
2020-02-07 21:00:48 # Start
2020-02-07 21:00:48 # Compute dask
2020-02-07 21:00:48 # Compute fields




2020-02-07 21:00:48 # DONE
2020-02-07 22:00:32 # Start
2020-02-07 22:00:32 # Compute dask
2020-02-07 22:00:32 # Compute fields




2020-02-07 22:00:32 # DONE
2020-02-07 23:00:16 # Start
2020-02-07 23:00:16 # Compute dask
2020-02-07 23:00:16 # Compute fields




2020-02-07 23:00:16 # DONE
2020-02-14 00:00:00 # Start
2020-02-14 00:00:00 # Compute dask
2020-02-14 00:00:00 # Compute fields




2020-02-14 00:00:00 # DONE
2020-02-14 00:59:44 # Start
2020-02-14 00:59:44 # Compute dask
2020-02-14 00:59:44 # Compute fields




2020-02-14 00:59:44 # DONE
2020-02-14 01:59:28 # Start
2020-02-14 01:59:28 # Compute dask
2020-02-14 01:59:28 # Compute fields




2020-02-14 01:59:28 # DONE
2020-02-14 02:59:12 # Start
2020-02-14 02:59:12 # Compute dask
2020-02-14 02:59:12 # Compute fields




2020-02-14 02:59:12 # DONE
2020-02-14 03:58:56 # Start
2020-02-14 03:58:56 # Compute dask
2020-02-14 03:58:56 # Compute fields




2020-02-14 03:58:56 # DONE
2020-02-14 05:00:48 # Start
2020-02-14 05:00:48 # Compute dask
2020-02-14 05:00:48 # Compute fields




KeyboardInterrupt: 