<br>

# Compute Salinity Index
***

In [1]:
import xarray as xr
import scipy.signal as sg
import numpy as np

<br>

### Define Funtions
***

In [2]:
def centerlon(data):
    i = data.lon
    data = xr.concat([
        data.sel(lon=i[i >= 180]),
        data.sel(lon=i[i <  180])],
        dim="lon"
    )
    data = data.assign_coords({"lon": (((data.lon + 180) % 360) - 180)})

    return data

In [3]:
def timeplace(data, lat_s, lat_e, lon_s, lon_e, year_s, year_e):
    data = data.sel(                              
        lat = slice(lat_s,lat_e),
        lon = slice(lon_s,lon_e),
        time = (data["time.year"] >= year_s)
             & (data["time.year"] <= year_e),
    )
    return data

In [4]:
def detrend_anoms(data):

    
    ### Substract Climatology
    
    anom_d = (data
        .groupby("time.month")
        .mean(("time"))
    )
    anom_d = data.groupby('time.month') - anom_d
    
    
    ### Detrend Data
    
    for lat in range(len(anom_d.lat)):
        for lon in range(len(anom_d.lon)):

            if not xr.ufuncs.isnan(anom_d[:, lat, lon]).sum().isin([0, len(anom_d.time)]):
                anom_d[:, lat, lon] = 'nan'

            elif xr.ufuncs.isnan(anom_d[:, lat, lon]).sum().isin([0]):
                anom_d[:, lat, lon] = sg.detrend(anom_d[:, lat, lon], axis=0)

    
    return anom_d

In [5]:
def seasonal_mean(data, month_s, month_e):
    
    data = (
        data.sel(
            time=(data["time.month"] >= month_s)
            & (data["time.month"] <= month_e)
        )
        .groupby('time.year')
        .mean('time')
    )
    
    return data

<br>

### Set Years
---

In [6]:
year_s = 1901
year_e = 2017

<br>

## Salinity
---
---

### Set Region and Month

In [7]:
lat_s = -40
lat_e = 60

lon_s = -60
lon_e = 50

month_s = 3  # April
month_e = 5  # June

### Import Data

In [8]:
sssa_d = xr.open_dataset('data/da_my_sss.nc').salinity
# sssa_d

### Process Data

In [9]:
# %%time
### Center around 0-meridian
sssa_d = centerlon(sssa_d)

### Choose Time Range and Region
sssa_d = timeplace(sssa_d, lat_s, lat_e, lon_s, lon_e, year_s, year_e)


### Compute Anomalies and Detrend them
sssa_d = detrend_anoms(sssa_d)

### Seasonal Mean
sssa_d = seasonal_mean(sssa_d, month_s, month_e)

sssa_d

<br>

## Box Indices (as suggested in Li2016)
---

In [10]:
def reg_mean(data, region, nanregion):

    data.loc[
        :,
        region_nan['lat_s']:region_nan['lat_e'],
        region_nan['lon_s']:region_nan['lon_e']
    ] = np.nan
    
    data = (
        data.sel(
            lat=slice(region['lat_s'], region['lat_e']),
            lon=slice(region['lon_s'], region['lon_e']),
        )
        .weighted(np.cos(np.deg2rad(data.lat)))
        .mean(("lat", "lon"))
    )
    data.name = 'sssa_na'
    return data

region_na= dict(
    lat_s = 25,
    lat_e = 50,
    lon_s = -50,
    lon_e = -15
)

region_nan = dict(
    lat_s = 38,
    lat_e = 50,
    lon_s = -50,
    lon_e = -40
)

sssa_na = reg_mean(sssa_d, region_na, region_nan)

In [11]:
def reg_mean(data, region):

    data = (
        data.sel(
            lat=slice(region['lat_s'], region['lat_e']),
            lon=slice(region['lon_s'], region['lon_e']),
        )
        .weighted(np.cos(np.deg2rad(data.lat)))
        .mean(("lat", "lon"))
    )  
    data.name = 'sssa_sa'
    return data

region_sa = dict(
    lat_s = -22.5,
    lat_e = -10,
    lon_s = -42,
    lon_e = -10
)

sssa_sa = reg_mean(sssa_d, region_sa)

In [12]:
xr.merge([
    sssa_na,
    sssa_sa
]).reset_coords(names='depth', drop=True).to_netcdf("data/da_pred_salinty.nc")