# VegFrax

## Initialise VegFrax

### Load packages

In [None]:
%matplotlib inline
%load_ext autoreload

import os
import sys
import gdal
import pandas as pd
import numpy as np
import xarray as xr
import matplotlib.pyplot as plt

import datacube
sys.path.append('../../../Scripts')
from dea_datahandling import load_ard
from dea_dask import create_local_dask_cluster
from dea_plotting import display_map, rgb

sys.path.append('../../modules')
import vegfrax

sys.path.append('../../shared')
import satfetcher, tools

#sys.path.append('./gdvtools')
#import gdvtools

### Set up a dask cluster

In [None]:
# initialise the cluster. paste url into dask panel for more info.
create_local_dask_cluster()

## Load and validate high-res classified local raster

In [None]:
# set path to high-resolution classified image (e.g. 10m Sentinel 2 or 1m WV)
rast_class = [r'../../data/vegfrax/class/roy_veg_class.tif']

# validate basic raster properties
tools.validate_rasters(rast_class)

# load raster as an xarray dataset and set nodata to -9999
#ds_class = gdvtools.rasters_to_dataset(rast_class, nodata_value=-9999)
ds_class = satfetcher.load_local_rasters(rast_path_list=rast_class, 
                                         use_dask=True, 
                                         conform_nodata_to=-9999)

# compute raster into memory
ds_class = ds_class.compute()


# rename band
ds_class = ds_class.rename({list(ds_class.data_vars)[0]: 'classes'}) # todo - own func?

# get unique classes inherent in data
np_classes = vegfrax.get_dataset_classes(ds_class, nodata_value=-9999)

### Load low-res raw (local raster approach)

In [None]:
"""
# set path to pre-existing low-resolution raw image of same area (e.g. landsat)
rast_raw = [
    './data/fraser/raw/modis_b1_500m_201905.tif',
    './data/fraser/raw/modis_b2_500m_201905.tif',
    './data/fraser/raw/modis_b3_500m_201905.tif',
    './data/fraser/raw/modis_b4_500m_201905.tif',
    './data/fraser/raw/modis_b5_500m_201905.tif',
    './data/fraser/raw/modis_b6_500m_201905.tif',
    './data/fraser/raw/modis_b7_500m_201905.tif'
]

# sort rasters in band order
rast_raw.sort()

# validate basic raster properties
gdvtools.validate_rasters(rast_raw)

# load raster as an xarray dataset and set nodata to -9999
ds_raw = gdvtools.rasters_to_dataset(rast_raw, nodata_value=-9999)

# display dataset
#print(ds_raw)
"""

## Load lower-res Sentienl data (ODC approach)

### Set study area, time range, show map

In [None]:
# open up a datacube connection
dc = datacube.Datacube(app='vegfrax')

In [None]:
# testing study area extent - yandi and roy hill
lat_extent, lon_extent = (-22.63461, -22.33461), (119.88111, 120.18111) # royhill

# display onto interacrive map
display_map(x=lon_extent, y=lat_extent)

In [None]:
# provide study area name
study_area = 'royhill'

# select start and end year range
time_range = ('2016-10', '2016-11')

# set datacube query parameters
platform = 'sentinel'
#bands = ['nbart_blue', 'nbart_green', 'nbart_red', 'nbart_nir', 'nbart_swir_1', 'nbart_swir_2']
bands = ['nbart_blue', 'nbart_green', 'nbart_red', 'nbart_nir_1', 'nbart_swir_2', 'nbart_swir_3'] # sentinel
min_gooddata = 0.90

# fetch satellite data from dea ard product
ds_raw = satfetcher.load_dea_ard(platform=platform, 
                                 bands=bands, 
                                 x_extent=lon_extent, 
                                 y_extent=lat_extent, 
                                 time_range=time_range, 
                                 min_gooddata=min_gooddata, 
                                 use_dask=True)

# display dataset
#ds_raw

### Conform band names

In [None]:
# rename dea bands to common standard
ds_raw = satfetcher.conform_dea_ard_band_names(ds=ds_raw, platform=platform)

# display dataset
#ds_raw

### Calculate tasselled cap bands

In [None]:
# calculate veg (mavi) and moist (ndmi) indices
ds_raw = tools.calculate_indices(ds=ds_raw, 
                                 index=['tcg', 'tcb', 'tcw'], 
                                 custom_name=None, 
                                 rescale=False, 
                                 drop=True)

# display dataset
#ds_raw

### Reduce to a median of all times

In [None]:
# reduce to median image
ds_raw = ds_raw.median('time', keep_attrs=True)

# compute into memory
ds_raw = ds_raw.compute()

In [None]:
ds_raw['tcw'].plot(robust=True)

## Prepare random samples

### Generate random samples from raw and classified raster overlaps

In [None]:
# set number of samples for training
num_samples = 1000

# generate random samples within area overlap between raw and classified rasters
df_samples = vegfrax.generate_random_samples(ds_raw, ds_class, num_samples=num_samples, nodata_value=-9999)

# display result
#print(df_samples)

### Extract values from raw rasters (e.g. low resolution image bands)

In [None]:
# extract pixel values from raw, low resolution rasters at each point
df_extract = tools.extract_xr_values(ds=ds_raw, 
                                     coords=df_samples, 
                                     keep_xy=True, 
                                     nodata_value=-9999)

# display result
#print(df_extract)

### Remove any no data values from extraction

In [None]:
# remove any points containing a nodata value
df_extract_clean = tools.remove_nodata_records(df_extract, nodata_value=-9999)

# display result
#print(df_extract_clean)

## Generate and prepare frequency windows and analysis data

### Build raw focal window extents from random sample points

In [None]:
# generate focal windows and extract pixels from class raster
df_windows = vegfrax.create_frequency_windows(ds_raw, ds_class, df_extract_clean)

# display result
#print(df_windows)

### Convert raw windows to class frequency information

In [None]:
# transform raw focal window pixel classes and counts to unique classes and frequencies at each point
df_freqs = vegfrax.convert_window_counts_to_freqs(df_windows, nodata_value=-9999)

# display result
#print(df_freqs)

### Select desired analysis classes and prepare data for analysis

In [None]:
# set desired output classes. keep empty to produce all classes. could put 1, 2 for classes 1 and 2.
override_classes = ['2', '3', '10']

# prepare data for analysis - prepare classes, nulls, normalise frequencies
df_data = vegfrax.prepare_freqs_for_analysis(ds_raw, 
                                             ds_class, 
                                             df_freqs, 
                                             override_classes, 
                                             nodata_value=-9999)

# display result
#print(df_data)

## Perform Fractional Cover Analysis (FCA)

### Set FCA modelling parameters

In [None]:
# set gridsearch fit optimiser parameters (see sklearn gridsearchcv docs for more)
grid_params = {
    'max_depth': [3, 5, 7],
    'n_estimators': [10, 100, 250]
}

# set number of train/test set cross-validations (higher better, but slower)
validation_iters = 50

### Perform the FCA 

In [None]:
# perform fca
ds_preds = vegfrax.perform_fca(ds_raw, ds_class, df_data, df_extract_clean, 
                               grid_params, validation_iters, nodata_value=-9999)

## Display a vegetation class

In [None]:
# set the class to plot (e.g. 6 = shrubland, 1 = Eucalyptus woodland)
class_label = '3'

# create fig
fig = plt.figure(figsize=(9, 9))

# plot this class on map
ds_preds[class_label].plot(cmap='terrain_r')

In [None]:
from datacube.utils.cog import write_cog

da = ds_preds[class_label]
da.attrs = ds_raw.attrs

write_cog(geo_im=da, 
          fname='class_{0}.tif'.format(class_label), 
          overwrite=True, 
          nodata=-9999)