In [None]:
import pandas as pd
import h3
import matplotlib.pyplot as plt
import seaborn as sns
from pathlib import Path

# Import our new modules
from texas_mushrooms.pipeline import spatial, weather
from texas_mushrooms.modeling.bayesian import BayesianMushroomModel

# Set plot style
sns.set_theme(style="whitegrid")
%matplotlib inline

# Create outputs directory
OUTPUT_DIR = Path("../data/outputs")
OUTPUT_DIR.mkdir(parents=True, exist_ok=True)

## 1. Load Data and Add H3 Indices

We'll load the processed geospatial data and add H3 indices at Resolution 7 (~1.2km edge length).

In [None]:
# Load data (assuming processed data exists, otherwise fallback to raw photos)
data_path = Path("../data/processed/photo_geospatial.csv")
if not data_path.exists():
    print("Processed data not found. Loading raw photos...")
    data_path = Path("../data/raw/photos.csv")

df = pd.read_csv(data_path)

# Ensure we have lat/lon
if "latitude" not in df.columns:
    # If using raw photos, we might need to merge with days or handle missing coords
    print("Warning: Latitude/Longitude columns missing. Please run the processing pipeline first.")
else:
    # Add H3 Indices
    df_h3 = spatial.add_h3_indices(df, resolution=7)
    print(f"Added H3 indices. Unique cells: {df_h3['h3_index'].nunique()}")
    display(df_h3.head())

Added H3 indices. Unique cells: 99


Unnamed: 0,date,page_url,photo_url,latitude,longitude,label_species,h3_index
0,2024-11-15,https://www.texasmushrooms.org/date-en/2024-11...,https://www.texasmushrooms.org/asergeev/pictur...,30.558571,-96.209577,Unidentified,87446d05bffffff
1,2024-11-15,https://www.texasmushrooms.org/date-en/2024-11...,https://www.texasmushrooms.org/asergeev/pictur...,30.558598,-96.209812,Unidentified,87446d05bffffff
2,2024-11-14,https://www.texasmushrooms.org/date-en/2024-11...,https://www.texasmushrooms.org/asergeev/pictur...,30.561051,-96.214087,Xylodon flaviporus,87446d05bffffff
3,2024-11-14,https://www.texasmushrooms.org/date-en/2024-11...,https://www.texasmushrooms.org/asergeev/pictur...,30.561051,-96.214087,Xylodon flaviporus,87446d05bffffff
4,2024-11-14,https://www.texasmushrooms.org/date-en/2024-11...,https://www.texasmushrooms.org/asergeev/pictur...,30.561051,-96.214087,Xylodon flaviporus,87446d05bffffff


## 2. Enrich with Elevation

We will fetch elevation data for each unique H3 cell centroid using the Open-Meteo API.

In [None]:
# Get unique H3 indices
unique_h3 = df_h3["h3_index"].dropna().unique()

# Get centroids
centroids = [spatial.get_h3_centroid(h) for h in unique_h3]
lats = [c[0] for c in centroids]
lons = [c[1] for c in centroids]

# Fetch elevation (batch)
print(f"Fetching elevation for {len(unique_h3)} unique cells...")
elevations = weather.fetch_elevation_batch(lats, lons)

# Create a mapping
h3_elevation = dict(zip(unique_h3, elevations))

# Map back to dataframe
df_h3["elevation"] = df_h3["h3_index"].map(h3_elevation)

# Save H3 enriched data
df_h3.to_csv(OUTPUT_DIR / "h3_enriched_data.csv", index=False)
print(f"H3 enriched data saved to {OUTPUT_DIR / 'h3_enriched_data.csv'}")

display(df_h3[["h3_index", "latitude", "longitude", "elevation"]].head())

NameError: name 'df_h3' is not defined

## 3. Prepare Data for Modeling

We need to aggregate counts by H3 cell and Date to create a target variable for our Poisson model.

In [None]:
# Aggregate counts
# We need a date column. If using raw photos, we might need to extract it or merge.
# Assuming 'date' exists or can be derived.
if "date" in df_h3.columns:
    daily_counts = df_h3.groupby(["h3_index", "date"]).size().reset_index(name="count")
    
    # Merge with elevation (static per cell)
    daily_counts["elevation"] = daily_counts["h3_index"].map(h3_elevation)
    
    # Fill missing elevations if any
    daily_counts = daily_counts.dropna(subset=["elevation"])
    
    # Save aggregated data
    daily_counts.to_csv(OUTPUT_DIR / "daily_counts_by_h3.csv", index=False)
    print(f"Daily counts saved to {OUTPUT_DIR / 'daily_counts_by_h3.csv'}")
    
    print("Modeling Data Prepared:")
    display(daily_counts.head())
else:
    print("Date column missing. Cannot aggregate for temporal modeling.")

Modeling Data Prepared:


Unnamed: 0,h3_index,date,count,elevation
0,874468428ffffff,2022-11-02,1,112.0
1,874468510ffffff,2020-11-14,6,52.0
2,874468510ffffff,2021-11-27,33,52.0
3,874468514ffffff,2013-11-10,5,40.0
4,874468514ffffff,2019-11-23,9,40.0


## 4. Bayesian Modeling (PyMC)

We will define a Zero-Inflated Poisson model.
$y \sim \text{ZIP}(\psi, \mu)$
$\log(\mu) = \alpha + \beta \cdot \text{Elevation}$

(Note: In a real scenario, we would merge weather data here as well).

In [None]:
if "date" in df_h3.columns and "daily_counts" in locals() and not daily_counts.empty:
    # Initialize model
    model = BayesianMushroomModel(daily_counts, target_col="count")
    
    # Build ZIP model using Elevation as a predictor
    # We standardize elevation for better sampling
    elev_std = daily_counts["elevation"].std()
    if elev_std == 0 or pd.isna(elev_std):
        print("Warning: Elevation standard deviation is 0 or NaN. Using raw elevation (centered).")
        daily_counts["elevation_std"] = daily_counts["elevation"] - daily_counts["elevation"].mean()
    else:
        daily_counts["elevation_std"] = (daily_counts["elevation"] - daily_counts["elevation"].mean()) / elev_std
    
    print("Building model...")
    model.build_zip_model(predictors=["elevation_std"])
    
    # Sample
    print("Sampling...")
    # Reduced draws for demonstration speed
    # cores=1 is critical for Windows to prevent hangs
    model.sample(draws=500, tune=500, chains=2, cores=1)
    
    # Save trace summary to CSV
    summary = az.summary(model.trace)
    summary.to_csv(OUTPUT_DIR / "model_summary.csv")
    print(f"Model summary saved to {OUTPUT_DIR / 'model_summary.csv'}")
    
    # Create improved trace plot with clear labels
    import arviz as az
    
    fig, axes = plt.subplots(3, 2, figsize=(12, 10))
    
    var_labels = {
        "alpha": "α (Intercept)",
        "betas": "β (Elevation Effect)", 
        "psi": "ψ (Zero-Inflation Prob)"
    }
    
    for i, var in enumerate(["alpha", "betas", "psi"]):
        # Posterior distribution (left column)
        az.plot_posterior(model.trace, var_names=[var], ax=axes[i, 0], 
                         hdi_prob=0.94, point_estimate="mean")
        axes[i, 0].set_title(f"{var_labels[var]} - Posterior Distribution", fontsize=11)
        axes[i, 0].set_xlabel("Parameter Value")
        axes[i, 0].set_ylabel("Density")
        
        # Trace plot (right column) 
        az.plot_trace(model.trace, var_names=[var], axes=axes[i:i+1, 1:2], 
                     combined=True, compact=True)
        axes[i, 1].set_title(f"{var_labels[var]} - MCMC Trace", fontsize=11)
        axes[i, 1].set_xlabel("Sample")
        axes[i, 1].set_ylabel("Parameter Value")
    
    plt.suptitle("Zero-Inflated Poisson Model: Parameter Estimates\n(Mushroom Count ~ Elevation)", 
                 fontsize=14, fontweight='bold', y=1.02)
    plt.tight_layout()
    
    # Save figure
    fig.savefig(OUTPUT_DIR / "model_trace_plot.png", dpi=150, bbox_inches="tight")
    print(f"Trace plot saved to {OUTPUT_DIR / 'model_trace_plot.png'}")
    
    plt.show()
    
    # Print interpretation
    print("\n" + "="*60)
    print("PARAMETER INTERPRETATION:")
    print("="*60)
    print(f"α (alpha): Baseline log-rate of mushroom counts when elevation = mean")
    print(f"β (betas): Change in log-rate per 1 std deviation increase in elevation")
    print(f"ψ (psi): Probability of 'structural zeros' (no mushrooms possible)")
    print("="*60)
    display(summary)
else:
    print("Skipping modeling due to missing data or daily_counts not defined.")

Building model...
Sampling...
Sampling...


Initializing NUTS using jitter+adapt_diag...
Multiprocess sampling (4 chains in 4 jobs)
NUTS: [alpha, betas, psi]
Multiprocess sampling (4 chains in 4 jobs)
NUTS: [alpha, betas, psi]


: 

: 