# Image Cluster Notebook

## Imports and Functions

In [2]:
from ipyleaflet import basemaps
import ipysheet
import ipywidgets as widgets
from IPython.display import Markdown
from joblib import cpu_count
import leafmap.leafmap as leafmap
import numpy
from osgeo import gdal
import pandas
from pathlib import Path
from sklearn import cluster

# ----------------------------------------------------------------------------
# getCluster
# ----------------------------------------------------------------------------
def getCluster(ds: gdal.Dataset, 
               bands: list=None, 
               n_cluster: int=5) -> numpy.ndarray:
    
    img = numpy.moveaxis(bands, 0, -1)
    img1d = img.reshape(-1, img.shape[-1])
    
    params = {
        'n_clusters' : n_cluster,
        'random_state' : 0,
        'batch_size' : 256*cpu_count()
    }
    
    cl = cluster.MiniBatchKMeans(**params)
    # cl = cluster.KMeans(**params)
    model = cl.fit(img1d)
    imgCl = model.labels_
    imgCl = imgCl.reshape(img[:,:,0].shape)

    return imgCl

# ----------------------------------------------------------------------------
# handleClick
# ----------------------------------------------------------------------------
def handleClick(change: dict) -> None:
    
    with output:
        
        if change.new == 'Next':

            nn = updateList(list(sl.options), list(sl.value))
            updateDict('N')
            sl.options = nn
            bt.value='Select:'

        if change.new == 'Done':
            updateDict('D')

        if change.new == 'Start Over':
        
            sl.options = opts
            updateDict('S')
            bt.value = 'Select:'

# ----------------------------------------------------------------------------
# labelsToGeotiff
# These renderers seem to write tiles or pyramids to disk behind the scenes.  
# To make rendering code easier, write the labels as a geotiff; then the
# renderer will not need to do it.
# ----------------------------------------------------------------------------
def labelsToGeotiff(labelsFile: Path, labels: numpy.ndarray) -> gdal.Dataset:
        
    labelsDs = gdal.GetDriverByName('GTiff').Create( \
                    str(labelsFile),
                    xsize=ds.RasterXSize,
                    ysize=ds.RasterYSize,
                    eType=gdal.GDT_Float32
               )

    labelsDs.SetSpatialRef(ds.GetSpatialRef())
    labelsDs.SetGeoTransform(ds.GetGeoTransform())
    outBand = labelsDs.GetRasterBand(1)
    outBand.WriteArray(labels)
    outBand.ComputeStatistics(0)  # For min, max used in color table
    labelsDs.FlushCache()
    labelsDs.BuildOverviews()

    return labelsDs

# ----------------------------------------------------------------------------
# relabel
# ----------------------------------------------------------------------------
def relabel(labelArray: numpy.ndarray, lookup: dict) -> numpy.ndarray:
    
    newLab = labelArray.copy()
    
    for k, v in lookup.items():

        if len(v)==1 & k==v[0]:
            pass
        else:
            newLab = numpy.where(numpy.isin(newLab, v), k, newLab)
    
    return newLab
            
# ----------------------------------------------------------------------------
# updateDict
# ----------------------------------------------------------------------------
def updateDict(op: str) -> None:
    
    if op == 'N':
        
        key = list(sl.value)[0]
        table[key] = list(sl.value)
        # print('Re-grouping : ', table)
    
    elif op == 'D':
    
        if len(sl.options) > 0:
            
            key = list(sl.options)[0]
            table[key] = list(sl.options)
        
        print('Final Groups : ', table)
    
    elif op == 'S':
        
        table.clear()
        print('Start Over : ', table)
    
# ----------------------------------------------------------------------------
# updateList
# ----------------------------------------------------------------------------
def updateList(old: list, out: list) -> list:
    return [ele for ele in old if ele not in out]

## Step 1: specify inputs using the variables below.

In [3]:
inFile = Path('/explore/nobackup/people/rlgill/innovation-lab-repositories/ImageCluster/my4326-small.tif')
outDirectory = inFile.parent

redBandId = 3
greenBandId = 2
blueBandId = 1
noDataValue = -9999.0 or None

## Step 2: open the file and create overviews, if they do not exist.

In [4]:
# ---
# Build overviews for rendering.  Cloud-optimized Geotiffs, CoGs, rely on
# tiling and overview images. Before we bother creating CoGs, which entails an
# additional image file, see how adding overviews to the input image helps.
# Assigning 'dummy' prevents extraneous output in the notebook.
# ---
ds = gdal.Open(str(inFile))

if ds.GetRasterBand(1).GetOverviewCount() == 0:
    dummy = ds.BuildOverviews()

## Step 3: review the input image.

In [None]:
# ---
# https://ipyleaflet.readthedocs.io/en/latest/map_and_basemaps/map.html?highlight=map#ipyleaflet.Map
# Map <- TileLayer -> RasterLayer -> Layer 
# It is supposed to zoom to the extent of the raster, but it is not.
# ---
m = leafmap.Map(google_map='SATELLITE', zoom=25)

m.add_raster(str(inFile), 
             band=[3, 2, 1], 
             vmin=0, 
             vmax=1, 
             nodata=noDataValue,
             opacity=0.5)
m

## Step 4: run K-means clustering.

In [None]:
redBand = ds.GetRasterBand(redBandId).ReadAsArray()
greenBand = ds.GetRasterBand(greenBandId).ReadAsArray()
blueBand = ds.GetRasterBand(blueBandId).ReadAsArray()
labels = getCluster(ds, bands=[redBand, greenBand, blueBand], n_cluster=30)

## Step 5: review the cluster map.

In [None]:
# https://gis.stackexchange.com/questions/384581/raster-to-geopandas
# Get the UL and LR of the image.
import geopandas

xform = ds.GetGeoTransform()
xScale = xform[1]
yScale = xform[5]
width = ds.RasterXSize
height = ds.RasterYSize
ulx = xform[0]
uly = xform[3]
lrx = ulx + width * xScale
lry = uly + height * yScale

# Create an xy grid.  (10031, 3862)
x = numpy.linspace(ulx, lrx, ds.RasterXSize)
y = numpy.linspace(uly, lry, ds.RasterYSize)
xs, ys = numpy.meshgrid(x, y)

data = {'X': pandas.Series(xs.ravel()),
        'Y': pandas.Series(ys.ravel()),
        'Z': pandas.Series(labels.ravel())}

df = pandas.DataFrame(data=data)
geometry = geopandas.points_from_xy(df.X, df.Y)

gdf = geopandas.GeoDataFrame(df, 
                             crs=ds.GetSpatialRef().ExportToProj4(), 
                             geometry=geometry)

In [None]:
# Render as a geodataframe to get the "hover" tool.
# This is ridiculously slow.
l = leafmap.Map(google_map='SATELLITE', zoom=25)
l.add_gdf(gdf)
l

## Step 6: update the labels.
Select multiple values by clicking the mouse or using the arrow keys while pressing shirt, control or command.

In [None]:
opts = list(numpy.unique(labels))

sl = widgets.SelectMultiple(options=opts, 
                            layout=(widgets.Layout(height='200px', width='150px')))

bt = widgets.ToggleButtons(options=['Select:', 'Next', 'Done', 'Start Over'], 
                           value='Select:')

output = widgets.Output()
display(sl, bt, output)
table = {}
bt.observe(handleClick, names='value')

## Step 7: edit the groups.
Edit cluster IDs in each group.  When finished, proceed to the next cell.

In [None]:
strTab = {}

for item in table:
    strTab[item] = ', '.join(str(i) for i in table[item])

df = pandas.DataFrame(strTab.items(), columns=['Class', 'Cluster ID'])
sheet = ipysheet.from_dataframe(df)
sheet.column_width = [1, 5]
sheet

In [None]:
editedDf = ipysheet.to_dataframe(sheet)
strClusters = editedDf.to_dict()['Cluster ID']

finalClusters = {}

for key in strClusters:
    
    strCluster = strClusters[key]
    finalClusters[int(key)] = [int(i) for i in strCluster.split(',') if i]

newLabels = relabel(labels, finalClusters)

## Step 8: review the updated map.

In [None]:
clusterMapFile = outDirectory / (inFile.stem + '-cluster-map' + inFile.suffix)
nDs = labelsToGeotiff(clusterMapFile, newLabels)

n = leafmap.Map(google_map='SATELLITE', zoom=25)

n.add_raster(str(clusterMapFile), 
             vmin=nDs.GetRasterBand(1).GetMinimum(), 
             vmax=nDs.GetRasterBand(1).GetMinimum(), 
             opacity=0.5, 
             palette='viridis')

n

In [None]:
Markdown(f'<h3>Complete: the cluster map is at<br><br>{clusterMapFile}</h3>')