# 0. Import needed Libraries

In [None]:
# the 3 main libraries
import numpy as np
import xarray as xr
import matplotlib.pyplot as plt
import cartopy.crs as ccrs
from eofs.xarray import Eof as eof
# librairies to do nicer plots...
from matplotlib.patches import Rectangle    # only to draw a rectangle

### 1. Get data:

In [None]:
#sstproduct = 'erssst'
#sstproduct = 'oisst025'
sstproduct = 'oisst'

In [None]:
if sstproduct == "erssst":
    
    data = xr.tutorial.open_dataset('ersstv5')
    data = data.sel(lat=slice(60., -60.))          # get rid of data out of 60S-60N
    data = data.sel(time=slice('1970','2021'))     # select years between 1970 and 2021
    mask = data.sst.isel(time=0)
    mask.data = np.where(np.isnan(mask.data), 0., 1.)
    
elif sstproduct == "oisst":
    
    data = xr.open_dataset('sst.mnmean.nc')
    data = data.sel(time=slice('1982','2022'))     # select years between 1982 and 2022
    data_mask = xr.open_dataset('lsmask.nc')
    mask = data_mask.mask                          # get the DataArray from the Dataset
    mask = mask.squeeze()                          # get rid of the time dimension in the variable mask
    
elif sstproduct == "oisst025":
    
    data = xr.open_dataset('sst.mon.mean.nc')
    data = data.sel(time=slice('1981-11','2023-11'))     # select years between 1982 and 2022
    mask = data.sst.isel(time=0)
    mask.data = np.where(np.isnan(mask.data), 0., 1.)    
else:

    print("wrong name of SST product")

In [None]:
weights = np.cos( np.deg2rad(mask.lat) )
weights = mask*weights

In [None]:
data

### Detrend SST

In [None]:
linfit = data.sst.polyfit('time', 1)
trend = xr.polyval(coord=data.time, coeffs=linfit.polyfit_coefficients)   # SST trend
sst_detrend = data.sst - trend.values + data.sst.mean(dim='time')         # detrended SST

### Detrended interannual anomaly

In [None]:
sstbymth = sst_detrend.groupby("time.month")
mthclim = sstbymth.mean("time")       # detrended climatological months (seasonal cycle)
sstanom = sstbymth - mthclim          # detrended interannual annomaly
sstanom = sstanom.rename('sstanom')   # change variable name

### Compute nino3.4 (5S-5N and 170-120W) index

In [None]:
if sstproduct == "oisst025":
    latslice = slice(-4.5,4.5)
else:
    latslice = slice(4.5,-4.5)

weights_nino34 = weights.sel(lon=slice(360-169.5,360-120.5), lat=latslice)
nino34_index = sstanom.weighted(weights_nino34).mean(dim=('lon','lat'))
nino34_index = nino34_index.rename('nino34')   # change variable name
nino34_index.plot()

### cut over Indian Ocean

In [None]:
if sstproduct == "oisst025":
    latslice = slice(-20, 20)
else:
    latslice = slice(20, -20)

data = data.sel(lon=slice(30,120),lat=latslice)
sstanom = sstanom.sel(lon=slice(30,120),lat=latslice)
mask = mask.sel(lon=slice(30,120),lat=latslice)
weights = np.cos( np.deg2rad(mask.lat) )
weights = mask*weights

### SST anomalies regressed onto Nino3.4 SST

In [None]:
sstanom.coords["time"] = (("time"), nino34_index.data)  # redefine 'time' coordinates with nino34 time series
linfit = sstanom.polyfit('time', 1)                     # compute the regression with this new "time" 
sstanom.coords["time"] = (("time"), data.time.data)     # put back original time

sstreg = linfit.polyfit_coefficients.isel(degree=0)
sstcor = xr.corr(sstanom, nino34_index, dim='time')                      

#  create a 2 pannels figure
fig, axes = plt.subplots(1,2,figsize=(14, 4))
sstreg.where(mask == 1.).plot(ax=axes[0])                                # plot the regression coefficient
rect = Rectangle((360-170,-5),50,10,linewidth=1,edgecolor='w',fill=False)   # nino3.4 rectangle (170W-120W, 5S-5N)
axes[0].add_patch(rect)                                                     # overlay 
axes[0].set_title('Regression coefficient')
sstcor.where(mask == 1.).plot(ax=axes[1])                                # plot the correlation
rect = Rectangle((360-170,-5),50,10,linewidth=1,edgecolor='w',fill=False)   # nino3.4 rectangle (170W-120W, 5S-5N)
axes[1].add_patch(rect)                                                     # overlay 
axes[1].set_title('Correlation')
axes[1].set_xlabel('lon')
axes[1].set_ylabel('lat')

# 4. EOF of SST anomalies:

### Create an [Eof](https://ajdawson.github.io/eofs/latest/api/eofs.xarray.html#eofs.xarray.Eof) object

In [None]:
solver = eof(sstanom, weights=weights)

### Explained variance by the different EOFs/PCs couple

In [None]:
solver.varianceFraction(neigs=10).plot.step(where='mid')

###  Plot the first 2 EOFs, PCs and Variance Fraction:  

In [None]:
# eof = linear_regression_coef(PC,sstanom) = Cov(PC,sstanom)/Var(PC)
# pcscaling=1 (default) --> Var(PC) = 1 --> eof = Cov(PC,sstanom)
eofs = solver.eofsAsCovariance(neofs=2, pcscaling=1)      # beware of syntaxe: A C in eofsAsCovariance
pcs  = solver.pcs(npcs=2, pcscaling=1)
varfrac = solver.varianceFraction(neigs=2)                # beware of syntaxe: F in varianceFraction

# create a 4 pannels figure
fig, axes = plt.subplots(2,2,figsize=(10, 5),constrained_layout=True)
# plot EOF1 and PC1
eofs.sel(mode=0).plot(ax=axes[0,0], cbar_kwargs={'label': '°C for 1 std of PC'})
axes[0,0].set_title('EOF 1: '+str(int(varfrac.values[0]*100))+'%') 
pcs.sel(mode=0).plot(ax=axes[1,0])
axes[1,0].set_title('PC 1')
# plot EOF2 and PC2
eofs.sel(mode=1).plot(ax=axes[0,1], cbar_kwargs={'label': '°C for 1 std of PC'})
axes[0,1].set_title('EOF 2: '+str(int(varfrac.values[1]*100))+'%')
pcs.sel(mode=1).plot(ax=axes[1,1])
axes[1,1].set_title('PC 2')

In [None]:
print(solver.pcs(npcs=10, pcscaling=1).var(dim='time'))

In [None]:
pcs  = solver.pcs(npcs=2, pcscaling=1)
xr.corr(pcs, nino34_index, dim='time').data                    


### The Indian Ocean Dipole indices
    IOD west: 50°E to 70°E and 10°S to 10°N
    IOD east: 90°E to 110°E and 10°S to 0°S
IOD index = IOD west - IOD east

In [None]:
if sstproduct == "oisst025":
    latslice = slice(-10,10)
else:
    latslice = slice(10,-10)
weights_iodw = weights.sel(lon=slice(50,70), lat=latslice)
iodw = sstanom.weighted(weights_iodw).mean(dim=('lon','lat'))

if sstproduct == "oisst025":
    latslice = slice(-10,0)
else:
    latslice = slice(0,-10)
weights_iode = weights.sel(lon=slice(90,110), lat=latslice)
iode = sstanom.weighted(weights_iode).mean(dim=('lon','lat'))
iod_index = iodw - iode

### IOD and Nino34 plot

In [None]:
fig, axes = plt.subplots(1,figsize=(10, 3))
iod_index.plot(ax=axes, label='IOD')
nino34_index.plot(ax=axes, label='Nino34')
plt.legend()

### correlation of IOD with EOF2 and Nino34

In [None]:
print('Correlation IOD-EOF2  :', xr.corr(iod_index, pcs.sel(mode=1),dim='time').data)
print('Correlation IOD-Nino34:', xr.corr(iod_index, nino34_index,dim='time').data)

### StdDev of IOD and Nino34

In [None]:
iod_std = iod_index.std(dim='time')
nino34_std = nino34_index.std(dim='time')
print('IOD StdDev   :', iod_std.data)
print('Nino34 StdDev:', nino34_std.data)

### Monthly std of IOD index

In [None]:
iod_index.groupby('time.month').std(dim='time').plot.step(where='mid')

## relation with the precipitation

In [None]:
data2 = xr.open_dataset('precip.mon.mean.nc')
data2 = data2.sel(lat=slice(60., -60.))          # get rid of data out of 60S-60N
if sstproduct == "oisst025":
    data2 = data2.sel(time=slice('1981-11','2023-11'))     
else:
    data2 = data2.sel(time=slice('1982','2022'))     
data2

In [None]:
plt.figure( figsize=(10, 3) )
crspc = ccrs.PlateCarree(central_longitude=200)
ax = plt.axes(projection=crspc)
crspc = ccrs.PlateCarree(central_longitude=0)
data2.precip.mean(dim='time').plot(ax=ax, transform=crspc, cmap='YlGnBu')
ax.gridlines(draw_labels = True)
ax.coastlines()

In [None]:
linfit = data2.precip.polyfit('time', 1)
trend = xr.polyval(coord=data2.time, coeffs=linfit.polyfit_coefficients)   # precip trend
precip_detrend = data2.precip - trend.values + data2.precip.mean(dim='time') # detrended precip
precipbymth = precip_detrend.groupby("time.month")
mthclim = precipbymth.mean("time")             # detrended climatological months (seasonal cycle)
precipanom = precipbymth - mthclim             # detrended interannual annomaly
precipanom = precipanom.rename('precipanom')   # change variable name

In [None]:
plt.figure( figsize=(10, 3) )
crspc = ccrs.PlateCarree(central_longitude=200)
ax = plt.axes(projection=crspc)
crspc = ccrs.PlateCarree(central_longitude=0)
precipanom.isel(time=0).plot(ax=ax, transform=crspc, cmap='seismic')
ax.gridlines(draw_labels = True)
ax.coastlines()

### Precip vs. Nino3.4

In [None]:
precipanom.coords["time"] = (("time"), nino34_index.data)  # redefine 'time' coordinates with nino34 time series
linfit = precipanom.polyfit('time', 1)                     # compute the regression with this new "time" 
precipanom.coords["time"] = (("time"), data2.time.data)     # put back original time

precipreg_nino = linfit.polyfit_coefficients.isel(degree=0)
precipcor_nino = xr.corr(precipanom, nino34_index, dim='time')                      

#  create a 2 pannels figure
crspc = ccrs.PlateCarree(central_longitude=200)
fig, axes = plt.subplots(2,1,figsize=(10, 5), subplot_kw=dict(projection=crspc))

crspc = ccrs.PlateCarree(central_longitude=0)
# plot the regression coefficient
precipreg_nino.plot(ax=axes[0], transform=crspc, cmap='seismic')
axes[0].set_title('Regression coefficient')
axes[0].coastlines()
# plot the correlation
precipcor_nino.plot(ax=axes[1], transform=crspc, cmap='seismic')
axes[1].set_title('Correlation')
axes[1].coastlines()

### Precip vs. IOD

In [None]:
precipanom.coords["time"] = (("time"), iod_index.data)  # redefine 'time' coordinates with iod time series
linfit = precipanom.polyfit('time', 1)                     # compute the regression with this new "time" 
precipanom.coords["time"] = (("time"), data2.time.data)     # put back original time

precipreg_iod = linfit.polyfit_coefficients.isel(degree=0)
precipcor_iod = xr.corr(precipanom, iod_index, dim='time')                      

#  create a 2 pannels figure
crspc = ccrs.PlateCarree(central_longitude=200)
fig, axes = plt.subplots(2,1,figsize=(10, 5), subplot_kw=dict(projection=crspc))

crspc = ccrs.PlateCarree(central_longitude=0)
# plot the regression coefficient
precipreg_iod.plot(ax=axes[0], transform=crspc, cmap='seismic')
axes[0].set_title('Regression coefficient')
axes[0].coastlines()
# plot the correlation
precipcor_iod.plot(ax=axes[1], transform=crspc, cmap='seismic')
axes[1].set_title('Correlation')
axes[1].coastlines()

In [None]:
#  create a 2 pannels figure
crspc = ccrs.PlateCarree(central_longitude=200)
fig, axes = plt.subplots(2,1,figsize=(10, 5), subplot_kw=dict(projection=crspc))

crspc = ccrs.PlateCarree(central_longitude=0)
# plot the regression coefficient
(precipreg_nino*nino34_std).plot(ax=axes[0], transform=crspc, cmap='seismic',vmin=-2.5,vmax=2.5)
axes[0].set_title('Precip/Nino Regression * Nino StdDev')
axes[0].coastlines()
# plot the correlation
(precipreg_iod*iod_std).plot(ax=axes[1], transform=crspc, cmap='seismic',vmin=-2.5,vmax=2.5)
axes[1].set_title('Precip/IOD Regression * IOD StdDev')
axes[1].coastlines()

### Precip anom for pure IOD, Nino and Mx IOD-Nino

In [None]:
crspc = ccrs.PlateCarree(central_longitude=200)
fig, axes = plt.subplots(3,1,figsize=(10, 7), subplot_kw=dict(projection=crspc))

crspc = ccrs.PlateCarree(central_longitude=0)
# precipanom of pure IOD events
precip_iod = precipanom.where( (iod_index > 2*iod_std) & (nino34_index < 0.5*nino34_std) ).mean(dim='time')
precip_iod.plot(ax=axes[0], transform=crspc, cmap='seismic',vmin=-10,vmax=10)
axes[0].set_title('"Pure IOD" precip')
axes[0].coastlines()
# precipanom of pure Nino events
precip_nino = precipanom.where( (iod_index < 0.5*iod_std) & (nino34_index > 2*nino34_std) ).mean(dim='time')
precip_nino.plot(ax=axes[1], transform=crspc, cmap='seismic',vmin=-10,vmax=10)
axes[1].set_title('"Pure Nino" precip')
axes[1].coastlines()
# precipanom of IOD-Nino events
precip_iodnino = precipanom.where( (iod_index > 2*iod_std) & (nino34_index > 2*nino34_std) ).mean(dim='time')
precip_iodnino.plot(ax=axes[2], transform=crspc, cmap='seismic',vmin=-10,vmax=10)
axes[2].set_title('"IOD+Nino" precip')
axes[2].coastlines()

### Precip November Mean

In [None]:
plt.figure( figsize=(10, 3) )
crspc = ccrs.PlateCarree(central_longitude=200)
ax = plt.axes(projection=crspc)
crspc = ccrs.PlateCarree(central_longitude=0)
novmean = data2.precip.groupby('time.month').mean(dim='time').sel(month=11)
novmean.plot(ax=ax, transform=crspc, cmap='YlGnBu',vmin=0,vmax=10)
ax.gridlines(draw_labels = True)
ax.coastlines()

### Ratio precip anom/mean

In [None]:
plt.figure( figsize=(10, 3) )
crspc = ccrs.PlateCarree(central_longitude=200)
ax = plt.axes(projection=crspc)
crspc = ccrs.PlateCarree(central_longitude=0)
(precip_iodnino /novmean).where(novmean > 1).plot(ax=ax, transform=crspc, cmap='seismic',vmin=-2,vmax=2)
ax.gridlines(draw_labels = True)
ax.coastlines()