# Prediction <img align="right" src="../figs/DE_Africa_Logo_Stacked_RGB_small.jpg">

## Description

Using the model we created in the `3_Train_fit_evaluate_classifier.ipynb`, this notebook will make predictions on new data to generate a cropland mask for Eastern Africa. The notebook will ceate both pixel-wise classifications and classification probabilities. Results are saved to disk as Cloud-Optimised-Geotiffs.

1. Open and inspect the shapefile which delineates the extent we're classifying
2. Import the model
3. Make predictions on new data loaded through the ODC.  The pixel classification will also undergo a post-processing step where steep slopes and water are masked using a SRTM derivative and WOfS, respectively. Pixels labelled as crop above 3600 metres ASL are also masked. 

***
## Getting started

To run this analysis, run all the cells in the notebook, starting with the "Load packages" cell. 

### Load Packages

In [1]:
import warnings
warnings.filterwarnings("ignore")

import sys
import os
import gdal
import shutil
import datacube
import numpy as np
import xarray as xr
import geopandas as gpd
import subprocess as sp
from joblib import load
from datacube.utils import geometry
from datacube.utils.cog import write_cog
from rsgislib.segmentation import segutils
from scipy.ndimage.measurements import _stats
from datacube.utils.geometry import assign_crs
from datacube.testutils.io import rio_slurp_xarray

sys.path.append('../../Scripts')
from deafrica_datahandling import load_ard
from deafrica_dask import create_local_dask_cluster
from deafrica_classificationtools import HiddenPrints, predict_xr
from deafrica_plotting import map_shapefile
from deafrica_spatialtools import xr_rasterize

#import out feature layer function for prediction
from feature_layer_functions import gm_mads_two_seasons_production
%load_ext autoreload
%autoreload 2

### Set up a dask cluster
This will help keep our memory use down and conduct the analysis in parallel. If you'd like to view the dask dashboard, click on the hyperlink that prints below the cell. You can use the dashboard to monitor the progress of calculations.

In [2]:
create_local_dask_cluster()

0,1
Client  Scheduler: tcp://127.0.0.1:33451  Dashboard: /user/chad/proxy/8787/status,Cluster  Workers: 1  Cores: 31  Memory: 254.70 GB


## Analysis parameters

* `model_path`: The path to the location where the model exported from the previous notebook is stored
* `training_data`: Name and location of the training data `.txt` file output from runnning `1_Extract_training_data.ipynb`
* `test_shapefile`: A shapefile containing polygons that represent regions where you want to test your model. The shapefile should have a unique identifier as this will be used to export classification results to disk as geotiffs. Alternatively, this could be a shapefile that defines the extent of the entire AOI you want to classify.
* `results`: A folder location to store the classified geotiffs 

In [3]:
model_path = 'results/gm_mads_two_seasons_ml_model_20210401.joblib'

training_data = "results/training_data/gm_mads_two_seasons_training_data_20210401.txt"

# test_shapefile = 'data/s2_tiles_eastern_aez.geojson'
test_shapefile = 'data/imagesegtiles.shp'

results = 'results/classifications/predicted/20210401/'#/g/data/u23/data/crop-mask/prediction/'

model_type='gm_mads_two_seasons'

### Open and inspect test_shapefile

In [4]:
gdf = gpd.read_file(test_shapefile)

In [5]:
# gdf.head()
# map_shapefile(gdf, attribute='title')

## Open the model

The code below will also re-open the training data we exported from `3_Train_fit_evaluate_classifier.ipynb`

In [6]:
model = load(model_path)

## Making a prediction


### Set up datacube query

These query options should exactl match the query params in `1_Extract_training_data.ipynb`, unless there are measurements that no longer need to be loaded because they were dropped during a feature selection process (which we didn't conduct).

In [7]:
#set up our inputs to collect_training_data
products = ['s2_l2a']
time = ('2019-01', '2019-12')
measurements = [
    'red', 'blue', 'green', 'nir', 'swir_1', 'swir_2', 'red_edge_1',
    'red_edge_2', 'red_edge_3'
]
resolution = (-20, 20)
output_crs = 'epsg:6933'
dask_chunks = {'x':'auto', 'y': 'auto'}


### Connect to the datacube

In [8]:
dc = datacube.Datacube(app='prediction')

### Loop through test tiles and predict

For every tile we list in the `test_shapefile`, we calculate the feature layers, and then use the DE Africa function `predict_xr` to classify the data.

The results are exported to file as Cloud-Optimised Geotiffs.

In [9]:
%%time
i=1
for index, row in gdf[0:4].iterrows():
    #get id for labelling
    g_id=gdf.iloc[index]['title']
    
    print('working on tile: '+g_id+". ","Tile: "+str(i)+"/"+str(len(gdf)),end='\r')

    #grab tile ids
    x=int(gdf.loc[index]['title'][:5])
    y=int(gdf.loc[index]['title'][6:])

    #load the precomputed tifs (20m) and generate features
    try:
        data=gm_mads_two_seasons_production(x=x, y=y)
    except:
        print('tile' + g_id+ ' failed')
        continue
    
    #predict using the imported model
    with HiddenPrints():
        predicted = predict_xr(model,
                           data.chunk(dask_chunks),
                           proba=True,
                           persist=True,
                           clean=True,
                           return_input=True
                          ).compute()
    
    # Mask dataset to set pixels outside the polygon to `NaN`
    with HiddenPrints():
        mask = xr_rasterize(gdf.iloc[[index]], data)
        predicted = predicted.where(mask).astype('float32')
    
    #write out ndvi for image seg
    ndvi = predicted[['NDVI_S1', 'NDVI_S2']]
    #write_cog(ndvi.to_array(), results+'predicted/20210401/ndvi/Eastern_tile_'+g_id+'_NDVI.tif',overwrite=True)
    
    #grab predictions and proba for post process filtering
    predict=predicted.Predictions
    proba=predicted.Probabilities
    proba=proba.where(predict==1, 100-proba) #crop proba only
    
    #-----------------image seg---------------------------------------------
    print('   image segmentation...')
    #store temp files somewhere
    directory=results+'tmp_'+g_id
    if not os.path.exists(directory):
        os.mkdir(directory)
    
    tmp='tmp_'+g_id+'/'

    #inputs to image seg
    tiff_to_segment = results+'ndvi/Eastern_tile_'+g_id+'_NDVI.tif'
    kea_file = results+'ndvi/Eastern_tile_'+g_id+'_NDVI.kea'
    segmented_kea_file = results+'ndvi/Eastern_tile_'+g_id+'_segmented.kea'

    #convert tiff to kea
    gdal.Translate(destName=kea_file,
                   srcDS=tiff_to_segment,
                   format='KEA',
                   outputSRS='EPSG:6933')
    
    #run image seg
    with HiddenPrints():
        segutils.runShepherdSegmentation(inputImg=kea_file,
                                             outputClumps=segmented_kea_file,
                                             tmpath=results+tmp,
                                             numClusters=60,
                                             minPxls=50)
    
    #open segments
    segments=xr.open_rasterio(segmented_kea_file).squeeze().values
    
    #calculate mode
    print('   calculating mode...')
    count, _sum =_stats(predict, labels=segments, index=segments)
    mode = _sum > (count/2)
    mode = xr.DataArray(mode, coords=predict.coords, dims=predict.dims, attrs=predict.attrs)
    
    #remove the tmp folder
    shutil.rmtree(results+tmp)
    os.remove(kea_file)
    os.remove(segmented_kea_file)
    #os.remove(tiff_to_segment)
    
    #--Post processing---------------------------------------------------------------
    print("     post processing")
    #mask with WOFS
    wofs=dc.load(product='ga_ls8c_wofs_2_summary',like=data.geobox)
    wofs=wofs.frequency > 0.2 # threshold
    predict=predict.where(~wofs, 0)
    proba=proba.where(~wofs, 0)
    mode=mode.where(~wofs, 0)

    #mask steep slopes
    url_slope="https://deafrica-data.s3.amazonaws.com/ancillary/dem-derivatives/cog_slope_africa.tif"
    slope=rio_slurp_xarray(url_slope, gbox=data.geobox)
    slope=slope > 35
    predict=predict.where(~slope, 0)
    proba=proba.where(~slope, 0)
    mode=mode.where(~slope, 0)

    #mask where the elevation is above 3600m
    elevation=dc.load(product='srtm', like=data.geobox)
    elevation=elevation.elevation > 3600 # threshold
    predict=predict.where(~elevation.squeeze(), 0)
    proba=proba.where(~elevation.squeeze(), 0)
    mode=mode.where(~elevation.squeeze(), 0)
    
    #set dtype
    predict=predict.astype(np.int8)
    proba=proba.astype(np.float32)
    mode=mode.astype(np.int8)
    
    #----export classifications to disk------------------------------------
#     write_cog(predict, results+'tiles/Eastern_tile_'+g_id+'_prediction_pixel_'+model_type+'_20210401.tif',
#               overwrite=True)
    
#     write_cog(mode, results+'segmented/Eastern_tile_'+g_id+'_prediction_filtered_'+model_type+'.tif', overwrite=True)
    
#     write_cog(proba, results+'proba/Eastern_tile_'+g_id+'_proba_'+model_type+'_20210401.tif',
#               overwrite=True)

#     #also save to g/data
# #     output_path='/g/data/crop_mask/eastern/classifications/gm_mads_two_seasons_20210203/'
# #     write_cog(predict, output_path+'predicted/Eastern_tile_'+g_id+'_prediction_pixel_'+model_type+'_20210203.tif',
# #               overwrite=True)
    
    i+=1


   image segmentation...0000.  Tile: 1/12
   calculating mode...
     post processing
   image segmentation...0001.  Tile: 2/12
   calculating mode...
     post processing
   image segmentation...0002.  Tile: 3/12
   calculating mode...
     post processing
   image segmentation...0000.  Tile: 4/12
   calculating mode...
     post processing
CPU times: user 10min 58s, sys: 1min 24s, total: 12min 22s
Wall time: 16min 9s


In [None]:
# working on tile: +0035,-0004.  Tile: 125/390

In [None]:
predict.plot(size=12)

## Next steps

To continue working through the notebooks in this `Eastern Africa Cropland Mask` workflow, go to the next notebook `5_Object-based_filtering.ipynb`.

1. [Extracting_training_data](1_Extracting_training_data.ipynb) 
2. [Inspect_training_data](2_Inspect_training_data.ipynb)
3. [Train_fit_evaluate_classifier](3_Train_fit_evaluate_classifier.ipynb)
4. **Predict (this notebook)**
5. [Object-based_filtering](5_Object-based_filtering.ipynb)


***

## Additional information

**License:** The code in this notebook is licensed under the [Apache License, Version 2.0](https://www.apache.org/licenses/LICENSE-2.0). 
Digital Earth Africa data is licensed under the [Creative Commons by Attribution 4.0](https://creativecommons.org/licenses/by/4.0/) license.

**Contact:** If you need assistance, please post a question on the [Open Data Cube Slack channel](http://slack.opendatacube.org/) or on the [GIS Stack Exchange](https://gis.stackexchange.com/questions/ask?tags=open-data-cube) using the `open-data-cube` tag (you can view previously asked questions [here](https://gis.stackexchange.com/questions/tagged/open-data-cube)).
If you would like to report an issue with this notebook, you can file one on [Github](https://github.com/digitalearthafrica/deafrica-sandbox-notebooks).

**Last modified:** Dec 2020
