# Test the functions

In [2]:
import ecoclim_tools as et
import numpy as np
import pandas as pd
import xarray as xr
import dask
import pytest

In [47]:
import dask
from dask_jobqueue import SLURMCluster
from dask.distributed import Client

# Create a SLURM cluster object
cluster = SLURMCluster(
    job_name='test_detrend',
    cores=1,
    memory='1GiB', # how much memory for each core, if apply for 10 core then in total 80G
    processes=1,
    job_cpu=1,
    job_mem='1GB',
    queue='work',
    walltime='00:30:00', # time limit to run this job
    log_directory='./log'  # Set the directory for .out files
    # local_directory='/local_scratch/slurm.$SLURM_JOB_ID/dask/spill',
    # interface='ib0'
)

Perhaps you already have a cluster running?
Hosting the HTTP server on port 38701 instead


In [48]:
# Create the client to load the Dashboard
client = Client(cluster)

# Scale the cluster to 2 workers # determine how many core your are using
cluster.scale(20) # <- adjust this number

In [30]:
print(client.dashboard_link)

http://10.0.30.61:35113/status


In [45]:
client.shutdown()

In [49]:
def test_dask_detrend():
    # Create a sample 3D Dask DataArray with a linear trend
    time = np.arange(100)
    data = xr.DataArray(
        2 * time[:, None, None] + np.random.rand(100, 720, 720),
        coords=[('time', time), ('lat', np.arange(720)), ('lon', np.arange(720))],
        dims=('time', 'lat', 'lon')
    )
    dask_data = data.chunk({'time': -1, 'lat': 72, 'lon': 72})
    
    # Apply detrend
    detrended = et.detrend(dask_data)
    
    # Compute the result
    detrended_computed = detrended.compute()
    
    # Assert: Check the slope at a specific point (e.g., lat=0, lon=0)
    # np.polyfit needs 1D arrays, so we select one pixel
    pixel_series = detrended_computed.isel(lat=0, lon=0).values
    
    slope = np.polyfit(time, pixel_series, 1)[0]
    
    # The trend (slope 2) should be removed, so slope should be ~0
    assert abs(slope) < 1e-5

In [50]:
# check whether including dask accelerates the code => yes

In [58]:
%timeit -r 5 -n 1 test_dask_detrend()

This may cause some slowdown.
Consider loading the data with Dask directly
 or using futures or delayed objects to embed the data into the graph without repetition.
See also https://docs.dask.org/en/stable/best-practices.html#load-data-with-dask for more information.
This may cause some slowdown.
Consider loading the data with Dask directly
 or using futures or delayed objects to embed the data into the graph without repetition.
See also https://docs.dask.org/en/stable/best-practices.html#load-data-with-dask for more information.
This may cause some slowdown.
Consider loading the data with Dask directly
 or using futures or delayed objects to embed the data into the graph without repetition.
See also https://docs.dask.org/en/stable/best-practices.html#load-data-with-dask for more information.
This may cause some slowdown.
Consider loading the data with Dask directly
 or using futures or delayed objects to embed the data into the graph without repetition.
See also https://docs.dask.org/

10.9 s ± 1.46 s per loop (mean ± std. dev. of 5 runs, 1 loop each)


In [None]:
# check whether new detrend method accelerates the code

In [52]:
def polyval(coord, coeffs, degree_dim="degree"):
    x = coord.data

    deg_coord = coeffs[degree_dim]
    N = int(deg_coord.max()) + 1

    lhs = xr.DataArray(
        np.stack([x ** (N - 1 - i) for i in range(N)], axis=1),
        dims=(coord.name, degree_dim),
        coords={coord.name: coord, degree_dim: np.arange(deg_coord.max() + 1)[::-1]},
    )
    return (lhs * coeffs).sum(degree_dim)


# Function to detrend
# Modified from source: https://gist.github.com/rabernat/1ea82bb067c3273a6166d1b1f77d490f
def detrend_dim(da, dim, deg=1):
    """detrend along a single dimension."""
    # calculate polynomial coefficients
    p = da.polyfit(dim=dim, deg=deg, skipna=False)
    # first create a chunked version of the "ocean_time" dimension
    chunked_dim = xr.DataArray(
        dask.array.from_array(da[dim].data, chunks=da.chunksizes[dim]),
        dims=dim,
        name=dim,
    )
    fit = polyval(chunked_dim, p.polyfit_coefficients)
    # evaluate trend
    # remove the trend
    return da - fit

def test_dask_detrend_v2():
    # Create a sample 3D Dask DataArray with a linear trend
    time = np.arange(100)
    data = xr.DataArray(
        2 * time[:, None, None] + np.random.rand(100, 720, 720),
        coords=[('time', time), ('lat', np.arange(720)), ('lon', np.arange(720))],
        dims=('time', 'lat', 'lon')
    )
    dask_data = data.chunk({'time': -1, 'lat': 72, 'lon': 72})
    
    # Apply detrend
    detrended = detrend_dim(dask_data, dim='time')
    
    # Compute the result
    detrended_computed = detrended.compute()
    
    # Assert: Check the slope at a specific point (e.g., lat=0, lon=0)
    # np.polyfit needs 1D arrays, so we select one pixel
    pixel_series = detrended_computed.isel(lat=0, lon=0).values
    
    slope = np.polyfit(time, pixel_series, 1)[0]
    
    # The trend (slope 2) should be removed, so slope should be ~0
    assert abs(slope) < 1e-5

In [57]:
%timeit -r 5 -n 1 test_dask_detrend_v2()

This may cause some slowdown.
Consider loading the data with Dask directly
 or using futures or delayed objects to embed the data into the graph without repetition.
See also https://docs.dask.org/en/stable/best-practices.html#load-data-with-dask for more information.
This may cause some slowdown.
Consider loading the data with Dask directly
 or using futures or delayed objects to embed the data into the graph without repetition.
See also https://docs.dask.org/en/stable/best-practices.html#load-data-with-dask for more information.
This may cause some slowdown.
Consider loading the data with Dask directly
 or using futures or delayed objects to embed the data into the graph without repetition.
See also https://docs.dask.org/en/stable/best-practices.html#load-data-with-dask for more information.
This may cause some slowdown.
Consider loading the data with Dask directly
 or using futures or delayed objects to embed the data into the graph without repetition.
See also https://docs.dask.org/

6.37 s ± 134 ms per loop (mean ± std. dev. of 5 runs, 1 loop each)
