# Model experiments

In [3]:
import os
import multiprocessing

os.environ["XLA_FLAGS"] = "--xla_force_host_platform_device_count={}".format(
    multiprocessing.cpu_count()
)

In [4]:
import pandas as pd
import pymc as pm
import numpy as np
import polars as pl
import arviz as az
import xarray as xr
import matplotlib.pyplot as plt
import seaborn as sns
from typing import Union
import pytensor.tensor as pt

In [10]:
# Load black for formatting
import jupyter_black

jupyter_black.load()

In [11]:
import jax
from jax import devices

print(jax.devices())
print(jax.local_device_count())
# gpu_device = jax.devices("METAL")[0]
# cpu_device = jax.devices("cpu")[0]
# print(gpu_device)
# print(cpu_device)

[CpuDevice(id=0), CpuDevice(id=1), CpuDevice(id=2), CpuDevice(id=3), CpuDevice(id=4), CpuDevice(id=5), CpuDevice(id=6), CpuDevice(id=7), CpuDevice(id=8), CpuDevice(id=9)]
10


In [12]:
import warnings

warnings.filterwarnings("ignore", category=FutureWarning, module="arviz")

In [13]:
plt.rcParams["grid.color"] = "grey"
plt.rcParams["grid.linestyle"] = "-"
plt.rcParams["grid.linewidth"] = 0.5
plt.rcParams["axes.grid"] = True
plt.rcParams["axes.axisbelow"] = True

## Load training data

In [5]:
# Load model data
df = pl.read_parquet("../../data/runs/run_folder_2024-04-16_18.14/model_data.parquet")

In [15]:
df.head()

SS,SSB,SSBS,Biome,Realm,Ecoregion,Biome_Realm,Biome_Realm_Ecoregion,Max_scaled_abundance,Primary vegetation_Light use,Primary vegetation_Intense use,Secondary vegetation_Minimal use,Secondary vegetation_Light use,Secondary vegetation_Intense use,Cropland_Minimal use,Cropland_Light_Intense,Pasture_Minimal use,Pasture_Light_Intense,Urban_All uses,Pop_density_1km_log,Road_density_50km_cbrt,Mean_pop_density_1km_log,Annual_mean_temp_1km,Temp_seasonality_1km,Max_temp_warmest_month_1km,Min_temp_coldest_month_1km,Annual_precip_1km,Precip_wettest_month_1km,Precip_driest_month_1km,Precip_seasonality_1km,Longitude,Latitude,Primary vegetation_Light use x Pop_density_1km_log,Primary vegetation_Intense use x Pop_density_1km_log,Secondary vegetation_Minimal use x Pop_density_1km_log,Secondary vegetation_Light use x Pop_density_1km_log,Secondary vegetation_Intense use x Pop_density_1km_log,Cropland_Minimal use x Pop_density_1km_log,Cropland_Light_Intense x Pop_density_1km_log,Pasture_Minimal use x Pop_density_1km_log,Pasture_Light_Intense x Pop_density_1km_log,Urban_All uses x Pop_density_1km_log,Primary vegetation_Light use x Road_density_50km_cbrt,Primary vegetation_Intense use x Road_density_50km_cbrt,Secondary vegetation_Minimal use x Road_density_50km_cbrt,Secondary vegetation_Light use x Road_density_50km_cbrt,Secondary vegetation_Intense use x Road_density_50km_cbrt,Cropland_Minimal use x Road_density_50km_cbrt,Cropland_Light_Intense x Road_density_50km_cbrt,Pasture_Minimal use x Road_density_50km_cbrt,Pasture_Light_Intense x Road_density_50km_cbrt,Urban_All uses x Road_density_50km_cbrt
str,str,str,str,str,str,str,str,f64,u8,u8,u8,u8,u8,u8,i32,u8,i32,i32,f64,f64,f64,f64,f64,f64,f64,f64,f64,f64,f64,f64,f64,f64,f64,f64,f64,f64,f64,f64,f64,f64,f64,f64,f64,f64,f64,f64,f64,f64,f64,f64,f64
"""AD1_2009__Verg…","""AD1_2009__Verg…","""AD1_2009__Verg…","""Tropical & Sub…","""Neotropic""","""Veracruz Moist…","""Tropical & Sub…","""Tropical & Sub…",0.424419,0,0,0,1,0,0,0,0,0,0,0.525598,-0.256285,0.975725,-1.084831,0.861193,-0.351666,-1.276192,-0.739761,-0.330739,-0.372573,0.528716,-1.342541,1.246605,0.0,0.0,0.0,0.525598,0.0,0.0,0.0,0.0,0.0,0.0,-0.0,-0.0,-0.0,-0.256285,-0.0,-0.0,-0.0,-0.0,-0.0,-0.0
"""AD1_2013__Gras…","""AD1_2013__Gras…","""AD1_2013__Gras…","""Tropical & Sub…","""Afrotropic""","""Kwazulu-Cape C…","""Tropical & Sub…","""Tropical & Sub…",0.217391,1,0,0,0,0,0,0,0,0,0,-1.027548,0.301278,-0.973692,-1.028981,1.09135,-0.933438,-1.233357,-1.339244,-0.932684,-0.53846,-0.420373,0.293829,-2.805349,-1.027548,-0.0,-0.0,-0.0,-0.0,-0.0,-0.0,-0.0,-0.0,-0.0,0.301278,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0
"""AD1_2013__Gras…","""AD1_2013__Gras…","""AD1_2013__Gras…","""Tropical & Sub…","""Afrotropic""","""Kwazulu-Cape C…","""Tropical & Sub…","""Tropical & Sub…",0.206634,0,0,0,0,0,0,0,0,0,0,-0.842682,0.166074,-0.973692,-1.194542,0.961351,-1.14395,-1.263549,-1.328775,-0.92383,-0.508164,-0.380644,0.293347,-2.812924,-0.0,-0.0,-0.0,-0.0,-0.0,-0.0,-0.0,-0.0,-0.0,-0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0
"""AR1_2008__John…","""AR1_2008__John…","""AR1_2008__John…","""Tropical & Sub…","""Neotropic""","""Tumbes-Piura D…","""Tropical & Sub…","""Tropical & Sub…",0.037826,0,0,0,0,0,0,0,0,0,1,1.616002,-0.615943,-0.001067,-0.035174,1.161973,0.759621,-0.527562,-2.403766,-1.316992,-0.968259,1.614289,-1.133619,-0.737392,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,1.616002,-0.0,-0.0,-0.0,-0.0,-0.0,-0.0,-0.0,-0.0,-0.0,-0.615943
"""AR1_2008__Nava…","""AR1_2008__Nava…","""AR1_2008__Nava…","""Tropical & Sub…","""Neotropic""","""Petén-Veracruz…","""Tropical & Sub…","""Tropical & Sub…",0.21582,0,0,0,0,0,0,0,0,0,0,-1.854169,-1.23998,-1.632876,0.758963,0.35421,1.219464,0.316739,0.83193,0.334331,0.315482,0.095061,-1.265546,1.001619,-0.0,-0.0,-0.0,-0.0,-0.0,-0.0,-0.0,-0.0,-0.0,-0.0,-0.0,-0.0,-0.0,-0.0,-0.0,-0.0,-0.0,-0.0,-0.0,-0.0


## Generic functions

In [16]:
def transform_response_variable(df: pl.DataFrame, method: str) -> pl.DataFrame:

    adjust = 0.001
    original_col_name = "Max_scaled_abundance"
    transformed_col_name = original_col_name

    # Small adjustment to align with support for Beta distribution
    if method == "adjust" or method == "logit":
        df = df.with_columns(
            pl.when(pl.col(original_col_name) == 0)
            .then(adjust)
            .when(pl.col(original_col_name) == 1)
            .then(1 - adjust)
            .otherwise(pl.col(original_col_name))
            .alias(original_col_name)
        )
        if method == "logit":
            transformed_col_name += "_logit"
            df = df.with_columns(
                pl.col(original_col_name)
                .map_elements(lambda x: logit(x))
                .alias(transformed_col_name)
            )

    # Square root transformation
    elif method == "sqrt":
        transformed_col_name += "_sqrt"
        df = df.with_columns(
            pl.col(original_col_name).sqrt().alias(transformed_col_name)
        )

    # Replace original column with transformed one, if the name has changed
    if transformed_col_name != original_col_name:
        original_col_index = df.columns.index(original_col_name)
        new_col = df.get_column(transformed_col_name)
        df = df.drop([original_col_name, transformed_col_name])
        df = df.insert_column(index=original_col_index, column=new_col)

    return df

In [6]:
def format_data_for_model(
    df: pl.DataFrame,
    response_var: str,
    categorical_vars: list[str],
    continuous_vars: list[str],
    interaction_terms: list[str],
):

    # Sort dataframe for consistent operations below
    df = df.sort(["SS", "SSB", "SSBS"])

    # Extract studies and blocks as indices
    studies = df.get_column("SS").unique().to_list()
    study_idx = df.get_column("SS").cast(pl.Categorical).to_physical().to_numpy()

    blocks = df.get_column("SSB").unique().to_list()
    block_idx = df.get_column("SSB").cast(pl.Categorical).to_physical().to_numpy()

    # Do the same for biomes, realms and ecoregions
    biomes = df.get_column("Biome").unique().to_list()
    biome_idx = df.get_column("Biome").cast(pl.Categorical).to_physical().to_numpy()
    biome_realm = df.get_column("Biome_Realm").unique().to_list()
    biome_realm_idx = (
        df.get_column("Biome_Realm").cast(pl.Categorical).to_physical().to_numpy()
    )
    biome_realm_eco = df.get_column("Biome_Realm_Ecoregion").unique().to_list()
    biome_realm_eco_idx = (
        df.get_column("Biome_Realm_Ecoregion")
        .cast(pl.Categorical)
        .to_physical()
        .to_numpy()
    )

    # Create an array with block-to-study indices
    block_to_study_idx = (
        df.select(["SS", "SSB"])
        .unique()
        .sort(["SS", "SSB"])
        .get_column("SS")
        .cast(pl.Categorical)
        .to_physical()
        .to_numpy()
    )

    # Similarly for biomes and realms
    realm_to_biome_idx = (
        df.select(["Biome", "Biome_Realm"])
        .unique()
        .sort(["Biome", "Biome_Realm"])
        .get_column("Biome")
        .cast(pl.Categorical)
        .to_physical()
        .to_numpy()
    )

    eco_to_realm_idx = (
        df.select(["Biome_Realm", "Biome_Realm_Ecoregion"])
        .unique()
        .sort(["Biome_Realm", "Biome_Realm_Ecoregion"])
        .get_column("Biome_Realm")
        .cast(pl.Categorical)
        .to_physical()
        .to_numpy()
    )

    # Get a vector of output values
    y = (
        df.select([col for col in df.columns if response_var in col])
        .to_numpy()
        .flatten()
    )

    # List of interaction terms created previously, for selection
    full_interaction_terms = []
    for col_1 in interaction_terms:
        for col_2 in categorical_vars:
            full_interaction_terms.append(f"{col_2} x {col_1}")

    # Create the fixed effects design matrix
    x = df.select(
        categorical_vars + continuous_vars + full_interaction_terms
    ).to_numpy()

    # Create a coordinate vector
    long = df.get_column("Longitude").to_numpy().flatten()
    lat = df.get_column("Latitude").to_numpy().flatten()
    site_coords = np.stack((long, lat), axis=1)

    # Random effects design matrices
    z_study = df.select(categorical_vars + continuous_vars).to_numpy()
    z_block = df.select(categorical_vars + continuous_vars).to_numpy()

    # Covariates that are in x but not in z
    x_z_study_diff = df.select(full_interaction_terms).to_numpy()
    x_z_block_diff = df.select(full_interaction_terms).to_numpy()

    # Convert numpy array to the precision actually needed, and no more,
    # to increase sampling speed
    y = y.astype(np.float32)
    x = x.astype(np.float32)
    study_idx = study_idx.astype(np.uint16)
    block_idx = block_idx.astype(np.uint16)
    block_to_study_idx.astype(np.uint16)
    x_z_study_diff.astype(np.float32)
    x_z_block_diff.astype(np.float32)
    site_coords.astype(np.float32)
    idx = np.arange(len(y)).astype(np.uint32)

    coords = {
        "idx": idx,
        "x_vars": categorical_vars + continuous_vars + full_interaction_terms,
        "x_vars_int": ["Intercept"]
        + categorical_vars
        + continuous_vars
        + full_interaction_terms,
        "z_study_vars": categorical_vars + continuous_vars,
        "z_block_vars": categorical_vars + continuous_vars,
        "x_z_study_diff_vars": full_interaction_terms,
        "x_z_block_diff_vars": full_interaction_terms,
        "biomes": biomes,
        "biome_realm": biome_realm,
        "biome_realm_eco": biome_realm_eco,
        "studies": studies,
        "blocks": blocks,
        "site_coords": ["Longitude", "Latitude"],
    }

    output_dict = {
        "coords": coords,
        "y": y,
        "x": x,
        "z_study": z_study,
        "z_block": z_block,
        "x_z_study_diff": x_z_study_diff,
        "x_z_block_diff": x_z_block_diff,
        "biome_idx": biome_idx,
        "biome_realm_idx": biome_realm_idx,
        "realm_to_biome_idx": realm_to_biome_idx,
        "biome_realm_eco_idx": biome_realm_eco_idx,
        "eco_to_realm_idx": eco_to_realm_idx,
        "study_idx": study_idx,
        "block_idx": block_idx,
        "block_to_study_idx": block_to_study_idx,
        "site_coords": site_coords,
    }

    return output_dict

In [18]:
def plot_prior_distribution(prior_samples, category, variable):

    if category == "prior":
        data = prior_samples.prior
    elif category == "prior_predictive":
        data = prior_samples.prior_predictive
    else:
        data = prior_samples.observed_data

    az.plot_dist(
        data[variable],
        figsize=(6, 3),
        kind="hist",
        color="C1",
        hist_kwargs=dict(alpha=0.6, bins=50),
    )

    plt.title(f"{category}: {variable}", fontsize=12)

    plt.tick_params(axis="x", labelsize=10)
    plt.tick_params(axis="y", labelsize=10)

    max_ticks = 15
    ax = plt.gca()
    ax.xaxis.set_major_locator(plt.MaxNLocator(max_ticks))
    plt.xticks(rotation=45)

    plt.show()

In [19]:
def run_sampling(
    model: pm.Model, sampler_settings: dict[str, Union[str, int, float]]
) -> az.InferenceData:
    with model:
        trace = pm.sample(
            draws=sampler_settings["draws"],
            tune=sampler_settings["tune"],
            cores=sampler_settings["cores"],
            chains=sampler_settings["chains"],
            target_accept=sampler_settings["target_accept"],
            nuts_sampler=sampler_settings["nuts_sampler"],
            # idata_kwargs={"log_likelihood": True},
        )

    return trace

In [20]:
def summarize_sampling_statistics(trace: az.InferenceData) -> None:

    var_names = list(trace.posterior.data_vars)
    idata = az.convert_to_dataset(trace)

    # Divergences
    divergences = np.sum(trace.sample_stats["diverging"].values)
    print(f"There are {divergences} divergences in the sampling chains.")

    # Acceptance rate
    accept_rate = np.mean(trace.sample_stats["acceptance_rate"].values)
    print(f"The mean acceptance rate was {accept_rate:.3f}")

    # R-hat statistics
    for var in var_names:
        try:
            r_hat = az.summary(idata, var_names=var, round_to=2)["r_hat"]
            mean_r_hat = np.mean(r_hat)
            min_r_hat = np.min(r_hat)
            max_r_hat = np.max(r_hat)
            print(
                f"R-hat for {var} are: {mean_r_hat:.3f} (mean) | {min_r_hat:.3f} (min) | {max_r_hat:.3f} (max)"
            )
        except KeyError:
            continue

    # ESS statistics
    for var in var_names:
        try:
            ess = az.summary(idata, var_names=var, round_to=2)["ess_bulk"]
            mean_ess = np.mean(ess)
            min_ess = np.min(ess)
            max_ess = np.max(ess)
            print(
                f"ESS for {var} are: {int(mean_ess)} (mean) | {int(min_ess)} (min) | {int(max_ess)} (max)"
            )
        except KeyError:
            continue

In [21]:
def forest_plot(trace, var_names):

    axes = az.plot_forest(
        data=trace,
        var_names=var_names,
        combined=True,
        hdi_prob=0.95,
    )

    ax = axes[0]
    labels = [item.get_text() for item in ax.get_yticklabels()]
    new_labels = []
    for label in labels:
        new_label = (
            label.replace("[", "")
            .replace("]", "")
            .replace(",", ":")
            .replace("beta", "")
        )
        new_labels.append(new_label)

    for label in ax.get_yticklabels():
        label.set_fontsize(10)
    for label in ax.get_xticklabels():
        label.set_fontsize(10)

    # Set the new labels to the y-axis
    ax.set_yticklabels(new_labels)

    plt.tight_layout()
    plt.show()

In [22]:
def bayesian_r2(pred_cond, pred_cond_oos, pred_marg, trace):
    cond_r2_values = []
    cond_oos_r2_values = []
    marg_r2_values = []

    sigmas = trace.posterior["eps"].mean(dim=["chain"]).values

    # Calculate one R^2 score for each draw
    for s in range(sigmas.shape[0]):
        cond = pred_cond[s, :]
        cond_oos = pred_cond_oos[s, :]
        marg = pred_marg[s, :]
        sigma = sigmas[s]

        # Compute variance of fitted values
        var_fit_cond = np.var(cond)
        var_fit_cond_oos = np.var(cond_oos)
        var_fit_marg = np.var(marg)

        # Residual variance
        # Needs to be adapted for more complex models
        var_res = sigma**2

        # Compute Bayesian R^2
        cond_r2 = var_fit_cond / (var_fit_cond + var_res)
        cond_oos_r2 = var_fit_cond_oos / (var_fit_cond + var_res)
        marg_r2 = var_fit_marg / (var_fit_cond + var_res)

        cond_r2_values.append(cond_r2)
        cond_oos_r2_values.append(cond_oos_r2)
        marg_r2_values.append(marg_r2)

    return cond_r2_values, cond_oos_r2_values, marg_r2_values

In [23]:
def create_stratification_column(
    df: pd.DataFrame, stratify_groups: list[str]
) -> pd.DataFrame:
    """
    Create a new column for stratification by concatenating the
    specified group columns.
    """

    if len(stratify_groups) > 1:
        df["Stratify_group"] = df[stratify_groups].astype(str).agg("_".join, axis=1)
    else:
        df["Stratify_group"] = df[stratify_groups[0]]

    return df

In [24]:
def generate_kfolds(
    df: pd.DataFrame,
    y_var: str,
    x_vars: list[str],
    groups: list[str],
    k: int = 5,
    stratify: bool = False,
) -> tuple[list[pd.DataFrame], list[pd.DataFrame], list[pd.Series], list[pd.Series]]:

    # Lists for storing the train and test datasets
    x_train_list = []
    x_test_list = []
    y_train_list = []
    y_test_list = []
    group_train_list = []
    group_test_list = []

    # Set up stratified k-fold sampler object and sample using the
    # stratify code (as the "y class label") for stratification
    if stratify:
        kfold = StratifiedKFold(n_splits=k, shuffle=True)
        strat_col = df["Stratify_group"]
    else:
        kfold = KFold(n_splits=k, shuffle=True)
        strat_col = None

    for train_index, test_index in kfold.split(X=df, y=strat_col):
        x_train, x_test = df.iloc[train_index][x_vars], df.iloc[test_index][x_vars]
        y_train, y_test = df.iloc[train_index][y_var], df.iloc[test_index][y_var]
        group_train, group_test = (
            df.iloc[train_index][groups],
            df.iloc[test_index][groups],
        )

        # Store the data for this fold
        x_train_list.append(x_train)
        x_test_list.append(x_test)
        y_train_list.append(y_train)
        y_test_list.append(y_test)
        group_train_list.append(group_train)
        group_test_list.append(group_test)

    return (
        x_train_list,
        x_test_list,
        y_train_list,
        y_test_list,
        group_train_list,
        group_test_list,
    )

In [25]:
def calculate_smape(y_true: np.array, y_pred: np.array, divisor: int = 1) -> float:
    # Check that the divisor argument is valid
    assert divisor in [1, 2], "Divisor argument must be 1 or 2"

    # Perform the accuracy calculation of 1 - sMAPE
    smape = 100 * (
        np.mean(np.abs(y_pred - y_true) / ((np.abs(y_pred) + np.abs(y_true)) / divisor))
    )
    return smape

In [26]:
def percentage_correct_quintile(y_true, y_pred):
    # Convert arrays to a DataFrame
    df = pd.DataFrame({"True": y_true, "Pred": y_pred})

    # Calculate deciles for the true values
    deciles = np.percentile(y_true, np.arange(20, 100, 20))

    # Function to find the decile index
    def find_decile(value):
        return 1 + np.digitize(value, deciles)

    # Apply function to true values
    df["True_quintile"] = df["True"].apply(find_decile)
    df["Pred_quintile"] = df["Pred"].apply(find_decile)
    df["Correct_quintile"] = np.where(df["True_quintile"] == df["Pred_quintile"], 1, 0)

    # Calculate the percentage of correct predictions
    pct_correct = df["Correct_quintile"].mean() * 100

    return pct_correct

In [7]:
response_var = "Max_scaled_abundance"

categorical_vars = [
    "Primary vegetation_Light use",
    "Primary vegetation_Intense use",
    "Secondary vegetation_Minimal use",
    "Secondary vegetation_Light use",
    "Secondary vegetation_Intense use",
    "Cropland_Minimal use",
    "Cropland_Light_Intense",
    "Pasture_Minimal use",
    "Pasture_Light_Intense",
    "Urban_All uses",
]

continuous_vars = [
    "Pop_density_1km_log",
    "Road_density_50km_cbrt",
    "Mean_pop_density_1km_log",
    "Annual_mean_temp_1km",
    "Temp_seasonality_1km",
    "Max_temp_warmest_month_1km",
    "Min_temp_coldest_month_1km",
    "Annual_precip_1km",
    "Precip_wettest_month_1km",
    "Precip_driest_month_1km",
    "Precip_seasonality_1km",
]

interaction_terms = ["Pop_density_1km_log", "Road_density_50km_cbrt"]

response_var_transform = "sqrt"  # Options: adjust, sqrt, logit, null

In [8]:
# Transform the response variable
df = transform_response_variable(df, method="sqrt")

NameError: name 'transform_response_variable' is not defined

In [9]:
# Generate formatted model data dictionary
model_data = format_data_for_model(
    df,
    response_var,
    categorical_vars,
    continuous_vars,
    interaction_terms,
)

## Spatial and environmental covariance function (GP)

In [30]:
import pytensor.tensor as pt


class Matern32Haversine(pm.gp.cov.Matern32):
    def __init__(self, *args, **kwargs):
        super().__init__(*args, **kwargs)

    def haversine_dist(self, X, Xs):
        if Xs is None:
            Xs = X

        # Constants
        R = 6371  # Radius of the Earth in kilometers
        pi_over_180 = np.pi / 180

        # Convert degrees to radians
        X_rad = X * pi_over_180
        Xs_rad = Xs * pi_over_180

        # Extract longitude and latitude
        lon1 = pt.reshape(X_rad[:, 0], (-1, 1))
        lat1 = pt.reshape(X_rad[:, 1], (-1, 1))
        lon2 = pt.reshape(Xs_rad[:, 0], (1, -1))
        lat2 = pt.reshape(Xs_rad[:, 1], (1, -1))

        # Haversine formula
        dlon = lon2 - lon1
        dlat = lat2 - lat1
        a = pt.sin(dlat / 2) ** 2 + pt.cos(lat1) * pt.cos(lat2) * pt.sin(dlon / 2) ** 2
        c = 2 * pt.arctan2(pt.sqrt(a), pt.sqrt(1 - a))

        return R * c  # Distance in kilometers

    def full(self, X, Xs=None):
        X, Xs = self._slice(X, Xs)
        h_dist = self.haversine_dist(X, Xs)
        return self.full_from_distance(h_dist, squared=False)

### Test with gp.MarginalApprox

In [34]:
# Unpack the required model data from dictionary
coords = model_data["coords"]
y = model_data["y"]
site_coords = model_data["site_coords"]

with pm.Model(coords=coords) as gp_corr:
    # Inducing points using K-means clustering
    Xu = pm.gp.util.kmeans_inducing_points(100, site_coords)

    # Observed data
    y_obs = pm.Data("y_obs", y, dims="idx")
    long_lat = pm.Data("long_lat", site_coords, dims=("idx", "site_coords"))

    # Gaussian Process regression to capture spatial correlation
    ls_spat = pm.Gamma("ls_spat", mu=200, sigma=50)
    gp_spat = pm.gp.MarginalApprox(cov_func=Matern32Haversine(input_dim=2, ls=ls_spat))
    sigma_y = pm.HalfNormal("sigma_y", sigma=1)
    y_like = gp_spat.marginal_likelihood(
        "y_like", X=long_lat, Xu=Xu, y=y, sigma=sigma_y
    )

In [35]:
# Sampler settings
sampler_settings = {
    "draws": 500,
    "tune": 500,
    "cores": 4,
    "chains": 4,
    "target_accept": 0.95,
    "nuts_sampler": "numpyro",
}

# Run sampling
gp_corr_trace = run_sampling(gp_corr, sampler_settings)

  0%|          | 0/1000 [00:00<?, ?it/s]

  0%|          | 0/1000 [00:00<?, ?it/s]

  0%|          | 0/1000 [00:00<?, ?it/s]

  0%|          | 0/1000 [00:00<?, ?it/s]

In [36]:
model_data["site_coords"].shape

(7732, 2)

In [37]:
test_data = model_data["site_coords"][:100, :]

In [38]:
test_data.shape

(100, 2)

In [40]:
thinned_trace = gp_corr_trace.sel(draw=slice(None, None, 10))

In [None]:
site_coords = model_data["site_coords"]

with gp_corr:
    f_pred = gp_spat.conditional("f_pred_3", site_coords)
    pred_samples = pm.sample_posterior_predictive(thinned_trace, var_names=["f_pred_3"])

  pred_samples = pm.sample_posterior_predictive(thinned_trace, var_names=["f_pred_3"])
Sampling: [f_pred_3]
INFO:pymc.sampling.forward:Sampling: [f_pred_3]


Output()

In [42]:
pred_samples

### Combine hierarchical model with GP

In [None]:
def abund_gp_corr(model_data) -> pm.Model:
    """Docstring."""

    # Unpack the required model data from dictionary
    coords = model_data["coords"]
    y = model_data["y"]
    x = model_data["x"]
    site_coords = model_data["site_coords"]
    biome_realm_idx = model_data["biome_realm_idx"]
    realm_to_biome_idx = model_data["realm_to_biome_idx"]

    with pm.Model(coords=coords) as model:

        # Observed data that be changed later on for train-test runs
        y_obs = pm.MutableData("y_obs", y, dims="idx")
        x_obs = pm.MutableData("x_obs", x, dims=("idx", "x_vars"))
        long_lat = pm.MutableData("long_lat", site_coords, dims=("idx", "site_coords"))

        # Hyperpriors for biome-level intercept terms
        mu_a = pm.Normal("mu_a", mu=0.5, sigma=0.25)
        sigma_a = pm.HalfNormal("sigma_a", sigma=0.25)

        # Hyperpriors for slope terms
        mu_b = pm.Normal("mu_b", mu=0, sigma=0.25, dims="x_vars")
        sigma_b = pm.HalfNormal("sigma_b", sigma=0.25, dims="x_vars")

        # Biome-level priors (non-centered parameterization)
        biome_offset_1 = pm.Normal("biome_offset_1", mu=0, sigma=1, dims="biomes")
        biome_offset_2 = pm.Normal(
            "biome_offset_2", mu=0, sigma=1, dims=("biomes", "x_vars")
        )
        alpha_biome = pm.Deterministic("alpha_biome", mu_a + biome_offset_1 * sigma_a)
        beta_biome = pm.Deterministic("beta_biome", mu_b + biome_offset_2 * sigma_b)

        # Realm-level intercepts and slopes, sampled from the corresponding biomes
        mu_a_realm = pm.Deterministic("mu_a_realm", alpha_biome[realm_to_biome_idx])
        mu_b_realm = pm.Deterministic("mu_b_realm", beta_biome[realm_to_biome_idx])
        realm_offset_1 = pm.Normal("realm_offset_1", mu=0, sigma=1, dims="biome_realm")
        realm_offset_2 = pm.Normal(
            "realm_offset_2", mu=0, sigma=1, dims=("biome_realm", "x_vars")
        )
        sigma_realm = pm.HalfNormal("sigma_realm", sigma=0.25)
        alpha_realm = pm.Deterministic(
            "alpha_realm", mu_a_realm + realm_offset_1 * sigma_realm
        )
        beta_realm = pm.Deterministic(
            "beta_realm", mu_b_realm + realm_offset_2 * sigma_realm
        )

        # Variance assumed independent within and between groups
        sigma_y = pm.HalfNormal("sigma_y", sigma=0.25)

        # Gaussian Process regression to capture spatial correlation
        ls_spat = pm.Gamma("ls_spat", mu=200, sigma=50)
        gp_spat = pm.gp.Latent(cov_func=Matern32Haversine(input_dim=2, ls=ls_spat))
        f_spat = gp_spat.prior("f_spat", X=long_lat)

        # TODO: GP for environmental correlation

        # Expected values
        y_hat = pm.Deterministic(
            "y_hat",
            alpha_realm[biome_realm_idx]
            + pm.math.sum(x_obs * beta_realm[biome_realm_idx], axis=1)
            + f_spat,
        )

        # Likelihood function
        y_like = pm.Normal(  # noqa: F841
            "y_like", mu=y_hat, sigma=sigma_y, observed=y_obs, dims="idx"
        )

        return model

## High-level multiple model comparison

### LOO cross validation and WAIC

In [None]:
trace_dict = {
    "bii_pooled": bii_pooled_trace,
    "bii_abund_study_intercept": bii_abund_study_intercept_trace,
    "bii_abund_study_block_intercept": bii_abund_study_block_intercept_trace,
    "bii_abund_study_slope": bii_abund_study_slope_trace,
    "bii_abund_study_block_slope": bii_abund_study_block_slope_trace,
    "abund_base": abund_base_trace,
}
df_comp = az.compare(trace_dict)
df_comp

In [None]:
az.plot_compare(df_comp, insample_dev=True);

### R^2 score comparison

In [115]:
pred_dict = {
    # "abund_base": abund_base_r2,
    "abund_ecoregions": abund_eco_r2,
}

print("Comparison of R^2 values for different models \n")
for model in pred_dict.keys():
    print(f"{model}:")
    for mode in pred_dict[model].keys():
        r2 = pred_dict[model][mode]
        mean, std = np.mean(r2), np.std(r2)
        print(f"{mode}: {mean:.3f} (mean) | {std:.3f} (std)")
    print("\n")

Comparison of R^2 values for different models 

abund_ecoregions:
conditional: 0.468 (mean) | 0.004 (std)
conditional_oos: 0.468 (mean) | 0.004 (std)
marginal: 0.496 (mean) | 0.252 (std)




### Cross-validation

In [None]:
# Stratify the data on study
df = create_stratification_column(df, ["SS"])

## Prior predictive checks

In [122]:
# Selecting prior samples

# abund_base_prior
# gp_corr_prior

prior_samples = gp_corr_prior

In [123]:
plot_prior_distribution(prior_samples, category="observed_data", variable="y_like")

AttributeError: 'InferenceData' object has no attribute 'observed_data'

In [None]:
plot_prior_distribution(prior_samples, category="prior_predictive", variable="y_like")

## Model output summary

In [None]:
# Update the right trace object here

# bii_abund_study_intercept_trace
# bii_abund_study_block_intercept_trace
# bii_abund_study_slope_trace
# bii_abund_study_block_slope_trace
# abund_base_trace

trace = abund_base_trace

In [None]:
# Summary output for posteror of hyperprior
az.summary(trace, var_names=["mu_a", "mu_b"], round_to=2)

In [None]:
forest_plot(trace, var_names=["mu_b"])

### Hyperprior posterior distributions

In [None]:
az.plot_posterior(trace, var_names=["mu_a", "sigma_a"])
plt.tight_layout()
plt.show()

In [None]:
az.plot_posterior(trace, var_names=["mu_b"])
plt.tight_layout()
plt.show()

## Model fit

### Conditional and marginal R^2 scores

In [None]:
def plot_r2_distribution(r2_values):
    plt.figure(figsize=(8, 4))
    plt.hist(r2_values, bins=30, alpha=0.6, color="g", edgecolor="black")
    mean_r2 = np.mean(r2_values)
    median_r2 = np.median(r2_values)
    plt.axvline(
        mean_r2,
        color="r",
        linestyle="dashed",
        linewidth=1,
        label=f"Mean: {mean_r2:.2f}",
    )
    plt.axvline(
        median_r2,
        color="b",
        linestyle="dashed",
        linewidth=1,
        label=f"Median: {median_r2:.2f}",
    )
    plt.title("Distribution of Bayesian R² Values")
    plt.xlabel("R²")
    plt.ylabel("Frequency")
    plt.legend()
    plt.show()

In [None]:
# Calculate R^2 values
cond_r2_values, marg_r2_values = bayesian_r2(trace)

In [None]:
plot_r2_distribution(cond_r2_values)

In [None]:
plot_r2_distribution(marg_r2_values)

### Leave-one-out cross validation (PSIS-LOO-CV)

In [None]:
loo = az.loo(trace, var_name="y_like")
loo

### Widely applicable information criterion (WAIC)

In [None]:
waic = az.waic(trace, var_name="y_like")
waic

### Prediction and residual plots

In [None]:
def plot_predictions_and_residuals(y_true: np.array, y_pred: np.array) -> None:

    # Scatter plot of predictions vs actuals
    plt.figure(figsize=(8, 4))
    plt.scatter(y_true, y_pred, alpha=0.2)
    plt.plot([min(y_true), max(y_true)], [min(y_true), max(y_true)], "--k", linewidth=1)
    # plt.text(0.05, 0.95, f"R² = {r2:.3f}", fontsize=12, transform=plt.gca().transAxes)
    plt.xlabel("Actual values")
    plt.ylabel("Predicted values")
    plt.title("Predictions vs actuals")
    plt.grid(True)
    plt.show()

    # Residuals vs actual scatter plot
    residuals = y_true - y_pred
    plt.figure(figsize=(8, 4))
    plt.scatter(y_true, residuals, alpha=0.2)
    plt.axhline(y=0, color="r", linestyle="--", linewidth=2)
    plt.xlabel("Actual values")
    plt.ylabel("Residuals")
    plt.title("Residuals vs actual values")
    plt.grid(True)
    plt.show()

    # Residual Density Plot (Distribution of Residuals)
    plt.figure(figsize=(8, 4))
    sns.kdeplot(residuals, fill=True)
    plt.xlabel("Residuals")
    plt.ylabel("Density")
    plt.title("Residual density Plot")
    plt.grid(True)
    plt.show()

In [None]:
pred_with_re = trace.posterior_predictive["y_like"].mean(dim=["chain", "draw"]).values
y_true = trace.observed_data["y_like"].values

In [None]:
plot_predictions_and_residuals(y_true, pred_with_re)

## Posterior predictive distribution

In [None]:
ax = az.plot_ppc(trace)
for label in ax.get_xticklabels():
    label.set_fontsize(10)
plt.xlabel("")
ax.legend(fontsize=10)
plt.show()

In [None]:
ax = az.plot_ppc(trace, kind="cumulative")
for label in ax.get_xticklabels():
    label.set_fontsize(10)
plt.xlabel("")
ax.legend(fontsize=10)
plt.show()

## Debugging: Posterior predictive checks

### Trace plots

In [None]:
def trace_plot(
    trace,
    var_names,
):
    axes = az.plot_trace(
        data=trace,
        var_names=var_names,
        divergences="bottom",
        compact=False,  # Plot multidimensional variables in one plot
        combined=False,  # Combine chains into a single line
    )

    plt.tight_layout()
    plt.show()

In [None]:
# Posteriors of hyperpriors: mu_beta
trace_plot(trace, var_names=["beta"])

### ESS plots

In [None]:
# ESS for posterior of hyperparameters
az.plot_ess(trace, kind="evolution", var_names=["beta"])
plt.tight_layout()
plt.show()

### Energy plot

In [None]:
az.plot_energy(trace)
plt.show()

## Detailed debugging

### Autocorrelation in chains

In [None]:
az.plot_autocorr(trace, var_names=["beta"])
plt.tight_layout()
plt.show()

In [10]:
site_coords = model_data["site_coords"]

In [15]:
def compute_rbf(X, centers, length_scale, rbf_type="gaussian"):
        dist = pt.sqrt(((X[:, None, :] - centers[None, :, :]) ** 2).sum(axis=2))
        print(dist.eval().shape)
        if rbf_type == "gaussian":
            return pt.exp(-0.5 * (dist / length_scale) ** 2)
        elif rbf_type == "multiquadric":
            return pt.sqrt(1 + (dist / length_scale) ** 2)
        elif rbf_type == "inverse_multiquadric":
            return 1 / pt.sqrt(1 + (dist / length_scale) ** 2)
        else:
            raise ValueError("Unsupported RBF type")

In [13]:
n_centers = 30
centers = site_coords[
            np.random.choice(site_coords.shape[0], n_centers, replace=False), :
        ]

In [14]:
length_scale = 10

In [17]:
rbf_matrix = compute_rbf(site_coords, centers, length_scale, rbf_type="gaussian")

(7732, 30)
