In [1]:
# ensemble modeller

# import various evidence layers
# build frame of discernment
# e.g. rescale images > 0.8 for each image
# which images are gdv areas, which layers are non-gdv areas?

# superset = [L, P, S, F, C]

# L = GDV Likelihood
# P = Phenometric (e.g. LIOT)
# S = SDM
# F = Fractional Map
# C = CHM

# determine if each layer is site, non-site.
# assign bpa's 
# e.g. bird like sw aspect, so sw aspect gets 0.75, flat givem 0.5, rest (superset) gets 0.25
# e.g. increasing signmoidal, 0 to 1
# combine site hypotheses (e.g. site layers)
# e.g. site (vege layer) = 0.8, rest (ignoreance) get 0.2. 
# e.g. site (aspect) =0.9, rest (ignoreance) get 0.1
# prepare matrix for site hypothesies (vege, aspect)
# cross multiply sites (veg x aspect) and ignorance
# next matrix for non-sites 1, 2, cross multi,sum values
# new matrix for non-sites 2, 3, cross multi,sum values
# add results of those two into matrix, cross multi, then sum values
# take result from site, non-site, superset, and put in new matrix
# cross multi, sum up, divide by intersection 
# calc belief by summing any site values, calc disbelief by summing any non-site
# do this per pixel
# you get disbelief, plaus, belief interval (plaus - belief) and belief


# Ensemble

## Initialise Ensemble

### Load packages

In [2]:
%matplotlib inline
%load_ext autoreload

import os, sys
import pandas as pd
import numpy as np
import xarray as xr
import matplotlib.pyplot as plt

import datacube
sys.path.append('../../../Scripts')
from dea_datahandling import load_ard
from dea_dask import create_local_dask_cluster
from dea_plotting import display_map, rgb

sys.path.append('../../modules')
import gdvspectra, phenolopy, nicher, vegfrax, canopy, ensemble

sys.path.append('../../shared')
import satfetcher, tools

  shapely_geos_version, geos_capi_version_string


### Set up a dask cluster and ODC

In [3]:
# initialise the cluster. paste url into dask panel for more info.
create_local_dask_cluster()

# open up a datacube connection
dc = datacube.Datacube(app='ensemble')

0,1
Client  Scheduler: tcp://127.0.0.1:46545  Dashboard: /user/lewis/proxy/8787/status,Cluster  Workers: 1  Cores: 2  Memory: 13.11 GB


  username=username, password=password,


## Study area and data setup

### Set study area, time range, show map

In [4]:
# testing study area extent - yandi and roy hill
#lat_extent, lon_extent = (-22.82901, -22.67901), (118.94980, 119.29979)  # yandi
lat_extent, lon_extent = (-22.63461, -22.33461), (119.88111, 120.18111) # royhill

# display onto interacrive map
display_map(x=lon_extent, y=lat_extent)

# Generate GDV Likelihood

### Load and prepare DEA ODC satellite data

In [None]:
# provide study area name
study_area = 'royhill'

# select start and end year range
time_range = ('2016', '2020')

# set datacube query parameters
platform = 'landsat'
bands = ['nbart_blue', 'nbart_green', 'nbart_red', 'nbart_nir', 'nbart_swir_1', 'nbart_swir_2']
#bands = ['nbart_blue', 'nbart_green', 'nbart_red', 'nbart_nir_1', 'nbart_swir_2'] # sentinel
min_gooddata = 0.90

# fetch satellite data from dea ard product
ds = satfetcher.load_dea_ard(platform=platform, 
                             bands=bands, 
                             x_extent=lon_extent, 
                             y_extent=lat_extent, 
                             time_range=time_range, 
                             min_gooddata=min_gooddata, 
                             use_dask=True)

# rename dea bands to common standard
ds = satfetcher.conform_dea_ard_band_names(ds=ds, platform=platform)

# take a copy of dataset for cva later
ds_backup = ds.copy(deep=True)

# display dataset
#ds

### Calculate standardised vege/moist data

In [None]:
# set wet and dry season month(s). we will use several per season
wet_month, dry_month = [1, 2, 3], [9, 10, 11]

# get subset fo data for wet and dry season months
ds = gdvspectra.subset_months(ds=ds, 
                              month=wet_month + dry_month,
                              inplace=True)

# calculate veg (mavi) and moist (ndmi) indices
ds = tools.calculate_indices(ds=ds, 
                             index=['mavi', 'ndmi'], 
                             custom_name=['veg_idx', 'mst_idx'], 
                             rescale=True, 
                             drop=True)

# perform resampling
ds = gdvspectra.resample_to_wet_dry_medians(ds=ds, 
                                            wet_month=wet_month, 
                                            dry_month=dry_month,
                                            inplace=True)

# we have some calcs to make, persist now
ds = ds.persist()

# drop any years from dataset where wet and dry seasons missing
ds = gdvspectra.drop_incomplete_wet_dry_years(ds)

# todo - remove this compute when bug fixed in ver > 0.18.2
ds = ds.compute()

# fill any empty first, last years using back/forward fill
ds = gdvspectra.fill_empty_wet_dry_edges(ds=ds,
                                         wet_month=wet_month, 
                                         dry_month=dry_month,
                                         inplace=True)

# interpolate all missing pixels using full linear interpolation
ds =  gdvspectra.interp_empty_wet_dry(ds=ds,
                                      wet_month=wet_month, 
                                      dry_month=dry_month,
                                      method='full', 
                                      inplace=True)

# standardise data to invariant targets derived from dry times
ds = gdvspectra.standardise_to_dry_targets(ds=ds, 
                                           dry_month=dry_month, 
                                           q_upper=0.99, 
                                           q_lower=0.05,
                                           inplace=True)

# calculate standardised seaonal similarity (diff between wet, dry per year)
ds_similarity = gdvspectra.calc_seasonal_similarity(ds=ds,
                                                    wet_month=wet_month,
                                                    dry_month=dry_month,
                                                    q_mask=0.9,
                                                    inplace=True)

### Generate likelihood model

In [None]:
# generate gdv likelihood model using wet, dry, similarity variables
ds_like = gdvspectra.calc_likelihood(ds=ds, 
                                     ds_similarity=ds_similarity,
                                     wet_month=wet_month, 
                                     dry_month=dry_month)

# preview an all-time median of gdv likelihood. red is high likelihood
fig = plt.figure(figsize=(10, 5))
ds_like['like'].median('time').plot(robust=False, cmap='jet')

# create out file
out_like_nc = '../../{0}_{1}_{2}_{3}_like.nc'.format(study_area,
                                                     time_range[0],
                                                     time_range[1],
                                                     platform)

### Generate field occurrence points for thresholding

In [None]:
# set location of point shapefile with presence/absence column
#shp_path = r'../GDVSDM/data_testing/presence_points/presence_points.shp'
shp_path = r'../../data/gdvspectra/royhill_2_final_albers.shp'

# read shapefile as pandas dataframe
df_records = tools.read_shapefile(shp_path=shp_path)

# subset to just x, y, pres/abse column
df_records = tools.subset_records(df_records=df_records, p_a_column='GDV_ACT')

# display dataframe
#df_records

### Threshold likelihood

In [None]:
# perform thresholding using standard deviation on all-time median likelihood
ds_thresh = gdvspectra.threshold_likelihood(ds=ds_like.median('time', keep_attrs=True),
                                            df=df_records, 
                                            num_stdevs=2.5, 
                                            res_factor=3, 
                                            if_nodata='any')

# preview an all-time median of gdv likelihood thresholded
fig = plt.figure(figsize=(10, 5))
ds_thresh.where(~ds_thresh.isnull(), 0.001)['like'].plot(robust=False, cmap='jet')

### Export GDV likelihood and threshold for later

In [None]:
# export median likelihood to nc
#tools.export_xr_as_nc(ds=ds_like['like'].median('time'), filename='ds_like.nc')

# export threshold as nc
#tools.export_xr_as_nc(ds_thresh, filename='ds_threshold.nc')

### Calculate trends using Mann-Kendall trend analysis

In [None]:
# create a mask where gdv highly likely
ds_mask = xr.where(~ds_thresh.isnull(), True, False)

# do mk to find sig. inc/dec trends in high likelihood areas
ds_mk = gdvspectra.perform_mk_original(ds=ds_like.where(ds_mask), 
                                       pvalue=None, 
                                       direction='both')

# show mk trends. blue is increasing, red is decreasing
fig = plt.figure(figsize=(10, 5))
ds_mk['tau'].plot(robust=True, cmap='Spectral')

### Calculate slope using Theil-Sen

In [None]:
# create a mask where gdv highly likely
ds_mask = xr.where(~ds_thresh.isnull(), True, False)

# do theil sen slopes in high likelihood areas
ds_ts = gdvspectra.perform_theilsen_slope(ds=ds_like.where(ds_mask), 
                                          alpha=0.95)

# show mk trends. blue is increasing, red is decreasing
fig = plt.figure(figsize=(10, 5))
ds_ts['theilsen'].plot(robust=True, cmap='Spectral')

# Generate Phenometrics

### Calculate and pre-process vegetation

In [None]:
# takes our dask ds and calculates veg index from spectral bands
ds = tools.calculate_indices(ds=ds_backup, 
                             index='mavi', 
                             custom_name='veg_idx', 
                             rescale=False, 
                             drop=True)

# conform edges
ds = phenolopy.conform_edge_dates(ds=ds)

# resample to weekly medians
ds = phenolopy.resample(ds=ds, 
                        interval='1W',
                        inplace=True)

# interpolate missing values
ds = phenolopy.interpolate(ds=ds, 
                           method='full', 
                           inplace=True)

# group into single year of weekly all-time medians
ds = phenolopy.group(ds=ds, 
                     interval='week',
                     inplace=True)

# takes our dask ds and remove outliers from data using median method
ds = phenolopy.remove_outliers(ds=ds, 
                               method='median', 
                               user_factor=2, 
                               z_pval=0.05)

# take dataset and resample data to weekly medians (1WS)
ds = phenolopy.resample(ds=ds, 
                        interval='1W',
                        inplace=True)

# remove any years outside of dominant year
ds = phenolopy.remove_overshoot_times(ds=ds, max_times=3)

# use savitsky-golay filter to smooth across time dimension
ds = phenolopy.smooth(ds=ds, 
                      method='savitsky', 
                      window_length=3, 
                      polyorder=1, 
                      sigma=1)

### Calculate phenometrics

In [None]:
# compute into memory
ds = ds.compute()

# set desired metrics
#metrics = ['sos', 'eos', 'lios', 'sios', 'liot', 'siot']

# calc phenometrics via phenolopy!
ds = phenolopy.calc_phenometrics(ds=ds,
                                 peak_metric='pos', 
                                 base_metric='vos', 
                                 method='seasonal_amplitude', 
                                 factor=0.2, 
                                 thresh_sides='one_sided', 
                                 abs_value=0.1)

### Display phenometric

In [None]:
# set the metric to display
metric_name = 'liot_values'
fig = plt.figure(figsize=(9, 7), dpi=85)
ds[metric_name].plot(robust=True, cmap='terrain_r')

### Generate field occurrence points for thresholding

In [None]:
# set location of point shapefile with presence/absence column
#shp_path = r'../GDVSDM/data_testing/presence_points/presence_points.shp'
shp_path = r'../../data/gdvspectra/royhill_2_final_albers.shp'

# read shapefile as pandas dataframe
df_records = tools.read_shapefile(shp_path=shp_path)

# subset to just x, y, pres/abse column
df_records = tools.subset_records(df_records=df_records, p_a_column='GDV_ACT')

# display dataframe
#df_records

### Check AUC of metrics

In [None]:
# prepare metric for auc
metric = 'vos_values'
da = ds[metric].rename({'variable': 'time'}).to_dataset(promote_attrs=True)
da = da.rename({metric: 'like'})

# threshold to get auc
gdvspectra.threshold_likelihood(ds=da,
                                df=df_records, 
                                num_stdevs=2.5, 
                                res_factor=3, 
                                if_nodata='any')

In [None]:
# export to nc
#tools.export_xr_as_nc(ds=ds, filename='ds_phenometrics.nc')

# Generate Species Distribution Model

### Set up data

In [None]:
# get continuous rasters. wwe will use lidar
folder_path = r'../../data/nicher/roy_lidar'
rast_cont_list = nicher.get_files_from_path(folder_path)

# drop any undesirable vars 
rast_cont_list.remove(folder_path + '/' + 'chm_lidar_10m.tif')
rast_cont_list.remove(folder_path + '/' + 'dem_lidar_10m_fill.tif')
rast_cont_list.remove(folder_path + '/' + 'dem_lidar_10m.tif')
rast_cont_list.remove(folder_path + '/' + 'dem_lidar_10m_tri.tif')

# load rasters as individual dataset variables
ds_sdm = satfetcher.load_local_rasters(rast_path_list=rast_cont_list, 
                                       use_dask=True, 
                                       conform_nodata_to=-999)

# compute dask - we need to make calculations
ds_sdm = ds_sdm.compute()

# set path to shapefile
shp_path = r'../../data/nicher/presence_points/presence_points.shp'

# extract point x and y from shapefile as pandas dataframe
df_records = tools.read_shapefile(shp_path=shp_path)

# subset columns
df_presence = tools.subset_records(df_records=df_records, 
                                   p_a_column=None)

# drop presence column
df_presence = df_presence.drop('actual', axis='columns')

# generate absences using dataset pixels and occurrence coords
df_absence = nicher.generate_absences(ds=ds_sdm, 
                                      occur_shp_path=shp_path,
                                      buff_m=250, 
                                      res_factor=3)

# extract values for presence points
df_presence_data = tools.extract_xr_values(ds=ds_sdm, 
                                           coords=df_presence, 
                                           keep_xy=False, 
                                           res_factor=3)

# do same for absence points
df_absence_data = tools.extract_xr_values(ds=ds_sdm, 
                                          coords=df_absence, 
                                          keep_xy=False, 
                                          res_factor=3)

# remove all presence records containing nodata values
df_presence_data = tools.remove_nodata_records(df_records=df_presence_data,
                                               nodata_value=ds_sdm.nodatavals)

# remove all absence records containing nodata values
df_absence_data = tools.remove_nodata_records(df_records=df_absence_data,
                                               nodata_value=ds_sdm.nodatavals)

# take pres and abse records and combine, add new pres/abse column
df_pres_abse_data = nicher.combine_pres_abse_records(df_presence=df_presence_data, 
                                                     df_absence=df_absence_data)

# generate the matrix. < 0.6 weak collinearity, 0.6-0.8 moderate, >= 0.8 strong
nicher.generate_correlation_matrix(df_records=df_pres_abse_data,
                                   show_fig=True,
                                   show_text=False)

# generate vif scores. 1 = No multicolinearity, 1-5 = moderate, > 5 = high, > 10 = Remove
nicher.generate_vif_scores(df_records=df_pres_abse_data)
plt.show()

# create a random forest estimator using default sklearn parameters
estimator = nicher.create_estimator(estimator_type='rf', 
                                    n_estimators=100)

# generate SDM with 5 replicates and 10% training-testing split
ds_sdm = nicher.generate_sdm(ds=ds_sdm, 
                             df_records=df_pres_abse_data, 
                             estimator=estimator, 
                             rast_cont_list=rast_cont_list, 
                             rast_cate_list=None, 
                             replicates=5, 
                             test_ratio=0.1, 
                             equalise_test_set=False, 
                             calc_accuracy_stats=True)

# show results
fig = plt.figure(figsize=(9, 7), dpi=85)
ds_sdm['sdm_mean'].plot(robust=False, cmap='jet')

In [None]:
# export to nc
#tools.export_xr_as_nc(ds=ds_sdm, filename='ds_sdm.nc')

# Generate Veg Frax

### Prepare Landsat satellite image

In [None]:
# calculate veg (mavi) and moist (ndmi) indices
ds_raw = tools.calculate_indices(ds=ds_backup, 
                                 index=['tcg', 'tcb', 'tcw'], 
                                 custom_name=None, 
                                 rescale=False, 
                                 drop=True)

# load into memory now - we have values to modify!
ds_raw = ds_raw.median('time', keep_attrs=True).compute()

### Prepare classified hi-def raster

In [None]:
# set path to high-resolution classified image (e.g. 10m Sentinel 2 or 1m WV)
rast_class = r'../../data/vegfrax/class/Vegetation_Mapping_Mine_20181121_rasterised_albers.tif'
ds_class = satfetcher.load_local_rasters(rast_path_list=[rast_class], 
                                         use_dask=True, 
                                         conform_nodata_to=-128)

# do basic preparations (dtype, rename, checks)
ds_class = vegfrax.prepare_classified_xr(ds=ds_class)

# subset high to low extent
ds_class = tools.clip_xr_to_xr(ds_a=ds_class, 
                               ds_b=ds_raw)

# set them manually (make sure you include 0 if reclassifying)
req_class = [1, 3, 13, 14, 0]
ds_class = vegfrax.reclassify_xr(ds=ds_class, 
                                 req_class=req_class,
                                 merge_classes=True,
                                 inplace=True)

# get list of all classes...
req_class = vegfrax.get_xr_classes(ds_class)

# load into memory now - we have values to modify!
ds_class = ds_class.compute()

### Prepare frequencies

In [None]:
# generate random samples within area overlap between raw and classified rasters
num_samples = 500
df_samples = vegfrax.generate_strat_random_samples(ds_raw=ds_raw,
                                                   ds_class=ds_class, 
                                                   req_class=req_class,
                                                   num_samples=num_samples)

# extract pixel values from raw, low resolution rasters at each point
df_extract = tools.extract_xr_values(ds=ds_raw, 
                                     coords=df_samples, 
                                     keep_xy=True)

# remove any points containing a nodata value
df_extract_clean = tools.remove_nodata_records(df_extract, 
                                               nodata_value=ds_raw.nodatavals)

# generate focal windows and extract pixels from class raster
df_windows = vegfrax.create_frequency_windows(ds_raw=ds_raw, 
                                              ds_class=ds_class, 
                                              df_records=df_extract_clean)

# transform raw focal window pixel classes and counts to unique classes and frequencies at each point
df_freqs = vegfrax.convert_window_counts_to_freqs(df_windows=df_windows, 
                                                  nodata_value=ds_class.nodatavals)

### Perform FCA

In [None]:
# set desired output classes. keep empty to produce all classes. could put 1, 2 for classes 1 and 2.
override_classes = ['1']

# prepare data for analysis - prepare classes, nulls, normalise frequencies
df_data = vegfrax.prepare_freqs_for_analysis(ds_raw=ds_raw, 
                                             ds_class=ds_class, 
                                             df_freqs=df_freqs, 
                                             override_classes=override_classes)

# perform fca
ds_preds = vegfrax.perform_fca(ds_raw=ds_raw, 
                               ds_class=ds_class, 
                               df_data=df_data, 
                               df_extract_clean=df_extract_clean, 
                               n_estimators=100,
                               n_validations=10)

# create fig
class_label = '1'
fig = plt.figure(figsize=(12, 9))
ds_preds[class_label].plot(robust=False, cmap='terrain_r')

### Generate field occurrence points for thresholding

In [None]:
# set location of point shapefile with presence/absence column
#shp_path = r'../GDVSDM/data_testing/presence_points/presence_points.shp'
shp_path = r'../../data/gdvspectra/royhill_2_final_albers.shp'

# read shapefile as pandas dataframe
df_records = tools.read_shapefile(shp_path=shp_path)

# subset to just x, y, pres/abse column
df_records = tools.subset_records(df_records=df_records, p_a_column='GDV_ACT')

# display dataframe
#df_records

In [None]:
# prepare metric for auc
ds_preds.attrs = ds.attrs
da = ds_preds.rename({class_label: 'like'})

# threshold to get auc
gdvspectra.threshold_likelihood(ds=da,
                                df=df_records, 
                                num_stdevs=2.5, 
                                res_factor=3, 
                                if_nodata='any')

## Perform Ensemble

In [5]:
# get paths of rasters from arcgis
like_path = r'../../data/ensemble/royhill/like.tif'
pheno_path = r'../../data/ensemble/royhill/pheno_liot.tif'
sdm_path = r'../../data/ensemble/royhill/sdm_lidar.tif'
frax_path = r'../../data/ensemble/royhill/vegfrax_class_3_vs_0.tif'
chm_path = r'../../data/ensemble/royhill/chm_lidar_10m.tif'

# create empty dict 
in_dict = {
    'like_path': like_path,
    'pheno_path': pheno_path,
    'sdm_path': sdm_path,
    'frax_path': frax_path,
    'chm_path': chm_path,
}

# remove any empty keys
for k in list(in_dict.keys()):
    if in_dict[k] in ['', None]:
        del in_dict[k]
        
# check if anything remains
if len(in_dict) == 0:
    raise ValueError('No valid paths provided.')

In [6]:
# iter dict and replace path with lazy loaded ds
ds_dict = {}
for k, v in in_dict.items():
    var_name = k.replace('_path', '')
    ds = satfetcher.load_local_rasters(rast_path_list=v, 
                                       use_dask=True, 
                                       conform_nodata_to=np.nan)
    # convert to array
    ds_dict[var_name] = ds.to_array().squeeze(drop=True)

Converting rasters to an xarray dataset.
Converted raster to xarray data array: like
Rasters converted to dataset successfully.

Converting rasters to an xarray dataset.
Converted raster to xarray data array: pheno_liot
Rasters converted to dataset successfully.

Converting rasters to an xarray dataset.
Converted raster to xarray data array: sdm_lidar
Rasters converted to dataset successfully.

Converting rasters to an xarray dataset.
Converted raster to xarray data array: vegfrax_class_3_vs_0
Rasters converted to dataset successfully.

Converting rasters to an xarray dataset.
Converted raster to xarray data array: chm_lidar_10m
Rasters converted to dataset successfully.



In [7]:
# use pheno as our target layer for resampling , else like
if 'pheno' in ds_dict:
    resampler_var = 'pheno'
elif 'like' in ds_dict:
    resampler_var = 'like'
else:
    raise ValueError('Ensemble must have a likelihood or phenometric var(s).')

In [8]:
# iterate each layer and resample to target layer
for k in list(ds_dict.keys()):
    if k != resampler_var:
        ds_dict[k] = tools.resample_xr(ds_from=ds_dict[k], 
                                       ds_to=ds_dict[resampler_var], 
                                       resampling='nearest')

In [9]:
# iter each and perform sigmoidal
for k in list(ds_dict.keys()):
    if k == 'like':
        ds_dict[k] = canopy.inc_sigmoid(ds=ds_dict[k], 
                                        a=0.3, 
                                        b=float(ds_dict[k].max()))
    
    elif k == 'pheno':
        ds_dict[k] = canopy.inc_sigmoid(ds=ds_dict[k], 
                                        a=5, 
                                        b=float(ds_dict[k].max()))
        
    elif k == 'sdm':
        ds_dict[k] = canopy.inc_sigmoid(ds=ds_dict[k], 
                                        a=0.1, 
                                        b=float(ds_dict[k].max()))
        
    elif k == 'chm':
        ds_dict[k] = canopy.bell_sigmoid(ds=ds_dict[k], 
                                         a=1, 
                                         bc=6,
                                         d=11)
        
    elif k == 'frax':
        ds_dict[k] = canopy.dec_sigmoid(ds=ds_dict[k],
                                        c=float(ds_dict[k].min()),
                                        d=float(ds_dict[k].max()))

In [26]:
ds_dict

{'pheno': <xarray.DataArray 'stack-234e19b4f96e460d73156c9da24c7378' (y: 1208, x: 1121)>
 dask.array<where, shape=(1208, 1121), dtype=float32, chunksize=(1208, 1121), chunktype=numpy.ndarray>
 Coordinates:
   * y        (y) float64 -2.46e+06 -2.46e+06 -2.46e+06 ... -2.496e+06 -2.496e+06
   * x        (x) float64 -1.236e+06 -1.236e+06 ... -1.202e+06 -1.202e+06,
 'sdm': <xarray.DataArray 'stack-bf5225d3e75bdd754da8cd996582e654' (y: 1208, x: 1121)>
 dask.array<where, shape=(1208, 1121), dtype=float32, chunksize=(1208, 1121), chunktype=numpy.ndarray>
 Coordinates:
   * x        (x) float64 -1.236e+06 -1.236e+06 ... -1.202e+06 -1.202e+06
   * y        (y) float64 -2.46e+06 -2.46e+06 -2.46e+06 ... -2.496e+06 -2.496e+06}

In [22]:
def perform_dempster_shafer(ds_dict):
    """
    """
    
    # if all vars provided...
    bpa_list = ['like', 'pheno', 'sdm', 'chm', 'frax']
    if len(ds_dict) == len(bpa_list):
        if all(e in ds_dict for e in bpa_list):
            print('Performing Dempster-Shafer on all variables.')
            return ensemble.bpa_all(ds_dict)
            
    # if like, sdm, chm, frax provided...
    bpa_list = ['like', 'sdm', 'chm', 'frax']
    if len(ds_dict) == len(bpa_list):
        if all(e in ds_dict for e in bpa_list):
            print('Performing Dempster-Shafer on like, sdm, chm and frax variables.')
            return ensemble.bpa_lscv(ds_dict)

    # if pheno, sdm, chm, frax provided...
    bpa_list = ['pheno', 'sdm', 'chm', 'frax']
    if len(ds_dict) == len(bpa_list):
        if all(e in ds_dict for e in bpa_list):
            print('Performing Dempster-Shafer on pheno, sdm, chm and frax variables.')
            return ensemble.bpa_pscv(ds_dict)
        
    # if pheno, sdm, frax provided...
    bpa_list = ['pheno', 'sdm', 'frax']
    if len(ds_dict) == len(bpa_list):
        if all(e in ds_dict for e in bpa_list):
            print('Performing Dempster-Shafer on pheno, sdm and frax variables.')
            return ensemble.bpa_psv(ds_dict)    
    
    # if pheno, chm, frax provided...
    bpa_list = ['pheno', 'chm', 'frax']
    if len(ds_dict) == len(bpa_list):
        if all(e in ds_dict for e in bpa_list):
            print('Performing Dempster-Shafer on pheno, chm and frax variables.')
            return ensemble.bpa_pcv(ds_dict)
        
    # do pheno, sdm, chm
        
    # do pheno, sdm
    
    # do pheno, chm

    # if pheno and frax provided...
    bpa_list = ['pheno', 'frax']
    if len(ds_dict) == len(bpa_list):
        if all(e in ds_dict for e in bpa_list):
            print('Performing Dempster-Shafer on pheno and frax variables.')
            return ensemble.bpa_pv(ds_dict)    


# remove chm < 2

In [25]:
ds_dempster = perform_dempster_shafer(ds_dict)

In [24]:
ds_dempster

In [None]:
# load like var
ds_like = satfetcher.load_local_rasters(rast_path_list=r'../../data/ensemble/royhill/like.tif', 
                                        use_dask=True, 
                                        conform_nodata_to=np.nan)

# load liot var
ds_liot = satfetcher.load_local_rasters(rast_path_list=r'../../data/ensemble/royhill/pheno_liot.tif', 
                                        use_dask=True, 
                                        conform_nodata_to=np.nan)

# load sdm lidar
ds_sdm = satfetcher.load_local_rasters(rast_path_list=r'../../data/ensemble/royhill/sdm_lidar.tif', 
                                       use_dask=True, 
                                       conform_nodata_to=np.nan)

# load veg frax
ds_frax = satfetcher.load_local_rasters(rast_path_list=r'../../data/ensemble/royhill/vegfrax_class_3_vs_0.tif', 
                                        use_dask=True, 
                                        conform_nodata_to=np.nan)

# load chm
ds_chm = satfetcher.load_local_rasters(rast_path_list=r'../../data/ensemble/royhill/chm_lidar_10m.tif', 
                                        use_dask=True, 
                                        conform_nodata_to=np.nan)

In [None]:
# resample higher to lower
#sdm to like

