In [None]:
from pathlib import Path
import os
import re
import time

import numpy as np
import pandas as pd

from b3alien import b3cube, simulation, griis

import pyspark.sql.functions as f
from pyspark.sql.window import Window
from pyspark.sql import SparkSession
from sedona.spark import SedonaContext

# ---- Environment (as you had it) ----
os.environ["JAVA_HOME"] = "/opt/miniconda3"
os.environ["PROTOCOL_BUFFERS_PYTHON_IMPLEMENTATION"] = "python"
os.environ["SPARK_LOCAL_IP"] = "127.0.0.1"

# ---- Spark/Sedona (once) ----
spark_session = (SparkSession.builder
    .appName("SedonaApp")
    .master("local[*]")
    .config("spark.driver.memory", "4g")
    .config("spark.serializer", "org.apache.spark.serializer.KryoSerializer")
    .config("spark.kryo.registrator", "org.apache.sedona.core.serde.SedonaKryoRegistrator")
    .config("spark.jars.packages", "org.apache.sedona:sedona-spark-shaded-3.5_2.12:1.7.0")
    .getOrCreate())

spark = SedonaContext.create(spark_session)

# ---- Load parquet once ----
df = spark.read.format("parquet").load("./global_level1_country.parquet")
df = df.withColumn("geometry", f.expr("ST_GeomFromWKB(geometry)"))  # do once


In [None]:
ROOT = Path("./gbif_downloads_by_country")
RESULTS_CSV = Path("./country_simulation_results.csv")

COUNTRY_DIR_RE = re.compile(r"^(?P<cc>[A-Z]{2})(?:_(?P<n>\d+))?$")

def list_country_folders(root: Path) -> dict[str, list[Path]]:
    """
    Returns mapping: {"ES": [Path("ES"), Path("ES_2"), ...], ...}
    Only includes folders that look like country folders.
    """
    grouped: dict[str, list[Path]] = {}
    for p in root.iterdir():
        if not p.is_dir():
            continue
        m = COUNTRY_DIR_RE.match(p.name)
        if not m:
            continue
        cc = m.group("cc")
        grouped.setdefault(cc, []).append(p)
    # stable ordering: base folder first, then numbered
    for cc in grouped:
        grouped[cc] = sorted(grouped[cc], key=lambda x: (x.name != cc, x.name))
    return grouped

def find_merged_checklist_file(folder: Path) -> Path | None:
    """
    Accept both spellings: merged_dist.txt and merged_distr.txt
    """
    for name in ("merged_dist.txt", "merged_distr.txt"):
        fp = folder / name
        if fp.exists():
            return fp
    return None

def load_merged_specieskeys_for_country(country_folders: list[Path]) -> list[int]:
    """
    Merge all checklists across e.g. ES, ES_2, ES_3 into one unique list of specieskeys.
    Skips folders that don't have the merged_dist(r).txt file.
    """
    species_set: set[int] = set()

    for folder in country_folders:
        fp = find_merged_checklist_file(folder)
        if fp is None:
            continue
        cl = griis.CheckList(str(fp))
        # cl.species is expected to be an iterable of ints
        species_set.update(int(x) for x in cl.species)

    return sorted(species_set)


In [None]:
def run_country_pipeline(country_code: str, specieskeys: list[int]) -> dict:
    """
    Runs your Spark filtering + cumulative curve + Solow-Costello simulation for one country.
    Returns a dict of results for CSV.
    """
    if not specieskeys:
        return {
            "country_code": country_code,
            "ok": False,
            "message": "No checklist specieskeys found (missing merged_dist(r).txt?)",
            "n_species": 0
        }

    # 1) Filter parquet by country
    df_cc = df.filter(f.col("countrycode") == country_code)

    # 2) Filter by checklist species
    # Cast to long for safety (you did this too)
    species_mask = f.col("specieskey").cast("long").isin([int(x) for x in specieskeys])
    df_filtered = df_cc.filter(species_mask)

    # 3) First appearance per species
    first_seen_df = df_filtered.groupBy("specieskey").agg(
        f.min("yearmonth").alias("first_appearance")
    )

    # 4) New species per month + cumulative
    new_species_per_month = (
        first_seen_df.groupBy("first_appearance")
        .count()
        .withColumnRenamed("count", "new_species_count")
    )

    window_spec = Window.orderBy("first_appearance").rowsBetween(Window.unboundedPreceding, Window.currentRow)

    df_cumulative = new_species_per_month.withColumn(
        "cumulative_species",
        f.sum("new_species_count").over(window_spec)
    )

    # 5) To pandas
    df_final = df_cumulative.toPandas()
    df_final["time"] = pd.to_datetime(df_final["first_appearance"], format="%Y-%m", errors="coerce")
    df_final = df_final.sort_values("time").dropna(subset=["time"])

    if df_final.empty or df_final["cumulative_species"].max() <= 1:
        return {
            "country_code": country_code,
            "ok": False,
            "message": "Insufficient data after filtering (empty or too few points)",
            "n_species": len(specieskeys)
        }

    # 6) Rate calculation + simulation
    time_arr, rate = b3cube.calculate_rate(df_final)

    # IMPORTANT: Your original code had `_, vec1 = ...` but later uses `fitted_rate`.
    # Most likely the function returns (fitted_rate, params_or_curve). We'll capture that.
    fitted_rate, _ = simulation.simulate_solow_costello_scipy(time_arr, rate, vis=False)

    results = simulation.parallel_bootstrap_solow_costello(time_arr, rate, n_iterations=500)

    if "beta1_ci" not in results or "beta1_samples" not in results:
        raise KeyError("Expected keys 'beta1_ci' and 'beta1_samples' not found in bootstrap results.")

    rate_lo, rate_hi = map(float, results["beta1_ci"])
    rate_samples = np.asarray(results["beta1_samples"], dtype=float)

    return {
        "country_code": country_code,
        "ok": True,
        "message": "success",
        "n_species": len(specieskeys),
        "fitted_rate": float(fitted_rate),
        "fitted_rate_ci_low": rate_lo,
        "fitted_rate_ci_high": rate_hi,
        "fitted_rate_samples_n": int(rate_samples.size),
        # optional: include mean/std of samples to keep CSV compact
        "fitted_rate_samples_mean": float(np.mean(rate_samples)) if rate_samples.size else np.nan,
        "fitted_rate_samples_std": float(np.std(rate_samples, ddof=1)) if rate_samples.size > 1 else np.nan,
    }


In [None]:
country_folders = list_country_folders(ROOT)

all_rows = []
for cc, folders in sorted(country_folders.items()):
    print(f"\n=== {cc} ({len(folders)} folder(s)) ===")

    specieskeys = load_merged_specieskeys_for_country(folders)
    print(f"Checklist specieskeys: {len(specieskeys)}")

    try:
        row = run_country_pipeline(cc, specieskeys)
    except Exception as e:
        row = {
            "country_code": cc,
            "ok": False,
            "message": f"Exception: {e}",
            "n_species": len(specieskeys),
        }

    all_rows.append(row)

res_df = pd.DataFrame(all_rows)

# Append-safe write: if file exists, append without header; else write with header
if RESULTS_CSV.exists():
    res_df.to_csv(RESULTS_CSV, mode="a", header=False, index=False)
else:
    res_df.to_csv(RESULTS_CSV, index=False)

res_df
