In [1]:
%load_ext autoreload
%autoreload 2

## Setup Dask

In [2]:
import dask
from dask.distributed import Client
from dask_jobqueue import SLURMCluster

dask.config.set(**{'array.slicing.split_large_chunks': False})

portdash = 10062

cluster = SLURMCluster(
    queue="batch",
    cores=8,
    processes=1,
    account="gfdl_o",
    memory="48GB",
    walltime="08:00:00",
    local_directory="$TMPDIR",
    death_timeout=240,
    scheduler_options={"dashboard_address":f":{portdash}"},
    job_extra_directives=["--exclude=pp[008-010],pp[013-075]"],
    job_name="mhw-metrics"
)

client = Client(cluster)

In [3]:
client

0,1
Connection method: Cluster object,Cluster type: dask_jobqueue.SLURMCluster
Dashboard: http://140.208.147.197:10062/status,

0,1
Dashboard: http://140.208.147.197:10062/status,Workers: 0
Total threads: 0,Total memory: 0 B

0,1
Comm: tcp://140.208.147.197:33283,Workers: 0
Dashboard: http://140.208.147.197:10062/status,Total threads: 0
Started: Just now,Total memory: 0 B


In [24]:
cluster.scale(jobs=0)

In [None]:
cluster.close()
client.close()

## Import Modules

In [5]:
import xarray as xr
import numpy as np
import dask.array as da

import util
import events
import mhwmetrics
import annualcycle

In [6]:
%%time
ds = xr.open_zarr("/xtmp/Xinru.Li/work/cm4_thetao_0151_0250_fast_detrend", use_cftime=True)
temp_chunk = util.load_chunk(ds,times=(None,None),xrange=(500,520),yrange=(700,720))
temp_chunk

CPU times: user 101 ms, sys: 39.5 ms, total: 140 ms
Wall time: 513 ms


Unnamed: 0,Array,Chunk
Bytes,891.11 MiB,891.11 MiB
Shape,"(36500, 16, 20, 20)","(36500, 16, 20, 20)"
Dask graph,1 chunks in 3 graph layers,1 chunks in 3 graph layers
Data type,float32 numpy.ndarray,float32 numpy.ndarray
"Array Chunk Bytes 891.11 MiB 891.11 MiB Shape (36500, 16, 20, 20) (36500, 16, 20, 20) Dask graph 1 chunks in 3 graph layers Data type float32 numpy.ndarray",36500  1  20  20  16,

Unnamed: 0,Array,Chunk
Bytes,891.11 MiB,891.11 MiB
Shape,"(36500, 16, 20, 20)","(36500, 16, 20, 20)"
Dask graph,1 chunks in 3 graph layers,1 chunks in 3 graph layers
Data type,float32 numpy.ndarray,float32 numpy.ndarray


In [16]:
%%time
def process_chunk(temp_chunk):
    temp_chunk = temp_chunk.load()
    mean, thresh = events.threshold_and_climo(temp_chunk)

    def _process_sub_chunk(temp_chunk, thresh, trange):
        sub_arr = temp_chunk.sel(time=slice(*trange))
        exceedences, thresh_tiled = events.calculate_execeedences(sub_arr, thresh)
        event_index, num_events = events.tag_events(exceedences)
        return mhwmetrics.calc_metrics(event_index, event_index, thresh_tiled)

    start = 151
    years = list(np.arange(start,start+100))
    subchunks = [(str(x).zfill(4)+"-01-01",str(x+29).zfill(4)+"-12-31") for x in years if x+29 < start+100]
    subchunks = subchunks[37:]
    
    dsmetrics = [_process_sub_chunk(temp_chunk, thresh, x) for x in subchunks]
    dsmetrics = xr.concat(dsmetrics, "window")
    dsmetrics = dsmetrics.assign_coords({"window":dsmetrics.window})

    try:
        assert "clim_freq" in dsmetrics.variables.keys()
    except Exception as exc:
        print(temp_chunk)
        print(dsmetrics)
        raise exc

    return dsmetrics

test_ds = process_chunk(temp_chunk)

Using Cython-accelerated percentile calculation
Using Cython-accelerated percentile calculation
CPU times: user 5min 31s, sys: 46.9 s, total: 6min 18s
Wall time: 5min 56s


In [17]:
test_ds

In [18]:
nwindow = 34

# Suppose your final output domain has shape (16, 1080, 1440)
w_size, z_size, y_size, x_size = (nwindow, 16, 1080, 1440)

# Choose chunk sizes for dask
w_chunks, z_chunks, y_chunks, x_chunks = (nwindow, 16, 20, 20)

# Build a dask array that spans the full domain
arr_shape = (w_size, z_size, y_size, x_size)
arr_chunks = (w_chunks, z_chunks, y_chunks, x_chunks)
dask_data  = da.zeros(arr_shape, chunks=arr_chunks, dtype='float32')

# Create a template with that full shape but chunked as you prefer
template = xr.Dataset(
    {
       'clim_freq':   (('window','z2000_l','yh','xh'), dask_data),
       'clim_dur':    (('window','z2000_l','yh','xh'), dask_data),
       'clim_Im':     (('window','z2000_l','yh','xh'), dask_data),
       'clim_Atot':   (('window','z2000_l','yh','xh'), dask_data),
       'clim_HSpeak': (('window','z2000_l','yh','xh'), dask_data),
    },
    coords={
        'window':   test_ds.window,  # length 37
        'z2000_l':  ds.z2000_l, # length 16
        'yh':       ds.yh,      # length 1080
        'xh':       ds.xh,      # length 1440
    }
)

In [19]:
%%time
result = xr.map_blocks(
    process_chunk,            # your function
    ds.thetao,               # the dataset to split into blocks
    template=template,       # structure of each block's output
)

CPU times: user 9min 11s, sys: 7.53 s, total: 9min 19s
Wall time: 9min 15s


In [20]:
%%time
result

CPU times: user 3 µs, sys: 1 µs, total: 4 µs
Wall time: 4.29 µs


Unnamed: 0,Array,Chunk
Bytes,3.15 GiB,850.00 kiB
Shape,"(34, 16, 1080, 1440)","(34, 16, 20, 20)"
Dask graph,3888 chunks in 8 graph layers,3888 chunks in 8 graph layers
Data type,float32 numpy.ndarray,float32 numpy.ndarray
"Array Chunk Bytes 3.15 GiB 850.00 kiB Shape (34, 16, 1080, 1440) (34, 16, 20, 20) Dask graph 3888 chunks in 8 graph layers Data type float32 numpy.ndarray",34  1  1440  1080  16,

Unnamed: 0,Array,Chunk
Bytes,3.15 GiB,850.00 kiB
Shape,"(34, 16, 1080, 1440)","(34, 16, 20, 20)"
Dask graph,3888 chunks in 8 graph layers,3888 chunks in 8 graph layers
Data type,float32 numpy.ndarray,float32 numpy.ndarray

Unnamed: 0,Array,Chunk
Bytes,3.15 GiB,850.00 kiB
Shape,"(34, 16, 1080, 1440)","(34, 16, 20, 20)"
Dask graph,3888 chunks in 8 graph layers,3888 chunks in 8 graph layers
Data type,float32 numpy.ndarray,float32 numpy.ndarray
"Array Chunk Bytes 3.15 GiB 850.00 kiB Shape (34, 16, 1080, 1440) (34, 16, 20, 20) Dask graph 3888 chunks in 8 graph layers Data type float32 numpy.ndarray",34  1  1440  1080  16,

Unnamed: 0,Array,Chunk
Bytes,3.15 GiB,850.00 kiB
Shape,"(34, 16, 1080, 1440)","(34, 16, 20, 20)"
Dask graph,3888 chunks in 8 graph layers,3888 chunks in 8 graph layers
Data type,float32 numpy.ndarray,float32 numpy.ndarray

Unnamed: 0,Array,Chunk
Bytes,3.15 GiB,850.00 kiB
Shape,"(34, 16, 1080, 1440)","(34, 16, 20, 20)"
Dask graph,3888 chunks in 8 graph layers,3888 chunks in 8 graph layers
Data type,float32 numpy.ndarray,float32 numpy.ndarray
"Array Chunk Bytes 3.15 GiB 850.00 kiB Shape (34, 16, 1080, 1440) (34, 16, 20, 20) Dask graph 3888 chunks in 8 graph layers Data type float32 numpy.ndarray",34  1  1440  1080  16,

Unnamed: 0,Array,Chunk
Bytes,3.15 GiB,850.00 kiB
Shape,"(34, 16, 1080, 1440)","(34, 16, 20, 20)"
Dask graph,3888 chunks in 8 graph layers,3888 chunks in 8 graph layers
Data type,float32 numpy.ndarray,float32 numpy.ndarray

Unnamed: 0,Array,Chunk
Bytes,3.15 GiB,850.00 kiB
Shape,"(34, 16, 1080, 1440)","(34, 16, 20, 20)"
Dask graph,3888 chunks in 8 graph layers,3888 chunks in 8 graph layers
Data type,float32 numpy.ndarray,float32 numpy.ndarray
"Array Chunk Bytes 3.15 GiB 850.00 kiB Shape (34, 16, 1080, 1440) (34, 16, 20, 20) Dask graph 3888 chunks in 8 graph layers Data type float32 numpy.ndarray",34  1  1440  1080  16,

Unnamed: 0,Array,Chunk
Bytes,3.15 GiB,850.00 kiB
Shape,"(34, 16, 1080, 1440)","(34, 16, 20, 20)"
Dask graph,3888 chunks in 8 graph layers,3888 chunks in 8 graph layers
Data type,float32 numpy.ndarray,float32 numpy.ndarray

Unnamed: 0,Array,Chunk
Bytes,3.15 GiB,850.00 kiB
Shape,"(34, 16, 1080, 1440)","(34, 16, 20, 20)"
Dask graph,3888 chunks in 8 graph layers,3888 chunks in 8 graph layers
Data type,float32 numpy.ndarray,float32 numpy.ndarray
"Array Chunk Bytes 3.15 GiB 850.00 kiB Shape (34, 16, 1080, 1440) (34, 16, 20, 20) Dask graph 3888 chunks in 8 graph layers Data type float32 numpy.ndarray",34  1  1440  1080  16,

Unnamed: 0,Array,Chunk
Bytes,3.15 GiB,850.00 kiB
Shape,"(34, 16, 1080, 1440)","(34, 16, 20, 20)"
Dask graph,3888 chunks in 8 graph layers,3888 chunks in 8 graph layers
Data type,float32 numpy.ndarray,float32 numpy.ndarray


In [None]:
%%time
result = result.compute()

This may cause some slowdown.
Consider scattering data ahead of time and using futures.


In [None]:
result.to_netcdf("results_window_34.nc")

In [23]:
!ls -lhrt

total 49G
drwxr-xr-x 2 John.Krasting o 4.0K Apr  6 10:08 fortran
-rw-r--r-- 1 John.Krasting o 417K Apr  6 10:09 fast_percentile.cpp
-rwxr-xr-x 1 John.Krasting o  90K Apr  6 10:09 fast_percentile.cpython-311-x86_64-linux-gnu.so
drwxr-xr-x 2 John.Krasting o 4.0K Apr  6 10:09 build
-rw-r--r-- 1 John.Krasting o 475M Apr  6 15:45 results.nc
-rw-r--r-- 1 John.Krasting o   67 Apr  7 09:59 README.md
-rw-r--r-- 1 John.Krasting o  11K Apr  7 10:03 annualcycle.py
-rw-r--r-- 1 John.Krasting o 2.9K Apr  7 10:03 events.py
-rw-r--r-- 1 John.Krasting o 3.9K Apr  7 10:03 fast_percentile.pyx
-rw-r--r-- 1 John.Krasting o 1.8K Apr  7 10:03 mhwmetrics.py
-rw-r--r-- 1 John.Krasting o  327 Apr  7 10:03 setup.py
-rw-r--r-- 1 John.Krasting o 1.1K Apr  7 10:03 util.py
drwxr-xr-x 2 John.Krasting o 4.0K Apr  7 10:11 __pycache__
-rw-r--r-- 1 John.Krasting o  15K Apr  7 12:06 slurm-44417504.out
-rw-r--r-- 1 John.Krasting o 950M Apr  7 12:40 results_window.nc
-rw-r--r-- 1 John.Krasting o  38K Apr  7 15:26 slurm-4441