In [1]:
# --- Imports ---
from b3alien import b3cube
from b3alien import simulation
from b3alien import griis

import numpy as np
import pandas as pd
import matplotlib as mpl
import matplotlib.pyplot as plt
from matplotlib.lines import Line2D
from matplotlib.patches import Patch
from matplotlib.ticker import MaxNLocator

# --- Matplotlib: journal-friendly defaults ---
mpl.rcParams.update({
    "figure.figsize": (12, 10),
    "figure.dpi": 100,
    "savefig.dpi": 300,
    "savefig.bbox": "tight",
    "axes.titlesize": 14,
    "axes.labelsize": 14,   # bigger axis labels
    "xtick.labelsize": 12,
    "ytick.labelsize": 12,
    "legend.fontsize": 12,  # bigger legend
    "font.family": "DejaVu Sans",
    "axes.grid": True,
})

# --------- CONFIGURE YOUR 4 DATASETS HERE ---------
# name: short label for the panel title
# cube_path: path to the data cube parquet
# cl_path: path to the GRIIS (or other) checklist file
DATASETS = [
    {
        "name": "Australia",
        "cube_path": "/Users/maarten/Downloads/data_AUS_level2.parquet",
        "cl_path":   "/Users/maarten/Downloads/griis_australia/merged_distr.txt",
    },
    {
        "name": "Costa Rica",
        "cube_path": "/Users/maarten/Documents/GIT/b3alien/tests/data/costa-rica/data_CR_level2.parquet",
        "cl_path":   "/Users/maarten/Documents/GIT/b3alien/tests/data/costa-rica/dwca-griis-costa-rica-v1.4/merged_distr.txt",
    },
    {
        "name": "South Africa",
        "cube_path": "/Users/maarten/Downloads/data_ZA_level2.parquet",
        "cl_path":   "/Users/maarten/Downloads/dwca-south-africa-griis-gbif-v2.7/merged_distr.txt",
    },
    {
        "name": "Belgium",
        "cube_path": "/Users/maarten/Downloads/data_BE.parquet",
        "cl_path":   "/Users/maarten/Documents/GIT/b3alien/tests/data/belgium/dwca-unified-checklist-v1.14/merged_distr.txt",
    },
]

# Analysis window (years)
START_YEAR = 1970
END_YEAR   = 2022

# Bootstrap iterations (you can lower to speed up)
N_BOOT = 1000

# --- Helpers ---

def filter_year_series(year_like, values, start_year, end_year):
    """
    Accepts:
      year_like: array-like of years OR datetimes
      values:    array-like of same length
    Returns:
      (years_int, values_filtered) within [start_year, end_year].
    """
    years = pd.Series(year_like)
    # If datetime, convert to year; if numeric, cast to int
    if np.issubdtype(years.dtype, np.datetime64):
        years = years.dt.year.astype(int)
    else:
        years = years.astype(int)

    vals = pd.Series(values).reset_index(drop=True)
    years = years.reset_index(drop=True)

    mask = (years >= start_year) & (years <= end_year)
    return years[mask].to_numpy(), vals[mask].to_numpy()

def compute_dataset(cube_path, cl_path, start_year, end_year, n_boot):
    CL = griis.CheckList(cl_path)
    cube = b3cube.OccurrenceCube(cube_path)

    # Rate from data cube
    d_s, d_c = b3cube.cumulative_species(cube, CL.species)
    time, rate = b3cube.calculate_rate(d_c)
    df = pd.DataFrame({"year": time, "rate": rate})
    time_w, rate_w = b3cube.filter_time_window(df, start_year, end_year)

    # Point fit (for reporting)
    _, vec1 = simulation.simulate_solow_costello_scipy(time_w, rate_w, vis=False)
    fitted_rate = float(vec1[1])  # matches beta1

    # Bootstrap (use built-in outputs only)
    results = simulation.parallel_bootstrap_solow_costello(time_w, rate_w, n_iterations=n_boot)

    # Pull parameter CI and samples directly
    if "beta1_ci" not in results or "beta1_samples" not in results:
        raise KeyError("Expected keys 'beta1_ci' and 'beta1_samples' not found in bootstrap results.")
    rate_lo, rate_hi = map(float, results["beta1_ci"])  # 95% CI from the function
    rate_samples = np.asarray(results["beta1_samples"], dtype=float)

    # Survey effort
    survey_eff = b3cube.get_survey_effort(cube, calc_type='distinct')
    year_like = survey_eff["date"] if "date" in survey_eff.columns else survey_eff.get("year")
    if year_like is None:
        raise ValueError("survey_eff must have 'date' (datetime) or 'year' column.")
    if "distinct_observers" not in survey_eff.columns:
        for alt in ["n_observers", "observers", "distinct_recorders"]:
            if alt in survey_eff.columns:
                survey_eff = survey_eff.rename(columns={alt: "distinct_observers"})
                break
        if "distinct_observers" not in survey_eff.columns:
            raise ValueError("survey_eff must contain 'distinct_observers' (or equivalent).")
    years_obs, obs_vals = filter_year_series(year_like, survey_eff["distinct_observers"], start_year, end_year)

    return {
        "time": np.asarray(time_w),
        "rate": np.asarray(rate_w),
        "c1_mean": np.asarray(results["c1_mean"]),
        "c1_lo": np.asarray(results["c1_lower"]),
        "c1_hi": np.asarray(results["c1_upper"]),
        "years_obs": years_obs,
        "obs_vals": obs_vals,
        "fitted_rate": fitted_rate,
        "fitted_rate_ci": (rate_lo, rate_hi),
        "fitted_rate_samples": rate_samples,  # optional, handy for diagnostics
    }


# --- Run all four datasets ONCE, then plot, then export summary ---
all_results = []
summary_rows = []

for ds in DATASETS:
    print(f"Processing: {ds['name']}")
    res = compute_dataset(ds["cube_path"], ds["cl_path"], START_YEAR, END_YEAR, N_BOOT)
    rec = {**ds, **res}
    all_results.append(rec)

    lo, hi = res["fitted_rate_ci"]
    print(f"  Fitted rate: {res['fitted_rate']:.4f} / year  (95% CI [{lo:.4f}, {hi:.4f}])")

# --- Build single figure with 4 panels (2x2), consistent styling ---
fig, axes = plt.subplots(2, 2, figsize=(12, 10), sharex=False)
axes = axes.ravel()

effort_color = "tab:blue"   # left axis
obs_line_color = "black"    # cumulative observed
boot_mean_color = "tab:orange"
ci_face_alpha = 0.25

for ax, res in zip(axes, all_results):
    # Left axis: log observation effort
    ax.set_title(res["name"])
    ax.set_xlabel("Year")
    ax.set_ylabel("log(Observation effort)", color=effort_color)

    obs = np.asarray(res["obs_vals"], dtype=float)
    obs = np.where(obs <= 0, np.nan, obs)  # guard against nonpositive
    ax.plot(res["years_obs"], np.log(obs), marker="o", linestyle="-",
            linewidth=1.5, markersize=3.5, color=effort_color, label="Observation effort")
    ax.tick_params(axis="y", labelcolor=effort_color)
    ax.xaxis.set_major_locator(MaxNLocator(integer=True, nbins=8))

    # Right axis: cumulative discoveries + bootstrap mean + CI
    ax2 = ax.twinx()
    ax2.set_ylabel("Cumulative discoveries")
    ax2.plot(res["time"], np.cumsum(res["rate"]), linestyle="-", linewidth=1.8,
             color=obs_line_color, label="Observed discoveries")
    ax2.plot(res["time"], res["c1_mean"], linestyle="--", linewidth=1.6,
             color=boot_mean_color, label="Bootstrap mean")
    ax2.fill_between(res["time"], res["c1_lo"], res["c1_hi"],
                     color=boot_mean_color, alpha=ci_face_alpha, label="95% CI")
    ax2.xaxis.set_major_locator(MaxNLocator(integer=True, nbins=8))

# Global legend (one for the whole figure)
legend_elements = [
    Line2D([0], [0], color=effort_color, lw=1.8, marker="o", markersize=4, label="Observation effort (log)"),
    Line2D([0], [0], color=obs_line_color, lw=1.8, label="Observed discoveries"),
    Line2D([0], [0], color=boot_mean_color, lw=1.8, ls="--", label="Bootstrap mean"),
    Patch(facecolor=boot_mean_color, alpha=ci_face_alpha, label="95% CI"),
]
fig.legend(handles=legend_elements, loc="lower center", ncol=2, frameon=False, bbox_to_anchor=(0.5, 0.01))

fig.tight_layout(rect=[0, 0.05, 1, 0.96])

# --- Save high-quality outputs ---
fig.savefig("solow_costello_4panels_new.pdf")
fig.savefig("solow_costello_4panels_new.svg")
fig.savefig("solow_costello_4panels_new.png", dpi=300)
plt.show()

# --- Summary (reusing results already in memory; no recompute) ---
summary_rows = [
    {
        "dataset": rec["name"],
        "fitted_rate_per_year": rec["fitted_rate"],
        "ci_lower": rec["fitted_rate_ci"][0],
        "ci_upper": rec["fitted_rate_ci"][1],
    }
    for rec in all_results
]
pd.DataFrame(summary_rows).to_csv("fitted_rate_CIs.csv", index=False)
print("Saved: fitted_rate_CIs.csv")


Processing: Australia


Bootstrapping: 100%|██████████| 1000/1000 [07:20<00:00,  2.27it/s]


  Fitted rate: -0.0031 / year  (95% CI [-0.0408, -0.0034])
Processing: Costa Rica


Bootstrapping: 100%|██████████| 1000/1000 [02:17<00:00,  7.30it/s]


  Fitted rate: -0.0269 / year  (95% CI [-0.0472, -0.0125])
Processing: South Africa


Bootstrapping: 100%|██████████| 1000/1000 [03:28<00:00,  4.81it/s]


  Fitted rate: 0.0157 / year  (95% CI [0.0020, 0.0872])
Processing: Belgium


Bootstrapping: 100%|██████████| 1000/1000 [02:47<00:00,  5.96it/s]


  Fitted rate: 0.0427 / year  (95% CI [0.0266, 0.0698])
Saved: fitted_rate_CIs.csv


  plt.show()
