In [None]:
import os
from glob import glob
import rasterio as rio
import numpy as np
from tqdm import tqdm,tqdm_notebook
from rasterio.plot import reshape_as_image
import re
import pandas as pd
import json
import geopandas as gpd


import seaborn as sns
from rasterio.features import rasterize
import matplotlib.pyplot as plt

from IPython.core.display import display, HTML
display(HTML("<style>.container { width:90% !important; }</style>"))

In [None]:
from sklearn.cluster import MiniBatchKMeans
from skimage.morphology import reconstruction
from rasterio.features import sieve
from rasterio.plot import reshape_as_image

def compute_rws(stack_arr,rws_type):
    np.seterr(divide='ignore', invalid='ignore')
    
    if rws_type == 'WIW':
        b8a = stack_arr[7].astype(np.float32)/10000
        b12 = stack_arr[9].astype(np.float32)/10000
        rws = np.where( ((b8a>0)&(b8a<=0.1804)) & ((b12>0)&(b12<=0.1131)), 1, 0)

    elif rws_type == 'RWS':
        b3= stack_arr[1].astype(np.float32)/10000
        b11 = stack_arr[8].astype(np.float32)/10000
        mndwi = (b3-b11)/(b3+b11)
        mgrn = (stack_arr[[1,2,6]].astype(np.float32)/10000).min(0)
        rws = np.where( (mndwi>0.3) &((mgrn>0) & (mgrn<0.15)),1,0)
    
    nodata_mask = np.where(stack_arr==0,1,0).max(0)
    rws_img = np.where((nodata_mask!=1) & (rws==1),stack_arr,0)
    return rws_img,rws
        
def compute_cluster(rws_rgb_img,k=4):
    samples = reshape_as_image(rws_rgb_img).reshape(-1,rws_rgb_img.shape[0])
    kmeans_pred = MiniBatchKMeans(n_clusters=k+1, random_state=42,max_iter=10,batch_size=10000,reassignment_ratio=0).fit(samples)
    kmeans_pred_img = kmeans_pred.labels_.reshape(rws_rgb_img.shape[1], rws_rgb_img.shape[2]).astype(rio.uint16)
    return kmeans_pred_img

def compute_mnws(img,cluster_img,bands=['B2','B3','B4','B8','B11','B12']):
    band_names = ['B2','B3','B4','B5','B6','B7','B8','B8A','B11','B12']
    band_pos = list(map(lambda x:band_names.index(x),bands))
    
    nodata = np.where(img[band_pos]==0,1,0).max(0)
    raw_img = np.where(nodata!=1,img[band_pos],0)

    mnws = []
    
    max_i = np.argmax(np.unique(cluster_img,return_counts=True)[1])
    all_labels = list(range(0,cluster_img.max()+1))
    labels = list(set(all_labels)-set([all_labels[max_i]]))

    for label in labels:
        
        #calculate band stats
        region_img = np.where(cluster_img==label,raw_img,0)
        band_means = np.array(list(map(lambda x:np.mean(region_img[x][region_img[x]!=0],dtype=np.float32),
                                       range(len(band_pos))))).reshape(len(band_pos),-1)
        band_std = np.array(list(map(lambda x:np.std(region_img[x][region_img[x]!=0],dtype=np.float32),
                                     range(len(band_pos))))).reshape(len(band_pos),-1)
        
        #calculate nws 
        reshaped_raw_img = raw_img.reshape(len(band_pos),-1)
        nws = (((((reshaped_raw_img-band_means)/band_std)**2).sum(0)/len(band_pos))**0.5).reshape(img.shape[1],img.shape[2])
        mnws.append(nws)
        
    mnws_img = np.array(mnws).min(0)
    mnws_img_clip = np.array([np.where(nodata!=1,mnws_img,0)])
    return mnws_img_clip




In [None]:
mosaic_dir = f'{os.path.abspath("..")}images'
mosaics = glob(mosaic_dir+"/*.tif")
dates = [re.findall(r"(\d{8})", mos)[0] for mos in mosaics]
band_names = ['B2','B3','B4','B5','B6','B7','B8','B8A','B11','B12']

In [None]:
%%time
#cloud masks computation
band_names = ['B2','B3','B4','B5','B6','B7','B8','B8A','B11','B12']

ref_file = f'{os.path.abspath("..")}results/42mosaics_median_composite.tif'
src_mask = rio.open(ref_file)
mask = src_mask.read().astype(int)


for i in tqdm(range(len(mosaics)),position=0, leave=True):
    file = mosaics[i]
    outf = f'{os.path.abspath("..")}results/cloud_masks/{os.path.basename(file)}'.replace('mosaic','cl_mask')

    with rio.open(file) as src:
        profile = src.profile.copy()
        mosaic = src.read()          #target image (mosaic)           

        #cloud shadow = 1 and thick cloud = 2
        cl_shadow = np.where( ((mosaic[7]-mask[7])<-400) & ((mosaic[8]-mask[8])<-400) & (mosaic[0]<1000),1,0)
        cl_thick = np.where( ((mosaic[0]-mask[0])>800) & ((mosaic[1]-mask[1])>800) & ((mosaic[2]-mask[2])>800),2,0)
        cl_masks = np.array([cl_shadow,cl_thick]).sum(0)
        cl_masks = sieve(cl_masks, size=30).astype(np.uint8)
        
        #export as geotiff
        profile.update({'dtype':cl_masks.dtype,'nodata':0,'count':1})
        with rio.open(outf,'w',**profile) as dst:
            dst.write(np.array([cl_masks]))
            
src_mask.close()

In [None]:
%%time
#mnws computation
for i in tqdm(range(len(mosaics)),position=0, leave=True):
    file = mosaics[i]
    outf = f'{os.path.abspath("..")}results/wiw_mnws/wiw_{os.path.basename(file)}'.replace('mosaic','mnws')

    with rio.open(file) as src:
        profile = src.profile.copy()
        img = src.read()

        #mnws computations
        rws_img,_ = compute_rws(img,'WIW')
        cluster_img = compute_cluster(rws_img[[0,1,2]],k=8)
        mnws_img = compute_mnws(img,cluster_img,['B2','B3','B4','B8A','B11','B12'])
        
        #export as geotiff
        profile.update({'dtype':mnws_img.dtype,'nodata':0,'count':mnws_img.shape[0]})
        with rio.open(outf,'w',**profile) as dst:
            dst.write(mnws_img)
        

In [None]:
#wiw formula isnt working well. Maybe add the MGRN threshold or MNDWI threshold 

In [None]:
def mnws_boxplot(data):
    plt.figure(figsize=(10,5))
    sns.set_style('darkgrid',rc={"xtick.bottom" : True, "ytick.left" : True,'axes.edgecolor': 'black'})
    flierprops = dict(marker='o',markerfacecolor='red', markersize=5,markeredgecolor='white')
    sns.boxplot(data=data,order=['Wa','Sh','Ds','Ve','Ag','Se','Bs','Cl'],flierprops=flierprops,whis=[5, 95],orient='h',palette='husl',linewidth=1)
    plt.xlabel('MNWS',fontsize=16)
    plt.ylabel('Land cover',fontsize=16)
    plt.title('MNWS land cover types 25 Feb 2017 ',fontsize=20)
    plt.xticks(range(0, int(data['Cl'].max()), 5))
    plt.axvline(3, 8,0,ls='--',color='black',lw=2)
    plt.show()

def plot_bands(band_data_mean,band_data_std):

    sns.set_style('darkgrid',rc={"xtick.bottom" : True, "ytick.left" : True,'axes.edgecolor': 'black'})
    plt.figure(figsize=(10,7))
    colors=['blue','orange','brown','yellow','green','purple','red','black']
    axs = [sns.lineplot(data=band_data.loc[band_data.index[i]]/10000,sort=False,color=colors[i]) for i in range(len(band_data.index))]
    axs[0].lines[0].set_linestyle("--")
    plt.legend(band_data.index,loc='upper right')
    plt.xlabel('Band',fontsize=16)
    plt.ylabel('Top of Atmosphere (TOA) Reflectance',fontsize=16)
    plt.title('Spectral signature land cover types 25 Feb 2017 ',fontsize=20)


    plt.fill_between(band_names,(band_data.iloc[0]/10000)-(band_data_std.iloc[0]/10000), (band_data.iloc[0]/10000)+(band_data_std.iloc[0]/10000), alpha=.3)
    plt.fill_between(band_names,(band_data.iloc[-1]/10000)-(band_data_std.iloc[-1]/10000), (band_data.iloc[-1]/10000)+(band_data_std.iloc[-1]/10000), alpha=.3,color=colors[-1])

    plt.show()

In [None]:
%%time

# MNWS and band stats example 25 Feb 2017
sample = glob('./data/sample/sample*.geojson')[0]
geo_sample = gpd.read_file(sample)
labels = dict(zip(geo_sample['code'].unique(),geo_sample['desc'].unique()))
geom_code = list((zip(geo_sample['geometry'].tolist(), geo_sample['code'].tolist() )))

with rio.open(mosaics[3]) as src:
    profile = src.profile.copy()
    img = src.read()
    
    #band stats
    sample_mask = rasterize(shapes=geom_code, out_shape=src.shape, transform=src.transform)
    band_data = pd.concat([pd.DataFrame(img[:,sample_mask==code].mean(1),columns=[desc]) for code,desc in labels.items()],axis=1).T
    band_data.columns = band_names
    band_data_std = pd.concat([pd.DataFrame(img[:,sample_mask==code].std(1),columns=[desc]) for code,desc in labels.items()],axis=1).T
    band_data_std.columns = band_names
    
    #mnws computations
    rws_img,rws = compute_rws(img,'RWS')
    cluster_img = compute_cluster(rws_img[[0,1,2]],k=8)
    mnws_img = compute_mnws(img,cluster_img,['B2','B3','B4','B8A','B11','B12'])
    mnws_data = pd.concat([pd.DataFrame(mnws_img[:,sample_mask==code]).T for code in list(labels.keys())],axis=1)
    mnws_data.columns = list(labels.values())



In [None]:
plot_bands(band_data,band_data_std)
mnws_boxplot(mnws_data)
round(mnws_data.describe(),3).loc[['mean','std','min','max']].T.sort_values('mean')