In [None]:
%matplotlib inline
from matplotlib import pyplot as plt

In [None]:
import warnings
warnings.simplefilter(action='ignore', category=FutureWarning)
warnings.simplefilter(action='ignore', category=RuntimeWarning)
warnings.simplefilter("ignore")

In [None]:
from IPython.display import YouTubeVideo, Image

# [XARRAY](https://github.com/xray/xray)

[xarray](https://github.com/xarray/xarray) (formerly `xray`) has been developed by scientists / engineers working at the [Climate Corporation](http://climate.com/)

It is an open source project and Python package that aims to bring
the labeled data power of [pandas](http://pandas.pydata.org) to the
physical sciences, by providing N-dimensional variants of the core
[pandas](http://pandas.pydata.org) data structures, `Series` and
`DataFrame`: the xarray `DataArray` and `Dataset`.

the goal is to provide a pandas-like and pandas-compatible toolkit for
analytics on multi-dimensional arrays, rather than the tabular data for
which pandas excels. The approach adopts the [Common Data
Model](http://www.unidata.ucar.edu/software/thredds/current/netcdf-java/CDM)
for self-describing scientific data in widespread use in the Earth
sciences (e.g., [netCDF](http://www.unidata.ucar.edu/software/netcdf)
and [OPeNDAP](http://www.opendap.org/)): `xray.Dataset` is an in-memory
representation of a netCDF file.

-   HTML documentation: <http://xarray.readthedocs.org>: **really good doc !**
-   Source code: <http://github.com/xarray/xarray>

The main advantages of using [xarray](https://github.com/xarray/xarray) versus [netCDF4](https://github.com/Unidata/netcdf4-python) are: 

+ intelligent selection along **labelled dimensions** (and also indexes)
+ **groupby** operations
+ **resampling** operations
+ data alignment 
+ IO (netcdf)
+ conversion from / to [Pandas.DataFrames](http://pandas.pydata.org/pandas-docs/dev/generated/pandas.DataFrame.html)


To install the latest version of xarray (via conda): 

    ᐅ conda install xarray

or if you want the bleeding edge: 


    ᐅ pip install git+https://github.com/xarray/xarray

There's too much to see in the context of this talk, to know more about all the cool **xarray** features, watch: 

PyData talk by **Stefan Hoyer**: <https://www.youtube.com/watch?v=T5CZyNwBa9c>

In [None]:
YouTubeVideo('T5CZyNwBa9c', width=500, height=400, start=0)

## Some examples

In [None]:
%matplotlib inline
from matplotlib import pyplot as plt

In [None]:
import os
import numpy as np
import pandas as pd

In [None]:
import xarray as xr; print(xr.__version__)

### Open a netcdf file: monthly HGT from NCEP / NCAR from January 1948 to July 2017

The file (270 Mb) can be downloaded at [ftp://ftp.cdc.noaa.gov/Datasets/ncep.reanalysis.derived/pressure/hgt.mon.mean.nc](ftp://ftp.cdc.noaa.gov/Datasets/ncep.reanalysis.derived/pressure/hgt.mon.mean.nc)

In [None]:
dset = xr.open_dataset('../data/hgt.2019.nc')

In [None]:
dset.info

**dset** is a [xray.Dataset](http://xray.readthedocs.org/en/stable/data-structures.html#dataset), It is a dict-like container of labeled arrays (DataArray objects) with aligned dimensions. It is designed as an in-memory representation of the data model from the netCDF file format.

In [None]:
Image('http://xarray.pydata.org/en/stable/_images/dataset-diagram.png', width=700)

In [None]:
dset.dims

### accessing variables

In [None]:
lat = dset['lat']

In [None]:
lat

In [None]:
type(lat)

In [None]:
lat.attrs

In [None]:
type(lat.data)

In [None]:
lat.data

### selecting a Dataset along dimensions

In [None]:
dset

In [None]:
dset.sel(time='2019-01-01')

### integer indexing also available via `isel`

In [None]:
dset.isel(time=0)

### and you can slice along one or multiple dimensions 

In [None]:
dset.sel(time=slice('2019-01-01','2019-01-31'))

### when selecting along any dimension, you need to respect the ORDER of the dimensions, i.e. in this case, the latitudes go from NORTH TO SOUTH 

In [None]:
dset.sel(time=slice('2019-01-01','2019-01-31'), lat=slice(40,-40))

### or you can sort 

In [None]:
if dset.lat[0] > dset.lat[-1]: 
    dset = dset.sortby('lat')

In [None]:
dset.sel(time=slice('2019-01-01','2019-01-31'), lat=slice(-40,40))

In [None]:
subset = dset.sel(time='2019-01-01', level=1000, lat=slice(-50,40), lon=slice(100, 220))

In [None]:
subset

In [None]:
subset = subset.drop('time_bnds')

In [None]:
subset

In [None]:
subset['hgt']

In [None]:
subset['hgt'].plot()

### basic mapping with cartopy

In [None]:
import cartopy.crs as ccrs

In [None]:
crs = ccrs.PlateCarree(central_longitude=180)

In [None]:
f, ax = plt.subplots(figsize=(10,8), subplot_kw={'projection':crs})
subset['hgt'].plot.contourf(ax=ax, transform=ccrs.PlateCarree(), levels=20);
ax.coastlines(resolution='10m')
f.savefig('../figures/map_hgt', dpi=200, bbox_inches='tight')

In [None]:
crs = ccrs.Orthographic(central_longitude=160)

In [None]:
f, ax = plt.subplots(subplot_kw={'projection':crs}, figsize=(10, 8))
# dset.isel(time=0, level=0)['hgt'].plot.contourf(ax=ax, transform=ccrs.PlateCarree(central_longitude=0), levels=20);
subset['hgt'].plot.contourf(ax=ax, transform=ccrs.PlateCarree(), levels=20);
ax.set_global()
ax.coastlines(resolution='50m')

### important to close the datasets, or else you might encounter issues 

In [None]:
dset.close()

In [None]:
subset.close()

### reading multiple files datasets 

reading datasets that are split in different files (i.e. one file per year or month) is relatively easy in xarray, you just need to pass a LIST (Python list) of file paths

In [None]:
import pathlib

In [None]:
path = pathlib.Path('/home/nicolasf/drives/auck_scratch/fauchereaun/Wellington_Python_Workshop_data/')

In [None]:
lfiles = list(path.rglob("ersst.??????.nc"))

In [None]:
lfiles.sort()

In [None]:
len(lfiles)

In [None]:
lfiles[0]

In [None]:
lfiles[-1]

In [None]:
dset = xr.open_mfdataset(lfiles)

In [None]:
dset

In [None]:
dset.nbytes / 1e6

In [None]:
dset = dset.squeeze()

In [None]:
subset = dset.sel(lon=slice(100., 280.), lat=slice(-5., 5.))

In [None]:
subset

### aggregation functions along dimensions

you can apply functions along the dimensions of a **xarray** dataset, i.e. calculate a mean over time, or (for zonal means), over latitudes

In [None]:
time_mean = subset.mean('time')

In [None]:
time_mean

In [None]:
time_mean['sst'].plot(cmap=plt.get_cmap('RdBu_r'))

In [None]:
zonal_mean = subset.mean('lat')

In [None]:
zonal_mean

In [None]:
f, my_axes = plt.subplots(figsize=(8,8))
zonal_mean.sel(time=slice('2018', None))['ssta'].plot.contourf(ax=my_axes, levels=30)
my_axes.set_title('ZONAL MEAN SSTs')
my_axes.invert_yaxis()

### And that can be handy to calculate an index ... example with NINO3.4

NINO3.4 is regional avegare of the SST anomalies in the domain (5N-5S, 170W-120W)

In [None]:
index = subset.sel(lon=slice(360-170, 360-120))

In [None]:
NINO34 = index.mean(dim=('lon','lat'))

or 

In [None]:
NINO34 = index.mean('lon').mean('lat')

In [None]:
NINO34['ssta'].plot()

### calculates a monthly climatology using the groupby machinery

In [None]:
Image(filename='images/split-apply-combine.png', width=800)

### CLIMATOLOGY

In [None]:
path

In [None]:
sfc_temp = xr.open_dataset(path / 'air.mon.mean.nc')

In [None]:
sfc_temp

In [None]:
clim = sfc_temp.groupby('time.month').mean('time')

In [None]:
clim

In [None]:
clim.sel(month = 8)['air'].plot(cmap=plt.get_cmap('RdBu_r'))

In [None]:
from calendar import month_abbr

In [None]:
months = month_abbr[1:]

In [None]:
months

In [None]:
clim['month'] = (('month'), months)

In [None]:
clim

In [None]:
clim.sel(month = 'Jan')['air'].plot(cmap=plt.get_cmap('RdBu_r'))

### the plot method accepts parameters allowing to fo facetted plots 

In [None]:
clim['air'].plot(x='lon', y='lat', col='month', col_wrap=4)

### to plot in on a proper map, use cartopy

In [None]:
from calendar import month_name

In [None]:
proj = ccrs.PlateCarree(central_longitude=180)

In [None]:
transform = ccrs.PlateCarree(central_longitude=0)

In [None]:
months

In [None]:
f, axes = plt.subplots(nrows=4, ncols=3, figsize=(14,10), subplot_kw={'projection':proj})
f.subplots_adjust(hspace=0.1, wspace=0.1)
axes = axes.flatten()
for i, month in enumerate(months): 
    ax = axes[i]
    clim.sel(month = month)['air'].plot.contourf(ax=ax,transform=transform, cmap=plt.get_cmap('RdBu_r'), levels=np.arange(-50, 55, 5), extend='both')
    ax.coastlines(resolution='50m')

**NOTE**: If you have **DAILY** data, you can calculate a daily climatology using the `dayofyear` attribute, e.g.: 
    
```python 

clim = dset.groupby('time.dayofyear').mean('time')

```

### calculates a seasonal (DJF, MAM, ...) climatology

In [None]:
seas_clim = sfc_temp.groupby('time.season').mean('time')

In [None]:
seas_clim

In [None]:
f, axes = plt.subplots(nrows=2, ncols=2, figsize=(14,10), subplot_kw={'projection':proj})
f.subplots_adjust(hspace=0.1, wspace=0.1)
axes = axes.flatten('F')
for i, seas in enumerate(seas_clim['season'].values): 
    ax = axes[i]
    seas_clim.sel(season = seas)['air'].plot.contourf(ax=ax,transform=transform, cmap=plt.get_cmap('RdBu_r'), levels=np.arange(-50, 55, 5), extend='both')
    ax.coastlines(resolution='50m')

### calculates anomalies with respect to a specific climatological *normal*

#### 1. defines the function

In [None]:
def demean(x): 
    return x - x.sel(time=slice('1981-01-01','2010-12-01')).mean('time')

#### 2. apply the function to the groupby object

In [None]:
sfc_anoms = sfc_temp.groupby('time.month').apply(demean) 

In [None]:
sfc_anoms

In [None]:
sfc_anoms.isel(time=-1)['air'].plot(cbar_kwargs={'orientation':'horizontal','shrink':0.9, 'label':u'\N{DEGREE SIGN}C'})

In [None]:
crs = ccrs.PlateCarree(central_longitude=180)

In [None]:
f, ax = plt.subplots(figsize=(10,8), subplot_kw={'projection':crs})
sfc_anoms.isel(time=-1)['air'].plot.contourf(ax=ax, transform=ccrs.PlateCarree(), levels=20, cbar_kwargs={'shrink':0.5});
ax.coastlines(resolution='50m')

In [None]:
global_anom = sfc_anoms.mean(dim=('lat','lon'))

In [None]:
global_anom['air'].plot()

In [None]:
f, ax = plt.subplots()
global_anom.rolling(dim={'time':12}, min_periods=12, center=True).mean()['air'].plot(ax=ax)
ax.axhline(0, color='k', lw=0.5)

### then you can export it to a dataframe, and e.g. save it into a csv file 

In [None]:
global_anom_df = global_anom['air'].to_dataframe()

In [None]:
global_anom_df.head()

### Creates a xray dataset object from numpy arrays

In [None]:
lon = np.linspace(0, 357.5, 144, endpoint=True)

lat = np.linspace(-90,90, 73, endpoint=True)

lons, lats = np.meshgrid(lon,lat)

lev = np.array([1000,925,850])

time = pd.date_range(start='2015-1-1',end='2015-1-3')

In [None]:
lat

In [None]:
arr = np.random.randn(3,3,73,144)

The dictionnary **keys** are the **variables** contained in the Dataset.<br><br>
The Dictionnary **values** are **tuples**, with first the (or the list of) dimension(s) over which the array varies, then the array itself

In [None]:
d = {}
d['time'] = ('time',time)
d['latitudes'] = ('latitudes',lat)
d['longitudes'] = ('longitudes', lon)
d['level'] = ('level', lev)
d['var'] = (['time','level','latitudes','longitudes'], arr)

In [None]:
dset = xr.Dataset(d)

In [None]:
dset

adding global attributes

In [None]:
dset.attrs['author'] = 'nicolas.fauchereau@niwa.co.nz'

In [None]:
dset

adding variables attributes

In [None]:
dset.longitudes.attrs['units'] = 'degrees_east'
dset.latitudes.attrs['units'] = 'degrees_north'

In [None]:
dset.latitudes.attrs

In [None]:
dset.sel(time='2015-1-2', level=1000)

In [None]:
dset.to_netcdf('../data/dset_from_dict.nc')

In [None]:
!ncdump -h ../data/dset_from_dict.nc

### Creates a xray dataset object from a Pandas DataFrame

In [None]:
import string
df = pd.DataFrame(np.random.randn(365,5), \
                  index=pd.date_range(start='2014-1-1', periods=365), \
                  columns=list(string.ascii_letters[:5]))

In [None]:
df.head()

In [None]:
df_ds = xr.Dataset.from_dataframe(df)

In [None]:
df_ds

In [None]:
group = df_ds.groupby('index.month').mean('index')

In [None]:
group

### Opening a file over the network with openDAP

In [None]:
url = 'http://www.esrl.noaa.gov/psd/thredds/dodsC/Datasets/interp_OLR/olr.mon.mean.nc'

In [None]:
olr_dset = xr.open_dataset(url)

In [None]:
olr_sub = olr_dset.sel(time='1998-1-1',lat=slice(30,-30), lon=slice(170, 300))

In [None]:
olr_sub

In [None]:
f, ax = plt.subplots(figsize=(10,8), subplot_kw={'projection':ccrs.PlateCarree(central_longitude=180)})
olr_sub['olr'].plot.contourf(ax=ax, transform=ccrs.PlateCarree(), levels=20, cbar_kwargs={'shrink':0.5});
ax.coastlines(resolution='50m')

### now with the gridlines and lat / lon labels 

In [None]:
from cartopy.mpl.ticker import LongitudeFormatter, LatitudeFormatter

In [None]:
f, ax = plt.subplots(figsize=(10,8), subplot_kw={'projection':ccrs.PlateCarree(central_longitude=180)})

olr_sub['olr'].plot.contourf(ax=ax, transform=ccrs.PlateCarree(), levels=20, cbar_kwargs={'shrink':0.5});

ax.coastlines(resolution='50m')

xticks = np.arange(170, 300 + 30, 30)

yticks = np.arange(-30., 30., 10.)

gl = ax.gridlines(draw_labels=False, linewidth=0.5, linestyle='--', xlocs=xticks, ylocs=yticks, crs=ccrs.PlateCarree())

ax.set_xticks(xticks, crs=ccrs.PlateCarree())

ax.set_yticks(yticks, crs=ccrs.PlateCarree())

lon_formatter = LongitudeFormatter(zero_direction_label=True, dateline_direction_label=True)

lat_formatter = LatitudeFormatter()

ax.xaxis.set_major_formatter(lon_formatter)

ax.yaxis.set_major_formatter(lat_formatter)

ax.set_extent([170, 300, -30, 30], crs=ccrs.PlateCarree())

ax.set_xlabel('longitude')

ax.set_xlabel('latitude')

f.savefig('../figures/olr_map.png', bbox_inches='tight', dpi=200)

In [None]:
!eom ../figures/olr_map.png

### interpolation from one grid to another 

In [None]:
olr_sub

In [None]:
olr_sub = olr_sub.sortby('lat')

In [None]:
olr_sub

In [None]:
d = {}
d['lon'] = (('lon'), np.arange(olr_sub.lon[0], olr_sub.lon[-1] + 0.25, 0.25))
d['lat'] = (('lat'), np.arange(olr_sub.lat[0], olr_sub.lat[-1] + 0.25, 0.25))

In [None]:
target_grid = xr.Dataset(d)

In [None]:
target_grid

In [None]:
olr_sub_interp = olr_sub.interp_like(target_grid)

In [None]:
olr_sub_interp

In [None]:
f, axes = plt.subplots(nrows=2, figsize=(10,10))

olr_sub['olr'].plot.imshow(ax=axes[0], interpolation='none')

olr_sub_interp['olr'].plot.imshow(ax=axes[1], interpolation='none')
