# LFMC Mapping Tool

### Draw a polygon of a NSW area of interest in the map below that appears when you click 'Run All Cells' in the above dropdown menu 'Run'. 
 - Maximum area of polygon is 100 km² (e.g. 10km x 10km)

### Once you have finished drawing the polygon, the tool will automatically produce LFMC maps for all available, good quality Sentinel-2 satellite data in the last month, plus a mean map across those times. 
 - LFMC percentile maps will also be produced, based on other maps of LFMC mean and standard deviaiton (2015 to mid-2024).

In [None]:
def start_lfmc_map():
    import geopandas as gpd
    from shapely.geometry import shape, mapping
    from ipyleaflet import Map, DrawControl
    from IPython.display import display
    import ipywidgets as widgets
    import matplotlib.pyplot as plt
    from matplotlib.colors import LinearSegmentedColormap
    import xarray as xr
    import rioxarray
    import numpy as np
    import joblib
    import warnings
    from sklearn.exceptions import InconsistentVersionWarning
    import os
    import datetime
    # from pyproj import Transformer
    # import dask.distributed
    from pystac_client import Client
    from odc.stac import configure_rio, stac_load
    # import threading
    import scipy.stats as st
    
    # --- LFMC mapping function ---
    def map_compute_fmc(polygon_geojson):
    
        poly = shape(polygon_geojson['geometry'])
        gdf = gpd.GeoDataFrame(index=[0], crs='EPSG:4326', geometry=[poly])
        gdf = gdf.to_crs('EPSG:3308')
        poly_3308 = gdf.geometry.iloc[0]
        
        # Load mask & clip
        # trees = xr.open_dataset('inputs/nonveg_mask_nve_nsw_v2.tif',
        #                         chunks='auto').band_data.drop_vars('band').squeeze('band')
        trees = rioxarray.open_rasterio('inputs/nonveg_mask_nve_nsw_v2.tif').squeeze('band').drop_vars('band')
        trees = trees.rio.write_crs('EPSG:3308')
        trees = trees.chunk('auto')
        trees.attrs = {}
        trees = trees.rio.clip([mapping(poly_3308)], crs='EPSG:3308').compute()
        
        # Date range
        today = str(datetime.date.today())
        month_ago = str(datetime.date.today() - datetime.timedelta(days=30))
        
        # Dask client + rio
        # client = dask.distributed.Client()
        configure_rio(cloud_defaults=True, aws={'aws_unsigned': True}, 
                      # client=client
                     )
        
        # STAC query
        catalog = Client.open('https://explorer.dea.ga.gov.au/stac')
        query = catalog.search(
            collections=['ga_s2am_ard_3', 'ga_s2bm_ard_3', 'ga_s2cm_ard_3'],
            datetime=(month_ago, today),
            bbox=gdf.to_crs('EPSG:4326').total_bounds.tolist()
        )
        items = list(query.items())
        
        s2_bands = [
            'nbart_red','nbart_green','nbart_blue','nbart_red_edge_1','nbart_red_edge_2',
            'nbart_red_edge_3','nbart_nir_1','nbart_nir_2','nbart_swir_2','nbart_swir_3'
        ]
        
        # Load xarray cube
        s2_cube = stac_load(
            items, bands=s2_bands+['oa_fmask'],
            chunks={'time': 1}, groupby='solar_day', cloud_mask='fmask', like=trees
        )
        s2_cube = s2_cube.where(s2_cube['oa_fmask'] == 1).drop_vars('oa_fmask').dropna(dim='time', how='all')
        s2_cube.load()
        s2_cube['ndii'] = (s2_cube.nbart_nir_1 - s2_cube.nbart_swir_2) / (s2_cube.nbart_nir_1 + s2_cube.nbart_swir_2)
        s2_cube['ndvi'] = (s2_cube.nbart_nir_1 - s2_cube.nbart_red) / (s2_cube.nbart_nir_1 + s2_cube.nbart_red)
        s2_cube.time.attrs = {}
        
        warnings.filterwarnings("ignore", category=InconsistentVersionWarning)
        rf = joblib.load('inputs/rf_s2fmc_forest.joblib')
        rf_predictors = ['ndvi','ndii'] + s2_bands
        
        print('Data prepared. Starting LFMC prediction...')
        
        encoding = {'foliar_moisture_content': {'zlib': True, 'complevel': 1, 'shuffle': True}}
        folder_path = 'Maps/'
        for filename in os.listdir(folder_path):
            file_path = os.path.join(folder_path, filename)
            if os.path.isfile(file_path):
                os.remove(file_path)

        left, bottom, right, top = s2_cube.rio.bounds()
        left, bottom, right, top = int(left), int(bottom), int(right), int(top)

        # Load selection of statewide mean and std FMC layers (2015-2024)
        state_fmc_mean = rioxarray.open_rasterio(
            'https://hie-pub.westernsydney.edu.au/projects/fmc_modelling/fmc_layers/other_layers/fmc_nsw_mean_v2.tif',
            chunks='auto').drop_vars('band').squeeze('band').sel(y=slice(bottom,top),x=slice(left,right)).load()
        state_fmc_std = rioxarray.open_rasterio(
            'https://hie-pub.westernsydney.edu.au/projects/fmc_modelling/fmc_layers/other_layers/fmc_nsw_std_v2.tif',
            chunks='auto').drop_vars('band').squeeze('band').sel(y=slice(bottom,top),x=slice(left,right)).load()
        state_fmc_mean = state_fmc_mean.rio.reproject_match(s2_cube['ndvi'])
        state_fmc_std = state_fmc_std.rio.reproject_match(s2_cube['ndvi'])

        # Create colourmap consistent with Australian Flammability Monitoring System
        colors = [(0.87, 0, 0), (1, 1, 0.73), (0.165, 0.615, 0.957)]  # R -> G -> B
        cmap = LinearSegmentedColormap.from_list('fmc', colors, N=256)

        for t in s2_cube.time.data[:]:
            data = s2_cube.sel(time=t)
            t_ = str(t)[:10]
            
            if not np.isnan(data).all().compute():
                s2_cube = s2_cube.drop_sel(time=t)
                continue
            df = data[rf.feature_names_in_].drop_vars(['spatial_ref','time']).to_dataframe()
            try:
                preds = rf.predict(df)
            except Exception as error:
                print('Skipping timestep due to error:', error)
                continue

            # Make prediction and reformat
            df['mean_fmc'] = preds
            preds = df['mean_fmc'].to_xarray()
            # preds = preds.rio.write_crs(s2_cube.rio.crs)
            preds = preds.where(np.isfinite(data['ndii']))
            preds = preds.where(trees.compute())
            preds = preds.expand_dims('time').rename('foliar_moisture_content')

            # Calculate percentile
            percentile = st.norm.cdf(preds, loc=state_fmc_mean, scale=state_fmc_std) * 100
            percentile = xr.DataArray(percentile, coords=preds.coords)
            # percentile['time'] = data['time']
            percentile = percentile.rename('perc_foliar_moisture_content')

            s2_cube = s2_cube.merge(preds)
            s2_cube = s2_cube.merge(percentile)
            
            preds = preds.rio.write_crs('EPSG:3308')
            preds.rio.to_raster(f'Maps/lfmc_{left}_{bottom}_{right}_{top}_{t_}.tif', encoding=encoding)
            percentile = percentile.rio.write_crs('EPSG:3308')
            percentile.rio.to_raster(f'Maps/lfmc_perc_{left}_{bottom}_{right}_{top}_{t_}.tif', encoding=encoding)
            
        
        s2_cube['foliar_moisture_content'].attrs['units'] = '% dry weight'
        s2_cube['perc_foliar_moisture_content'].attrs['units'] = '%'
        
        g = s2_cube['foliar_moisture_content'].plot(col='time', col_wrap=3, robust=True, cmap='viridis_r')
        g.fig.suptitle('LFMC of tree cover in polygon')
        plt.show()
        g = s2_cube['perc_foliar_moisture_content'].plot(col='time', col_wrap=3, robust=True, cmap=cmap)
        g.fig.suptitle('LFMC percentile of tree cover in polygon')
        plt.show()
        
        s2_cube['mean_fmc'] = s2_cube['foliar_moisture_content'].mean('time')
        s2_cube = s2_cube.rio.write_crs('EPSG:3308')
        s2_cube['mean_fmc'].rio.to_raster(f'Maps/lfmc_{left}_{bottom}_{right}_{top}_mean_{month_ago}_to_{today}.tif', encoding=encoding)
        s2_cube['mean_perc_fmc'] = s2_cube['perc_foliar_moisture_content'].mean('time')
        s2_cube['mean_perc_fmc'].rio.to_raster(f'Maps/lfmc_perc_{left}_{bottom}_{right}_{top}_mean_{month_ago}_to_{today}.tif', encoding=encoding)
        
        print('\nSuccess! LFMC maps saved in Maps/\n')
    
        return

    # --- Create map ---
    
    # m = Map(center=(-33.6, 150.7), zoom=10)
    # draw_ctrl = DrawControl(
    #     polygon={'shapeOptions': {'color': '#6bc2e5'}},
    #     rectangle={'shapeOptions': {'color': '#6bc2e5'}},
    #     circlemarker={}, circle={}, polyline={}
    # )
    # m.add_control(draw_ctrl)
    
    m = Map(center=(-33.6, 150.7), zoom=10)
    
    draw_ctrl = DrawControl(
        polygon={'shapeOptions': {'color': '#6bc2e5'}},
        rectangle={'shapeOptions': {'color': '#6bc2e5'}},
        circlemarker={}, circle={}, polyline={}
    )
    m.add_control(draw_ctrl)
    
    # --- Widgets for lat/lon input ---
    lat_text = widgets.FloatText(
        value=-33.6,
        description='Lat:',
        step=0.0001,
        layout=widgets.Layout(width='200px')
    )
    lon_text = widgets.FloatText(
        value=150.7,
        description='Lon:',
        step=0.0001,
        layout=widgets.Layout(width='200px')
    )
    go_button = widgets.Button(description="Go to location")
    
    # Callback to recenter map
    def on_go_clicked(b):
        lat, lon = lat_text.value, lon_text.value
        m.center = (lat, lon)
    
    go_button.on_click(on_go_clicked)
    
    # Display UI and map
    controls = widgets.HBox([lat_text, lon_text, go_button])
    
    
    
    out = widgets.Output()
    display(out)
    
    # --- Callback with area check and background thread ---
    def handle_draw(self, action, geo_json):
        with out:
            out.clear_output()        
            # clear_output(wait=True)
            geom = shape(geo_json['geometry'])
            # Convert to projected CRS for area in meters
            gdf = gpd.GeoDataFrame(index=[0], crs='EPSG:4326', geometry=[geom])
            gdf = gdf.to_crs('EPSG:3395')  # metric CRS
            area_m2 = gdf.geometry.iloc[0].area
            max_area = 10_000 * 10_000  # 10 km x 10 km
    
            if area_m2 > max_area:
                print(f"Polygon too large ({area_m2/1e6:.2f} km²). Maximum allowed is 100 km² (e.g. 10x10 km).")
            else:
                print("Polygon drawn. Running LFMC mapping...")
                # thread = threading.Thread(target=map_compute_fmc, args=(geo_json,))
                # thread.start()
                # print(geo_json)
                map_compute_fmc(geo_json)
    
    draw_ctrl.on_draw(handle_draw)
    
    display(controls, m)


In [None]:
start_lfmc_map()

 - Running the tool may take a few minutes depending on polygon size and quantity of Sentinel-2 data available, a message will appear below the map plots when modelling is complete
 - Maps are saved as .tif files in the Maps folder (download via right click), in the EPSG:3308 coordinate reference system
 - File names contain the coordinates of the bounding box of the polygon input, followed by the date of the prediction or the date range, in the case of a mean map

 - If you need to stop the mapping, click the stop icon in the menu bar or click the menu 'Kernel' then 'Interrupt Kernel', then start again. Any saved maps will be deleted when re-starting the tool