# Computation with xarray


## Learning Objectives

- Do basic arithmetic with DataArrays and Datasets
- Perform aggregation or reduction along one or multiple dimensions of a
  DataArray or Dataset


## Arithmetic Operations

Arithmetic operations with a single DataArray automatically vectorize (like numpy) over all array values. 

![](https://jakevdp.github.io/PythonDataScienceHandbook/figures/02.05-broadcasting.png)

Image Credit: Jake VanderPlas, Python Data Science Handbook

We will demonstrate it using the example dataset containing the monthly means of sea surface temperature. The data are in degrees centigrade and we want to convert the units to Kelvin


In [None]:
import xarray as xr
import matplotlib.pyplot as plt
import numpy as np

In [None]:
ds = xr.open_dataset('sst.mnmean.nc')
da = ds['sst']
da # check that units are degC

In [None]:
# convert to kelvin
sst_k = da + 273.15
sst_k

You will notice a few changes in this new *local* `DataArray`. 
- the local name in memory is `sst_k`, but the object name retains `sst`
- the attributes have disappeared but the coordinates are still associated
- the data have been modified as required

This new DataArray is not included in the Dataset. The following operation adds another varaible to the dataset with the new name. But the attributes are not carried over. You need to add them manually and modify them accordingly:

In [None]:
ZEROK=-273.15
ds['sst_k'] = da - ZEROK
print(ds.data_vars)
ds.sst_k.attrs = ds.sst.attrs
ds.sst_k.attrs['units'] = 'K'
ds.sst_k.attrs['actual_range'] = ds.sst.attrs['actual_range'] - ZEROK
ds.sst_k.attrs['valid_range'] = ds.sst.attrs['valid_range'] - ZEROK
ds.sst_k

## Aggregation (Reduction) Methods

Xarray supports many of the aggregation methods that numpy implements. With _aggregation or reduction methods_ we intend all the operations that change the number of dimensions of an array. When you compute the arithmetic mean of a series of 10 values, you change the dimensions from 1D to 0D: the 10 numbers are aggregated into 1 value.

A partial list of methods includes: `all, any, argmax, argmin, max, mean, median, min, prod, sum, std, var`.

The power of xarray is that whereas the numpy syntax would require scalar axes (i.e. 0,1,2, etc.), **xarray can use dimension names**. In the following code, the 3D object `sst` is reduced to a 2D array by computing the long-term climatological annual mean. The argument `dim` in the `mean()` method accepts the dimension labels

In [None]:
da_mean = da.mean(dim='time')
print(da_mean)
# the following command is an example of calling other matplotlib functions from xarray methods
da_mean.plot.contourf(levels=np.arange(0.,30.,2.),cmap='turbo')

This plot is done using one of the many [matplotlib colormaps](https://matplotlib.org/stable/tutorials/colors/colormaps.html)

Aggregation also works on **multiple dimensions**. The aggregation of 2 spatial dimensions returns a 1D object: a timeseries of the global ocean monthly standard deviation in the case shown below:

In [None]:
sst_std = da.std(dim=['lat', 'lon'])
print(sst_std)
sst_std.plot(figsize=(10,8),marker='o')

## Broadcasting
**Broadcasting** allows an operator or a function to act on two or more arrays
to operate even if these arrays do not have the same shape. 

That said, not all the dimensions can be subjected to broadcasting; they must meet certain rules.
The image below illustrates how perfoming an operation on arrays with different coordinates will result in automatic broadcasting

![](https://tutorial.xarray.dev/_images/broadcasting.png)

Image Credit: Stephan Hoyer

In [None]:
da.shape, da.dims

In [None]:
da_mean.shape, da_mean.dims

The following operation subtracts the mean (2D array) from the original array (3D array), to obtain the climatological monthly anomaly. The climatological mean is broadcasted over time (under the hood) to create a 3D object that can be subtracted from the original array.

**Note**: this operation creates a monthly anomaly that retains the seasonal signal (the annual mean is subtracted from each month). See below for computing monthly anomalies removing the climatological seasonal cycle (a repeated cycle containing the 12 climatological months is subtracted from the orginial data

In [None]:
anom = da - da_mean
anom

## High level computation: groupby, resample, rolling

Xarray has some very useful high level objects that let you do common
computations:

- `groupby` :
  [Bin data in to groups and reduce](https://xarray.pydata.org/en/stable/groupby.html)
- `resample` :
  [Groupby specialized for time axes. Either downsample or upsample your data](https://xarray.pydata.org/en/stable/time-series.html#resampling-and-grouped-operations).
- `rolling` :
  [Operate on rolling windows of your data e.g. running mean](https://xarray.pydata.org/en/stable/computation.html#rolling-window-operations)



### The _groupby_ method
This is a very powerful method, which is inherited from pandas. It is easier to first understand it using pandas DataFrames before applying it to Datasets.

Single DataArrays and whole Datasets can be binned in groups based on certain criteria, such as for instance grouping by season or by any other time period. 

Check out the examples below that compute the *climatological seasonal mean* and compare the dimensions of the created objects. A new dimension is created, reflecting the grouping criterion.

In [None]:
ds

In [None]:
# seasonal groups
ds.groupby('time.season')

In [None]:
# day of the week groups
ds.groupby('time.dayofweek')

In [None]:
# compute a climatological seasonal mean by applying an aggregator on the group
seasonal_mean = ds.groupby('time.season').mean()
seasonal_mean

Note that the seasons are out of order (they are alphabetically sorted). This is a common annoyance, which does not preclude the use of the data if you are extracting them through the labels.

However, it may be a problem if you want to use one great feature of xarray visualization: [faceting](http://xarray.pydata.org/en/stable/plotting.html#faceting). **Faceting** allows you to display the same kind of plot repeated multiple times, along one dimension. The order of the dimension is followed, and hence in this case we would plot the seasons in the wrong order.

The solution is to use the method `reindex`, which rearranges the index of your Dataset. And then the faceting can be applied by selecting the dimension to show in the columns, and when to *wrap the columns*

In [None]:
seasonal_mean = seasonal_mean.reindex(season=['DJF', 'MAM', 'JJA', 'SON'])
seasonal_mean.sst.plot(col='season', col_wrap=2, cmap='turbo', vmin=0., vmax=28.)

The groupby method allows us to do more advanced broadcasting. For instance to compute the monthly anomalies by removing the climatological seasonal cycle. This is a typical method used in climate sciences to analyse the interannual variability of a signal. We can for instance check how different the mean SST in January 2016 was from the climatological mean, and so on for every individual month. Remember that to obtain a meaningful anomaly, the climatology should be computed over a sufficiently long period of time, in order to capture the natural climate variability (usually 30 years).

In the following, a repeated cycle containing the 12 climatological months is subtracted from the original data (still grouped by month), and we extract a point to check the time series. Notice that now the range is centred on zero.

In [None]:
monclim = da.groupby('time.month').mean(dim='time')
print(monclim.shape)
monanom = da.groupby('time.month')-monclim
monanom.sel(lat=-38,lon=12).plot(figsize=(12,8),marker='o')
plt.grid()

### The _resample_ method
Another operation inherited from [pandas](https://pandas.pydata.org/pandas-docs/stable/user_guide/timeseries.html#resampling) is resampling. It is another ways to aggregate data and to compute statistics on downsampled values (upsample can also be done, to create a higher frequency timeseries with linear interpolation). The key parameter is the *rule*, which indicate the frequency. 

In [None]:
# resample to annual frequency. Check the difference between using '1Y' and '1YS'
ym=ds.sst.resample(time="1YS").mean()
ym

In [None]:
ym.sel(lat=-36,lon=10).plot(marker='s')

### Performing operations on rolling windows
The method `rolling` can be used to compute running means. The rolling window is expressed in number of time steps (in this case months). Notice that the plotting default is to show the mean at the end of the window period, hence the line will start from month 13. You can change this behaviour by adding the parameter `center=True`. 

In [None]:
# A rolling mean with a window size of 13 months, to cover the full year
axis = plt.axes()
roll=ds.sst.rolling(time=13).mean()
roll.sel(lat=-36,lon=10).plot(ax=axis,label='smoothed (1 year)')
ds.sst.sel(lat=-36,lon=10).plot(ax=axis,label='original',marker='+')
plt.legend()

## Going Further


<div class="alert alert-block alert-success">
  <p>Computation with xarray (extended version): <a href="https://xarray-contrib.github.io/xarray-tutorial/scipy-tutorial/03_computation_with_xarray.html">
      Computation with xarray notebook</a></p>
  <p>Plotting and visualization (extended version): <a href="https://xarray-contrib.github.io/xarray-tutorial/scipy-tutorial/04_plotting_and_visualization.html">Plotting and Visualization notebook</a></p>
</div>
