# Chapter 18: Clustering and segmentation

In [None]:
import geopandas as gpd
from sklearn.cluster import KMeans # pip3 install scikit-learn

cities_gdf = gpd.read_file('data/hungary_cities.shp')
display(cities_gdf)

Fetch points for cities:

In [None]:
points = [(geom.x, geom.y) for geom in cities_gdf.geometry]
print("Number of points: {0}".format(len(points)))

Cluster the points using the *K-Means algorithm*:

In [None]:
pred = KMeans(n_clusters=19).fit_predict(points)
print(pred)
print(len(pred))

Plot figure:

In [None]:
import matplotlib.pyplot as plt
%matplotlib inline

plt.figure(figsize=(12, 8))

# Fetch list of X and Y coordinates
xs = [point[0] for point in points]
ys = [point[1] for point in points]

# Put the cluster points on the plot
plt.scatter(xs, ys, c=pred)

# Display plot
plt.title("Cluster map of the Hungarian cities")
plt.show()

---

## Clustering raster images

### Read the dataset

The `LC08_L1TP_188027_20200420_20200508_01_T1.tif` file is a Landsat 8 satellite image from Budapest and parts of Western-Hungary, acquired on 2020 April 20. It should be familiar from [Chapter 10](10_spatial_raster_solutions.ipynb).

Download: https://gis.inf.elte.hu/files/public/landsat-budapest-2020 (1.4 GB)

In [None]:
import rasterio
budapest_2020 = rasterio.open('LC08_L1TP_188027_20200420_20200508_01_T1.tif')
print(budapest_2020.count) # band count
print(budapest_2020.width) # dimensions
print(budapest_2020.height)

Define resampling function:

In [None]:
from rasterio.enums import Resampling

def read_resampled_band(dataset, band, resample_factor):   
    data = dataset.read(band,
        out_shape=(
            1,
            int(dataset.height * resample_factor),
            int(dataset.width * resample_factor)
        ),
        resampling=Resampling.bilinear
    )
    return data

Read the blue, green, read and near-infrared bands into *Numpy* arrays. Resample them to a smaller size to make further processing (clustering especially) faster.

In [None]:
blue = read_resampled_band(budapest_2020, 2, 1/4)
green = read_resampled_band(budapest_2020, 3, 1/4)
red = read_resampled_band(budapest_2020, 4, 1/4)
nir = read_resampled_band(budapest_2020, 5, 1/4)

print(red.shape)

Display the near-infrared band for verification:

In [None]:
plt.figure(figsize=[10,10])
plt.imshow(nir, cmap='Reds')
plt.axis('off')
plt.colorbar()
plt.show()

Display the RGB image for verification:

In [None]:
from rasterio.plot import show
import numpy as np

red_max = np.percentile(red, 99.99)
blue_max = np.percentile(blue, 99.99)
green_max = np.percentile(green, 99.99)

# astype('f4') is a numpy function to convert to float (4 byte)
redf = red.astype('f4') / red_max
bluef = blue.astype('f4') / blue_max
greenf = green.astype('f4') / green_max
rgb = [redf, greenf, bluef]

plt.figure(figsize=[10,10])
show(rgb)
plt.show()

---

### Single-band clustering

Cluster the satellite image based on the near-infrared band.

In [None]:
nir_1d = nir.reshape(nir.shape[0] * nir.shape[1], 1)
print(nir_1d.shape)

In [None]:
pred = KMeans(n_clusters=6).fit_predict(nir_1d)
img_clusters = pred.reshape(nir.shape)

In [None]:
import matplotlib.colors as mc
cmap = mc.LinearSegmentedColormap.from_list('', ['red', 'black', 'gray', 'green', 'white', 'blue'])

plt.figure(figsize=[12,12])
plt.imshow(img_clusters, cmap=cmap)
plt.axis('off')
plt.show()

---

### Multi-band clustering

Cluster the satellite image based on the RGB bands.

In [None]:
red_1d   = red.reshape(red.shape[0] * red.shape[1], 1)
green_1d = green.reshape(green.shape[0] * green.shape[1], 1)
blue_1d  = blue.reshape(blue.shape[0] * blue.shape[1], 1)

rgb_1d = [(0, 0, 0)] * (red.shape[0] * red.shape[1])
for i in range(red.shape[0] * red.shape[1]):
    rgb_1d[i] = (red_1d[i, 0], green_1d[i, 0], blue_1d[i, 0])
print(rgb_1d[200000])

In [None]:
pred = KMeans(n_clusters=8).fit_predict(rgb_1d)
img_clusters = pred.reshape(red.shape)

In [None]:
cmap = mc.LinearSegmentedColormap.from_list('', ['red', 'black', 'gray', 'green', 'white', 'yellow', 'blue', 'purple'])

plt.figure(figsize=[15,15])
plt.imshow(img_clusters, cmap=cmap)
plt.axis('off')
plt.show()