# Calibration Tutorial - Crane, OR - Irrigated Flux Plot

## Step 1: Uncalibrated Model Run

This tutorial focuses on calibrating SWIM-RS for a single irrigated alfalfa plot at the S2 flux station in Crane, Oregon. Unlike the unirrigated Fort Peck example, this site is actively irrigated.

This notebook demonstrates:
1. Loading pre-built model input data
2. Running the uncalibrated SWIM model
3. Comparing model output with OpenET ensemble (PT-JPL, SIMS, SSEBop, geeSEBAL)
4. Validation against flux tower observations using multiple metrics (R², r, RMSE, bias)

**Input Data:** The `data/prepped_input.json` file contains pre-computed input data.

### Data Pipeline

The full data workflow uses two scripts and can be re-run if needed:

1. **`extract_data.py`** - Extracts raw data from Earth Engine and GridMET to CSV/parquet files
2. **`build_inputs.py`** - Processes extracted data through SwimContainer and exports to `prepped_input.json`

To reproduce the input data from scratch:

```bash
cd data
python extract_data.py    # Extract from EE/GridMET (requires authentication)
python build_inputs.py    # Build container and export JSON
```

See `data/extract_data.py` for extraction options and `data/build_inputs.py` for container workflow details.

In [1]:
import os
import sys
import time
import zipfile

import numpy as np
import pandas as pd
from scipy import stats
from sklearn.metrics import mean_squared_error, r2_score

root = os.path.abspath('../..')
sys.path.append(root)

from swimrs.swim.config import ProjectConfig
from swimrs.swim.sampleplots import SamplePlots
from swimrs.model.obs_field_cycle import field_day_loop

from swimrs.viz.swim_timeseries import plot_swim_timeseries

%matplotlib inline

## 1. Project Setup

Define paths and unzip pre-built data if needed.

In [2]:
project_ws = os.path.abspath('.')
data = os.path.join(project_ws, 'data')

config_file = os.path.join(project_ws, '3_Crane.toml')
prepped_input = os.path.join(data, 'prepped_input.json')

# Unzip data files if they haven't been extracted
prepped_zip = os.path.join(data, 'prepped_input.zip')

if os.path.exists(prepped_zip) and not os.path.exists(prepped_input):
    print("Extracting prepped_input.zip...")
    with zipfile.ZipFile(prepped_zip, 'r') as z:
        z.extractall(data)

# Unzip flux tower data if needed
flux_zip = os.path.join(data, 'S2_daily_data.zip')
flux_csv = os.path.join(data, 'S2_daily_data.csv')

if os.path.exists(flux_zip) and not os.path.exists(flux_csv):
    print("Extracting S2_daily_data.zip...")
    with zipfile.ZipFile(flux_zip, 'r') as z:
        z.extractall(data)

print(f"Project workspace: {project_ws}")
print(f"Config file: {config_file}")
print(f"Input data: {prepped_input}")

Project workspace: /home/dgketchum/code/swim-rs/examples/3_Crane
Config file: /home/dgketchum/code/swim-rs/examples/3_Crane/3_Crane.toml
Input data: /home/dgketchum/code/swim-rs/examples/3_Crane/data/prepped_input.json


In [3]:
# Load the project configuration
config = ProjectConfig()
config.read_config(config_file, project_ws)

### Initial Parameter Values

The model will run with the following default parameter values and bounds:

In [4]:
from swimrs.calibrate.pest_builder import PestBuilder

def show_parameter_table(config):
    """Display parameter bounds and initial values from PestBuilder."""
    builder = PestBuilder(config)
    params = builder.initial_parameter_dict()
    
    print("=" * 80)
    print("INITIAL PARAMETER VALUES AND BOUNDS")
    print("=" * 80)
    print(f"{'Parameter':<12} {'Initial':>12} {'Lower':>10} {'Upper':>10} {'Std':>8}  Description")
    print("-" * 80)
    
    descriptions = {
        'aw': 'Available water capacity (mm)',
        'ks_alpha': 'Soil evap stress damping',
        'kr_alpha': 'Root zone stress damping', 
        'ndvi_k': 'NDVI-Kcb slope',
        'ndvi_0': 'NDVI-Kcb intercept',
        'mad': 'Management allowable depletion',
        'swe_alpha': 'Snow melt temp coefficient',
        'swe_beta': 'Snow melt rate coefficient',
    }
    
    for name, p in params.items():
        init = p['initial_value']
        if init is None:
            init_str = 'auto'
        elif isinstance(init, str):
            init_str = init[:12]
        else:
            init_str = f"{init:.2f}"
        print(f"{name:<12} {init_str:>12} {p['lower_bound']:>10.2f} {p['upper_bound']:>10.2f} {p['std']:>8.2f}  {descriptions.get(name, '')}")
    
    print("=" * 80)

show_parameter_table(config)

Using default Python script at: /home/dgketchum/code/swim-rs/src/swimrs/calibrate/custom_forward_run.py
INITIAL PARAMETER VALUES AND BOUNDS
Parameter         Initial      Lower      Upper      Std  Description
--------------------------------------------------------------------------------
aw                   auto     100.00     400.00    50.00  Available water capacity (mm)
ks_alpha             0.50       0.01       1.00     0.15  Soil evap stress damping
kr_alpha             0.50       0.01       1.00     0.15  Root zone stress damping
ndvi_k               7.00       4.00      10.00     0.75  NDVI-Kcb slope
ndvi_0               0.40       0.10       0.70     0.25  NDVI-Kcb intercept
mad                  auto       0.01       0.90     0.15  Management allowable depletion
swe_alpha            0.30      -0.50       1.00     0.20  Snow melt temp coefficient
swe_beta             1.50       0.50       2.50     0.30  Snow melt rate coefficient




## 2. About the Study Site

The S2 site is an irrigated alfalfa field in Crane, Oregon. According to IrrMapper data, this location has been irrigated since about 1996, making it a good test case for the irrigation scheduling component of SWIM-RS.

In [5]:
selected_feature = 'S2'

print(f"Site: {selected_feature}")
print(f"Location: Crane, Oregon")
print(f"Crop: Irrigated alfalfa")
print(f"Date range: {config.start_dt} to {config.end_dt}")

Site: S2
Location: Crane, Oregon
Crop: Irrigated alfalfa
Date range: 1987-01-01 00:00:00 to 2022-12-31 00:00:00


## 3. Run the Uncalibrated Model

We define a helper function to run the SWIM model and capture its output.

In [6]:
def run_fields(ini_path, project_ws, selected_feature, output_csv, forecast=False):
    """Run SWIM model and save combined input/output to CSV."""
    start_time = time.time()

    config = ProjectConfig()
    config.read_config(ini_path, project_ws, forecast=forecast)

    fields = SamplePlots()
    fields.initialize_plot_data(config)
    fields.output = field_day_loop(config, fields, debug_flag=True)

    end_time = time.time()
    print(f'\nExecution time: {end_time - start_time:.2f} seconds\n')

    out_df = fields.output[selected_feature].copy()
    in_df = fields.input_to_dataframe(selected_feature)
    
    # Drop columns from input that already exist in output to avoid duplicates
    overlap_cols = out_df.columns.intersection(in_df.columns)
    if len(overlap_cols) > 0:
        in_df = in_df.drop(columns=overlap_cols)
    
    df = pd.concat([out_df, in_df], axis=1, ignore_index=False)
    
    # Cut out nan output from before the start of the model run
    df = df.loc[config.start_dt:config.end_dt]
    
    df.to_csv(output_csv)
    return df

In [7]:
selected_feature = 'S2'
out_csv = os.path.join(project_ws, f'combined_output_{selected_feature}_uncalibrated.csv')

df = run_fields(config_file, project_ws, selected_feature=selected_feature, output_csv=out_csv)

USING PARAMETER DEFAULTS

Execution time: 1.79 seconds



In [8]:
print(f"Output shape: {df.shape}")
print(f"Date range: {df.index[0]} to {df.index[-1]}")
print(f"\nKey output columns:")
key_cols = ['et_act', 'etref', 'kc_act', 'kc_bas', 'ks', 'ke', 'melt', 'rain', 
            'depl_root', 'swe', 'ppt', 'irrigation', 'soil_water']
for col in key_cols:
    if col in df.columns:
        print(f"  {col}: mean={df[col].mean():.3f}, max={df[col].max():.3f}")

Output shape: (1826, 73)
Date range: 2003-01-01 00:00:00 to 2007-12-31 00:00:00

Key output columns:
  et_act: mean=1.653, max=6.797
  etref: mean=2.884, max=8.528
  kc_act: mean=0.589, max=0.963
  kc_bas: mean=0.426, max=0.967
  ks: mean=0.989, max=1.000
  ke: mean=0.337, max=0.698
  melt: mean=0.140, max=9.366
  rain: mean=0.586, max=25.900
  depl_root: mean=37.361, max=157.927
  swe: mean=1.541, max=59.716
  ppt: mean=0.736, max=25.900
  irrigation: mean=1.017, max=25.400
  soil_water: mean=252.309, max=336.000


## 4. Visualize Model Output

Let's examine a single year (2004) to see the model's behavior.

In [9]:
ydf = df.loc['2004-01-01': '2004-12-31']
print(f'Total irrigation: {ydf.irrigation.sum():.1f} mm')
print(f'Total ET: {ydf.et_act.sum():.1f} mm')
print(f'Total precipitation: {ydf.ppt.sum():.1f} mm')

plot_swim_timeseries(ydf, ['et_act', 'etref', 'rain', 'melt', 'irrigation'], 
                     start='2004-01-01', end='2004-12-31', png_dir='et_uncalibrated.png')

Total irrigation: 321.2 mm
Total ET: 553.4 mm
Total precipitation: 267.7 mm
et_uncalibrated.png


## 5. Compare with Flux Tower Observations

We compare three estimates of actual ET (mm/day):

1. **SWIM ET**: Model-estimated actual evapotranspiration (daily)
2. **OpenET Ensemble ET**: Remote sensing retrievals from OpenET (PT-JPL, SIMS, SSEBop, geeSEBAL) averaged together (ETf × ETo)
3. **Flux ET**: Independent observations from the S2 eddy covariance tower (Volk et al.)

We show two comparisons:
- **Capture dates only**: Both methods compared only on Landsat overpass dates
- **Full time series**: SWIM (daily) vs OpenET (interpolated between Landsat dates) on all flux tower days

In [10]:
def compare_et_estimates(combined_output_path, flux_data_path, irr=True):
    """Compare model ET and OpenET ensemble ET against flux tower observations.
    
    Returns two comparison DataFrames:
    1. Capture dates only: Both methods on Landsat overpass dates only
    2. Full time series: SWIM daily, OpenET interpolated, on all flux tower days
    
    Reports R², Pearson r, bias, and RMSE for each comparison.
    """
    flux_data = pd.read_csv(flux_data_path, index_col='date', parse_dates=True)
    flux_et = flux_data['ET']  # Actual ET from flux tower (mm/day)

    output = pd.read_csv(combined_output_path, index_col=0)
    output.index = pd.to_datetime(output.index)

    # Determine suffix based on irrigation mask
    mask_suffix = 'irr' if irr else 'inv_irr'
    
    # OpenET ensemble models - compute mean ETf across available models
    ensemble_models = ['ptjpl', 'sims', 'ssebop', 'geesebal']
    etf_cols = []
    for model in ensemble_models:
        col_name = f'etf_{model}_{mask_suffix}'
        if col_name in output.columns:
            etf_cols.append(col_name)
    
    # Fallback to single SSEBop if ensemble columns not available
    if not etf_cols:
        etf_col = f'etf_{mask_suffix}'
        if etf_col in output.columns:
            etf_cols = [etf_col]
            print(f"Using single ETf column: {etf_col}")
        else:
            print(f"Warning: No ETf columns found for mask '{mask_suffix}'")
            return pd.DataFrame(), pd.DataFrame()
    
    print(f"Using ETf columns: {etf_cols}")
    
    # Compute ensemble mean ETf (ignoring NaN)
    ensemble_etf = output[etf_cols].mean(axis=1, skipna=True)
    
    # Calculate actual ET from OpenET ensemble: ETf × ETo (sparse, only on Landsat dates)
    openet_et_sparse = ensemble_etf * output['etref']
    
    # Linear interpolation of OpenET to get daily values
    openet_et_interp = openet_et_sparse.interpolate(method='linear')
    
    # Count original OpenET observations
    n_openet_obs = openet_et_sparse.notna().sum()

    # CAPTURE DATES ONLY comparison (OpenET sparse)
    capture_df = pd.DataFrame({
        'swim_et': output['et_act'],
        'openet_et': openet_et_sparse,
        'flux_et': flux_et
    }).dropna()

    # FULL TIME SERIES comparison (OpenET interpolated)
    full_df = pd.DataFrame({
        'swim_et': output['et_act'],
        'openet_et': openet_et_interp,
        'flux_et': flux_et
    }).dropna()

    def calc_metrics(df, col1, col2):
        r, _ = stats.pearsonr(df[col1], df[col2])
        r2 = r2_score(df[col1], df[col2])
        rmse = np.sqrt(mean_squared_error(df[col1], df[col2]))
        bias = (df[col2] - df[col1]).mean()
        return r2, r, rmse, bias

    # Capture dates metrics
    r2_swim_cap, r_swim_cap, rmse_swim_cap, bias_swim_cap = calc_metrics(capture_df, 'flux_et', 'swim_et')
    r2_openet_cap, r_openet_cap, rmse_openet_cap, bias_openet_cap = calc_metrics(capture_df, 'flux_et', 'openet_et')

    # Full time series metrics
    r2_swim_full, r_swim_full, rmse_swim_full, bias_swim_full = calc_metrics(full_df, 'flux_et', 'swim_et')
    r2_openet_full, r_openet_full, rmse_openet_full, bias_openet_full = calc_metrics(full_df, 'flux_et', 'openet_et')

    print("="*70)
    print(f"CAPTURE DATES ONLY ({len(capture_df)} Landsat overpass dates)")
    print("="*70)
    print(f"{'Metric':<12} {'SWIM ET':>12} {'OpenET ET':>12}")
    print("-" * 38)
    print(f"{'R²':<12} {r2_swim_cap:>12.3f} {r2_openet_cap:>12.3f}")
    print(f"{'Pearson r':<12} {r_swim_cap:>12.3f} {r_openet_cap:>12.3f}")
    print(f"{'Bias (mm)':<12} {bias_swim_cap:>12.3f} {bias_openet_cap:>12.3f}")
    print(f"{'RMSE (mm)':<12} {rmse_swim_cap:>12.3f} {rmse_openet_cap:>12.3f}")
    
    print()
    print("="*70)
    print(f"FULL TIME SERIES ({len(full_df)} days, OpenET interpolated from {n_openet_obs} obs)")
    print("="*70)
    print(f"{'Metric':<12} {'SWIM ET':>12} {'OpenET ET':>12}")
    print("-" * 38)
    print(f"{'R²':<12} {r2_swim_full:>12.3f} {r2_openet_full:>12.3f}")
    print(f"{'Pearson r':<12} {r_swim_full:>12.3f} {r_openet_full:>12.3f}")
    print(f"{'Bias (mm)':<12} {bias_swim_full:>12.3f} {bias_openet_full:>12.3f}")
    print(f"{'RMSE (mm)':<12} {rmse_swim_full:>12.3f} {rmse_openet_full:>12.3f}")

    return full_df, capture_df

In [11]:
# Use irrigated mask since this is an irrigated site
flux_data = os.path.join(data, 'S2_daily_data.csv')
full_df, capture_df = compare_et_estimates(out_csv, flux_data, irr=True)

Using single ETf column: etf_irr
Using ETf columns: ['etf_irr']


ValueError: `x` and `y` must have length at least 2.

In [None]:
import matplotlib.pyplot as plt

# Create 2x2 scatter plots for both comparisons
fig, axes = plt.subplots(2, 2, figsize=(12, 10))

# Helper function to calculate metrics
def calc_metrics(df, col1, col2):
    r, _ = stats.pearsonr(df[col1], df[col2])
    r2 = r2_score(df[col1], df[col2])
    rmse = np.sqrt(mean_squared_error(df[col1], df[col2]))
    return r2, r, rmse

# Determine axis limits
max_et = max(full_df['flux_et'].max(), full_df['swim_et'].max(), 
             full_df['openet_et'].max()) * 1.1

# TOP ROW: Capture dates only
r2_swim_cap, r_swim_cap, rmse_swim_cap = calc_metrics(capture_df, 'flux_et', 'swim_et')
r2_openet_cap, r_openet_cap, rmse_openet_cap = calc_metrics(capture_df, 'flux_et', 'openet_et')

ax = axes[0, 0]
ax.scatter(capture_df['flux_et'], capture_df['swim_et'], alpha=0.5, s=15)
ax.plot([0, max_et], [0, max_et], 'r--', label='1:1 line')
ax.set_xlabel('Flux ET (mm/day)')
ax.set_ylabel('SWIM ET (mm/day)')
ax.set_title(f'SWIM vs Flux - Capture Dates (n={len(capture_df)})\n'
             f'R² = {r2_swim_cap:.3f}, r = {r_swim_cap:.3f}, RMSE = {rmse_swim_cap:.2f} mm')
ax.legend()
ax.set_xlim(0, max_et)
ax.set_ylim(0, max_et)

ax = axes[0, 1]
ax.scatter(capture_df['flux_et'], capture_df['openet_et'], alpha=0.5, s=15)
ax.plot([0, max_et], [0, max_et], 'r--', label='1:1 line')
ax.set_xlabel('Flux ET (mm/day)')
ax.set_ylabel('OpenET Ensemble ET (mm/day)')
ax.set_title(f'OpenET vs Flux - Capture Dates (n={len(capture_df)})\n'
             f'R² = {r2_openet_cap:.3f}, r = {r_openet_cap:.3f}, RMSE = {rmse_openet_cap:.2f} mm')
ax.legend()
ax.set_xlim(0, max_et)
ax.set_ylim(0, max_et)

# BOTTOM ROW: Full time series comparison
r2_swim, r_swim, rmse_swim = calc_metrics(full_df, 'flux_et', 'swim_et')
r2_openet, r_openet, rmse_openet = calc_metrics(full_df, 'flux_et', 'openet_et')

ax = axes[1, 0]
ax.scatter(full_df['flux_et'], full_df['swim_et'], alpha=0.3, s=10)
ax.plot([0, max_et], [0, max_et], 'r--', label='1:1 line')
ax.set_xlabel('Flux ET (mm/day)')
ax.set_ylabel('SWIM ET (mm/day)')
ax.set_title(f'SWIM vs Flux - Full Series (n={len(full_df)})\n'
             f'R² = {r2_swim:.3f}, r = {r_swim:.3f}, RMSE = {rmse_swim:.2f} mm')
ax.legend()
ax.set_xlim(0, max_et)
ax.set_ylim(0, max_et)

ax = axes[1, 1]
ax.scatter(full_df['flux_et'], full_df['openet_et'], alpha=0.3, s=10)
ax.plot([0, max_et], [0, max_et], 'r--', label='1:1 line')
ax.set_xlabel('Flux ET (mm/day)')
ax.set_ylabel('OpenET Ensemble ET (mm/day)')
ax.set_title(f'OpenET vs Flux - Full Series, interpolated (n={len(full_df)})\n'
             f'R² = {r2_openet:.3f}, r = {r_openet:.3f}, RMSE = {rmse_openet:.2f} mm')
ax.legend()
ax.set_xlim(0, max_et)
ax.set_ylim(0, max_et)

plt.tight_layout()
plt.savefig('comparison_scatter_uncalibrated.png', dpi=150)
plt.show()

## Summary

The uncalibrated model using default parameters shows the baseline performance before calibration. We compared both SWIM and the OpenET ensemble (PT-JPL, SIMS, SSEBop, geeSEBAL average) against independent flux tower observations.

**Two comparison modes:**
- **Capture dates**: Only Landsat overpass dates where we have satellite observations
- **Full time series**: All flux tower days, with OpenET values interpolated between satellite dates

**Key observations:**
- The model isn't applying enough irrigation
- The NDVI-to-Kcb relationship needs tuning for alfalfa
- Soil parameters may not match the actual site conditions
- OpenET ensemble provides robust remote sensing benchmark

**Next step:** In notebook `02_calibration.ipynb`, we'll use PEST++ to calibrate the model parameters using SSEBop ETf and SNODAS SWE observations.

**Key insight:** We're not using the flux data for calibration - it's only for validation. For calibration, we rely solely on widely-available remote sensing data (ETf and SNODAS SWE).

## Optional: Query Data from SwimContainer

If you've built the container using `build_inputs.py`, you can query ingested data directly:

In [None]:
# Query container data (optional - requires build_inputs.py to have been run)
from swimrs.container import SwimContainer

container_path = os.path.join(data, '3_Crane.swim')

if os.path.exists(container_path):
    container = SwimContainer.open(container_path, mode='r')
    
    # List available fields
    print(f"Fields in container: {container.field_uids}")
    
    # Get all time series for a single field using field_timeseries
    ts_df = container.query.field_timeseries('S2')
    print(f"\nTime series shape: {ts_df.shape}")
    print(f"Variables: {list(ts_df.columns)[:10]}...")
    
    # Query specific data using dataframe with zarr paths
    # Path structure: remote_sensing/{type}/{instrument}/{model}/{mask}
    ndvi_df = container.query.dataframe("remote_sensing/ndvi/landsat/irr", fields=['S2'])
    print(f"\nNDVI observations: {ndvi_df.notna().sum().values[0]}")
    
    etf_df = container.query.dataframe("remote_sensing/etf/landsat/ssebop/irr", fields=['S2'])
    print(f"ETf observations: {etf_df.notna().sum().values[0]}")
    
    # Show container status
    print("\n" + container.query.status())
    
    container.close()
else:
    print(f"Container not found at {container_path}")
    print("Run: cd data && python build_inputs.py")