# Prediction <img align="right" src="../Supplementary_data/DE_Africa_Logo_Stacked_RGB_small.jpg">

## Background

stuff

## Description
This notebook

### Load Packages

In [None]:
import datacube
from odc.algo import xr_geomedian
import xarray as xr
import subprocess as sp
import numpy as np
from joblib import load
import geopandas as gpd
from datacube.utils import geometry
from datacube.utils.cog import write_cog
from datacube.utils.geometry import assign_crs

import sys
sys.path.append('../Scripts')
from deafrica_datahandling import load_ard
from deafrica_classificationtools import predict_xr, predict_proba_xr
from deafrica_dask import create_local_dask_cluster
from deafrica_plotting import map_shapefile
from deafrica_bandindices import calculate_indices
from deafrica_temporal_statistics import temporal_statistics

### Set up a dask cluster
This will help keep our memory use down and conduct the analysis in parallel. If you'd like to view the dask dashboard, click on the hyperlink that prints below the cell. You can use the dashboard to monitor the progress of calculations.

In [None]:
create_local_dask_cluster()

## Analysis parameters

* `ncpus`: Set this value to > 1 to parallize the collection of training data. eg. npus=8. 
* `model`: Set

In [None]:
# automatically detect number of cpus, adjust to [-3:] if working on deafault Sandbox
ncpus= int(float(sp.getoutput('env | grep CPU')[-4:]))

model_path = 'results/ml_model.joblib'

print('ncpus = '+str(ncpus))

### Connect to the datacube

In [None]:
dc = datacube.Datacube(app='prediction')

## Open the model



In [None]:
model = load(model_path)

## Open 'tiles' shapefile

In [None]:
#read shapefile
gdf = gpd.read_file('../crop_mask/data/tiles.shp')

#open shapefile
aez=gpd.read_file('../crop_mask/data/AEZs/Southern.shp')

# clip points to region
gdf = gpd.overlay(gdf, aez, how='intersection')

# add an ID column
gdf['id']=range(0, len(gdf))


In [None]:
#print gdf
list_of_tiles = [3,6,9,11,17,26]

map_shapefile(gdf.iloc[list_of_tiles], 'id', hover_col='id')

## Make a prediction

Extract data from the datacube exactly matching the feature layers we created during the extraction of training data in script `1_Extract_training_data.ipynb`

In [None]:
def temporalStats_and_elevation(ds):   
    
    # summarise the surface reflectance bands
    sr = ds.median('time').compute()
    
    # ndvi time series
    ndvi = calculate_indices(ds,
                             index=['NDVI'],
                             drop=True,
                             collection='s2')
    
    # calculate some temporal stats
    print('temporal')
    ts = temporal_statistics(ndvi.NDVI,
                       stats=['f_mean','abs_change',
                              'complexity','central_diff']).compute()
    
    # Load elevation data using the spatial coords from ds
    elev = dc.load(product='srtm', like=ds.geobox).squeeze()
    
    #merge the results so we return a single xarray.Dataset
    result = xr.merge([ts,sr,elev], compat='override')
    
    #reassign crs/geobox
    result = assign_crs(result, crs=ds.geobox.crs)
    
    return result.squeeze()

In [None]:
#set up our inputs to collect_training_data
products =  ['s2_l2a']
time = ('2019-01','2019-12')

# Set up the inputs for the ODC query
measurements =  ['red', 'nir', 'blue', 'swir_1', 'swir_2']
resolution = (-30,30)
output_crs='epsg:6933'
dask_chunks={'x':1000,'y':1000,'time':-1}

In [None]:
tiles_classified = []

for index, row in gdf.iloc[list_of_tiles].iterrows():

    print("Working on tile: "+str(gdf['id'][index]))
    
    # generate a datacube query object
    query = {
        'time': time,
        'measurements': measurements,
        'resolution': resolution,
        'output_crs': output_crs,
        'group_by' : 'solar_day',
    }
    
    # Get the geometry
    geom = geometry.Geometry(row.geometry.__geo_interface__,
                             geometry.CRS(f'EPSG:{gdf.crs.to_epsg()}'))
    
    # Update dc query with geometry      
    query.update({'geopolygon': geom})

    ds = load_ard(dc=dc,
                  products=products,
                  dask_chunks=dask_chunks,
                  **query)

    data = temporalStats_and_elevation(ds)
    
    #predict using the imported model
    print('predicting...')
    predicted = predict_xr(model, data.squeeze(), progress=True)
    tiles_classified.append(predicted)    
    write_cog(predicted, 'results/classifications/Southern_'+ str(row['id'])+'_prediction.tif')
    
#     predicted_proba = predict_proba_xr(model, data.squeeze(), progress=True)
#     write_cog(predicted_proba, 'results/classifications/Southern_'+ row['id']+'_prediction_proba.tif')
    