# XMHW tests on the OFAM3 dataset

Purpose
-------
    The following will investigate the capability of xmhw to parallelise the MHW analysis on a subset of temperature data from the OFAM3 - 10th degree resolution global simulation from 1980-2100. The simulation runs from 1980 to 2006 under JRA55 atmospheric forcing, and thereafter the reanalysis is repeated but with the addition of the RCP8.5 climate trend.

    Contents:
        1. Load in Temperature Data and visualise (2D in space, 1D in time)
        2. Select the region around Australia to perform the heatwave analysis and throw rest away
        3. Calculate the climatology required for the heatwave analysis and save as a new netcdf file
            [ this will be read in later and in a new session for performing the heatwave analysis ]
        4. Perform heatwave analysis using xmhw by iterating around the subsetted grid

Thanks to John Reilly for sharing his [code](https://github.com/Thomas-Moore-Creative/shared_sandbox/blob/main/mhw-3d-scalingTests-gadiJup.ipynb)
    


### imports

In [None]:
import sys
import os

### data handling
import numpy as np
import pandas as pd
import xarray as xr
import scipy as sci

### plotting
import matplotlib.pyplot as plt
from matplotlib import ticker
from matplotlib.gridspec import GridSpec
import matplotlib.colors as mcolors
import cartopy.crs as ccrs
import cartopy.feature as cfeature
import cmocean.cm as cmo
from cmocean.tools import lighten

### marine heatwaves python package
from xmhw.xmhw import threshold, detect

# print versions of packages
print("python version =",sys.version[:5])
print("numpy version =", np.__version__)
print("pandas version =", pd.__version__)
print("xarray version =", xr.__version__)
print("scipy version =", sci.__version__)
print("matplotlib version =", sys.modules[plt.__package__].__version__)
print("cmocean version =", sys.modules[cmo.__package__].__version__)
print("cartopy version =", sys.modules[ccrs.__package__].__version__)


wrkdir = "/g/data/fp2/OFAM3"
os.chdir(wrkdir)


### remove warnings

In [None]:
import warnings
warnings.filterwarnings('ignore')

### import the dask client for assessing performance

In [None]:
from dask.distributed import Client
client = Client(threads_per_worker=2)
client

## grab the historical temperature data from fp2

In [None]:
sst = xr.open_mfdataset("./jra55_historical.1/surface/ocean_temp_sfc_*.nc", combine='by_coords').squeeze()

x1 = 1000
x2 = 1650
y1 = 250
y2 = 700


plt.figure()
plt.pcolormesh(sst['temp'].isel(Time=0))
plt.plot((x1,x1),(y1,y2), 'r-')
plt.plot((x2,x2),(y1,y2), 'r-')
plt.plot((x1,x2),(y1,y1), 'r-')
plt.plot((x1,x2),(y2,y2), 'r-')
plt.colorbar()



## subset the data to be only around the australian continent

In [None]:
sst = sst.isel(yt_ocean=slice(y1,y2), xt_ocean=slice(x1,x2)).drop_vars('st_ocean')
sst = sst['temp'].rename({"Time":"time"}).compute()
sst


In [None]:
print("Historical SST dataset = %i Gb"%(sst.nbytes/1e9))


## iterate around the australian continent and compute the heatwaves

### calculate the climatology
    which we will use later for calculating the marine heatwaves in a subsequent step

In [None]:
%%time

sst['doy'] = sst['time'].dt.dayofyear
sst = sst.chunk({"time":-1, "yt_ocean":10, "xt_ocean":10})
sst


### calculate the daily climatology and 90th percentile threshold to define a MHW

In [None]:
%%time

ii = 0
jj = 0
di = 50
dj = 50

print("Calculating the climatology and threshold")
seas_list = []
thresh_list = []
for ii in np.arange(0,len(sst.coords['xt_ocean']),di):
    print(ii)
    for jj in np.arange(0,len(sst.coords['yt_ocean']),dj):
        tmp = sst.isel(xt_ocean=slice(ii,ii+di), yt_ocean=slice(jj,jj+dj))
        seas_list.append(tmp.groupby('doy').mean(dim='time').compute())
        thresh_list.append(tmp.groupby('doy').quantile(0.9, dim='time', skipna=True).compute())

        
### merge the lists into single xarrays with the results
print("Merging results")
seas_new = xr.merge(seas_list)
thresh_new = xr.merge(thresh_list)


### perform rolling mean average (moving window) across the time dimension and snip ends

In [None]:
climatology = seas_new.pad(doy=(31-1)//2, mode='wrap').rolling(doy=31, center=True).mean()
threshold90 = thresh_new.pad(doy=(31-1)//2, mode='wrap').rolling(doy=31, center=True).mean(skipna=True)

climatology = climatology.chunk({'doy':-1, 'yt_ocean':50, 'xt_ocean':50}).isel(doy=slice(15,-15))
threshold90 = threshold90.chunk({'doy':-1, 'yt_ocean':50, 'xt_ocean':50}).isel(doy=slice(15,-15)).drop_vars('quantile')



In [None]:
print("Size (Mb) of daily climatology = %i"%(climatology.nbytes/1e6))
print("Size (Mb) of daily threshold90 = %i"%(threshold90.nbytes/1e6))

### save to disk

In [None]:
%%time
os.chdir("/g/data/es60/pjb581/heatwaves")
os.getcwd()

print("Saving climatology and threshold to disk")
climatology.to_netcdf('Australian_SST_daily_climatology.nc', mode='w')
threshold90.to_netcdf('Australian_SST_daily_MHWthreshold.nc', mode='w')
