# Predict habitat for native species

## Set up notebook

In [None]:
%load_ext autoreload
%autoreload 2

In [None]:
import logging

logging.basicConfig(level=logging.INFO)

## Import configuration and modules

In [None]:
from eo_insights.stac_configuration import de_australia_stac_config
from eo_insights.raster_base import RasterBase, QueryParams, LoadParams
from eo_insights.band_indices import calculate_indices
from rasterstats import zonal_stats
import geopandas as gpd
import pandas as pd
from sklearn.ensemble import RandomForestClassifier
import xarray as xr
from pathlib import Path
import numpy as np
from skimage.segmentation import quickshift,slic
from sklearn.preprocessing import StandardScaler,MinMaxScaler
import fiona
from shapely.geometry import shape

## Load training data and define study area

In [None]:
# Filtered Frog presence points from ALA, with pseodo absence points pre-generated
# relative path of training data
path_fauna_data="notebooks/native_species_data/records-2024-02-13-south-east-coastal-plain-southern-bell-frog_filtered_subset.geojson"
# load data
cwd = Path().resolve()
gpd_fauna=gpd.read_file(cwd.joinpath(path_fauna_data)).to_crs("EPSG:4326")
bbox=gpd_fauna.total_bounds
gpd_fauna

## Query satellite products

### Get DEA configuration and list collections

In [None]:
config = de_australia_stac_config
config.list_collections()

### Load DEA products:
* Landsat 8 & 9 yearly geomedian

In [None]:
query_params_ls = QueryParams(
    bbox=bbox,
    start_date="2023-01-01",
    end_date="2023-12-31",
)

# Landsat 8 & 9 yearly geomedian: load only geomedian spectral bands for now
params_ls_8_9_gm = LoadParams(
    crs="EPSG:3577",
    resolution=30,
    bands=("nbart_blue", "nbart_green", "nbart_red", "nbart_nir", "nbart_swir_1", "nbart_swir_2"),
)
raster_ls_8_9_gm = RasterBase.from_stac_query(
    config=config,
    collections=["ga_ls8cls9c_gm_cyear_3"],
    query_params=query_params_ls,
    load_params=params_ls_8_9_gm,
)

# calculate indices
ds_ls_8_9_gm = calculate_indices(raster_ls_8_9_gm.data, ["ndvi","ndwi"])

ds_ls_8_9_gm= ds_ls_8_9_gm.isel(time=0)
ds_ls_8_9_gm

* Sentinel-2 Barest Earth and BSI

In [None]:
query_params_s2_be = QueryParams(
    bbox=bbox,
    start_date="2018-01-01",
    end_date="2018-12-31",
)

# load only bands that are used for BSI calculation
params_s2_be = LoadParams(
    crs="EPSG:3577",
    resolution=30,
    bands=("blue", "red", "nir", "swir_1"),
)
raster_s2_be = RasterBase.from_stac_query(
    config=config,
    collections=["s2_barest_earth"],
    query_params=query_params_s2_be,
    load_params=params_s2_be,
)

# calculate BSI
ds_s2_be_bsi = calculate_indices(raster_s2_be.data, ["bsi"])

# dropping original bands
ds_s2_be_bsi = ds_s2_be_bsi[["bsi"]].isel(time=0)
ds_s2_be_bsi

* SRTM DEM

In [None]:
query_params_dem = QueryParams(
    bbox=bbox,
    start_date="2014-01-01",
    end_date="2014-12-31",
)

params_dem = LoadParams(
    crs="EPSG:3577",
    resolution=30,
    bands=("dem"),
)
raster_dem= RasterBase.from_stac_query(
    config=config,
    collections=["ga_srtm_dem1sv1_0"],
    query_params=query_params_dem,
    load_params=params_dem,
)
ds_dem = raster_dem.data.isel(time=0)
ds_dem

* Multi-scale Topographic Position Index (TPI)

In [None]:
query_params_tpi = QueryParams(
    bbox=bbox,
    start_date="2018-01-01",
    end_date="2018-12-31",
)

params_tpi = LoadParams(
    crs="EPSG:3577",
    resolution=30,
    bands=("regional", "intermediate", "local"),
)
raster_tpi= RasterBase.from_stac_query(
    config=config,
    collections=["multi_scale_topographic_position"],
    query_params=query_params_tpi,
    load_params=params_tpi,
)
ds_tip = raster_tpi.data.isel(time=0)
ds_tip

* Weathering Intensity

In [None]:
query_params_weathering = QueryParams(
    bbox=bbox,
    start_date="2018-01-01",
    end_date="2018-12-31",
)

params_weathering = LoadParams(
    crs="EPSG:3577",
    resolution=30,
    bands=("intensity"),
)
raster_weathering= RasterBase.from_stac_query(
    config=config,
    collections=["weathering_intensity"],
    query_params=query_params_weathering,
    load_params=params_weathering,
)
ds_weathering = raster_weathering.data.isel(time=0)
ds_weathering

### Stack all product bands

In [None]:
ds_all = xr.merge([ds_ls_8_9_gm, ds_s2_be_bsi, ds_dem, ds_tip, ds_weathering], compat='override')
ds_all

## Object-based classification for native species habitat prediction

### Image segmentation

In [None]:
arr_data=np.moveaxis(ds_all[['ndvi','ndwi','bsi']].to_array().to_numpy(),0,-1)

# normalisation
rows,columns,n_band=arr_data.shape
arr_data=np.reshape(arr_data,(rows*columns,n_band))
# arr_data = MinMaxScaler ().fit_transform(arr_data)
arr_data = StandardScaler ().fit_transform(arr_data)

# Do segmentation - slic or quickshift
arr_data=np.reshape(arr_data,(rows,columns,n_band))
# segments = slic(arr_data,n_segments=500,compactness=compactness,slic_zero=False)
# segments = quickshift(arr_data,ratio=0.8,kernel_size=2,max_dist=10,sigma=0,convert2lab=False)
segments = quickshift(arr_data,ratio=1.0,kernel_size=3,max_dist=10,sigma=0,convert2lab=False)
segments = xr.DataArray(segments, coords=ds_all.coords, dims=ds_all.dims,
                    attrs=ds_all.attrs).astype(np.int16)

### Object-level features calculation

In [None]:
output_path="native_species_data/segmentation.geojson"
gdf=xr_vectorize(segments,output_path=output_path,connectivity=8)
# Calculate zonal statistics for all bands
affine=ds_all.rio.transform()
var_names=list(ds_all.keys())
list_stats=['median', 'std', 'percentile_10','percentile_90']
# list_stats=['median', 'std']
with fiona.open(output_path) as zones:
    gpd_stats=None
    # var=var_names[0]
    for var in var_names:
        # print('calculating zonal statistics for',var)
        band=ds_all[var].to_numpy()
        zonestats=zonal_stats(zones,band,stats=list_stats,all_touched=True,
                        geojson_out=True,affine=segments.rio.transform())
        # convert to geopandas dataframe
        pd_stats=pd.DataFrame.from_dict(zonestats)
        geoms = [shape(j) for j in pd_stats['geometry']]
        gpd_stats=gpd.GeoDataFrame(pd_stats,geometry=geoms,crs="EPSG:3577")
        # unpack statistics to be added as new attributes - band name as prefix 
        for stat in list_stats:
            stat_var=var+'_'+stat
            gpd_stats[stat_var]=[gpd_stats['properties'][j][stat] for j in range(len(gpd_stats))]
        # remove redundant attributes
        gpd_stats=gpd_stats.drop(['properties','type'],axis=1)
        if gpd_stats is None:
            gpd_stats=gpd_stats.copy()
        else:
            gpd_stats=pd.concat([gpd_stats,gpd_stats.drop(['id','geometry'],axis=1)],axis=1)
    gpd_stats.to_file("native_species_data/segmentation_stats.geojson")

### Model training

#### Get presence samples

#### Generate pesudo-absence samples

### Prediction