In [1]:
import dask.array as da
import xarray as xr
import numpy as np
import pandas as pd

In [15]:
 # --- Function to create realistic, large-scale synthetic data ---
def create_synthetic_3d_mhw_data(shape, chunks):
    """
    Generates a large, 3D Dask array with a realistic time series structure.
    Includes a seasonal cycle, a linear trend, noise, and embedded MHWs.
    """
    # 1. Create a time coordinate array
    time_coords = pd.date_range("2000-01-01", periods=shape[0])
    # Create a Dask array of day-of-year values, chunked like our data
    doy = xr.DataArray(time_coords.dayofyear, dims=['time'], coords={'time': time_coords}).chunk({'time': chunks[0]}).data
    
    # 2. Build the time series components
    # Seasonal Cycle (broadcasts across lat/lon)
    seasonal_cycle = 5 * np.cos(2 * np.pi * (doy - 150) / 365.25)
    # Linear Trend
    trend = 0.2 * da.linspace(0, 1, shape[0], chunks=chunks[0])
    # Base temperature
    base_temp = 15.0
    
    # Combine them into a base time series. Dask handles the broadcasting.
    base_timeseries = base_temp + seasonal_cycle[:, None, None] + trend[:, None, None]
    
    # 3. Add random noise
    noise = da.random.normal(0, 0.5, size=shape, chunks=chunks)
    
    # 4. Create the final temperature data
    temp_data = base_timeseries + noise
    
    # 5. Embed some large MHW events
    # This creates a "mask" and adds heat to specific regions and times
    # MHW 1: A large event in one corner
    temp_data = temp_data.map_blocks(
        lambda block, block_info=None:
            block + 4.0 * (block_info[0]['chunk-location'][0] == 1) * # In the 2nd time chunk
                        (block_info[0]['chunk-location'][1] == 0) * # In the 1st lat chunk
                        (block_info[0]['chunk-location'][2] == 0), # In the 1st lon chunk
        dtype=temp_data.dtype
    )
    # MHW 2: A different event in another corner
    temp_data = temp_data.map_blocks(
        lambda block, block_info=None:
            block + 3.0 * (block_info[0]['chunk-location'][0] == 5) * # In the 6th time chunk
                        (block_info[0]['chunk-location'][1] == 1) * # In the 2nd lat chunk
                        (block_info[0]['chunk-location'][2] == 1), # In the 2nd lon chunk
        dtype=temp_data.dtype
    )
    
    return temp_data

# --- Setup the Benchmark ---
# Define the shape and chunking for our large dataset
# MODIFICATION: Changed from 10 years to 30 years
shape = (365 * 30, 200, 200) # 30 years, 200x200 grid
chunks = (365, 50, 50)       # Keep chunks the same size

print("Creating large synthetic dataset with a 30-year baseline...")
# The create_synthetic_3d_mhw_data function does not need to change.
dask_data = create_synthetic_3d_mhw_data(shape, chunks)

# Create coordinates
time = pd.date_range("2000-01-01", periods=shape[0])
lat = np.arange(shape[1])
lon = np.arange(shape[2])

# Create the final xarray DataArray
ds_temp = xr.DataArray(
    dask_data,
    dims=["time", "lat", "lon"],
    coords={"time": time, "lat": lat, "lon": lon},
    name="temperature"
)
print(f"Dataset size: {ds_temp.nbytes / 1e9:.2f} GB")

Creating large synthetic dataset with a 30-year baseline...
Dataset size: 3.50 GB


In [16]:
import mhw3d.bipolarMhwToolBox as ben_mhw

In [17]:
%%time
print("Step 1: Calculating climatology and threshold...")
seas = ben_mhw.smoothedClima_mhw(ds_temp).compute()
thresh = ben_mhw.smoothedThresh_mhw(ds_temp).compute()

print("Step 2: Preparing daily anomaly and severity data...")
ssta = ds_temp.groupby('time.dayofyear') - seas
thresh_aligned = thresh.sel(dayofyear=ds_temp['time.dayofyear'])
seas_aligned = seas.sel(dayofyear=ds_temp['time.dayofyear'])
severity = ssta / (thresh_aligned - seas_aligned + 1e-9)

ds_for_detection = xr.Dataset({
    'ssta': ssta,
    'severity': severity,
    'time': ds_temp.time
})

# --- THE FIX ---
# Re-chunk the data so the 'time' dimension is a single block,
# which is required by the 'core_dims' of the ufunc.
print("Step 3: Re-chunking data for the detection algorithm...")
ds_for_detection = ds_for_detection.chunk({"time": -1, "lat": "auto", "lon": "auto"})

print("Step 4: Building the Dask graph for MHW detection...")
mhw_results_lazy = ben_mhw.calculate_MHWs_metrics(ds_for_detection)

print("Step 5: Triggering Dask computation. This is the main workload...")
mhw_results_computed = mhw_results_lazy.compute()

print("Benchmark complete.")
display(mhw_results_computed)

Step 1: Calculating climatology and threshold...
Step 2: Preparing daily anomaly and severity data...
Step 3: Re-chunking data for the detection algorithm...
Step 4: Building the Dask graph for MHW detection...
Step 5: Triggering Dask computation. This is the main workload...
Benchmark complete.


CPU times: user 1h 24min 29s, sys: 10min 45s, total: 1h 35min 15s
Wall time: 29min 12s
