# Clustering

In this exercise k-means clustering is used for finding 5 clusters from a Sentinel-2 satellite image, that has 8 bands: 4 bands for two dates. The bands source data is:
* 'b02' / '2021-05-11'
* 'b02' / '2021-07-21'
* 'b03' / '2021-05-11'
* 'b03' / '2021-07-21'
* 'b04' / '2021-05-11'
* 'b04' / '2021-07-21'
* 'b08' / '2021-05-11'
* 'b08' / '2021-07-21'

[Bands](https://custom-scripts.sentinel-hub.com/custom-scripts/sentinel-2/bands/): b02=blue, b03=green, b04=red, b08=infrared

## Main steps

1) Read data and shape it to suitable form for scikit-learn.
2) Calculate clusters.
3) Reshape results back to 2D raster.
4) Plot restults.
5) Save clustered image.

## Imports and paths

In [None]:
from sklearn.cluster import KMeans
import rasterio
import numpy as np
import os
import time
import urllib
from rasterio.windows import from_bounds

In [None]:
### File paths.
# Source data URLs
image_url = 'https://a3s.fi/gis-courses/gis_ml/image.tif'
multiclass_classification_url = 'https://a3s.fi/gis-courses/gis_ml/labels_multiclass.tif'

# Folders
user = os.environ.get('USER')
base_folder = os.path.join('/scratch/project_2002044', user, '2022/GeoML')
dataFolder = os.path.join(base_folder,'data')
outputBaseFolder= os.path.join(base_folder,'01_clustering')

# Source data local paths
image_file = os.path.join(dataFolder, 'image.tif')
# Labels are used only for comparision, not for clustering
multiclass_classification_file = os.path.join(dataFolder, 'labels_multiclass.tif')

# Output path
outputImage = os.path.join(outputBaseFolder,'clustering_KMeans.tif')

# BBOX for exercise data, we use less than full image for clustering, because of speed and to better see the results when plotting.
minx = 240500
miny = 6775500
maxx = 253500
maxy = 6788500 

(Download input data if needed.)

In [None]:
if not os.path.isdir(dataFolder):
        os.makedirs(dataFolder)

if not os.path.exists(image_file):
    urllib.request.urlretrieve(image_url, image_file)
     
if not os.path.exists(multiclass_classification_file):
    urllib.request.urlretrieve(multiclass_classification_url, multiclass_classification_file)   

## Read data and shape it to suitable form for scikit-learn.

The satellite image has 8 channels, so rasterio reads it in as 3D data cube.

For scikit-learn we reshape the data to 2D, having in dataframe one row for each pixel. Each pixel has eight values, one for each band/date.

In [None]:
# Read the pixel values from .tif file as dataframe
with rasterio.open(image_file) as image_dataset:
    image_data = image_dataset.read(window=from_bounds(minx, miny, maxx, maxy, image_dataset.transform))
    
# We have to change the data format from bands x width x height to width*height x bands
# This means that each pixel from the original dataset has own row in the result dataframe.
# Check shape of input data
print ('Dataframe original shape, 3D: ', image_data.shape)    

Save number of bands for later, to be able to reshape data back to 2D.

In [None]:
no_bands_in_image = image_data.shape[0]
no_bands_in_image

As a mid-step transponse the axis order, so that the bands are the last. Notice how the dataframe size changes.

In [None]:
image_data2 = np.transpose(image_data, (1, 2, 0))
# Check again the data shape, now the bands should be last.
print ('Dataframe shape after transpose, 3D: ', image_data2.shape) 

Then reshape to 2D.

In [None]:
pixels = image_data2.reshape(-1, no_bands_in_image)
print ('Dataframe shape after transpose and reshape, 2D: ', pixels.shape) 

*How many pixels there is?*

## Calculate clusters

Find 5 clusters using 10 iterations. *Try also with for example 7 classes.*

*This takes a moment, please wait*

In [None]:
classes = KMeans(n_clusters=5, random_state=63, max_iter=10).fit_predict(pixels)

Check the results. *How many pixels each class has?*

In [None]:
np.unique(classes, return_counts=True)[1]

## Reshape results back to 2D raster

The clustering results are per pixel, so first reshape the data back to 2D raster.

In [None]:
print ('Dataframe shape, output, 1D: ', classes.shape) 

In [None]:
#Reshape back to 2D    
classes2D = np.reshape(classes, (image_data.shape[1], image_data.shape[2]))
print ('Dataframe shape, output after reshape, 2D: ', classes2D.shape)  

## Plot results

* Satellite image
* Clustering results
* Training labels
* Clustering histogram

In [None]:
import matplotlib.pyplot as plt
import matplotlib.colors
%matplotlib inline
from rasterio.plot import show
from rasterio.plot import show_hist

In [None]:
### Help function for sentinel image plotting, to normalize band values and enhance contrast. Just what QGIS does automatically.
def normalize(array):
    min_percent = 2   # Low percentile
    max_percent = 98  # High percentile
    lo, hi = np.percentile(array, (min_percent, max_percent))
    return (array - lo) / (hi - lo)

In [None]:
### Create a subplot for 4 images  
fig, ax = plt.subplots(ncols=2, nrows=2, figsize=(10, 10))
# If you have big screen, use bigger figsize, for example 15,15.

# The Sentinel image   
# Read the bands separately and apply the normalize function to each of them to increase contrast
nir, red, green = image_data[7,], image_data[3,], image_data[1,]
nirn, redn, greenn = normalize(nir), normalize(red), normalize(green)
stacked = np.stack((nirn, redn, greenn))
show(stacked, ax=ax[0,0], title='image') 

### The clustering results
cmap = matplotlib.colors.LinearSegmentedColormap.from_list("", ["red","yellow","darkgreen","violet","blue"]) 
show(classes2D, ax=ax[0, 1], cmap=cmap, title='Clustering classes')

### The training multiclass labels
with rasterio.open(multiclass_classification_file) as src:
    labels_data = src.read(window=from_bounds(minx, miny, maxx, maxy, src.transform))
    
cmap = matplotlib.colors.LinearSegmentedColormap.from_list("", ["white","green","orange","blue","violet"])
show(labels_data, ax=ax[1, 0], cmap=cmap, title='Training labels')

### The histogram of clustering results
show_hist(classes2D, ax=ax[1, 1], bins=[-0.5,0.5,1.5,2.5,3.5,4.5], title="Clustering histogram") 

## Save clustered image

Save the result to a GeoTiff file. First prepare the metadata of the new file - compared to the original file, the new file has onlyl 1 band and int32 data type.

In [None]:
meta = image_dataset.meta
meta.update(count=1, dtype='int32')
# Save the data
with rasterio.open(outputImage, 'w', **meta) as dst:
    dst.write(classes2D, 1)