# VegFrax

## Initialise VegFrax

### Load packages

In [None]:
%matplotlib inline
%load_ext autoreload

import os
import sys
import gdal
import pandas as pd
import numpy as np
import xarray as xr
import datacube
import matplotlib.pyplot as plt

sys.path.append('../Scripts')
from dea_datahandling import load_ard
from dea_dask import create_local_dask_cluster
from dea_plotting import display_map, rgb

sys.path.append('./scripts')
import vegfrax

sys.path.append('./gdvtools')
import gdvtools

### Set up a dask cluster

In [None]:
# initialise the cluster. paste url into dask panel for more info.
create_local_dask_cluster()

## Load classified and raw raster data

### Load and validate high-res classified raster

In [None]:
# set path to high-resolution classified image (e.g. 10m Sentinel 2 or 1m WV)
rast_class = [r'./data/fraser/class/fraser_10m_class.tif']

# validate basic raster properties
gdvtools.validate_rasters(rast_class)

# load raster as an xarray dataset and set nodata to -9999
ds_class = gdvtools.rasters_to_dataset(rast_class, nodata_value=-9999)

# rename band
ds_class = ds_class.rename({list(ds_class.data_vars)[0]: 'classes'}) # todo - own func?

# get unique classes inherent in data
np_classes = vegfrax.get_dataset_classes(ds_class, nodata_value=-9999)

### Load low-res raw (local raster approach)

In [None]:
# set path to pre-existing low-resolution raw image of same area (e.g. landsat)
rast_raw = [
    './data/fraser/raw/modis_b1_500m_201905.tif',
    './data/fraser/raw/modis_b2_500m_201905.tif',
    './data/fraser/raw/modis_b3_500m_201905.tif',
    './data/fraser/raw/modis_b4_500m_201905.tif',
    './data/fraser/raw/modis_b5_500m_201905.tif',
    './data/fraser/raw/modis_b6_500m_201905.tif',
    './data/fraser/raw/modis_b7_500m_201905.tif'
]

# sort rasters in band order
rast_raw.sort()

# validate basic raster properties
gdvtools.validate_rasters(rast_raw)

# load raster as an xarray dataset and set nodata to -9999
ds_raw = gdvtools.rasters_to_dataset(rast_raw, nodata_value=-9999)

# display dataset
#print(ds_raw)

### Load low-res raw (ODC approach)

In [None]:
# open up a datacube connection
dc = datacube.Datacube(app='vegfrax')

In [None]:
# set study area extent. this cover all (and more) of high-res raster
lat_ext, lon_ext = (-24.675736, -25.838889), (152.765831, 153.420994)

# display onto interacrive map
display_map(x=lon_ext, y=lat_ext)

In [None]:
# set datacube band requirements
measurements = [
    'nbart_blue',
    'nbart_green',
    'nbart_red',
    'nbart_nir_1',
    'nbart_swir_2',
    'nbart_swir_3'
]

# create query
query = {
    'x': lon_ext,
    'y': lat_ext,
    'time': ('2019-03', '2019-07'),
    'measurements': measurements,
    'output_crs': 'EPSG:32756',
    'resolution': (-500, 500),
    'group_by': 'solar_day',
    #'align': (0, 0)
}

# load the downscaled data
ds_raw = load_ard(
    dc=dc,
    products=['s2a_ard_granule', 's2b_ard_granule'],
    min_gooddata=0.90,
    dask_chunks={'time': 1},
    **query
)

# combine per-pixel via median
ds_raw = ds_raw.median('time').compute()

# display dataset
#print(ds_raw)

# plot rgb of data
#rgb(ds_raw[['nbart_red', 'nbart_green', 'nbart_blue']])

## Prepare random samples

### Generate random samples from raw and classified raster overlaps

In [None]:
# set number of samples for training
num_samples = 1000

# generate random samples within area overlap between raw and classified rasters
df_samples = vegfrax.generate_random_samples(ds_raw, ds_class, num_samples=num_samples, nodata_value=-9999)

# display result
#print(df_samples)

### Extract values from raw rasters (e.g. low resolution image bands)

In [None]:
# extract pixel values from raw, low resolution rasters at each point
df_extract = gdvtools.extract_dataset_values(ds=ds_raw, coords=df_samples, keep_xy=True, nodata_value=-9999)

# display result
#print(df_extract)

### Remove any no data values from extraction

In [None]:
# remove any points containing a nodata value
df_extract_clean = gdvtools.remove_nodata_records(df_extract, nodata_value=-9999)

# display result
#print(df_extract_clean)

## Generate and prepare frequency windows and analysis data

### Build raw focal window extents from random sample points

In [None]:
# generate focal windows and extract pixels from class raster
df_windows = vegfrax.create_frequency_windows(ds_raw, ds_class, df_extract_clean)

# display result
#print(df_windows)

### Convert raw windows to class frequency information

In [None]:
# transform raw focal window pixel classes and counts to unique classes and frequencies at each point
df_freqs = vegfrax.convert_window_counts_to_freqs(df_windows, nodata_value=-9999)

# display result
#print(df_freqs)

### Select desired analysis classes and prepare data for analysis

In [None]:
# set desired output classes. keep empty to produce all classes. could put 1, 2 for classes 1 and 2.
override_classes = []

# prepare data for analysis - prepare classes, nulls, normalise frequencies
df_data = vegfrax.prepare_freqs_for_analysis(ds_raw, ds_class, df_freqs, override_classes, nodata_value=-9999)

# display result
#print(df_data)

## Perform Fractional Cover Analysis (FCA)

### Set FCA modelling parameters

In [None]:
# set gridsearch fit optimiser parameters (see sklearn gridsearchcv docs for more)
grid_params = {
    'max_depth': [3, 5, 7, 9],
    'n_estimators': [10, 100, 250, 500]
}

# set number of train/test set cross-validations (higher better, but slower)
validation_iters = 50

### Perform the FCA 

In [None]:
# perform fca
ds_preds = vegfrax.perform_fca(ds_raw, ds_class, df_data, df_extract_clean, 
                               grid_params, validation_iters, nodata_value=-9999)

## Display a vegetation class

In [None]:
# set the class to plot (e.g. 6 = shrubland, 1 = Eucalyptus woodland)
class_label = '6'

# create fig
fig = plt.figure(figsize=(6, 9))

# plot this class on map
ds_preds[class_label].plot(cmap='terrain_r')

In [None]:
from datacube.utils.cog import write_cog

da = ds_preds[class_label]
da.attrs = ds_raw.attrs

write_cog(geo_im=da, 
          fname='class_{0}.tif'.format(class_label), 
          overwrite=True, 
          nodata=-9999)

In [None]:
ds_preds.attrs = ds_raw.attrs.crs

In [None]:
da.attrs