# LFMC Mapping Tool

### Enter the csv path and column names (lat, lon and date) of observations to predict LFMC for each row of csv. Also enter an output path to return the input with prediciton columns appended

In [129]:
csv_path = '../inputs/input.csv'
lat_col = 'Location_Latitude'
lon_col = 'Location_Longitude'
date_col = 'Collection_Date_YYYY-MM-DD'

csv_out_path = '../output_fmc_predictions.csv'

### To produce LFMC observations for good quality Sentinel-2 satellite data within 6 days before or after observations, in the above dropdown menu 'Run', click 'Run All Cells'
In the ouput csv, as well as the prediction nearest in space and time ('fmc_center'), there is a sequence of 9 values representing the nearest 3x3 pixels ('fmc_patch'), and the time of the Sentinel-2 observation used for prediction ('fmc_sentinel_time'). Running modelling may take a few minutes or more depending number of observations.

In [130]:
import pandas as pd
import xarray as xr
import rioxarray
import numpy as np
from pyproj import Transformer
import datetime
import joblib
import os
import dask.distributed
from pystac_client import Client
from odc.stac import configure_rio, stac_load
from datetime import timedelta
import matplotlib.pyplot as plt
import shapely.geometry as geom


# # def compute_fmc_from_csv(csv_path, lat_col='lat', lon_col='lon', date_col='date'):
#     # """
#     # Given a CSV of coordinates and dates, retrieve LFMC for each point.
#     # Only uses Sentinel-2 images within ±6 days of the observation date.
#     # Returns dataframe with:
#     #   - 'fmc_center': LFMC of closest pixel
#     #   - 'fmc_patch': list of LFMC values in 9x9 patch
#     # """

df = pd.read_csv(csv_path)

configure_rio(cloud_defaults=True, aws={"aws_unsigned": True})

# Load RF model
rf = joblib.load('../inputs/rf_s2fmc_forest.joblib')

# STAC catalog
catalog = Client.open("https://explorer.dea.ga.gov.au/stac")

# EPSG:4326 transformer (we’ll only use this for bbox definition)
# Note: your CSV is already lat/lon (EPSG:4326), so no transform needed
edge_deg = 60 / 111320.0  # ~60 m in degrees at equator (rough approximation)

# Coordinate transformer
transformer = Transformer.from_crs("EPSG:4326", "EPSG:3577", always_xy=True)

cache = {}
fmc_center_list = []
fmc_patch_list = []
fmc_sentinel_time_list = []

for idx, row in df.iloc[:,:].iterrows():
    lat, lon, obs_date = row[lat_col], row[lon_col], pd.to_datetime(row[date_col])
    key = (lat, lon, obs_date)

    if key in cache:
        fmc_center_list.append(cache[key]['center'])
        fmc_patch_list.append(cache[key]['patch'])
        fmc_sentinel_time_list.append(cache[key]['sentinel_time'])
        continue

    # Define small 60 m bbox directly in EPSG:4326
    min_lon, max_lon = lon - edge_deg/2, lon + edge_deg/2
    min_lat, max_lat = lat - edge_deg/2, lat + edge_deg/2

    # Query STAC for ±6 days
    start_date = (obs_date - pd.Timedelta(days=6)).strftime('%Y-%m-%d')
    end_date = (obs_date + pd.Timedelta(days=6)).strftime('%Y-%m-%d')

    query = catalog.search(
        collections=["ga_s2am_ard_3", "ga_s2bm_ard_3", "ga_s2cm_ard_3"],
        datetime=(start_date, end_date),
        bbox=(min_lon, min_lat, max_lon, max_lat)
    )
    
    items = list(query.items())
    if len(items) == 0:
        print('Query returned no data for',key)
        fmc_center_list.append(np.nan)
        fmc_patch_list.append([np.nan]*9)  # 3x3 patch
        fmc_sentinel_time.append(np.nan)
        cache[key] = {'center': np.nan, 'patch': [np.nan]*9, 'sentinel_time': np.nan}
        continue
    print('Query did return data for',key)
    
    # Sentinel-2 bands
    s2_bands = [
        'nbart_red', 'nbart_green', 'nbart_blue',
        'nbart_red_edge_1', 'nbart_red_edge_2', 'nbart_red_edge_3',
        'nbart_nir_1', 'nbart_nir_2',
        'nbart_swir_2', 'nbart_swir_3'
    ]

    # Make bbox polygon for loading (in EPSG:4326)
    bbox_poly = geom.box(min_lon, min_lat, max_lon, max_lat)

    # Load only the 3x3 window, reproject to EPSG:3577
    s2_cube = stac_load(
        items,
        bands=s2_bands + ['oa_fmask'],
        chunks={'time': 1},
        groupby="solar_day",
        cloud_mask='fmask',
        crs="EPSG:3577",         # <-- reprojection happens here
        resolution=20,
        geopolygon=bbox_poly     # <-- crop to tiny bbox
    )

    # Convert lat, lon to projected coordinates (EPSG:3577)
    x, y = transformer.transform(lon, lat)

    # Make sure only 3x3 pixels around lon/lat returned
    s2_cube = s2_cube.sel(y=slice(y+30,y-30),x=slice(x-30,x+30))
    
    s2_cube = s2_cube.where(s2_cube['oa_fmask'] == 1).drop_vars('oa_fmask').dropna(dim='time', how='all')
    s2_cube.load()
    
    if len(s2_cube.time) == 0:
        print('Masking removed all data for',key)
        fmc_center_list.append(np.nan)
        fmc_patch_list.append([np.nan]*81)
        fmc_sentinel_time_list.append(np.nan)
        cache[key] = {'center': np.nan, 'patch': [np.nan]*81, 'sentinel_time':np.nan}
        continue
    
    s2_cube['ndii'] = (s2_cube.nbart_nir_1 - s2_cube.nbart_swir_2) / (s2_cube.nbart_nir_1 + s2_cube.nbart_swir_2)
    s2_cube['ndvi'] = (s2_cube.nbart_nir_1 - s2_cube.nbart_red) / (s2_cube.nbart_nir_1 + s2_cube.nbart_red)

    # s2_cube['ndvi'].plot(col='time',col_wrap=3); plt.show()

    # Select closest in time to obs
    s2_cube = s2_cube.sel(time=obs_date, method='nearest')
    t = pd.to_datetime(s2_cube.time.values)
    
    data_df = s2_cube[rf.feature_names_in_].drop_vars(['spatial_ref', 'time']).to_dataframe().dropna()
    if data_df.empty:
        fmc_center_list.append(np.nan)
        fmc_patch_list.append([np.nan]*81)
        fmc_sentinel_time_list.append(np.nan)
        cache[key] = {'center': np.nan, 'patch': [np.nan]*81, 'sentinel_time':np.nan}
        continue
    
    preds = rf.predict(data_df)
    
    # Convert predictions to 2D xarray
    # fmc_array = np.full((len(s2_cube.y), len(s2_cube.x)), np.nan)
    # fmc_array.ravel()[data_df.index.codes[0]] = preds
    data_df['fmc'] = preds
    s2_cube_pred = data_df.to_xarray()

    # s2_cube_pred['fmc'].plot(); plt.show()
    
    # Closest pixel
    # x_idx = np.abs(s2_cube.x - x).argmin().item()
    # y_idx = np.abs(s2_cube.y - y).argmin().item()
    # fmc_center = fmc_array[y_idx, x_idx]
    fmc_center = s2_cube_pred.sel(y=y, x=x, method='nearest')['fmc'].values
    
#     # 9x9 patch
#     patch = fmc_array[max(0, y_idx-4):y_idx+5, max(0, x_idx-4):x_idx+5].flatten()
    
    fmc_center_list.append(fmc_center)
    fmc_patch_list.append(preds.tolist())
    fmc_sentinel_time_list.append(t)
    cache[key] = {'center': fmc_center, 'patch': preds.tolist(), 'sentinel_time': t}
    

df['fmc_center'] = fmc_center_list
df['fmc_center'] = df['fmc_center'].astype('float32')
df['fmc_patch'] = fmc_patch_list
df['fmc_sentinel_time'] = fmc_sentinel_time_list
    
    # return df

df.to_csv(csv_out_path)

Query did return data for (-29.6199439, 153.2933406, Timestamp('2021-04-12 00:00:00'))
Query did return data for (-29.81746551, 153.2574998, Timestamp('2021-04-15 00:00:00'))
Query did return data for (-29.6199439, 153.2933406, Timestamp('2021-05-03 00:00:00'))
Masking removed all data for (-29.6199439, 153.2933406, Timestamp('2021-05-03 00:00:00'))
Query did return data for (-29.81746551, 153.2574998, Timestamp('2021-05-03 00:00:00'))
Masking removed all data for (-29.81746551, 153.2574998, Timestamp('2021-05-03 00:00:00'))
Query did return data for (-29.81746551, 153.2574998, Timestamp('2021-05-31 00:00:00'))
Query did return data for (-29.6199439, 153.2933406, Timestamp('2021-05-31 00:00:00'))
Query did return data for (-29.6199439, 153.2933406, Timestamp('2021-06-30 00:00:00'))
Query did return data for (-29.81746551, 153.2574998, Timestamp('2021-06-30 00:00:00'))
Query did return data for (-29.6199439, 153.2933406, Timestamp('2021-11-10 00:00:00'))
Query did return data for (-30.4