# Generic Scaling-up Framework

**Generic framework with BASE_SCALE concept**

Framework for scaling up from any base scale unit to different spatial boundaries using area weighting.

## Key Concepts
- **BASE_SCALE**: The atomic unit (finest resolution) that all other scales aggregate from
- **AGGREGATION_SCALES**: Larger spatial boundaries that contain multiple base scale units
- **NULL AGGREGATOR**: BASE_SCALE acts as identity aggregator (no spatial aggregation)

In [None]:
# Standard library imports
import os
import sys
from pathlib import Path

# Third-party imports
import pandas as pd
import numpy as np
import geopandas as gpd
from tqdm.notebook import tqdm
import matplotlib.pyplot as plt

# Local imports
from scaling_up_framework_functions import (
    pivot_year, aggregate_measure_weighted, standardise_z, normalise,
    SpatialScale, validate_geometries, assign_functional_group, validate_csv_data,
    plot_base_scale, plot_spatial_hierarchy
)
from config import BASE_SCALE, GROUP_RULES, INCLUDE_UNMATCHED_TYPES, SPATIAL_FILES, DATA_PATH, OUTPUT_PATH, DEFAULT_CRS

## Configuration and Setup

In [None]:
# Ensure output directory exists
OUTPUT_PATH.mkdir(parents=True, exist_ok=True)

# Initialize data containers
base_scale = None  # The atomic unit for all aggregations
aggregation_scales = {}  # Higher-level spatial boundaries
data_cache = {}

## Load Base Scale Data (Atomic Units)

**Grouping Behavior:**
- `type_field = None`: All features aggregated together
- `type_field` defined + `GROUP_RULES`: Uses functional groups
- `type_field` defined, no `GROUP_RULES`: Uses original type values

In [None]:
def load_base_scale() -> SpatialScale:
    """Load the base scale - the atomic unit for all aggregations."""
    base = SpatialScale(
        name=BASE_SCALE["name"],
        source=BASE_SCALE["file"],
        unique_id_field=BASE_SCALE["unique_id"],
        measure_field=BASE_SCALE["measure_field"],
        type_field=BASE_SCALE["type_field"],
        aggregation_method=BASE_SCALE.get("aggregation_method", "geometry"),
        is_base_scale=True
    )
    
    # Validate geometries
    base.data = validate_geometries(base.data)
    
    # Handle grouping based on type_field availability
    # This determines how features are grouped for separate analysis
    if BASE_SCALE.get("type_field") is None:
        print("WARNING: type_field not defined - all features will be aggregated together")
        print("Result: Single output per spatial scale (no sub-grouping)")
        base.data["grp"] = "all_features"
    else:
        # Assign functional groups or use original type field as fallback
        base.data["grp"] = base.data[BASE_SCALE["type_field"]].apply(
            lambda x: assign_functional_group(x, BASE_SCALE["name"], GROUP_RULES, INCLUDE_UNMATCHED_TYPES)
        )
        # If no functional groups assigned, use original type field
        if base.data["grp"].isna().all():
            print(f"No functional groups defined for {BASE_SCALE['name']} - using original type field")
            base.data["grp"] = base.data[BASE_SCALE["type_field"]]
        else:
            print(f"Using functional groups for {BASE_SCALE['name']}")
    
    return base

# Load base scale data
base_scale = load_base_scale()
print(f"BASE_SCALE: {base_scale}")

# Plot base scale
plot_base_scale(base_scale, BASE_SCALE)

## Load Aggregation Scale Boundaries

In [None]:
def create_aggregation_scale(name: str, source_key: str, unique_field: str) -> pd.DataFrame:
    """Create aggregation scale by spatially joining with base scale."""
    agg_scale = SpatialScale(name, SPATIAL_FILES[source_key], unique_field)
    
    # Spatial join with base scale
    joined = gpd.sjoin(
        base_scale.data, agg_scale.data,
        how="left", predicate="intersects"
    ).dropna().set_index(BASE_SCALE["unique_id"])
    
    return joined

# Define aggregation scales
AGGREGATION_CONFIGS = {
    "Valley": {"source_key": "Valley", "unique_field": "BWS_Region"},
    "DIWA": {"source_key": "DIWA", "unique_field": "WNAME"},
    "Ramsar": {"source_key": "Ramsar", "unique_field": ["RAMSAR_NAM", "WETLAND_NA"]},
    "NorthSouthBasin": {"source_key": "NorthSouthBasin", "unique_field": "Region"}
}

# Load all aggregation scales
agg_scale_objects = {}
for scale_name, config in AGGREGATION_CONFIGS.items():
    aggregation_scales[scale_name] = create_aggregation_scale(
        scale_name, config["source_key"], config["unique_field"]
    )
    agg_scale_objects[scale_name] = SpatialScale(scale_name, SPATIAL_FILES[config["source_key"]], config["unique_field"])
    print(f"AGGREGATION_SCALE {scale_name}: {len(aggregation_scales[scale_name])} intersections")

# Plot all scales together
plot_spatial_hierarchy(base_scale, BASE_SCALE, agg_scale_objects)

## Process Metric Data

In [None]:
def load_metric_data(metric_name: str) -> pd.DataFrame:
    """Load metric data from CSV file or spatial data attribute."""
    if metric_name in data_cache:
        return data_cache[metric_name]
    
    print(f"Loading {metric_name} data...")
    
    # First try to find CSV file
    csv_files = list(DATA_PATH.glob(f"*{metric_name}*.csv"))
    
    if csv_files:
        # Load from CSV file
        csv_file = csv_files[0]
        print(f"Found CSV: {csv_file.name}")
        data = pd.read_csv(csv_file)
        data = validate_csv_data(data, metric_name, BASE_SCALE["unique_id"])
    else:
        # Try to load from spatial data attribute
        print(f"No CSV found, checking if '{metric_name}' exists in spatial data...")
        if metric_name not in base_scale.data.columns:
            raise FileNotFoundError(f"Metric '{metric_name}' not found in CSV files or spatial data columns.\n"
                                  f"Available spatial columns: {list(base_scale.data.columns)}")
        
        print(f"Using '{metric_name}' from spatial data")
        # Create DataFrame from spatial data
        data = base_scale.data[[BASE_SCALE["unique_id"], metric_name]].copy()
        data = data.reset_index(drop=True)
        
        # Validate metric column is numeric
        if not pd.api.types.is_numeric_dtype(data[metric_name]):
            data[metric_name] = pd.to_numeric(data[metric_name], errors='coerce')
            nan_count = data[metric_name].isna().sum()
            if nan_count > 0:
                print(f"Warning: {nan_count} non-numeric values converted to NaN")
                data = data.dropna(subset=[metric_name])
    
    # Normalize values
    data[metric_name] = normalise(data[metric_name])
    
    data_cache[metric_name] = data
    return data

# Load NDVI data
ndvi_data = load_metric_data("NDVI")
print(f"Loaded NDVI data: {len(ndvi_data)} records")

## Analysis and Output Generation

In [None]:
def process_all_scales(data: pd.DataFrame, metric_name: str) -> None:
    """Process metric across all spatial scales."""
    
    # Create pivot table with baseline statistics
    pivot_data = pivot_year(data, metric_name, BASE_SCALE["unique_id"])
    
    # Process BASE_SCALE (null aggregator)
    print(f"Processing {metric_name} for BASE_SCALE ({BASE_SCALE['name']})...")
    base_result = aggregate_measure_weighted(
        pivot_data, None, [], base_scale.measure_field, ["grp"], is_base_scale=True
    )
    output_file = OUTPUT_PATH / f"{metric_name}_{BASE_SCALE['name']}.csv"
    base_result.round(4).to_csv(output_file)
    print(f"Saved: {output_file}")
    
    # Process each AGGREGATION_SCALE
    for scale_name, aggregator in aggregation_scales.items():
        print(f"Processing {metric_name} for AGGREGATION_SCALE ({scale_name})...")
        
        agg_result = aggregate_measure_weighted(
            pivot_data, aggregator, [scale_name], base_scale.measure_field, ["grp"]
        )
        
        output_file = OUTPUT_PATH / f"{metric_name}_{scale_name}.csv"
        agg_result.round(4).to_csv(output_file)
        print(f"Saved: {output_file}")

# Process NDVI across all scales
process_all_scales(ndvi_data, "NDVI")

## Summary and Validation

In [None]:
# Display scaling hierarchy
print("\n=== SPATIAL SCALING HIERARCHY ===")
print(f"BASE_SCALE (atomic units): {base_scale}")
print("\nAGGREGATION_SCALES:")
for name, data in aggregation_scales.items():
    print(f"  {name}: {len(data)} base units mapped")

# Display year range if year column exists
if 'year' in ndvi_data.columns:
    print(f"\nMetric data years: {ndvi_data['year'].min()}-{ndvi_data['year'].max()}")
else:
    print(f"\nMetric data: Single time point (no year column)")

print(f"Output files: {len(list(OUTPUT_PATH.glob('*.csv')))}")

# List output files by scale
print("\nOutput files by scale:")
for file in sorted(OUTPUT_PATH.glob("*.csv")):
    scale_type = "BASE_SCALE" if BASE_SCALE["name"] in file.name else "AGGREGATION_SCALE"
    print(f"  {scale_type}: {file.name}")