# 0.1.5.1: Build sPlot functional diversity maps

## Imports and config

In [2]:
from pathlib import Path
from typing import Any

import dask.dataframe as dd
import numpy as np
import pandas as pd
import xarray as xr
from box import ConfigBox

from src.conf.conf import get_config
from src.conf.environment import detect_system, log
from src.utils.dask_utils import close_dask, init_dask
from src.utils.df_utils import rasterize_points, reproject_geo_to_xy
from src.utils.raster_utils import xr_to_raster
from src.utils.trait_utils import clean_species_name, filter_pft

cfg = get_config()

## Load the data

In [None]:
log.info("Starting sPlot map generation...")
sys_cfg = cfg[detect_system()][cfg.model_res]["build_splot_maps"]

# Setup ################
log.info("=== STAGE 1: Initial Setup ===")
splot_dir = Path(cfg.interim_dir, cfg.splot.interim.dir) / cfg.splot.interim.extracted
log.info("Using sPlot data from: %s", splot_dir)


# Check if we need to compute functional diversity metrics
trait_stats = cfg.datasets.Y.trait_stats
fd_metrics = ["f_ric", "f_eve", "f_div", "f_red", "sp_ric", "f_ric_ses"]
use_fd_approach = any(stat in fd_metrics for stat in trait_stats)


def _repartition_if_set(df: dd.DataFrame, npartitions: int | None) -> dd.DataFrame:
    return df.repartition(npartitions=npartitions) if npartitions is not None else df


# create dict of dask kws, but only if they are not None
dask_kws = {k: v for k, v in sys_cfg.dask.items() if v is not None}
log.info("=== STAGE 3: Dask Initialization ===")
log.info("Initializing Dask client with parameters: %s", dask_kws)
client, _ = init_dask(dashboard_address=cfg.dask_dashboard, **dask_kws)
# /Setup ################

[94m2025-04-22 15:11:15 CEST - src.conf.environment - INFO - Starting sPlot map generation...[0m
[94m2025-04-22 15:11:15 CEST - src.conf.environment - INFO - Detected system: nemo2[0m
[94m2025-04-22 15:11:15 CEST - src.conf.environment - INFO - === STAGE 1: Initial Setup ===[0m
[94m2025-04-22 15:11:15 CEST - src.conf.environment - INFO - Using sPlot data from: data/interim/splot/extracted[0m
[94m2025-04-22 15:11:15 CEST - src.conf.environment - INFO - === STAGE 3: Dask Initialization ===[0m
[94m2025-04-22 15:11:15 CEST - src.conf.environment - INFO - Initializing Dask client with parameters: {'n_workers': 40, 'threads_per_worker': 5, 'memory_limit': '40GB'}[0m
[94m2025-04-22 15:11:18 CEST - src.conf.environment - INFO - Dask dashboard URL: http://127.0.0.1:39143/status[0m


In [4]:
if use_fd_approach:
    # For functional diversity approach, we need to discover all PCA columns
    # First load the columns metadata to identify PCA columns
    log.info("Discovering PCA component columns for functional diversity metrics...")
    trait_file_path = (
        Path(cfg.interim_dir, cfg.trydb.interim.dir) / cfg.trydb.interim.filtered
    )
    trait_meta = dd.read_parquet(trait_file_path, calculate_divisions=False)
    pca_cols = [col for col in trait_meta.columns if col.startswith("PC")]

    log.info("Found %d PCA columns: %s", len(pca_cols), ", ".join(pca_cols))

    needed_columns = ["speciesname"] + pca_cols


[94m2025-04-22 15:11:18 CEST - src.conf.environment - INFO - Discovering PCA component columns for functional diversity metrics...[0m
[94m2025-04-22 15:11:18 CEST - src.conf.environment - INFO - Found 4 PCA columns: PC1, PC2, PC3, PC4[0m


Load the sPlot header data. This contains plot information and coordinates.

In [65]:
dd.read_parquet(splot_dir / "header.parquet").columns


Index(['PlotObservationID', 'Dataset', 'GIVD_NU', 'RESURVEY', 'RS_PLOT',
       'RS_OBSERV', 'Nested_in', 'Longitude', 'Latitude',
       'Location_uncertainty', 'Location_origin', 'Locality', 'Country',
       'Subregion', 'Continent', 'Date', 'Releve_area', 'Cover_scale',
       'Plants_recorded', 'Lichens_identified', 'Mosses_identified', 'FL_full',
       'FL_name', 'FL_code', 'FL_first', 'FL_second', 'FL_third', 'EUNIS',
       'EUNIS_old', 'EUNIS_coal', 'EUNIS_first', 'EUNIS_second', 'EUNIS_third',
       'Naturalness', 'Grassland', 'Shrubland', 'Forest', 'Wetland',
       'Sparse_vegetation', 'Cover_total', 'Cover_cryptogams', 'Cover_forbs',
       'Cover_bare_soil', 'Cover_bare_rock', 'Cover_open_water',
       'Cover_layer_litter', 'Cover_layer_algae', 'Cover_layer_lichen',
       'Cover_layer_moss', 'Cover_layer_herb', 'Cover_layer_shrub',
       'Cover_layer_tree', 'Max_height_cryptogams_mm',
       'Avg_height_low_herbs_cm', 'Avg_height_high_herbs_cm',
       'Max_height_he

In [None]:
def _filter_certain_plots(df: pd.DataFrame, givd_nu: str) -> pd.DataFrame:
    """Filter out certain plots."""
    return df[df["GIVD_NU"] != givd_nu]


log.info("Loading sPlot header data...")
header = (
    dd.read_parquet(
        splot_dir / "header.parquet",
        columns=["PlotObservationID", "Longitude", "Latitude", "GIVD_NU"],
    )
    .astype(
        {
            "PlotObservationID": "uint32[pyarrow]",
            "GIVD_NU": "category",
        }
    )
    .pipe(_repartition_if_set, sys_cfg.npartitions)
    .pipe(_filter_certain_plots, "00-RU-008")
    .drop(columns=["GIVD_NU"])
    .astype({"Longitude": np.float64, "Latitude": np.float64})
    .set_index("PlotObservationID")
    .map_partitions(reproject_geo_to_xy, to_crs=cfg.crs, x="Longitude", y="Latitude")
    .drop(columns=["Longitude", "Latitude"])
)
log.info("Header data loaded and processed")

[94m2025-04-22 15:11:18 CEST - src.conf.environment - INFO - Loading sPlot header data...[0m


[94m2025-04-22 15:11:18 CEST - src.conf.environment - INFO - Header data loaded and processed[0m


In [6]:
header.head()

Unnamed: 0_level_0,x,y
PlotObservationID,Unnamed: 1_level_1,Unnamed: 2_level_1
1,-14433690.0,6833109.0
2,-14433580.0,6833128.0
3,-14433890.0,6833086.0
4,-14434150.0,6833099.0
5,-14433930.0,6833053.0


Load the trait data

In [7]:
log.info("Loading trait data for columns: %s", ", ".join(needed_columns))

# Load pre-cleaned and filtered TRY traits and set species as index
traits = (
    dd.read_parquet(
        Path(cfg.interim_dir, cfg.trydb.interim.dir) / cfg.trydb.interim.filtered,
        columns=needed_columns,
    )
    .pipe(_repartition_if_set, sys_cfg.npartitions)
    .set_index("speciesname")
)
log.info("Trait data loaded and indexed")

[94m2025-04-22 15:11:26 CEST - src.conf.environment - INFO - Loading trait data for columns: speciesname, PC1, PC2, PC3, PC4[0m
[94m2025-04-22 15:11:26 CEST - src.conf.environment - INFO - Trait data loaded and indexed[0m


In [8]:
traits.head()

Unnamed: 0_level_0,PC1,PC2,PC3,PC4
speciesname,Unnamed: 1_level_1,Unnamed: 2_level_1,Unnamed: 3_level_1,Unnamed: 4_level_1
abarema adenophora,5.840701,0.398256,0.723679,4.098356
abarema adenophorum,2.897819,0.518998,-0.611865,2.608451
abarema alexandri,2.395,-0.759245,-1.042815,2.819884
abarema barbouriana,1.6963,-1.297582,-1.096845,2.978906
abarema cochleata,2.251713,-0.984658,-1.032982,2.98829


Load PFT data

In [None]:
# Load PFT data, filter by desired PFT, clean species names, and set them as index
# for joining
log.info("Loading and processing PFT data...")
pft_path = Path(cfg.raw_dir, cfg.trydb.raw.pfts)
if pft_path.suffix == ".csv":
    pfts = dd.read_csv(Path(cfg.raw_dir, cfg.trydb.raw.pfts), encoding="latin-1")
elif pft_path.suffix == ".parquet":
    pfts = dd.read_parquet(Path(cfg.raw_dir, cfg.trydb.raw.pfts))
else:
    raise ValueError(f"Unsupported PFT file format: {pft_path.suffix}")

pfts = (
    pfts.astype(
        {
            "AccSpeciesName": "string[pyarrow]",
            "pft": "category",
        }
    )
    .pipe(_repartition_if_set, sys_cfg.npartitions)
    .pipe(filter_pft, cfg.PFT)
    .drop(columns=["AccSpeciesID"])
    .dropna(subset=["AccSpeciesName"])
    .pipe(clean_species_name, "AccSpeciesName", "speciesname")
    .drop(columns=["AccSpeciesName"])
    .drop_duplicates(subset=["speciesname"])
    .set_index("speciesname")
)
log.info("PFT data loaded and processed")

[94m2025-04-22 15:11:28 CEST - src.conf.environment - INFO - Loading and processing PFT data...[0m
[94m2025-04-22 15:11:28 CEST - src.conf.environment - INFO - PFT data loaded and processed[0m


In [10]:
pfts.head()



Unnamed: 0_level_0,pft
speciesname,Unnamed: 1_level_1
aa argyrolepis,Grass
aa calceata,Shrub
aa colombiana,Grass
aa fiebrigii,Shrub
aa hartwegii,Grass


Load vegetation data and merge with PFTs and trait data. We'll persist the dataframe since it will be used in later calculations.

In [None]:
# Load sPlot vegetation records, clean species names, match with desired PFT, and
# merge with trait data
log.info("Loading and processing sPlot vegetation data...")
merged = (
    dd.read_parquet(
        splot_dir / "vegetation.parquet",
        columns=[
            "PlotObservationID",
            "Species",
            "Rel_Abund_Plot",
        ],
    )
    .astype(
        {
            "PlotObservationID": "uint32[pyarrow]",
            "Species": "string[pyarrow]",
            "Rel_Abund_Plot": "float64[pyarrow]",
        }
    )
    .pipe(_repartition_if_set, sys_cfg.npartitions)
    .dropna(subset=["Species"])
    .pipe(clean_species_name, "Species", "speciesname")
    .drop(columns=["Species"])
    .set_index("speciesname")
    .join(pfts, how="inner")
    .join(traits, how="inner")
    .reset_index()
    .drop(columns=["pft", "speciesname"])
    .compute()
)
log.info("Data merging complete")

[94m2025-04-22 15:11:35 CEST - src.conf.environment - INFO - Loading and processing sPlot vegetation data...[0m
[94m2025-04-22 15:14:25 CEST - src.conf.environment - INFO - Data merging complete[0m


In [12]:
print("Merged DF shape: ", merged.shape)
merged.head()

Merged DF shape:  (40316038, 6)


Unnamed: 0,PlotObservationID,Rel_Abund_Plot,PC1,PC2,PC3,PC4
0,2552892,0.035714,5.840701,0.398256,0.723679,4.098356
1,2555092,0.047619,5.840701,0.398256,0.723679,4.098356
2,2555108,0.018868,5.840701,0.398256,0.723679,4.098356
3,2552804,0.071429,2.101126,-1.562004,-1.777728,1.707156
4,2553185,0.058824,2.101126,-1.562004,-1.777728,1.707156


Group by vegetation plot and remove plots for which the TRY-matched observations don't cover at least 80% of the abundance and that have fewer than 3 observations.

In [74]:
# Filter groups based on abundance and observation count
df_by_plots = merged.groupby("PlotObservationID").filter(
    lambda x: (x["Rel_Abund_Plot"].sum() >= 0.8) and (len(x) >= 3)
)

In [None]:
print("Original DF shape:", merged.shape)
print("Filtered DF shape:", df_by_plots.shape)
print(
    "Number of plots removed:",
    merged["PlotObservationID"].nunique() - df_by_plots["PlotObservationID"].nunique(),
)
print(
    "Percentage of plots kept:",
    np.round(
        (
            df_by_plots["PlotObservationID"].nunique()
            / merged["PlotObservationID"].nunique()
            * 100
        ),
        2,
    ),
    "%",
)


Original DF shape: (40316038, 6)
Filtered DF shape: (22701202, 6)
Number of plots removed: 1251767
Percentage of plots kept: 50.05 %


## Define functional diversity equations

In [68]:
def _fd_stats(
    g: pd.DataFrame,
    pca_cols: list,
    stats: list[str],
    include_ses: bool = False,
    random_seed: int | None = None,
) -> pd.Series:
    """Calculate all functional diversity stats per plot.

    Args:
        g: DataFrame group containing species observations for a single plot
        pca_cols: List of column names for PCA components
        include_ses: Whether to include standardized effect size for functional richness
        random_seed: Random seed for SES calculations

    Returns:
        Series with functional diversity metrics
    """
    # Check if we have enough data to calculate metrics
    if g.empty or len(g) < 2:
        result = pd.Series(
            [np.nan] * len(stats),
            index=stats,
        )
        return result

    # Extract trait matrix and normalize abundances
    trait_matrix = g[pca_cols].values

    # Convert abundances to numpy array before normalization to avoid ArrowExtensionArray issues
    abundances_np = g["Rel_Abund_Plot"].to_numpy()
    abundances_sum = abundances_np.sum()
    normalized_abund = (
        abundances_np / abundances_sum
        if abundances_sum > 0
        else np.ones_like(abundances_np) / len(abundances_np)
    )

    calculated_stats = {s: 0.0 for s in stats}
    if "sp_ric" in stats:
        # Calculate species richness (number of species in the plot)
        sp_richness = len(g)
        calculated_stats["sp_ric"] = sp_richness
    if "f_ric" in stats:
        # Calculate functional richness
        f_ric = _calculate_fric(trait_matrix, plot_id=g.name, S=len(g))
        calculated_stats["f_ric"] = f_ric
    if "f_eve" in stats:
        # Calculate functional evenness
        f_eve = calculate_functional_evenness(trait_matrix, normalized_abund)
        calculated_stats["f_eve"] = f_eve
    if "f_div" in stats:
        # Calculate functional divergence
        f_div = calculate_mean_pairwise_dissimilarity(trait_matrix, normalized_abund)
        calculated_stats["f_div"] = f_div
    if "f_red" in stats:
        # Calculate functional redundancy (1 - functional divergence)
        f_red = 1 - f_div if not np.isnan(f_div) else np.nan
        calculated_stats["f_red"] = f_red

    result = pd.Series(
        calculated_stats,
        index=stats,
    )

    # # Calculate standardized effect size for functional richness if requested
    # if include_ses and not np.isnan(f_ric) and len(g) > len(pca_cols):
    #     f_ric_ses = _calculate_ses_fric(trait_matrix, random_seed=random_seed)
    #     result["cf_ric_ses"] = f_ric_ses

    return result


def _calculate_fric(
    trait_matrix: np.ndarray,
    plot_id: int | str,
    S: int,
    n_permutations: int = 999,
    random_seed: int | None = None,
) -> float:
    """
    Calculate functional richness (FRic) for a given plot.

    Args:
        trait_matrix: Array of trait values for species in the plot
        plot_id: ID of the plot
        S: Number of species in the plot
        n_permutations: Number of permutations for null distribution
    """

    # Extract trait data for species in this plot
    n_dims = trait_matrix.shape[1]

    # Skip if not enough species for convex hull
    if S <= n_dims:
        log.warning("Not enough species for convex hull for plot %s", plot_id)
        return np.nan

    # Calculate observed functional richness
    try:
        from scipy.spatial import ConvexHull

        hull = ConvexHull(trait_matrix)
        observed_fric = hull.volume
    except Exception:
        # Cannot compute hull
        log.warning("Cannot compute hull for plot %s", plot_id)
        return np.nan

    # Generate null distributions by randomizing trait values
    null_fric_values = []
    rng = np.random.RandomState(random_seed)

    for _ in range(n_permutations):
        # Randomly shuffle each trait column independently
        null_matrix = np.zeros_like(trait_matrix)
        for j in range(n_dims):
            null_matrix[:, j] = rng.choice(trait_matrix[:, j], size=S, replace=False)

        try:
            null_hull = ConvexHull(null_matrix)
            null_fric = null_hull.volume
            null_fric_values.append(null_fric)
        except Exception:
            null_fric_values.append(np.nan)

    # Calculate standardized effect size
    null_fric_values = np.array([v for v in null_fric_values if not np.isnan(v)])

    if len(null_fric_values) > 0:
        null_mean = np.mean(null_fric_values)
        null_std = np.std(null_fric_values)
        if null_std > 0:
            ses = (observed_fric - null_mean) / null_std
        else:
            log.warning(
                "Standard deviation of null FRic values is zero for plot %s", plot_id
            )
            ses = np.nan
    else:
        log.warning("No null FRic values for plot %s", plot_id)
        ses = np.nan

    return ses


def calculate_functional_evenness(
    trait_matrix: np.ndarray, abundances: np.ndarray | None = None
) -> float:
    from scipy.spatial.distance import pdist, squareform
    from scipy.sparse.csgraph import minimum_spanning_tree
    import numpy as np

    # Calculate pairwise distances between species in trait space
    dist_matrix = squareform(pdist(trait_matrix, metric="euclidean"))

    # Get minimum spanning tree
    mst = minimum_spanning_tree(dist_matrix).toarray()

    # Get branch lengths from MST
    branch_lengths = []
    for i in range(len(trait_matrix)):
        for j in range(i + 1, len(trait_matrix)):
            if mst[i, j] > 0:
                # Weight branch by species abundances if provided
                weight = 1.0
                if abundances is not None:
                    weight = abundances[i] * abundances[j]
                branch_lengths.append((mst[i, j], weight))

    # Sort branch lengths
    branch_lengths.sort()

    # Number of species
    S = len(trait_matrix)

    # Sum of minimum weighted branch lengths
    weighted_distances = [min(bl[0], 1 / (S - 1)) * bl[1] for bl in branch_lengths]

    # Calculate evenness as deviation from regular spacing
    PEW = np.sum(weighted_distances) / (S - 1)
    FEve = (PEW - 1 / (S - 1)) / (1 - 1 / (S - 1))

    return FEve


def calculate_mean_pairwise_dissimilarity(
    trait_matrix: np.ndarray, abundances: np.ndarray, metric: Any = "euclidean"
) -> float:
    """
    Calculate Mean Pairwise Dissimilarity (MPD).
    MPD = Rao / Simpson

    Args:
        trait_matrix: Matrix of species trait values (species × traits)
        abundances: Abundance weights for each species
        metric: Distance metric to use (default: "euclidean")

    Returns:
        Mean Pairwise Dissimilarity (MPD)
    """
    rao = calculate_rao_quadratic_entropy(trait_matrix, abundances, metric)
    simpson = calculate_simpson_diversity(abundances)

    # Avoid division by zero
    if simpson > 0:
        return rao / simpson
    else:
        log.warning("Simpson index is zero, returning NaN")
        return np.nan


def calculate_simpson_diversity(abundances: np.ndarray) -> float:
    """
    Calculate Simpson diversity index.

    Args:
        abundances: Array of species abundances

    Returns:
        Simpson diversity (1 - sum of squared relative abundances)
    """
    # Normalize abundances if they don't sum to 1
    if not np.isclose(np.sum(abundances), 1.0):
        rel_abundances = abundances / np.sum(abundances)
    else:
        rel_abundances = abundances.copy()

    # Simpson index = 1 - sum(p_i^2)
    simpson = 1 - np.sum(rel_abundances**2)

    assert isinstance(simpson, float)
    return simpson


def calculate_rao_quadratic_entropy(
    trait_matrix: np.ndarray, abundances: np.ndarray, metric: Any = "euclidean"
) -> float:
    """
    Calculate Rao's quadratic entropy.

    Args:
        trait_matrix: Matrix of species trait values (species × traits)
        abundances: Abundance weights for each species
        metric: Distance metric to use (default: "euclidean")

    Returns:
        Rao's quadratic entropy (abundance-weighted mean of trait distances)
    """
    if len(trait_matrix) < 2:
        return np.nan

    from scipy.spatial.distance import pdist, squareform

    # Normalize abundances
    if not np.isclose(np.sum(abundances), 1.0):
        rel_abundances = abundances / np.sum(abundances)
    else:
        rel_abundances = abundances.copy()

    # Calculate distance matrix
    dist_matrix = squareform(pdist(trait_matrix, metric=metric))

    # Calculate Rao's quadratic entropy
    rao_qe = 0
    for i in range(len(rel_abundances)):
        for j in range(len(rel_abundances)):
            rao_qe += rel_abundances[i] * rel_abundances[j] * dist_matrix[i, j]

    return rao_qe / 2  # Divide by 2 because we double-count each pair

## Calculate plot-wise FD metrics

In [78]:
print("Number of plots in df_by_plots:", df_by_plots.PlotObservationID.nunique())

Number of plots in df_by_plots: 1254124


In [77]:
sample_plots = df_by_plots.PlotObservationID.unique().to_numpy()
np.random.seed(42)
sample_plots = np.random.choice(sample_plots, size=10000, replace=False)

In [79]:
df_by_plots_sample = df_by_plots.query("PlotObservationID in @sample_plots").groupby(
    "PlotObservationID"
)

In [80]:
trait_stats = ["f_ric", "f_eve", "f_div", "f_red", "sp_ric"]

In [81]:
fd_df = df_by_plots_sample.apply(_fd_stats, pca_cols, stats=trait_stats, include_groups=False)

fd_df



TypeError: remove: path should be string, bytes or os.PathLike, not NoneType

: 