In [1]:
%matplotlib inline

import datacube
import numpy as np
import xarray as xr
import subprocess as sp
import pandas as pd
import geopandas as gpd
from odc.io.cgroups import get_cpu_quota
from odc.geo.xr import assign_crs

from deafrica_tools.plotting import map_shapefile
from deafrica_tools.bandindices import calculate_indices
from deafrica_tools.classification import collect_training_data

In [2]:
path = 'peatlands.geojson' 
field = 'Class'

In [3]:
ncpus=round(get_cpu_quota())
print('ncpus = '+str(ncpus))

ncpus = 4


In [4]:
# Load input data shapefile
input_data = gpd.read_file(path)

# Transform the class column datatype to int.
input_data[field] = input_data[field].astype(int)

# Plot first five rows
input_data.head()

input_data

Unnamed: 0,id,Class,geometry
0,2,1,"MULTIPOLYGON (((32.89122 -26.59649, 32.88572 -..."
1,3,1,"MULTIPOLYGON (((30.10178 -25.558, 30.10228 -25..."
2,4,0,"MULTIPOLYGON (((30.11356 -25.54435, 30.11162 -..."
3,5,1,"MULTIPOLYGON (((27.60395 -24.46435, 27.60384 -..."
4,6,0,"MULTIPOLYGON (((27.54494 -24.45915, 27.54494 -..."
5,7,1,"MULTIPOLYGON (((31.03985 -22.88845, 31.03988 -..."
6,8,0,"MULTIPOLYGON (((31.07475 -22.91915, 31.07482 -..."
7,12,0,"MULTIPOLYGON (((27.15714 -26.47669, 27.15714 -..."
8,13,0,"MULTIPOLYGON (((27.13739 -26.482, 27.1374 -26...."
9,14,0,"MULTIPOLYGON (((22.76795 -34.04311, 22.78504 -..."


In [5]:
# Plot training data in an interactive map
map_shapefile(input_data, attribute=field)

Label(value='')

Map(center=[-28.494402730498905, 27.82899858879331], controls=(ZoomControl(options=['position', 'zoom_in_text'…

In [6]:
#set up our inputs to collect_training_data
zonal_stats = 'mean'

# Set up the inputs for the ODC query
time = ('2019')
measurements =  ['blue','green','red','nir','swir_1','swir_2','red_edge_1',
                 'red_edge_2', 'red_edge_3', 'BCMAD', 'EMAD', 'SMAD']
resolution = (-20,20)
output_crs='epsg:6933'

In [7]:
query = {
    'time': time,
    'resolution': resolution,
    'output_crs': output_crs
}

In [8]:
from datacube.testutils.io import rio_slurp_xarray


def feature_layers(query):
    #connect to the datacube
    dc = datacube.Datacube(app='feature_layers')
    
    #load s2 annual geomedian
    ds = dc.load(product='gm_s2_annual',
                 measurements =['blue','green','red','nir','swir_1','swir_2','red_edge_1',
                 'red_edge_2', 'red_edge_3', 'BCMAD', 'EMAD', 'SMAD'],
                 **query)
    
    #calculate some band indices
    da = calculate_indices(ds,
                           index=['NDVI', 'LAI', 'MNDWI'],
                           drop=False,
                           satellite_mission='s2')

    ds = dc.load(product="s1_rtc",
             measurements=['vv','vh'],
             group_by="solar_day",
             **query)
    
    # print(ds.vv)
    median_ds = ds[["vv","vh"]].median()
    print(median_ds)
    #add slope dataset
    url_slope = "https://deafrica-input-datasets.s3.af-south-1.amazonaws.com/srtm_dem/srtm_africa_slope.tif"
    slope = rio_slurp_xarray(url_slope, gbox=ds.geobox)
    slope = slope.to_dataset(name='slope')
    
    #merge results into single dataset 
    result = xr.merge([da, slope, median_ds],compat='override')

    return result.squeeze()

In [9]:
column_names, model_input = collect_training_data(
                                    gdf=input_data,
                                    dc_query=query,
                                    ncpus=1, # number of cpus to use
                                    field=field, # label you want to give to the class
                                    zonal_stats=None, # Descriptive statistics
                                    feature_func=feature_layers # 
                                    )

Collecting training data in serial mode
<xarray.Dataset> Size: 20B
Dimensions:      ()
Coordinates:
    spatial_ref  int32 4B 6933
Data variables:
    vv           float64 8B 0.04468
    vh           float64 8B 0.003488
<xarray.Dataset> Size: 20B
Dimensions:      ()
Coordinates:
    spatial_ref  int32 4B 6933
Data variables:
    vv           float64 8B 0.03881
    vh           float64 8B 0.008046
<xarray.Dataset> Size: 20B
Dimensions:      ()
Coordinates:
    spatial_ref  int32 4B 6933
Data variables:
    vv           float64 8B 0.03107
    vh           float64 8B 0.005916
<xarray.Dataset> Size: 20B
Dimensions:      ()
Coordinates:
    spatial_ref  int32 4B 6933
Data variables:
    vv           float64 8B 0.07666
    vh           float64 8B 0.01502
<xarray.Dataset> Size: 20B
Dimensions:      ()
Coordinates:
    spatial_ref  int32 4B 6933
Data variables:
    vv           float64 8B 0.07977
    vh           float64 8B 0.01721
<xarray.Dataset> Size: 20B
Dimensions:      ()
Coordinates:
  

In [10]:
df = pd.DataFrame(model_input, columns=column_names)

In [11]:
df.head()

Unnamed: 0,Class,blue,green,red,nir,swir_1,swir_2,red_edge_1,red_edge_2,red_edge_3,BCMAD,EMAD,SMAD,NDVI,LAI,MNDWI,slope,vv,vh
0,1.0,377.0,585.0,534.0,2515.0,1673.0,876.0,993.0,2018.0,2403.0,0.049633,514.807922,0.001842,0.649721,1.271919,-0.481842,15.365908,0.044678,0.003488
1,1.0,392.0,599.0,583.0,2464.0,1690.0,895.0,1010.0,1968.0,2340.0,0.046304,494.613434,0.002013,0.617329,1.188531,-0.476627,16.298006,0.044678,0.003488
2,1.0,389.0,610.0,580.0,2482.0,1719.0,906.0,1043.0,2003.0,2352.0,0.049173,530.746826,0.001946,0.621163,1.200839,-0.47617,16.749792,0.044678,0.003488
3,1.0,549.0,772.0,833.0,2457.0,2527.0,1586.0,1268.0,1991.0,2281.0,0.048083,602.628357,0.001914,0.493617,0.983337,-0.531979,5.892557,0.044678,0.003488
4,1.0,544.0,771.0,839.0,2350.0,2392.0,1483.0,1280.0,1951.0,2224.0,0.046251,564.475464,0.001787,0.473816,0.909285,-0.512488,5.892557,0.044678,0.003488


In [13]:
import pandas as pd

# Filter out rows where 'Class' is 1.0
class_1_df = df[df['Class'] == 1.0]

# Filter out rows where 'Class' is 0.0
class_0_df = df[df['Class'] == 0.0]

# Randomly sample from class_0_df to match the count of class_1_df
class_0_sampled = class_0_df.sample(n=len(class_1_df), random_state=42)

# Combine the balanced dataset
df_balanced = pd.concat([class_0_sampled, class_1_df])

# Shuffle the DataFrame
df_balanced = df_balanced.sample(frac=1, random_state=42).reset_index(drop=True)

# Verify the class distribution
print(df_balanced['Class'].value_counts())

Class
1.0    65318
0.0    65318
Name: count, dtype: int64


In [18]:
df_balanced.to_csv("training_data.csv")
df2 = df.drop(df.columns[0], axis=1)
df2

Unnamed: 0,blue,green,red,nir,swir_1,swir_2,red_edge_1,red_edge_2,red_edge_3,BCMAD,EMAD,SMAD,NDVI,LAI,MNDWI,slope,vv,vh
0,1307.0,1670.0,1897.0,1633.0,1176.0,855.0,2033.0,1722.0,1805.0,0.124956,1300.032715,0.012037,-0.074788,-0.298729,0.173577,0.000000,0.019267,0.001569
1,1186.0,1546.0,1701.0,1215.0,775.0,595.0,1770.0,1342.0,1402.0,0.117070,1053.252319,0.011100,-0.166667,-0.468940,0.332184,0.000000,0.019267,0.001569
2,1098.0,1469.0,1627.0,1096.0,614.0,498.0,1655.0,1208.0,1263.0,0.118511,914.377258,0.011653,-0.195006,-0.498488,0.410466,0.000000,0.019267,0.001569
3,1023.0,1384.0,1554.0,996.0,567.0,470.0,1598.0,1141.0,1196.0,0.109593,855.793884,0.013830,-0.218824,-0.517060,0.418760,0.000000,0.019267,0.001569
4,1015.0,1371.0,1553.0,985.0,545.0,454.0,1578.0,1110.0,1167.0,0.110707,849.534424,0.012447,-0.223798,-0.522835,0.431106,0.000000,0.019267,0.001569
...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...
1688247,507.0,608.0,497.0,500.0,461.0,345.0,604.0,540.0,564.0,0.291916,1096.545898,0.024121,0.003009,-0.115197,0.137512,1.863390,0.088388,0.021386
1688248,481.0,578.0,466.0,417.0,394.0,301.0,547.0,446.0,483.0,0.321916,962.884583,0.023219,-0.055493,-0.164141,0.189300,1.863390,0.088388,0.021386
1688249,457.0,548.0,448.0,337.0,339.0,261.0,507.0,379.0,405.0,0.311370,798.966492,0.022858,-0.141401,-0.222610,0.235626,2.635231,0.088388,0.021386
1688250,449.0,539.0,441.0,308.0,311.0,239.0,486.0,342.0,366.0,0.289455,706.476196,0.019269,-0.177570,-0.243487,0.268235,4.249183,0.088388,0.021386


## Normalization

In [19]:
from sklearn.preprocessing import MinMaxScaler

min_max_scaler =MinMaxScaler(feature_range=(0,1))
df_scaled = min_max_scaler.fit_transform(df2)
df_scaled

array([[0.37335184, 0.36284767, 0.35816498, ..., 0.        , 0.05558809,
        0.02726129],
       [0.33136711, 0.33120694, 0.31691919, ..., 0.        , 0.05558809,
        0.02726129],
       [0.30083276, 0.31155907, 0.3013468 , ..., 0.        , 0.05558809,
        0.02726129],
       ...,
       [0.07841777, 0.07655014, 0.05324074, ..., 0.05898192, 0.56048605,
        0.48691759],
       [0.07564192, 0.07425364, 0.05176768, ..., 0.09510547, 0.56048605,
        0.48691759],
       [0.07980569, 0.07757081, 0.0530303 , ..., 0.09510547, 0.56048605,
        0.48691759]])

In [20]:
df_scaled = pd.DataFrame(df_scaled, columns=column_names[1:])
df_scaled

Unnamed: 0,blue,green,red,nir,swir_1,swir_2,red_edge_1,red_edge_2,red_edge_3,BCMAD,EMAD,SMAD,NDVI,LAI,MNDWI,slope,vv,vh
0,0.373352,0.362848,0.358165,0.286470,0.160835,0.119848,0.347992,0.299665,0.306740,0.251360,0.282850,0.155538,0.321676,0.156770,0.601048,0.000000,0.055588,0.027261
1,0.331367,0.331207,0.316919,0.207483,0.099294,0.078630,0.297706,0.228981,0.233333,0.232778,0.222469,0.143282,0.256170,0.104268,0.712579,0.000000,0.055588,0.027261
2,0.300833,0.311559,0.301347,0.184996,0.074586,0.063253,0.275717,0.204055,0.208015,0.236173,0.188490,0.150519,0.235966,0.095154,0.767625,0.000000,0.055588,0.027261
3,0.274809,0.289870,0.285985,0.166100,0.067373,0.058814,0.264818,0.191592,0.195811,0.215161,0.174156,0.178982,0.218985,0.089425,0.773457,0.000000,0.055588,0.027261
4,0.272033,0.286553,0.285774,0.164021,0.063996,0.056278,0.260994,0.185826,0.190528,0.217787,0.172625,0.160901,0.215438,0.087644,0.782139,0.000000,0.055588,0.027261
...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...
1688247,0.095767,0.091860,0.063552,0.072373,0.051105,0.038998,0.074761,0.079799,0.080692,0.644747,0.233062,0.313565,0.377142,0.213380,0.575688,0.041707,0.560486,0.486918
1688248,0.086745,0.084205,0.057029,0.056689,0.040823,0.032023,0.063862,0.062314,0.065938,0.715432,0.200358,0.301769,0.335432,0.198284,0.612105,0.041707,0.560486,0.486918
1688249,0.078418,0.076550,0.053241,0.041572,0.032382,0.025682,0.056214,0.049851,0.051730,0.690585,0.160252,0.297047,0.274183,0.180249,0.644680,0.058982,0.560486,0.486918
1688250,0.075642,0.074254,0.051768,0.036092,0.028085,0.022194,0.052199,0.042969,0.044627,0.638948,0.137622,0.250107,0.248397,0.173809,0.667611,0.095105,0.560486,0.486918


In [21]:
df_scaled.to_csv("scaled_training_data.csv")

#set the name and location of the output file
output_file = "scaled_training_data.csv"
model_input = pd.read_csv(output_file)

In [22]:
model_input

Unnamed: 0.1,Unnamed: 0,blue,green,red,nir,swir_1,swir_2,red_edge_1,red_edge_2,red_edge_3,BCMAD,EMAD,SMAD,NDVI,LAI,MNDWI,slope,vv,vh
0,0,0.373352,0.362848,0.358165,0.286470,0.160835,0.119848,0.347992,0.299665,0.306740,0.251360,0.282850,0.155538,0.321676,0.156770,0.601048,0.000000,0.055588,0.027261
1,1,0.331367,0.331207,0.316919,0.207483,0.099294,0.078630,0.297706,0.228981,0.233333,0.232778,0.222469,0.143282,0.256170,0.104268,0.712579,0.000000,0.055588,0.027261
2,2,0.300833,0.311559,0.301347,0.184996,0.074586,0.063253,0.275717,0.204055,0.208015,0.236173,0.188490,0.150519,0.235966,0.095154,0.767625,0.000000,0.055588,0.027261
3,3,0.274809,0.289870,0.285985,0.166100,0.067373,0.058814,0.264818,0.191592,0.195811,0.215161,0.174156,0.178982,0.218985,0.089425,0.773457,0.000000,0.055588,0.027261
4,4,0.272033,0.286553,0.285774,0.164021,0.063996,0.056278,0.260994,0.185826,0.190528,0.217787,0.172625,0.160901,0.215438,0.087644,0.782139,0.000000,0.055588,0.027261
...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...
1688247,1688247,0.095767,0.091860,0.063552,0.072373,0.051105,0.038998,0.074761,0.079799,0.080692,0.644747,0.233062,0.313565,0.377142,0.213380,0.575688,0.041707,0.560486,0.486918
1688248,1688248,0.086745,0.084205,0.057029,0.056689,0.040823,0.032023,0.063862,0.062314,0.065938,0.715432,0.200358,0.301769,0.335432,0.198284,0.612105,0.041707,0.560486,0.486918
1688249,1688249,0.078418,0.076550,0.053241,0.041572,0.032382,0.025682,0.056214,0.049851,0.051730,0.690585,0.160252,0.297047,0.274183,0.180249,0.644680,0.058982,0.560486,0.486918
1688250,1688250,0.075642,0.074254,0.051768,0.036092,0.028085,0.022194,0.052199,0.042969,0.044627,0.638948,0.137622,0.250107,0.248397,0.173809,0.667611,0.095105,0.560486,0.486918


In [23]:
df

Unnamed: 0,class,blue,green,red,nir,swir_1,swir_2,red_edge_1,red_edge_2,red_edge_3,BCMAD,EMAD,SMAD,NDVI,LAI,MNDWI,slope,vv,vh
0,0.0,1307.0,1670.0,1897.0,1633.0,1176.0,855.0,2033.0,1722.0,1805.0,0.124956,1300.032715,0.012037,-0.074788,-0.298729,0.173577,0.000000,0.019267,0.001569
1,0.0,1186.0,1546.0,1701.0,1215.0,775.0,595.0,1770.0,1342.0,1402.0,0.117070,1053.252319,0.011100,-0.166667,-0.468940,0.332184,0.000000,0.019267,0.001569
2,0.0,1098.0,1469.0,1627.0,1096.0,614.0,498.0,1655.0,1208.0,1263.0,0.118511,914.377258,0.011653,-0.195006,-0.498488,0.410466,0.000000,0.019267,0.001569
3,0.0,1023.0,1384.0,1554.0,996.0,567.0,470.0,1598.0,1141.0,1196.0,0.109593,855.793884,0.013830,-0.218824,-0.517060,0.418760,0.000000,0.019267,0.001569
4,0.0,1015.0,1371.0,1553.0,985.0,545.0,454.0,1578.0,1110.0,1167.0,0.110707,849.534424,0.012447,-0.223798,-0.522835,0.431106,0.000000,0.019267,0.001569
...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...
1688247,0.0,507.0,608.0,497.0,500.0,461.0,345.0,604.0,540.0,564.0,0.291916,1096.545898,0.024121,0.003009,-0.115197,0.137512,1.863390,0.088388,0.021386
1688248,0.0,481.0,578.0,466.0,417.0,394.0,301.0,547.0,446.0,483.0,0.321916,962.884583,0.023219,-0.055493,-0.164141,0.189300,1.863390,0.088388,0.021386
1688249,0.0,457.0,548.0,448.0,337.0,339.0,261.0,507.0,379.0,405.0,0.311370,798.966492,0.022858,-0.141401,-0.222610,0.235626,2.635231,0.088388,0.021386
1688250,0.0,449.0,539.0,441.0,308.0,311.0,239.0,486.0,342.0,366.0,0.289455,706.476196,0.019269,-0.177570,-0.243487,0.268235,4.249183,0.088388,0.021386


In [25]:
model_input['Unnamed: 0'] = df['class']
model_input.rename(columns={'Unnamed: 0': 'class'}, inplace=True) # Renaming column name 'Unnamed: 0' to 'Class'
model_input.to_csv('model_input.csv', index=False)
model_input

Unnamed: 0,class,blue,green,red,nir,swir_1,swir_2,red_edge_1,red_edge_2,red_edge_3,BCMAD,EMAD,SMAD,NDVI,LAI,MNDWI,slope,vv,vh
0,0.0,0.373352,0.362848,0.358165,0.286470,0.160835,0.119848,0.347992,0.299665,0.306740,0.251360,0.282850,0.155538,0.321676,0.156770,0.601048,0.000000,0.055588,0.027261
1,0.0,0.331367,0.331207,0.316919,0.207483,0.099294,0.078630,0.297706,0.228981,0.233333,0.232778,0.222469,0.143282,0.256170,0.104268,0.712579,0.000000,0.055588,0.027261
2,0.0,0.300833,0.311559,0.301347,0.184996,0.074586,0.063253,0.275717,0.204055,0.208015,0.236173,0.188490,0.150519,0.235966,0.095154,0.767625,0.000000,0.055588,0.027261
3,0.0,0.274809,0.289870,0.285985,0.166100,0.067373,0.058814,0.264818,0.191592,0.195811,0.215161,0.174156,0.178982,0.218985,0.089425,0.773457,0.000000,0.055588,0.027261
4,0.0,0.272033,0.286553,0.285774,0.164021,0.063996,0.056278,0.260994,0.185826,0.190528,0.217787,0.172625,0.160901,0.215438,0.087644,0.782139,0.000000,0.055588,0.027261
...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...
1688247,0.0,0.095767,0.091860,0.063552,0.072373,0.051105,0.038998,0.074761,0.079799,0.080692,0.644747,0.233062,0.313565,0.377142,0.213380,0.575688,0.041707,0.560486,0.486918
1688248,0.0,0.086745,0.084205,0.057029,0.056689,0.040823,0.032023,0.063862,0.062314,0.065938,0.715432,0.200358,0.301769,0.335432,0.198284,0.612105,0.041707,0.560486,0.486918
1688249,0.0,0.078418,0.076550,0.053241,0.041572,0.032382,0.025682,0.056214,0.049851,0.051730,0.690585,0.160252,0.297047,0.274183,0.180249,0.644680,0.058982,0.560486,0.486918
1688250,0.0,0.075642,0.074254,0.051768,0.036092,0.028085,0.022194,0.052199,0.042969,0.044627,0.638948,0.137622,0.250107,0.248397,0.173809,0.667611,0.095105,0.560486,0.486918
