In [None]:
import os
from glob import glob
import rasterio as rio
import numpy as np
from tqdm import tqdm,tqdm_notebook
from rasterio.plot import reshape_as_image,reshape_as_raster
import re
import pandas as pd
import json
import geopandas as gpd

import seaborn as sns
from rasterio.features import rasterize
import matplotlib.pyplot as plt
from IPython.core.display import display, HTML
display(HTML("<style>.container { width:90% !important; }</style>"))

from sklearn.cluster import MiniBatchKMeans
from rasterio.features import sieve
from rasterio.plot import reshape_as_image
from affine import Affine

In [None]:
from scipy.ndimage.filters import uniform_filter
from scipy.ndimage.measurements import variance
from skimage.filters import threshold_otsu,threshold_multiotsu

def lee_filter(img, size):
    img_mean = uniform_filter(img, (size, size))
    img_sqr_mean = uniform_filter(img**2, (size, size))
    img_variance = img_sqr_mean - img_mean**2

    overall_variance = variance(img)

    img_weights = img_variance / (img_variance + overall_variance)
    img_output = img_mean + img_weights * (img - img_mean)
    return img_output

def compute_s1_rws(band):
    thresh = threshold_multiotsu(band,classes=3)
    regions = np.digitize(band, bins=thresh)
    s1_rws = np.where(regions==0,1,0).astype(rio.uint8)
    return s1_rws

def compute_s1_cluster(s1_rws,k=4):
    samples = reshape_as_image(s1_rws).reshape(-1,s1_rws.shape[0])
    kmeans_pred = MiniBatchKMeans(n_clusters=k+1, random_state=42,max_iter=10,batch_size=10000,reassignment_ratio=0).fit(samples)
    kmeans_pred_img = kmeans_pred.labels_.reshape(s1_rws.shape[1], s1_rws.shape[2]).astype(rio.uint16)
    return kmeans_pred_img

def compute_s1_mnws(img,cluster_img):
    
    nodata = np.where(img==0,1,0).max(0)
    raw_img = np.where(nodata!=1,img,0)

    mnws = []
    
    max_i = np.argmax(np.unique(cluster_img,return_counts=True)[1])
    all_labels = list(range(0,cluster_img.max()+1))
    labels = list(set(all_labels)-set([all_labels[max_i]]))

    for label in labels:
        
        #calculate band stats
        region_img = np.where(cluster_img==label,raw_img,0)
        band_means = np.array(list(map(lambda x:np.mean(region_img[x][region_img[x]!=0],dtype=np.float32),
                                       range( len(img) )))).reshape(len(img) ,-1)
        band_std = np.array(list(map(lambda x:np.std(region_img[x][region_img[x]!=0],dtype=np.float32),
                                     range( len(img) )))).reshape( len(img) ,-1)
        
        #calculate nws 
        reshaped_raw_img = raw_img.reshape(len(img),-1).astype(np.float32)
        nws = (((((reshaped_raw_img-band_means)/band_std)**2).sum(0)/len(img))**0.5).reshape(img.shape[1],img.shape[2])
        mnws.append(nws)
        
    mnws_img = np.array(mnws).min(0)
    mnws_img_clip = np.array([np.where(nodata!=1,mnws_img,0)])
    return mnws_img_clip

In [None]:
%%time

#Speckle filter (Lee's)
s1_files = glob(f'{os.path.abspath("..")}s1_2017/*.tif')

for i in tqdm(range(len(s1_files)),position=0, leave=True):
    file = s1_files[i]

    with rio.open(file) as src:
        profile = src.profile.copy()
        img = src.read()

        vv,vh = img[0].copy(),img[1].copy()
        vv[np.isnan(vv)]=0
        vh[np.isnan(vh)]=0

        vh_lee = lee_filter(vh,5)
        vv_lee = lee_filter(vv,5)
        
        s1_stack = np.where(np.isnan(img[0])!=1,np.array([vv_lee,vh_lee]),0)
        
        s1_rws = np.where(compute_s1_rws(vv_lee)==1,s1_stack,0)
        cluster_img = compute_s1_cluster(s1_rws)

        s1_mnws = compute_s1_mnws(s1_stack,cluster_img)
        
        outf = f'{os.path.abspath("..")}/results/s1_mnws_vv/s1_mnws_vv_{file.split("_")[-1].split(".")[0]}.tif'
        
        #export as geotiff
        profile.update({'dtype':s1_mnws.dtype,'nodata':0,'count':1 })
        with rio.open(outf,'w',**profile) as dst:
            dst.write(s1_mnws)


In [None]:
s1_mnws_files = glob(f'{os.path.abspath("..")}results/s1_mnws_vv/*.tif')

wd = []

for i in tqdm(range(len(s1_mnws_files)),position=0, leave=True):
    file = s1_mnws_files[i]
    
    with rio.open(file) as src:
        profile = src.profile.copy()
        mnws_img = src.read()
        water = np.where(mnws_img<1,1,0)
        wd.append(water)
        
wd_sum = np.array(wd).sum(0).astype(np.float32)
wf = (wd_sum/len(s1_mnws_files))*12
wf_r = np.round(wf,2)
wf_r = np.where(mnws_img!=0,wf_r,0)

#export as geotiff
outf = f'{os.path.abspath("..")}results/s1_wf.tif'
profile.update({'dtype':wf_r.dtype,'nodata':0,'count':1})
with rio.open(outf,'w',**profile) as dst:
    dst.write(wf_r)