# Prediction <img align="right" src="../Supplementary_data/DE_Africa_Logo_Stacked_RGB_small.jpg">

## Background

stuff

## Description
This notebook 

### Load Packages

In [None]:
import datacube
from odc.algo import xr_geomedian
import xarray as xr
import subprocess as sp
import numpy as np
from joblib import load

import sys
sys.path.append('../Scripts')
from deafrica_datahandling import load_ard
from deafrica_classificationtools import predict_xr
from deafrica_dask import create_local_dask_cluster

### Set up a dask cluster
This will help keep our memory use down and conduct the analysis in parallel. If you'd like to view the dask dashboard, click on the hyperlink that prints below the cell. You can use the dashboard to monitor the progress of calculations.

In [None]:
create_local_dask_cluster()

## Analysis parameters

* `ncpus`: Set this value to > 1 to parallize the collection of training data. eg. npus=8. 
* `model`: Set

In [None]:
# automatically detect number of cpus, adjust to [-3:] if working on deafault Sandbox
ncpus= int(float(sp.getoutput('env | grep CPU')[-4:]))

model_path = 'results/ml_model.joblib'


### Connect to the datacube

In [None]:
dc = datacube.Datacube(app='prediction')

## Open the model



In [None]:
model = load(model_path)

## Extract feature layers from datacube

In [None]:
ds = load_ard(dc=dc, 
              products=['s2_l2a'],
              dask_chunks={'x':2000, 'y':2000},
              dtype='native',
              **query)


In [None]:
def two_epochs_MADS(ds):
    dc = datacube.Datacube(app='training')
    
    epoch1_gm = geomedian = int_geomedian(ds)
    epoch1_gm = calculate_indices(epoch1_gm,
                             index=['NDVI', 'LAI'],
                             drop=False,
                             collection='s2')
    
    stats = TernaryMAD(num_threads=1)
    epoch1_mad = stats.compute(data=ds)
    epoch1_mad.coords['x'] = ds.x
    epoch1_mad.coords['y'] = ds.y
    
    q = {
    'geopolygon':ds.geobox.extent,
    'time': ('2019-07', '2019-12'),
    'measurements': ['blue','green','red','nir','swir_1','swir_2'],
    'resolution': (-20, 20),
    'group_by' :'solar_day',
    'output_crs':'epsg:6933'}
    

    print('epoch 2')    
    ds2 = load_ard(dc=dc,products=['s2_l2a'],**q)    
    epoch2_gm = GeoMedian().compute(ds2)
    epoch2_gm = calculate_indices(epoch2_gm,
                             index=['NDVI', 'LAI'],
                             drop=False,
                             collection='s2')
    
    epoch2_gm = epoch2_gm.rename({
                     'blue':'blue_2',
                     'green':'green_2',
                     'red':'red_2',
                     'nir':'nir_2',
                     'swir_1':'swir_1_2',
                     'swir_2':'swir_2_2',
                     'NDVI':'NDVI_2',
                     'LAI':'LAI_2'
                      })
    
    stats = TernaryMAD(num_threads=1)
    epoch2_mad = stats.compute(data=ds2)
    epoch2_mad.coords['x'] = ds2.x
    epoch2_mad.coords['y'] = ds2.y
    epoch2_mad = epoch2_mad.rename({
        'sdev':'sdev_2',
        'edev':'edev_2',
        'bcdev':'bcdev_2'
    })

    print('slope...')
    slope = dc.load(product='srtm', like=ds.geobox).squeeze()
    slope = slope.elevation
    slope = xr_terrain(slope, 'slope_riserun')
    slope = slope.to_dataset(name='slope')
    
    print('Merging...')
    result = xr.merge([epoch1_gm,
                       epoch1_mad,
                       epoch2_gm,
                       epoch2_mad,
                       slope], compat='override')

    return result.squeeze()


## Make a prediction

In [None]:
predicted = predict_xr(model, data, progress=True)