In [None]:
import xarray
import intake
import numpy as np
import pytide
import dask
import dask.distributed

In [None]:
# Definition of the calculation period of the analysis (the spin-up period is
# not included).
START_DATE = np.datetime64('2011-11-13')
END_DATE = np.datetime64('2012-11-12')

In [None]:
cat_url = "https://raw.githubusercontent.com/pangeo-data/pangeo-datastore/master/intake-catalogs/ocean/llc4320.yaml"
cat = intake.Catalog(cat_url)

In [None]:
sss = cat.LLC4320_SSS.to_dask()

In [None]:
time_series = sss.time.values

In [None]:
period = (time_series >= START_DATE) & (time_series <= END_DATE)
print("number of layers to process %d" % len(time_series[period]))
print("period [%s, %s]" % (time_series[period].min(), time_series[period].max()))

In [None]:
wave_table = pytide.WaveTable(['M2', 'K1', 'O1', 'P1', 'S1', 'S2'])
print("%d tidal constituents to be analysed" % len(wave_table))

In [None]:
from dask_kubernetes import KubeCluster
from dask.distributed import Client
cluster = KubeCluster()
cluster.adapt(minimum=1, maximum=400)
client = Client(cluster)
cluster

In [None]:
def compute_nodal_corrections(client, waves, time_series):
    t = time_series.astype(np.float64) * 1e-9
    f, v0u = waves.compute_nodal_corrections(t)
    f, v0u = f.T, v0u.T
    return (dask.array.from_delayed(client.scatter(f, broadcast=True),
                                    shape=f.shape,
                                    dtype=f.dtype),
            dask.array.from_delayed(client.scatter(v0u, broadcast=True),
                                    shape=v0u.shape,
                                    dtype=v0u.dtype))

In [None]:
f, v0u = compute_nodal_corrections(client, wave_table, time_series[period])

In [None]:
del wave_table
del time_series

In [None]:
def load_faces(ds, face, period):
    """Load a face from the time series"""
    ds = ds.SSS
    ds = ds.transpose("face", "j", "i", "time")
    return ds.isel(face=face, time=period).data

In [None]:
ds = load_faces(sss, 0, period)
ds

In [None]:
def dask_array_rechunk(da, nblocks, axis=2):
    """TODO rechunk"""
    chunks = []
    div = int(np.sqrt(nblocks))
    for index, item in enumerate(da.chunks):
        chunks.append(np.array(item).sum() * (div if index == axis else 1))
    return tuple(item // div for index, item in enumerate(chunks))

In [None]:
ds = ds.rechunk(dask_array_rechunk(ds, 4800))
ds

In [None]:
def _apply_along_axis(arr, func1d, func1d_axis, func1d_args, func1d_kwargs):
    """Wrap apply_along_axis"""
    return np.apply_along_axis(func1d, func1d_axis, arr, *func1d_args,
                                  **func1d_kwargs)


def apply_along_axis(func1d, axis, arr, *args, **kwargs):
    """Apply the harmonic analysis to 1-D slices along the given axis."""
    arr = dask.array.core.asarray(arr)

    # Validate and normalize axis.
    arr.shape[axis]
    axis = len(arr.shape[:axis])

    # Rechunk so that analyze is applied over the full axis.
    arr = arr.rechunk(arr.chunks[:axis] + (arr.shape[axis:axis + 1], ) +
                      arr.chunks[axis + 1:])

    # Test out some data with the function.
    test_data = np.ones(args[0].shape[1], dtype=arr.dtype)
    test_result = np.array(func1d(test_data, *args, **kwargs))

    # Map analyze over the data to get the result
    # Adds other axes as needed.
    result = arr.map_blocks(
        _apply_along_axis,
        name=dask.utils.funcname(func1d) + '-along-axis',
        dtype=test_result.dtype,
        chunks=(arr.chunks[:axis] + test_result.shape + arr.chunks[axis + 1:]),
        drop_axis=axis,
        new_axis=list(range(axis, axis + test_result.ndim, 1)),
        func1d=func1d,
        func1d_axis=axis,
        func1d_args=args,
        func1d_kwargs=kwargs,
    )

    return result

In [None]:
future = apply_along_axis(pytide.WaveTable.harmonic_analysis, 2, ds,
                          *(f, v0u))

In [None]:
analysis = future.compute()

In [None]:
analysis = np.transpose(analysis, [2, 0, 1])

In [None]:
import matplotlib.pyplot as plt
%matplotlib inline