
**To do**

* ~~Handle all NaN case for land-sea.~~
* ~~Test on interpolated daily NDVI values.~~
* ~~Implement trends/slopes~~
* Have answer exported as netcdf given long processing time.
* Consider breaking Aus into six tiles and run these seperately so each run only takes a few hours. This could prevent the situation where several thousand SU are used only for the process to crash half way through

<!-- 
```
#performing regression
a = np.zeros((72,144,3))
for i in range(len(data.lat)):
    for j in range(len(data.lon)):
        a[i,j,:] = (Ridge().fit(np.array((dataU.isel(lev=2).x1.values[:,i,j].reshape(-1,1),
                    dataU.isel(lev=2).x2.values[:,i,j].reshape(-1,1),
                    dataU.isel(lev=2).x3.values[:,i,j].reshape(-1,1))).reshape(108,3),
                    data2U.y.values[:,i,j].reshape(108)).coef_)
dataU = data.assign_coords(varname=['x1','x2','x3'])
dataU['multiple_reg_coeff'] = (('lat','lon','varname'), a)
``` -->

In [None]:
import os
import gc
import sys
import warnings

import scipy
import numpy as np
import xarray as xr
import pandas as pd
from scipy import stats

import dask
import dask.array
from dask import delayed

import seaborn as sb
import contextily as ctx
import matplotlib.pyplot as plt
import matplotlib.colors as colors

from odc.geo.xr import assign_crs
# from scipy.stats import gaussian_kde
# from sklearn.metrics import r2_score
import pymannkendall as mk
# from xarrayMannKendall import Mann_Kendall_test

import sys
sys.path.append('/g/data/os22/chad_tmp/AusEFlux/src/')

%matplotlib inline

## Dask cluster

Local or Dynamic?

In [None]:
# '-q normal','-P w97','-l ncpus='+str(cores),'-l mem='+str(memory), '-l storage=gdata/os22+gdata/w97'

In [None]:
# import dask.config
# from dask.distributed import Client,LocalCluster
# from dask_jobqueue import PBSCluster
# walltime = '01:00:00'
# cores = 24
# memory = '50GB'

# cluster = PBSCluster(walltime=str(walltime), cores=cores, memory=str(memory),processes=cores,
#                      job_extra_directives=['-q normal','-P w97','-l ncpus='+str(cores),'-l mem='+str(memory),
#                                 '-l storage=gdata/os22+gdata/w97'],
#                      local_directory='$TMPDIR',
#                      job_directives_skip=["select"],
#                      # python=os.environ["DASK_PYTHON"]
# #                     )
# cluster.scale(jobs=2)
# client = Client(cluster)

from _utils import start_local_dask
start_local_dask(n_workers=24, threads_per_worker=1, memory_limit='93GiB')

## Open datasets

In [None]:
# ds = xr.open_dataset('/g/data/os22/chad_tmp/Aus_phenology/data/NDVI/NDVI_smooth_AusENDVI-clim_MCD43A4.nc')['NDVI']
# ds = ds.isel(latitude=slice(200,250), longitude=slice(200,250)) #testing sample
# # ds = ds.dropna(dim='time',
# #             how='all').resample(time='1W').interpolate(kind='quadratic')
# # ds = assign_crs(ds, crs='EPSG:4326')
# ds = ds.rename('NDVI')

# covars = xr.open_dataset('/g/data/os22/chad_tmp/Aus_phenology/data/covars.nc')

## Per pixel phenometrics with dask.delayed


In [None]:
def extract_peaks_troughs(da,
                      rolling=90,
                      distance=90,
                      prominence='auto',
                      plateau_size=10
                     ):
    
    """
    Identifying peaks and troughs in a vegetation index time series
    
    The algorithm builds upon those described by Broich et al. (2014).  
    The steps are as follows:
    
    1. Calculate rolling minimums
    2. Calculate rolling maximums
    3. Using scipy.signal.find_peaks, identify peaks and troughs in rolling max
        and min time-series using a minimum distance between peaks, a peak plateau size,
        and the peak/trough prominence.
    4. Remove peaks or troughs where troughs (peaks) occur sequentially without
       a peak (trough) between them. So enforce the pattern peak-valley-peak-valley etc.
       This is achieved by taking the maximum (minimum) peak (trough) where two peaks (troughs)
       occur sequentially.
    
    returns:
    --------
    A dictionary with keys the same as the input and the values a pandas.dataframe with peaks
    and trough identified by the time stamp they occur.
    
    """
    #check its an array
    if not isinstance(da, xr.DataArray):
        raise TypeError(
            "This function only excepts an xr.DataArray"
        )
    #doesn't matter what we call the variable just
    #need it to be predicatable
    # da = Y.isel(spatial=i)
    da.name = 'NDVI'
    
    #ensure ds has only time coordinates
    coords = list(da.coords)
    coords.remove('time')
    da = da.drop_vars(coords)
    
    #calculate rolling min/max to find local minima maxima
    roll_max = da.rolling(time=rolling, min_periods=1, center=True).max()
    roll_min = da.rolling(time=rolling, min_periods=1, center=True).min()
    
    if prominence=='auto':
        #dynamically determine how prominent a peak needs to be
        #based upon typical range of seasonal cycle
        clim = da.groupby('time.month').mean()
        _range = (clim.max() - clim.min()).values.item()
        if _range>=0.05:
            prominence = 0.01
        if _range<0.05:
            prominence = 0.005
    
    #find peaks and valleys
    peaks = scipy.signal.find_peaks(roll_max.data,
                        distance=distance,
                        prominence=prominence,
                        plateau_size=plateau_size)[0]
    
    troughs = scipy.signal.find_peaks(roll_min.data*-1,#invert
                        distance=distance,
                        prominence=prominence,
                        plateau_size=plateau_size)[0]
    
    #--------------cleaning-------
    # Identify where two peaks or two valleys occur one after another and remove.
    # i.e. enforcing the pattern peak-vally-peak-valleys etc.
    # First get the peaks and troughs into a dataframe with matching time index
    df = da.to_dataframe()
    df['peaks'] = da.isel(time=peaks).to_dataframe()
    df['troughs'] = da.isel(time=troughs).to_dataframe()
    df_peaks_troughs = df.drop('NDVI', axis=1).dropna(how='all')
    
    #find where two peaks or two troughs occur sequentially
    peaks_num_nans = df_peaks_troughs.peaks.isnull().rolling(2).sum()
    troughs_num_nans = df_peaks_troughs.troughs.isnull().rolling(2).sum()
    
    # Grab the indices where the rolling sum of NaNs equals 2.
    # The labelling is inverted here because two NaNs in the trough column
    # mean two peaks occur concurrently, and vice versa
    idx_consecutive_peaks = troughs_num_nans[troughs_num_nans==2.0]
    idx_consecutive_troughs = peaks_num_nans[peaks_num_nans==2.0]
    
    # Loop through locations with two sequential peaks and drop
    # the smaller of the two peaks
    for idx in idx_consecutive_peaks.index:
        
        loc = df_peaks_troughs.index.get_loc(idx)
        df = df_peaks_troughs.iloc[[loc-1,loc]]
        
        min_peak_to_drop = df.idxmin(skipna=True).peaks
        df_peaks_troughs = df_peaks_troughs.drop(min_peak_to_drop)
    
    # Loop through locations with two sequential troughs and drop
    # the higher of the two troughs (less prominent trough)
    for idx in idx_consecutive_troughs.index:
        
        loc = df_peaks_troughs.index.get_loc(idx)
        df = df_peaks_troughs.iloc[[loc-1,loc]]
        
        min_trough_to_drop = df.idxmax(skipna=True).troughs
        df_peaks_troughs = df_peaks_troughs.drop(min_trough_to_drop)
    
    return df_peaks_troughs

@dask.delayed
def phenometrics(da,
             rolling=90,
             distance=90,
             prominence='auto',
             plateau_size=10,
             amplitude=0.20,  
             verbose=True
            ):
    """
    Calculate statistics that describe the phenology cycle of
    a vegetation condition time series.
    
    Identifies the start and end points of each cycle using 
    the `seasonal amplitude` method. When the vegetation time series reaches
    20% of the sesonal amplitude between the first minimum and the peak,
    and the peak and the second minimum.
    
    To ensure we are measuring only full cycles we enforce the time series to
    start and end with a trough.
    
    Phenometrics calculated:
        * ``'SOS'``: DOY of start of season
        * ``'POS'``: DOY of peak of season
        * ``'EOS'``: DOY of end of season
        * ``'vSOS'``: Value at start of season
        * ``'vPOS'``: Value at peak of season
        * ``'vEOS'``: Value at end of season
        * ``'TOS'``: DOY of the minimum at the beginning of cycle (left of peak)
        * ``'vTOS'``: Value at the beginning of cycle (left of peak)
        * ``'LOS'``: Length of season (DOY)
        * ``'AOS'``: Amplitude of season (in value units)
        * ``'IOS'``: Integral of season (in value units)
        * ``'ROG'``: Rate of greening (value units per day)
        * ``'ROS'``: Rate of senescence (value units per day)
    
    returns:
    --------
    Dictionary where keys are the labels of the polygons, and values
    are Pandas.Dataframe containing phenometrics.
    
    """
    #Extract peaks and troughs in the timeseries
    peaks_troughs = extract_peaks_troughs(
                         da,
                         rolling=rolling,
                         distance=distance,
                         prominence=prominence,
                         plateau_size=plateau_size)
    
    # start the timeseries with trough
    if np.isnan(peaks_troughs.iloc[0].troughs):
        p_t = peaks_troughs.iloc[1:]
    else:
        p_t = peaks_troughs
    
    # end the timeseries with trough
    if np.isnan(p_t.iloc[-1].troughs)==True:
        p_t = p_t.iloc[0:-1]
        
    # Store phenology stats
    pheno = {}
    
    peaks_only = p_t.peaks.dropna()
    for peaks, idx in zip(peaks_only.index, range(0,len(peaks_only))):
        # First we extract the trough times either side of the peak
        start_time = p_t.iloc[p_t.index.get_loc(peaks)-1].name
        end_time = p_t.iloc[p_t.index.get_loc(peaks)+1].name
    
        # now extract the NDVI time series for the cycle
        ndvi_cycle = da.sel(time=slice(start_time, end_time))
        
        # add the stats to this
        vars = {}
       
        # --Extract phenometrics---------------------------------
        pos = ndvi_cycle.idxmax(skipna=True)
        vars['POS_year'] = pos.dt.year #so we can keep track
        vars['POS'] = pos.dt.dayofyear.values
        vars['vPOS'] = ndvi_cycle.max().values
        #we want the trough values from the beginning of the season only (left side)
        vars['TOS_year'] =  p_t.iloc[p_t.index.get_loc(peaks)-1].name.year
        vars['TOS'] = p_t.iloc[p_t.index.get_loc(peaks)-1].name.dayofyear
        vars['vTOS'] = p_t.iloc[p_t.index.get_loc(peaks)-1].troughs
        vars['AOS'] = (vars['vPOS'] - vars['vTOS'])
        
        #SOS ------ 
        # Find the greening cycle (left of the POS)
        greenup = ndvi_cycle.where(ndvi_cycle.time <= pos)
        # Find absolute distance between 20% of the AOS and the values of the greenup, then
        # find the NDVI value that's closest to 20% of AOS, this is our SOS date
        sos = np.abs(greenup - (vars['AOS'] * amplitude + vars['vTOS'])).idxmin(skipna=True)
        vars['SOS_year'] = sos.dt.year #so we can keep track
        vars['SOS'] = sos.dt.dayofyear.values
        vars['vSOS'] = ndvi_cycle.sel(time=sos).values
        
        #EOS ------
        # Find the senescing cycle (right of the POS)
        browning = ndvi_cycle.where(ndvi_cycle.time >= ndvi_cycle.idxmax(skipna=True))
        # Find absolute distance between 20% of the AOS and the values of the browning, then
        # find the NDVI value that's closest to 20% of AOS, this is our EOS date
        ampl_browning = browning.max() - browning.min()
        eos = np.abs(browning - (ampl_browning * amplitude + browning.min())).idxmin(skipna=True)
        vars['EOS_year'] = eos.dt.year #so we can keep track
        vars['EOS'] = eos.dt.dayofyear.values
        vars['vEOS'] = ndvi_cycle.sel(time=eos).values
    
        # LOS ---
        los = (pd.to_datetime(eos.values) - pd.to_datetime(sos.values)).days
        vars['LOS'] = los
    
        #Integral of season
        ios = ndvi_cycle.sel(time=slice(sos, eos))
        ios = ios.integrate(coord='time', datetime_unit='D')
        vars['IOS'] = ios
    
        # Rate of growth and sensecing (NDVI per day)
        vars['ROG'] = (vars['vPOS'] - vars['vSOS']) / ((pd.to_datetime(pos.values) - pd.to_datetime(sos.values)).days)
        vars['ROS'] = (vars['vEOS'] - vars['vPOS']) / ((pd.to_datetime(eos.values) - pd.to_datetime(pos.values)).days) 
        # print(vars)
        pheno[idx] = vars
    
    ds = pd.DataFrame(pheno).astype('float32').transpose().to_xarray()
    
    ds = ds.astype(np.float32)
    lat = da.latitude.item()
    lon = da.longitude.item()
    ds.assign_coords(latitude=lat, longitude=lon)
    
    for var in ds.data_vars:
        ds[var] = ds[var].expand_dims(latitude=[lat], longitude = [lon])
    
    return ds


### Handle NaNs
Due to issues with xarray quadratic interpolation, we need to remove every NaN or else the daily interpolation function will fail

In [None]:
ds = xr.open_dataset('/g/data/os22/chad_tmp/Aus_phenology/data/NDVI/NDVI_smooth_AusENDVI-clim_MCD43A4.nc')['NDVI']
ds = ds.isel(latitude=slice(175,250), longitude=slice(175,250))

##remove last ~6 timesteps that are all-NaN (from S-G smoothing).
times_to_keep = ds.mean(['latitude','longitude']).dropna(dim='time',how='any').time
ds = ds.sel(time=times_to_keep)

#Find where NaNs are >10 % of data, will use this mask to remove pixels later.
nan_mask = np.isnan(ds).sum('time') >= len(ds.time) / 10

#fill the mostly all NaN slices with a fill value
ds = xr.where(nan_mask, -99, ds)

#interpolate away any remaining NaNs
ds = ds.interpolate_na(dim='time', method='cubic', fill_value="extrapolate")

#now we can finally interpolate to daily
ds = ds.resample(time='1D').interpolate(kind='quadratic')

#export so we can test fnction
ds.to_netcdf('/g/data/os22/chad_tmp/Aus_phenology/data/ndvi_test.nc')

### Prepare data

In [None]:
# Open NDVI data
path='/g/data/os22/chad_tmp/Aus_phenology/data/ndvi_test.nc'
ds = xr.open_dataarray(path)

#stack spatial indexes, this makes it easy to loop through data
y_stack = ds.stack(spatial=('latitude', 'longitude'))
Y = y_stack.transpose('time', 'spatial')

#find spatial indexes where values are mostly NaN (mostly land-sea mask)
# This is where the nan_mask we created earlier == True
idx_all_nan = np.where(nan_mask.stack(spatial=('latitude', 'longitude'))==True)[0]

# open template array which we'll use 
# whenever we encounter an all-NaN index.
# Created the template using one of the output results
# bb = xr.full_like(results[0], fill_value=-99, dtype='float32')
template_path='/g/data/os22/chad_tmp/Aus_phenology/data/template.nc'
ds_template = xr.open_dataset(template_path)

### Apply phenometrics perpixel

In [None]:
# lazily loop through spatial indexes and append
# returned xarrays to list
results=[]
for i in range(Y.shape[1]):

    #select pixel
    data = Y.isel(spatial=i)
    
    # First, check if spatial index has data. If its one of 
    # the all-NaN indexes then return xarray filled with -99 values
    if i in idx_all_nan:
        xx = ds_template.copy() #use our template    
        xx['latitude'] = [data.latitude.values.item()] #update coords
        xx['longitude'] = [data.longitude.values.item()]
    
    else:
        xx = phenometrics(data,
                          rolling=90,
                          distance=90,
                          prominence='auto',
                          plateau_size=10,
                          amplitude=0.20
                         )

    #append results, either data or all-zeros
    results.append(xx)

#bring into memory
with warnings.catch_warnings():
    warnings.filterwarnings("ignore", category=FutureWarning)
    results = dask.compute(results)[0]

### Find average phenology

To do:

Add another mtric for the number of seasons

df['n_seasons'] = len(df)

In [None]:
%%time
@dask.delayed
def _mean(ds):
    return ds.mean('index')

p_average = [_mean(d) for d in results]
p_average = dask.compute(p_average)[0]
p_average = xr.combine_by_coords(p_average)

#remove NaN areas that have a fill value
p_average = p_average.where(p_average>-99)

### Plot average phenology

In [None]:
fig,axes=plt.subplots(4,3, figsize=(11,11), sharey=True, sharex=True, layout='constrained')
pheno_stats=['SOS','vSOS', 'LOS', 
             'POS', 'vPOS', 'ROG',
             'EOS', 'vEOS', 'ROS',
             'IOS', 'vTOS', 'AOS',
      ]
cmaps = ['twilight', 'gist_earth_r', 'viridis',
         'twilight', 'gist_earth_r','magma',
         'twilight', 'gist_earth_r', 'magma_r',
         'inferno', 'gist_earth_r','plasma'
        ]
for ax,pheno,cmap in zip(axes.ravel(), pheno_stats, cmaps):
    if "v" not in pheno:
        vmin, vmax=0, 365
        label='DOY'
    if "v" in pheno:
        vmin,vmax=0.1, 0.85
        label='NDVI'
    if 'LOS' in pheno:
        vmin, vmax=160, 300
        label='days'
    if 'AOS' in pheno:
        vmin, vmax=0.05, 0.4
        label='NDVI'
    if 'IOS' in pheno:
        vmin, vmax=20, 200
        label='NDVI/\n season'
    if 'ROG' in pheno:
        vmin, vmax=0.00025, 0.0025
        label='NDVI/\nday'
    if 'ROS' in pheno:
        vmin, vmax=-0.0025, -0.00025
        label='NDVI/\nday'
    im=p_average[pheno].plot(ax=ax, add_colorbar=False, cmap=cmap, vmin=vmin, vmax=vmax, add_labels=False)
    ctx.add_basemap(ax, source=ctx.providers.CartoDB.VoyagerNoLabels, crs='EPSG:4326', attribution='', attribution_size=1)
    ax.axes.get_xaxis().set_ticks([])
    ax.axes.get_yaxis().set_ticks([])
    #need to create colorbar manually to have label on top
    norm = colors.Normalize(vmin=vmin, vmax=vmax)
    cbar = plt.cm.ScalarMappable(norm=norm, cmap=cmap)
    ax_cbar = fig.colorbar(cbar, ax=ax, shrink=0.7)
    ax_cbar.ax.set_title(label, fontsize=8)
    ax.set_title(f'{pheno} 1982-2022')

## Find trends in phenology

In [None]:
@dask.delayed
def phenology_trends(ds, vars):
    slopes=[]
    p_values=[]
    for var in vars:

        #apply mankendall over 'index' dimension
        #this return 9 variables 
        out = xr.apply_ufunc(mk.original_test,
                      ds[var],
                      input_core_dims=[["index"]],
                      output_core_dims=[[],[],[],[],[],[],[],[],[]],
                      vectorize=True)
        
        #grab just the slope and p-value
        p = out[2].rename(var+'_p_value')
        s = out[7].rename(var+'_slope')
        
        slopes.append(s)
        p_values.append(p)

    #merge all the variables
    slopes_xr = xr.merge(slopes)
    p_values_xr = xr.merge(p_values)

    #export a dataset
    return xr.merge([slopes_xr,p_values_xr]).astype('float32')

In [None]:
%%time
trend_vars = ['POS','vPOS','TOS','vTOS','AOS','SOS',
              'vSOS','EOS','vEOS','LOS','IOS','ROG','ROS']
p_trends = [phenology_trends(d, trend_vars) for d in results]
p_trends = dask.compute(p_trends)[0]
p_trends = xr.combine_by_coords(p_trends)

#remove NaNs
p_trends = p_trends.where(~np.isnan(p_average.vPOS))

### Plot phenology trends

In [None]:
fig,axes=plt.subplots(4,3, figsize=(11,11),  layout='constrained')#sharey=True, sharex=True,
pheno_stats=['SOS','vSOS', 'LOS', 
             'POS', 'vPOS', 'ROG',
             'EOS', 'vEOS', 'ROS',
             'AOS', 'vTOS', 'IOS'
      ]
cmaps = ['coolwarm','BrBG','PRGn',
         'coolwarm','BrBG','Spectral',
         'coolwarm','BrBG','Spectral_r',
         'PiYG','BrBG','PuOr'
        ]
for ax,pheno,cmap in zip(axes.ravel(), pheno_stats, cmaps):
   
    if "v" not in pheno:
        vmin, vmax=-1.5,1.5
        label='days/\nyear'
    if "v" in pheno:
        vmin,vmax=-0.0015, 0.0015
        label='NDVI/\nyear'
    if 'LOS' in pheno:
        vmin, vmax=-1.5, 1.5
        label='days/\nyear'
    if 'AOS' in pheno:
        vmin, vmax=-0.002, 0.002
        label='NDVI\nyear'
    if 'ROG' in pheno:
        vmin, vmax=-2.0e-05, 2.0e-05
        label='NDVI/day/\nyear'
    if 'ROS' in pheno:
        vmin, vmax=-2.0e-05, 2.0e-05
        label='NDVI/day/\nyear'
    if 'IOS' in pheno:
        vmin, vmax=-0.5, 0.5
        label='NDVI/\nyear'

    d_to_plot = p_trends[pheno+'_slope']    
    im=d_to_plot.plot(ax=ax, add_colorbar=False, cmap=cmap, vmin=vmin, vmax=vmax, add_labels=False)
                             
    # significance plotting
    lons, lats = np.meshgrid(d_to_plot.longitude, d_to_plot.latitude)
    sig_area = np.where(p_trends[pheno+'_p_value'] <= 0.05)
    ax.hexbin(x=lons[sig_area].reshape(-1),
             y=lats[sig_area].reshape(-1),
             C=d_to_plot.where(p_trends[pheno+'_p_value'] <= 0.05).data[sig_area].reshape(-1),
             hatch='XXXX',
             alpha=0,
             gridsize=50#(10,10)
            )    

    ctx.add_basemap(ax, source=ctx.providers.CartoDB.VoyagerNoLabels, crs='EPSG:4326', attribution='', attribution_size=1)
    ax.axes.get_xaxis().set_ticks([])
    ax.axes.get_yaxis().set_ticks([])

    #need to create colorbar manually to have label on top
    norm = colors.Normalize(vmin=vmin, vmax=vmax)
    cbar = plt.cm.ScalarMappable(norm=norm, cmap=cmap)
    ax_cbar = fig.colorbar(cbar, ax=ax, shrink=0.7)
    ax_cbar.ax.set_title(label, fontsize=8)
    ax.set_title(f'{pheno}, 1982-2022')
    # print(pheno)
    
# axes[3,2].axis('off');
# fig.savefig(f'/g/data/os22/chad_tmp/Aus_phenology/results/figs/trends_phenometrics_map_{region_type}_{product}.png',
#             bbox_inches='tight', dpi=300)

## Regression between annual means and climate

testing per pixel PLS regression with dask delayed

### Annual means

In [None]:
# ds = ds.resample(time='YS').mean()

In [None]:
# rain = covars['rain']
# rain = rain.resample(time='YS').sum()

# covars = covars.drop('rain')
# covars = covars.resample(time='YS').mean()
# covars['rain'] = rain

# covars = covars.drop('trees')

### test dask.delayed method

This works fast but it has a serious memory leak...RAM explodes once you get close to 10,0000 pixels. ~100 GiB to run a 100x100 tile.

In [None]:
# from sklearn.cross_decomposition import PLSRegression

In [None]:
# def apply_pls_regression_dask(x, y, n_components=2):
#     """
#     Perform PLS regression along the time dimension of two xarray DataArrays using Dask for parallelization.
    
#     Parameters:
#     x (xr.DataArray): The independent variables with dimensions (time, latitude, longitude, variable).
#     y (xr.DataArray): The dependent variable with dimensions (time, latitude, longitude).
#     n_components (int): Number of PLS components to use.
    
#     Returns:
#     xr.Dataset: A dataset containing the PLS regression coefficients, and R-squared values.
#     """
#     # Ensure x has the shape (time, latitude, longitude, variables)
#     if 'variable' not in x.dims:
#         raise ValueError("x must have a 'variable' dimension representing independent variables.")
    
#     # Align the arrays to ensure they have matching time, latitude, longitude dimensions
#     x, y = xr.align(x, y)
    
#     # Stack latitude and longitude into a single dimension to simplify the regression calculation
#     x_stack = x.stack(spatial=('latitude', 'longitude'))
#     y_stack = y.stack(spatial=('latitude', 'longitude'))

#     # Convert to numpy arrays or dask arrays
#     X = x_stack.transpose('time', 'spatial', 'variable').data
#     Y = y_stack.transpose('time', 'spatial').data

#     # Perform the PLS regression at each spatial point using Dask
#     def compute_pls(i):
#         Xi = X[:, i, :]
#         Yi = Y[:, i]

#         #check for NaNs
#         mask = ~np.isnan(Yi)
        
#         Xi = Xi[mask]
#         Yi = Yi[mask]
        
#         # Initialize the PLS regression model
#         pls = PLSRegression(n_components=n_components)
        
#         # Fit the model
#         pls.fit(Xi, Yi)

#          #predict
#         Y_fit = pls.predict(Xi)

#         #find the slope of predicted and original
#         s_pred = mk.original_test(Y_fit, alpha=0.05)
#         s_actual = mk.original_test(Yi, alpha=0.05)

#         r2 = r2_score(Yi, Y_fit)
        
#         return pls.coef_.ravel(), r2, s_pred.slope, s_actual.slope

#     # Use Dask to parallelize the computation across spatial points
#     results = dask.array.compute([delayed(compute_pls)(i) for i in range(X.shape[1])])

#     # Unpack results
#     coefs, r_squared, slopes_pred, slopes_actual = zip(*results[0])
    
#     # Convert lists to arrays
#     coefs = np.array(coefs).T
#     r_squared = np.array(r_squared)
#     slopes_pred = np.array(slopes_pred)
#     slopes_actual = np.array(slopes_actual)
    
#     #retrun the results as xarrays
#     coefs = xr.DataArray(coefs, dims=('variable', 'spatial'),
#                          coords={'variable': x_stack.mean('time').variable,
#                                   'spatial': x_stack.spatial})
#     coefs = coefs.unstack('spatial')
#     coefs =  xr.Dataset({
#                 'CO2':coefs.variable[0],
#                 'srad':coefs.variable[1],
#                 'tavg':coefs.variable[2],
#                 'vpd':coefs.variable[3],
#                 'rain':coefs.variable[4]},
#                 coords={'latitude': coefs.latitude,
#                               'longitude':coefs.longitude})
    
#     r_squared = xr.DataArray(r_squared, dims=('spatial',), coords={'spatial': x_stack.spatial})
#     r_squared = r_squared.unstack('spatial')
    
#     slopes_pred = xr.DataArray(slopes_pred, dims=('spatial',), coords={'spatial': x_stack.spatial})
#     slopes_pred = slopes_pred.unstack('spatial')
    
#     slopes_actual = xr.DataArray(slopes_actual, dims=('spatial',), coords={'spatial': x_stack.spatial})
#     slopes_actual = slopes_actual.unstack('spatial')
    
#     r = xr.Dataset({
#         'r_squared': r_squared,
#         'slopes_predicted': slopes_pred,
#         'slopes_actual': slopes_actual,
#     })
    
#     return coefs, r

In [None]:
# ds_test = ds.isel(latitude=slice(200,300), longitude=slice(200,300))
# covars_test = covars.isel(latitude=slice(200,300), longitude=slice(200,300)).to_dataarray()

In [None]:
# %%time
# c, r = apply_pls_regression_dask(covars_test, ds, n_components=5)

## test ufunc

doesn't seem to run??

In [None]:
# ds_test = ds.isel(latitude=slice(200,300), longitude=slice(200,300))
# covars_test = covars.isel(latitude=slice(200,300), longitude=slice(200,300)).to_dataarray()

# ds_test = ds_test.chunk({'latitude': 20, 'longitude': 20})
# covars_test = covars_test.chunk({'latitude': 20, 'longitude': 20})


In [None]:
# %%time
# # Example usage (assuming your data is already loaded into x and y):
# result = apply_pls_regression_xarray(covars_test, ds_test, n_components=2)#.compute()
