<a href="https://colab.research.google.com/github/ericslevenson/arctic-surface-water/blob/main/Predictor.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [None]:
# Authenticate private account (only required for exporting to drive/gee/gcp)
from google.colab import auth 
auth.authenticate_user()

# Google Drive setup (if needed)
from google.colab import drive
drive.mount('/content/drive')

Mounted at /content/drive


In [None]:
from psutil import virtual_memory
ram_gb = virtual_memory().total / 1e9
print('Your runtime has {:.1f} gigabytes of available RAM\n'.format(ram_gb))

Your runtime has 54.8 gigabytes of available RAM



In [None]:
# Complete the environment
!cp /content/drive/MyDrive/Colab\ Notebooks/UNET_lake_identifier/utils/mightymosaic.py /content
!cp /content/drive/MyDrive/Colab\ Notebooks/UNET_lake_identifier/utils/gs_utils.py /content
!pip install rasterio
!pip install rioxarray
!pip install segmentation_models

Looking in indexes: https://pypi.org/simple, https://us-python.pkg.dev/colab-wheels/public/simple/
Collecting rasterio
  Downloading rasterio-1.2.10-cp37-cp37m-manylinux1_x86_64.whl (19.3 MB)
[K     |████████████████████████████████| 19.3 MB 27.9 MB/s 
[?25hCollecting snuggs>=1.4.1
  Downloading snuggs-1.4.7-py3-none-any.whl (5.4 kB)
Collecting affine
  Downloading affine-2.3.1-py2.py3-none-any.whl (16 kB)
Collecting cligj>=0.5
  Downloading cligj-0.7.2-py3-none-any.whl (7.1 kB)
Collecting click-plugins
  Downloading click_plugins-1.1.1-py2.py3-none-any.whl (7.5 kB)
Installing collected packages: snuggs, cligj, click-plugins, affine, rasterio
Successfully installed affine-2.3.1 click-plugins-1.1.1 cligj-0.7.2 rasterio-1.2.10 snuggs-1.4.7
Looking in indexes: https://pypi.org/simple, https://us-python.pkg.dev/colab-wheels/public/simple/
Collecting rioxarray
  Downloading rioxarray-0.9.1.tar.gz (47 kB)
[K     |████████████████████████████████| 47 kB 4.3 MB/s 
[?25h  Installing build d

In [None]:
import os
import tensorflow.keras as keras
import numpy as np
import rasterio
import segmentation_models as sm
from tensorflow.keras import models
import tensorflow as tf
from tensorflow.keras.layers import Input, Conv2D
from tensorflow.keras.models import Model
import mightymosaic as MightyMosaic
import gs_utils as gs_utils
import rioxarray as rxr

Segmentation Models: using `keras` framework.


In [None]:
BACKBONE = 'efficientnetb7'
preprocess_input = sm.get_preprocessing(BACKBONE)
image_descriptor = 'BGRN_SR.tif'

#image directory to predict
image_in_dir = '/content/drive/MyDrive/S2_BGRN/'

#path to slope raster
path_to_slope = '/content/drive/MyDrive/AK_6N_slope.tif'

#directory to save output prediction masks
prediction_output_dir= '/content/drive/MyDrive/UNET_outputs/'

#specify the model file
model_file = '/content/drive/MyDrive/waterbody-mask/UNET/single_class_slope_best_model.h5'
slope_input = True #True if using the slope model, False otherwise

classes = 'single' #either 'single' or 'multi'

MM_overlap_factor = 2 #overlap factor for moving window prediction. Higher values increasingly mitigate image tiling artifacts but can potentially introduce noise.

prediction_batch_size = 2 #batch size for model prediction. Number of tiles in the image must be divisible by this number

#set tensorflow data options
options = tf.data.Options()
options.experimental_distribute.auto_shard_policy = \
        tf.data.experimental.AutoShardPolicy.OFF

In [None]:
def main():
    
    print('compiling model')
    model = compile_cnn(model_file)

    #get list of image names to predict
    images = [i for i in os.listdir(image_in_dir) if image_descriptor in i]
    print(images)
    #remove image names from list that already have predictions
    images = [j for j in images if j.split(image_descriptor)[0]+'_unet_pred.tif' not in os.listdir(prediction_output_dir)]
    print(images)
    #open slope raster once
    if slope_input == True:
      xds_slope = rxr.open_rasterio(path_to_slope)
    
    for image in images: 
            
        print('formatting input data')
        if slope_input == True:
          full_im, crs, gt, image_shape = create_moving_window_data(image_in_dir, image, xds_slope, preprocess_input) #this is the original
        else:
          full_im, crs, gt, image_shape = create_moving_window_data(image_in_dir, image, preprocess_input)

        print('generating Mighty Mosaic')
        mosaic = MightyMosaic.from_array(full_im, (256,256), overlap_factor=MM_overlap_factor, fill_mode='reflect')
            
        print('predicting for input data')
        prediction_mosaic = mosaic.apply(model.predict, progress_bar=True, batch_size=prediction_batch_size)
        prediction_mosaic = prediction_mosaic.get_fusion()
        
        #write output image
        if classes == 'single':
            
            gs_utils.write_geotiff(prediction_output_dir + image.split(image_descriptor)[0] + '_unet_pred.tif', 
                                image_shape, gt, crs, prediction_mosaic[:,:,0], 0)
        
        if classes == 'multi':
            
            prediction = np.zeros(np.shape(prediction_mosaic[:,:,0]))

            mask = ((prediction_mosaic[:,:,0]<0.5) & ((prediction_mosaic[:,:,1]>=0.5) | (prediction_mosaic[:,:,2]>=0.5)))
            
            prediction[mask] = 1

            gs_utils.write_geotiff(prediction_output_dir + image.split(image_descriptor)[0] + '_unet_pred.tif', 
                                image_shape, gt, crs, prediction, 0)
            
        print('')

    slope = None
    
if __name__ == "__main__":
    main()

In [None]:
######## For some reason this stopped early...check with original script. 
##I think it's an if statement that is not indented correctly. "for image in images...."

SEPARATING OUT COMPONENTS OF THE PROCESS

In [None]:
# Functions
# compile cnn from .h5 model file
def compile_cnn(model_file):
    """
        Compile u-net from .h5 model file
        :param model_file: path to .h5 model file.
        :type shape: `str` representing the path to .h5 model file.
        
        :return: u-net model
        :rtype: 'tf.keras.Model'
    """
    # load single class model from file
    if classes == 'single':
         model = keras.models.load_model(model_file, 
                                custom_objects={'binary_focal_loss_plus_jaccard_loss': sm.losses.binary_focal_jaccard_loss,
                                                'precision':sm.metrics.Precision,
                                                'recall':sm.metrics.Recall,
                                                'f1-score':sm.metrics.FScore,
                                                'iou_score': sm.metrics.IOUScore})
    # load multiclass model from file
    if classes == 'multi':
        model = keras.models.load_model(model_file, 
                                custom_objects={'precision':sm.metrics.Precision,
                                                'recall':sm.metrics.Recall,
                                                'f1-score':sm.metrics.FScore,
                                                'iou_score': sm.metrics.IOUScore})

    return model
def create_moving_window_data(image_dir, image, preprocess_input):
  """
      Format input data for u-net prediction
      :param image_dir: path to directory containing images to be predicted.
      :type image_dir: `str` representing path to directory containing images to be predicted.
      :param image: name of image to predict.
      :type image: `str` representing name of image, must end in .tif.
      :param slope: slope raster.
      :type slope: `xarray.core.dataarray.DataArray` of slope raster.
      :param preprocess_input: u-net preprocessing function specific to model backbone (efficientnet-b7).
      :type preprocess_input: 'function'.
      :return: full_im
      :rtype: 'np.array' of preprocessed image input data
      :return: crs (projection)
      :rtype: 'str' wkt representation of image projection
      :return: gt (geotransform)
      :rtype: 'tuple' array of image geotransform in gdal format
      :return: shape
      :rtype: 'tuple' dimensions of input image in (y,x)
  """
  #open image
  xds_full_im = rxr.open_rasterio(os.path.join(image_dir, image))
  #get projection information (crs, geotransform, and raster shape)
  crs = xds_full_im.rio.crs.wkt
  gt = xds_full_im.rio.transform().to_gdal()
  shape = np.shape(xds_full_im[0])
  #get values as array
  full_im = xds_full_im.values
  #scale image values
  full_im = (full_im/10000.0)*255
  #set areas where image=0 to 0.5
  image_zero_mask = full_im==0
  full_im[image_zero_mask==True] = 0.5

  #reshape to bands_last format and preprocess
  full_im = full_im.transpose(2, 0, 1)
  full_im = full_im.transpose(2, 0, 1)
  full_im = preprocess_input(full_im)

  return full_im, crs, gt, shape

In [None]:
model = compile_cnn(model_file)

In [None]:
images = [i for i in os.listdir(image_in_dir)]
print(images)

['20200817YKF_BGRN_SR-0000000000-0000000000.tif', '20200817YKF_BGRN_SR-0000000000-0000023296.tif']


In [None]:
images = [i for i in os.listdir(image_in_dir)]
for image in images:
  xds_full_im = rxr.open_rasterio(os.path.join(image_in_dir, image))
  #get projection information (crs, geotransform, and raster shape)
  crs = xds_full_im.rio.crs.wkt
  gt = xds_full_im.rio.transform().to_gdal()
  full_im, crs, gt, image_shape = create_moving_window_data(image_in_dir, image, preprocess_input)
  mosaic = MightyMosaic.from_array(full_im, (256,256), overlap_factor=MM_overlap_factor, fill_mode='reflect')
  del full_im
  prediction_mosaic = mosaic.apply(model.predict, progress_bar=True, batch_size=prediction_batch_size)
  del mosaic
  prediction_mosaic = prediction_mosaic.get_fusion()
  gs_utils.write_geotiff(prediction_output_dir + image.split('.')[0] + '_unet_pred.tif', 
                    image_shape, gt, crs, prediction_mosaic[:,:,0], 0)
  del prediction_mosaic

NameError: ignored

In [None]:
images = [i for i in os.listdir(image_in_dir)]
image = images[1]

In [None]:
image

'201906_dj_BGRN_SR-0000000000-0000000000.tif'

In [None]:
xds_full_im = rxr.open_rasterio(os.path.join(image_in_dir, image))
#get projection information (crs, geotransform, and raster shape)
crs = xds_full_im.rio.crs.wkt
gt = xds_full_im.rio.transform().to_gdal()

In [None]:
full_im, crs, gt, image_shape = create_moving_window_data(image_in_dir, image, preprocess_input)

In [None]:
mosaic = MightyMosaic.from_array(full_im, (256,256), overlap_factor=MM_overlap_factor, fill_mode='reflect')

In [None]:
del full_im

In [None]:
prediction_mosaic = mosaic.apply(model.predict, progress_bar=True, batch_size=prediction_batch_size)
del mosaic
prediction_mosaic = prediction_mosaic.get_fusion()

  0%|          | 0/2760 [00:00<?, ?it/s]

In [None]:
gs_utils.write_geotiff(prediction_output_dir + image.split(image_descriptor)[0] + '_unet_pred.tif', 
                    image_shape, gt, crs, prediction_mosaic[:,:,0], 0)

#### Slope section w separated components

In [None]:
# Functions
###########################################################################################################################################################

# compile cnn from .h5 model file
def compile_cnn(model_file):
    """
        Compile u-net from .h5 model file
        :param model_file: path to .h5 model file.
        :type shape: `str` representing the path to .h5 model file.
        
        :return: u-net model
        :rtype: 'tf.keras.Model'
    """
    # load single class model from file
    if classes == 'single':
         model = keras.models.load_model(model_file, 
                                custom_objects={'binary_focal_loss_plus_jaccard_loss': sm.losses.binary_focal_jaccard_loss,
                                                'precision':sm.metrics.Precision,
                                                'recall':sm.metrics.Recall,
                                                'f1-score':sm.metrics.FScore,
                                                'iou_score': sm.metrics.IOUScore})
    # load multiclass model from file
    if classes == 'multi':
        model = keras.models.load_model(model_file, 
                                custom_objects={'precision':sm.metrics.Precision,
                                                'recall':sm.metrics.Recall,
                                                'f1-score':sm.metrics.FScore,
                                                'iou_score': sm.metrics.IOUScore})

    return model
    
def create_moving_window_data(image_dir, image, slope, preprocess_input):
    """
        Format input data for u-net prediction
        :param image_dir: path to directory containing images to be predicted.
        :type image_dir: `str` representing path to directory containing images to be predicted.
        :param image: name of image to predict.
        :type image: `str` representing name of image, must end in .tif.
        :param slope: slope raster.
        :type slope: `xarray.core.dataarray.DataArray` of slope raster.
        :param preprocess_input: u-net preprocessing function specific to model backbone (efficientnet-b7).
        :type preprocess_input: 'function'.
        :return: full_im
        :rtype: 'np.array' of preprocessed image input data
        :return: crs (projection)
        :rtype: 'str' wkt representation of image projection
        :return: gt (geotransform)
        :rtype: 'tuple' array of image geotransform in gdal format
        :return: shape
        :rtype: 'tuple' dimensions of input image in (y,x)
    """
    #open image
    xds_full_im = rxr.open_rasterio(os.path.join(image_dir, image))
    #get projection information (crs, geotransform, and raster shape)
    crs = xds_full_im.rio.crs.wkt
    gt = xds_full_im.rio.transform().to_gdal()
    shape = np.shape(xds_full_im[0])
    #get values as array
    full_im = xds_full_im.values
    #scale image values
    full_im = (full_im/10000.0)*255
    #set areas where image=0 to 0.5
    image_zero_mask = full_im==0
    full_im[image_zero_mask==True] = 0.5
    #stack on optional slope raster
    if slope_input == True:
        #reproject to match image extent, resolution, and crs
        slope_data = slope.rio.reproject_match(xds_full_im)
        slope_data = slope_data.assign_coords({
            "x": xds_full_im.x,
            "y": xds_full_im.y,
        })
        slope_data = slope_data.values[0]
        #set nodata areas in image to 0.5 in the slope raster
        slope_data[image_zero_mask[0]==True] = 0.5
        #add slope to stack
        full_im = np.append(full_im, [slope_data], axis=0)
  
    #reshape to bands_last format and preprocess
    full_im = full_im.transpose(2, 0, 1)
    full_im = full_im.transpose(2, 0, 1)
    full_im = preprocess_input(full_im)
  
    return full_im, crs, gt, shape

In [None]:
#######################Separating MAIN (below) into components for RAM############

In [None]:
model = compile_cnn(model_file)

In [None]:
images = [i for i in os.listdir(image_in_dir) if image_descriptor in i]
for image in images:
  print(image)

20180909_06VVM_BGRN_SR.tif
20200817_06VVR_BGRN_SR.tif
20190629_06VVN_BGRN_SR.tif
20180613_06VWM_BGRN_SR.tif
20190808_06VUM_BGRN_SR.tif
20200817_06VVP_BGRN_SR.tif
20160831_06VUR_BGRN_SR.tif
20210524_06VUP_BGRN_SR.tif
20210524_06VVQ_BGRN_SR.tif
20210524_06VUN_BGRN_SR.tif


In [None]:
images = [i for i in os.listdir(image_in_dir) if image_descriptor in i]
for image in images:
  xds_full_im = rxr.open_rasterio(os.path.join(image_in_dir, image))
  #get projection information (crs, geotransform, and raster shape)
  crs = xds_full_im.rio.crs.wkt
  gt = xds_full_im.rio.transform().to_gdal()
  xds_slope = rxr.open_rasterio(path_to_slope)
  full_im, crs, gt, image_shape = create_moving_window_data(image_in_dir, image, xds_slope, preprocess_input)
  del xds_slope
  mosaic = MightyMosaic.from_array(full_im, (256,256), overlap_factor=MM_overlap_factor, fill_mode='reflect')
  prediction_mosaic = mosaic.apply(model.predict, progress_bar=True, batch_size=prediction_batch_size)
  prediction_mosaic = prediction_mosaic.get_fusion()
  del mosaic
  gs_utils.write_geotiff(prediction_output_dir + image.split(image_descriptor)[0] + '_unet_pred.tif', 
                      image_shape, gt, crs, prediction_mosaic[:,:,0], 0)

  0%|          | 0/3784 [00:00<?, ?it/s]

  0%|          | 0/3784 [00:00<?, ?it/s]

  0%|          | 0/3698 [00:00<?, ?it/s]

  0%|          | 0/3698 [00:00<?, ?it/s]

  0%|          | 0/3698 [00:00<?, ?it/s]

  0%|          | 0/3698 [00:00<?, ?it/s]

  0%|          | 0/3784 [00:00<?, ?it/s]

  0%|          | 0/3784 [00:00<?, ?it/s]

  0%|          | 0/3784 [00:00<?, ?it/s]