## Importing packages and defining functions

In [1]:
from dask.distributed import LocalCluster, Client
import dask.array
import datetime
from datetime import date 
from datetime import datetime
import glob
import numpy as np
import scipy.ndimage as ndimage
from scipy.ndimage.measurements import label, find_objects
import xarray as xr

In [2]:
# This is for running on Gadi

import os
import dask.distributed
threads_per_worker = 1
try:
    c # Already running
except NameError:
    c = dask.distributed.Client(
        n_workers=int(os.environ['PBS_NCPUS'])//threads_per_worker,
        threads_per_worker=threads_per_worker,
        memory_limit=f'{4*threads_per_worker}gb',
        local_directory=os.path.join(os.environ['PBS_JOBFS'],
                                     'dask-worker-space')
    )
c

KeyError: 'PBS_NCPUS'

In [3]:
def atleastn(da, n, dim='time'):
    """
    Return values with at least n contiguous points around them
    """

    def atleastn_helper(array, n, axis):
        count = np.zeros_like(np.take(array, 0,axis=axis), dtype='i4')
        mask = np.empty_like(np.take(array, 0,axis=axis), dtype='bool')
        mask = True
    
        for i in range(array.shape[axis]):
            array_slice = np.take(array, i, axis=axis)
        
            # Increase the count when there is a valid value, reset when there is not
            # This was initially set to 0, now I have changed it to 1 to detect only valid heatwave days 
            # The previous way was fine as long as I masked values less than or equal to 1, and they were white on the colour bar
            count = np.where(array_slice > 1, count + 1, 0)
        
            # Add new points when the contiguous count exceeds the threshold
            mask = np.where(count >= n, False, mask)
            
        out_slice = np.take(array, array.shape[axis]//2, axis=axis)
        return np.where(mask, np.nan, out_slice)
    
    def atleastn_dask_helper(array, axis, **kwargs):
        r = dask.array.map_blocks(atleastn_helper, array, drop_axis=axis, axis=axis, n=n, dtype=array.dtype)
        return r
    
    if isinstance(da.data, dask.array.Array):
        reducer = atleastn_dask_helper
    else:
        reducer = atleastn_helper
        
    return da.rolling({dim: n*2-1}, center=True, min_periods=n).reduce(reducer, n=n)


## Opening files 

In [4]:
files = sorted(glob.glob('/g/data/e14/cp3790/Charuni/MHW-sev/mhw_severity.pc90.*.nc'))

mhw = xr.open_mfdataset(files, combine='by_coords').sel(time=slice('1982', '2018'))

In [5]:
mhw_sev = mhw.severity

## Calculations 

In [6]:
%%time
%matplotlib inline
#hws_preDur = xarray.open_dataset('/scratch/w35/saw562/helpdesk/severity.nc', chunks={'time': 100, 'latitude': 100, 'longitude': 100})


candidates = mhw_sev.where(mhw_sev > 1)

CPU times: user 386 ms, sys: 35 ms, total: 421 ms
Wall time: 558 ms


Masking out points where there are less than 5 contiguous points in the time dimension

In [7]:
%%time

oscar = atleastn(candidates, n=5)

CPU times: user 53 ms, sys: 3 ms, total: 56 ms
Wall time: 54.5 ms


## Saving to netcdf

In [8]:
%%time

# Save to a file

xr.Dataset({'severity': oscar}).to_netcdf('/g/data/e14/cp3790/Charuni/filtered_severity_mhw.nc',
                                              encoding={'severity': 
                                                        {'chunksizes': (100, oscar.shape[1], oscar.shape[2]),
                                                         'zlib': True,
                                                         'shuffle': True, 
                                                         'complevel': 2}})   

# compression level (complevel) up to 6 is fine, >6 and it starts giving trouble 

  return func(*(_execute_task(a, cache) for a in args))


MemoryError: Unable to allocate 4.35 GiB for an array with shape (13514, 80, 120, 9) and data type float32