In [None]:
###############################################################

# KDE-entropy (bits) on PC1 / PC1–PC2 / PC1–PC2–PC3
# vs Shannon diversity (bits) from species counts.
# Bandwidth sensitivity: h = α·h*, α ∈ {1/2, 1/√2, 1, √2}
#
# Plot simulation results

###############################################################

import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
from collections import defaultdict
from sklearn.decomposition import PCA
from sklearn.preprocessing import StandardScaler
from scipy.stats import pearsonr, t
from matplotlib.ticker import MaxNLocator, AutoMinorLocator

# Model parameters
RNG_SEED       = None
N_COMM_TRAIN   = 100
N_GRAINS       = 50
MIN_SPECIES    = 1
MAX_SPECIES    = 30
MAX_PC_FOR_PCA = 3
HSTAR_GRID_KW  = dict(h_min=1e-2, h_max=1e1, num=400, logspace=True)

OUTPUT_PDF_SENS = "kde_entropy_vs_shannon_h_sensitivity_with_CI_PI_lines.pdf"
H_SCALES        = [0.5, 2**(-0.5), 1.0, 2**0.5]

# For figure
ALPHA_X = -0.15
ALPHA_FONTSIZE = 13
YLABEL_PAD = 36
POINT_SIZE = 18
LINE_WIDTH = 2.0
LABEL_FONTSIZE = 14
STATS_FONTSIZE = 10

COLOR_POINTS = "#6b7280"
COLOR_LINE   = "#111827"
COLOR_CI     = "#6b7280"
COLOR_PI     = "#9ca3af"
GRID_COLOR   = "#d1d5db"

# Helpers
def to_numpy_array(X):
    if hasattr(X, "detach"): return X.detach().cpu().numpy()
    if hasattr(X, "to_numpy"): return X.to_numpy()
    return np.asarray(X)

def build_species_index(labels):
    d = defaultdict(list)
    for i, sp in enumerate(labels):
        d[int(sp)].append(i)
    return {k: np.asarray(v, dtype=int) for k, v in d.items()}

# Simulating different communities
def sample_community_indices(species_index, N, rng, min_species, max_species):
    species_list = list(species_index.keys())
    minS = max(1, min_species)
    maxS = min(max_species, len(species_list))
    if minS > maxS:
        raise ValueError("min_species > max_species given available species.")

    for _ in range(2000):
        S = int(rng.integers(minS, maxS + 1))
        picks = rng.choice(species_list, size=S, replace=False)

        tau = float(np.exp(rng.uniform(np.log(0.01), np.log(20.0))))
        alpha = np.full(S, tau / S, dtype=float)
        p = rng.dirichlet(alpha)
        counts = rng.multinomial(N, p).astype(int)

        if counts.sum() != N:
            diff = N - counts.sum()
            j = int(rng.integers(0, S))
            counts[j] = max(0, counts[j] + diff)

        avail = np.array([len(species_index[s]) for s in picks], dtype=int)
        for _inner in range(5):
            over = counts - avail
            if np.all(over <= 0): break
            excess = int(over[over > 0].sum())
            counts = np.minimum(counts, avail)
            cap = np.maximum(avail - counts, 0)
            if cap.sum() == 0: break
            w = cap / cap.sum()
            add = np.floor(excess * w).astype(int)
            rem = excess - add.sum()
            if rem > 0:
                idxs = rng.choice(np.arange(S), size=rem, replace=True, p=w)
                add[idxs] += 1
            counts += add

        if np.any(counts > avail) or counts.sum() != N: continue

        chosen = []
        for s, cnt in zip(picks, counts):
            if cnt > 0:
                chosen.extend(rng.choice(species_index[s], size=cnt, replace=False))
        chosen = np.asarray(chosen, dtype=int)
        if len(chosen) == N: return chosen

    raise RuntimeError("Could not sample valid community (availability too tight?).")

def shannon_diversity_bits_from_counts(counts):
    n = counts.sum()
    if n <= 0: return np.nan
    p = counts[counts > 0] / n
    return float(-(p * (np.log(p)/np.log(2.0))).sum())

# Estimating entropy from KDE in PC space
def _pairwise_sq_dists(X):
    X = np.asarray(X, float)
    norms = np.sum(X * X, axis=1, keepdims=True)
    return norms + norms.T - 2.0 * (X @ X.T)

def loo_mean_loglik_gaussian_iso(X, h):
    X = np.asarray(X, float)
    if X.ndim != 2 or X.shape[0] <= 1 or h <= 0 or not np.isfinite(h):
        return -np.inf
    n, d = X.shape
    D2 = _pairwise_sq_dists(X)
    A  = -D2 / (2.0 * h * h)
    np.fill_diagonal(A, -np.inf)
    row_max = np.max(A, axis=1, keepdims=True)
    lse = row_max[:, 0] + np.log(np.sum(np.exp(A - row_max), axis=1))
    log_const = -np.log(n - 1) - d*np.log(h) - 0.5*d*np.log(2*np.pi)
    return float(np.mean(log_const + lse))

def find_h_star_lcv(X, h_min=1e-2, h_max=1e1, num=400, logspace=True):
    hs = (np.logspace(np.log10(h_min), np.log10(h_max), num=num)
          if logspace else np.linspace(h_min, h_max, num=num))
    best_h, best_val = None, -np.inf
    for h in hs:
        val = loo_mean_loglik_gaussian_iso(X, h)
        if val > best_val:
            best_val, best_h = val, h
    return float(best_h)

def kde_entropy_loo_bits(X, h):
    X = np.asarray(X, float)
    n, d = X.shape
    if n <= 1 or not (np.isfinite(h) and h > 0): return np.nan
    D2 = _pairwise_sq_dists(X)
    A = -D2 / (2.0 * h * h)
    np.fill_diagonal(A, -np.inf)
    row_max = np.max(A, axis=1, keepdims=True)
    lse = row_max[:, 0] + np.log(np.sum(np.exp(A - row_max), axis=1))
    log_const = -np.log(n - 1) - d*np.log(h) - 0.5*d*np.log(2*np.pi)
    log_p_hat_nats = log_const + lse
    return float(-np.mean(log_p_hat_nats) / np.log(2.0))

# Running the simulation on modern specimens
def simulate_communities_on_modern(X_train, Y):
    X_modern = to_numpy_array(X_train)
    Y_modern = np.asarray(Y, int)

    scaler = StandardScaler()
    X_scaled = scaler.fit_transform(X_modern)

    n_pcs = min(MAX_PC_FOR_PCA, X_scaled.shape[1], X_scaled.shape[0] - 1)
    pca = PCA(n_components=n_pcs, random_state=RNG_SEED)
    Z = pca.fit_transform(X_scaled)

    sp_index = build_species_index(Y_modern)
    rng = np.random.default_rng(RNG_SEED)

    comm_indices, shannon_div_bits = [], []
    for _ in range(N_COMM_TRAIN):
        idx = sample_community_indices(sp_index, N_GRAINS, rng, MIN_SPECIES, MAX_SPECIES)
        labels = Y_modern[idx]
        _, counts = np.unique(labels, return_counts=True)
        comm_indices.append(idx)
        shannon_div_bits.append(shannon_diversity_bits_from_counts(counts))

    return Z, comm_indices, np.asarray(shannon_div_bits, float), Y_modern

def compute_kde_entropy_bits_per_comm_multi_h(Z, comm_indices, h_star_by_dim, scales):
    rows = []
    for cid, idx in enumerate(comm_indices):
        pts_all = Z[idx, :]
        for d in (1, 2, 3):
            if d > Z.shape[1]: continue
            pts = pts_all[:, :d]
            if pts.shape[0] <= d: continue
            h_star = h_star_by_dim.get(d, np.nan)
            if not (np.isfinite(h_star) and h_star > 0): continue
            for alpha in scales:
                h = alpha * h_star
                if not (np.isfinite(h) and h > 0): continue
                H_bits = kde_entropy_loo_bits(pts, h=h)
                rows.append(dict(community_id=cid, dim=d, scale=float(alpha), H_bits=H_bits))
    return pd.DataFrame(rows)

# Plotting
def _polish_axes(ax):
    ax.spines['top'].set_visible(False)
    ax.spines['right'].set_visible(False)
    ax.xaxis.set_major_locator(MaxNLocator(nbins=5))
    ax.yaxis.set_major_locator(MaxNLocator(nbins=5))
    ax.xaxis.set_minor_locator(AutoMinorLocator(2))
    ax.yaxis.set_minor_locator(AutoMinorLocator(2))
    ax.tick_params(axis='both', which='major',
                   length=6, width=1.0, direction='in', labelsize=12)
    ax.tick_params(axis='both', which='minor',
                   length=3, width=0.8, direction='in')

def _alpha_math_label_inline(alpha: float) -> str:
    if np.isclose(alpha, 0.5):         return r'$1/2$'
    if np.isclose(alpha, 2**(-0.5)):   return r'$1/\sqrt{2}$'
    if np.isclose(alpha, 1.0):         return r'$1$'
    if np.isclose(alpha, 2**0.5):      return r'$\sqrt{2}$'
    return rf'${alpha:.3g}$'

def _format_p(p):
    return "p < 0.0001" if p < 1e-4 else f"p = {p:.3f}"

def _ols_ci_pi(x, y, xs, conf=0.95):
    n = x.size
    xbar = x.mean()
    ybar = y.mean()
    Sxx = np.sum((x - xbar)**2)
    Sxy = np.sum((x - xbar)*(y - ybar))
    a = Sxy / Sxx
    b = ybar - a * xbar
    yhat_x = a * x + b
    resid = y - yhat_x
    s2 = np.sum(resid**2) / (n - 2)
    s = np.sqrt(s2)
    yhat = a * xs + b

    alpha = 1.0 - conf
    tcrit = t.ppf(1 - alpha/2, df=n-2)

    se_mean = s * np.sqrt(1/n + (xs - xbar)**2 / Sxx)
    ci_lo = yhat - tcrit * se_mean
    ci_hi = yhat + tcrit * se_mean

    se_pred = s * np.sqrt(1 + 1/n + (xs - xbar)**2 / Sxx)
    pi_lo = yhat - tcrit * se_pred
    pi_hi = yhat + tcrit * se_pred

    return a, b, yhat, ci_lo, ci_hi, pi_lo, pi_hi

def plot_sensitivity_grid(df_H, shannon_div_bits, scales, output_pdf=None):
    n_rows, n_cols = len(scales), 3
    fig, axes = plt.subplots(
        n_rows, n_cols,
        figsize=(16, 4.35 * n_rows),
        dpi=180,
        sharey=True,
        gridspec_kw={"wspace": 0.13, "hspace": 0.22}
    )
    if n_rows == 1:
        axes = np.array([axes])

    name_dim = {
        1: "KDE entropy (PC1) [bits]",
        2: "KDE entropy (PC1–PC2) [bits]",
        3: "KDE entropy (PC1–PC2–PC3) [bits]",
    }
    dim_list = (1, 2, 3)

    for r, alpha in enumerate(scales):
        alpha_lbl = _alpha_math_label_inline(alpha)
        for c, d in enumerate(dim_list):
            ax = axes[r, c]
            subd = df_H[(df_H["dim"] == d) & (df_H["scale"] == alpha)]
            if subd.empty:
                ax.axis("off"); continue

            x = subd["H_bits"].values.astype(float)
            y = shannon_div_bits[subd["community_id"].values.astype(int)].astype(float)
            mask = np.isfinite(x) & np.isfinite(y)
            x, y = x[mask], y[mask]

            ax.scatter(x, y, s=POINT_SIZE, alpha=0.6, color=COLOR_POINTS)

            if x.size >= 3:
                xs = np.linspace(np.nanmin(x), np.nanmax(x), 200)
                a, b, yhat, ci_lo, ci_hi, pi_lo, pi_hi = _ols_ci_pi(x, y, xs, conf=0.95)

                ax.plot(xs, pi_lo, color=COLOR_PI, lw=1.6, ls="--")
                ax.plot(xs, pi_hi, color=COLOR_PI, lw=1.6, ls="--")
                ax.plot(xs, ci_lo, color=COLOR_CI, lw=1.6, ls="-")
                ax.plot(xs, ci_hi, color=COLOR_CI, lw=1.6, ls="-")
                ax.plot(xs, yhat, lw=LINE_WIDTH, color=COLOR_LINE)

                r_p, p_p = pearsonr(x, y)
                ax.text(0.02, 0.98,
                        f"r = {r_p:.2f}, {_format_p(p_p)}",
                        transform=ax.transAxes, va="top", ha="left",
                        fontsize=STATS_FONTSIZE, color=COLOR_LINE)

            if r == n_rows - 1:
                ax.set_xlabel(name_dim[d], fontsize=LABEL_FONTSIZE, labelpad=12)
            else:
                ax.set_xlabel("")

            if c == 0:
                ax.set_ylabel("Shannon diversity (bits)", fontsize=LABEL_FONTSIZE, labelpad=YLABEL_PAD)
                ax.text(ALPHA_X, 0.5, r"$\alpha$ = " + alpha_lbl,
                        transform=ax.transAxes, rotation=90,
                        va='center', ha='center', fontsize=ALPHA_FONTSIZE, color=COLOR_LINE,
                        clip_on=False)
            else:
                ax.set_ylabel("")

            _polish_axes(ax)
            ax.grid(True, alpha=0.45, which='major', color=GRID_COLOR)
            ax.grid(True, alpha=0.25, which='minor', color=GRID_COLOR)

    fig.subplots_adjust(bottom=0.08, left=0.12, right=0.985, top=0.985)

    if output_pdf:
        fig.savefig(output_pdf, format="pdf", bbox_inches="tight")
        print(f"\nSaved sensitivity grid to: {output_pdf}")
    plt.show()
    return fig

# Run
# X_train: feature matrix (n_samples × n_features), CNN embeddings (e.g., from MIP images, patches, or both)
# Y: integer species labels per sample (length = n_samples)

print("Starting simulation (KDE-entropy sensitivity grid with CI/PI boundaries) ...")
print(f"Modern data shape: {X_train.shape}; unique species: {len(np.unique(Y))}")

Z_all, comm_indices, shannon_div_bits, Y_all = simulate_communities_on_modern(X_train, Y)

# Optimal h* per dimension obtained via leave-one-out cross-validation
h_star_by_dim = {}
for d in (1, 2, 3):
    if d > Z_all.shape[1]: continue
    Zd = Z_all[:, :d]
    h_star_by_dim[d] = find_h_star_lcv(Zd, **HSTAR_GRID_KW)
    print(f"h* for PC{'1' if d==1 else '1–2' if d==2 else '1–3'}: {h_star_by_dim[d]:.4g}")

df_H_sens = compute_kde_entropy_bits_per_comm_multi_h(Z_all, comm_indices, h_star_by_dim, H_SCALES)
plot_sensitivity_grid(df_H_sens, shannon_div_bits, H_SCALES, output_pdf=OUTPUT_PDF_SENS)

print("\nDone.")