In [None]:
import warnings
warnings.simplefilter(action='ignore', category=FutureWarning)
warnings.simplefilter(action='ignore', category=RuntimeWarning)
warnings.simplefilter("ignore")

In [None]:
from IPython.display import YouTubeVideo, Image

# [XARRAY](https://github.com/xray/xray)

[xarray](https://github.com/xarray/xarray) (formerly `xray`) has been developed by scientists / engineers working at the [Climate Corporation](http://climate.com/)

It is an open source project and Python package that aims to bring
the labeled data power of [pandas](http://pandas.pydata.org) to the
physical sciences, by providing N-dimensional variants of the core
[pandas](http://pandas.pydata.org) data structures, `Series` and
`DataFrame`: the xray `DataArray` and `Dataset`.

the goal is to provide a pandas-like and pandas-compatible toolkit for
analytics on multi-dimensional arrays, rather than the tabular data for
which pandas excels. The approach adopts the [Common Data
Model](http://www.unidata.ucar.edu/software/thredds/current/netcdf-java/CDM)
for self-describing scientific data in widespread use in the Earth
sciences (e.g., [netCDF](http://www.unidata.ucar.edu/software/netcdf)
and [OPeNDAP](http://www.opendap.org/)): `xray.Dataset` is an in-memory
representation of a netCDF file.

-   HTML documentation: <http://xarray.readthedocs.org>: **really good doc !**
-   Source code: <http://github.com/xarray/xarray>

The main advantages of using [xarray](https://github.com/xarray/xarray) versus [netCDF4](https://github.com/Unidata/netcdf4-python) are: 

+ intelligent selection along **labelled dimensions** (and also indexes)
+ **groupby** operations
+ **resampling** operations
+ data alignment 
+ IO (netcdf)
+ conversion from / to [Pandas.DataFrames](http://pandas.pydata.org/pandas-docs/dev/generated/pandas.DataFrame.html)


To install the latest version of xarray (via conda): 

    ᐅ conda install xarray

or if you want the bleeding edge: 


    ᐅ pip install git+https://github.com/xarray/xarray

There's too much to see in the context of this talk, to know more about all the cool **xarray** features, watch: 

PyData talk by **Stefan Hoyer**: <https://www.youtube.com/watch?v=T5CZyNwBa9c>

In [None]:
YouTubeVideo('T5CZyNwBa9c', width=500, height=400, start=0)

## Some examples

In [None]:
%matplotlib inline
from matplotlib import pyplot as plt

In [None]:
import os
import numpy as np
import pandas as pd
from mpl_toolkits.basemap import Basemap as bm

In [None]:
import xarray as xr; print(xr.__version__)

### Open a netcdf file: monthly HGT from NCEP / NCAR from January 1948 to July 2017

The file (270 Mb) can be downloaded at [ftp://ftp.cdc.noaa.gov/Datasets/ncep.reanalysis.derived/pressure/hgt.mon.mean.nc](ftp://ftp.cdc.noaa.gov/Datasets/ncep.reanalysis.derived/pressure/hgt.mon.mean.nc)

In [None]:
dpath = '/Users/nicolasf/drives/auck_scratch/fauchereaun/Python_Workshop/data/'

In [None]:
dset = xr.open_dataset(os.path.join(dpath, 'hgt.mon.mean.nc'))

In [None]:
dset

**dset** is a [xray.Dataset](http://xray.readthedocs.org/en/stable/data-structures.html#dataset), It is a dict-like container of labeled arrays (DataArray objects) with aligned dimensions. It is designed as an in-memory representation of the data model from the netCDF file format.

In [None]:
Image('http://xarray.pydata.org/en/stable/_images/dataset-diagram.png', width=700)

In [None]:
dset.dims

### accessing variables

In [None]:
lat = dset['lat']

In [None]:
type(lat)

In [None]:
lat

In [None]:
lat.attrs

In [None]:
type(lat.data)

In [None]:
lat.data

### selecting a Dataset along dimensions

In [None]:
dset

In [None]:
dset.sel(time=('1998-1-01'))

In [None]:
dset.sel(time=slice('1998-01-01','2000-12-01'))

when selecting along any dimension, you need to respect the ORDER of the dimensions, i.e. in this case,
the latitudes go from NORTH TO SOUTH 

In [None]:
dset.sel(time=slice('1998-01-01','2000-12-01'), lat=slice(40,-40))

In [None]:
# dset.sel(time=slice('1998-01-01','2000-12-01'), lat=slice(-40,40))

In [None]:
subset = dset.sel(time='1998-01-01', level=1000, lat=slice(40,-50), lon=slice(100, 220))

extraction of lat and lon dimensions and cast into 2D arrays (for plotting)

In [None]:
lat = subset['lat']
lon = subset['lon']

In [None]:
lons, lats = np.meshgrid(lon, lat)

In [None]:
f, axes = plt.subplots(ncols=2, figsize=(10,6))
im = axes[0].imshow(lons)
plt.colorbar(im, ax=axes[0], orientation='horizontal')
axes[0].set_title('lons')
im = axes[1].imshow(lats)
plt.colorbar(im, ax=axes[1], orientation='horizontal')
axes[1].set_title('lats')

#### defines the Basemap projection

In [None]:
m = bm(projection='cyl',llcrnrlat=lat.min(),urcrnrlat=lat.max(),\
            llcrnrlon=lon.min(),urcrnrlon=lon.max(),\
            lat_ts=0,resolution='c')

#### defines a function to plot a field (must be 2D)

In [None]:
def plot_field(X, lat, lon, vmin, vmax, step, cmap=plt.get_cmap('jet'), ax=False, title=False, grid=False):
    if not ax: 
        f, ax = plt.subplots(figsize=(10, (X.shape[0] / float(X.shape[1])) * 10))
    m.ax = ax
    im = m.contourf(lons, lats, X, np.arange(vmin, vmax+step, step), latlon=True, cmap=cmap, extend='both', ax=ax)
    m.drawcoastlines()
    if grid: 
        m.drawmeridians(np.arange(0, 360, 60), labels=[0,0,0,1])
        m.drawparallels([-40, 0, 40], labels=[1,0,0,0])
    m.colorbar(im)
    if title: 
        ax.set_title(title)

#### plots 

In [None]:
plot_field(subset['hgt'], lats, lons, -50, 200, 10, cmap=plt.get_cmap('RdBu_r'), grid=True)

#### close the files 

In [None]:
dset.close()

In [None]:
subset.close()

### reading multiple files datasets 

reading datasets that are split in different files (i.e. one file per year or month) is relatively easy in xarray, you just need to pass a LIST (Python list) of file paths

In [None]:
import os

In [None]:
from glob import glob

In [None]:
dpath = '/Users/nicolasf/drives/auck_scratch/fauchereaun/Python_Workshop/data/ERSST/'

In [None]:
lfiles = glob(os.path.join(dpath, "ersst*ft.nc"))

In [None]:
len(lfiles)

In [None]:
lfiles[0]

In [None]:
lfiles[1]

In [None]:
dset = xr.open_mfdataset(lfiles)

In [None]:
dset

In [None]:
subset = dset.sel(lon=slice(100., 200.), lat=slice(-50., 10.))

In [None]:
subset

### aggregation functions along dimensions

you can apply functions along the dimensions of a **xarray** dataset, i.e. calculate a mean over time, or (for zonal means), over latitudes

In [None]:
subset

In [None]:
lat = subset['lat']
lon = subset['lon']
lons, lats = np.meshgrid(lon, lat)

In [None]:
time_mean = subset.mean('time')

In [None]:
time_mean

In [None]:
time_mean['sst'].plot(cmap=plt.get_cmap('RdBu_r'))

In [None]:
zonal_mean = subset.mean('lat')

In [None]:
zonal_mean

In [None]:
zonal_mean.sel(time=slice('2015', None))['sst'].plot()

And that can be handy to calculate an index ... 

In [None]:
index = subset.mean('lat').mean('lon')

In [None]:
index

In [None]:
index['sst'].plot()

### calculates a monthly climatology using the groupby machinery

In [None]:
Image(filename='images/split-apply-combine.png', width=800)

In [None]:
lat = subset['lat']
lon = subset['lon']

In [None]:
lons, lats = np.meshgrid(lon, lat)

In [None]:
clim = subset.groupby('time.month').mean('time')

In [None]:
clim.sel(month = 8)['sst'].plot(cmap=plt.get_cmap('RdBu_r'))

In [None]:
clim

In [None]:
from calendar import month_abbr

In [None]:
months = month_abbr[1:]

In [None]:
months

In [None]:
clim['month'] = (('month'), months)

In [None]:
clim

In [None]:
clim.sel(month = 'Jan')['sst'].plot(cmap=plt.get_cmap('RdBu_r'))

### the plot method accepts parameters allowing to fo facetted plots 

In [None]:
clim['sst'].plot(x='lon', y='lat', col='month', col_wrap=3, cmap=plt.get_cmap('RdBu_r'))

### to plot in on a proper map, use basemap

In [None]:
from calendar import month_name

In [None]:
m = bm(projection='cyl',llcrnrlat=lat.min(),urcrnrlat=lat.max(),\
            llcrnrlon=lon.min(),urcrnrlon=lon.max(),\
            lat_ts=0,resolution='c')

In [None]:
f, axes = plt.subplots(nrows=4,ncols=3, figsize=(14,10))
f.subplots_adjust(hspace=0.1, wspace=0.1)
axes = axes.flatten()
for i, month in enumerate(months): 
    ax = axes[i]
    plot_field(clim.sel(month=month)['sst'], lats, lons, 0, 30, 1, ax=ax, title=month)

**NOTE**: If you have **DAILY** data, you can calculate a daily climatology using the `dayofyear` attribute, e.g.: 
    
```python 

clim = dset.groupby('time.dayofyear').mean('time')

```

### calculates a seasonal (DJF, MAM, ...) climatology

In [None]:
seas_clim = subset.groupby('time.season').mean('time')

In [None]:
seas_clim

In [None]:
f, axes = plt.subplots(nrows=2,ncols=2, figsize=(10,5))
f.subplots_adjust(hspace=0.1, wspace=0.1)
axes = axes.flatten('F')
for i, seas in enumerate(seas_clim['season'].values): 
    ax = axes[i]
    plot_field(seas_clim['sst'][i,:,:].values, lats, lons, 0, 30, 1, ax=ax, title=seas)

### calculates seasonal averages weigthed by the number of days in each month

adapted from [http://xray.readthedocs.org/en/stable/examples/monthly-means.html#monthly-means-example](http://xray.readthedocs.org/en/stable/examples/monthly-means.html#monthly-means-example)

In [None]:
def get_dpm(time):
    """
    return a array of days per month corresponding to the months provided in `time`
    """
    import calendar as cal
    month_length = np.zeros(len(time), dtype=np.float)

    for i, (month, year) in enumerate(zip(time.month, time.year)):
        month_length[i] = cal.monthrange(year, month)[1]
    return month_length

In [None]:
def season_mean(ds, calendar='standard'):
    # Make a DataArray of season/year groups
    year_season = xr.DataArray(ds.time.to_index().to_period(freq='Q-NOV').to_timestamp(how='E'),
                                 coords=[ds.time], name='year_season')

    # Make a DataArray with the number of days in each month, size = len(time)
    month_length = xr.DataArray(get_dpm(ds.time.to_index()),
                                  coords=[ds.time], name='month_length')
    # Calculate the weights by grouping by 'time.season'
    weights = month_length.groupby('time.season') / month_length.groupby('time.season').sum()

    # Test that the sum of the weights for each season is 1.0
    np.testing.assert_allclose(weights.groupby('time.season').sum().values, np.ones(4))

    # Calculate the weighted average
    return (ds * weights).groupby('time.season').sum(dim='time')

In [None]:
sst_seas = season_mean(subset)

In [None]:
sst_seas

In [None]:
f, axes = plt.subplots(nrows=2,ncols=2, figsize=(10,5))
f.subplots_adjust(hspace=0.1, wspace=0.1)
axes = axes.flatten('F')
for i, seas in enumerate(seas_clim['season'].data): 
    ax = axes[i]
    plot_field(seas_clim['sst'][i,:,:].data, lats, lons, 0, 30, 1, ax=ax, title=seas)

#### difference between non-weigthed and weighted seasonal climatologies

In [None]:
diff_seas = seas_clim - sst_seas

In [None]:
f, axes = plt.subplots(nrows=2,ncols=2, figsize=(10,5))
f.subplots_adjust(hspace=0.1, wspace=0.1)
axes = axes.flatten('F')
for i, seas in enumerate(seas_clim['season'].values): 
    ax = axes[i]
    plot_field(diff_seas['sst'][i,:,:].data, lats, lons, -0.1, 0.1, 0.01, ax=ax, title=seas)

### calculates anomalies with respect to a specific climatological *normal*

#### 1. defines the function

In [None]:
def demean(x): 
    return x - x.sel(time=slice('1981-01-01','2010-12-01')).mean('time')

#### 2. apply the function to the groupby object

In [None]:
sst_anoms = subset['sst'].groupby('time.month').apply(demean) 

#### should be very similar to the original anomalies

In [None]:
plot_field(sst_anoms.sel(time=('2017-01-01')), lats, lons, -4, 4, 0.1, \
           cmap=plt.get_cmap('RdBu_r'), grid=True)

### Creates a xray dataset object from numpy arrays

In [None]:
lon = np.linspace(0, 357.5, 144, endpoint=True)
lat = np.linspace(-90,90, 73, endpoint=True)

lons, lats = np.meshgrid(lon,lat)

lev = np.array([1000,925,850])
time = pd.date_range(start='2015-1-1',end='2015-1-3')

In [None]:
lat

In [None]:
arr = np.random.randn(3,3,73,144)

The dictionnary **keys** are the **variables** contained in the Dataset.<br><br>
The Dictionnary **values** are **tuples**, with first the (or the list of) dimension(s) over which the array varies, then the array itself

In [None]:
d = {}
d['time'] = ('time',time)
d['latitudes'] = ('latitudes',lat)
d['longitudes'] = ('longitudes', lon)
d['level'] = ('level', lev)
d['var'] = (['time','level','latitudes','longitudes'], arr)

In [None]:
dset = xr.Dataset(d)

adding global attributes

In [None]:
dset.attrs['author'] = 'nicolas.fauchereau@gmail.com'

adding variables attributes

In [None]:
dset.longitudes.attrs['units'] = 'degrees_east'
dset.latitudes.attrs['units'] = 'degrees_north'

In [None]:
dset.sel(time='2015-1-2', level=1000)

In [None]:
lons, lats = np.meshgrid(dset['longitudes'], dset['latitudes'])

In [None]:
plot_field(dset.sel(time='2015-1-2', level=1000)['var'], \
           lats, lons, -4, 4, 0.1, grid=True)

In [None]:
dset.to_netcdf('../data/dset_from_dict.nc')

In [None]:
!/usr/local/bin/ncdump -h ../data/dset_from_dict.nc

### Creates a xray dataset object from a Pandas DataFrame

In [None]:
import string
df = pd.DataFrame(np.random.randn(365,5), \
                  index=pd.date_range(start='2014-1-1', periods=365), \
                  columns=list(string.ascii_letters[:5]))

In [None]:
df.head()

#### from DataFrame

In [None]:
df_ds = xr.Dataset.from_dataframe(df)

In [None]:
df_ds

In [None]:
group = df_ds.groupby('index.month').mean('index')

In [None]:
group

#### converts TO a Pandas.DataFrame

In [None]:
group_df = group.to_dataframe()

In [None]:
group_df.reindex_axis(list(string.ascii_letters[:5]), axis=1).head()

In [None]:
df.groupby(df.index.month).mean().head()

### Opening a file throught the network with openDAP

In [None]:
url = 'http://www.esrl.noaa.gov/psd/thredds/dodsC/Datasets/interp_OLR/olr.mon.mean.nc'

In [None]:
olr_dset = xr.open_dataset(url)

In [None]:
olr_dset.time[-1]

#### the dataset is not loaded in memory until one _selects_ something

In [None]:
olr_sub = olr_dset.sel(time='1998-1-1',lat=slice(30,-30), lon=slice(170, 300))

In [None]:
olr_sub

In [None]:
m = bm(projection='merc',llcrnrlat=-30,urcrnrlat=30,\
            llcrnrlon=170,urcrnrlon=300,\
            lat_ts=0,resolution='c')

In [None]:
lons, lats = np.meshgrid(olr_sub['lon'], olr_sub['lat'])

In [None]:
plot_field(olr_sub['olr'].values, lats, lons, 80, 300, 10, grid=True)