## Floods dataset visualization

In [1]:
from datetime import datetime
import numpy as np
import pandas as pd
import os
import rasterio
from tqdm.notebook import tqdm
import matplotlib.pyplot as plt
%matplotlib inline

In [2]:
import datashader as ds
from datashader.colors import Elevation
import datashader.transfer_functions as tf
from datashader.transfer_functions import shade
from datashader.transfer_functions import stack
from datashader.transfer_functions import dynspread
from datashader.transfer_functions import set_background
from datashader.transfer_functions import Images, Image
from datashader.utils import orient_array
import xarray as xr
import xrspatial.multispectral as ms
from xrspatial import hillshade

  import pandas.util.testing as tm


### Read in all the data

In [3]:
root_dir = '/home/k3blu3/datasets/s1floods'
meta_file = 'flood-training-metadata.csv'
image_dir = os.path.join(root_dir, 'train_features')
label_dir = os.path.join(root_dir, 'train_labels')

In [4]:
df = pd.read_csv(os.path.join(root_dir, meta_file))
df.head(5)

Unnamed: 0,image_id,chip_id,flood_id,polarization,location,scene_start
0,awc00_vh,awc00,awc,vh,Bolivia,2018-02-15
1,awc00_vv,awc00,awc,vv,Bolivia,2018-02-15
2,awc01_vh,awc01,awc,vh,Bolivia,2018-02-15
3,awc01_vv,awc01,awc,vv,Bolivia,2018-02-15
4,awc02_vh,awc02,awc,vh,Bolivia,2018-02-15


In [6]:
# define all image types per chip id
extensions = ['vv', 'vh', 'nasadem', 'jrc-gsw-occurrence']

In [7]:
# keep unique chip ids
df= df.drop_duplicates(subset=['chip_id'])

In [8]:
len(df)

542

In [None]:
all_layers = list()
for idx, row in tqdm(df.iterrows(), total=len(df)):
    # read in all image types
    layers = dict()
    for ext in extensions:
        cvs = ds.Canvas(plot_width=512, plot_height=512)
        fname = os.path.join(image_dir, f"{row['chip_id']}_{ext}.tif")
        layer = xr.open_rasterio(fname).load()[0]
        layer.name = ext
        layer = cvs.raster(layer, agg='mean')
        layer.data = orient_array(layer)
        layers[ext] = layer
        
    # read in target
    cvs = ds.Canvas(plot_width=512, plot_height=512)
    fname = os.path.join(label_dir, f"{row['chip_id']}.tif")
    layer = xr.open_rasterio(fname).load()[0]
    layer.name = 'label'
    layer = cvs.raster(layer, agg='mean')
    layer.data = orient_array(layer)
    layers['label'] = layer   
    
    all_layers.append(layers)

### Visualize the data

In [None]:
# write a function to rescale an input image with percentile scaling
def rescale_img(img, min_val=0.0, max_val=1.0, dtype=np.float32, pmin=0.0, pmax=100.0, vmin=None, vmax=None):
    # compute min and max percentile ranges to scale with
    if not vmin:
        vmin, vmax = np.nanpercentile(img, pmin), np.nanpercentile(img, pmax)

    # rescale & clip
    img_rescale = ((img - vmin) * (1.0 / (vmax - vmin) * max_val)).astype(dtype)
    np.clip(img_rescale, min_val, max_val, out=img_rescale)

    return img_rescale

In [None]:
outdir = os.path.join(root_dir, 'training_viz')
if not os.path.exists(outdir):
    os.makedirs(outdir)

In [None]:
plt.close('all')

In [None]:
ctr = 1
for idx, row in tqdm(df.iterrows(), total=len(df)):
    layers = all_layers[ctr-1]
    
    # render hillshade on top of elevation
    el = layers['nasadem']
    hs = hillshade(el, azimuth=100, angle_altitude=50)
    dem = stack(shade(hs, cmap=['white', 'gray']), shade(el, cmap=Elevation, alpha=128))

    # create matplotlib figure
    plt.figure(dpi=200, figsize=(6, 6))
    
    # display VH, VV, VH
    plt.subplot(2, 2, 1)
    vv = np.ma.masked_equal(layers['vv'].to_numpy(), 0)
    vh = np.ma.masked_equal(layers['vh'].to_numpy(), 0)
    vh = rescale_img(vv, pmin=1, pmax=99)
    vv = rescale_img(vh, pmin=1, pmax=99)
    sar_img = np.ma.stack([vh, vv, vh], axis=-1)
    ax = plt.gca()
    ax.get_xaxis().set_visible(False)
    ax.get_yaxis().set_visible(False)
    plt.title('Sentinel 1 VH', fontsize=6)
    plt.imshow(vh, cmap='cividis')
    
    # display target
    plt.subplot(2, 2, 2)
    label = np.ma.masked_equal(layers['label'].to_numpy(), 255)
    label = rescale_img(label, vmin=0, vmax=1)
    ax = plt.gca()
    ax.get_xaxis().set_visible(False)
    ax.get_yaxis().set_visible(False)
    plt.title('Flood Label', fontsize=6)
    plt.imshow(label, cmap='inferno') 
    
    # display DEM
    plt.subplot(2, 2, 3)
    nasadem = np.asarray(dem.to_pil())
    ax = plt.gca()
    ax.get_xaxis().set_visible(False)
    ax.get_yaxis().set_visible(False)
    plt.title('NASA DEM', fontsize=6)
    plt.imshow(nasadem)
    
    # display JRC
    plt.subplot(2, 2, 4)
    jrc = rescale_img(layers['jrc-gsw-occurrence'].to_numpy(), vmin=0, vmax=100)
    ax = plt.gca()
    ax.get_xaxis().set_visible(False)
    ax.get_yaxis().set_visible(False)
    plt.title('JRC Occurrence', fontsize=6)
    plt.imshow(jrc, cmap='Blues')   
    
    fout = os.path.join(outdir, 'image_{:03d}.jpg'.format(ctr))
    plt.savefig(fout, bbox_inches='tight')
    #lt.tight_layout(w_pad=0.1, h_pad=0.1)

    #plt.show()
    plt.close()
    ctr += 1
    
    #break