In [None]:
import os
from glob import glob
import rasterio as rio
import numpy as np
from tqdm import tqdm,tqdm_notebook
from IPython.core.display import display, HTML
display(HTML("<style>.container { width:90% !important; }</style>"))

In [None]:
mosaic_dir = f'{os.path.abspath("..")}images'
mosaics = glob(mosaic_dir+"/*.tif")

In [None]:
from sklearn.cluster import MiniBatchKMeans
from rasterio.plot import reshape_as_image
from matplotlib.pyplot import imshow,show

def compute_rws(stack_arr,rws_type):
    np.seterr(divide='ignore', invalid='ignore')
    
    if rws_type == 'WIW':
        b8a = stack_arr[7].astype(float)/10000
        b12 = stack_arr[9].astype(float)/10000
        rws = np.where( ((b8a<=0.1804) & (b12<=0.1131)), 1, 0)

    elif rws_type == 'RWS':
        b3= stack_arr[1].astype(float)/10000
        b11 = stack_arr[8].astype(float)/10000
        mndwi = (b3-b11)/(b3+b11)
        mndwi = np.where(mndwi>0.3,1,0)
        
        mgrn = (stack_arr[[1,2,6]].astype(float)/10000).min(axis=0)
        mgrn = np.where((mgrn>0) & (mgrn<0.15),1,0)
        rws = np.where(np.array([mndwi,mgrn]).sum(0)==2,1,0)
    
    nodata_mask = np.where(stack_arr==0,1,0).max(0)
    rws = np.where(nodata_mask!=1,rws,0)
    return rws.astype(rio.uint16)
        
def extract_rws(stack_arr,rws,bands=['B2','B3','B4']):
    band_names = ['B2','B3','B4','B5','B6','B7','B8','B8A','B11','B12']
    band_pos = list(map(lambda x:band_names.index(x),bands))
    img = stack_arr[band_pos]
    rws_img = np.where(rws==1,img,0)
    return rws_img

def compute_cluster(rws_rgb_img,k=4):
    samples = reshape_as_image(rws_rgb_img).reshape(-1,rws_rgb_img.shape[0])
    kmeans_pred = MiniBatchKMeans(n_clusters=k, random_state=42,max_iter=10,batch_size=10000,reassignment_ratio=0).fit(samples)
    kmeans_pred_img = kmeans_pred.labels_.reshape(rws_rgb_img.shape[1], rws_rgb_img.shape[2]).astype(rio.uint16)
    return kmeans_pred_img

In [None]:
%%time
band_names = ['B2','B3','B4','B5','B6','B7','B8','B8A','B11','B12']

with rio.open(mosaics[0]) as src:
    profile = src.profile.copy()
    img = src.read()
    rws = compute_rws(img,'RWS')
    rws_rgb_img = extract_rws(img,rws)
    cluster_img = compute_cluster(rws_rgb_img,9)
    
    mnws = []
    labels = np.unique(cluster_img)[1:]
    for label in labels:
        region = np.where(cluster_img==label,1,0)
        region_img = np.where(region==1,img,0)[[0,1,2,6,8,9]].reshape(6,-1)
        band_means = np.nanmean(np.where(region_img!=0,region_img,np.nan),1).reshape(6,-1)
        band_std = np.nanstd(np.where(region_img!=0,region_img,np.nan),1).reshape(6,-1)

        raw_img = img[[0,1,2,6,8,9]].reshape(6,-1)

        band_sum = np.power(( np.subtract(raw_img,band_means,dtype=np.float32) /band_std),2).sum(0) 
        nws_img = np.sqrt((band_sum/6)).reshape(region.shape[0],region.shape[1])
        mnws.append(nws_img)


    mnws_img = np.array(mnws).sum(0).astype(np.float32)
    nodata_mask = np.where(img==0,1,0).max(0)
    mnws_img_clip = np.array([np.where(nodata_mask!=1,mnws_img,0)])
    
    profile.update({'dtype':np.float32,'nodata':0,'count':1})
    with rio.open('mnws.tif','w',**profile) as dst:
        dst.write(mnws_img_clip)
