In [None]:
import os
from glob import glob
import rasterio as rio
import numpy as np
from tqdm import tqdm,tqdm_notebook
from sklearn.cluster import MiniBatchKMeans
from rasterio.plot import reshape_as_image


from IPython.core.display import display, HTML
display(HTML("<style>.container { width:90% !important; }</style>"))

In [None]:

def compute_rws(stack_arr,rws_type):
    np.seterr(divide='ignore', invalid='ignore')
    
    if rws_type == 'WIW':
        b8a = stack_arr[7].astype(np.float32)/10000
        b12 = stack_arr[9].astype(np.float32)/10000
        rws = np.where( ((b8a<=0.1804) & (b12<=0.1131)), 1, 0)

    elif rws_type == 'RWS':
        b3= stack_arr[1].astype(np.float32)/10000
        b11 = stack_arr[8].astype(np.float32)/10000
        mndwi = np.where(((b3-b11)/(b3+b11))>0.3,1,0)
        mgrn = (stack_arr[[1,2,6]].astype(np.float32)/10000).min(0)
        rws = np.where((mndwi>0.3)&((mgrn>0) & (mgrn<0.15)),1,0)
    
    nodata_mask = np.where(stack_arr==0,1,0).max(0)
    rws_img = np.where((nodata_mask!=1) & (rws==1),stack_arr,0)
    return rws_img
        
def compute_cluster(rws_rgb_img,k=4):
    samples = reshape_as_image(rws_rgb_img).reshape(-1,rws_rgb_img.shape[0])
    kmeans_pred = MiniBatchKMeans(n_clusters=k+1, random_state=42,max_iter=10,batch_size=10000,reassignment_ratio=0).fit(samples)
    kmeans_pred_img = kmeans_pred.labels_.reshape(rws_rgb_img.shape[1], rws_rgb_img.shape[2]).astype(rio.uint16)
    return kmeans_pred_img

def compute_mnws(img,cluster_img,bands=['B2','B3','B4','B8','B11','B12']):
    band_names = ['B2','B3','B4','B5','B6','B7','B8','B8A','B11','B12']
    band_pos = list(map(lambda x:band_names.index(x),bands))
    
    nodata = np.where(img[band_pos]==0,1,0).max(0)
    raw_img = np.where(nodata!=1,img[band_pos],0)

    mnws = []
    labels = list(range(1,cluster_img.max()+1))
    for label in labels:
        
        #calculate band stats
        region_img = np.where(cluster_img==label,raw_img,0)
        band_means = np.array(list(map(lambda x:np.mean(region_img[x][region_img[x]!=0],dtype=np.float32),
                                       range(len(band_pos))))).reshape(len(band_pos),-1)
        band_std = np.array(list(map(lambda x:np.std(region_img[x][region_img[x]!=0],dtype=np.float32),
                                     range(len(band_pos))))).reshape(len(band_pos),-1)
        
        #calculate nws 
        reshaped_raw_img = raw_img.reshape(len(band_pos),-1)
        nws = (((((reshaped_raw_img-band_means)/band_std)**2).sum(0)/len(band_pos))**0.5).reshape(img.shape[1],img.shape[2])
        mnws.append(nws)
        
    mnws_img = np.array(mnws).min(0)
    mnws_img_clip = np.array([np.where(nodata!=1,mnws_img,0)])
    return mnws_img_clip

def hot(blue_band,red_band):
    b02 = blue_band.astype(np.float32)/10000
    b04 = red_band.astype(np.float32)/10000
    hot_img = b02 - (0.5*b04)
    return hot_img

In [None]:
mosaic_dir = f'{os.path.abspath("..")}images'
mosaics = glob(mosaic_dir+"/*.tif")

In [None]:
%%time
band_names = ['B2','B3','B4','B5','B6','B7','B8','B8A','B11','B12']

with rio.open(mosaics[-1]) as src:
    profile = src.profile.copy()
    img = src.read()
    rws_rgb_img = compute_rws(img,'RWS')[[0,1,2]]

    cluster_img = compute_cluster(rws_rgb_img,k=8)
    mnws_img = compute_mnws2(img,cluster_img)
    
    profile.update({'dtype':mnws_img.dtype,'nodata':0,'count':mnws_img.shape[0]})
    with rio.open('mnws_last.tif','w',**profile) as dst:
        dst.write(mnws_img)
