### Spectral Clustering Example.

This is a modification of TomAugspurger's Spectral Clustering example (noted in https://github.com/dask/dask-ml/issues/151) which has been modified to process GeoTIFF's using xarrays.

I also followed the example from https://www.planet.com/docs/api-quickstart-examples/step-2-download/ to download a geotif file.

In addition to `dask-ml`, we'll use `rasterio` to read the data and `matplotlib` to plot the figures.  I'm just working on my laptop, so we could use either the threaded or distributed scheduler. I'll use the distributed scheduler for the diagnostics.

In [1]:
%matplotlib inline

In [2]:
import sys
import rasterio
import matplotlib.pyplot as plt
import dask.array as da
from dask_ml.cluster import SpectralClustering
from dask.distributed import Client

In [3]:
import os

fname = 'Midwest_Mosaic.tif'
url = 'https://github.com/ebo/pangeo-tutorials/raw/master/data/Landsat_Mosaics/'+fname

#fname = 'QB02_20081221115153_1010010008F1C600_08DEC21115153-M1BS-500324873040_01_P002_u16rf32628_pan_pansh_stack.tif'
#url = 'file:/data1/new_stacked_images_meta/'+fname

if not os.path.exists(fname):
    import urllib.request
    print("downloading test image file '%s'"%fname)
    urllib.request.urlretrieve(url, fname)

In [4]:
#client = Client(memory_limit=24e9, n_workers=1, threads_per_worker=4)

# when I use processes=False I get a different error 
client = Client(memory_limit=24e9, processes=False) 
client

0,1
Client  Scheduler: inproc://169.154.136.32/17577/1  Dashboard: http://localhost:44432/status,Cluster  Workers: 1  Cores: 8  Memory: 24.00 GB


In [5]:
from rasterio import windows
with rasterio.open(fname, 'r', chunks={'x':100, 'y':100}) as src:
    profile = src.profile
    tags = src.tags()
    
import xarray as xr
arr = xr.open_rasterio(fname, chunks={'band': 1, 'x': 2048, 'y': 2048})
arr

  import sys


<xarray.DataArray (band: 3, y: 1227, x: 1343)>
dask.array<shape=(3, 1227, 1343), dtype=uint8, chunksize=(1, 1227, 1343)>
Coordinates:
  * band     (band) int64 1 2 3
  * y        (y) float64 4.986e+06 4.985e+06 4.985e+06 4.984e+06 4.984e+06 ...
  * x        (x) float64 1.939e+05 1.944e+05 1.948e+05 1.953e+05 1.957e+05 ...
Attributes:
    transform:   (193671.75, 456.0, 0.0, 4986075.0, 0.0, -456.0)
    crs:         +init=epsg:32615
    res:         (456.0, 456.0)
    is_tiled:    0
    nodatavals:  (nan, nan, nan)

In [6]:
arr = arr.astype(float)

# Rescale for the clustering algorithm
arr = (arr - arr.mean()) / arr.std()

In [7]:
# Subsample to the upper-left quadrant for viewing
#plt.imshow(arr[:2048, :2048].compute()) # distributed.protocol.core - CRITICAL - Failed to Serialize
# try this: plt.imshow(arr.compute()[:2048, :2048])

We'll reshape the image to be how dask-ml / scikit-learn expect it: `(n_samples, n_features)` where n_features is 1 in this case. Then we'll persist that in memory. We still have a small dataset at this point. The large dataset, which dask helps us manage, is the intermediate `n_samples x n_samples` array that spectral clustering operates on.

In [20]:
import numpy as np
#X = da.from_array(arr.reshape(-1, 1), chunks=100_000)
X = arr.reduce(np.ndarray.flatten)

TypeError: 'axis' is an invalid keyword argument for this function

In [None]:
X = client.compute(X) # was persist

And we'll fit the estimator.

In [None]:
clf = SpectralClustering(n_clusters=4, random_state=0,
                         kmeans_params={'init_max_iter': 5})

In [None]:
%time clf.fit(X)

In [None]:
labels = clf.assign_labels_.labels_.compute()

c = labels.reshape(arr.shape)

fig, axes = plt.subplots(ncols=2, figsize=(12, 6))
axes[0].imshow(arr[:2500, :2500])
axes[1].imshow(c[:2500, :2500]);

axes[0].set_title("Image")
axes[1].set_title("Clustered")

for ax in axes:
    ax.xaxis.set_visible(False)
    ax.yaxis.set_visible(False)

In [None]:
# save the results back as a GeoTIFF
outfile = os.path.splitext(fname)[0]+"_out.tif"

# create the output as a single band, not as the 3 that came in
profile['count'] = 1

with rasterio.open(outfile,'w',chunks={'x':100, 'y':100}, **profile) as dst:
    # skip the band only tags and propegate the image associated tags
    dst.update_tags(**tags)
    
    # output the classified array
    dst.write(c.astype(profile['dtype']), 1)