# This notebook demonstrates the use of post-classification probability filters

In classifying land cover, pixels are assigned probabilities for each land cover class that the pixel belongs in that land cover class. Oftentimes the land cover class with the highest probability is chosen as the final classification. However some land cover products choose to apply rules to these classification probabilities in order to increase the final accuracy, such as the [10m Sentinel-2 Based European Land Cover map](http://s2glc.cbk.waw.pl/extension) created by [Malinowski et al. 2020](https://www.mdpi.com/2072-4292/12/21/3523/htm).

This notebook demonstrates post-classification probability filters that allows the user to generate rules based on performance on the training data. The notebook includes 4 steps

1. Load Land Cover Classifications and Label Data
2. Calculate Accuracy and Confusion Matrix for Original Classifications on Label Data
3. Define Probability Filters and Apply to Land Cover Probabilities
4. Calculate Accuracy and Confusion Matrix for Post-Filtered Classifications on Label Data



## Step 0: Load libraries and iniatilize Earth Engine

In [None]:
#Load necessary libraries
import sys
import os
import ee
import geemap
import numpy as np
import pandas as pd
from IPython.display import HTML, display
from ipyleaflet import Map, basemaps
import random
import json
import time
import ast

# relative import for this folder hierarchy, credit: https://stackoverflow.com/a/35273613
module_path = os.path.abspath(os.path.join('..'))
if module_path not in sys.path:
    sys.path.append(module_path)
    
from wri_change_detection import preprocessing as npv
from wri_change_detection import gee_classifier as gclass
from wri_change_detection import post_classification_filters as pcf



<font size="4">Iniatilize Earth Engine and Google Cloud authentication</font>

In [None]:
#Initialize earth engine
try:
    ee.Initialize()
except Exception as e:
    ee.Authenticate()
    ee.Initialize()

<font size="4">Define a seed number to ensure reproducibility across random processes. This seed will be used in all subsequent sampling as well. We'll also define seeds for sampling the training, validation, and test sets.</font>

In [None]:
num_seed=30
random.seed(num_seed)


## Step 1: Load Land Cover Classifications and Label Data


<font size="4">

Define land cover classification image collection, with one image for each time period. Each image should have one band representing the classification in that pixel for one time period.</font>

In [None]:
#Load collection
#This collection represents monthly dynamic world classifications of land cover, later we'll squash it to annual
dynamic_world_classifications_monthly = ee.ImageCollection('projects/wings-203121/assets/dynamic-world/v3-5_stack_tests/wri_test_goldsboro')

#Get classes from first image
dw_classes = dynamic_world_classifications_monthly.first().bandNames()
dw_classes_str = dw_classes.getInfo()
full_dw_classes_str = ['No Data']+dw_classes_str

#Get dictionary of classes and values
#Define array of land cover classification values
dw_class_values = np.arange(1,10).tolist()
dw_class_values_ee = ee.List(dw_class_values)
#Create dictionary representing land cover classes and land cover class values
dw_classes_dict = ee.Dictionary.fromLists(dw_classes, dw_class_values_ee)

#Make sure the dictionary looks good
print(dw_classes_dict.getInfo())


<font size="4">Define color palettes to map land cover</font>

In [None]:
change_detection_palette = ['#ffffff', # no_data=0
                              '#419bdf', # water=1
                              '#397d49', # trees=2
                              '#88b053', # grass=3
                              '#7a87c6', # flooded_vegetation=4
                              '#e49535', # crops=5
                              '#dfc25a', # scrub_shrub=6
                              '#c4291b', # builtup=7
                              '#a59b8f', # bare_ground=8
                              '#a8ebff', # snow_ice=9
                              '#616161', # clouds=10
]
statesViz = {'min': 0, 'max': 10, 'palette': change_detection_palette};

oneChangeDetectionViz = {'min': 0, 'max': 1, 'palette': ['696a76','ff2b2b']}; #gray = 0, red = 1
consistentChangeDetectionViz = {'min': 0, 'max': 1, 'palette': ['0741df','df07b5']}; #blue = 0, pink = 1



<font size="4">Gather projection and geometry information from the land cover classifications</font>

In [None]:
projection_ee = dynamic_world_classifications_monthly.first().projection()
projection = projection_ee.getInfo()
crs = projection.get('crs')
crsTransform = projection.get('transform')
scale = dynamic_world_classifications_monthly.first().projection().nominalScale().getInfo()
print('CRS and Transform: ',crs, crsTransform)

geometry = dynamic_world_classifications_monthly.first().geometry().bounds()


<font size="4">Convert the land cover collection to a multiband image, one band for each year, and reduce the monthly probability predictions to annual probablities.</font>

In [None]:
#Define years to get annual classifications for
years = np.arange(2016,2020)

#Squash scenes from monthly to annual
dynamic_world_classifications = npv.squashScenesToAnnualClassification(dynamic_world_classifications_monthly,years,method='median',image_name='dw_{}')

#Squash scenes from monthly to annual
dynamic_world_probabilites = npv.squashScenesToAnnualProbability(dynamic_world_classifications_monthly,years,method='median',image_name='dw_probs_{}')

#Get image names 
dw_band_names = dynamic_world_classifications.aggregate_array('system:index').getInfo()
#Convert to a multiband image and rename using dw_band_names
dynamic_world_classifications_image = dynamic_world_classifications.toBands().rename(dw_band_names)


<font size="4">
Load label data to later compare land cover classification to label data. Export points of labelled data in order to compare to classifications later.
</font>

In [None]:
#only labels for regions in Modesto CA, Goldsboro NC, the Everglades in FL, and one region in Brazil have been
#uploaded to this collection
labels = ee.ImageCollection('projects/wri-datalab/DynamicWorld_CD/DW_Labels')

#Filter to where we have DW classifications
labels_filtered = labels.filterBounds(dynamic_world_classifications_monthly.geometry())
print('Number of labels that overlap classifications', labels_filtered.size().getInfo())

#Save labels projection
labels_projection = labels_filtered.first().projection()
#Define geometry to sample points from 
labels_geometry = labels_filtered.geometry().bounds()

#Compress labels by majority vote
labels_filtered = labels_filtered.reduce(ee.Reducer.mode())
#Remove pixels that were classified as no data
labels_filtered = labels_filtered.mask(labels_filtered.neq(0))
#Rename band
labels_filtered = labels_filtered.rename(['labels'])


#Sample points from label image at every pixel
labelPoints = labels_filtered.sample(region=labels_geometry, projection=labels_projection, 
                                     factor=1, 
                                     seed=num_seed, dropNulls=True,
                                     geometries=True)

#Export sampled points
labelPoints_export_name = 'goldsboro'
labelPoints_assetID = 'projects/wri-datalab/DynamicWorld_CD/DW_LabelPoints_{}'
labelPoints_description = 'DW_LabelPoints_{}'

export_results_task = ee.batch.Export.table.toAsset(
    collection=labelPoints, 
    description = labelPoints_description.format(labelPoints_export_name), 
    assetId = labelPoints_assetID.format(labelPoints_export_name))
export_results_task.start()


<font size="4">Map land cover classifications and labels</font>

In [None]:
#Map years to check them out!
center = [35.410769, -78.100163]
zoom = 12
Map1 = geemap.Map(center=center, zoom=zoom,basemap=basemaps.Esri.WorldImagery,add_google_map = False)
Map1.addLayer(dynamic_world_classifications_image.select('dw_2016'),statesViz,name='2016 DW LC')
Map1.addLayer(dynamic_world_classifications_image.select('dw_2017'),statesViz,name='2017 DW LC')
Map1.addLayer(dynamic_world_classifications_image.select('dw_2018'),statesViz,name='2018 DW LC')
Map1.addLayer(dynamic_world_classifications_image.select('dw_2019'),statesViz,name='2019 DW LC')
Map1.addLayer(labels_filtered,statesViz,name='Labels')
display(Map1)


## Step 2: Calculate Accuracy and Confusion Matrix for Original Classifications on Label Data

In [None]:
#Load label points
labelPointsFC = ee.FeatureCollection(labelPoints_assetID.format('goldsboro'))

#Save 2019 DW classifications and rename to "dw_classifications"
dw_2019 = dynamic_world_classifications_image.select('dw_2019').rename('dw_classifications')

#Sample the 2019 classifications at each label point
labelPointsWithDW = dw_2019.sampleRegions(collection=labelPointsFC, projection = projection_ee, 
                                          tileScale=4, geometries=True)

#Calculate confusion matrix, which we will use for an accuracy assessment
originalErrorMatrix = labelPointsWithDW.errorMatrix('labels', 'dw_classifications')

#Get the confusion matrix as a list
errorMatrixValues = originalErrorMatrix.getInfo()

#Print the confusion matrix with the class names as a dataframe
#Axis 1 (the rows) of the matrix correspond to the actual values, and Axis 0 (the columns) to the predicted values.
errorMatrixDf = pd.DataFrame(errorMatrixValues, index = full_dw_classes_str, columns = full_dw_classes_str)
print('Axis 1 (the rows) of the matrix correspond to the actual values, and Axis 0 (the columns) to the predicted values.')
display(errorMatrixDf)

#You can also print further accuracy scores from the confusion matrix, however each one takes a couple minutes 
#to load
print('Accuracy',originalErrorMatrix.accuracy().getInfo())
# print('Consumers Accuracy',originalErrorMatrix.consumersAccuracy().getInfo())
# print('Producers Accuracy',originalErrorMatrix.producersAccuracy().getInfo())
# print('Kappa',originalErrorMatrix.kappa().getInfo())



## Step 3: Define Probability Filters and Apply to Land Cover Probabilities

In [None]:
#{'bare_ground': 8, 'built_area': 7, 'crops': 5, 'flooded_vegetation': 4, 'grass': 3, 'scrub': 6, 'snow_and_ice': 9, 'trees': 2, 'water': 1}

#Define list of dictionaries to pass to applyProbabilityCutoffs
#applyProbabilityCutoffs:
# Function to apply a probability filter to land cover probabilities in each image of imageCollection. 
# The user defines which classes will be filtered and how to filter them in the params list.
# The params list is a list of dictionaries, one for each class the user wants to filter.
# The dictionaries in the params list is of the form {'class_name': String, 'class_value': Int, 'filter': String, 'threshold', Float}

# If the filter is 'gt' or 'gte' (representing 'greater than' or 'greater than or equal to'), and if the pixel class probability greater than or greater 
#     than or equal to the threshold, then final classification is replaced by the value of that class.
# If the filter is 'lt' or 'lte' (representing 'less than' or 'less than or equal to'), and if the pixel class probability less than or less than or
#     equal to the threshold, and the pixel is in that class then final classification, then the final classification is replaced by
#     the majority class of the neighborhood, where the neighborhood is a square kernel of size 1.

#Here we will apply two filters: 
#First if the probability of the tree class is <0.5 and the pixel was classified as a tree, 
#then we replace it with the majority of the neighbor pixel classes

#Second if the probability of the built-area class is >0.3, then the pixel is classified as built-area

params = [{'class_name': 'trees', 'class_value': 2, 'filter': 'lt', 'threshold': 0.5},
          {'class_name': 'built_area', 'class_value': 7, 'filter': 'gt', 'threshold': 0.3}]

classifications_filtered = pcf.applyProbabilityCutoffs(dynamic_world_probabilites, params)
image_names = classifications_filtered.aggregate_array('system:index')
classifications_filtered = classifications_filtered.toBands().rename(image_names)


## Step 4: Calculate Accuracy and Confusion Matrix for Post-Filtered Classifications on Label Data

In [None]:
#Load label points
labelPointsFC = ee.FeatureCollection(labelPoints_assetID.format('goldsboro'))

#Save 2019 post-filtered DW classifications and rename to "dw_filterd_classifications"
classifications_filtered_2019 = classifications_filtered.select('dw_probs_2019').rename('dw_filterd_classifications')

#Sample the 2019 classifications at each label point
labelPointsWithFilteredDW = classifications_filtered_2019.sampleRegions(collection=labelPointsFC, 
                                                                        projection = projection_ee, 
                                                                        tileScale=4, geometries=True)

#Calculate confusion matrix, which we will use for an accuracy assessment
filteredErrorMatrix = labelPointsWithFilteredDW.errorMatrix('labels', 'dw_filterd_classifications')

#Get the confusion matrix as a list
filteredErrorMatrixValues = filteredErrorMatrix.getInfo()

#Print the confusion matrix with the class names as a dataframe
#Axis 1 (the rows) of the matrix correspond to the actual values, and Axis 0 (the columns) to the predicted values.
filteredErrorMatrixDf = pd.DataFrame(filteredErrorMatrixValues, index = full_dw_classes_str, 
                                     columns = full_dw_classes_str)
print('Axis 1 (the rows) of the matrix correspond to the actual values, and Axis 0 (the columns) to the predicted values.')
display(filteredErrorMatrixDf)

#You can also print further accuracy scores from the confusion matrix, however each one takes a couple minutes 
#to load
print('Accuracy',filteredErrorMatrix.accuracy().getInfo())
# print('Consumers Accuracy',originalErrorMatrix.consumersAccuracy().getInfo())
# print('Producers Accuracy',originalErrorMatrix.producersAccuracy().getInfo())
# print('Kappa',originalErrorMatrix.kappa().getInfo())


<font size="4">Map the classifications before and after the filtering, along with the probabilities for quality check.</font>

In [None]:
#Map results to check them out!
treesViz = {'min': 0, 'max': 1, 'palette': ['93ff8f','117e29']}; #light green = 0, dark green = 1
builtAreaViz = {'min': 0, 'max': 1, 'palette': ['ff8f8f','c30000']}; #gray = 0, red = 1

#Select probabilities for 2019
dynamic_world_probabilites_2019 = dynamic_world_probabilites.filterDate('2019-01-01','2019-12-31').first()

#Find where classifications changed after filtering
changed_with_filter = dynamic_world_classifications_image.select('dw_2019').neq(classifications_filtered.select('dw_probs_2019'))

Map2 = geemap.Map(center=center, zoom=zoom,basemap=basemaps.Esri.WorldImagery,add_google_map = False)
Map2.addLayer(dynamic_world_probabilites_2019.select('trees'),treesViz,name='2019 DW Trees Probability')
Map2.addLayer(dynamic_world_probabilites_2019.select('built_area'),builtAreaViz,name='2019 DW Built Area Probability')
Map2.addLayer(dynamic_world_classifications_image.select('dw_2019'),statesViz,name='2019 DW LC')
Map2.addLayer(classifications_filtered.select('dw_probs_2019'),statesViz,name='2019 DW LC Post Filter')
Map2.addLayer(changed_with_filter,oneChangeDetectionViz,name='Changed with Filter')
display(Map2)
