# Image Cluster Notebook

## Imports and Functions

In [1]:
from ipyleaflet import basemaps
import ipysheet
import ipywidgets as widgets
from IPython.display import Markdown
from joblib import cpu_count
import leafmap.leafmap as leafmap
import numpy
from osgeo import gdal
import pandas
from pathlib import Path
from sklearn import cluster

# ----------------------------------------------------------------------------
# getCluster
# ----------------------------------------------------------------------------
def getCluster(ds: gdal.Dataset, bands: list=None, n_cluster: int=5):
    
    img = numpy.moveaxis(bands, 0, -1)
    img1d = img.reshape(-1, img.shape[-1])
    
    params = {
        'n_clusters' : n_cluster,
        'random_state' : 0,
        'batch_size' : 256*cpu_count()
    }
    
    cl = cluster.MiniBatchKMeans(**params)
    # cl = cluster.KMeans(**params)
    model = cl.fit(img1d)
    imgCl = model.labels_
    imgCl = imgCl.reshape(img[:,:,0].shape)

    return imgCl

# ----------------------------------------------------------------------------
# handleClick
# ----------------------------------------------------------------------------
def handleClick(change: dict) -> None:
    
    with output:
        
        if change.new == 'Next':

            nn = updateList(list(sl.options), list(sl.value))
            updateDict('N')
            sl.options = nn
            bt.value='Select:'

        if change.new == 'Done':
            updateDict('D')

        if change.new == 'Start Over':
        
            sl.options = opts
            updateDict('S')
            bt.value = 'Select:'

# ----------------------------------------------------------------------------
# labelsToGeotiff
# These renderers seem to write tiles or pyramids to disk behind the scenes.  
# To make rendering code easier, write the labels as a geotiff; then the
# renderer will not need to do it.
# ----------------------------------------------------------------------------
def labelsToGeotiff(labelsFile: Path, labels: numpy.ndarray) -> gdal.Dataset:
        
    labelsDs = gdal.GetDriverByName('GTiff').Create( \
                    str(labelsFile),
                    xsize=ds.RasterXSize,
                    ysize=ds.RasterYSize,
                    eType=gdal.GDT_Float32
               )

    labelsDs.SetSpatialRef(ds.GetSpatialRef())
    labelsDs.SetGeoTransform(ds.GetGeoTransform())
    outBand = labelsDs.GetRasterBand(1)
    outBand.WriteArray(labels)
    outBand.ComputeStatistics(0)  # For min, max used in color table
    labelsDs.FlushCache()
    labelsDs.BuildOverviews()

    return labelsDs

# ----------------------------------------------------------------------------
# relabel
# ----------------------------------------------------------------------------
def relabel(labelArray: numpy.ndarray, lookup: dict) -> numpy.ndarray:
    
    newLab = labelArray.copy()
    
    for k, v in lookup.items():

        if len(v)==1 & k==v[0]:
            pass
        else:
            newLab = numpy.where(numpy.isin(newLab, v), k, newLab)
    
    return newLab
            
# ----------------------------------------------------------------------------
# updateDict
# ----------------------------------------------------------------------------
def updateDict(op: str) -> None:
    
    if op == 'N':
        
        key = list(sl.value)[0]
        table[key] = list(sl.value)
        # print('Re-grouping : ', table)
    
    if op == 'D':
    
        if len(sl.options) > 0:
            
            key = list(sl.options)[0]
            table[key] = list(sl.options)
        
        print('Final Groups : ', table)
    
    if op == 'S':
        
        table.clear()
        print('Start Over : ', table)
    
# ----------------------------------------------------------------------------
# updateList
# ----------------------------------------------------------------------------
def updateList(old: list, out: list) -> list:
    return [ele for ele in old if ele not in out]

## Step 1: specify inputs using the variables below.

In [6]:
inFile = Path('/explore/nobackup/people/rlgill/innovation-lab-repositories/ImageCluster/my4326.tif')
redBandId = 3
greenBandId = 2
blueBandId = 1

# inFile = Path('/explore/nobackup/people/mcarrol2/LCLUC_Senegal/ForKonrad/Tappan26_WV03_20210314_M1BS_10400100676F0900_isos.tif')
# redBandId = 1
# greenBandId = 1
# blueBandId = 1

# inFile = Path('/explore/nobackup/people/mcarrol2/LCLUC_Senegal/ForKonrad/Tappan26_WV03_20210314_M1BS_10400100676F0900_data.tif')
# redBandId = 1
# greenBandId = 2
# blueBandId = 3

noDataValue = -9999.0 or None
outDirectory = inFile.parent

## Step 2: open the file, create overviews, and compute the clusters.

In [7]:
# ---
# Build overviews for rendering.  Cloud-optimized Geotiffs, CoGs, rely on
# tiling and overview images. Before we bother creating CoGs, which entails an
# additional image file, see how adding overviews to the input image helps.
# Assigning 'dummy' prevents extraneous output in the notebook.
# ---
ds = gdal.Open(str(inFile))

if ds.GetRasterBand(1).GetOverviewCount() == 0:
    dummy = ds.BuildOverviews()
    
# ---
# Read the bands.
# ---
redBand = ds.GetRasterBand(redBandId).ReadAsArray()
greenBand = ds.GetRasterBand(greenBandId).ReadAsArray()
blueBand = ds.GetRasterBand(blueBandId).ReadAsArray()
noDataValue = ds.GetRasterBand(redBandId).GetNoDataValue() or noDataValue

# ---
# Compute the clusters and prepare to render.
# ---
labels = getCluster(ds, bands=[redBand, greenBand, blueBand], n_cluster=30)
labelsFile = outDirectory / (inFile.stem + '-labels' + inFile.suffix)
lDs = labelsToGeotiff(labelsFile, labels)

## Step 3: review the input image.

In [8]:
# ---
# If the input image is not in EPSG:4326, the map will not be centered.
# Compute the bounding box, and transform it to EPSG:4326.
# ---
# from osgeo import osr
# inUlx, xres, xskew, inUly, yskew, yres  = ds.GetGeoTransform()
# inLrx = inUlx + (ds.RasterXSize * xres)
# inLry = inUly + (ds.RasterYSize * yres)

# target = osr.SpatialReference()
# # target.ImportFromEPSG(4326)
# target.ImportFromEPSG(3857)
# transform = osr.CoordinateTransformation(ds.GetSpatialRef(), target)
# uly, ulx, ulz = transform.TransformPoint(inUlx, inUly)
# lry, lrx, lrz = transform.TransformPoint(inLrx, inLry)

# ---
# Render the map.
# ---
# leftArgs = {'band': [redBandId, greenBandId, blueBandId],
#             'vmin': redBand.min(),
#             'vmax': redBand.max(),
#             'nodata': noDataValue,
#             'opacity': 0.5
#            }

# rightArgs = {'vmin': ds.GetRasterBand(1).GetMinimum(),
#              'vmax': ds.GetRasterBand(1).GetMaximum(),
#              'nodata': noDataValue,
#              'opacity': 0.5,
#              'palette': 'viridis',
#             }            

# # The lat/lon bounds in the form [[south, west], [north, east]].
# leafmap.linked_maps(rows=1, 
#                     cols=2, 
#                     layers=[str(inFile), str(labelsFile)],
#                     layer_args=[leftArgs, rightArgs],
#                     fit_bounds=[[lry, ulx], [uly, lrx]],
#                     # center=[xcx, xcy],
#                     # zoom=12,
#                     # basemap=basemaps.Esri.WorldImagery,
#                    )

left = leafmap.get_local_tile_layer(str(inFile),
                                    band=[redBandId, greenBandId, blueBandId],
                                    vmin=redBand.min(),
                                    vmax=redBand.max(),
                                    nodata=noDataValue,
                                    opacity=0.5,
                                    layer_name='Input Image')
 
right = leafmap.get_local_tile_layer(str(labelsFile),
                                     vmin=ds.GetRasterBand(1).GetMinimum(),
                                     vmax=ds.GetRasterBand(1).GetMaximum(),
                                     nodata=noDataValue,
                                     opacity=0.5,
                                     palette='viridis',
                                     layer_name='Clusters')


m = leafmap.Map(fullscreen_control=False,
                layers_control=True,
                search_control=False,
                draw_control=False,
                measure_control=False,
                scale_control=False)

m.add(left)
m.add(right)

# m.add_raster(str(inFile), 
#              band=[redBandId, greenBandId, blueBandId],
#              vmin=redBand.min(),
#              vmax=redBand.max(),
#              nodata=noDataValue,
#              opacity=0.5,
#              layer_name='Input Image')
 
# m.add_raster(str(labelsFile),
#              vmin=ds.GetRasterBand(1).GetMinimum(),
#              vmax=ds.GetRasterBand(1).GetMaximum(),
#              nodata=noDataValue,
#              opacity=0.5,
#              palette='viridis',
#              layer_name='Clusters')
 
m.fit_bounds([[left.bounds[0][0], left.bounds[0][1]],
              [left.bounds[1][0], left.bounds[1][1]]])

m

Map(center=[20, 0], controls=(ZoomControl(options=['position', 'zoom_in_text', 'zoom_in_title', 'zoom_out_text…

## Step 5: review the cluster map.

In [None]:
l = leafmap.Map(google_map='SATELLITE', zoom=25)

l.add_raster(str(labelsFile), 
             vmin=lDs.GetRasterBand(1).GetMinimum(), 
             vmax=lDs.GetRasterBand(1).GetMaximum(),
             opacity=0.5, 
             palette='viridis')

l

## Step 6: update the labels.
Select multiple values by clicking the mouse or using the arrow keys while pressing shirt, control or command.

In [None]:
opts = list(numpy.unique(labels))

sl = widgets.SelectMultiple(options=opts, 
                            layout=(widgets.Layout(height='200px', width='150px')))

bt = widgets.ToggleButtons(options=['Select:', 'Next', 'Done', 'Start Over'], 
                           value='Select:')

output = widgets.Output()
display(sl, bt, output)
table = {}
bt.observe(handleClick, names='value')

## Step 7: edit the groups.
Edit cluster IDs in each group.  When finished, proceed to the next cell.

In [None]:
strTab = {}

for item in table:
    strTab[item] = ', '.join(str(i) for i in table[item])

df = pandas.DataFrame(strTab.items(), columns=['Class', 'Cluster ID'])
sheet = ipysheet.from_dataframe(df)
sheet.column_width = [1, 5]
sheet

In [None]:
editedDf = ipysheet.to_dataframe(sheet)
strClusters = editedDf.to_dict()['Cluster ID']

finalClusters = {}

for key in strClusters:
    
    strCluster = strClusters[key]
    finalClusters[int(key)] = [int(i) for i in strCluster.split(',') if i]

newLabels = relabel(labels, finalClusters)

## Step 8: review the updated map.

In [None]:
clusterMapFile = outDirectory / (inFile.stem + '-cluster-map' + inFile.suffix)
nDs = labelsToGeotiff(clusterMapFile, newLabels)

n = leafmap.Map(google_map='SATELLITE', zoom=25)

n.add_raster(str(clusterMapFile), 
             vmin=nDs.GetRasterBand(1).GetMinimum(), 
             vmax=nDs.GetRasterBand(1).GetMinimum(), 
             opacity=0.5, 
             palette='viridis')

n

In [None]:
Markdown(f'<h3>Complete: the cluster map is at<br><br>{clusterMapFile}</h3>')