# 0. Import needed Libraries

In [None]:
# the 3 main libraries
import numpy as np
import xarray as xr
import matplotlib.pyplot as plt
from eofs.xarray import Eof as eof
# librairies to do nicer plots...
from matplotlib.patches import Rectangle    # only to draw a rectangle
import cartopy.crs as ccrs     # only for plots with projections and continents maps
import hvplot.xarray           # only for "fancy" plots

# 1. xarray basic manipulation:

 - Open file

In [None]:
data = xr.open_dataset('sst.mnmean.nc')

 - explore [Dataset](http://xarray.pydata.org/en/stable/generated/xarray.Dataset.html)

In [None]:
data  # file Dataset

In [None]:
data.coords

In [None]:
data.time.size

In [None]:
data.data_vars

 - explore [DataArray](http://xarray.pydata.org/en/stable/generated/xarray.DataArray.html)

In [None]:
data.sst  # sst DataArray http://xarray.pydata.org/en/stable/generated/xarray.DataArray.html

In [None]:
print(data.sst.shape)
print(data.sst.ndim)
print(data.sst.dims)

data index selection: **[isel()](http://xarray.pydata.org/en/stable/generated/xarray.DataArray.isel.html)**  

In [None]:
data.sst.isel(lon=100, lat=50, time=0)     # select the point defined by its indexes (i=100, j=50, t=0)

data value selection: **[sel()](http://xarray.pydata.org/en/stable/generated/xarray.DataArray.sel.html)**  

In [None]:
data.sst.sel(lon=180.5, lat=0.5)   # time series for a point defined by its location (lon = 180.5, lat = 0.5)

the **[slice](https://docs.python.org/3.8/library/functions.html#slice)** object  

In [None]:
data.sst.sel(lon=180.5, lat=slice(0.5,-0.5))  # Warning : respect latitude order 0.5 before -0.5

different syntaxes accepted for the time dimension

In [None]:
print( data.sst.isel(time=200).shape )
print( data.sst.sel(time='2000-01-01T00:00:00.000000000').shape )
print( data.sst.sel(time='2000-01-01').shape )
print( data.sst.sel(time='2000-01').shape )
print( data.sst.sel(time='2000').shape )
print( data.sst.sel(time=slice('2000','2002')).shape )

combine isel and sel selections

In [None]:
data.sst.isel(time=200).sel(lon=180.5, lat=slice(0.5,-0.5))

the **[GroupBy](http://xarray.pydata.org/en/stable/generated/xarray.DataArray.groupby.html)** object  

In [None]:
data.sst.groupby('time.month').mean(dim='time').shape

# 2. reduce memory footprint

In [None]:
data = data.sel(time=slice('1982','2013'))   # keep only "complete" years
data = data.sel(lat=slice(59.5, -59.5))      # get rid of data out of 60S-60N

# 3. a few basic plots
See xarray doc on [visualization](https://xarray-contrib.github.io/xarray-tutorial/scipy-tutorial/04_plotting_and_visualization.html)

SST time series on 2 different SST time series on 2 different points, [pyplot options](https://matplotlib.org/api/_as_gen/matplotlib.pyplot.plot.html)

In [None]:
fig, axes = plt.subplots(2,1,sharex=True)
data.sst.isel(lon=100, lat=50).plot(ax=axes[0])
data.sst.sel(lon=180.5, lat=0.5).plot(ax=axes[1])
plt.show()

Spatial average of monthly mean 

In [None]:
data.sst.groupby('time.month').mean(dim='time').mean(dim=('lon','lat')).plot()

Get the land-sea mask

In [None]:
mask = xr.open_dataset('lsmask.nc')
mask = mask.sel(lat=slice(59.5, -59.5))   # reduce memory footprint: get rid of data out of 60S-60N
mask = mask.squeeze()                     # get rid of the time dimension in the variable mask.mask
mask

Simple SST maps without and with masked values.
[colorbars list](https://matplotlib.org/tutorials/colors/colormaps.html)

In [None]:
fig, axes = plt.subplots(2,1)
sst2d = data.sst.isel(time=100)
sst2d.plot(ax=axes[0])
sst2d = sst2d.where(mask.mask == 1.)
sst2d.plot(vmin=-2, vmax=32, cmap='gist_rainbow_r',ax=axes[1])
plt.show()

Map with continents and [projection](https://scitools.org.uk/cartopy/docs/latest/crs/projections.html) 

In [None]:
p = data.sst.isel(time=200).where(mask.mask == 1.).plot(
    subplot_kws=dict( projection=ccrs.PlateCarree(central_longitude=200), facecolor='gray'),
    transform=ccrs.PlateCarree(),vmin=-2, vmax=32, cmap='gist_rainbow_r'
)
p.axes.set_global()
p.axes.coastlines()
p.axes.gridlines()

Widget maps with continents and projection

In [None]:
#proj = ccrs.PlateCarree(central_longitude=200)
proj = ccrs.LambertCylindrical(central_longitude=200)

data.sst.isel(time=slice(1, 13)).where(mask.mask == 1.).hvplot.quadmesh(
    'lon', 'lat', clim=(-2,32), projection=proj, project=True, global_extent=True, 
    cmap='gist_rainbow_r', rasterize=True, dynamic=False, coastline=True, 
    frame_width=500, widget_location='bottom', widget_type='scrubber')

Longitude-time plot around the equator "[hovmoller diagram](https://fr.wikipedia.org/wiki/Diagramme_de_Hovm%C3%B6ller)"

In [None]:
data.sst.sel(lat=slice(1.5,-1.5)).mean(dim='lat').plot()

# 4. basic computation:

Compute spatial weight:  
Earth is not flat! Regular lat-lon grid cells area varies with the cosinus of the latitude (see [sperical coordinates](https://mathinsight.org/spherical_coordinates)).  
  -> weight data with mask * cos( lat )

In [None]:
weights = np.cos( np.deg2rad(mask.lat) )
weights = mask*weights
weights

Visualize the weights

In [None]:
fig, axes = plt.subplots(2,1, sharex=True)
mask.mask.plot(ax=axes[0])
weights.mask.plot(ax=axes[1])

In [None]:
data.sst.weighted(weights.mask).mean(("time", "lon")).plot()

Comparison of horizontal mean: masked and [weighted](http://xarray.pydata.org/en/stable/computation.html?highlight=rolling#weighted-array-reductions)

In [None]:
data.sst.mean(("lon", "lat")).plot(label="unweighted")
weights = np.cos(np.deg2rad(data.lat))
data.sst.weighted(weights).mean(("lon", "lat")).plot(label="weighted")
weights = np.cos(np.deg2rad(mask.lat))
weights = mask*weights
data.sst.weighted(weights.mask).mean(("lon", "lat")).plot(label="weighted+mask")
plt.legend()

Mean seasonal cycle

In [None]:
a = data.sst.groupby('time.month').mean(dim='time')
#a.mean(dim=('lon','lat')).plot(label="unweighted")
fig, axes = plt.subplots(3,1, figsize=(10, 10), constrained_layout=True)
a.weighted(weights.mask).mean(dim=('lon','lat')).plot(label="gobal mean", ax=axes[0])
axes[0].set_title('gobal mean')
a.sel(lon=150.5, lat=2.5).plot(ax=axes[1])
axes[1].set_title('West Eq. Pac.')
a.sel(lon=300.5, lat=45.5).plot(ax=axes[2])
axes[2].set_title('North-West Atl.')

[Rolling](http://xarray.pydata.org/en/stable/computation.html?highlight=rolling#rolling-window-operations) mean

In [None]:
a = data.sst.weighted(weights.mask).mean(("lon", "lat"))
a.plot(label="weighted+mask")
b = a.rolling(time=12, center=True).mean()
b.plot(label="rolling 1 year")
b = a.rolling(time=12*5, center=True).mean()
b.plot(label="rolling 5 years")
plt.legend()
c = a.rolling(time=12*10, center=True).mean()
c.plot(label="rolling 10 years")
plt.legend()

Global yearly mean

In [None]:
a = data.sst.groupby('time.year').mean(dim='time')   # yearly mean
a.weighted(weights.mask).mean(dim=('lon','lat')).plot(label="weighted+mask")
#a.mean(dim=('lon','lat')).plot(label="unweighted")
plt.legend()

Total standard deviation : [std](http://xarray.pydata.org/en/stable/generated/xarray.Dataset.std.html?highlight=std) method 

In [None]:
a = data.sst.std(dim='time').where(mask.mask == 1.)
fig, axes = plt.subplots(2,1)
a.plot(vmin=0,vmax=6,ax=axes[0])
a.sel(lon=slice(260.5,320.5),lat=slice(50.5,10.5)).plot(vmin=0,vmax=6,ax=axes[1])

Seasonal standard deviation

In [None]:
a = data.sst.groupby('time.month').mean(dim='time')
#(a.max(dim='month') - a.min(dim='month')).where(mask.mask == 1.).plot(vmin=0,vmax=12)
a.std(dim='month').where(mask.mask == 1.).plot(vmin=0,vmax=6)

# 5. linear trend:

Long term mean trend ([linear regression](http://xarray.pydata.org/en/stable/computation.html?highlight=polyval#fitting-polynomials) with [polyfit](http://xarray.pydata.org/en/stable/generated/xarray.DataArray.polyfit.html?highlight=polyfit) and [polyval](http://xarray.pydata.org/en/stable/generated/xarray.polyval.html?highlight=polyval) methonds)

In [None]:
linfit = data.sst.polyfit('time', 1)
trend = xr.polyval(coord=data.time, coeffs=linfit.polyfit_coefficients)
ns_century = 1.e9*3600.*24.*365.*100.   # nono-seconds to century
(linfit.polyfit_coefficients.isel(degree=0)*ns_century).where(mask.mask == 1.).plot(vmin=-5,vmax=5,cmap='RdBu_r')

In [None]:
a = data.sst.groupby('time.year').mean(dim='time')   # yearly mean
a.weighted(weights.mask).mean(dim=('lon','lat')).plot(label="org")
a = trend.groupby('time.year').mean(dim='time')   # yearly mean
a.weighted(weights.mask).mean(dim=('lon','lat')).plot(label="trend")
plt.legend()

Detrend SST

In [None]:
sst_detrend = data.sst - trend.values

In [None]:
fig, axes = plt.subplots(1,2, figsize=(6,3), constrained_layout=True)
a = data.sst.groupby('time.year').mean(dim='time')   # yearly mean
a.weighted(weights.mask).mean(dim=('lon','lat')).plot(ax=axes[0])
axes[0].set_title('With trend')
a = sst_detrend.groupby('time.year').mean(dim='time')   # yearly mean
a.weighted(weights.mask).mean(dim=('lon','lat')).plot(ax=axes[1])
axes[0].set_title('Without trend')
plt.show()

Detrended interannual anomaly

In [None]:
sstmth = sst_detrend.groupby("time.month")
sstanom = sstmth - sstmth.mean("time")

interannual anomaly (with/without trend)

In [None]:
a = data.sst.weighted(weights.mask).mean(dim=('lon','lat'))
clim = a.groupby("time.month").mean("time")
anom = a.groupby("time.month") - clim
anom.plot(label="org")
sstanom.weighted(weights.mask).mean(dim=('lon','lat')).plot(label="detrended")
plt.legend()

# 6. Regression onto Nino3.4:

Interannual standard deviation, with [nino3.4](https://www.ncdc.noaa.gov/teleconnections/enso/indicators/sst/) box (5S-5N and 170-120W).

In [None]:
sstanom.std(dim='time').where(mask.mask == 1.).plot(vmin=0,vmax=2,cmap='YlGnBu')
# overplot nino3.4 rectangle (170W-120W, 5S-5N)
ax = plt.gca()                                                              # Get the current reference
rect = Rectangle((360-170,-5),50,10,linewidth=1,edgecolor='r',fill=False)   # Create a Rectangle patch
ax.add_patch(rect)                                                          # Add the patch to the Axes

Equatorial hovmoller of the SST interannual anomaly

In [None]:
fig, axes = plt.subplots(1,1,figsize=(5, 10))
ssteq = sstanom.sel(lat=slice(2.5,-2.5), lon=slice(100.5, 285.5))
wgteq = weights.sel(lat=slice(2.5,-2.5), lon=slice(100.5, 285.5))
ssteq.weighted(wgteq.mask).mean(dim='lat').plot()

Compute nino3.4 (5S-5N and 170-120W) index

In [None]:
nino34 = sstanom.sel(lon=slice(360-169.5,360-120.5), lat=slice(4.5,-4.5))
weights_nino34 = weights.mask.sel(lon=slice(360-169.5,360-120.5), lat=slice(4.5,-4.5))
nino34_index = sstanom.weighted(weights_nino34).mean(dim=('lon','lat'))
nino34_index.plot()

SST anomalies regressed onto Nino3.4 SST

In [None]:
sstanom.coords["time"] = (("time"), nino34_index)          # redefine 'time' coordinates with nino34 time series
linfit = sstanom.polyfit('time', 1)                        # compute the regression with this new "time" 
sstanom.coords["time"] = (("time"), data.coords["time"])   # put back original time
sstreg = linfit.polyfit_coefficients.isel(degree=0)
#
fig, axes = plt.subplots(1,2,figsize=(14, 5))
sstreg.where(mask.mask == 1.).plot(ax=axes[0])                              # plot the regression coefficient
rect = Rectangle((360-170,-5),50,10,linewidth=1,edgecolor='w',fill=False)   # nino3.4 rectangle (170W-120W, 5S-5N)
axes[0].add_patch(rect)                                                     # overlay 
axes[0].set_title('Regression coefficient')
sstcor = xr.corr(sstanom, nino34_index, dim='time')
sstcor.where(mask.mask == 1.).plot(ax=axes[1])
rect = Rectangle((360-170,-5),50,10,linewidth=1,edgecolor='w',fill=False)   # nino3.4 rectangle (170W-120W, 5S-5N)
axes[1].add_patch(rect)                                                     # overlay 
axes[1].set_title('Correlation')
axes[1].set_xlabel('lon')
axes[1].set_ylabel('lat')