In [None]:
from matplotlib import pyplot as plt
from sklearn.cluster import KMeans
import xarray as xr

In [None]:
# Read data
viirs_file = '../data/input/snapshot-2025-04-12T00_00_00Z.tif'
xds = xr.open_dataset(viirs_file, engine="rasterio")

In [None]:
## Display false color image
# Extract data array (last band is alpha), drop singleton dimension, and reorder dimensions so that they are in expected order
da = xds.isel(band=slice(None, 3)).to_dataarray().drop_vars("variable").squeeze().transpose("x", "y", "band")

# Rescale from [0, 255] to [0, 1]
da = da / 255

# Plot
da.plot.imshow()
plt.title('False Color VIIRS Imagery')

In [None]:
## K-means classifcation
# Number of clusters
# Started with four to (hopefully) correspond to water, veg, built up, and clouds.
# Settled on five. There just aren't enough clouds to be picked up with so few classes,
# and setting to five gives a class that looks like it corresponds to mixed pixels
# that contain water. Eh, okay lets put it way up
n_clusters = 15

# Extract data as numpy array and reshape so that it in N x bands
n_bands = xds.band_data.shape[0]
n_x = xds.band_data.shape[1]
n_y = xds.band_data.shape[2]
data = xds.band_data.transpose('y', 'x', 'band').values.reshape(-1, n_bands)

# Perform classifcation
kmeans = KMeans(n_clusters=n_clusters, random_state=9)
results = kmeans.fit_predict(data)

# Reshape data back to grid
results = results.reshape(n_x, n_y).T

# Add results to dataset
xds['clusters'] = xr.DataArray(results, dims=('x', 'y'))

In [None]:
# Plot both images side by side
_, axes = plt.subplots(1, 2, figsize=(12, 6))

da.plot.imshow(ax=axes[0])
xds['clusters'].plot.imshow(ax=axes[1], cmap='tab10', add_colorbar=False)

for ax in axes:
    ax.set_aspect('equal')


In [None]:
xds['clusters'].plot.imshow(figsize=(15, 15), cmap='tab10', add_colorbar=True)

In [None]:
# Save output to geotiff
da_output = xds['clusters']
da_output = da_output.transpose('y', 'x')
da_output.rio.to_raster('../data/output/classified_data.tif')

In [None]:
# Check output
xds_new = xr.open_dataset('../data/output/classified_data.tif', engine="rasterio")
xds_new