# Pbar Globe 2nd Try

I think I need a clean slate where I can start at the very beginning and test each step along the way. I'm assuming most/all of this code will be pulled from pbar_globe, but there's a lot going on there, and I tried to jump to the finish line and it didn't go well.

# Housekeeping

## Imports and plotting params

In [1]:
import xarray as xr
import matplotlib.pyplot as plt
import matplotlib as mpl
from matplotlib.path import Path
import matplotlib.colors as colors
import pandas as pd
import numpy as np
from importlib import reload
import cartopy.crs as ccrs
import cmocean.cm as cmo
import gsw
import scipy.ndimage as filter
import scipy.interpolate as interpolate
from flox.xarray import xarray_reduce
from xgcm.autogenerate import generate_grid_ds
from xgcm import Grid

In [2]:
import os
os.chdir('/home.ufs/amf2288/argo-intern/funcs')
import filt_funcs as ff
import density_funcs as df

In [3]:
plt.rcParams['axes.titleweight'] = 'bold'
plt.rcParams['axes.titlesize']  = 18
plt.rcParams['axes.labelsize']  = 14
plt.rcParams['xtick.labelsize'] = 14
plt.rcParams['ytick.labelsize'] = 14
plt.rcParams['legend.fontsize'] = 14

## Load pre-made datasets

In [11]:
#ds_p = xr.open_dataset('/swot/SUM05/amf2288/res_2.5/ds_p.nc')                 #individual profiles where sampling rate <=2.5m, pressure grid
#ds_rho = xr.open_dataset('/swot/SUM05/amf2288/res_2.5/ds_rho.nc')             #individual profiles where sampling rate <=2.5m, density grid

ds_p_grid3 = xr.open_dataset('/swot/SUM05/amf2288/res_2.5/ds_p_grid3.nc')     #3x3 gridded dataset where sampling rate <=2.5m, pressure grid
ds_rho_grid3 = xr.open_dataset('/swot/SUM05/amf2288/res_2.5/ds_rho_grid3.nc') #3x3 gridded dataset where sampling rate <=2.5m, density grid

# Create ds_pmean

## Construct $p-\rho$ relationship

In [None]:
def match_mean_profiles(ds_rho, pres_rho_gridded):
    lat_bins = pres_rho_gridded.lat_left.values
    lon_bins = pres_rho_gridded.lon_left.values

    lat_idx = np.digitize(ds_rho.LATITUDE.values, lat_bins) - 1
    lon_idx = np.digitize(ds_rho.LONGITUDE.values, lon_bins) - 1

    valid = (lat_idx >= 0) & (lat_idx < len(pres_rho_gridded.lat)) & \
            (lon_idx >= 0) & (lon_idx < len(pres_rho_gridded.lon))

    matched_profiles = np.full((ds_rho.dims['N_PROF'], ds_rho.dims['rho_grid']), np.nan, dtype=np.float32)
    matched_profiles[valid] = pres_rho_gridded.PRES.values[:, lon_idx[valid], lat_idx[valid]].T

    return xr.DataArray(matched_profiles, dims=('N_PROF', 'rho_grid'),
                        coords={'N_PROF': ds_rho.N_PROF, 'rho_grid': ds_rho.rho_grid})

In [None]:
ds_rho['pres_dens'] = match_mean_profiles(ds_rho, ds_rho_grid3.PRES)

In [None]:
num = 0
prof = ds_rho.isel(N_PROF=num)
mean = ds_rho_grid3.PRES.sel(lat=prof.LATITUDE, method='nearest').sel(lon=prof.LONGITUDE, method='nearest')
print(f'PROF lat: {prof.LATITUDE.values}, lon: {prof.LONGITUDE.values}')
print(f'MEAN lat: {mean.lat.values}, lon: {mean.lon.values}')

plt.figure(figsize=(4,6))
prof.PRES.plot(y='rho_grid', label='profile')
prof.pres_dens.plot(y='rho_grid',  label='mean profile')
mean.PRES.plot(y='density', label='pres_dens')
plt.legend();

## Interpolate ds_rho to ds_pmean

In [None]:
from scipy.ndimage import uniform_filter1d

def interp_to_pmean(var_profile, pres_dens, pmean_grid, roll=50):
    try:
        # Apply a centered moving average using a uniform filter
        pmean_smooth = uniform_filter1d(pres_dens, size=roll, mode='nearest', origin=0)

        valid = ~np.isnan(var_profile) & ~np.isnan(pmean_smooth)
        var_nonan = var_profile[valid]
        pmean_nonan = pmean_smooth[valid]

        if len(pmean_nonan) < 3:
            return np.full_like(pmean_grid, np.nan)

        fvar = interpolate.PchipInterpolator(pmean_nonan, var_nonan, extrapolate=False)
        return fvar(pmean_grid)

    except ValueError:
        return np.full_like(pmean_grid, np.nan)

In [None]:
def get_ds_pmean(ds_rho, pmean_grid, variables, dim1='N_PROF', dim_dens='rho_grid', dim_pmean='pmean', roll=50):
    
    xrs = []
    pmean_size = pmean_grid.size
    
    for var in variables:
        var_interp = xr.apply_ufunc(interp_to_pmean, ds_rho[var], ds_rho['pres_dens'],
                                   input_core_dims=[['density'], ['density']],
                                   output_core_dims=[['Pmean_grid']],
                                   dask_gufunc_kwargs={'output_sizes': {'Pmean_grid':pmean_size}},
                                   vectorize=True,
                                   dask='parallaleized',
                                   kwargs={'pmean_grid':pmean_grid, 'roll':roll},
                                   output_dtypes=[ds_rho[var].dtype])
        print(f'Completed comp for {var}')
        var_interp = xr.DataArray(
            data=var_interp.data,
            dims=[dim1,'pmean_grid'],
            coords={'pmean_grid': pmean_grid, 'prof': ds_rho.prof},
            name=var)
        xrs.append(var_interp)
        print(f'Completed {var}')
        
    ds_pmean = xr.merge(xrs)
    print(f'Completed merge')
    ds_pmean = ds_pmean.assign_coords(lat  = ('N_PROF', ds_rho.lat.data))
    ds_pmean = ds_pmean.assign_coords(lon = ('N_PROF', ds_rho.lon.data))
    #ds_pmean = ds_pmean.assign_coords(TIME      = ('N_PROF', ds_rho.TIME.data))
    
    return ds_pmean.rename({'pmean_grid':'pmean'})