# Image Cluster Notebook

In [2]:
# inFile = Path('/explore/nobackup/people/rlgill/innovation-lab-repositories/ImageCluster/my4326.tif')
# redBandId = 3
# greenBandId = 2
# blueBandId = 1

# inFile = Path('/explore/nobackup/people/mcarrol2/LCLUC_Senegal/ForKonrad/Tappan26_WV03_20210314_M1BS_10400100676F0900_isos.tif')
# redBandId = 1
# greenBandId = 1
# blueBandId = 1

# inFile = Path('/explore/nobackup/people/mcarrol2/LCLUC_Senegal/ForKonrad/Tappan26_WV03_20210314_M1BS_10400100676F0900_data.tif')
# redBandId = 5
# greenBandId = 4
# blueBandId = 2

# noDataValue = -9999.0 or None
# outDirectory = inFile.parent

# loadImage(inFile=inFile, 
#           redBandId=redBandId, 
#           greenBandId=greenBandId, 
#           blueBandId=blueBandId, 
#           noDataValue=noDataValue)

In [3]:
from ipyleaflet import basemaps
from ipyleaflet import TileLayer
import ipysheet
from IPython.display import Markdown
import ipywidgets as widgets
from joblib import cpu_count
# import leafmap.leafmap as leafmap
import leafmap
import numpy
from osgeo import gdal
import pandas
from pathlib import Path
from sklearn import cluster

In [4]:
# ----------------------------------------------------------------------------
# getCluster
# ----------------------------------------------------------------------------
def getCluster(ds: gdal.Dataset, bands: list=None, n_cluster: int=5):
    
    img = numpy.moveaxis(bands, 0, -1)
    img1d = img.reshape(-1, img.shape[-1])
    
    params = {
        'n_clusters' : n_cluster,
        'random_state' : 0,
        'batch_size' : 256*cpu_count()
    }
    
    cl = cluster.MiniBatchKMeans(**params)
    model = cl.fit(img1d)
    imgCl = model.labels_
    imgCl = imgCl.reshape(img[:,:,0].shape)

    return imgCl

# ----------------------------------------------------------------------------
# handleClick
# ----------------------------------------------------------------------------
def handleClick(change: dict) -> None:
    
    with output:
        
        if change.new == 'Next':

            nn = updateList(list(sl.options), list(sl.value))
            updateDict('N')
            sl.options = nn
            bt.value='Select:'

        if change.new == 'Done':
            updateDict('D')

        if change.new == 'Start Over':
        
            sl.options = opts
            updateDict('S')
            bt.value = 'Select:'

# ----------------------------------------------------------------------------
# labelsToGeotiff
# These renderers seem to write tiles or pyramids to disk behind the scenes.  
# To make rendering code easier, write the labels as a geotiff; then the
# renderer will not need to do it.
# ----------------------------------------------------------------------------
def labelsToGeotiff(labelsFile: Path, labels: numpy.ndarray) -> gdal.Dataset:
        
    labelsDs = gdal.GetDriverByName('GTiff').Create( \
                    str(labelsFile),
                    xsize=ds.RasterXSize,
                    ysize=ds.RasterYSize,
                    eType=gdal.GDT_Float32
               )

    labelsDs.SetSpatialRef(ds.GetSpatialRef())
    labelsDs.SetGeoTransform(ds.GetGeoTransform())
    outBand = labelsDs.GetRasterBand(1)
    outBand.WriteArray(labels)
    outBand.ComputeStatistics(0)  # For min, max used in color table
    labelsDs.FlushCache()
    labelsDs.BuildOverviews()

    return labelsDs

# ----------------------------------------------------------------------------
# relabel
# ----------------------------------------------------------------------------
def relabel(labelArray: numpy.ndarray, lookup: dict) -> numpy.ndarray:
    
    newLab = labelArray.copy()
    
    for k, v in lookup.items():

        if len(v)==1 & k==v[0]:
            pass
        else:
            newLab = numpy.where(numpy.isin(newLab, v), k, newLab)
    
    return newLab
            
# ----------------------------------------------------------------------------
# updateDict
# ----------------------------------------------------------------------------
def updateDict(op: str) -> None:
    
    if op == 'N':
        
        key = list(sl.value)[0]
        table[key] = list(sl.value)
        # print('Re-grouping : ', table)
    
    if op == 'D':
    
        if len(sl.options) > 0:
            
            key = list(sl.options)[0]
            table[key] = list(sl.options)
        
        print('Final Groups : ', table)
    
    if op == 'S':
        
        table.clear()
    
# ----------------------------------------------------------------------------
# updateList
# ----------------------------------------------------------------------------
def updateList(old: list, out: list) -> list:
    return [ele for ele in old if ele not in out]


In [25]:
# ----------------------------------------------------------------------------
# Class ImageHelper
#
# TODO: add accessors using @property.
# ----------------------------------------------------------------------------
class ImageHelper(object):
    
    # ------------------------------------------------------------------------
    # init
    # ------------------------------------------------------------------------
    def __init__(self,
                 inputFile: Path,
                 noDataValue: float,
                 redBandId: int = 1,
                 blueBandId: int = 1,
                 greenBandId: int = 1,
                ):
        
        # TODO: check validity and comment.
        self._inputFile: Path = inputFile
        self._redBandId: int = redBandId
        self._greenBandId: int = greenBandId
        self._blueBandId: int = blueBandId
        self._layer: TileLayer = None
 
        # ---
        # Build overviews for rendering.  Cloud-optimized Geotiffs, CoGs, rely on
        # tiling and overview images. Before we bother creating CoGs, which 
        # entails an additional image file, see how adding overviews to the input
        # image helps.
        # Assigning 'dummy' prevents extraneous output in the notebook.
        # ---
        self._dataSet: gdal.Dataset = gdal.Open(str(self._inputFile))
 
        if self._dataSet.GetRasterBand(1).GetOverviewCount() == 0:
            dummy =self. _dataSet.BuildOverviews()

        # ---
        # Read the bands.
        # ---
        self._redBand: numpy.ndarray = \
            self._dataSet.GetRasterBand(self._redBandId).ReadAsArray()
        
        self._greenBand: numpy.ndarray = \
            self._dataSet.GetRasterBand(self._greenBandId).ReadAsArray()
        
        self._blueBand: numpy.ndarray = \
            self._dataSet.GetRasterBand(self._blueBandId).ReadAsArray()
        
        # ---
        # Initialize the no-data value.
        # ---
        self._noDataValue = self._dataSet.GetRasterBand(self._redBandId). \
            GetNoDataValue() or noDataValue
       
        # ---
        # Compute the minimum and maximum pixels values to help the renderer.
        # ---
        forExtremes = self._redBand[self._redBand != self._noDataValue]
        self._minValue: float = forExtremes.min()
        self._maxValue: float = forExtremes.max()

    # ------------------------------------------------------------------------
    # getLayer
    # ------------------------------------------------------------------------
    def getLayer(self) -> TileLayer:

        if not self._layer:
            
            bands = [self._redBandId, self._greenBandId, self._blueBandId]

            # self._layer = leafmap.get_local_tile_layer(\
            self._layer = leafmap.add_raster(
                    str(self._inputFile),
                    band=bands,
                    vmin=self._minValue,
                    vmax=self._maxValue,
                    nodata=-self._noDataValue,
                    opacity=0.5,
                    layer_name=self._inputFile.name)
            
        return self._layer
    

In [26]:
# ----------------------------------------------------------------------------
# Class ImageHelperCallback
#
# When a file is choosen, this is the call-back function for the FileDialog().
# This is necessary to provide a handle to the map, to which the chosen
# image is added.
#
# TODO: Add accessors using @property.
# TODO: Add user input for bands and no-data value.
# ----------------------------------------------------------------------------
class ImageHelperCallback(object):
    
    # ------------------------------------------------------------------------
    # __init__
    # ------------------------------------------------------------------------
    def __init__(self, aMap: leafmap.Map) -> None:
        
        self._map: leafmap.Map = aMap
        self._imageHelper: ImageHelper = None
    
    # ------------------------------------------------------------------------
    # addImage
    # ------------------------------------------------------------------------
    def addImage(self, chooser) -> None:

        self._imageHelper = ImageHelper(inputFile=Path(chooser.selected), 
                                        noDataValue=-9999.0, 
                                        redBandId=3, 
                                        greenBandId=2, 
                                        blueBandId=1)

        layer = self._imageHelper.getLayer()
        self._map.add(layer)
        # self._map.add_raster(chooser.selected)
        
        # ---
        # Fit_bounds works for projections that are not Mercator, when you
        # first set the center.  Junk.
        # ---
        self._map.center=[layer.bounds[0][0], layer.bounds[0][1]]

        self._map.fit_bounds([[layer.bounds[0][0], layer.bounds[0][1]],
                              [layer.bounds[1][0], layer.bounds[1][1]]])
        
        # ---
        # Remove the file chooser.  I would prefer to hide it.
        # ---
        for control in self._map.controls:
            
            if isinstance(control, WidgetControl) and \
                isinstance(control.widget, FileChooser):

                self._map.remove_control(control)

        from leafmap import toolbar
        toolbar.inspector_gui()


In [27]:
# ----------------------------------------------------------------------------
# Class Clusterer
# ----------------------------------------------------------------------------
class Clusterer(object):

    pass


In [28]:
# ---
# We could use leafmap.add_raster, but this way we get a handle to the tile.
# ---
# clusterLayer = leafmap.get_local_tile_layer( \
#                                      str(labelsFile),
#                                      vmin=lDs.GetRasterBand(1).GetMinimum(),
#                                      vmax=lDs.GetRasterBand(1).GetMaximum(),
#                                      nodata=noDataValue,
#                                      opacity=0.5,
#                                      palette='viridis',
#                                      layer_name='Clusters')

# ---
# The only reason I am using leafmap and not ipyleaflet directly is to get
# the inspector tool under the wrench button on leafmap's map.
# ---
m = leafmap.Map(fullscreen_control=False,
                layers_control=True,
                search_control=False,
                draw_control=False,
                measure_control=False,
                scale_control=False,
                toolbar_control=True
               )

# ----------
from ipyleaflet import WidgetControl
from ipyfilechooser import FileChooser
fileDialog = FileChooser()
fileDialog.filter_pattern = ['*.tif']
inputCallBack = ImageHelperCallback(m)
fileDialog.register_callback(inputCallBack.addImage)
fileDialogControl = WidgetControl(widget=fileDialog, description='Input Image')
m.add_control(fileDialogControl)
# ----------



m

Map(center=[20, 0], controls=(ZoomControl(options=['position', 'zoom_in_text', 'zoom_in_title', 'zoom_out_text…

## Update the labels.
Select multiple values by clicking the mouse or using the arrow keys while pressing shirt, control or command.

In [None]:
opts = list(numpy.unique(labels))

sl = widgets.SelectMultiple(options=opts, 
                            layout=(widgets.Layout(height='200px', width='150px')))

bt = widgets.ToggleButtons(options=['Select:', 'Next', 'Done', 'Start Over'], 
                           value='Select:')

output = widgets.Output()
display(sl, bt, output)
table = {}
bt.observe(handleClick, names='value')

## Edit the groups.
Edit cluster IDs in each group.  When finished, proceed to the next cell.

In [None]:
strTab = {}

for item in table:
    strTab[item] = ', '.join(str(i) for i in table[item])

df = pandas.DataFrame(strTab.items(), columns=['Class', 'Cluster ID'])
sheet = ipysheet.from_dataframe(df)
sheet.column_width = [1, 5]
sheet

In [None]:
editedDf = ipysheet.to_dataframe(sheet)
strClusters = editedDf.to_dict()['Cluster ID']

finalClusters = {}

for key in strClusters:
    
    strCluster = strClusters[key]
    finalClusters[int(key)] = [int(i) for i in strCluster.split(',') if i]

newLabels = relabel(labels, finalClusters)

## Review the updated map.

In [None]:
clusterMapFile = outDirectory / (inFile.stem + '-cluster-map' + inFile.suffix)
nDs = labelsToGeotiff(clusterMapFile, newLabels)

newClusterLayer = leafmap.get_local_tile_layer( \
                                     str(clusterMapFile),
                                     vmin=lDs.GetRasterBand(1).GetMinimum(),
                                     vmax=lDs.GetRasterBand(1).GetMaximum(),
                                     nodata=noDataValue,
                                     opacity=0.5,
                                     palette='viridis',
                                     layer_name='New Clusters')
 
n = leafmap.Map(fullscreen_control=False,
                layers_control=True,
                search_control=False,
                draw_control=False,
                measure_control=False,
                scale_control=False,
                google_map='SATELLITE',
                center=[inLayer.bounds[0][0], inLayer.bounds[0][1]])

n.add(newClusterLayer)

n.fit_bounds([[inLayer.bounds[0][0], inLayer.bounds[0][1]],
              [inLayer.bounds[1][0], inLayer.bounds[1][1]]])

display(n)

Markdown(f'<h3>The cluster map is at<br><br>{clusterMapFile}</h3>')