This notebook loads in satellite data, predict land covers using pre-trained models and calculate majority vote of the predicted classification results. For Lesotho this notebook is used twice in the workflow, first to produce a reference/baseline land cover map using the unfiltered training data, second to produce a land cover map for a target year (2021) using the filtered training data.

In [1]:
%matplotlib inline
import os
import datacube
import warnings
import time
import numpy as np
from scipy import stats
import pandas as pd
import geopandas as gpd
import matplotlib.pyplot as plt
from odc.algo import xr_geomedian
import xarray as xr
from joblib import load
from deafrica_tools.classification import predict_xr
from deafrica_tools.dask import create_local_dask_cluster
from deafrica_tools.datahandling import load_ard
from deafrica_tools.bandindices import calculate_indices
from datacube.utils.cog import write_cog

# file paths and attributes
lesotho_tiles_shp='Results/stratified_sampling_AOIs.geojson' # randomly selected small regions
# lesotho_tiles_shp='Data/Lesotho_boundaries_projected_epsg32735_tiles.shp'
# lesotho_tiles_shp='Data/Lesotho_boundaries_projected_epsg32735_tiles_smaller.shp'
# lesotho_tiles_shp='Data/Lesotho_boundaries_projected_epsg32735_tiles_bigger.shp'

# rf_model_path='Results/RF_models_GEE_replicate.joblib' # models trained on unfiltered training data
rf_model_path='Results/RF_models_using_filtered_td_GEE_replicate.joblib' # models trained on filtered training data

class_name = 'LC_Class_I' # class label in integer format
crs='epsg:4326' # input crs: WGS84
output_crs='epsg:32735' # output crs: WGS84/UTM Zone 35S

# load and get bounding boxes of tiles covering Lesotho
lesotho_tiles=gpd.read_file(lesotho_tiles_shp).to_crs(crs)
tile_bboxes=lesotho_tiles.bounds
print('tile boundaries for Lesotho: \n',tile_bboxes)

# load trained classifier
rf_models = load(rf_model_path)
print('loaded random forest models:\n',rf_models)

# band mesurements for query
measurements = ['blue','green','red','red_edge_1','red_edge_2', 'red_edge_3','nir_1','nir_2','swir_1','swir_2']

# Set up a dask cluster
create_local_dask_cluster(n_workers=1)

  _numeric_index_types = (pd.Int64Index, pd.Float64Index, pd.UInt64Index)
  _numeric_index_types = (pd.Int64Index, pd.Float64Index, pd.UInt64Index)
  _numeric_index_types = (pd.Int64Index, pd.Float64Index, pd.UInt64Index)


ncpus = 62
tile boundaries for Lesotho: 
          minx       miny       maxx       maxy
0   27.635700 -29.477456  27.646070 -29.468382
1   28.847423 -29.636949  28.857913 -29.627784
2   28.130782 -28.910120  28.141136 -28.901010
3   28.102968 -28.898528  28.113318 -28.889419
4   27.557625 -29.382681  27.567979 -29.373612
..        ...        ...        ...        ...
85  29.460999 -29.928268  29.471572 -29.919058
86  29.482590 -30.455901  29.493226 -30.446687
87  28.046525 -30.684943  28.057061 -30.675836
88  29.462675 -29.582384  29.473209 -29.573175
89  27.846329 -30.699937  27.856849 -30.690845

[90 rows x 4 columns]
loaded random forest models:
 [RandomForestClassifier(max_samples=0.5, n_estimators=50, n_jobs=-1), RandomForestClassifier(max_samples=0.5, n_estimators=50, n_jobs=-1), RandomForestClassifier(max_samples=0.5, n_estimators=50, n_jobs=-1), RandomForestClassifier(max_samples=0.5, n_estimators=50, n_jobs=-1), RandomForestClassifier(max_samples=0.5, n_estimators=50, n_jobs=

0,1
Client  Scheduler: tcp://127.0.0.1:35007  Dashboard: /user/whusggliuqx@gmail.com/proxy/8787/status,Cluster  Workers: 1  Cores: 62  Memory: 512.40 GB


In [None]:
# define a function to feature layers
def feature_layers(query): 
    #connect to the datacube
    dc = datacube.Datacube(app='feature_layers')
    # query bands
    ds = load_ard(dc=dc,
                  products=['s2_l2a'],
                  group_by='solar_day',
                  verbose=False,
                  #mask_filters=[("opening", 2)], # morphological opening by 2 pixels to remove small masked regions
                  **query)
    ds_index = calculate_indices(ds,index=['NDVI'],drop=False,satellite_mission='s2')
    del ds
    # calculate NDVI
    ds_index['NDVI']=ds_index['NDVI']*10000
    # calculate geomedians within each two-month interval
    ds_geomedian=ds_index.resample(time='2MS').map(xr_geomedian).astype(np.int16)
    del ds_index
#     # rechunk to a single array along time dimension so that interpolate_na can be applied: note: this may consume more memory
#     ds_geomedian=ds_geomedian.chunk({'time':-1})
#     # interpolate nodata using mean of previous and next observation
#     ds_geomedian=ds_geomedian.interpolate_na(dim='time',method='linear',use_coordinate=False)
    # stack multi-temporal measurements and rename them
    n_time=ds_geomedian.dims['time']
    list_measurements=list(ds_geomedian.keys())
    list_stack_measures=[]
    for j in range(len(list_measurements)):
        for k in range(n_time):
            variable_name=list_measurements[j]+'_'+str(k)
            # print ('Stacking band ',list_measurements[j],' at time ',k)
            measure_single=ds_geomedian[list_measurements[j]].isel(time=k).rename(variable_name)
            list_stack_measures.append(measure_single)
    ds_stacked=xr.merge(list_stack_measures,compat='override')
    return ds_stacked

In [None]:
# loop through all tiles and predict land cover
for i in range(0,len(tile_bboxes)):
    # extract bounding box of each tile
    minx,miny,maxx,maxy=tile_bboxes.iloc[i]
    print('bounding box ',': minx: ',minx,'miny: ',miny,'maxx: ',maxx,'maxy: ',maxy)
    
    # define ODC query
    query = {
        'x': (minx,maxx),
        'y': (miny,maxy),
        'time': ('2021-01', '2021-12'),
        'measurements': measurements,
        'resolution': (-10, 10),
        'crs':crs,
        'output_crs':output_crs,
        'dask_chunks' : {'x':1700, 'y':1700}, # chunk size should be adjusted based on sandbox instance
#         'dask_chunks' : {'x':-1, 'y':-1}
    }
    
    # calculate features
    all_data = feature_layers(query)
    print('stacked Sentinel-2 dataset:\n',all_data)
    start_time = time.time() # start timing how long it takes for the prediction
    
    # predict land cover using the RF models
    list_xr_merge=[] # initialise a list of predicted results
    for j in range(len(rf_models)): # loop through the list of trained model
        print('predicting split ',j)
        rf_model=rf_models[j]
        predicted = predict_xr(rf_model,all_data,persist=False,clean=True).compute()
        prediction_single=predicted.Predictions.rename('prediction_'+str(j)) # rename each result
        list_xr_merge.append(prediction_single) # append to the list of predicted results
    predictions_all=xr.merge(list_xr_merge,compat='override') # merge as a multi-variable dataset
    print("%s seconds spent on predicting" % (time.time() - start_time)) # print time spent on prediction
    print('predictions:\n',predictions_all)
    
    # majority vote of the predicted results
    arr_predictions_all=predictions_all.to_array() # convert dataset to data array
    predictions_mode=stats.mode(arr_predictions_all).mode.squeeze() # apply majority voting
    
    # write final prediction as cog file
#     outname_prediction='Results/Land_cover_prediction_basline_Lesotho_2021_GEE_replicate_tile'+str(i)+'.tif'
#     outname_prediction='Results/Land_cover_prediction_Lesotho_2021_GEE_replicate_tile'+str(i)+'.tif'
    outname_prediction='Results/Land_cover_prediction_Lesotho_2021_GEE_replicate_sampling_AOI_'+str(i)+'.tif'
    xr_predictions_mode=xr.DataArray(predictions_mode, coords=prediction_single.coords, dims=prediction_single.dims,
                                     attrs=prediction_single.attrs).astype(np.int16) # numpy array to data array
    print('writing majority vote prediction cog file...')
    write_cog(xr_predictions_mode, outname_prediction, overwrite=True) # write as cog

In [1]:
# merge multiple tiles as a mosaic tif
! gdal_merge.py -o Results/Land_cover_prediction_Lesotho_2021_GEE_replicate_sampling_mosiac_AOI.tif -co COMPRESS=Deflate -ot Byte Results/Land_cover_prediction_Lesotho_2021_GEE_replicate_sampling_AOI_*.tif

0...10...20...30...40...50...60...70...80...90...100 - done.
