# Point-based and Parallel Processing Water Observations from Space (WOfS) Product in Africa Using Sentinel 2  <img align="right" src="../Supplementary_data/DE_Africa_Logo_Stacked_RGB_small.jpg">

* **Products used:**
[s2_l2a](https://explorer.digitalearth.africa/s2_l2a)

## Description 
The [Water Observations from Space (WOfS)](https://www.ga.gov.au/scientific-topics/community-safety/flood/wofs/about-wofs) is a derived product from Landsat 8 satellite observations as part of provisional Landsat 8 Collection 2 surface reflectance and shows surface water detected in Africa.
Individual water classified images are called Water Observation Feature Layers (WOFLs), and are created in a 1-to-1 relationship with the input satellite data. 
Hence there is one WOFL for each satellite dataset processed for the occurrence of water.

The data in a WOFL is stored as a bit field. This is a binary number, where each digit of the number is independantly set or not based on the presence (1) or absence (0) of a particular attribute (water, cloud, cloud shadow etc). In this way, the single decimal value associated to each pixel can provide information on a variety of features of that pixel. 
For more information on the structure of WOFLs and how to interact with them, see [Water Observations from Space](../Datasets/Water_Observations_from_Space.ipynb) and [Applying WOfS bitmasking](../Frequently_used_code/Applying_WOfS_bitmasking.ipynb) notebooks.

This notebook explains how you can query WOfS product for each collected validation points in Africa based on point-based sampling approach. 

The notebook demonstrates how to:

1. Load validation points for each partner institutions following cleaning stage described in notebook 1
2. Query WOFL data for validation points and capture available WOfS defined class using point-based sampling and multiprocessing functionality
3. Extract a LUT for each point that contains both information for validation points and WOfS class as well number of clear observation in each month 
***

## Getting started

To run this analysis, run all the cells in the notebook, starting with the "Load packages" cell.

### Load packages
Import Python packages that are used for the analysis.

In [1]:
%matplotlib inline

import datacube
from datacube.utils import masking, geometry 
import sys
import os
import rasterio
import xarray
import glob
import numpy as np
import pandas as pd
import seaborn as sn
import geopandas as gpd
import matplotlib.pyplot as plt
import multiprocessing as mp
import scipy, scipy.ndimage
import warnings
warnings.filterwarnings("ignore") #this will suppress the warnings for multiple UTM zones in your AOI 

sys.path.append("../Scripts")
from geopandas import GeoSeries, GeoDataFrame
from shapely.geometry import Point
from sklearn.metrics import confusion_matrix, accuracy_score 
from sklearn.metrics import plot_confusion_matrix, f1_score  
from deafrica_plotting import map_shapefile,display_map, rgb
from deafrica_spatialtools import xr_rasterize
from deafrica_datahandling import wofs_fuser, mostcommon_crs,load_ard,deepcopy
from deafrica_dask import create_local_dask_cluster
from tqdm import tqdm

### Analysis parameters

To analyse validation points collected by each partner institution, we need to obtain WOfS surface water observation data that corresponds with the labelled input data locations. 
- Path2csv: the path to CEO validation points labelled by each partner institutions in Africa 
- ValPoints: CEO validation points labelled by each partner institutions in Africa in ESRI shapefile format 
- Path: Direct path to the ESRI shapefile in case that the shapefile in available
- input_data: geopandas datafram for CEO validation points labelled by each partner institutions in Africa

*** Note: Run the following three cells in case that you dont have a ESRI shapefile for validation points. 

In [2]:
path2csv = '../Data/Processed/AGRYHMET/AGRYHMET_ValidationPoints.csv'
df = pd.read_csv(path2csv,delimiter=",")

In [3]:
geometries = [Point(xy) for xy in zip(df.LON, df.LAT)]
crs = {'init': 'epsg:4326'} 
ValPoints = GeoDataFrame(df, crs=crs, geometry=geometries)

In [4]:
ValPoints.to_file(filename='../Data/Processed/AGRYHMET/AGRYHMET_ValidationPoints.shp') 

*** Note: In case that you have ESRI shapefile for validation points, please continute from this point onward. 

In [5]:
path = '../Data/Processed/AGRYHMET/AGRYHMET_ValidationPoints.shp'

In [6]:
#reading the table and converting CRS to metric
input_data = gpd.read_file(path).to_crs('epsg:6933')  
input_data.columns

Index(['Unnamed_ 0', 'PLOT_ID', 'LON', 'LAT', 'FLAGGED', 'ANALYSES',
       'SENTINEL2Y', 'STARTDATE', 'ENDDATE', 'WATER', 'NO_WATER', 'BAD_IMAGE',
       'NOT_SURE', 'CLASS', 'COMMENT', 'MONTH', 'WATERFLAG', 'geometry'],
      dtype='object')

In [7]:
input_data= input_data.drop(['Unnamed_ 0'], axis=1)

In [8]:
#Checking the size of the input data 
input_data.shape

(8724, 17)

### Sample WOfS at the ground truth coordinates
To load WOFL data, we can first create a re-usable query as below that will define two particular items, `group_by` solar day, ensuring that the data between scenes is combined correctly. The second parameter is `resampling` method that is set to be nearest. This query will later be updated in the script for other parameters to conduct WOfS query. the time period we are interested in, as well as other important parameters that are used to correctly load the data. 

We can convert the WOFL bit field into a binary array containing True and False values. This allows us to use the WOFL data as a mask that can be applied to other datasets. The `make_mask` function allows us to create a mask using the flag labels (e.g. "wet" or "dry") rather than the binary numbers we used above. For more details on how to do masking on WOfS, see the [Applying_WOfS_bit_masking](../Frequently_used_code/Applying_WOfS_bitmasking.ipynb) notebook in Africa sandbox.

In [9]:
#generate query object 
query ={'group_by':'solar_day',
        'resampling':'nearest'}

## WOfS Classifier

This is where the decision tree to determine if a particular pixel is wet or not wet is.

In [10]:
# wofs classifier
import gc
import numpy as np
import xarray as xr

import datacube
# Command line tool imports
import argparse
import os
import collections
import gdal
from datetime import datetime


def wofs_classify(dataset_in, clean_mask=None, x_coord='longitude', y_coord='latitude',
                  time_coord='time', no_data=-9999, mosaic=False, enforce_float64=False):
    """
    Description:
      Performs WOfS algorithm on given dataset.
    Assumption:
      - The WOfS algorithm is defined for Landsat 5/Landsat 7
    References:
      - Mueller, et al. (2015) "Water observations from space: Mapping surface water from
        25 years of Landsat imagery across Australia." Remote Sensing of Environment.
      - https://github.com/GeoscienceAustralia/eo-tools/blob/stable/eotools/water_classifier.py
    -----
    Inputs:
      dataset_in (xarray.Dataset) - dataset retrieved from the Data Cube; should contain
        coordinates: time, latitude, longitude
        variables: blue, green, red, nir, swir1, swir2
    x_coord, y_coord, time_coord: (str) - Names of DataArrays in `dataset_in` to use as x, y,
        and time coordinates.
    Optional Inputs:
      clean_mask (nd numpy array with dtype boolean) - true for values user considers clean;
        if user does not provide a clean mask, all values will be considered clean
      no_data (int/float) - no data pixel value; default: -9999
      mosaic (boolean) - flag to indicate if dataset_in is a mosaic. If mosaic = False, dataset_in
        should have a time coordinate and wofs will run over each time slice; otherwise, dataset_in
        should not have a time coordinate and wofs will run over the single mosaicked image
      enforce_float64 (boolean) - flag to indicate whether or not to enforce float64 calculations;
        will use float32 if false
    Output:
      dataset_out (xarray.DataArray) - wofs water classification results: 0 - not water; 1 - water
    Throws:
        ValueError - if dataset_in is an empty xarray.Dataset.
    """

    def _band_ratio(a, b):
        """
        Calculates a normalized ratio index
        """
        return (a - b) / (a + b)

    def _run_regression(band1, band2, band3, band4, band5, band7):
        """
        Regression analysis based on Australia's training data
        TODO: Return type
        """

        # Compute normalized ratio indices
        ndi_52 = _band_ratio(band5, band2)
        ndi_43 = _band_ratio(band4, band3)
        ndi_72 = _band_ratio(band7, band2)

        #classified = np.ones(shape, dtype='uint8')

        classified = np.full(shape, no_data, dtype='uint8')

        # Start with the tree's left branch, finishing nodes as needed

        # Left branch
        r1 = ndi_52 <= -0.01

        r2 = band1 <= 2083.5
        classified[r1 & ~r2] = 0  #Node 3

        r3 = band7 <= 323.5
        _tmp = r1 & r2
        _tmp2 = _tmp & r3
        _tmp &= ~r3

        r4 = ndi_43 <= 0.61
        classified[_tmp2 & r4] = 1  #Node 6
        classified[_tmp2 & ~r4] = 0  #Node 7

        r5 = band1 <= 1400.5
        _tmp2 = _tmp & ~r5

        r6 = ndi_43 <= -0.01
        classified[_tmp2 & r6] = 1  #Node 10
        classified[_tmp2 & ~r6] = 0  #Node 11

        _tmp &= r5

        r7 = ndi_72 <= -0.23
        _tmp2 = _tmp & ~r7

        r8 = band1 <= 379
        classified[_tmp2 & r8] = 1  #Node 14
        classified[_tmp2 & ~r8] = 0  #Node 15

        _tmp &= r7

        r9 = ndi_43 <= 0.22
        classified[_tmp & r9] = 1  #Node 17
        _tmp &= ~r9

        r10 = band1 <= 473
        classified[_tmp & r10] = 1  #Node 19
        classified[_tmp & ~r10] = 0  #Node 20

        # Left branch complete; cleanup
        del r2, r3, r4, r5, r6, r7, r8, r9, r10
        gc.collect()

        # Right branch of regression tree
        r1 = ~r1

        r11 = ndi_52 <= 0.23
        _tmp = r1 & r11

        r12 = band1 <= 334.5
        _tmp2 = _tmp & ~r12
        classified[_tmp2] = 0  #Node 23

        _tmp &= r12

        r13 = ndi_43 <= 0.54
        _tmp2 = _tmp & ~r13
        classified[_tmp2] = 0  #Node 25

        _tmp &= r13

        r14 = ndi_52 <= 0.12
        _tmp2 = _tmp & r14
        classified[_tmp2] = 1  #Node 27

        _tmp &= ~r14

        r15 = band3 <= 364.5
        _tmp2 = _tmp & r15

        r16 = band1 <= 129.5
        classified[_tmp2 & r16] = 1  #Node 31
        classified[_tmp2 & ~r16] = 0  #Node 32

        _tmp &= ~r15

        r17 = band1 <= 300.5
        _tmp2 = _tmp & ~r17
        _tmp &= r17
        classified[_tmp] = 1  #Node 33
        classified[_tmp2] = 0  #Node 34

        _tmp = r1 & ~r11

        r18 = ndi_52 <= 0.34
        classified[_tmp & ~r18] = 0  #Node 36
        _tmp &= r18

        r19 = band1 <= 249.5
        classified[_tmp & ~r19] = 0  #Node 38
        _tmp &= r19

        r20 = ndi_43 <= 0.45
        classified[_tmp & ~r20] = 0  #Node 40
        _tmp &= r20

        r21 = band3 <= 364.5
        classified[_tmp & ~r21] = 0  #Node 42
        _tmp &= r21

        r22 = band1 <= 129.5
        classified[_tmp & r22] = 1  #Node 44
        classified[_tmp & ~r22] = 0  #Node 45

        # Completed regression tree

        return classified

    # Default to masking nothing.
    if clean_mask is None:
        clean_mask = create_default_clean_mask(dataset_in)
    
    # Extract dataset bands needed for calculations
    #blue = dataset_in.blue
    #green = dataset_in.green
    #red = dataset_in.red
    #nir = dataset_in.nir
    #swir1 = dataset_in.swir_1
    #swir2 = dataset_in.swir_2
    blue = dataset_in.B02
    green = dataset_in.B03
    red = dataset_in.B04
    nir = dataset_in.B08
    swir1 = dataset_in.B11
    swir2 = dataset_in.B12

    # Enforce float calculations - float64 if user specified, otherwise float32 will do
    dtype = blue.values.dtype  # This assumes all dataset bands will have
    # the same dtype (should be a reasonable
    # assumption)

    if enforce_float64:
        if dtype != 'float64':
            blue.values = blue.values.astype('float64')
            green.values = green.values.astype('float64')
            red.values = red.values.astype('float64')
            nir.values = nir.values.astype('float64')
            swir1.values = swir1.values.astype('float64')
            swir2.values = swir2.values.astype('float64')
    else:
        if dtype == 'float64':
            pass
        elif dtype != 'float32':
            blue.values = blue.values.astype('float32')
            green.values = green.values.astype('float32')
            red.values = red.values.astype('float32')
            nir.values = nir.values.astype('float32')
            swir1.values = swir1.values.astype('float32')
            swir2.values = swir2.values.astype('float32')

    shape = blue.values.shape
    #print('decision time!')
    classified = _run_regression(blue.values, green.values, red.values, nir.values, swir1.values, swir2.values)

    classified_clean = np.full(classified.shape, no_data, dtype='float64')
    classified_clean[clean_mask] = classified[clean_mask]  # Contains data for clear pixels

    # Create xarray of data
    x_coords = dataset_in[x_coord]
    y_coords = dataset_in[y_coord]

    time = None
    coords = None
    dims = None

    if mosaic:
        coords = [y_coords, x_coords]
        dims = [y_coord, x_coord]
    else:
        time_coords = dataset_in[time_coord]
        coords = [time_coords, y_coords, x_coords]
        dims = [time_coord, y_coord, x_coord]

    data_array = xr.DataArray(classified_clean, coords=coords, dims=dims)

    if mosaic:
        dataset_out = xr.Dataset({'wofs': data_array},
                                 coords={y_coord: y_coords, x_coord: x_coords})
    else:
        dataset_out = xr.Dataset(
            {'wofs': data_array},
            coords={time_coord: time_coords, y_coord: y_coords, x_coord: x_coords})

    return dataset_out

Defining a function to query WOfS database according to the first five days before and after of each calendar month 

In [11]:
def get_wofs_for_point(index, row, input_data, query, results_wet, results_clear):
    dc = datacube.Datacube(app='WOfS_accuracy')
    #get the month value for each index
    month = input_data.loc[index]['MONTH'] 
    #get the value for time including year, month, start date and end date 
    timeYM = '2018-'+f'{month:02d}'
    # time needs changing for S2
    #start_date = np.datetime64(timeYM) - np.timedelta64(5,'D')
    #end_date = np.datetime64(timeYM) + np.timedelta64(5,'D')
    start_date = np.datetime64(timeYM) #- np.timedelta64(5,'D')
    end_date = np.datetime64(timeYM) + np.timedelta64(5,'D')
    time = (str(start_date),str(end_date))
    
    plot_id = input_data.loc[index]['PLOT_ID']
    #having the original query as it is 
    dc_query = deepcopy(query) 
    geom = geometry.Geometry(input_data.geometry.values[index].__geo_interface__,  geometry.CRS('EPSG:6933'))
    q = {"geopolygon":geom}
    t = {"time":time} 
    #updating the query
    dc_query.update(t)
    dc_query.update(q)
    bands = ['blue', 'green', 'red', 'nir', 'swir_1', 'swir_2', 'mask']
    # get the S2 data
    #print('gets thems s2s')
    # try using load not load ard
    s2_data = dc.load(product='s2_l2a',
                       #measurements=bands,
                       y = (input_data.geometry.y[index], input_data.geometry.y[index]),
                       x =(input_data.geometry.x[index], input_data.geometry.x[index]),
                       crs = 'EPSG:6933',
                       time=time,
                       output_crs = 'EPSG:6933',
                       resolution=(-10,10))
    #s2_data = load_ard(dc=dc,
    #                   products=['s2_l2a'],
    #                   measurements=bands,
    #                   y = (input_data.geometry.y[index], input_data.geometry.y[index]),
    #                   x =(input_data.geometry.x[index], input_data.geometry.x[index]),
    #                   crs = 'EPSG:6933',
    #                   time=time,
    #                   output_crs = 'EPSG:6933',
    #                   resolution=(-10,10))
    #print('tasty datas yum yum')
    if not 'B02' in s2_data: #need to get in here and make it right, run on s2_data, if not red in s2_data then pass (eg) then run the classifier
        pass
    else:
        # run the classifier
        clean_mask = np.isin(s2_data['SCL'], [4,5,6,7,11])
        #clean_mask = s2_data.to_array().isnull().any(dim='band')
        #print('mask time woot woot')
        # change this to mask out the nans
        # clean_mask for nans ds.to_array().isnull().any(dim='band')
        s2_wofs = wofs_classify(s2_data, 
                              clean_mask=clean_mask, 
                              no_data=np.nan, 
                              x_coord='x', 
                              y_coord='y')
        #print('omg s2 wofs')
        wofls=s2_wofs['wofs']
        #Define a mask for dry and clear pixels
        #wofl = bin bit mask
        #s2_wofs.wofs==1
        #wofl_wetnocloud = s2_wofs.wofs==1
        wofl_wetnocloud = (wofls==1) 
        wofl_drynocloud = (wofls==0)
        clear = (wofl_wetnocloud | wofl_drynocloud).all(dim=['x','y']).values
        #record the total number of clear observations for each point in each month and use it to filter out month with no valid data
        n_clear = clear.sum()  
        #condition to identify whether WOfS seen water in specific month for a particular location 
        if n_clear > 0:
            wet = wofl_wetnocloud.isel(time=clear).max().values  
        else:
            wet = 0 
        #updating results for both wet and clear observations 
        #print('wet/not wet')
        results_wet.update({str(int(plot_id))+"_"+str(month) : int(wet)})
        results_clear.update({str(int(plot_id))+"_"+str(month) : int(n_clear)})        
        
        return time


Define a function for parallel processing 

In [12]:
def _parallel_fun(input_data, query, ncpus):
    
    manager = mp.Manager()
    results_wet = manager.dict()
    results_clear = manager.dict()
   
    # progress bar
    pbar = tqdm(total=len(input_data))
        
    def update(*a):
        pbar.update()

    with mp.Pool(ncpus) as pool:
        for index, row in input_data.iterrows():
            #print('in the pool, index',index)
            pool.apply_async(get_wofs_for_point,
                                 [index,
                                 row,
                                 input_data,
                                 query,
                                 results_wet,
                                 results_clear], callback=update)
        #print('leaving pool')
        pool.close()
        pool.join()
        pbar.close()
        
    return results_wet, results_clear

Test the for loop 

In [13]:
results_wet_test = dict()
results_clear_test = dict()

for index, row in input_data[0:5].iterrows():
    time = get_wofs_for_point(index, row, input_data, query, results_wet_test, results_clear_test)
    print(time)



('2018-08', '2018-08-06')
('2018-10', '2018-10-06')
('2018-01', '2018-01-06')
('2018-02', '2018-02-06')
('2018-03', '2018-03-06')


Point-based query and parallel processing on WOfS 

In [14]:
# NOTE: this takes a while longer with S2 than with landsat, grab a tea or coffee while you wait?
# takes ~1min 20sec for 10 on 2 cpus 
wet, clear = _parallel_fun(input_data, query, ncpus=15)

100%|██████████| 8724/8724 [1:01:11<00:00,  2.38it/s]


In [15]:
#extracting the final table with both CEO labels and WOfS class Wet and clear observations 
wetdf = pd.DataFrame.from_dict(wet, orient = 'index')
cleardf = pd.DataFrame.from_dict(clear,orient='index')
df2 = wetdf.merge(cleardf, left_index=True, right_index=True)
df2 = df2.rename(columns={'0_x':'CLASS_WET','0_y':'CLEAR_OBS'})
#split the index (which is plotid + month) into seperate columns
for index, row in df2.iterrows():
    df2.at[index,'PLOT_ID'] = index.split('_')[0] +'.0'
    df2.at[index,'MONTH'] = index.split('_')[1]
#reset the index
df2 = df2.reset_index(drop=True)
#convert plot id and month to str to help with matching
input_data['PLOT_ID'] = input_data.PLOT_ID.astype(str)
input_data['MONTH']= input_data.MONTH.astype(str)
# merge both dataframe at locations where plotid and month match
final_df = pd.merge(input_data, df2, on=['PLOT_ID','MONTH'], how='outer')

In [16]:
#Defining the shape of final table 
final_df.shape

(8724, 19)

In [17]:
#Counting the number of rows in the final table with NaN values in class_wet and clear observation (Optional)
#This part is to test the parallel processig function returns identicial results each time that it runs 
countA = final_df["CLASS_WET"].isna().sum()
countB = final_df["CLEAR_OBS"].isna().sum()
countA, countB


(198, 198)

In [18]:
final_df.to_csv(('../Results/WOfS_Assessment/Point_Based/Institutions/AGRYHMET_PointBased_5D_S2.csv'))

In [19]:
print(datacube.__version__)

1.8.3


***

## Additional information

**License:** The code in this notebook is licensed under the [Apache License, Version 2.0](https://www.apache.org/licenses/LICENSE-2.0). 
Digital Earth Africa data is licensed under the [Creative Commons by Attribution 4.0](https://creativecommons.org/licenses/by/4.0/) license.

**Contact:** If you need assistance, please post a question on the [Open Data Cube Slack channel](http://slack.opendatacube.org/) or on the [GIS Stack Exchange](https://gis.stackexchange.com/questions/ask?tags=open-data-cube) using the `open-data-cube` tag (you can view previously asked questions [here](https://gis.stackexchange.com/questions/tagged/open-data-cube)).
If you would like to report an issue with this notebook, you can file one on [Github](https://github.com/digitalearthafrica/deafrica-sandbox-notebooks).

**Last modified:** September 2020

**Compatible datacube version:** 

## Tags
Browse all available tags on the DE Africa User Guide's [Tags Index](https://) (placeholder as this does not exist yet)