# Set up the environment

Before running the notebook, please install the necessary packages and environment by running the following shell commands in your terminal:

```bash
# Create the conda environment from the provided environment file
conda env create -f ../conda_env_pkgs.yml -n soc_model_env

# Activate the new environment
conda activate soc_model_env

# Launch Jupyter Notebook from within the environment
jupyter notebook


In [83]:
import os
import sys
from dateutil.relativedelta import relativedelta
from datetime import datetime
#import nbimporter

import ee
import geemap

import pandas as pd
import numpy as np
import seaborn as sns
import matplotlib.pyplot as plt

from pprint import pprint as pp

#Import relevant functions from soc_predictive_model.ipynb
#from soc_predictive_model import temporal_filter

# Authenticate and Initialize Earth Engine
ee.Authenticate()
ee.Initialize(project= "ee-christopherharrellgis")

#Initialize Map
Map = geemap.Map()

### Data Imports

In [204]:
# SOC samples table loaded from earth engine asset
soc_sample_points = ee.FeatureCollection("users/christopherharrellgis/soc_samples")

# Buffer sample points by 10km and return bounds
study_area = soc_sample_points.geometry().bounds().buffer(10000).bounds()

# Harmonized Sentinel-2 SR and Cloud Score
s2 = ee.ImageCollection('COPERNICUS/s2')
s2CloudScore = ee.ImageCollection('GOOGLE/CLOUD_SCORE_PLUS/V1/S2_HARMONIZED')

# Landsat 
ls = ee.ImageCollection("LANDSAT/LC08/C02/T1_TOA")

# DEM
dem = ee.ImageCollection('COPERNICUS/DEM/GLO30')

# CHIRPS Daily
chirps = ee.ImageCollection('UCSB-CHG/CHIRPS/DAILY')



##### Define Global Variables

In [175]:
# Global variables
output_data_folder = "../soc/data/"

# Start/End Dates
p1_start_date = "2022-03-01"
p1_end_date = "2023-02-28"
 
p2_start_date = "2023-03-01"
p2_end_date = "2024-02-29"

# Sentinel-2
S2_QA_BAND = "cs_cdf"
S2_CLEAR_THRESHOLD = 0.60
S2_SCALE_FACTOR = 0.0001

# Landsat-9
ls_QA_BAND = "QA_PIXEL"
ls_CLEAR_THRESHOLD = 322
ls_SCALE_FACTOR = 0.0000275

In [176]:
def temporal_filter(imageCollection, start_date, end_date):
    """Filters an ImageCollection by start and end dates"""
    return imageCollection.filterDate(start_date, end_date)

def spatial_filter(imageCollection, region):
    """Filters an ImageCollection by region"""
    return imageCollection.filterBounds(region)

def temporal_reducer(col):
    """Reduces an ImageCollection over time by computing the mean, 10th and 90th percentiles, and standard deviation for each pixel."""
    reducer = ee.Reducer.mean() \
    .combine(ee.Reducer.percentile([10, 90]), sharedInputs = True) \
    .combine(ee.Reducer.stdDev(), sharedInputs = True)
    return col.reduce(reducer) 

def print_dataset_info(data):
    """Prints basic information (name, type, projection, scale, band min/max) about an ee.Image or ee.ImageCollection."""
    if isinstance(data, ee.ImageCollection):
        first_image = ee.Image(data.first())
        name = data.get('system:id').getInfo() or 'ImageCollection'
    elif isinstance(data, ee.Image):
        first_image = data
        name = data.get('system:id').getInfo() or 'Image'
    else:
        raise TypeError("Input must be an ee.Image or ee.ImageCollection")

    # Get projection and scale
    band = first_image.bandNames().get(0)
    band_names = first_image.bandNames().getInfo()
    proj = first_image.select([band]).projection()
    crs = proj.crs().getInfo()
    scale = proj.nominalScale().getInfo()

    # Reduce region to get min/max stats per band
    region = first_image.geometry().bounds()
    stats = first_image.reduceRegion(
        reducer=ee.Reducer.minMax(),
        geometry=study_area,
        scale=scale,
        maxPixels=1e8
    ).getInfo()

    # Organize min/max per band
    band_stats = {}
    for b in band_names:
        min_val = stats.get(f"{b}_min", 'N/A')
        max_val = stats.get(f"{b}_max", 'N/A')
        band_stats[b] = {'min': min_val, 'max': max_val}

    # Print results
    print(f"Dataset Name: {name}")
    print(f"Type: {'ImageCollection' if isinstance(data, ee.ImageCollection) else 'Image'}")
    print(f"Band Names: {band_names}")
    print(f"CRS: {crs}")
    print(f"Scale: {scale} meters")
    print("Band Min/Max:")
    for b, stats in band_stats.items():
        print(f"  {b}: min={stats['min']}, max={stats['max']}")


#### Sentinel-2 processing

In [177]:
def s2_cloud_mask(image):
    """Applies a cloud mask based on the QA band using a threshold to identify clear pixels."""
    return image.select(S2_QA_BAND).gte(S2_CLEAR_THRESHOLD)

def calc_ndvi_s2(image):
    """Calculates the Normalized Difference Vegetation Index (NDVI) from NIR and Red bands."""
    return image.normalizedDifference(['B8', 'B4']).rename('NDVI')

def calc_evi_s2(image):
    """Calculates the Enhanced Vegetation Index (EVI) from an image and returns it as a single-band image named 'EVI'."""
    return image.expression(
        '2.5 * ((NIR - RED) / (NIR + 6 * RED - 7.5 * BLUE + 1))',
        {
            'NIR': image.select('B8'),
            'RED': image.select('B4'),
            'BLUE': image.select('B2')
        }
    ).rename('EVI')

def calc_nirv_s2(image):
    """Calculates NIRv as NIR * NDVI."""
    ndvi = calc_ndvi_s2(image)
    nir = image.select('B8')
    nirv = nir.multiply(ndvi).rename('NIRv')
    return image.addBands(nirv)

def calc_lai_s2(image):
    """Estimates Leaf Area Index (LAI) from Sentinel-2 using EVI as a proxy."""
    evi = calc_evi_s2(image)
    lai = evi.expression('3.618 * EVI - 0.118', {
        'EVI': evi
    }).rename('LAI')
    return image.addBands(lai)

def calc_veg_indices_s2(image):
    """Adds multiple vegetation and water indices (NDVI, EVI, NDWI, MNDWI, SAVI) as bands to the image."""
    return image.addBands(
        [calc_ndvi_s2(image), calc_evi_s2(image), calc_nirv_s2(image), calc_lai_s2(image)]
    )

def process_s2_image(image):
    """Processes Sentinel-2 SR image to apply cloud masking, scale, and add NDVI, EVI, NIRv, and LAI."""
    mask = s2_cloud_mask(image)
    image = image.select([
        "B2", "B3", "B4", "B5", "B6", "B7", "B8", "B8A", "B11", "B12"]
    ).multiply(S2_SCALE_FACTOR)
    indices = calc_veg_indices_s2(image).updateMask(mask)
    return indices.copyProperties(image, ["system:time_start"])

def link_s2_collections(col1, col2):
    """Links a secondary collection (e.g., QA) to a primary Sentinel-2 collection by matching and attaching specified bands."""
    linked_collection = col1.linkCollection(col2, [S2_QA_BAND])
    return linked_collection

#### Landsat 9 Processing

In [205]:
def ls_cloud_mask(image):
    """Applies a cloud mask based on the QA_PIXEL band using a clear threshold."""
    return image.select(ls_QA_BAND).gte(ls_CLEAR_THRESHOLD)

def calc_ndvi_ls(image):
    """Calculates NDVI from Landsat 8 bands: NIR (B5) and Red (B4)."""
    return image.normalizedDifference(['SR_B5', 'SR_B4']).rename('NDVI')

def calc_evi_ls(image):
    """Calculates EVI from Landsat 8 bands."""
    return image.expression(
        '2.5 * ((NIR - RED) / (NIR + 6 * RED - 7.5 * BLUE + 1))', {
            'NIR': image.select('SR_B5'),
            'RED': image.select('SR_B4'),
            'BLUE': image.select('SR_B2')
        }
    ).rename('EVI')

def calc_nirv_ls(image):
    """Calculates NIRv for Landsat 8."""
    ndvi = calc_ndvi_ls(image)
    nir = image.select('SR_B5')
    nirv = nir.multiply(ndvi).rename('NIRv')
    return image.addBands(nirv)

def calc_lai_ls(image):
    """Estimates LAI from Landsat EVI."""
    evi = calc_evi_ls(image)
    lai = evi.expression('3.618 * EVI - 0.118', {
        'EVI': evi
    }).rename('LAI')
    return image.addBands(lai)

def calc_veg_indices_ls(image):
    """Adds multiple vegetation and water indices (NDVI, EVI, NDWI, MNDWI, SAVI) as bands to the image."""
    return image.addBands(
        [calc_ndvi_ls(image), calc_evi_ls(image), calc_nirv_ls(image), calc_nirv_ls(image)]
    )

def process_ls_image(image):
    """Processes Landsat 9 image: cloud mask, scale reflectance, calculate indices and LAI."""
    mask = ls_cloud_mask(image)
    image = image.select(['SR_B2', 'SR_B3', 'SR_B4', 'SR_B5', 'SR_B6', 'SR_B7']
    ).multiply(S2_SCALE_FACTOR)
    indices = calc_veg_indices_ls(image).updateMask(mask)
    return indices.copyProperties(image, ["system:time_start"])

#### CHIRPS Precip Processing

In [201]:
def calc_sum_precip(col):
    """Calculates annual total precipitation for a given year.
    Returns ee.Image
    """
    reducer = ee.Reducer.sum()
    img = col.reduce(reducer)
    return img.select("precipitation_sum").rename('TOTAL_PRECIP')

def monthly_total_precip(col, start_date, end_date):
    """Calculates total monthly precipitation between start and end dates using CHIRPS data.
    
    Returns ee.ImageCollection
    """
    start_dt = datetime.strptime(start_date, '%Y-%m-%d').date()
    end_dt = datetime.strptime(end_date, '%Y-%m-%d').date()

    monthly_images = []

    current = start_dt
    while current < end_dt:
        next_month = current + relativedelta(months=1)
        month_str = current.strftime('%b').upper()

        filtered = temporal_filter(col, current.isoformat(), next_month.isoformat())
        band_name = f"{current.year}_{current.strftime('%b').upper()}_TOTAL_PRECIP"
        reducer = ee.Reducer.sum()
        img = filtered.reduce(reducer).select("precipitation_sum").rename(band_name)
        monthly_images.append(img)
        current = next_month

    return ee.ImageCollection(monthly_images)

def mean_daily_precip_series(start_date, end_date):
    """Returns a time series of daily precipitation means using CHIRPS data."""
    filtered = chirps.filterDate(start_date, end_date)
    def extract_mean(img):
        stat = img.reduceRegion(
            reducer=ee.Reducer.mean(),
            geometry=study_area,
            scale=5000,
            maxPixels=1e13
        )
        return ee.Feature(None, {'date': img.date().format('YYYY-MM-dd'), 'precip': stat.get('precipitation')})
    return ee.FeatureCollection(filtered.map(extract_mean))

def monthly_total_precip_series(start_date, end_date):
    """Returns a time series of monthly precipitation means using CHIRPS data."""
    monthly_precip = monthly_precip_filter(start_date, end_date)

    band_names = monthly_precip.bandNames()

    def band_to_feature(band_name):
        band = ee.String(band_name)
        image = monthly_precip.select([band])
        mean_dict = image.reduceRegion(
            reducer=ee.Reducer.sum(),
            geometry=study_area,
            scale=5566,  # approx. CHIRPS resolution (~5.5km)
            maxPixels=1e13
        )
        return ee.Feature(None, mean_dict).set('month', band)

    features = band_names.map(lambda name: band_to_feature(name))
    return ee.FeatureCollection(features)


In [180]:
precip_vis = {
    'min': 0,
    'max': 300,  # adjust depending on your study region/time span
    'palette': [
        'ffffff',  # white (0 mm)
        'ccebc5',  # very light green
        'a8ddb5',  # light green-blue
        '7bccc4',  # aqua
        '4eb3d3',  # blue
        '2b8cbe',  # deep blue
        '08589e',  # darker blue
        '08306b'   # navy (heavy rainfall)
    ]
}

#### Processing Remote Sensing Data

In [206]:
# Sentinel-2 SR
s2 = spatial_filter(s2, study_area)

# P1
s2_p1 = temporal_filter(s2, p1_start_date, p1_end_date)
s2_linked_p1 = link_s2_collections(s2_p1, s2CloudScore)
s2_processed_p1 = s2_linked_p1.map(process_s2_image)

s2_p2 = temporal_filter(s2, p1_start_date, p1_end_date)
s2_linked_p2 = link_s2_collections(s2_p2, s2CloudScore)
s2_processed_p2 = s2_linked_p2.map(process_s2_image)

# Landsat
ls = spatial_filter(ls, study_area)

# P2
ls_p1 = temporal_filter(ls, p1_start_date, p1_end_date)
ls_p1_processed = ls_p1.map(process_ls_image)

ls_p2 = temporal_filter(ls, p2_start_date, p2_end_date)
ls_p2_processed = ls_p2.map(process_ls_image)

# CHIRPS
chirps = spatial_filter(chirps, study_area)

# P1 
chirps_p1 = temporal_filter(chirps, p1_start_date, p1_end_date)
total_precip_p1 = calc_sum_precip(chirps_p1)
monthly_precip_p1 = monthly_total_precip(chirps_p1, p1_start_date, p1_end_date)

#P2
chirps_p2 = temporal_filter(chirps, p2_start_date, p2_end_date)
total_precip_p2 = calc_sum_precip(chirps_p2)
monthly_precip_p2 = monthly_total_precip(chirps_p2, p2_start_date, p2_end_date)

