# Image Clustering Notebook
The notebook is to use K-means cluster for unsupervised classificaiton on multi-band raster data <br>
Folium serves as visualization backend <br>

<br>
Basic workflow: <br>
<ul>
    <li> Read raster data </li>
    <li> Plot raster data </li>
    <li> Run clustering </li>
    <li> Plot result with raster </li>
    <li> Re-label clusters </li>
    <li> Plot updated result </li>
    <li> Write out clustering labels </li>
</ul>

<br>

<i>version 1.0  Date: Mar-2022</i>

<br>
<i>version 1.1  Date: Apr.-2022:
<ul> 
    <li> Use geoViews for vis </li>
    <li> Add widget to select lables </li> 
<i/>


In [None]:
import rasterio as rio
import rioxarray as rxr
import xarray as xr
from sklearn import cluster
import numpy as np
import geopandas as gpd

from rasterio.plot import reshape_as_image
import geoviews as gv

import holoviews.operation.datashader as hd

import ipywidgets as widgets
import panel as pn
pn.extension()

In [None]:
import hvplot.xarray

In [None]:
import warnings
warnings.filterwarnings("ignore")

In [None]:
#using bokeh backend
gv.extension('bokeh')

In [None]:
from cartopy import crs
from joblib import cpu_count

In [None]:
import matplotlib.pyplot as plt


In [None]:
def get_cluster(raster_file, bands=None, n_cluster=5):
    src = rio.open(raster_file)
    if bands:
        arr = src.read(bands)
    else:
        arr = src.read()
    
    img = reshape_as_image(arr)
    img1d = img.reshape(-1, img.shape[-1])
    
    #parameters for KMeans
    params = {
    'n_clusters' : n_cluster,
    'random_state' : 0,
    'batch_size' : 256*cpu_count()
    }
    cl = cluster.MiniBatchKMeans(**params)
    model = cl.fit(img1d)
    # 
    img_cl = model.labels_
    img_cl = img_cl.reshape(img[:,:,0].shape)
    return img_cl


def relabel(lab_arr, lookup):
    new_lab = lab_arr.copy()
    for k, v in lookup.items():
        if len(v)==1 & k==v[0]:
            pass
        else:
            new_lab = np.where(np.isin(new_lab, v), k, new_lab)
    return new_lab

def np2ds(x, y, arr):
    # load 2-D np.array to gv.dataset
    # x: vector          x-coord (longitude)
    # y: vector          y-coord (latitude)
    # arr: 2-D Array     shape(x,y)
    # returns:
    #     min(arr), max(arr), gv.DataSet
    return arr.min(), arr.max(), gv.Dataset((x, y, arr), ['longitude', 'latitude'], 'class')

In [None]:
from datashader.utils import ngjit

@ngjit
def normalize_data(agg):
    out = np.zeros_like(agg)
    min_val = np.nanmin(agg)
    max_val = np.nanmax(agg)
    range_val = max_val-min_val
    
    cols, rows = agg.shape
#    c = 40
#    th = .125
    for x in range(cols):
        for y in range(rows):
            val = agg[x, y]
            norm = (val-min_val)/range_val
            #norm = 1/(1+np.exp(c*(th-norm)))
            out[x,y] = round(norm*255.0)
    return out

In [None]:
# path to raster
#rfile = "landsat.tif"
rfile = "HLS_3000_06-01_09-15_2014_2014.tif"

# load raster with rioxarry, can be replaced by other libs
ds = rxr.open_rasterio(rfile)

# show raster profile & spatial reference
print(ds)
print(ds.rio.crs)

## Show raster map

In [None]:
# specify rgb bands 
rb, gb, bb = (3,2,1)

# extract bands values 
r,g,b = map(np.array, [ds.sel(band=rb).squeeze(), ds.sel(band=gb).squeeze(), ds.sel(band=bb).squeeze()])

# this is extra step for HLS data;
# the data does't mark -9999 as undef in Meta
r = np.where(r<0, np.nan, r)
g = np.where(g<0, np.nan, g)
b = np.where(b<0, np.nan, b)

# normalize if necessary
r = normalize_data(r).astype(np.uint8)
g = normalize_data(g).astype(np.uint8)
b = normalize_data(b).astype(np.uint8)


In [None]:
# set basemap
# other available basemap samples  https://geoviews.org/user_guide/Working_with_Bokeh.html 
base_map = gv.tile_sources.EsriImagery.opts(width=800, height=800)

In [None]:
# set options for raster map
img_opts = dict(width=600, height=600)


### Trying to show raster RGB band


### op1: reproject to AlbersEqualArea
Obs: oddness around bounds

In [None]:
raster_layer = gv.RGB((ds.x, ds.y, r, g, b), vdims=list('RGB'), crs=crs.AlbersEqualArea()).opts(**img_opts)
raster_layer

### opgeom = gpd.points_from_xy(ds.x.values, ds.y.values, crs=ds.rio.crs)
geom2: reproject coords (x/y -> lat/lon)


In [None]:
geom = gpd.points_from_xy(ds.x.values, ds.y.values, crs=ds.rio.crs)
geom_reproj = geom.to_crs(4326)

In [None]:
raster_layer_2 = gv.RGB((geom_reproj.x, geom_reproj.y, r, g, b), vdims=list('RGB')).opts(**img_opts)

In [None]:
raster_layer_2

In [None]:
# web rendering is slow for large image
pn.Column(raster_layer*base_map)

## Run K-Means Clustering for Image
User can run clustering for subset of bands by giving a list; If not specified, all bands will be included  

In [None]:
## TODO: adding method to remove/fill missing data of image

In [None]:
%%time
# get_cluster takes dataset object (not np.array) 
# and (optional) list of bands as inputs 
labels = get_cluster(rfile, bands=(3,2,1), n_cluster=30)
labels.shape

## Show cluster map


In [None]:

# load data into a Dataset

########
## Note: need to take care of image projection here
########

#minv, maxv, cls_ds = np2ds(ds.x, ds.y, labels)
minv, maxv, cls_ds = np2ds(geom_reproj.x, geom_reproj.y, labels)

# set options for cluster map
cls_opts = dict(
    width=600, 
    height=600, 
    cmap="Category20b",
    colorbar=False,
    clim = (minv, maxv),
    
    tools=["hover"], active_tools=['wheel_zoom']
)
cluster_layer = gv.Image(cls_ds, crs=crs.PlateCarree()).opts(**cls_opts)

#pn.Column(cluster_layer.opts(**cls_opts)*base_map.opts(height=600, width=600))

In [None]:
cluster_layer

## Display maps side-by-side 


In [None]:
import time
s = time.time()
pn.Column(cluster_layer*raster_layer*base_map + cluster_layer*base_map)


## Update labels 
User can aggregate classes into new groups through loop-up table

In [None]:
tab = {}
def update_dict(op):
    if op == 'N':
        key = list(sl.value)[0]
        tab[key] = list(sl.value)
        print("Re-grouping : ", tab)
    if op == 'D':
        if len(sl.options) > 0:
            key = list(sl.options)[0]
            tab[key] = list(sl.options)
        print("Final Groups : ", tab)
    if op == 'S':
        tab.clear()
        print("Start Over : ", tab)
    
def update_list(old, out):
    return [ele for ele in old if ele not in out]

print("Multiple values can be selected with shift and/or ctrl (or command) pressed and mouse clicks or arrow keys.")
opts = list(np.unique(labels))
sl = widgets.SelectMultiple(options=opts, layout=(widgets.Layout(height='200px', width='150px')))
bt = widgets.ToggleButtons(options=['Select:', 'Next', 'Done', 'Start Over'], value='Select:')
output = widgets.Output()
display(sl, bt, output)
def handle_click(change):
    with output:
        if change.new == "Next":
            nn = update_list(list(sl.options), list(sl.value))
            update_dict("N")
            sl.options = nn
            bt.value='Select:'
        if change.new == "Done":
            update_dict("D")
        if change.new == "Start Over":
            sl.options = opts
            update_dict("S")
            bt.value = 'Select:'
            

            #sl.value = n[-1]

bt.observe(handle_click, names='value')

In [None]:
input("> Press Enter to Continue...\n")

# This cell updates clusters with new labels & converts labels into sequential one
new_labels = relabel(labels, tab)

t = {}
for i, v in enumerate(np.unique(new_labels)):
    t[i] = [v]

new_labels_seq = relabel(new_labels, t)
print("New Classes", np.unique(new_labels_seq))

In [None]:
# load data into a Dataset
minv, maxv, cls_ds_new = np2ds(ds.x, ds.y, new_labels_seq)

# set options for cluster map

cls_opts = dict(
    width=600, 
    height=600, 
    cmap="Category20b",
    colorbar=False,
    clim = (minv, maxv),   
    tools=["hover"], active_tools=['wheel_zoom']
)

cluster_layer_new = gv.Image(cls_ds_new, crs=crs.PlateCarree()).opts(**cls_opts)


In [None]:
pn.Column(cluster_layer_new*raster_layer*base_map+cluster_layer_new*base_map)

## Export labels to geoTiff

In [None]:
# path to output file
save_path = "./landsat_cluster.tif"

# labels to write out
rslt = new_labels_seq

In [None]:
'''
# using profile from input raster
with rio.open(rfile) as src:
    profile = src.meta
    # update number of bands
    profile.update(count=1)
    
    with rio.open(save_path, 'w', **profile) as dst:
        ds.write(rslt, indexes = 1)
'''

In [None]:
'''
# (optional) double check write-out labels
with rio.open(save_path) as ck:
    print(ck.meta)
    print(np.unique(ck.read()))
'''