This notebook explores the ocean model simulation used as an input to the generator.
We make use of the Southern Ocean Regionally Refined Mesh v 2.1 run of the E3SM / MPAS-Ocean model, a 1000 year simulation of the ocean circulation, provided at a monthly resolution.
The notebook uses `dask` to chunk this dataset throughout the workflow to enable scalable computation.

In [90]:
import sys
import os
os.environ['USE_PYGEOS'] = '0'
import gc
import collections
from pathlib import Path

import cartopy.crs as ccrs
import cartopy
import matplotlib.pyplot as plt
import seaborn as sns
from matplotlib import rcParams, cycler
from matplotlib import animation, rc
from matplotlib.gridspec import GridSpec
from mpl_toolkits.axes_grid1.inset_locator import inset_axes
import geopandas as gpd

import numpy as np
import xarray as xr
from xeofs.xarray import EOF
import rioxarray

import dask
import distributed

import scipy
from scipy import signal
import cftime
from shapely.geometry import mapping
from xarrayutils.utils import linear_trend, xr_linregress
import pandas as pd

In [None]:
client = distributed.Client()
client

In [88]:
# File path directories

# inDirName = os.path.abspath(os.path.join(os.getcwd(), os.pardir))

# Get full path of the aislens_emulation directory. All file IO is relative to this path.
main_dir = Path.cwd().parent

dir_ext_data = 'data/external/'
dir_interim_data = 'data/interim/'

# DATASET FILEPATHS
# Ocean model output - E3SM (SORRMv2.1.ISMF), data received from Darin Comeau / Matt Hoffman at LANL
DIR_SORRMv21 = 'data/external/SORRMv2.1.ISMF/regridded_output/'

# INTERIM GENERATED FILEPATHS
DIR_SORRMv21_Interim = 'data/interim/SORRMv2.1.ISMF/iceShelves_dedraft/'

# DATA FILENAMES
FILE_SORRMv21 = 'Regridded_SORRMv2.1.ISMF.FULL.nc'

# Ice shelf basin/catchment definitions
FILE_iceShelvesShape = 'iceShelves.geojson'

In [74]:
SORRMv21 = xr.open_dataset(inDirName+DIR_SORRMv21+FILE_SORRMv21, chunks={"Time":36})

ICESHELVES_MASK = gpd.read_file(inDirName+DIR_external+FILE_iceShelvesShape)
icems = ICESHELVES_MASK.to_crs({'init': 'epsg:3031'});
crs = ccrs.SouthPolarStereo();

  in_crs_string = _prepare_from_proj_string(in_crs_string)


In [75]:
flux = SORRMv21.timeMonthly_avg_landIceFreshwaterFlux
# ssh = SORRMv21.timeMonthly_avg_ssh

### Detrend and deseasonalize data

In [None]:
# Linear trend, if any
# Debug for dask implementation, a "consistent source of headaches".
# https://ncar.github.io/esds/posts/2022/dask-debug-detrend/

def detrend_dim(data, dim, deg):
    # detrend along a single dimension
    p = data.polyfit(dim=dim, deg=deg)
    fit = xr.polyval(data[dim], p.polyfit_coefficients)
    return data - fit

In [None]:
p = flux.polyfit(dim='Time',deg=1)
p.polyfit_coefficients

In [None]:
fit = xr.polyval(flux['Time'], p.polyfit_coefficients)

In [None]:
dim = 'Time'
data = flux
chunked_dim = xr.DataArray(dask.array.from_array(data[dim].data, chunks=data.chunksizes[dim]), dims=dim, name=dim)

In [76]:
# polyval function defined below does not work for a Datetime vector time dimension, only for integers. 
# TODO: Modify function to convert that to a vector of floats or ints.

def polyval(coord, coeffs, degree_dim="degree"):
    x = np.array(range(12000))# coord.data

    deg_coord = coeffs[degree_dim]
    N = int(deg_coord.max()) + 1

    lhs = xr.DataArray(
        np.stack([x ** (N - 1 - i) for i in range(N)], axis=1),
        dims=(coord.name, degree_dim),
        coords={coord.name: coord, degree_dim: np.arange(deg_coord.max() + 1)[::-1]},
    )
    return (lhs * coeffs).sum(degree_dim)


# Function to detrend
# Modified from source: https://gist.github.com/rabernat/1ea82bb067c3273a6166d1b1f77d490f
def detrend_dim(da, dim, deg=1):
    """detrend along a single dimension."""
    # calculate polynomial coefficients
    p = da.polyfit(dim=dim, deg=deg, skipna=False)
    # first create a chunked version of the "ocean_time" dimension
    chunked_dim = xr.DataArray(dask.array.from_array(da[dim].data, chunks=da.chunksizes[dim]), dims=dim, name=dim)
    fit = polyval(chunked_dim, p.polyfit_coefficients)
    # evaluate trend
    # remove the trend
    return da - fit

In [77]:
flux_detrend = detrend_dim(flux,"Time")

In [79]:
flux_detrend_computed = flux_detrend.compute()

Key:       ('solve-triangular-d0271ab970130deee2ab1b77776e01e4', 0, 0)
Function:  solve_triangular_safe
args:      (array([[-1.        ,  0.98123351],
       [ 0.        , -0.19282321]]), array([[0., 0., 0., ..., 0., 0., 0.],
       [0., 0., 0., ..., 0., 0., 0.]]))
kwargs:    {}
Exception: "ValueError('array must not contain infs or NaNs')"



ValueError: array must not contain infs or NaNs

In [None]:
"""
import scipy.signal as sps
import scipy.linalg as spl


def detrend(da, dim, detrend_type="constant"):
    """
    Detrend a DataArray

    Parameters
    ----------
    da : xarray.DataArray
        The data to detrend
    dim : str or list
        Dimensions along which to apply detrend.
        Can be either one dimension or a list with two dimensions.
        Higher-dimensional detrending is not supported.
        If dask data are passed, the data must be chunked along dim.
    detrend_type : {'constant', 'linear'}
        If ``constant``, a constant offset will be removed from each dim.
        If ``linear``, a linear least-squares fit will be estimated and removed
        from the data.

    Returns
    -------
    da : xarray.DataArray
        The detrended data.

    Notes
    -----
    This function will act lazily in the presence of dask arrays on the
    input.
    """

    if dim is None:
        dim = list(da.dims)
    else:
        if isinstance(dim, str):
            dim = [dim]

    if detrend_type not in ["constant", "linear", None]:
        raise NotImplementedError(
            "%s is not a valid detrending option. Valid "
            "options are: 'constant','linear', or None." % detrend_type
        )

    if detrend_type is None:
        return da
    elif detrend_type == "constant":
        return da - da.mean(dim=dim)
    elif detrend_type == "linear":
        data = da.data
        axis_num = [da.get_axis_num(d) for d in dim]
        chunks = getattr(data, "chunks", None)
        if chunks:
            axis_chunks = [data.chunks[a] for a in axis_num]
            print(axis_chunks)
            if not all([len(ac) == 1 for ac in axis_chunks]):
                raise ValueError("Contiguous chunks required for detrending.")
        if len(dim) == 1:
            dt = xr.apply_ufunc(
                sps.detrend,
                da,
                axis_num[0],
                output_dtypes=[da.dtype],
                dask="parallelized",
            )
        elif len(dim) == 2:
            dt = xr.apply_ufunc(
                _detrend_2d_ufunc,
                da,
                input_core_dims=[dim],
                output_core_dims=[dim],
                output_dtypes=[da.dtype],
                vectorize=True,
                dask="parallelized",
            )
        else:  # pragma: no cover
            raise NotImplementedError(
                "Only 1D and 2D detrending are implemented so far."
            )

    return dt


def _detrend_2d_ufunc(arr):
    assert arr.ndim == 2
    N = arr.shape

    col0 = np.ones(N[0] * N[1])
    col1 = np.repeat(np.arange(N[0]), N[1]) + 1
    col2 = np.tile(np.arange(N[1]), N[0]) + 1
    G = np.stack([col0, col1, col2]).transpose()

    d_obs = np.reshape(arr, (N[0] * N[1], 1))
    m_est = np.dot(np.dot(spl.inv(np.dot(G.T, G)), G.T), d_obs)
    d_est = np.dot(G, m_est)
    linear_fit = np.reshape(d_est, N)
    return arr - linear_fit
"""

In [None]:
# Deseasonalize
# Remove climatologies to isolate anomalies / deseasonalize
def deseasonalize(data):
    data_month = data.groupby("Time.month")
    data_clm = data_month.mean("Time") # Climatologies
    data_anm = data_month - data_clm # Deseasonalized anomalies
    return data_anm

In [None]:
flux_detrend = detrend_dim(flux,"Time",1)
# flux_clean = deseasonalize(flux_detrend)

In [None]:
# flux_detrend = flux_detrend.compute()

### Temporal Standard Deviation

In [None]:
# Standard deviation in time
flux_std = flux.std('Time').compute()
flux_std.where(flux_std!=0).plot()

### Temporal Mean

In [None]:
# Time mean
flux_tmean = flux.mean('Time').compute()
flux_tmean.where(flux_std!=0).plot()

### Cumulative melt rate (across the ice sheet)

In [None]:
flux_ts = flux.sum(['x','y']).compute()

In [None]:
plt.figure(figsize=(25,8))
flux_ts.plot()
plt.xlabel('Time (Simulation years)')
plt.title('Freshwater Flux - AIS Cumulative')

### Mean freshwater flux in each catchment

In [None]:
def clip_data(total_data, basin):
    clipped_data = total_data.rio.clip(icems.loc[[basin],'geometry'].apply(mapping))
    #clipped_data = clipped_data.dropna('time',how='all')
    #clipped_data = clipped_data.dropna('y',how='all')
    #clipped_data = clipped_data.dropna('x',how='all')
    # clipped_data = clipped_data.drop("month")
    return clipped_data

In [None]:
mean_flux = np.empty(len(icems[33:133]))

for i in range(33,133):
    clip_ds = clip_data(flux_tmean, i)
    mean_flux[i-33] = clip_ds.mean()

In [None]:
# flux_tmean.rio.write_crs("epsg:3031",inplace=True);

In [None]:
plt.figure(figsize=(25,8))
plt.plot(mean_flux, marker='x', lw=0.0)

