In [None]:
import gcsfs

gcsfs.__version__

In [None]:
from typing import Optional
import datetime
import logging
import re
import os
import dask.dataframe
import dask.distributed
import numpy as np
import pandas as pd
import pyinterp.backends.xarray
import pyinterp.geodetic
import xarray as xr

## Configuration du cluster local Dask

In [None]:
import dask_kubernetes

cluster = dask_kubernetes.KubeCluster()
cluster.adapt(minimum=1, maximum=10)
cluster

In [None]:
client = dask.distributed.Client(cluster)
client

In [None]:
def get_dask_dataframe(
        dirname: str,
        start: Optional[datetime.date] = None,
        end: Optional[datetime.date] = None,
        index: Optional[bool] = False,
) -> dask.dataframe.DataFrame:
    """Select the data frame to process between two dates"""
    if start is None:
        start = datetime.date(1995, 1, 1)
    if end is None:
        end = datetime.date.today()
    ddf = dask.dataframe.read_parquet(dirname,
                                      engine="pyarrow",
                                      filters=[('year', '>=', start.year),
                                               ('month', '>=', start.month),
                                               ('year', '<=', end.year),
                                               ('month', '<=', end.month)])
    ddf = ddf[(ddf.datetime > start.isoformat())
              & (ddf.datetime <= end.isoformat())]
    if index:
        ddf = ddf.set_index("datetime")
    return ddf

## Sélection géographique

In [None]:
def _select_area(ddf: dask.dataframe.DataFrame, box: pyinterp.geodetic.Box2D):
    """Applies geographic selection to a DataFrame of a partition"""
    return list(
        box.covered_by(ddf.longitude.values, ddf.latitude.values).astype(bool))


def select_area(ddf: dask.dataframe.DataFrame, box: pyinterp.geodetic.Box2D):
    """Applies geographic selection to a DataFrame"""
    return ddf.map_partitions(_select_area, box)

In [None]:
# Path the Parquet dataset
path = "gs://pangeo-cnes/argo"

In [None]:
# Reading a small dataset (You can increase the size of data to read, but it
# will take longer on our virtual machine)
ddf = get_dask_dataframe(
    path,
    datetime.date(1990, 1, 1),
    datetime.date(2019, 2, 1))

In [None]:
# Creation of the data selection box.
area = pyinterp.geodetic.Box2D(
    pyinterp.geodetic.Point2D(-80, 7),
    pyinterp.geodetic.Point2D(0,60))
area

In [None]:
# Calculation of the query
df = ddf[select_area(ddf, area)].compute()

In [None]:
# Visualization of the result
import matplotlib.pyplot as plt
import cartopy.crs as ccrs
import cartopy.feature as cfeature
%matplotlib inline

fig = plt.figure(figsize=(10, 5))
ax = fig.add_subplot(111, projection=ccrs.PlateCarree(central_longitude=180))
sc = ax.scatter(
    df.longitude,
    df.latitude,
    1,
    c=[item[0] for item in df.temp],
    transform=ccrs.PlateCarree(),
    cmap='jet')
ax.coastlines()
ax.add_feature(cfeature.LAND)
ax.add_feature(cfeature.COASTLINE)
fig.colorbar(sc)

## Sélection par numéro de plateforme

In [None]:
df = ddf[ddf.platform_number.isin(['2901216', '6900381', '5901026', '2902557'])]
df = df[['datetime', 'longitude', 'latitude', 'temp']]
df = df.compute()

In [None]:
fig = plt.figure(figsize=(10, 5))
ax = fig.add_subplot(111, projection=ccrs.PlateCarree(central_longitude=180))
sc = ax.scatter(
    df.longitude,
    df.latitude,
    1,
    c=[item[0] for item in df.temp],
    transform=ccrs.PlateCarree(),
    cmap='jet')
ax.coastlines()
ax.add_feature(cfeature.LAND)
ax.add_feature(cfeature.COASTLINE)
fig.colorbar(sc)

## Calcul d'une anomalie de pression

In [None]:
ddf = get_dask_dataframe(
    path,
    datetime.date(1990, 1, 1),
    datetime.date(2019, 2, 1))

In [None]:
def pressure_anomalies(df):
    """Calculates pressure anomalies"""
    return df.pres - df.pres_adjusted

In [None]:
# Here only columns containing the longitude and latitude of the floats are
# selected.
df = ddf[['longitude', 'latitude']].compute()
df['anomalies'] = ddf.map_partitions(
    pressure_anomalies, meta=(None, 'f8')).compute()

In [None]:
# The average anomaly is calculated
df['mean_anomalies'] = df['anomalies'].map(
    lambda series: np.nan if np.all(np.isnan(series)) else np.nanmean(series))

In [None]:
df

In [None]:
fig = plt.figure(figsize=(10, 5))
ax = fig.add_subplot(111, projection=ccrs.PlateCarree(central_longitude=180))
sc = ax.scatter(
    df.longitude,
    df.latitude,
    1,
    c=df.mean_anomalies,
    transform=ccrs.PlateCarree(),
    cmap='jet',
    vmin=-1,
    vmax=1)
ax.coastlines()
ax.add_feature(cfeature.LAND)
ax.add_feature(cfeature.COASTLINE)
fig.colorbar(sc)

## SLA interpolation on Argo float positions

In [None]:
class GridSeries:
    """Handles a series of grids stored in zarr format. This series is a
    time series."""
    def __init__(self, ds):
        self.ds = ds
        self.series, self.dt = self._load_ts()
        
    @staticmethod
    def _is_sorted(array):
        indices = np.argsort(array)
        return np.all(indices == np.arange(len(indices)))

    def _load_ts(self):
        """Loading the time series into memory."""
        time = self.ds.time
        assert self._is_sorted(time)

        series = pd.Series(time)
        frequency = set(np.diff(series.values.astype("datetime64[s]")).astype("int64"))
        if len(frequency) != 1:
            raise RuntimeError(
                "Time series does not have a constant step between two "
                f"grids: {frequency} seconds")
        return series, datetime.timedelta(seconds=float(frequency.pop()))
    
    def load_dataset(self, varname, start, end):
        """Loading the time series into memory for the defined period.

        Args:
            varname (str): Name of the variable to be loaded into memory.
            start (datetime.datetime): Date of the first map to be loaded.
            end (datetime.datetime): Date of the last map to be loaded.

        Return:
            pyinterp.backends.xarray.Grid3D: The interpolator handling the
            interpolation of the grid series.
        """
        if start < self.series.min() or end > self.series.max():
            raise IndexError(
                f"period [{start}, {end}] out of range [{self.series.min()}, "
                f"{self.series.max()}]")
        first = start - self.dt
        last = end + self.dt

        selected = self.series[(self.series >= first) & (self.series < last)]
        print(f"fetch data from {selected.min()} to {selected.max()}")
        
        data_array = ds[varname].isel(time=selected.index)
        return pyinterp.backends.xarray.Grid3D(data_array)

In [None]:
def interpolate(df, grid_series, varname):
    """Interpolate a variable 'varname' described by the time series
    'grid_series' for the locations provided in the DataFrame 'df'"""
    if not len(df):
        return np.array([])
    # The DataFrame must be ordered by the time axis
    df = df.set_index("datetime")
    # The time axis is divided into monthly periods
    period_start = df.groupby(df.index.to_period('M'))["sla"].count().index
    periods = []

    end = None

    # Calculates the period required to interpolate the data from the provided
    # time series
    for start, end in zip(period_start, period_start[1:]):
        start = start.to_timestamp()
        if start < grid_series.df.index[0]:
            start = grid_series.df.index[0]
        end = end.to_timestamp()
        periods.append((start, end))
    if end is None:
        end = period_start[0].to_timestamp()
    periods.append((end, df.index[-1] + datetime.timedelta(seconds=3600)))

    # Finally, the data on the different periods identified are interpolated.
    result = []
    for start, end in periods:
        interpolator = grid_series.load_dataset(varname, start, end)
        mask = (df.index >= start) & (df.index < end)
        selected = df.loc[mask, ["longitude", "latitude"]]
        result.append(
            interpolator.trivariate(dict(
                longitude=selected["longitude"].values,
                latitude=selected["latitude"].values,
                time=selected.index.values),
                                    interpolator="inverse_distance_weighting",
                                    num_threads=1))
    return pd.Series(np.hstack(result), df.index)

In [None]:
# Loading the time series
import intake
cat = intake.Catalog("https://raw.githubusercontent.com/pangeo-data/pangeo-datastore"
                     "/master/intake-catalogs/ocean.yaml")
ds = cat["sea_surface_height"].to_dask()
ds

In [None]:
# DELETE
ds = ds.drop("crs")

In [None]:
grid_series = GridSeries(ds)

In [None]:
# Select the data from dataset
ddf = get_dask_dataframe(
    path,
    datetime.date(1990, 1, 1),
    datetime.date(2019, 1, 2))

In [None]:
# Calculation of SLA
sla = ddf.map_partitions(interpolate, grid_series, 'sla', meta=('result', np.float64)).compute()

In [None]:
# Generation of a DataFrame containing the float positions and the
# interpolated SLA.
df = ddf[["datetime", "longitude", "latitude"]].compute()
df = df.join(sla, on="datetime")

### Visualization of the result

In [None]:
first = df.datetime.min()
last = df.datetime.max()
size = (df.datetime - first) / (last-first)

In [None]:
fig = plt.figure(figsize=(10, 5))
ax = fig.add_subplot(111, projection=ccrs.PlateCarree(central_longitude=180))
sc = ax.scatter(
    df.longitude,
    df.latitude,
    s=size*100,
    c=df.result,
    transform=ccrs.PlateCarree(),
    cmap='jet')
ax.coastlines()
ax.set_title("Time series of SLA "
             "(larger points are closer to the last date)")
ax.add_feature(cfeature.LAND)
ax.add_feature(cfeature.COASTLINE)
ax.set_extent([80, 100, 13.5, 25], crs=ccrs.PlateCarree())
fig.colorbar(sc)
