In [None]:
#@title Copyright 2023 Google LLC. { display-mode: "form" }
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

<table class="ee-notebook-buttons" align="left"><td>
<a target="_blank"  href="http://colab.research.google.com/github/google/earthengine-community/blob/master/guides/linked/Earth_Engine_TensorFlow_tree_counting_model.ipynb">
    <img src="https://www.tensorflow.org/images/colab_logo_32px.png" /> Run in Google Colab</a>
</td><td>
<a target="_blank"  href="https://github.com/google/earthengine-community/blob/master/guides/linked/Earth_Engine_TensorFlow_tree_counting_model.ipynb"><img width=32px src="https://www.tensorflow.org/images/GitHub-Mark-32px.png" /> View source on GitHub</a></td></table>

# Counting trees with Earth Engine and a custom TensorFlow model

Ever wonder how many trees there are?  Of course you have.  We were wondering the same thing.  People have been counting trees in imagery for a good long time ([Wang et al. 2004](https://doi.org/10.14358/PERS.70.3.351) have a nice little summary).  This kind of counting is useful for forest inventory, to understand species composition, density, diameter and height distributions.  Given that convolutional neural networks are fashionable nowadays, we wondered if tree-crown segmentation was a thing.  It is!  [Li et al. 2023](https://doi.org/10.1093/pnasnexus/pgad076) have a great paper, just out.  It's also commendable that they publish their code and other artifacts on GitHub ([their repo](https://github.com/sizhuoli/TreeCountSegHeight/tree/main)).  Specifically, they've got a bunch of trained TensorFlow models in there.  Jackpot!

## Host a trained model on Vertex AI to get predictions in Earth Engine

If there's suitable input data in Earth Engine (check the [extensive data catalog](https://developers.google.com/earth-engine/datasets) or [upload your own](https://developers.google.com/earth-engine/guides/image_upload)), you can get interactive predictions from geospatially-aware models hosted on Vertex AI.  By geospatially aware, we mean models that accept three-dimensional (four if you include the batch dimension) inputs: `[height, width, channels]`.  For deployment in a map, you also need to know the scale of the pixel, in ground units (meters).  [Li et al. 2023](https://doi.org/10.1093/pnasnexus/pgad076) explain most of this, and the rest is discoverable from the saved model itself.

To get started, import the necessary libraries and authenticate.

#### **Warning!** This demo consumes billable resources of Google Cloud, including Earth Engine, Vertex AI and Cloud Storage.

In [None]:
import ee
import folium
import google
import tensorflow as tf

from google.colab import auth
from keras import backend as K
from oauth2client.client import GoogleCredentials
from pydrive.auth import GoogleAuth
from pydrive.drive import GoogleDrive

In [None]:
auth.authenticate_user()
credentials, project = google.auth.default()

In [None]:
# REPLACE WITH YOUR CLOUD PROJECT!
PROJECT = 'your-project'
# REPLACE WITH YOUR (WRITABLE) CLOUD BUCKET!
BUCKET = 'your-bucket'

# Hosted model and endpoint names.
MODEL_NAME = 'trees_20210620-0202_adam_e4_redgreenblue_256'
ENDPOINT_NAME = 'trees_endpoint'
REGION = 'us-central1'

# A container image that can run the hosted model.
# See https://cloud.google.com/vertex-ai/docs/predictions/pre-built-containers#tensorflow
CONTAINER_IMAGE = 'us-docker.pkg.dev/vertex-ai/prediction/tf2-gpu.2-11:latest'

ATTRIBUTION = 'Map Data &copy; <a href="https://earthengine.google.com/">Google Earth Engine</a>'

In [None]:
ee.Initialize(credentials, project=PROJECT)

### Download a trained model from GitHub

Grab a set of trained models from the link in the [GitHub repo](https://github.com/sizhuoli/TreeCountSegHeight/blob/main/models/readme_models.md).  Note that there are several models, with different input set options.  We're assuming that all we have is RGB imagery (i.e. no NIR or DSM data).  Transfer that model to your Cloud Storage bucket.  From there, load the model directly.  Note that we're passing the Keras backend as a custom object.  Since we don't need to (re)train the model, we don't need to compile it.

In [None]:
gauth = GoogleAuth()
gauth.credentials = GoogleCredentials.get_application_default()
drive = GoogleDrive(gauth)

file_id = '1ZNibrh6pa4-cjXLawua6L96fOKS3uwbn'
downloaded = drive.CreateFile({'id': file_id})
downloaded.GetContentFile('saved_models.zip')

In [None]:
!unzip saved_models.zip

In [None]:
!ls -l saved_models

In [None]:
# Grab the RGB model.
RGB_MODEL = 'trees_20210620-0202_Adam_e4_redgreenblue_256_84_frames_weightmapTversky_MSE100_5weight_attUNet.h5'

In [None]:
model_file = 'saved_models/' + RGB_MODEL
model = tf.keras.models.load_model(model_file, custom_objects={'K': K}, compile=False)

Run a few checks to investigate the model.

In [None]:
if model.trainable:
  print('The model is trained.')
else:
  print('The model is not trained.')

In [None]:
model.summary()

In [None]:
zeros = tf.zeros((1, 256, 256, 3))
model(zeros)

### Prepare the model for hosting

Note that there is a list of two outputs.  Layers called `output_seg` and `output_dens`.  Wrap the model in de/serialization layers, for hosting on Vertex AI.  Save the wrapped model.

In [None]:
class DeSerializeInput(tf.keras.layers.Layer):
  def __init__(self, **kwargs):
    super().__init__(**kwargs)

  def call(self, inputs_dict):
    return_dict={}

    for (k, v) in inputs_dict.items():
      decoded = tf.io.decode_base64(v)
      return_dict[k] = tf.map_fn(lambda x: tf.io.parse_tensor(x, tf.float32),
                                 decoded, fn_output_signature=tf.float32)

    return return_dict

  def get_config(self):
    config = super().get_config()
    return config


class ReSerializeOutput(tf.keras.layers.Layer):
  def __init__(self, **kwargs):
    super().__init__(**kwargs)

  def call(self, outputs_list):
      return [tf.map_fn(
          lambda x: tf.io.encode_base64(
              tf.io.serialize_tensor(x)), tensor, fn_output_signature=tf.string)
      for tensor in outputs_list]

  def get_config(self):
    config = super().get_config()
    return config


input_deserializer = DeSerializeInput()
output_deserializer = ReSerializeOutput()

serialized_inputs = {
    model.inputs[0].name: tf.keras.Input(shape=[], dtype='string', name='array_image')
}

updated_model_input = input_deserializer(serialized_inputs)
updated_model = model(updated_model_input)
updated_model = output_deserializer(updated_model)
updated_model = tf.keras.Model(serialized_inputs, updated_model)

In [None]:
MODEL_DIR = f'gs://{BUCKET}/{RGB_MODEL[:-3]}'
updated_model.save(MODEL_DIR)

### Host the model

First delete any models of the same name.  If you get an error that model doesn't exist, you can ignore that.  If you get an error because the model exists and is deployed to an endpoint, you will need to either rename the model or undeploy the previously deployed model from the [Cloud Console](https://console.cloud.google.com/vertex-ai/online-prediction/endpoints).

In [None]:
!gcloud ai models delete {MODEL_NAME} --project={PROJECT} --region={REGION}

In [None]:
!gcloud ai models upload \
  --artifact-uri={MODEL_DIR} \
  --project={PROJECT} \
  --region={REGION} \
  --container-image-uri={CONTAINER_IMAGE} \
  --description={MODEL_NAME} \
  --display-name={MODEL_NAME} \
  --model-id={MODEL_NAME}

In [None]:
!gcloud ai endpoints create \
  --display-name={ENDPOINT_NAME} \
  --region={REGION} \
  --project={PROJECT}

In [None]:
ENDPOINT_ID = !gcloud ai endpoints list \
  --project={PROJECT} \
  --region={REGION} \
  --filter=displayName:{ENDPOINT_NAME} \
  --format="value(ENDPOINT_ID.scope())"
ENDPOINT_ID = ENDPOINT_ID[-1]

In [None]:
!gcloud ai endpoints deploy-model {ENDPOINT_ID} \
  --project={PROJECT} \
  --region={REGION} \
  --model={MODEL_NAME} \
  --machine-type=n1-standard-8 \
  --accelerator=type=nvidia-tesla-t4,count=1 \
  --display-name={MODEL_NAME}

## Connect to the hosted model from Earth Engine

It may take a few minutes for the model to display predictions.

In [None]:
estonia_images = ee.ImageCollection("Estonia/Maamet/orthos/rgb")
image = estonia_images.mosaic().float()

# Get a URL to serve image tiles.
mapid = image.getMapId({'bands': ['R', 'G', 'B'], 'min': 0, 'max': 128})

# Use folium to visualize the imagery.
map = folium.Map(location=[59.246333, 25.463968], zoom_start=21)

# Inputs.
folium.TileLayer(
    tiles=mapid['tile_fetcher'].url_format,
    attr=ATTRIBUTION,
    overlay=True,
    name='median composite',
  ).add_to(map)

endpoint_path = (
    'projects/' + PROJECT + '/locations/' + REGION + '/endpoints/' + ENDPOINT_ID)

# Connect to the hosted model.
vertex_model = ee.Model.fromVertexAi(
  endpoint=endpoint_path,
  inputTileSize=[128, 128],
  inputOverlapSize=[64, 64],
  proj=ee.Projection('EPSG:4326').atScale(0.2),
  fixInputProj=True,
  outputBands={
    're_serialize_output': {
      'type': ee.PixelType.float(),
      'dimensions': 1
    },
    're_serialize_output_1': {
      'type': ee.PixelType.float(),
      'dimensions': 1
    }
  }
)

# Normalization.
stats = image.reduceNeighborhood(
  reducer=ee.Reducer.mean().combine(ee.Reducer.stdDev(), None, True),
  kernel=ee.Kernel.square(10, 'meters'),
  optimization='window'
)
means = stats.select(['R_mean', 'G_mean', 'B_mean'])
sds = stats.select(['R_stdDev', 'G_stdDev', 'B_stdDev'])
input_image = (image.select(['R', 'G', 'B'])
    .subtract(means).divide(sds).float().toArray().rename('array_image'))

# Predictions.
predictions = vertex_model.predictImage(input_image)
seg = predictions.select('re_serialize_output').arrayGet([0])
prob = predictions.select('re_serialize_output_1').arrayGet([0])

seg_mapid = seg.getMapId({'min': 0, 'max': 0.5})
folium.TileLayer(
    tiles=seg_mapid['tile_fetcher'].url_format,
    attr=ATTRIBUTION,
    overlay=True,
    name='predictions',
  ).add_to(map)

map.add_child(folium.LayerControl())
display(map)

#### **Warning!** This demo consumes billable resources of Google Cloud, including Earth Engine, Vertex AI and Cloud Storage.  Be sure to shut down any prediction nodes to avoid ongoing charges.