# MODIS Water Classification Spot Check

Version: 0.1.4

Date: 02.01.22

In [None]:
import folium
from folium import plugins
import glob
import numpy as np
import os
import rasterio as rio
import tempfile

from rasterio.warp import calculate_default_transform, reproject, Resampling
from pyproj import Transformer 

In [None]:
year = 2020
tile = 'h11v02'
modelType = 'Caleb'
month= 'AnnualMap'

## Find the correct data and get the paths

In this case we're looking for four products:
1. Annual Water Mask output from a model (random forest, classic, etc).
2. Annual water probability output from a model as well.
3. Annual summation of land output from a model.
4. DEM (digital elevation map) for the corresponding tile.
5. Annual burn scar mask.

The annual burn scar is an unoffical product derived from the 16 day product MCD64A1. 

In [None]:
product_path = "/att/nobackup/zwwillia/MODIS_water/model_outputs/{0}/{1}/{2}/{3}/".format(modelType,year,tile,month)
gmted_path = '/adapt/nobackup/projects/ilab/scratch/mcarrol2/data/GMTED-MODIS/MODIS_tiles'
burn_scar_path = '/att/nobackup/cssprad1/projects/modis_water/data/burn_scar_products/MCD64A1-BurnArea-Annual/{}'.format(year)

# The glob module is used to find a list of files that match a certain pattern.
annual_mask_product = glob.glob(os.path.join(product_path, '*RandomForest-Mask.tif'))[0]
annual_probWater_product = glob.glob(os.path.join(product_path, '*RandomForest-ProbWater.tif'))[0]
annual_sumLand_product = glob.glob(os.path.join(product_path, '*RandomForest-SumLand.tif'))[0]
annual_burn_scar_product = glob.glob(os.path.join(burn_scar_path, 'MCD64A1-BurnArea_Annual_A2020.{}.tif'.format(tile)))[0]
gmted = glob.glob(os.path.join(gmted_path, 'GMTED.{}.med.tif'.format(tile)))[0]

print('Found {}'.format(annual_mask_product))
print('Found {}'.format(annual_probWater_product))
print('Found {}'.format(annual_sumLand_product))
print('Found {}'.format(annual_burn_scar_product))
print('Found {}'.format(gmted))

In [None]:
# -----------------------------------------------------------------------------
# Uses rasterio to open a raster, get the metadata and crs
# associated with it and get all the subdatasets in the file.
# This is very useful for hdf files such as MODIS hdfs.
# -----------------------------------------------------------------------------
def print_subdatasets(filename):
    bands_to_return = []
    with rio.open(filename) as dataset:
        meta_data = dataset.meta
        crs = dataset.read_crs()
        
        print([name for name in dataset.subdatasets if search_term in name])

# -----------------------------------------------------------------------------
# Gets a tiff that has the correct metadata for that tile, gets the metadata
# from the source tif and copies to a destination tiff. 
# -----------------------------------------------------------------------------     
def add_metadata_to_annual_product(filepath, model_type, year, tile):
    metadata_pull_src = [fv for fv in glob.glob(os.path.join(filepath, "{}-1*-{}-MOD-*.tif".format(year, tile)))][0]
    with rio.open(metadata_pull_src) as src:
        src_meta = src.meta
    dst_tiffs = [os.path.join(filepath, fn) for fn in os.listdir(filepath) if "{0}-{1}".format(year, tile) in os.path.basename(fn)]
    [copy_meta(dst_tiff, src_meta, metadata_pull_src) for dst_tiff in dst_tiffs]

# -----------------------------------------------------------------------------
# Given a path to a tiff with no metadata, assign the metadata given to that
# tiff.
# -----------------------------------------------------------------------------     
def copy_meta(dst_path, src_meta, src_name):
    print('Copying metadata from {} to {}'.format(src_name, dst_path))
    with rio.open(dst_path, 'r+') as dst:
        dst.crs = src_meta['crs']
        dst.transform = src_meta['transform']        

# -----------------------------------------------------------------------------
# Given a tiff file as input, open the tiff and get the transform needed to
# reproject from the tiff's source crs to the one we want (EPSG:3857).
# For each band in the tiff, open then reproject it into the desired crs
# then write to a temporary file. Return the path to the temp file.
# -----------------------------------------------------------------------------
def reproject_to_3857(input_tiff):
    # Set desitnation CRS
    dst_crs = f"EPSG:3857"

    # set out path
    out_path_rproj = os.path.join(tempfile.gettempdir(), input_tiff.split('/')[-1].replace('.tif','-3857.tif'))

    with rio.open(input_tiff) as src:
        # get src bounds and transform
        transform, width, height = calculate_default_transform(src.crs, dst_crs, src.width, src.height, *src.bounds)
        kwargs = src.meta.copy()
        kwargs.update({'crs': dst_crs,
                   'transform': transform,
                   'width': width,
                   'height': height})
    
        # reproject and write to file
        with rio.open(out_path_rproj, 'w', **kwargs) as dst:
            for i in range(1, src.count + 1):
                reproject(source=rio.band(src, i),
                      destination=rio.band(dst, i),
                      src_transform=src.transform,
                      src_crs=src.crs,
                      dst_transform=transform,
                      dst_crs=dst_crs,
                      resampling=Resampling.nearest)
    return out_path_rproj

# -----------------------------------------------------------------------------
# In order for folium to work properly we need to pass it the bounding box
# of the tiff in the form of lat and lon. This is done by using rasterio.
# -----------------------------------------------------------------------------
def get_bounds(tiff_3857):
    with rio.open(tiff_3857) as src:
        src_crs = src.crs['init'].upper()
        min_lon, min_lat, max_lon, max_lat = src.bounds
    bounds_orig = [[min_lat, min_lon], [max_lat, max_lon]]
    bounds = []
    dst_crs = 'EPSG:4326'
    for item in bounds_orig:   
        #converting to lat/lon
        lat = item[0]
        lon = item[1]
        proj = Transformer.from_crs(int(src_crs.split(":")[1]), int(dst_crs.split(":")[1]), always_xy=True)
        lon_n, lat_n = proj.transform(lon, lat)
        bounds.append([lat_n, lon_n])
    center_lon = bounds[0][1] + (bounds[1][1] - bounds[0][1])/2
    center_lat = bounds[0][0] + (bounds[1][0] - bounds[0][0])/2
    return {'bounds': bounds, 'center': (center_lon, center_lat)}

# -----------------------------------------------------------------------------
# Use rasterio to open and read in the desired band name as a nd-array.
# -----------------------------------------------------------------------------
def open_and_get_band(file_name, band_num=1):
    with rio.open(file_name) as data:
        b = data.read(band_num)
    return b

# -----------------------------------------------------------------------------
# Given an nd-array (band) and the bounds in lat lon of the nd-array, return
# a folium layer. To add on the map.
# -----------------------------------------------------------------------------
def get_overlay(band, meta_dict, name, opacity=1.0, show=True):
    return folium.raster_layers.ImageOverlay(band, 
                                             bounds=meta_dict['bounds'], 
                                             name=name, 
                                             opacity=opacity, 
                                             show=show)

# -----------------------------------------------------------------------------
# We don't need to keep those temp files we made for the reprojections around.
# -----------------------------------------------------------------------------
def cleanup(filename):
    if os.path.exists(filename):
        os.remove(filename)
    else:
        print('No file: {} exists.'.format(filename))

### Add metadata from a tiff we know has the correct metadata for the tile to the product tiffs.

In [None]:
add_metadata_to_annual_product(product_path, modelType, year, tile)

### Reproject all of our layers into 3857 projection

In [None]:
mask_3857 = reproject_to_3857(annual_mask_product)
probW_3857 = reproject_to_3857(annual_probWater_product)
sumL_3857 = reproject_to_3857(annual_sumLand_product)
bs_3857 = reproject_to_3857(annual_burn_scar_product)
dem_3857 = reproject_to_3857(gmted)

### Get all the bounding boxes for each product in lat,lon format.

In [None]:
mask_d = get_bounds(mask_3857)
probw_d = get_bounds(probW_3857)
suml_d = get_bounds(sumL_3857)
bs_d = get_bounds(bs_3857)
dem_d = get_bounds(dem_3857)

### Sanity check to make sure we're all working in the right part of the world...

In [None]:
from pprint import pprint

print('Water Mask')
pprint(mask_d)

print('Prob Water')
pprint(probw_d)

print('Sum Land')
pprint(suml_d)

print('Burn Scar')
pprint(bs_d)

print('DEM')
pprint(dem_d)

In [None]:
mask_b1 = open_and_get_band(mask_3857, 1)
probw_b1 = open_and_get_band(probW_3857, 1)
suml_b1 = open_and_get_band(sumL_3857, 1)
bs_b1 = open_and_get_band(bs_3857, 1)
gmted_b1 = open_and_get_band(dem_3857, 1)
zeros = np.zeros_like(mask_b1)

In [None]:
mask_rgb = np.dstack((mask_b1, zeros, zeros))
probw_rgb = np.dstack((zeros, zeros, probw_b1))
suml_rgb = np.dstack((zeros, suml_b1, zeros))
bs_rgb = np.dstack((bs_b1, np.zeros_like(bs_b1), bs_b1))
gmted_rgb = np.dstack((zeros, gmted_b1, gmted_b1))

Create a folium map that is centered on the location we want. Uses google base maps as the base map.

In [None]:
m = folium.Map(location=[mask_d['center'][1], mask_d['center'][0]],
                   tiles='https://mt1.google.com/vt/lyrs=s&x={x}&y={y}&z={z}', zoom_start = 6, attr='Google')

Add each product as an image_overlay to the map.

In [None]:
m.add_child(get_overlay(mask_rgb, mask_d, '{}-{} model water mask'.format(year, tile), opacity=0.6))
m.add_child(get_overlay(probw_rgb, probw_d, '{}-{} model proba water'.format(year, tile), opacity=0.8, show=False))
m.add_child(get_overlay(suml_rgb, suml_d, '{}-{} model sum land'.format(year, tile), opacity=0.8, show=False))
m.add_child(get_overlay(bs_rgb, bs_d, '{}-{} MCD burn scar'.format(year, tile), opacity=0.8, show=False))
m.add_child(get_overlay(gmted_b1, dem_d, '{} GEMTED'.format(tile), opacity=1, show=False))
m.add_child(plugins.MousePosition())
m.add_child(folium.LayerControl())

In [None]:
cleanup(mask_3857)
cleanup(probW_3857)
cleanup(sumL_3857)
cleanup(bs_3857)
cleanup(dem_3857)