In [1]:
import xarray as xr
import torch
import numpy as np 
import pandas as pd
from hython.models.convLSTM import ConvLSTM
from hython.datasets.datasets import get_dataset
from hython.sampler import SamplerBuilder, CubeletsDownsampler
from hython.trainer import HythonTrainer, RNNTrainParams, train_val
from hython.metrics import MSEMetric
from hython.losses import RMSELoss
from hython.normalizer import Normalizer
from hython.utils import write_to_zarr, read_from_zarr, set_seed


import torch.optim as optim
from torch.optim.lr_scheduler import ReduceLROnPlateau
from torch.utils.data import DataLoader

import matplotlib.pyplot as plt

  from .autonotebook import tqdm as notebook_tqdm


In [2]:

EXPERIMENT  = "exp1"

SURROGATE_INPUT = "https://eurac-eo.s3.amazonaws.com/INTERTWIN/SURROGATE_INPUT/adg1km_eobs_original.zarr/"

SURROGATE_MODEL_OUTPUT = f"path/to/model/output/directory/{EXPERIMENT}.pt"

TMP_STATS = "path/to/temporary/stats/directory" 

# === FILTER ==============================================================

# train/test temporal range
train_temporal_range = slice("2012-01-01","2022-12-31")
test_temporal_range = slice("2019-01-01", "2020-12-31")

# variables
dynamic_names = ["precip", "pet", "temp"] 
static_names = [ "thetaS", "thetaR", "KsatVer", "SoilThickness", "RootingDepth", "f", "Swood", "Sl", "Kext"]
target_names = ["vwc"]# ["vwc", "actevap", "snow", "snowwater"] 

# === MASK ========================================================================================

mask_names = ["mask_missing", "mask_lake"] # names depends on preprocessing application

# === DATASET ========================================================================================

DATASET = "CubeletsDataset" 

XSIZE,YSIZE, TSIZE = 10, 10, 360
XOVER,YOVER,TOVER = 5, 5, 220

MISSING_POLICY = 0.05 # "any", "all"

# == MODEL  ========================================================================================

HIDDEN_SIZE = 36 # 
DYNAMIC_INPUT_SIZE = len(dynamic_names)
STATIC_INPUT_SIZE = len(static_names)
KERNEL_SIZE = (3, 3) # height, width
NUM_LSTM_LAYER = 1
OUTPUT_SIZE = len(target_names)

TARGET_WEIGHTS = {t:1/len(target_names) for t in target_names}


# === SAMPLER/TRAINER ===================================================================================

# downsampling
DONWSAMPLING = False

TEMPORAL_FRAC = [0.8, 0.8] # train, test
SPATIAL_FRAC = [1, 1]  # train, test

# gradient clipping
gradient_clip = {"max_norm":1} # None

SEED = 42
EPOCHS = 20
BATCH = 32


assert (sum(v for v in TARGET_WEIGHTS.values()) - 1) < 0.01, "check target weights"

In [3]:
set_seed(SEED)
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")

In [4]:
# === READ TRAIN ===================================================================
Xd = (
    read_from_zarr(url=SURROGATE_INPUT , group="xd")
    .sel(time=train_temporal_range)[dynamic_names]
)
Xs = read_from_zarr(url=SURROGATE_INPUT , group="xs")[static_names]

Y = (
    read_from_zarr(url=SURROGATE_INPUT , group="y")
    .sel(time=train_temporal_range)[target_names]
)

SHAPE = Xd.attrs["shape"]

# === READ TEST ===================================================================

Y_test = (
    read_from_zarr(url=SURROGATE_INPUT , group="y")
    .sel(time=test_temporal_range)[target_names]
)
Xd_test = (
    read_from_zarr(url=SURROGATE_INPUT , group="xd")
    .sel(time=test_temporal_range)[dynamic_names]
)

In [5]:
Y

In [6]:
import numpy as np
import xarray as xr

__all__ = ["MSEMetric", "RMSEMetric"]


def metric_decorator(y_true, y_pred, target_names, sample_weight=None):
    def target(wrapped):
        def wrapper():
            metrics = {}
            for idx, target in enumerate(target_names):
                metrics[target] = wrapped(y_true[:, idx], y_pred[:, idx], sample_weight)
            return metrics 
        return wrapper
    return target

class Metric:
    """
    Hython is currently supporting sequence-to-one training (predicting the last time step of the sequence). Therefore it assumes that
    the shape of y_true and y_pred is (N, C).

    In the future it will also support sequence-to-sequence training for forecasting applications.

    TODO: In forecasting, the shape of y_true and y_pred is going to be (N,T,C), where T is the n of future time steps.

    """
    def __init__(self):
        pass

class MSEMetric(Metric):
    """
    Mean Squared Error (MSE)

    Parameters
    ----------
    y_pred (numpy.array): The true values.
    y_true (numpy.array): The predicted values.
    target_names: List of targets that contribute in the loss computation.

    Returns
    -------
    Dictionary of MSE metric for each target. {'target': mse_metric}
    
    """
    def __call__(self, y_pred, y_true, target_names: list[str]):
        return metric_decorator(y_pred, y_true, target_names)(compute_mse)()

class RMSEMetric(Metric):
    def __call__(self, y_pred, y_true, target_names: list[str]):
        return metric_decorator(y_pred, y_true, target_names)(compute_rmse)()
    

# == METRICS
# The metrics below should work for both numpy or xarray inputs. The usage of xarray inputs is supported as it is handy for lazy computations
# e.g. compute_mse(y_true.chunk(lat=100,lon=100), y_pred.chunk(lat=100,lon=100)).compute()



# DISCHARGE 

def compute_fdc_fms():
    """
    """
    pass 

def compute_fdc_fhv():
    """
    """
    pass 

def compute_fdc_flv():
    """
    """
    pass


# SOIL MOISTURE


def compute_hr():
    """Hit Rate, proportion of time soil is correctly simulated as wet.
        Wet threshold is when x >= 0.8 percentile
        Dry threshold is when x <= 0.2 percentile
    """
    pass 

def compute_far():
    """False Alarm Rate"""
    pass 

def compute_csi():
    """Critical success index"""
    pass


# GENERAL

def compute_variance(ds,dim="time", axis=0, std=False):
    if isinstance(ds, xr.DataArray):
        return ds.std(dim=dim) if std else ds.var(dim=dim)
    else:
        return np.std(ds, axis=axis) if std else np.var(ds, axis=axis) 
    
def compute_gamma(y_true: xr.DataArray, y_pred, axis=0):
    m1, m2 = np.mean(y_true, axis=axis), np.mean(y_pred, axis=axis)
    return (np.std(y_pred, axis=axis) / m2) / (np.std(y_true, axis=axis) / m1)
    
def compute_pbias(y_true: xr.DataArray, y_pred, dim="time", axis=0):
    if isinstance(y_true, xr.DataArray) or isinstance(y_pred, xr.DataArray):
         return 100 * ( (y_pred - y_true).mean(dim=dim, skipna=False) / np.abs(y_true).mean(dim=dim, skipna=False))
    else:
        return 100 * ( np.mean(y_pred - y_true, axis=axis) / np.mean(np.abs(y_true), axis=axis) )

def compute_bias(y_true: xr.DataArray, y_pred, dim="time", axis=0):
    if isinstance(y_true, xr.DataArray) or isinstance(y_pred, xr.DataArray):
         return  (y_pred - y_true).mean(dim=dim, skipna=False)
    else:
        return np.mean(y_pred - y_true, axis=axis) 

def compute_rmse(y_true, y_pred, dim="time", axis=0):
    if isinstance(y_true, xr.DataArray) or isinstance(y_pred, xr.DataArray):
        return np.sqrt(((y_pred - y_true) ** 2).mean(dim=dim, skipna=False))
    else:
        return np.sqrt(np.mean((y_pred - y_true) ** 2, axis=axis))
    
def compute_mse(y_true, y_pred, axis=0, dim="time", sample_weight=None):
    if isinstance(y_true, xr.DataArray) or isinstance(y_pred, xr.DataArray):
        return ((y_pred - y_true) ** 2).mean(dim=dim, skipna=False)
    else:
        return np.average((y_pred - y_true) ** 2, axis=axis, weights=sample_weight)




def kge_metric(y_true, y_pred, target_names):
    """
    The Kling Gupta efficiency metric

    Parameters:
    y_pred (numpy.array): The true values.
    y_true (numpy.array): The predicted values.
    targes: List of targets that contribute in the loss computation.

    Shape
    y_true: numpy.array of shape (N, T).
    y_pred: numpy.array of shape (N, T).

    Returns:
    Dictionary of kge metric for each target. {'target': kge_value}
    """

    metrics = {}

    for idx, target in enumerate(target_names):
        observed = y_true[:, idx]
        simulated = y_pred[:, idx]
        r = np.corrcoef(observed, simulated)[1, 0]
        alpha = np.std(simulated, ddof=1) / np.std(observed, ddof=1)
        beta = np.mean(simulated) / np.mean(observed)
        kge = 1 - np.sqrt(
            np.power(r - 1, 2) + np.power(alpha - 1, 2) + np.power(beta - 1, 2)
        )
        metrics[target] = kge

    return metrics



In [7]:
Y

In [8]:
import xarray as xr
import numpy as np

def create_random_dataset(original_ds: xr.Dataset) -> xr.Dataset:
    """
    Create a random xarray.Dataset based on the structure of an existing xarray.Dataset.
    
    Parameters:
    original_ds (xr.Dataset): The original dataset to copy structure from.
    
    Returns:
    xr.Dataset: A new dataset with the same dimensions and coordinates but with random data.
    """
    random_data_vars = {}
    
    # Loop through each variable in the original dataset
    for var_name, var_data in original_ds.data_vars.items():
        # Generate random data of the same shape as the original variable
        random_data = np.random.random(var_data.shape)
        
        # Create a new DataArray with the same coordinates and assign the random data
        random_data_vars[var_name] = xr.DataArray(
            data=random_data,
            dims=var_data.dims,
            coords=var_data.coords
        )
    
    # Create a new dataset with the same coordinates and dimensions
    random_ds = xr.Dataset(data_vars=random_data_vars, coords=original_ds.coords)
    
    return random_ds


randomised = create_random_dataset(Y)
randomised

In [9]:
import xarray as xr

def compute_hr(observed: xr.DataArray, simulated: xr.DataArray) -> tuple:
    """
    Hit Rate: Proportion of time soil is correctly simulated as wet and dry.
    
    Wet threshold is when x >= 80th percentile.
    Dry threshold is when x <= 20th percentile.
    
    Parameters:
    observed (xr.DataArray): Observed soil moisture data (lat, lon, time).
    simulated (xr.DataArray): Simulated soil moisture data (lat, lon, time).
    
    Returns:
    tuple: Wet threshold hit rate (%), Dry threshold hit rate (%).
    """
    
    # Compute the 80th and 20th percentiles for observed and simulated data along the time dimension
    observed_80th = observed.quantile(0.8, dim='time')
    simulated_80th = simulated.quantile(0.8, dim='time')
    
    observed_20th = observed.quantile(0.2, dim='time')
    simulated_20th = simulated.quantile(0.2, dim='time')

    # Create masks for "wet" periods (80th percentile) and "dry" periods (20th percentile)
    observed_wet = observed >= observed_80th
    simulated_wet = simulated >= simulated_80th
    
    observed_dry = observed <= observed_20th
    simulated_dry = simulated <= simulated_20th

    # Calculate the hit rate for "wet" periods
    wet_hits = (observed_wet & simulated_wet).sum(dim='time')
    total_wet_periods = observed_wet.sum(dim='time')

    total_wet_hits = wet_hits[list(wet_hits.data_vars)[0]].sum().values  # Convert to numpy array
    total_wet_periods_sum = total_wet_periods[list(total_wet_periods.data_vars)[0]].sum().values  # Convert to numpy array

    # Avoid division by zero in case there are no "wet" periods in observed data
    if total_wet_periods_sum == 0:
        wet_hit_rate = 0.0
    else:
        wet_hit_rate = (total_wet_hits / total_wet_periods_sum) * 100

    # Calculate the hit rate for "dry" periods
    dry_hits = (observed_dry & simulated_dry).sum(dim='time')
    total_dry_periods = observed_dry.sum(dim='time')

    total_dry_hits = dry_hits[list(dry_hits.data_vars)[0]].sum().values  # Convert to numpy array
    total_dry_periods_sum = total_dry_periods[list(total_dry_periods.data_vars)[0]].sum().values  # Convert to numpy array

    # Avoid division by zero in case there are no "dry" periods in observed data
    if total_dry_periods_sum == 0:
        dry_hit_rate = 0.0
    else:
        dry_hit_rate = (total_dry_hits / total_dry_periods_sum) * 100

    return float(wet_hit_rate), float(dry_hit_rate)


In [10]:
Y_array = Y.vwc
randomised_array = randomised.vwc 

In [15]:
import xarray as xr

def compute_hr(observed: xr.DataArray, simulated: xr.DataArray, wet_threshold_percentile: float = 0.8, dry_threshold_percentile: float = 0.2) -> tuple:
    """
    Hit Rate: Proportion of time soil is correctly simulated as wet and dry.
    
    Wet threshold is when x >= 80th percentile.
    Dry threshold is when x <= 20th percentile.
    
    Parameters:
    observed (xr.DataArray): Observed soil moisture data (lat, lon, time).
    simulated (xr.DataArray): Simulated soil moisture data (lat, lon, time).
    
    Returns:
    tuple: Wet threshold hit rate (%), Dry threshold hit rate (%).
    """
    
    # Compute the 80th and 20th percentiles for observed and simulated data along the time dimension
    observed_wet_quan = observed.quantile(wet_threshold_percentile, dim='time')
    simulated_wet_quan = simulated.quantile(wet_threshold_percentile, dim='time')
    
    observed_dry_quan = observed.quantile(dry_threshold_percentile, dim='time')
    simulated_dry_quan = simulated.quantile(dry_threshold_percentile, dim='time')

    # Create masks for "wet" periods (80th percentile) and "dry" periods (20th percentile)
    observed_wet = observed >= observed_wet_quan
    simulated_wet = simulated >= simulated_wet_quan
    
    observed_dry = observed <= observed_dry_quan
    simulated_dry = simulated <= simulated_dry_quan

    # Calculate the hit rate for "wet" periods
    wet_hits = (observed_wet & simulated_wet).sum(dim='time')
    total_wet_periods = observed_wet.sum(dim='time')

    total_wet_hits = wet_hits.sum().values  # Convert to numpy array
    total_wet_periods_sum = total_wet_periods.sum().values  # Convert to numpy array

    # Avoid division by zero in case there are no "wet" periods in observed data
    if total_wet_periods_sum == 0:
        wet_hit_rate = 0.0
    else:
        wet_hit_rate = (total_wet_hits / total_wet_periods_sum) * 100

    # Calculate the hit rate for "dry" periods
    dry_hits = (observed_dry & simulated_dry).sum(dim='time')
    total_dry_periods = observed_dry.sum(dim='time')

    total_dry_hits = dry_hits.sum().values  # Convert to numpy array
    total_dry_periods_sum = total_dry_periods.sum().values  # Convert to numpy array

    # Avoid division by zero in case there are no "dry" periods in observed data
    if total_dry_periods_sum == 0:
        dry_hit_rate = 0.0
    else:
        dry_hit_rate = (total_dry_hits / total_dry_periods_sum) * 100

    return float(wet_hit_rate), float(dry_hit_rate)


rate = compute_hr(Y_array, randomised_array)
rate

  return fnb._ureduce(a,


(20.01238316276588, 20.026956379807622)

In [11]:
import xarray as xr

def compute_hr(observed: xr.DataArray, simulated: xr.DataArray) -> tuple:
    """
    Hit Rate: Proportion of time soil is correctly simulated as wet and dry.
    
    Wet threshold is when x >= 80th percentile.
    Dry threshold is when x <= 20th percentile.
    
    Parameters:
    observed (xr.DataArray): Observed soil moisture data (lat, lon, time).
    simulated (xr.DataArray): Simulated soil moisture data (lat, lon, time).
    
    Returns:
    tuple: Wet threshold hit rate (%), Dry threshold hit rate (%).
    """
    
    # Compute the 80th and 20th percentiles for observed and simulated data along the time dimension
    observed_80th = observed.quantile(0.8, dim='time')
    simulated_80th = simulated.quantile(0.8, dim='time')
    
    observed_20th = observed.quantile(0.2, dim='time')
    simulated_20th = simulated.quantile(0.2, dim='time')

    # Create masks for "wet" periods (80th percentile) and "dry" periods (20th percentile)
    observed_wet = observed >= observed_80th
    simulated_wet = simulated >= simulated_80th
    
    observed_dry = observed <= observed_20th
    simulated_dry = simulated <= simulated_20th

    # Calculate the hit rate for "wet" periods
    wet_hits = (observed_wet & simulated_wet).sum(dim='time')
    total_wet_periods = observed_wet.sum(dim='time')

    total_wet_hits = wet_hits.sum().values  # Convert to numpy array
    total_wet_periods_sum = total_wet_periods.sum().values  # Convert to numpy array

    # Avoid division by zero in case there are no "wet" periods in observed data
    if total_wet_periods_sum == 0:
        wet_hit_rate = 0.0
    else:
        wet_hit_rate = (total_wet_hits / total_wet_periods_sum) * 100

    # Calculate the hit rate for "dry" periods
    dry_hits = (observed_dry & simulated_dry).sum(dim='time')
    total_dry_periods = observed_dry.sum(dim='time')

    total_dry_hits = dry_hits.sum().values  # Convert to numpy array
    total_dry_periods_sum = total_dry_periods.sum().values  # Convert to numpy array

    # Avoid division by zero in case there are no "dry" periods in observed data
    if total_dry_periods_sum == 0:
        dry_hit_rate = 0.0
    else:
        dry_hit_rate = (total_dry_hits / total_dry_periods_sum) * 100

    return float(wet_hit_rate), float(dry_hit_rate)


rate = compute_hr(Y_array, randomised_array)
rate

  return fnb._ureduce(a,


(20.01238316276588, 20.026956379807622)

In [18]:
import xarray as xr

def compute_hr(observed: xr.DataArray, simulated: xr.DataArray, wet_threshold_percentile: float = 0.8, dry_threshold_percentile: float = 0.2) -> dict:
    """
    Hit Rate: Proportion of time soil is correctly simulated as wet and dry.
    
    Wet threshold is when x >= 80th percentile.
    Dry threshold is when x <= 20th percentile.
    
    Parameters:
    observed (xr.DataArray): Observed soil moisture data (lat, lon, time).
    simulated (xr.DataArray): Simulated soil moisture data (lat, lon, time).
    
    Returns:
    tuple: Wet threshold hit rate (%), Dry threshold hit rate (%).
    """
    
    # Compute the 80th and 20th percentiles for observed and simulated data along the time dimension
    observed_wet_quan = observed.quantile(wet_threshold_percentile, dim='time')
    simulated_wet_quan = simulated.quantile(wet_threshold_percentile, dim='time')
    
    observed_dry_quan = observed.quantile(dry_threshold_percentile, dim='time')
    simulated_dry_quan = simulated.quantile(dry_threshold_percentile, dim='time')

    # Create masks for "wet" periods (80th percentile) and "dry" periods (20th percentile)
    observed_wet = observed >= observed_wet_quan
    simulated_wet = simulated >= simulated_wet_quan
    
    observed_dry = observed <= observed_dry_quan
    simulated_dry = simulated <= simulated_dry_quan

    # Calculate the hit rate for "wet" periods
    wet_hits = (observed_wet & simulated_wet).sum(dim='time')
    total_wet_periods = observed_wet.sum(dim='time')

    total_wet_hits = wet_hits.sum().values  # Convert to numpy array
    total_wet_periods_sum = total_wet_periods.sum().values  # Convert to numpy array

    # Avoid division by zero in case there are no "wet" periods in observed data
    if total_wet_periods_sum == 0:
        wet_hit_rate = 0.0
    else:
        wet_hit_rate = (total_wet_hits / total_wet_periods_sum) * 100

    # Calculate the hit rate for "dry" periods
    dry_hits = (observed_dry & simulated_dry).sum(dim='time')
    total_dry_periods = observed_dry.sum(dim='time')

    total_dry_hits = dry_hits.sum().values  # Convert to numpy array
    total_dry_periods_sum = total_dry_periods.sum().values  # Convert to numpy array

    # Avoid division by zero in case there are no "dry" periods in observed data
    if total_dry_periods_sum == 0:
        dry_hit_rate = 0.0
    else:
        dry_hit_rate = (total_dry_hits / total_dry_periods_sum) * 100

    hit_rate = {
    f'wet_threshold_{wet_threshold_percentile}_hit_rate': float(wet_hit_rate),
    f'dry_threshold_{dry_threshold_percentile}_hit_rate': float(dry_hit_rate)
    }

    print(hit_rate)
    
    return hit_rate

rate = compute_hr(Y_array, randomised_array)
rate

  return fnb._ureduce(a,


{'wet_threshold_0.8_hit_rate': 20.01238316276588, 'dry_threshold_0.2_hit_rate': 20.026956379807622}


{'wet_threshold_0.8_hit_rate': 20.01238316276588,
 'dry_threshold_0.2_hit_rate': 20.026956379807622}

In [19]:
import xarray as xr

def compute_far(observed: xr.DataArray, simulated: xr.DataArray, wet_threshold_percentile: float = 0.8, dry_threshold_percentile: float = 0.2) -> tuple:
    """
    Compute False Alarm Rate (FAR) for wet and dry predictions.

    Parameters:
    observed (xr.DataArray): Observed soil moisture data (lat, lon, time).
    simulated (xr.DataArray): Simulated soil moisture data (lat, lon, time).

    Returns:
    tuple: FAR for wet predictions (%), FAR for dry predictions (%).
    """
    
    # Compute the 80th and 20th percentiles for observed and simulated data along the time dimension
    observed_wet_quan = observed.quantile(wet_threshold_percentile, dim='time')
    simulated_wet_quan = simulated.quantile(wet_threshold_percentile, dim='time')

    observed_dry_quan = observed.quantile(dry_threshold_percentile, dim='time')
    simulated_dry_quan = simulated.quantile(dry_threshold_percentile, dim='time')

    # Create masks for "wet" and "dry" periods based on the percentiles
    observed_wet = observed >= observed_wet_quan
    simulated_wet = simulated >= simulated_wet_quan

    observed_dry = observed <= observed_dry_quan
    simulated_dry = simulated <= simulated_dry_quan

    # Calculate hits and false alarms for "wet" periods
    wet_hits = (observed_wet & simulated_wet).sum(dim='time')
    wet_false_alarms = (simulated_wet & ~observed_wet).sum(dim='time')

    # Sum hits and false alarms for "wet" across all spatial dimensions (lat, lon)
    total_wet_hits = wet_hits.sum().values  # Convert to numpy array
    total_wet_false_alarms = wet_false_alarms.sum().values  # Convert to numpy array

    # Calculate False Alarm Rate for wet conditions
    if (total_wet_hits + total_wet_false_alarms) == 0:
        wet_far = 0.0
    else:
        wet_far = (total_wet_false_alarms / (total_wet_false_alarms + total_wet_hits)) * 100  # As a percentage

    # Calculate hits and false alarms for "dry" periods
    dry_hits = (observed_dry & simulated_dry).sum(dim='time')
    dry_false_alarms = (simulated_dry & ~observed_dry).sum(dim='time')

    # Sum hits and false alarms for "dry" across all spatial dimensions (lat, lon)
    total_dry_hits = dry_hits.sum().values  # Convert to numpy array
    total_dry_false_alarms = dry_false_alarms.sum().values  # Convert to numpy array

    # Calculate False Alarm Rate for dry conditions
    if (total_dry_hits + total_dry_false_alarms) == 0:
        dry_far = 0.0
    else:
        dry_far = (total_dry_false_alarms / (total_dry_false_alarms + total_dry_hits)) * 100  # As a percentage

    far = {
    f'wet_threshold_{wet_threshold_percentile}_far': float(wet_far),
    f'dry_threshold_{dry_threshold_percentile}_far': float(dry_far)
    }

    print(far)
    
    return far


far = compute_far(Y_array, randomised_array)
far

  return fnb._ureduce(a,


{'wet_threshold_0.8_far': 92.23914681698632, 'dry_threshold_0.2_far': 92.26686076344734}


{'wet_threshold_0.8_far': 92.23914681698632,
 'dry_threshold_0.2_far': 92.26686076344734}

In [16]:
import xarray as xr

def compute_far(observed: xr.DataArray, simulated: xr.DataArray, wet_threshold_percentile: float = 0.8, dry_threshold_percentile: float = 0.2) -> tuple:
    """
    Compute False Alarm Rate (FAR) for wet and dry predictions.

    Parameters:
    observed (xr.DataArray): Observed soil moisture data (lat, lon, time).
    simulated (xr.DataArray): Simulated soil moisture data (lat, lon, time).

    Returns:
    tuple: FAR for wet predictions (%), FAR for dry predictions (%).
    """
    
    # Compute the 80th and 20th percentiles for observed and simulated data along the time dimension
    observed_wet_quan = observed.quantile(wet_threshold_percentile, dim='time')
    simulated_wet_quan = simulated.quantile(wet_threshold_percentile, dim='time')

    observed_dry_quan = observed.quantile(dry_threshold_percentile, dim='time')
    simulated_dry_quan = simulated.quantile(dry_threshold_percentile, dim='time')

    # Create masks for "wet" and "dry" periods based on the percentiles
    observed_wet = observed >= observed_wet_quan
    simulated_wet = simulated >= simulated_wet_quan

    observed_dry = observed <= observed_dry_quan
    simulated_dry = simulated <= simulated_dry_quan

    # Calculate hits and false alarms for "wet" periods
    wet_hits = (observed_wet & simulated_wet).sum(dim='time')
    wet_false_alarms = (simulated_wet & ~observed_wet).sum(dim='time')

    # Sum hits and false alarms for "wet" across all spatial dimensions (lat, lon)
    total_wet_hits = wet_hits.sum().values  # Convert to numpy array
    total_wet_false_alarms = wet_false_alarms.sum().values  # Convert to numpy array

    # Calculate False Alarm Rate for wet conditions
    if (total_wet_hits + total_wet_false_alarms) == 0:
        wet_far = 0.0
    else:
        wet_far = (total_wet_false_alarms / (total_wet_false_alarms + total_wet_hits)) * 100  # As a percentage

    # Calculate hits and false alarms for "dry" periods
    dry_hits = (observed_dry & simulated_dry).sum(dim='time')
    dry_false_alarms = (simulated_dry & ~observed_dry).sum(dim='time')

    # Sum hits and false alarms for "dry" across all spatial dimensions (lat, lon)
    total_dry_hits = dry_hits.sum().values  # Convert to numpy array
    total_dry_false_alarms = dry_false_alarms.sum().values  # Convert to numpy array

    # Calculate False Alarm Rate for dry conditions
    if (total_dry_hits + total_dry_false_alarms) == 0:
        dry_far = 0.0
    else:
        dry_far = (total_dry_false_alarms / (total_dry_false_alarms + total_dry_hits)) * 100  # As a percentage

    return float(wet_far), float(dry_far)

far = compute_far(Y_array, randomised_array)
far

  return fnb._ureduce(a,


(92.23914681698632, 92.26686076344734)

In [20]:
import xarray as xr

def compute_csi(observed: xr.DataArray, simulated: xr.DataArray, wet_threshold_percentile: float = 0.8, dry_threshold_percentile: float = 0.2) -> dict:
    """
    Compute the Critical Success Index (CSI) for wet and dry predictions.

    Parameters:
    observed (xr.DataArray): Observed soil moisture data (lat, lon, time).
    simulated (xr.DataArray): Simulated soil moisture data (lat, lon, time).

    Returns:
    tuple: CSI for wet predictions (%), CSI for dry predictions (%).
    """
    
    # Compute the 80th and 20th percentiles for observed and simulated data along the time dimension
    observed_wet_quan = observed.quantile(wet_threshold_percentile, dim='time')
    simulated_wet_quan = simulated.quantile(wet_threshold_percentile, dim='time')

    observed_dry_quan = observed.quantile(dry_threshold_percentile, dim='time')
    simulated_dry_quan = simulated.quantile(dry_threshold_percentile, dim='time')

    # Create masks for "wet" and "dry" periods based on the percentiles
    observed_wet = observed >= observed_wet_quan
    simulated_wet = simulated >= simulated_wet_quan

    observed_dry = observed <= observed_dry_quan
    simulated_dry = simulated <= simulated_dry_quan

    # Calculate hits, false alarms, and misses for "wet" periods
    wet_hits = (observed_wet & simulated_wet).sum(dim='time')
    wet_false_alarms = (simulated_wet & ~observed_wet).sum(dim='time')
    wet_misses = (~simulated_wet & observed_wet).sum(dim='time')

    # Sum hits, false alarms, and misses for "wet" across all spatial dimensions (lat, lon)
    total_wet_hits = wet_hits.sum().values  # Convert to numpy array
    total_wet_false_alarms = wet_false_alarms.sum().values  # Convert to numpy array
    total_wet_misses = wet_misses.sum().values  # Convert to numpy array

    # Calculate Critical Success Index for wet conditions
    csi_wet = (total_wet_hits / (total_wet_hits + total_wet_false_alarms + total_wet_misses)) * 100 if (total_wet_hits + total_wet_false_alarms + total_wet_misses) > 0 else 0.0

    # Calculate hits, false alarms, and misses for "dry" periods
    dry_hits = (observed_dry & simulated_dry).sum(dim='time')
    dry_false_alarms = (simulated_dry & ~observed_dry).sum(dim='time')
    dry_misses = (~simulated_dry & observed_dry).sum(dim='time')

    # Sum hits, false alarms, and misses for "dry" across all spatial dimensions (lat, lon)
    total_dry_hits = dry_hits.sum().values  # Convert to numpy array
    total_dry_false_alarms = dry_false_alarms.sum().values  # Convert to numpy array
    total_dry_misses = dry_misses.sum().values  # Convert to numpy array

    # Calculate Critical Success Index for dry conditions
    csi_dry = (total_dry_hits / (total_dry_hits + total_dry_false_alarms + total_dry_misses)) * 100 if (total_dry_hits + total_dry_false_alarms + total_dry_misses) > 0 else 0.0

    csi = {
    f'wet_threshold_{wet_threshold_percentile}_csi': float(csi_wet),
    f'dry_threshold_{dry_threshold_percentile}_csi': float(csi_dry)
    }

    print(csi)
    
    return csi

csi = compute_csi(Y_array, randomised_array)
csi

  return fnb._ureduce(a,


{'wet_threshold_0.8_csi': 5.923438122105611, 'dry_threshold_0.2_csi': 5.908549007857572}


{'wet_threshold_0.8_csi': 5.923438122105611,
 'dry_threshold_0.2_csi': 5.908549007857572}