# Image Cluster Notebook

## Imports and Functions

In [1]:
from ipyleaflet import basemaps
import ipysheet
from IPython.display import Markdown
import ipywidgets as widgets
from joblib import cpu_count
import leafmap.leafmap as leafmap
import numpy
from osgeo import gdal
import pandas
from pathlib import Path
from sklearn import cluster

# ----------------------------------------------------------------------------
# getCluster
# ----------------------------------------------------------------------------
def getCluster(ds: gdal.Dataset, bands: list=None, n_cluster: int=5):
    
    img = numpy.moveaxis(bands, 0, -1)
    img1d = img.reshape(-1, img.shape[-1])
    
    params = {
        'n_clusters' : n_cluster,
        'random_state' : 0,
        'batch_size' : 256*cpu_count()
    }
    
    cl = cluster.MiniBatchKMeans(**params)
    model = cl.fit(img1d)
    imgCl = model.labels_
    imgCl = imgCl.reshape(img[:,:,0].shape)

    return imgCl

# ----------------------------------------------------------------------------
# handleClick
# ----------------------------------------------------------------------------
def handleClick(change: dict) -> None:
    
    with output:
        
        if change.new == 'Next':

            nn = updateList(list(sl.options), list(sl.value))
            updateDict('N')
            sl.options = nn
            bt.value='Select:'

        if change.new == 'Done':
            updateDict('D')

        if change.new == 'Start Over':
        
            sl.options = opts
            updateDict('S')
            bt.value = 'Select:'

# ----------------------------------------------------------------------------
# labelsToGeotiff
# These renderers seem to write tiles or pyramids to disk behind the scenes.  
# To make rendering code easier, write the labels as a geotiff; then the
# renderer will not need to do it.
# ----------------------------------------------------------------------------
def labelsToGeotiff(labelsFile: Path, labels: numpy.ndarray) -> gdal.Dataset:
        
    labelsDs = gdal.GetDriverByName('GTiff').Create( \
                    str(labelsFile),
                    xsize=ds.RasterXSize,
                    ysize=ds.RasterYSize,
                    eType=gdal.GDT_Float32
               )

    labelsDs.SetSpatialRef(ds.GetSpatialRef())
    labelsDs.SetGeoTransform(ds.GetGeoTransform())
    outBand = labelsDs.GetRasterBand(1)
    outBand.WriteArray(labels)
    outBand.ComputeStatistics(0)  # For min, max used in color table
    labelsDs.FlushCache()
    labelsDs.BuildOverviews()

    return labelsDs

# ----------------------------------------------------------------------------
# relabel
# ----------------------------------------------------------------------------
def relabel(labelArray: numpy.ndarray, lookup: dict) -> numpy.ndarray:
    
    newLab = labelArray.copy()
    
    for k, v in lookup.items():

        if len(v)==1 & k==v[0]:
            pass
        else:
            newLab = numpy.where(numpy.isin(newLab, v), k, newLab)
    
    return newLab
            
# ----------------------------------------------------------------------------
# updateDict
# ----------------------------------------------------------------------------
def updateDict(op: str) -> None:
    
    if op == 'N':
        
        key = list(sl.value)[0]
        table[key] = list(sl.value)
        # print('Re-grouping : ', table)
    
    if op == 'D':
    
        if len(sl.options) > 0:
            
            key = list(sl.options)[0]
            table[key] = list(sl.options)
        
        print('Final Groups : ', table)
    
    if op == 'S':
        
        table.clear()
    
# ----------------------------------------------------------------------------
# updateList
# ----------------------------------------------------------------------------
def updateList(old: list, out: list) -> list:
    return [ele for ele in old if ele not in out]


## Specify inputs using the variables below.

In [2]:
inFile = Path('/explore/nobackup/people/rlgill/innovation-lab-repositories/ImageCluster/my4326.tif')
redBandId = 3
greenBandId = 2
blueBandId = 1

# inFile = Path('/explore/nobackup/people/mcarrol2/LCLUC_Senegal/ForKonrad/Tappan26_WV03_20210314_M1BS_10400100676F0900_isos.tif')
# redBandId = 1
# greenBandId = 1
# blueBandId = 1

# inFile = Path('/explore/nobackup/people/mcarrol2/LCLUC_Senegal/ForKonrad/Tappan26_WV03_20210314_M1BS_10400100676F0900_data.tif')
# redBandId = 5
# greenBandId = 4
# blueBandId = 2

noDataValue = -9999.0 or None
outDirectory = inFile.parent

## Open the file, create overviews, and compute the clusters.

In [3]:
# ---
# Build overviews for rendering.  Cloud-optimized Geotiffs, CoGs, rely on
# tiling and overview images. Before we bother creating CoGs, which entails an
# additional image file, see how adding overviews to the input image helps.
# Assigning 'dummy' prevents extraneous output in the notebook.
# ---
ds = gdal.Open(str(inFile))

if ds.GetRasterBand(1).GetOverviewCount() == 0:
    dummy = ds.BuildOverviews()
    
# ---
# Read the bands.
# ---
redBand = ds.GetRasterBand(redBandId).ReadAsArray()
greenBand = ds.GetRasterBand(greenBandId).ReadAsArray()
blueBand = ds.GetRasterBand(blueBandId).ReadAsArray()
noDataValue = ds.GetRasterBand(redBandId).GetNoDataValue() or noDataValue

forExtremes = redBand[redBand != noDataValue]
minValue = forExtremes.min()
maxValue = forExtremes.max()

# ---
# Compute the clusters and prepare to render.
# ---
labels = getCluster(ds, bands=[redBand, greenBand, blueBand], n_cluster=30)
labelsFile = outDirectory / (inFile.stem + '-labels' + inFile.suffix)
lDs = labelsToGeotiff(labelsFile, labels)

## Step 3: review the input image.

In [8]:
# ---
# We could use leafmap.add_raster, but this way we get a handle to the tile.
# ---
inLayer = leafmap.get_local_tile_layer(str(inFile),
                                       band=[redBandId, greenBandId, blueBandId],
                                       vmin=minValue,
                                       vmax=maxValue,
                                       nodata=noDataValue,
                                       opacity=0.5,
                                       layer_name='Input Image')
 
clusterLayer = leafmap.get_local_tile_layer( \
                                     str(labelsFile),
                                     vmin=lDs.GetRasterBand(1).GetMinimum(),
                                     vmax=lDs.GetRasterBand(1).GetMaximum(),
                                     nodata=noDataValue,
                                     opacity=0.5,
                                     palette='viridis',
                                     layer_name='Clusters')

# ---
# The only reason I am using leafmap and not ipyleaflet directly is to get
# the inspector tool under the wrench button on leafmap's map.
#
# There are comments that ipyleaflet.zoom_to_bounds() only works with Mercator
# projection.  This is consistent with what I see.  However, it seems to work
# when you set the center.
# ---
m = leafmap.Map(fullscreen_control=False,
                layers_control=True,
                search_control=False,
                draw_control=False,
                measure_control=False,
                scale_control=False,
                google_map='SATELLITE',
                center=[inLayer.bounds[0][0], inLayer.bounds[0][1]])

m.add_layer(inLayer)
m.add(clusterLayer)

m.fit_bounds([[inLayer.bounds[0][0], inLayer.bounds[0][1]],
              [inLayer.bounds[1][0], inLayer.bounds[1][1]]])

m

Map(center=[66.346511472519, -140.3222662042015], controls=(ZoomControl(options=['position', 'zoom_in_text', '…

## Update the labels.
Select multiple values by clicking the mouse or using the arrow keys while pressing shirt, control or command.

In [5]:
opts = list(numpy.unique(labels))

sl = widgets.SelectMultiple(options=opts, 
                            layout=(widgets.Layout(height='200px', width='150px')))

bt = widgets.ToggleButtons(options=['Select:', 'Next', 'Done', 'Start Over'], 
                           value='Select:')

# bt = widgets.ToggleButtons(options=['Next', 'Done', 'Start Over'], 
#                            value='Next')

output = widgets.Output()
display(sl, bt, output)
table = {}
bt.observe(handleClick, names='value')

SelectMultiple(layout=Layout(height='200px', width='150px'), options=(0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12…

ToggleButtons(options=('Select:', 'Next', 'Done', 'Start Over'), value='Select:')

Output()

## Edit the groups.
Edit cluster IDs in each group.  When finished, proceed to the next cell.

In [7]:
strTab = {}

for item in table:
    strTab[item] = ', '.join(str(i) for i in table[item])

df = pandas.DataFrame(strTab.items(), columns=['Class', 'Cluster ID'])
sheet = ipysheet.from_dataframe(df)
sheet.column_width = [1, 5]
sheet

Sheet(cells=(Cell(column_end=0, column_start=0, numeric_format='0[.]0', row_start=0, squeeze_row=False, type='…

In [8]:
editedDf = ipysheet.to_dataframe(sheet)
strClusters = editedDf.to_dict()['Cluster ID']

finalClusters = {}

for key in strClusters:
    
    strCluster = strClusters[key]
    finalClusters[int(key)] = [int(i) for i in strCluster.split(',') if i]

newLabels = relabel(labels, finalClusters)

## Review the updated map.

In [9]:
clusterMapFile = outDirectory / (inFile.stem + '-cluster-map' + inFile.suffix)
nDs = labelsToGeotiff(clusterMapFile, newLabels)

newClusterLayer = leafmap.get_local_tile_layer( \
                                     str(clusterMapFile),
                                     vmin=lDs.GetRasterBand(1).GetMinimum(),
                                     vmax=lDs.GetRasterBand(1).GetMaximum(),
                                     nodata=noDataValue,
                                     opacity=0.5,
                                     palette='viridis',
                                     layer_name='New Clusters')
 
n = leafmap.Map(fullscreen_control=False,
                layers_control=True,
                search_control=False,
                draw_control=False,
                measure_control=False,
                scale_control=False,
                google_map='SATELLITE',
                center=[inLayer.bounds[0][0], inLayer.bounds[0][1]])

n.add(newClusterLayer)

n.fit_bounds([[inLayer.bounds[0][0], inLayer.bounds[0][1]],
              [inLayer.bounds[1][0], inLayer.bounds[1][1]]])

display(n)

Markdown(f'<h3>The cluster map is at<br><br>{clusterMapFile}</h3>')

Map(center=[12.766262399623551, -16.502342467486244], controls=(ZoomControl(options=['position', 'zoom_in_text…

<h3>The cluster map is at<br><br>/explore/nobackup/people/mcarrol2/LCLUC_Senegal/ForKonrad/Tappan26_WV03_20210314_M1BS_10400100676F0900_isos-cluster-map.tif</h3>