# ADD attrs

In [1]:
import numpy as np
import pandas as pd
import xarray as xr
import pyproj
from rasterio.transform import Affine

import os
from glob import glob

import histlib.box as box
import histlib.aviso as aviso
import histlib.cstes as cstes
#import histlib.diagnosis as diag
import histlib.erastar as eras

from dask.distributed import wait
from histlib.cstes import labels, zarr_dir



In [2]:
if True:
    from dask.distributed import Client
    from dask_jobqueue import PBSCluster
    from dask import config
    config.set({"distributed.comm.timeouts.connect": "200s"})
    cluster = PBSCluster(cores=2, processes=2, walltime='04:00:00')
    #cluster = PBSCluster(cores=20, processes=20, walltime='02:00:00')#8
    w = cluster.scale(jobs=1)
else:
    from dask.distributed import Client, LocalCluster
    cluster = LocalCluster()

client = Client(cluster)
client

0,1
Connection method: Cluster object,Cluster type: dask_jobqueue.PBSCluster
Dashboard: http://10.148.1.70:8787/status,

0,1
Dashboard: http://10.148.1.70:8787/status,Workers: 0
Total threads: 0,Total memory: 0 B

0,1
Comm: tcp://10.148.1.70:42148,Workers: 0
Dashboard: http://10.148.1.70:8787/status,Total threads: 0
Started: Just now,Total memory: 0 B


# BASIC MATCHUP

In [17]:
def add_ggx_attrs(ds_data):
    listv = [l for l in list(ds_data.variables) if 'sla' in l]+['alti_mdt','alti_ocean_tide', 'alti_dac', 'alti_internal_tide']
    listv = [l for l in listv if 'gg' not in l]
    for v in listv :
        ds_data[v.replace('alti', 'alti_ggx')].attrs['comment'] = ds_data[v].attrs['comment']
        ds_data[v.replace('alti', 'alti_ggx')].attrs['units'] = r'$m.s^{-2}$'
        ds_data[v.replace('alti', 'alti_ggx')].attrs['long_name']= r'$g\partial_x$'+v.replace('alti_','')

def add_adt_to_ds_data(ds_data):
    add_ggx_attrs(ds_data)
    ds_data = ds_data.rename({'drifter_acc_x':'drifter_acc_x_0', 'drifter_acc_y':'drifter_acc_y_0', 'drifter_coriolis_x':'drifter_coriolis_x_0', 'drifter_coriolis_y':'drifter_coriolis_y_0'})
    for sla in ['alti_ggx_sla_filtered','alti_ggx_sla_unfiltered','alti_ggx_sla_unfiltered_denoised','alti_ggx_sla_unfiltered_imf1']:
        ds_data[sla.replace('sla', 'adt')] = ds_data[sla] + ds_data.alti_ggx_mdt
        ds_data[sla.replace('sla', 'adt')].attrs['comment'] = ds_data[sla].attrs['comment']
        ds_data[sla.replace('sla', 'adt')].attrs['units'] = r'$m.s^{-2}$'
        ds_data[sla.replace('sla', 'adt')].attrs['long_name']= ds_data[sla].attrs['long_name'].replace('sla', 'adt')
    return ds_data

def change_obs_coords(ds) :
    o = np.array(ds['obs']).astype('U')
    L = np.full_like(o, l+'__')
    ob = xr.DataArray(np.char.add(L, o), dims='obs')
    ds = ds.drop('obs')
    ds = ds.assign_coords({'obs':ob})
    return ds
    
import histlib.stress_to_windterm as stw
_data_var = [
    "f",
    "box_theta_lon",
    "__site_matchup_indice",
    "box_theta_lat",
    "drifter_theta_lon",
    "drifter_theta_lat",
    "drifter_typebuoy",
    "alti___distance",
    "alti___time_difference",
    'alti_ggx_dac',
    'alti_ggx_internal_tide',
    'alti_ggx_mdt',
    'alti_ggx_ocean_tide',
    'alti_ggx_sla_filtered',
    'alti_ggx_sla_unfiltered',
    'alti_ggx_sla_unfiltered_denoised',
    #'alti_ggx_sla_unfiltered_imf1',
    'alti_ggx_adt_filtered',
    'alti_ggx_adt_unfiltered',
    'alti_ggx_adt_unfiltered_denoised',
    #'alti_ggx_adt_unfiltered_imf1',
    "drifter_vx",
    "drifter_vy",
    "drifter_acc_x_0",
    "drifter_acc_y_0",
    "drifter_coriolis_x_0",
    "drifter_coriolis_y_0",
    
]
_aviso_var = [
    "aviso_alti_matchup_ggx_adt",
    "aviso_alti_matchup_ggy_adt",
    "aviso_drifter_matchup_ggx_adt",
    "aviso_drifter_matchup_ggy_adt",
    "aviso_alti_matchup_ggx_sla",
    "aviso_alti_matchup_ggy_sla",
    "aviso_drifter_matchup_ggx_sla",
    "aviso_drifter_matchup_ggy_sla",
]
_stress_var = [
    "e5_alti_matchup_taue",
    "e5_alti_matchup_taun",
    "es_alti_matchup_taue",
    "es_alti_matchup_taun",
    "e5_drifter_matchup_taue",
    "e5_drifter_matchup_taun",
    "es_drifter_matchup_taue",
    "es_drifter_matchup_taun",
]


list_wd_srce_suffix = ["es", "e5"]
list_func = [stw.cst_rio_z0, stw.cst_rio_z15]
list_func_suffix = ["cstrio_z0", "cstrio_z15"]


In [28]:
def matchup_dataset_one(l):
    ds_data = xr.open_zarr(os.path.join(zarr_dir,'test', f'{l}.zarr')).chunk({'obs':5}).persist()
    ds_data = add_adt_to_ds_data(ds_data)
    ds_aviso = xr.open_zarr(os.path.join(zarr_dir,'test', f'aviso_{l}.zarr')).chunk({'obs':5}).persist()
    ds_stress = xr.open_zarr(os.path.join(zarr_dir,'test', f'erastar_{l}.zarr')).chunk({'obs':5}).persist()
        # SELECT MATCHUP
    
    # COLOCALIZATIONS DATA
    drogue_status = ds_data.time<ds_data.drifter_drogue_lost_date.mean('site_obs')
    ds_data = ds_data[_data_var].reset_coords(['drifter_lat','drifter_lon','drifter_time','drifter_x','drifter_y',])
    _ds_data = ds_data.where(ds_data.site_obs == ds_data.__site_matchup_indice).sum('site_obs')# site_obs dimension (__site_matchup_indice not teh same for all, need where
    _ds_data = _ds_data.isel(alti_time_mid=ds_data.dims['alti_time_mid']//2).drop(["alti_time_mid", "alti_x_mid", "alti_y_mid"]) #alti_matchup
    for v in _ds_data.variables:
        _ds_data[v].attrs = ds_data[v].attrs
    _ds_data['drogue_status'] = drogue_status.assign_attrs({'long_name':'drogue status', 'description':'True if drogued, False if undrogued (day precision only)'})
    
    # AVISO
    _ds_aviso = ds_aviso[_aviso_var]

    #ERASTAR
    _ds_stress = ds_stress[_stress_var]

    # FOR IND PDFS
    _ds = xr.merge([_ds_data, _ds_aviso, _ds_stress])

    # COMPUTE WD TERM
    _ds = xr.merge(
        [
            _ds,
            stw.compute_wd_from_stress(
                _ds, list_wd_srce_suffix, list_func, list_func_suffix, False
            ),
        ]
    )
    # CLEANING : drop useless variables
    _ds = _ds.drop(
        _stress_var
        + [
            "f",
            "box_theta_lon",
            "box_theta_lat",
            "drifter_theta_lon",
            "drifter_theta_lat",
            "__site_matchup_indice",
        ]
    ).set_coords(["alti___distance", "alti___time_difference"])
    _ds = _ds.rename({v: v.replace("_matchup", "") for v in _ds})
    _ds = change_obs_coords(_ds)  
    _ds = _ds.drop(['box_x', 'box_y'])
    return _ds

In [29]:
l='gps_Jason-3_2020'
ds = matchup_dataset_one(l).compute()
ds

# CREATE COMB

In [36]:
def combinations(_ds, wd_x=None, wd_y=None, grad_x=None, grad_y=None, cutoff=None):
    """Create a list of dictionnaries containing the different data combinations possible to rebuild the moment conservation

    Parameters
    ----------
    _ds: dataset
        contains - drifter_acc_x/y,
                 - drifter_coriolisx/y,
                 - sla gradients from the different sources all fiishing with '_ggx/y',
                 - wind terms from the different sources and way to compute it from stress all finishing with '_wd_x/y'
    Returns
    ----------
    [{'acc': 'drifter_acc_x','coriolis': 'drifter_coriolis_x','ggx': 'alti_ggx','wind': 'es_cstrio_z0_alti_wd_x','id': 'co_es_cstrio_z0_alti_x'},....]
    list of dictionnaries containing the varaibles taken for each term and an identification: gradsrc_wdsrc_wdmethod_wddepth_matchupposition_x/y
    """
    if not wd_x : wd_x = [l for l in _ds if "wd_x" in l]
    if not wd_y : wd_y = [l for l in _ds if "wd_y" in l]
    if not grad_x : grad_x = [l for l in _ds if "ggx_adt" in l or "ggx_sla" in l]
    if not grad_y : grad_y = [l for l in _ds if "ggy_adt" in l or "ggy_sla" in l]
    if not cutoff : cutoff = [l.split('acc_x_')[-1] for l in _ds if "acc_x" in l]
    
    LIST = []

    for cf in cutoff :
        for grad in grad_x :
            for wd in wd_x :
                lx = {
                    "acc": 'drifter_acc_x_'+cf,
                    "coriolis": 'drifter_coriolis_x_'+cf,
                    "ggrad": "",
                    "wind": "",
                    "id": "",
                }
                # AVISO grad
                if "aviso" in grad:  #
                    if ("alti" in grad) and ("alti" in wd):
                        lx["ggrad"] = grad
                        lx["wind"] = wd
                        lx["id"] = (
                            "aviso__"
                            +cf+'__'
                            +grad[-3:]
                            +'__'
                            +"_".join(wd.split("_")[:3])
                            + "__alti_x"
                        )
                        LIST.append(lx)
                    elif ("drifter" in grad) and ("drifter" in wd):
                        lx["ggrad"] = grad
                        lx["wind"] = wd
                        lx["id"] = (
                            "aviso__"
                            +cf+'__'
                            +grad[-3:]
                            +'__'
                            + "_".join(wd.split("_")[:3])
                            + "__drifter_x"
                        )
                        LIST.append(lx)
                # Altimeters' grad
                elif "alti" in grad:
                    if "alti" in wd:
                        lx["ggrad"] = grad
                        lx["wind"] = wd
                        lx["id"] = (
                            "co__"
                            +cf+'_'
                            + grad.replace("alti_", "").replace("ggx", "")
                            +'__'
                            + "_".join(wd.split("_")[:3])
                            + "__alti_x"
                        )
                        LIST.append(lx)
                    elif "drifter" in wd:
                        lx["ggrad"] = grad
                        lx["wind"] = wd
                        lx["id"] = (
                            "co__"
                            +cf+'_'
                            + grad.replace("alti_", "").replace("ggx", "")
                            +'__'
                            + "_".join(wd.split("_")[:3])
                            + "__drifter_x"
                        )
                        LIST.append(lx)
        for grad in grad_y:
            for wd in wd_y:
                ly = {
                    "acc": "drifter_acc_y_"+cf,
                    "coriolis": "drifter_coriolis_y_"+cf,
                    "ggrad": "",
                    "wind": "",
                    "id": "",
                }
    
                if ("alti" in grad) and ("alti" in wd):
                    ly["ggrad"] = grad
                    ly["wind"] = wd
                    ly["id"] = (
                        "aviso__"
                        +cf+'__'
                        +grad[-3:]
                        + "__"
                        + "_".join(wd.split("_")[:3])
                        + "__alti_y"
                    )
                    LIST.append(ly)
    
                elif ("drifter" in grad) and ("drifter" in wd):
                    ly["ggrad"] = grad
                    ly["wind"] = wd
                    ly["id"] = (                            
                        "aviso__"
                        +cf+'__'
                        +grad[-3:]
                        + "__"
                        + "_".join(wd.split("_")[:3])
                        + "__drifter_y"
                    )
                    LIST.append(ly)
    return LIST

In [58]:
def datasets_for_pdfs(
    ds_matchup,
    sum_=False,
    except_=False,
    wd_x=None,
    wd_y=None,
    grad_x=None,
    grad_y=None,
    cutoff=None
):
    """Create a list of dictionnaries containing the different data combinations possible to rebuild the moment conservation

    Parameters
    ----------
    ds_data: dataset
            dataset containing colocalisations, should contain at least the _data_var
    ds_aviso: dataset
            dataset containing aviso sla gradient terms
    ds_stress:dataset
            dataset containing wind stress that will be used to compute the wind stress divergence term
    sum_: bool
            if true create a dataset with sum values of all the different combinations found with the combination function
    except_: bool
            if true create a dataset with sum values except one term for all the different combinations found with the combination function
    _data_var: str list
            list of variables to extract from ds_data, should contains at least ['f','box_theta_lon','__site_matchup_indice','box_theta_lat','drifter_theta_lon', 'drifter_theta_lat', 'alti___distance','alti___time_difference','drifter_acc_x', 'drifter_acc_y', 'drifter_coriolis_x', 'drifter_coriolis_y']
    _aviso_var: str list
            list of variables to extract from ds_aviso
    _stress_var: str list
            list of variables to extract from ds_stress
    list_wd_srce_suffix : str list
            list of wind stress term source suffix, ex : if we want only erastar wind stress list_wd_srce_suffix= ['es'], default is all sources
    list_func : function list
            list of functions to compute wind term from wind stress, these functions have to take the functions in the stress_to_windterm.py library as a model
    list_func_suffix : str list
            suffix to put in wind term variable name to identify the function used to compute wind term from wind stress, should correspond to list_func

    """
    
    # SUM combination
    if not wd_x : wd_x = [l for l in ds_matchup if "wd_x" in l]
    if not wd_y : wd_y = [l for l in ds_matchup if "wd_y" in l]
    if not grad_x : grad_x = [l for l in ds_matchup if "ggx_adt" in l or "ggx_sla" in l]
    if not grad_y : grad_y = [l for l in ds_matchup if "ggy_adt" in l or "ggy_sla" in l]
    if not cutoff : cutoff = [l.split('acc_x_')[-1] for l in ds_matchup if "acc_x" in l]
        
    COMB = combinations(ds_matchup, wd_x, wd_y, grad_x, grad_y, cutoff)
    ds_matchup = ds_matchup[wd_x + wd_y+ grad_x + grad_y + ["drifter_acc_x_"+cf for cf in cutoff]+["drifter_acc_y_"+cf for cf in cutoff]+["drifter_coriolis_x_"+cf for cf in cutoff]+["drifter_coriolis_y_"+cf for cf in cutoff]]
    _ds_sum = xr.Dataset()
    _ds_except = xr.Dataset()

    id_comb_list = []
    for comb in COMB:
        _id = comb["id"]
        id_comb_list.append(_id)
        comb.pop("id")
        ds_matchup["id_comb"] = id_comb_list

        if sum_:
            # TOTAL SUM
            S = 0
            print(comb.values())

            for l in list(comb.values()):
                print(l)
                S = S + ds_matchup[l]
            id_str = "sum_" + _id
            _ds_sum[id_str] = xr.DataArray(
                data=S,
                attrs={
                    "description": "+".join(comb.keys()),
                    "long_name": "+".join(
                        [ds_matchup[comb[v]].attrs["long_name"] for v in comb.keys()]
                    ),
                    "units": r"$m.s^{-2}$",
                    **comb,
                },
            )
            _ds_sum["id_comb"] = id_comb_list

        if except_:
            # EXCEPT ONE
            for except_key in list(comb.keys()):
                id_str_2 = "exc_" + except_key + "_" + _id
                S2 = 0
                keys = [l for l in comb.keys() if l != except_key]
                for key in keys:
                    if key != except_key:
                        S2 = S2 + ds_matchup[comb[key]]
                _ds_except[id_str_2] = xr.DataArray(
                    data=S2,
                    attrs={
                        "description": "+".join(keys),
                        "long_name": "+".join(
                            [ds_matchup[comb[v]].attrs["long_name"] for v in keys]
                        ),
                        "units": r"$m.s^{-2}$",
                        **comb,
                    },
                )
                _ds_except["id_comb"] = id_comb_list

    DS = [ds_matchup]
    if except_:
        DS.append(_ds_except)
    if sum_:
        DS.append(_ds_sum)
    return xr.merge(DS)


In [59]:
dsm = datasets_for_pdfs(ds, sum_=True, except_=True, )

dict_values(['drifter_acc_x_0', 'drifter_coriolis_x_0', 'alti_ggx_sla_filtered', 'es_cstrio_z0_alti_wd_x'])
drifter_acc_x_0
drifter_coriolis_x_0
alti_ggx_sla_filtered
es_cstrio_z0_alti_wd_x
dict_values(['drifter_acc_x_0', 'drifter_coriolis_x_0', 'alti_ggx_sla_filtered', 'es_cstrio_z0_drifter_wd_x'])
drifter_acc_x_0
drifter_coriolis_x_0
alti_ggx_sla_filtered
es_cstrio_z0_drifter_wd_x
dict_values(['drifter_acc_x_0', 'drifter_coriolis_x_0', 'alti_ggx_sla_filtered', 'e5_cstrio_z0_alti_wd_x'])
drifter_acc_x_0
drifter_coriolis_x_0
alti_ggx_sla_filtered
e5_cstrio_z0_alti_wd_x
dict_values(['drifter_acc_x_0', 'drifter_coriolis_x_0', 'alti_ggx_sla_filtered', 'e5_cstrio_z0_drifter_wd_x'])
drifter_acc_x_0
drifter_coriolis_x_0
alti_ggx_sla_filtered
e5_cstrio_z0_drifter_wd_x
dict_values(['drifter_acc_x_0', 'drifter_coriolis_x_0', 'alti_ggx_sla_filtered', 'es_cstrio_z15_alti_wd_x'])
drifter_acc_x_0
drifter_coriolis_x_0
alti_ggx_sla_filtered
es_cstrio_z15_alti_wd_x
dict_values(['drifter_acc_x_0', 'dri

In [None]:
def store_base_matchup_datasets_for_pdfs(labels=cstes.labels, zarr_dir=cstes.zarr_dir,
                      **kwargs
                     ):
    for l in labels :
        ds_data = xr.open_zarr(os.path.join(zarr_dir, f'{l}.zarr')).chunk({'obs':500})
        ds_data = add_adt_to_ds_data(ds_data)
        ds_aviso = xr.open_zarr(os.path.join(zarr_dir, f'aviso_{l}.zarr')).chunk({'obs':500})
        ds_stress = xr.open_zarr(os.path.join(zarr_dir, f'erastar_{l}.zarr')).chunk({'obs':500})
    
        #DATASET FOR PDF
        ds = datasets_for_pdfs(ds_data, ds_aviso, ds_stress, **kwargs)
        ds = ds.where(ds.alti___time_difference<=1800, drop=True)
        ds = change_obs_coords(ds)
    if l==labels[0]
        ds.to_zarr(os.path.join(zarr_dir_ok,f'matchup.zarr'), mode='w')
    else : 
        ds.to_zarr(os.path.join(zarr_dir_ok,f'matchup.zarr'), append_dim='obs')    

________
# Test matchup selection on ds_data
CAUTION : ds.__site_matchup_indice not the same for all dimensions ! -> need .where and then .sum() to reduce site_obs dim

In [46]:
ds = ds_data[_data_var].isel(obs=slice(0,2))
ds = ds.reset_coords(['drifter_lat','drifter_lon','drifter_time','drifter_x','drifter_y',])
ds['__site_matchup_indice'] = xr.DataArray([0,1], dims='obs')# change for easiest visualisation
dw = ds.where(ds.site_obs == ds.__site_matchup_indice).compute()
dwm =dw.sum('site_obs')

In [47]:
dw

In [48]:
dwm

In [30]:
cluster.close()