In [None]:
#@title Copyright 2023 Google LLC. { display-mode: "form" }
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

<table class="ee-notebook-buttons" align="left"><td>
<a target="_blank"  href="http://colab.research.google.com/github/google/earthengine-community/blob/master/guides/linked/Earth_Engine_TensorFlow_Decision_Forests.ipynb">
    <img src="https://www.tensorflow.org/images/colab_logo_32px.png" /> Run in Google Colab</a>
</td><td>
<a target="_blank"  href="https://github.com/google/earthengine-community/blob/master/guides/linked/Earth_Engine_TensorFlow_Decision_Forests.ipynb"><img width=32px src="https://www.tensorflow.org/images/GitHub-Mark-32px.png" /> View source on GitHub</a></td></table>

# Using a TensorFlow Decision Forest model in Earth Engine

[TensorFlow Decision Forests (TF-DF)](https://www.tensorflow.org/decision_forests) is an implementation of popular tree-based machine learning models in TensorFlow.  These models can be trained, saved and hosted on Vertex AI, as with TensorFlow neural networks.  This notebook demonstrates how to install TF-DF, train a random forest, host the model on Vertex AI and get interactive predictions in Earth Engine.  The demonstration model produces a map of land cover from Landsat image data and pre-generated training data.

To get started, import the necessary libraries and authenticate.

#### **Warning!** This demo consumes billable resources of Google Cloud, including Earth Engine, Vertex AI and Cloud Storage.

## Setup

Note that a specific version of TF-DF needs to be installed.  This is because Vertex AI only supports [specific versions of TensorFlow in pre-built containers](https://cloud.google.com/vertex-ai/docs/predictions/pre-built-containers).  Although it's possible to build a custom container image with a more recent version, here we'll just use TF-DF version 1.4 so that we don't have to build a custom container image.  See [the TF-DF compatibility table](https://www.tensorflow.org/decision_forests/known_issues#compatibility_table) for more info.

In [None]:
!pip3 install -q tensorflow_decision_forests==1.4.0

In [None]:
import ee
import folium
from google.colab import auth
import google
import tensorflow as tf
import tensorflow_decision_forests as tfdf

Ensure this matches the TF-DF version installed above.  (Should be 2.12.0 to match TF-DF 1.14.0).

In [None]:
print(tf.__version__)

In [None]:
auth.authenticate_user()
credentials, project = google.auth.default()

In [None]:
# REPLACE WITH YOUR CLOUD PROJECT!
PROJECT = 'your-project'
# A Cloud Storage bucket into which you can write saved model artifacts.
OUTPUT_BUCKET = 'your-bucket'

In [None]:
ee.Initialize(credentials, project=PROJECT)

In [None]:
# Name of the model output folder in Cloud Storage.
MODEL_DIR = 'gs://' + OUTPUT_BUCKET + '/tfdf_demo'

# Suitable names for the model and endpoint.
MODEL_NAME = 'tfdf-demo-14'
ENDPOINT_NAME = 'tfdf-endpoint-14'

# A suitable container image for the host machine.
CONTAINER_IMAGE = 'us-docker.pkg.dev/vertex-ai/prediction/tf2-cpu.2-12:latest'

# Cloud Storage bucket with training and testing datasets.
DATA_BUCKET = 'ee-docs-demos'

# This is a good region for hosting AI models.
REGION = 'us-central1'

# Training and testing dataset file names in the Cloud Storage bucket.
TRAIN_FILE_PREFIX = 'Training_demo'
TEST_FILE_PREFIX = 'Testing_demo'
file_extension = '.tfrecord.gz'
TRAIN_FILE_PATH = 'gs://' + DATA_BUCKET + '/' + TRAIN_FILE_PREFIX + file_extension
TEST_FILE_PATH = 'gs://' + DATA_BUCKET + '/' + TEST_FILE_PREFIX + file_extension

# The labels, consecutive integer indices starting from zero, are stored in
# this property, set on each point.
LABEL = 'landcover'
# Number of label values, i.e. number of classes in the classification.
N_CLASSES = 3
CLASS_NAMES = ['bare', 'veg', 'water']

# Use Landsat 8 surface reflectance data for predictors.
L8SR = ee.ImageCollection('LANDSAT/LC08/C02/T1_L2')
# Use these bands for prediction.
BANDS = ['SR_B2', 'SR_B3', 'SR_B4', 'SR_B5', 'SR_B6', 'SR_B7']

# These names are used to specify properties in the export of
# training/testing data and to define the mapping between names and data
# when reading into TensorFlow datasets.
FEATURE_NAMES = list(BANDS)
FEATURE_NAMES.append(LABEL)

# List of fixed-length features, all of which are float32.
FEATURES_DICT = {
    feature_name: tf.io.FixedLenFeature(shape=[1], dtype=tf.float32)
    for feature_name in FEATURE_NAMES
}

ATTRIBUTION = 'Map Data &copy; <a href="https://earthengine.google.com/">Google Earth Engine</a>'

## Load datasets

In [None]:
print('Found training file.' if tf.io.gfile.exists(TRAIN_FILE_PATH)
    else 'No training file found.')
print('Found testing file.' if tf.io.gfile.exists(TEST_FILE_PATH)
    else 'No testing file found.')

In [None]:
# Create a dataset from the TFRecord file in Cloud Storage.
train_dataset = tf.data.TFRecordDataset([TRAIN_FILE_PATH, TEST_FILE_PATH],
                                        compression_type='GZIP')

# Print the first record to check.
iter(train_dataset).next()

In [None]:
def parse_tfrecord(example_proto):
  """The parsing function.

  Read a serialized example into the structure defined by FEATURES_DICT.

  Args:
    example_proto: a serialized Example.

  Returns:
    A tuple of the predictors dictionary and the LABEL, cast to an `int32`.
  """
  parsed_features = tf.io.parse_single_example(example_proto, FEATURES_DICT)
  labels = parsed_features.pop(LABEL)
  return parsed_features, tf.cast(labels, tf.int32)

# Map the function over the dataset.
parsed_dataset = train_dataset.map(parse_tfrecord, num_parallel_calls=4)

# Print the first parsed record to check.
iter(parsed_dataset).next()

## Fit a TensorFlow Decision Forest

In [None]:
rf_model = tfdf.keras.RandomForestModel(verbose=2)

# Train the model.  You can ignore AutoGraph warnings.
rf_model.fit(x=parsed_dataset.batch(1))

In [None]:
tfdf.model_plotter.plot_model_in_colab(rf_model, tree_idx=0, max_depth=3)

In [None]:
rf_model.summary()

## Reshape inputs/outputs

Earth Engine sends and receives data as dictionaries keyed by band, where each band stores a patch.  TF-DF (as a non-spatial model) doesn't know about patches, and instead just receives flattened arrays.  Add some reshaping layers to flatten the inputs and reshape the model outputs into a patch.

In [None]:
# Create a custom keras layer to do the reshaping of the input data.
class ReshapeInputEE(tf.keras.layers.Layer):
  def __init__(self, **kwargs):
    super().__init__(**kwargs)

  def call(self, tensor_dict):
    return_dict={}
    for (k,v) in tensor_dict.items():
      return_dict[k] = tf.reshape(v, [-1, 1]) # Flatten
    return return_dict

  def get_config(self):
    config = super().get_config()
    return config


# This layer reshapes the model predictions to what EE requires.
class ReshapeOutputEE(tf.keras.layers.Layer):
  def __init__(self, **kwargs):
    super().__init__(**kwargs)

  def call(self, inputs):
    input_dict, model_predictions = inputs[0], inputs[1]
    # This layer needs to know the shape of the original input data.
    shape = tf.shape(list(input_dict.values())[0])
    # It should be [batch, height, width, channels]:
    return tf.reshape(model_predictions, [-1, shape[1], shape[2], N_CLASSES])

  def get_config(self):
    config = super().get_config()
    return config


input_reshaper = ReshapeInputEE(name="input_reshaper")
output_reshaper = ReshapeOutputEE(name="output_reshaper")

# Create the new inputs: a dictionary keyed by band name where each key
# stores a [H, W, 1] patch of inputs for the band.
inputs = {b: tf.keras.Input(shape=(None, None, 1), name=b) for b in BANDS}

# Create the model.
wrapped_model = input_reshaper(inputs)
wrapped_model = rf_model(wrapped_model)
wrapped_model = output_reshaper([inputs, wrapped_model])
wrapped_model = tf.keras.Model(inputs, wrapped_model, name="RF_with_reshaping")

## De/serialization

De/serialization prepares the model for hosting on Google Cloud.  Specifically, you need to provide image data to the Vertex AI API by sending the image data as base64-encoded text ([reference](https://cloud.google.com/vertex-ai/docs/general/base64)).  Wrap the trained model in de/serialization layers to handle the conversion.

In [None]:
class DeSerializeInput(tf.keras.layers.Layer):
  def __init__(self, **kwargs):
    super().__init__(**kwargs)

  def call(self, inputs_dict):
    return {
      k: tf.map_fn(lambda x: tf.io.parse_tensor(x, tf.float32),
                   tf.io.decode_base64(v),
                   fn_output_signature=tf.float32)
        for (k, v) in inputs_dict.items()
    }

  def get_config(self):
    config = super().get_config()
    return config


class ReSerializeOutput(tf.keras.layers.Layer):
  def __init__(self, **kwargs):
    super().__init__(**kwargs)

  def call(self, output_tensor):
    return tf.map_fn(lambda x: tf.io.encode_base64(tf.io.serialize_tensor(x)),
                    output_tensor,
                    fn_output_signature=tf.string)

  def get_config(self):
    config = super().get_config()
    return config

input_deserializer = DeSerializeInput()
output_deserializer = ReSerializeOutput()

serialized_inputs = {
    b: tf.keras.Input(shape=[], dtype='string', name=b) for b in BANDS
}

updated_model_input = input_deserializer(serialized_inputs)
updated_model = wrapped_model(updated_model_input)
updated_model = output_deserializer(updated_model)
updated_model= tf.keras.Model(serialized_inputs, updated_model)

In [None]:
tf.keras.utils.plot_model(updated_model)

## Save the model

In [None]:
# You may ignore compiler and absl warnings.
updated_model.save(MODEL_DIR)

## Deploy the model to Vertex AI

In [None]:
!gcloud ai models upload \
  --artifact-uri={MODEL_DIR} \
  --project={PROJECT} \
  --region={REGION} \
  --container-image-uri={CONTAINER_IMAGE} \
  --description={MODEL_NAME} \
  --display-name={MODEL_NAME} \
  --model-id={MODEL_NAME}

In [None]:
!gcloud ai endpoints create \
  --display-name={ENDPOINT_NAME} \
  --region={REGION} \
  --project={PROJECT}

In [None]:
ENDPOINT_ID = !gcloud ai endpoints list \
  --project={PROJECT} \
  --region={REGION} \
  --filter=displayName:{ENDPOINT_NAME} \
  --format="value(ENDPOINT_ID.scope())"
ENDPOINT_ID = ENDPOINT_ID[-1]
print(ENDPOINT_ID)

In [None]:
!gcloud ai endpoints deploy-model {ENDPOINT_ID} \
  --project={PROJECT} \
  --region={REGION} \
  --model={MODEL_NAME} \
  --machine-type=n1-standard-4 \
  --display-name={MODEL_NAME}

## Connect to the hosted model from Earth Engine

1. Generate the input imagery.  This should be done in exactly the same way as the training data were generated.  See [this example notebook](http://colab.research.google.com/github/google/earthengine-api/blob/master/python/examples/ipynb/TF_demo1_keras.ipynb) for details.
2. Connect to the hosted model.
3. Use the model to make predictions.
4. Display the results.

Note that it may take the model a couple minutes to spin up before it is ready to make predictions.

In [None]:
def mask_L8_sr(image):
  """Cloud masking function for Landsat 8, collection 2."""
  qa_mask = image.select('QA_PIXEL').bitwiseAnd(31).eq(0)
  saturation_mask = image.select('QA_RADSAT').eq(0)

  optical_bands = image.select('SR_B.').multiply(0.0000275).add(-0.2)
  thermal_bands = image.select('ST_B.*').multiply(0.00341802).add(149.0)

  return (image.addBands(optical_bands, None, True)
                .addBands(thermal_bands, None, True)
                .updateMask(qa_mask)
                .updateMask(saturation_mask).select('SR_B*.'))

# The image input data is a 2018 cloud-masked median composite.
image = L8SR.filterDate(
    '2018-01-01', '2018-12-31').map(mask_L8_sr).select(BANDS).median().float()

# Get a URL to serve image tiles.
mapid = image.getMapId({'bands': ['SR_B4', 'SR_B3', 'SR_B2'], 'min': 0, 'max': 0.3})

# Use folium to visualize the imagery.
map = folium.Map(location=[37.6, -122.3], zoom_start=11)

# Inputs.
folium.TileLayer(
    tiles=mapid['tile_fetcher'].url_format,
    attr=ATTRIBUTION,
    overlay=True,
    name='median composite',
  ).add_to(map)

endpoint_path = (
    'projects/' + PROJECT + '/locations/' + REGION + '/endpoints/' + str(ENDPOINT_ID))

# Connect to the hosted model.
vertex_model = ee.Model.fromVertexAi(
  endpoint=endpoint_path,
  inputTileSize=[8, 8],
  proj=ee.Projection('EPSG:4326').atScale(30),
  fixInputProj=True,
  outputBands={'output': {
      'type': ee.PixelType.float(),
      'dimensions': 1
    }
  })

predictions = vertex_model.predictImage(image.select(BANDS).float())
probabilities = predictions.arrayFlatten([CLASS_NAMES])
probability_vis = {
    'bands': CLASS_NAMES, 'min': 0.2, 'max': 0.5, 'format': 'png'
}
probability_mapid = probabilities.getMapId(probability_vis)
folium.TileLayer(
    tiles=probability_mapid['tile_fetcher'].url_format,
    attr=ATTRIBUTION,
    overlay=True,
    name='predictions',
  ).add_to(map)

map.add_child(folium.LayerControl())
display(map)

#### **Warning!** This demo consumes billable resources of Google Cloud, including Earth Engine, Vertex AI and Cloud Storage.  Be sure to shut down any prediction nodes to avoid ongoing charges.