## Begin to Automate Your  Workflow With Functions

Functions can be tricky to learn. In this lesson, you will see how to break down a workflow and create 
a new function that can be used over and over on your data. 

You will want to use functions for this automation assignment. 

In [None]:
# Import necessary packages
import os
from glob import glob
import matplotlib.pyplot as plt
import numpy as np
from shapely.geometry import box
import geopandas as gpd
import rioxarray as rxr
import xarray as xr
from rasterio.plot import plotting_extent
from shapely.geometry import mapping
import earthpy as et
import earthpy.spatial as es
import earthpy.plot as ep

# Get data and set working directory
data = et.data.get_data('cold-springs-fire')
os.chdir(os.path.join(et.io.HOME,
                      "earth-analytics",
                      "data"))

In [None]:
# Get a list of each directory
path = os.path.join("ndvi-automation", "sites")

# Get a list of both site directories (We will talk more about automation next week)
sites = glob(path + "/*/")
# Get the site name
site_name = os.path.basename(os.path.normpath(sites[0]))


In [None]:
# Open up the shapefile for clipping your landsat data to the study area
vector_dir = os.path.join(sites[0],
                          "vector")

# Open crop boundary
site_boundary_path = os.path.join(vector_dir,  site_name + "-crop.shp")
crop_bound = gpd.read_file(site_boundary_path)

In [None]:
landsat_dir = os.path.join(sites[0],
                           "landsat-crop")
# This is the crop folder containing all of the .tif files
landsat_dirs = glob(os.path.join(landsat_dir, "LC08*"))
landsat_dirs.sort()

In [None]:
# Select just a single directory and grab bands 4-5 from the directory
adir = landsat_dirs[3]
# Open bands
band_paths = glob(os.path.join(adir, "*band*[4-5].tif"))
band_paths.sort()
band_paths

In [None]:
# Open up the cloud mask layer

# Cloud no data vals for Landsat 8 -
vals = [328, 392, 840, 904, 1350, 352, 368, 416,
        432, 480, 864, 880, 928, 944, 992, 480, 992]

# Get cloud mask layer
qa_r = glob(os.path.join(adir, "*qa*"))

# Clip the cloud mask layer
cl_mask = rxr.open_rasterio(qa_r[0], masked=True).squeeze()
cl_mask_crop = cl_mask.rio.clip(crop_bound.geometry.apply(mapping))

# View unique values in the data -note that not every landsat band will have clouds to mask
np.unique(cl_mask_crop.values)

In [None]:
# Open a single band using rioxarray & mask to valid range then apply cloud mask
band = rxr.open_rasterio(band_paths[0], masked=True).squeeze()
band_crop = band.rio.clip(crop_bound.geometry.apply(mapping))

# Specify the valid range of values for landsat
valid_range = (0, 10000)

if valid_range:
    mask = ((band_crop < valid_range[0]) | (band_crop > valid_range[1]))
    band_crop = band_crop.where(~xr.where(mask, True, False))

band_crop = band_crop.where(~cl_mask_crop.isin(vals))
band_crop

In [None]:
def open_clean_bands(band_path,
                     crop_bound,
                     valid_range=None,
                     a_mask=None,
                     vals=None):
    """Open and mask a single landsat band using a pixel_qa layer.

    Parameters
    -----------
    band_path : string
        A path to the array to be opened
    crop_bound : geopandas GeoDataFrame
        A geopandas dataframe to be used to crop the raster data using 
        rasterio mask().
    valid_range : tuple (optional)
        A tuple of min and max range of values for the data. 
        Default = None
    a_mask : xarray DataArray
        An xarray DataArray with values that have not yet been set to 1
    vals : list
        A list of values needed to create the cloud mask

    Returns
    -----------
    band_crop : xarray DataArray
        An xarray DataArray with values that should be masked set to 1 
        for True (Boolean)
    """
    # TODO add tests to ensure the arrays are the same .shape
    band = rxr.open_rasterio(band_path, masked=True).squeeze()
    band_crop = band.rio.clip(crop_bound.geometry.apply(mapping))

    # Only run this step if a valid range tuple is provided
    if valid_range:
        mask = ((band_crop < valid_range[0]) | (band_crop > valid_range[1]))
        band_crop = band_crop.where(~xr.where(mask, True, False))

    if len(a_mask.shape) == 3 & a_mask.shape[0] == 1:
        a_mask = a_mask.squeeze()

    band_crop = band_crop.where(~a_mask.isin(vals))

    return band_crop

## What's fastest!!??

Below you can see different versions of the same workflow timed for speed!
The most efficient approach seems to be clipping and masking the NDVI layer
rather than doing that for each band!

In [None]:
%%timeit
# Open and clean a single band
band_4 = open_clean_bands(band_path=band_paths[0],
                          crop_bound=crop_bound,
                          # The range of valid values for landsat can include negative values
                          # for this week let's stick with 0-10000
                          valid_range=(0, 10000),
                          a_mask=cl_mask_crop,
                          vals=vals)

# Open and clean a single band
band_5 = open_clean_bands(band_path=band_paths[1],
                          crop_bound=crop_bound,
                          # The range of valid values for landsat can include negative values
                          # for this week let's stick with 0-10000
                          valid_range=(0, 10000),
                          a_mask=cl_mask_crop,
                          vals=vals)

#  Then  calculate NDVI
ndvi = es.normalized_diff(band_5.values.astype('f4'), 
                   band_4.values.astype('f4'))


In [None]:
%%timeit
# Alternatively do this in a loop and create a list output - cleaner code - same  amount of time
all_bands = []
for aband in band_paths:
    cleaned_band = open_clean_bands(band_path=aband,
                          crop_bound=crop_bound,
                          # The range of valid values for landsat can include negative values
                          # for this week let's stick with 0-10000
                          valid_range=(0, 10000),
                          a_mask=cl_mask_crop,
                          vals=vals)
    all_bands.append(cleaned_band)


# Then calculate NDVI - note ndvi here is not an xarray oobject - i  don't love that. 
ndvi_2 = es.normalized_diff(all_bands[1].values.astype('f4'), 
                   all_bands[0].astype('f4'))


In [None]:
%%timeit
# Open Data - note that t his is not accounting for funky landsat values >10000 or <0
band_4 = rxr.open_rasterio(band_paths[0], masked=True).squeeze()
band_5 = rxr.open_rasterio(band_paths[1], masked=True).squeeze()

# Calculate NDVI
ndvi = (band_5 - band_4) / (band_5 + band_4)

ndvi_crop = ndvi.rio.clip(crop_bound.geometry.apply(mapping))

ndvi_crop = ndvi_crop.where(~cl_mask_crop.isin(vals))