# P_bar globe, based on P_bar deep dive

Need to try this coord interp again, but with the methods that seem promising from P_bar deep dive

## Housekeeping

In [2]:
import xarray as xr
import matplotlib.pyplot as plt
import matplotlib as mpl
from matplotlib.path import Path
import matplotlib.colors as colors
import pandas as pd
import numpy as np
from importlib import reload
import cartopy.crs as ccrs
import cmocean.cm as cmo
import gsw
import scipy.ndimage as filter
import scipy.interpolate as interpolate
from flox.xarray import xarray_reduce
from xgcm.autogenerate import generate_grid_ds
from xgcm import Grid
from tqdm.notebook import tqdm
from scipy.ndimage import uniform_filter1d

In [3]:
import os
os.chdir('/home/amf2288/argo-intern/funcs')
import filt_funcs as ff
import density_funcs as df

In [4]:
plt.rcParams['axes.titleweight'] = 'bold'
plt.rcParams['axes.titlesize']  = 18
plt.rcParams['axes.labelsize']  = 14
plt.rcParams['xtick.labelsize'] = 14
plt.rcParams['ytick.labelsize'] = 14
plt.rcParams['legend.fontsize'] = 14

## Load pre-made datasets

In [14]:
#individual profiles where sampling rate <=2.5m, pressure grid
ds_p = xr.open_dataset('/swot/SUM05/amf2288/res_2.5/ds_p.nc').rename({'PRES_INTERPOLATED':'PRESSURE','N_PROF':'PROFILE','LATITUDE':'LAT','LONGITUDE':'LON'})   
ds_p


In [56]:
RG_p = xr.open_dataset('/swot/SUM05/amf2288/RG_clim/RG_p_grid3.nc').rename({'lat':'LAT','lon':'LON','pressure':'PRESSURE'})
RG_p

## Sort by density and reindex (individual profiles)

In [15]:
def sorting_index_from_sig(sig):
    """
    sig0: 1D numpy array for a single profile.
    Returns: 1D array of indices (same length) that sorts the valid region
             and leaves NaNs at the edges.
    """
    valid = ~np.isnan(sig)

    if not np.any(valid):
        return np.arange(len(sig))  # nothing to sort

    # valid region boundaries
    i0 = np.argmax(valid)
    i1 = len(sig) - np.argmax(valid[::-1])

    # sorting index for the core region
    core = sig[i0:i1]
    core_sort_idx = np.argsort(core)

    # build full index array
    full_idx = np.arange(len(sig))
    full_idx[i0:i1] = full_idx[i0:i1][core_sort_idx]

    return full_idx

In [16]:
def apply_sort_index(var, sort_idx, dim="PRESSURE"):
    return xr.apply_ufunc(
        lambda v, idx: v[idx],
        var,
        sort_idx,
        input_core_dims=[[dim], [dim]],
        output_core_dims=[[dim]],
        vectorize=True,
        dask="parallelized",
        output_dtypes=[var.dtype],
    )

In [None]:
ds_sort_idx = xr.apply_ufunc(
    sorting_index_from_sig,
    ds_p["SIG0"],
    input_core_dims=[["PRESSURE"]],
    output_core_dims=[["PRESSURE"]],
    vectorize=True,
    dask="parallelized",
    output_dtypes=[int],
)

In [None]:
ds_p["SIG0_sort"] = apply_sort_index(ds_p["SIG0"], ds_sort_idx)

In [None]:
#vars_to_sort = ["CT", "SA", "TEMP", "PSAL", "PRESSURE", "SPICE"]

#start with just CT for this notebook
vars_to_sort = ["CT"]

for v in vars_to_sort:
    ds_p[f"{v}_sort"] = apply_sort_index(ds_p[v], ds_sort_idx)

In [20]:
#plots to compare CT to CT_sort on the globe

## Sort by density (climatology)

In [57]:
RG_sort_idx = xr.apply_ufunc(
    sorting_index_from_sig,
    RG_p["SIG0"],
    input_core_dims=[["PRESSURE"]],
    output_core_dims=[["PRESSURE"]],
    vectorize=True,
    dask="parallelized",
    output_dtypes=[int],
)

In [58]:
RG_p["SIG0_sort"] = apply_sort_index(RG_p["SIG0"], RG_sort_idx)

In [59]:
vars_to_sort = ["CT", "SA"]

for v in vars_to_sort:
    RG_p[f"{v}_sort"] = apply_sort_index(RG_p[v], RG_sort_idx)

## Create ds_r

In [13]:
def interp_to_rho(rho_profile, var_profile, rho_grid, flag='group'):
    
    valid = ~np.isnan(rho_profile) & ~np.isnan(var_profile)
    rho_nonan = rho_profile[valid]
    var_nonan = var_profile[valid]

    if len(rho_nonan) < 3 or len(var_nonan) < 3:
        return np.full_like(rho_grid, np.nan)

    if flag == 'group':
        _, idx = np.unique(rho_nonan, return_index=True)
        rho_nonan = rho_nonan[idx]
        var_nonan = var_nonan[idx]

    fvar = interpolate.PchipInterpolator(rho_nonan, var_nonan, extrapolate=False)
    return fvar(rho_grid)

In [73]:
def get_ds_rho(ds_z, da_SIG, rho_grid, variables, dim1='PROFILE', dim2='PRESSURE', RG_flag=False):

    xrs = []
    rho_size = rho_grid.size
    
    for var in variables:
        var_interp = xr.apply_ufunc(interp_to_rho, da_SIG, ds_z[var],
            input_core_dims=[[dim2], [dim2]],
            output_core_dims=[['rho_grid']],
            dask_gufunc_kwargs={'output_sizes': {'rho_grid': rho_size}},
            vectorize=True,
            dask='parallelized',
            kwargs={'rho_grid': rho_grid},
            output_dtypes=[ds_z[var].dtype])
        print(f'Completed comp for {var}')
        var_interp = xr.DataArray(
            data=var_interp.data,
            dims=[dim1, 'rho_grid'],
            coords={'rho_grid': rho_grid, dim1: ds_z[dim1]},
            name=var)
        xrs.append(var_interp)
        print(f'Completed {var}')
    
    ds_rho = xr.merge(xrs)
    print(f'Completed merge')

    if RG_flag==False:
        ds_rho = ds_rho.assign_coords(LAT  =(dim1, ds_z.LAT.data))
        ds_rho = ds_rho.assign_coords(LON  =(dim1, ds_z.LON.data))
        ds_rho = ds_rho.assign_coords(TIME =(dim1, ds_z.TIME.data))
    
    return ds_rho

In [28]:
#again just doing CT for now, to see where we're at for method validation
variables = ['CT_sort','PRESSURE']
ds_r = get_ds_rho(ds_p, ds_p.SIG0_sort, np.arange(22,28,0.001), variables)

Completed comp for CT_sort
Completed CT_sort


  output[index] = result


Completed comp for PRESSURE
Completed PRESSURE
Completed merge


In [84]:
ds_r = ds_r.rename({'rho_grid':'DENSITY'})

#### Need to convert RG_p to RG_r to preserve sorting, requires stacking and unstacking

In [70]:
RG_p_stacked = RG_p.stack(PROFILE=("LAT", "LON"))

pressure = RG_p_stacked.PRESSURE.data  # shape (1001,)
prof = RG_p_stacked.PROFILE.size      # number of profiles

# Repeat pressure along columns to get shape (1001, prof)
pressure_2d = np.repeat(pressure[:, np.newaxis], prof, axis=1)

# Add as a new variable with fresh name to avoid conflict
RG_p_stacked["PRESSURE"] = (("PRESSURE", "PROFILE"), pressure_2d)

In [78]:
RG_r_stacked = get_ds_rho(RG_p_stacked, RG_p_stacked.SIG0_sort, np.arange(22,28,0.001), ['CT_sort','SA_sort','SIG0_sort','PRESSURE'], RG_flag=True)

Completed comp for CT_sort
Completed CT_sort
Completed comp for SA_sort
Completed SA_sort
Completed comp for SIG0_sort
Completed SIG0_sort


  output[index] = result


Completed comp for PRESSURE
Completed PRESSURE
Completed merge


In [79]:
RG_r = RG_r_stacked.unstack('PROFILE').rename({'rho_grid':'DENSITY'})
RG_r

## Match climatology to profiles

In [80]:
def match_mean_profiles(ds_r, pres_rho_gridded, lon_bins, lat_bins):
    # Bin indices
    lat_idx = np.digitize(ds_r.LAT.values, lat_bins) - 1
    lon_idx = np.digitize(ds_r.LON.values, lon_bins) - 1

    # Validate indices
    valid = (
        (lat_idx >= 0) & (lat_idx < pres_rho_gridded.sizes['LAT']) &
        (lon_idx >= 0) & (lon_idx < pres_rho_gridded.sizes['LON'])
    )

    # Transpose to match original indexing: [density, lon, lat]
    pres_grid = pres_rho_gridded.transpose('DENSITY', 'LON', 'LAT').values

    # Initialize output
    matched_profiles = np.full(
        (ds_r.sizes['PROFILE'], ds_r.sizes['DENSITY']),
        np.nan,
        dtype=np.float32
    )

    # Loop with progress bar
    for i in tqdm(range(ds_r.sizes['PROFILE']), desc="Matching profiles", unit="PROFILE"):
        if valid[i]:
            try:
                matched_profiles[i, :] = pres_grid[:, lon_idx[i], lat_idx[i]]
            except IndexError:
                continue

    # Return as DataArray
    return xr.DataArray(
        matched_profiles,
        dims=('PROFILE', 'DENSITY'),
        coords={'PROFILE': ds_r.PROFILE, 'DENSITY': ds_r.DENSITY}
    )

In [81]:
lon_bins = np.arange(-180,181,3)
lat_bins = np.arange(-90,91,3)

In [85]:
ds_r['PRESSURE_mean'] = match_mean_profiles(ds_r, RG_r.PRESSURE, lon_bins, lat_bins)

Matching profiles:   0%|          | 0/1488063 [00:00<?, ?PROFILE/s]

In [86]:
ds_r

## Create ds_pbar

In [87]:
def interp_to_pmean(var_profile, pmean_smooth, pmean_grid):

    try:
        valid = ~np.isnan(var_profile) & ~np.isnan(pmean_smooth)
        var_nonan = var_profile[valid]
        pmean_nonan = pmean_smooth[valid]

        fvar = interpolate.PchipInterpolator(pmean_nonan, var_nonan, extrapolate=False)
        return fvar(pmean_grid)

    except ValueError:
        return np.full_like(pmean_grid, np.nan)
        print(f"ValueError")

In [88]:
def get_ds_pmean(ds_rho, pres_var, pmean_grid, variables, dim1='PROFILE', dim2='rho_grid'):
    
    xrs = []
    pmean_size = pmean_grid.size
    
    # 1. Mean pressure profile
    pres_profile = ds_rho[pres_var].mean(dim1)

    # 2. Sort the mean profile
    sort_idx_mean = sorting_index_from_sig(pres_profile.values)
    pres_profile_sorted = pres_profile.values[sort_idx_mean]

    # 3. Wrap back into DataArray
    pmean_smooth = xr.DataArray(
        pres_profile_sorted,
        dims=[dim2],
        coords={dim2: ds_rho[dim2]},
        name="pmean_smooth"
    )
    
    for var in variables:
        var_interp = xr.apply_ufunc(interp_to_pmean, ds_rho[var], pmean_smooth,
                                   input_core_dims=[[dim2], [dim2]],
                                   output_core_dims=[['Pmean_grid']],
                                   dask_gufunc_kwargs={'output_sizes': {'Pmean_grid':pmean_size}},
                                   vectorize=True,
                                   dask='parallaleized',
                                   kwargs={'pmean_grid':pmean_grid},
                                   output_dtypes=[ds_rho[var].dtype])
        print(f'Completed comp for {var}')
        var_interp = xr.DataArray(
            data=var_interp.data,
            dims=[dim1,'pmean_grid'],
            coords={'pmean_grid': pmean_grid, dim1: ds_rho[dim1]},
            name=var)
        xrs.append(var_interp)
        print(f'Completed {var}')
        
    ds_pmean = xr.merge(xrs)
    print(f'Completed merge')
    ds_pmean = ds_pmean.assign_coords(LAT = (dim1, ds_rho.LAT.data))
    ds_pmean = ds_pmean.assign_coords(LON = (dim1, ds_rho.LON.data))
    ds_pmean = ds_pmean.assign_coords(TIME = (dim1, ds_rho.TIME.data))
    
    return ds_pmean

In [37]:
ds_pbar = get_ds_pmean(ds_r, 'PRESSURE', np.linspace(0, 2000, 1001), variables)

Completed comp for CT_sort
Completed CT_sort


  output[index] = result


Completed comp for PRESSURE
Completed PRESSURE
Completed merge


## Create gridded ds_p, ds_r, ds_pbar

In [41]:
def get_ds_gridded(ds, lon_bins, lat_bins, z_coord):
    # Step 1: Reduce using flox
    ds_binned = xarray_reduce(
        ds,
        'LON',
        'LAT',
        func='mean',
        expected_groups=(
            pd.IntervalIndex.from_breaks(lon_bins),
            pd.IntervalIndex.from_breaks(lat_bins)
        ),
        fill_value=np.nan,
        skipna=True
    )

    # Step 2: Rename dimensions and variables
    ds_binned = ds_binned.rename_dims({
        'LON_bins': 'LON',
        'LAT_bins': 'LAT',}).rename_vars({
        #'rho_grid': 'density'
        'LON_bins': 'LON',
        'LAT_bins': 'LAT',
        #'rho_grid': 'density'
    })

    # Step 3: Replace Interval coordinates with midpoints
    def interval_midpoints(intervals):
        return np.array([interval.mid for interval in intervals])

    ds_binned = ds_binned.assign_coords({
        'LON': ('LON', interval_midpoints(ds_binned['LON'].values)),
        'LAT': ('LAT', interval_midpoints(ds_binned['LAT'].values))
    })

    # Step 4: Generate grid and xgcm object
    ds_gridded = generate_grid_ds(ds_binned, {'X': 'LON', 'Y': 'LAT'})
    grid = Grid(ds_gridded, coords={
        'X': {'center': 'LON', 'left': 'lon_left'},
        'Y': {'center': 'LAT', 'left': 'lat_left'},
        'Z': {'center': z_coord}
    }, periodic=['X'])

    return ds_gridded, grid

In [40]:
ds_p_grid = get_ds_gridded(ds_p, lon_bins, lat_bins, 'PRESSURE')
print('Completed gridding ds_p')
ds_r_grid = get_ds_gridded(ds_r, lon_bins, lat_bins, 'rho_grid').rename({'rho_grid':'DENSITY'})
print('Completed gridding ds_r')
ds_pbar_grid = get_ds_gridded(ds_pbar, lon_bins, lat_bins, 'pmean_grid').rename({'pmean_grid':'PRESSURE_BAR'})
print('Completed gridding ds_pbar')

ValueError: cannot rename 'lon_bins' because it is not found in the dimensions of this dataset ('PRESSURE', 'LON_bins', 'LAT_bins')

## Plot results