# Machine Learning using Sentinel-2 Data

This example uses training data from the
[Coast Train](https://github.com/nick-murray/coastTrain) dataset
along with Sentinel-2 data to demonstrate how to use a
machine learning classifier, in this case, Random Forest, to
assign a class to each pixel.

This notebook combines lessons from previous notebooks into
a comprehensive worked example.

## Getting started

First we load the required Python libraries and tools.

In [6]:
from pystac_client import Client
from dask.distributed import Client as DaskClient
from odc.stac import load, configure_s3_access
import geopandas as gpd
import pandas as pd
import numpy as np
import xarray as xr
import folium
import joblib


from sklearn.ensemble import RandomForestClassifier

import odc.geo.xr  # noqa: F401

## Study site configuration

Here we establish the STAC catalog we're using as well as a
spatial and temporal extent. This can be anywhere, but this location
near Kuching was chosen due to the training data having several
classes available.

In [7]:
# STAC Catalog URL
# catalog = "https://stac.staging.digitalearthpacific.org"
catalog = "https://earth-search.aws.element84.com/v1"
# Create a STAC Client
client = Client.open(catalog)

<font color='blue'>1.1. Define your area of interest. Find the coordinates of the bottom left and top right corners of your bounding box / area of interest. 

Use QGIS, Google Maps or another site to find the coordinates. Make sure to use at least 4 or 5 decimal places. Lat = latitude and lon = longitude. The min is in the bottom left and max is in the top right.   

In [8]:
# 1.1 - input your four coordinates here 
min_lat = 
min_lon = 
max_lat = 
max_lon = 

bbox = [min_lon, min_lat, max_lon, max_lat]

<font color='blue'>1.2. Define the time period of interest. 

You can now define the time you are interested in. It is goood to put the format as "year-month"/"year-month". For example "2022-06/2024-09" to get all images covering June 2022 to September 2024. You can try to just choose the flowering months in one year but may have issues with cloud cover. Try different time periods and see what happes. 

In [9]:
# 1.2 - input your datetime here - recommend at least 3 months and max 3 years. 
datetime = " "

In [10]:
# Create local dask cluster to improve data load time. Only run this once.
dask_client = DaskClient(n_workers=1, threads_per_worker=16, memory_limit='16GB')

# Configure S3 access. Cloud defaults is an optimisation, while requester pays is required for Landsat
configure_s3_access(cloud_defaults=True, requester_pays=True)

<botocore.credentials.DeferredRefreshableCredentials at 0x7f0028b60a60>

## Training data

Next up we gather training data. This could be any geospatial point dataset
with a column that is numeric, for the class.

If you'd like to explore the structure of this data, you can run `gdf.head()`
to see the first few rows. The `explore()` function with the `column` argument
will show the data on the map, and change the colour based on that column.

<font color='blue'>2.1. Input your data from the field.  

If you have new data, save your data as a geojson in QGIS and then drag and drop it into the same folder as this notebook in DEP. Then you will have to put the name of the file in the brackets below inside of quotes: eg. `'name.geojson'`   

In [11]:
# 2.1 - input your data file here inside the '' 
gdf = gpd.read_file(' ', bbox=bbox)

# gdf = gdf.fillna(0)
gdf.explore(column=" ", legend=True)

## Find and load Sentinel-2 data

Here we search for Sentinel-2 scenes over our study area and use
Dask to lazy-load them. We're only loading the red, green, blue, nir and swir
bands, along with the scene classification (scl) band.

<font color='blue'>2.2. - define the satellite image collections you wish to use: 

Input your collections name inside the brackets: 

You may try:   
"sentinel-2-c1-l2a" (10mx10m resolution pixels)   
OR   
"landsat-c2-l2" (30mx30m resolution pixels)

<font color='blue'>2.3. - define your cloud cover threshold. We recommend a number somewhere between 10-50. No need for quotes. 

Write this next to "lt": eg. {"lt": 50}} 


In [12]:
# Search for Sentinel-2 data
items = client.search(
    collections=["sentinel-2-c1-l2a"],
    bbox=bbox,
    datetime=datetime,
    query={"eo:cloud_cover": {"lt": 25 }},
).item_collection()

print(f"Found {len(items)} items")

Found 27 items


In [13]:
# Load the data into an xarray Dataset
data = load(
    items,
    measurements=["red", "green", "blue", "nir08", "swir16", "scl"],
    bbox=bbox,
    chunks={"x": 2048, "y": 2048},
    groupby="solar_day",
)

data

Unnamed: 0,Array,Chunk
Bytes,327.85 kiB,16.39 kiB
Shape,"(20, 109, 77)","(1, 109, 77)"
Dask graph,20 chunks in 1 graph layer,20 chunks in 1 graph layer
Data type,uint16 numpy.ndarray,uint16 numpy.ndarray
"Array Chunk Bytes 327.85 kiB 16.39 kiB Shape (20, 109, 77) (1, 109, 77) Dask graph 20 chunks in 1 graph layer Data type uint16 numpy.ndarray",77  109  20,

Unnamed: 0,Array,Chunk
Bytes,327.85 kiB,16.39 kiB
Shape,"(20, 109, 77)","(1, 109, 77)"
Dask graph,20 chunks in 1 graph layer,20 chunks in 1 graph layer
Data type,uint16 numpy.ndarray,uint16 numpy.ndarray

Unnamed: 0,Array,Chunk
Bytes,327.85 kiB,16.39 kiB
Shape,"(20, 109, 77)","(1, 109, 77)"
Dask graph,20 chunks in 1 graph layer,20 chunks in 1 graph layer
Data type,uint16 numpy.ndarray,uint16 numpy.ndarray
"Array Chunk Bytes 327.85 kiB 16.39 kiB Shape (20, 109, 77) (1, 109, 77) Dask graph 20 chunks in 1 graph layer Data type uint16 numpy.ndarray",77  109  20,

Unnamed: 0,Array,Chunk
Bytes,327.85 kiB,16.39 kiB
Shape,"(20, 109, 77)","(1, 109, 77)"
Dask graph,20 chunks in 1 graph layer,20 chunks in 1 graph layer
Data type,uint16 numpy.ndarray,uint16 numpy.ndarray

Unnamed: 0,Array,Chunk
Bytes,327.85 kiB,16.39 kiB
Shape,"(20, 109, 77)","(1, 109, 77)"
Dask graph,20 chunks in 1 graph layer,20 chunks in 1 graph layer
Data type,uint16 numpy.ndarray,uint16 numpy.ndarray
"Array Chunk Bytes 327.85 kiB 16.39 kiB Shape (20, 109, 77) (1, 109, 77) Dask graph 20 chunks in 1 graph layer Data type uint16 numpy.ndarray",77  109  20,

Unnamed: 0,Array,Chunk
Bytes,327.85 kiB,16.39 kiB
Shape,"(20, 109, 77)","(1, 109, 77)"
Dask graph,20 chunks in 1 graph layer,20 chunks in 1 graph layer
Data type,uint16 numpy.ndarray,uint16 numpy.ndarray

Unnamed: 0,Array,Chunk
Bytes,327.85 kiB,16.39 kiB
Shape,"(20, 109, 77)","(1, 109, 77)"
Dask graph,20 chunks in 1 graph layer,20 chunks in 1 graph layer
Data type,uint16 numpy.ndarray,uint16 numpy.ndarray
"Array Chunk Bytes 327.85 kiB 16.39 kiB Shape (20, 109, 77) (1, 109, 77) Dask graph 20 chunks in 1 graph layer Data type uint16 numpy.ndarray",77  109  20,

Unnamed: 0,Array,Chunk
Bytes,327.85 kiB,16.39 kiB
Shape,"(20, 109, 77)","(1, 109, 77)"
Dask graph,20 chunks in 1 graph layer,20 chunks in 1 graph layer
Data type,uint16 numpy.ndarray,uint16 numpy.ndarray

Unnamed: 0,Array,Chunk
Bytes,327.85 kiB,16.39 kiB
Shape,"(20, 109, 77)","(1, 109, 77)"
Dask graph,20 chunks in 1 graph layer,20 chunks in 1 graph layer
Data type,uint16 numpy.ndarray,uint16 numpy.ndarray
"Array Chunk Bytes 327.85 kiB 16.39 kiB Shape (20, 109, 77) (1, 109, 77) Dask graph 20 chunks in 1 graph layer Data type uint16 numpy.ndarray",77  109  20,

Unnamed: 0,Array,Chunk
Bytes,327.85 kiB,16.39 kiB
Shape,"(20, 109, 77)","(1, 109, 77)"
Dask graph,20 chunks in 1 graph layer,20 chunks in 1 graph layer
Data type,uint16 numpy.ndarray,uint16 numpy.ndarray

Unnamed: 0,Array,Chunk
Bytes,163.93 kiB,8.20 kiB
Shape,"(20, 109, 77)","(1, 109, 77)"
Dask graph,20 chunks in 1 graph layer,20 chunks in 1 graph layer
Data type,uint8 numpy.ndarray,uint8 numpy.ndarray
"Array Chunk Bytes 163.93 kiB 8.20 kiB Shape (20, 109, 77) (1, 109, 77) Dask graph 20 chunks in 1 graph layer Data type uint8 numpy.ndarray",77  109  20,

Unnamed: 0,Array,Chunk
Bytes,163.93 kiB,8.20 kiB
Shape,"(20, 109, 77)","(1, 109, 77)"
Dask graph,20 chunks in 1 graph layer,20 chunks in 1 graph layer
Data type,uint8 numpy.ndarray,uint8 numpy.ndarray


## Data preparation

Now that we have data, we need to clean it up, masking out clouds
and scaling values to between 0-1, which are the valid reflectance
values.

We add a couple of indices too, which will help the machine learning
algorithm.

Note that we still have a lazy-loaded array, and haven't transferred
any data over the network.

In [14]:
# Mask out clouds and scale values

# Apply Sentinel-2 cloud mask
# 1: defective, 3: shadow, 9: high confidence cloud, 10: thin cirrus
mask_flags = [1, 3, 9, 10]

cloud_mask = ~data.scl.isin(mask_flags)
masked = data.where(cloud_mask)

# Apply scaling and clip to valid data, from 0 to 1
scaled = (masked.where(masked != 0) * 0.0001).clip(0, 1)

# Add some indices
scaled["ndvi"] = (scaled.nir08 - scaled.red) / (scaled.nir08 + scaled.red)
# scaled["ndwi"] = (scaled.green - scaled.nir08) / (scaled.green + scaled.nir08)

scaled


Unnamed: 0,Array,Chunk
Bytes,655.70 kiB,32.79 kiB
Shape,"(20, 109, 77)","(1, 109, 77)"
Dask graph,20 chunks in 13 graph layers,20 chunks in 13 graph layers
Data type,float32 numpy.ndarray,float32 numpy.ndarray
"Array Chunk Bytes 655.70 kiB 32.79 kiB Shape (20, 109, 77) (1, 109, 77) Dask graph 20 chunks in 13 graph layers Data type float32 numpy.ndarray",77  109  20,

Unnamed: 0,Array,Chunk
Bytes,655.70 kiB,32.79 kiB
Shape,"(20, 109, 77)","(1, 109, 77)"
Dask graph,20 chunks in 13 graph layers,20 chunks in 13 graph layers
Data type,float32 numpy.ndarray,float32 numpy.ndarray

Unnamed: 0,Array,Chunk
Bytes,655.70 kiB,32.79 kiB
Shape,"(20, 109, 77)","(1, 109, 77)"
Dask graph,20 chunks in 13 graph layers,20 chunks in 13 graph layers
Data type,float32 numpy.ndarray,float32 numpy.ndarray
"Array Chunk Bytes 655.70 kiB 32.79 kiB Shape (20, 109, 77) (1, 109, 77) Dask graph 20 chunks in 13 graph layers Data type float32 numpy.ndarray",77  109  20,

Unnamed: 0,Array,Chunk
Bytes,655.70 kiB,32.79 kiB
Shape,"(20, 109, 77)","(1, 109, 77)"
Dask graph,20 chunks in 13 graph layers,20 chunks in 13 graph layers
Data type,float32 numpy.ndarray,float32 numpy.ndarray

Unnamed: 0,Array,Chunk
Bytes,655.70 kiB,32.79 kiB
Shape,"(20, 109, 77)","(1, 109, 77)"
Dask graph,20 chunks in 13 graph layers,20 chunks in 13 graph layers
Data type,float32 numpy.ndarray,float32 numpy.ndarray
"Array Chunk Bytes 655.70 kiB 32.79 kiB Shape (20, 109, 77) (1, 109, 77) Dask graph 20 chunks in 13 graph layers Data type float32 numpy.ndarray",77  109  20,

Unnamed: 0,Array,Chunk
Bytes,655.70 kiB,32.79 kiB
Shape,"(20, 109, 77)","(1, 109, 77)"
Dask graph,20 chunks in 13 graph layers,20 chunks in 13 graph layers
Data type,float32 numpy.ndarray,float32 numpy.ndarray

Unnamed: 0,Array,Chunk
Bytes,655.70 kiB,32.79 kiB
Shape,"(20, 109, 77)","(1, 109, 77)"
Dask graph,20 chunks in 13 graph layers,20 chunks in 13 graph layers
Data type,float32 numpy.ndarray,float32 numpy.ndarray
"Array Chunk Bytes 655.70 kiB 32.79 kiB Shape (20, 109, 77) (1, 109, 77) Dask graph 20 chunks in 13 graph layers Data type float32 numpy.ndarray",77  109  20,

Unnamed: 0,Array,Chunk
Bytes,655.70 kiB,32.79 kiB
Shape,"(20, 109, 77)","(1, 109, 77)"
Dask graph,20 chunks in 13 graph layers,20 chunks in 13 graph layers
Data type,float32 numpy.ndarray,float32 numpy.ndarray

Unnamed: 0,Array,Chunk
Bytes,655.70 kiB,32.79 kiB
Shape,"(20, 109, 77)","(1, 109, 77)"
Dask graph,20 chunks in 13 graph layers,20 chunks in 13 graph layers
Data type,float32 numpy.ndarray,float32 numpy.ndarray
"Array Chunk Bytes 655.70 kiB 32.79 kiB Shape (20, 109, 77) (1, 109, 77) Dask graph 20 chunks in 13 graph layers Data type float32 numpy.ndarray",77  109  20,

Unnamed: 0,Array,Chunk
Bytes,655.70 kiB,32.79 kiB
Shape,"(20, 109, 77)","(1, 109, 77)"
Dask graph,20 chunks in 13 graph layers,20 chunks in 13 graph layers
Data type,float32 numpy.ndarray,float32 numpy.ndarray

Unnamed: 0,Array,Chunk
Bytes,655.70 kiB,32.79 kiB
Shape,"(20, 109, 77)","(1, 109, 77)"
Dask graph,20 chunks in 12 graph layers,20 chunks in 12 graph layers
Data type,float32 numpy.ndarray,float32 numpy.ndarray
"Array Chunk Bytes 655.70 kiB 32.79 kiB Shape (20, 109, 77) (1, 109, 77) Dask graph 20 chunks in 12 graph layers Data type float32 numpy.ndarray",77  109  20,

Unnamed: 0,Array,Chunk
Bytes,655.70 kiB,32.79 kiB
Shape,"(20, 109, 77)","(1, 109, 77)"
Dask graph,20 chunks in 12 graph layers,20 chunks in 12 graph layers
Data type,float32 numpy.ndarray,float32 numpy.ndarray

Unnamed: 0,Array,Chunk
Bytes,655.70 kiB,32.79 kiB
Shape,"(20, 109, 77)","(1, 109, 77)"
Dask graph,20 chunks in 23 graph layers,20 chunks in 23 graph layers
Data type,float32 numpy.ndarray,float32 numpy.ndarray
"Array Chunk Bytes 655.70 kiB 32.79 kiB Shape (20, 109, 77) (1, 109, 77) Dask graph 20 chunks in 23 graph layers Data type float32 numpy.ndarray",77  109  20,

Unnamed: 0,Array,Chunk
Bytes,655.70 kiB,32.79 kiB
Shape,"(20, 109, 77)","(1, 109, 77)"
Dask graph,20 chunks in 23 graph layers,20 chunks in 23 graph layers
Data type,float32 numpy.ndarray,float32 numpy.ndarray


In [15]:
# Visualise one date, to make sure it looks good.
# This example shows empty areas where we've masked out nodata, but
# note that there are still a lot of clouds coming in!

scaled.isel(time=0).odc.explore(vmin=0, vmax=0.3)

  _reproject(
  _reproject(


## Create a cloud-free composite

The final data preparation step involves creating a temporal
median of the data bands. Here we use `compute()` to process
the data and bring it into memory.

We preview the data in the second cell below.

In [16]:
# Create a median composite, which should get rid of most of the remaining clouds
# Note that this will take a few minutes to complete

median = scaled.median("time").compute()

median

<font color='blue'>3.1. - visualise the resulting median satellite image:

`median.odc.explore(vmin=0, vmax=0.3)`

In [17]:
median.odc.explore(vmin=0, vmax=0.3)

## Prepare training data array

This next step involves extracting observed values from the satellite data
and combining them with our point data, resulting in something like this:

`class, red, green, blue ...`

This structure is then fed into the machine learning classifier.

In [18]:
# First transform the training points to the same CRS as the data
training = gdf.to_crs(median.odc.geobox.crs)

# Next get the X and Y values out of the point geometries
training_da = training.assign(x=training.geometry.x, y=training.geometry.y).to_xarray()

# Now we can use the x and y values (lon, lat) to extract values from the median composite
training_values = (
    median.sel(training_da[["x", "y"]], method="nearest").squeeze().compute().to_pandas()
)

# Join the training data with the extracted values and remove unnecessary columns
training_array = pd.concat([training["Random_Forest"], training_values], axis=1)
training_array = training_array.drop(
    columns=[
        "y",
        "x",
        "spatial_ref",
    ]
)

# Drop rows where there was no data available
training_array = training_array.dropna()

# Preview our resulting training array
training_array.head()

Unnamed: 0,Random_Forest,red,green,blue,nir08,swir16,scl,ndvi
0,8,0.1322,0.16435,0.13435,0.50465,0.2639,0.0004,0.588485
1,8,0.1409,0.16525,0.1346,0.50465,0.2639,0.0004,0.566068
2,8,0.1409,0.16525,0.1346,0.50465,0.2639,0.0004,0.566068
3,8,0.1278,0.151,0.12915,0.5016,0.2559,0.0004,0.592806
4,8,0.1278,0.151,0.12915,0.5016,0.2559,0.0004,0.592806


## Create a classifier and fit a model

We pass in simple numpy arrays to the classifier, one has the
observations (the values of the red, green, blue and so on)
while the other has the classes.

In [19]:
# The classes are the first column
classes = np.array(training_array)[:, 0]

# The observation data is everything after the first column
observations = np.array(training_array)[:, 1:]

# Create a model...
classifier = RandomForestClassifier()

# ...and fit it to the data
model = classifier.fit(observations, classes)

## Prediction

Next we predict. Again, we need a simple numpy array, this time
just with the observations. This needs to be in long array where
the x dimension is the observation values and the y is each cell
in the original raster.

In [20]:
# Convert to a stacked array of observations
stacked_arrays = median.to_array().stack(dims=["y", "x"]).transpose()

# Predict the classes
predicted = model.predict(stacked_arrays)

# Reshape back to the original 2D array
array = predicted.reshape(len(median.y), len(median.x))

# Convert to an xarray again, because it's easier to work with
predicted_da = xr.DataArray(
    array, coords={"y": masked.y, "x": masked.x}, dims=["y", "x"]
)

## Visualise our results

Here we're visualising the results along with the RGB image
and the original training data points. We're doing this using
a Python library called Folium.

In [21]:
print(predicted_da.dtype)  # Check the dtype of your DataArray
predicted_da = predicted_da.astype('float32')  # Convert to float32

float64


In [22]:
# Put it all on a single interactive map
# center = [np.mean([min_lat[0], max_lat[0]]), np.mean([min_lat[1], max_lat[1]])]
# m = folium.Map(location=center, zoom_start=11)

center = [(min_lat + max_lat) / 2, (min_lon + max_lon) / 2]  # Assuming min_lon and max_lon are defined
m = folium.Map(location=center, zoom_start=11)



# RGB for the median
median.odc.to_rgba(vmin=0, vmax=0.3).odc.add_to(m, name="Median Composite")



<folium.raster_layers.ImageOverlay at 0x7f000d3a8250>

<font color='blue'>4.1. - visualise the resulting machine learning prediction:

The name of your model prediction is `predicted_da` so input this before `odc.add_to`

In [23]:
# Categorical for the predicted classes and for the training data
predicted_da.odc.add_to(m, name="Predicted")
gdf.explore(m=m, column="Random_Forest", legend=True, name="Training Data")

# Layer control
folium.LayerControl().add_to(m)

m

<font color='blue'>4.2. - write this model output map to a new file:

Input the name of your model prediction before `.odc.write_cog`

In [41]:
predicted_da.odc.write_cog("Cordia_Nadarivatu_v1_2018.tif", overwrite=True)
# predicted_da.plot.imshow()

PosixPath('Cordia_Nadarivatu_v1_2018.tif')

In [20]:
# Check the data type of the 'Random_Forest' column
print(gdf['Random_Forest'].dtype)

# Display the first few rows of the 'Random_Forest' column to inspect its contents
print(gdf['Random_Forest'].head())

int64
0    8
1    8
2    8
3    8
4    8
Name: Random_Forest, dtype: int64


In [21]:
# Assuming the column contains text data
african_tulip_count = gdf[gdf['Random_Forest'] == '6'].shape[0]

In [22]:
# Assuming '2' represents African Tulip in the Random_Forest column

# Total number of predictions
total_predictions = gdf['Random_Forest'].count()

# Number of African Tulip predictions (where Random_Forest equals 2)
african_tulip_count = gdf[gdf['Random_Forest'] == 6].shape[0]

# Calculate the percentage of African Tulip
percentage_african_tulip = (african_tulip_count / total_predictions) * 100

# Print the result
print(f"Percentage of Settlements: {percentage_african_tulip:.2f}%")


Percentage of Settlements: 29.34%


In [23]:
rf_percentage = gdf["Random_Forest"].value_counts(normalize=True) * 100
print(rf_percentage)

Random_Forest
8    49.462366
6    29.339478
4    12.903226
3     8.141321
9     0.153610
Name: proportion, dtype: float64


<font color='blue'>5.1. - download the file, load, explore and set new colours in QGIS:


In [None]:
joblib.dump(model, "cordia_v.2.0.model")

## Considerations

Do the results make sense?

What are some of the limitations of the visualisation?

### Next steps

The obvious next step is to fine tune the data. Perhaps download the points for this
region of interest as well as the RGB image and add and remove points until
there is a more representative training dataset.

### New AOI for your interest

Choose a new AoI and ToI based on your interests. This could be for wider Tongatapu, Eua or even Vava'u.  

In [39]:
# 1.1 - input your four coordinates here 
min_lat = -17.64440
min_lon = 177.84220
max_lat = -17.56635
max_lon = 177.94832

bbox_nadala = [min_lon, min_lat, max_lon, max_lat]

In [40]:
# Search for Sentinel-2 data
items = client.search(
    collections=["sentinel-2-c1-l2a"],
    bbox=bbox_nadala,
    datetime=datetime,
    query={"eo:cloud_cover": {"lt": 15 }},
).item_collection()

print(f"Found {len(items)} items")

Found 19 items


In [41]:
# Load the data into an xarray Dataset
data_nadala = load(
    items,
    measurements=["red", "green", "blue", "nir08", "swir16", "scl"],
    bbox=bbox_nadala,
    chunks={"x": 800, "y": 800},
    groupby="solar_day",
)

# data_Eua

In [42]:
# Mask out clouds and scale values

# Apply Sentinel-2 cloud mask
# 1: defective, 3: shadow, 9: high confidence cloud, 10: thin cirrus
mask_flags = [1, 3, 9, 10]

cloud_mask = ~data_nadala.scl.isin(mask_flags)
masked = data_nadala.where(cloud_mask)

# Apply scaling and clip to valid data, from 0 to 1
scaled_nadala = (masked.where(masked != 0) * 0.0001).clip(0, 1)

# Add some indices
scaled_nadala["ndvi"] = (scaled_nadala.nir08 - scaled_nadala.red) / (scaled_nadala.nir08 + scaled_nadala.red)
# scaled["ndwi"] = (scaled.green - scaled.nir08) / (scaled.green + scaled.nir08)

# Create a median composite, which should get rid of most of the remaining clouds
# Note that this will take a few minutes to complete

median_nadala = scaled_nadala.median("time").compute()

# median

median_nadala.odc.explore(vmin=0, vmax=0.3)

In [43]:
# Convert to a stacked array of observations
stacked_arrays = median_nadala.to_array().stack(dims=["y", "x"]).transpose()

# Predict the classes
predicted_nadala = model.predict(stacked_arrays)

# Reshape back to the original 2D array
array = predicted_nadala.reshape(len(median_nadala.y), len(median_nadala.x))

# Convert to an xarray again, because it's easier to work with
predicted_da = xr.DataArray(
    array, coords={"y": masked.y, "x": masked.x}, dims=["y", "x"]
)

In [44]:
print(predicted_da.dtype)  # Check the dtype of your DataArray
predicted_da = predicted_da.astype('float32')  # Convert to float32

# Put it all on a single interactive map
# center = [np.mean([min_lat[0], max_lat[0]]), np.mean([min_lat[1], max_lat[1]])]
# m = folium.Map(location=center, zoom_start=11)

center = [(min_lat + max_lat) / 2, (min_lon + max_lon) / 2]  # Assuming min_lon and max_lon are defined
m = folium.Map(location=center, zoom_start=11)

# RGB for the median
median_nadala.odc.to_rgba(vmin=0, vmax=0.3).odc.add_to(m, name="Median Composite")


float64


<folium.raster_layers.ImageOverlay at 0x7f000fa56950>

In [45]:
# Categorical for the predicted classes and for the training data
predicted_da.odc.add_to(m, name="Predicted")
gdf.explore(m=m, column="Random_Forest", legend=True, name="Training Data")

# Layer control
folium.LayerControl().add_to(m)

m

In [46]:
predicted_da.odc.write_cog("Cordia_NakorobuyaCordia_Jan_Nov_2024.tif", overwrite=True)
# predicted_da.plot.imshow()

PosixPath('Cordia_NakorobuyaCordia_Jan_Nov_2024.tif')