In [None]:
#@title Copyright 2023 Google LLC. { display-mode: "form" }
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

<table class="ee-notebook-buttons" align="left"><td>
<a target="_blank"  href="http://colab.research.google.com/github/google/earthengine-community/blob/master/guides/linked/Earth_Engine_Vertex_AI_training_demo.ipynb">
    <img src="https://www.tensorflow.org/images/colab_logo_32px.png" /> Run in Google Colab</a>
</td><td>
<a target="_blank"  href="https://github.com/google/earthengine-community/blob/master/guides/linked/Earth_Engine_Vertex_AI_training_demo.ipynb"><img width=32px src="https://www.tensorflow.org/images/GitHub-Mark-32px.png" /> View source on GitHub</a></td></table>

# Training a model in Vertex AI

This notebook demonstrates training a convolutional neural network on Vertex AI.  The trained model is suitable for use in Earth Engine with `ee.Model.fromVertexAIPredictor`.  The model is a simple convolutional model of land
cover.  The training data generation and model setup are described in detail in the [2022 Geo for Good Deep Learning session](https://earthoutreachonair.withgoogle.com/events/geoforgood22?talk=day1-trackthree-talk2).

**Running this demo may incur charges to your Google Cloud Account!**

## Installs and imports

In [None]:
!pip install -U google-cloud-aiplatform

In [None]:
from google.cloud import aiplatform
from google.colab import auth

import ee
import folium
import google
import keras
import tensorflow as tf

In [None]:
auth.authenticate_user()

In [None]:
PROJECT = 'your-project'

In [None]:
credentials, _ = google.auth.default()
ee.Initialize(credentials, project=PROJECT, opt_url='https://earthengine-highvolume.googleapis.com')

## Variable declarations

In [None]:
REGION = 'us-central1'
# Trained model output locations.  REPLACE WITH YOUR BUCKET!
OUTPUT_DIR = 'your-bucket'
EEIFIED_DIR = 'your-bucket'

# Name of package containing training code.
PACKAGE_PATH = 'demo_model'
# Name of the hosted model and endpoint.
MODEL_NAME = 'demo_lc_model'
ENDPOINT_NAME = 'demo_lc_endpoint'
# A container image that can run the hosted model.
CONTAINER_IMAGE = 'us-docker.pkg.dev/vertex-ai/prediction/tf2-cpu.2-11:latest'

# The time range over which to compute composites.
START = '2020-1-1'
END = '2021-1-1'

# A random spot near Ho Chi Minh City, Vietnam.
COORDS = [105.695, 9.883]

# Sentinel-2 bands to be used in prediction.
BANDS = ['B1', 'B2', 'B3', 'B4', 'B5', 'B6', 'B7',
          'B8', 'B8A', 'B9', 'B11', 'B12']
CLASSIFICATIONS = {
    "Water":              "419BDF",
    "Trees":              "397D49",
    "Grass":              "88B053",
    "Flooded vegetation": "7A87C6",
    "Crops":              "E49635",
    "Shrub and scrub":    "DFC35A",
    "Built-up areas":     "C4281B",
    "Bare ground":        "A59B8F",
    "Snow and ice":       "B39FE1",
}

ATTRIBUTION = 'Map Data &copy; <a href="https://earthengine.google.com/">Google Earth Engine</a>'

## Make a training package

Make a directory to hold the training code.  This script will be uploaded to Vertex AI in order to start the training job.

In [None]:
!rm -rf {PACKAGE_PATH}
!mkdir {PACKAGE_PATH}
!ls -l

Make a self-contained file that will load datasets, train and save the model.

In [None]:
%%writefile {PACKAGE_PATH}/task.py

import numpy as np
import tensorflow as tf
from tensorflow import keras
from tensorflow.keras import layers


gpus = tf.config.list_physical_devices('GPU')
if gpus:
  try:
    # Currently, memory growth needs to be the same across GPUs
    for gpu in gpus:
      tf.config.experimental.set_memory_growth(gpu, True)
  except RuntimeError as e:
    # Memory growth must be set before GPUs have been initialized
    print(e)


OUTPUT_DIR = 'your-bucket'

# Put gs://ee-docs-demos/g4g-tf-demos/tiles_2022/*.gz into a us-central1 bucket.
TRAINING_PATTERN = 'your-bucket/training*.tfrecord.gz'
VALIDATION_PATTERN = 'your-bucket/validation*.tfrecord.gz'

PATCH_SIZE = 128  # Pixels
SCALE = 10        # Meters per pixel
# Predictors.
BANDS = ['B1', 'B2', 'B3', 'B4', 'B5', 'B6', 'B7',
          'B8', 'B8A', 'B9', 'B11', 'B12']
# Target variable.
LABEL = 'landcover'
CLASSIFICATIONS = {
    "Water":              "419BDF",
    "Trees":              "397D49",
    "Grass":              "88B053",
    "Flooded vegetation": "7A87C6",
    "Crops":              "E49635",
    "Shrub and scrub":    "DFC35A",
    "Built-up areas":     "C4281B",
    "Bare ground":        "A59B8F",
    "Snow and ice":       "B39FE1",
}
# Input stack.
FEATURE_NAMES = BANDS + [LABEL]

COLUMNS = [
  tf.io.FixedLenFeature(shape=[PATCH_SIZE, PATCH_SIZE], dtype=tf.float32)
  for k in FEATURE_NAMES
]
FEATURES_DICT = dict(zip(FEATURE_NAMES, COLUMNS))


class Augment(tf.keras.layers.Layer):
  def __init__(self, seed=42):
    super().__init__()
    # both use the same seed, so they'll make the same random changes.
    self.augment_inputs = {
      b: tf.keras.layers.RandomFlip(mode="horizontal_and_vertical", seed=seed)
      for b in BANDS}
    self.augment_labels = tf.keras.layers.RandomFlip(
        mode="horizontal_and_vertical", seed=seed)

  def call(self, inputs, labels):
    inputs = {b: self.augment_inputs[b](inputs[b]) for b in BANDS}
    labels = self.augment_labels(labels)
    return inputs, labels


def parse_tfrecord(example_proto):
  """Deserialize an example proto.  Return (inputs, labels)"""
  parsed_features = tf.io.parse_example(example_proto, FEATURES_DICT)
  labels = parsed_features.pop(LABEL)
  return ({k: tf.expand_dims(v, axis=2) for k, v in parsed_features.items()},
          tf.one_hot(tf.cast(labels, tf.uint8), len(CLASSIFICATIONS)))


def get_dataset(pattern):
  dataset = tf.data.Dataset.list_files(pattern).interleave(
      lambda filename: tf.data.TFRecordDataset(filename, compression_type='GZIP'))
  dataset = dataset.map(parse_tfrecord, num_parallel_calls=tf.data.AUTOTUNE)
  # dataset = dataset.map(to_tuple, num_parallel_calls=tf.data.AUTOTUNE)
  dataset = dataset.cache()
  dataset = dataset.shuffle(1024)
  dataset = dataset.prefetch(buffer_size=tf.data.AUTOTUNE)
  return dataset


def get_model(input_depth, num_classes):
    inputs = keras.Input(shape=[None, None, input_depth], name='array')

    ### [First half of the network: downsampling inputs] ###

    # Entry block
    x = layers.Conv2D(32, 3, strides=2, padding="same")(inputs)
    x = layers.BatchNormalization()(x)
    x = layers.Activation("relu")(x)

    previous_block_activation = x  # Set aside residual

    for filters in [64, 128]:
        x = layers.Activation("relu")(x)
        x = layers.SeparableConv2D(filters, 3, padding="same")(x)
        x = layers.BatchNormalization()(x)

        x = layers.Activation("relu")(x)
        x = layers.SeparableConv2D(filters, 3, padding="same")(x)
        x = layers.BatchNormalization()(x)

        x = layers.MaxPooling2D(3, strides=2, padding="same")(x)

        # Project residual
        residual = layers.Conv2D(filters, 1, strides=2, padding="same")(
            previous_block_activation
        )
        x = layers.add([x, residual])  # Add back residual
        previous_block_activation = x  # Set aside next residual

    ### [Second half of the network: upsampling inputs] ###

    for filters in [128, 64, 32]:
        x = layers.Activation("relu")(x)
        x = layers.Conv2DTranspose(filters, 3, padding="same")(x)
        x = layers.BatchNormalization()(x)

        x = layers.Activation("relu")(x)
        x = layers.Conv2DTranspose(filters, 3, padding="same")(x)
        x = layers.BatchNormalization()(x)

        x = layers.UpSampling2D(2)(x)

        # Project residual
        residual = layers.UpSampling2D(2)(previous_block_activation)
        residual = layers.Conv2D(filters, 1, padding="same")(residual)
        x = layers.add([x, residual])  # Add back residual
        previous_block_activation = x  # Set aside next residual

    # Add a per-pixel classification layer
    outputs = layers.Conv2D(num_classes, 3, activation="softmax", padding="same")(x)

    return keras.Model(inputs, outputs)


# A Layer to stack and reshape the input tensors.
class MyPreprocessing(keras.layers.Layer):
  def __init__(self, **kwargs):
    super(MyPreprocessing, self).__init__(**kwargs)

  def call(self, features_dict):
    # (None, H, W, 1) -> (None, H, W, P)
    return tf.concat([features_dict[b] for b in BANDS], axis=3)

  def get_config(self):
    config = super().get_config()
    return config


# A Model that wraps the base model with the preprocessing layer.
class MyModel(keras.Model):
  def __init__(self, preprocessing, backbone, **kwargs):
    super().__init__(**kwargs)
    self.preprocessing = preprocessing
    self.backbone = backbone

  def call(self, features_dict):
    x = self.preprocessing(features_dict)
    return self.backbone(x)

  def get_config(self):
    config = super().get_config()
    return config


if __name__ == '__main__':

  training_dataset = get_dataset(TRAINING_PATTERN).map(
    Augment(), num_parallel_calls=tf.data.AUTOTUNE).batch(32)
  validation_dataset = get_dataset(VALIDATION_PATTERN).batch(1)

  foo, bar = iter(training_dataset).next()

  model = get_model(len(BANDS), len(CLASSIFICATIONS))
  m = MyModel(MyPreprocessing(), model)

  m.compile(
      optimizer='adam',
      loss='categorical_crossentropy',
      metrics=[
          tf.keras.metrics.OneHotIoU(
              num_classes=len(CLASSIFICATIONS),
              target_class_ids=list(range(len(CLASSIFICATIONS))),
          ),
          tf.keras.metrics.Accuracy(),
      ]
  )

  m.fit(
      training_dataset,
      validation_data=validation_dataset,
      epochs=25,
      callbacks=[tf.keras.callbacks.TensorBoard(
          'your-bucket/logs', histogram_freq=1)],
  )

  m.save(OUTPUT_DIR)

In [None]:
aiplatform.init(project=PROJECT, location=REGION, staging_bucket=OUTPUT_DIR)

job = aiplatform.CustomTrainingJob(
    display_name='demo-fcnn-training',
    script_path=f'{PACKAGE_PATH}/task.py',
    container_uri='us-docker.pkg.dev/vertex-ai/training/tf-gpu.2-11:latest',
)

job.run(
    machine_type='n1-standard-4',
    accelerator_type='NVIDIA_TESLA_T4',
    accelerator_count=1,
    args=[],
)

## Host the model

Reload the model saved by the training job and inspect the inputs it expects.  Note that the spatial dimensions are `None` to allow flexible input tile sizing.  Wrap the model in de/serialization layers to do the `base64` conversion.  Then re-save the wrapped model, which is ready to be hosted on Vertex AI.

In [None]:
trained_model = keras.models.load_model(OUTPUT_DIR)

In [None]:
print(trained_model.inputs)

In [None]:
class DeSerializeInput(tf.keras.layers.Layer):
  def __init__(self, **kwargs):
    super().__init__(**kwargs)

  def call(self, inputs_dict):
    return {
      k: tf.map_fn(lambda x: tf.io.parse_tensor(x, tf.float32),
                   tf.io.decode_base64(v),
                   fn_output_signature=tf.float32)
        for (k, v) in inputs_dict.items()
    }

  def get_config(self):
    config = super().get_config()
    return config


class ReSerializeOutput(tf.keras.layers.Layer):
  def __init__(self, **kwargs):
    super().__init__(**kwargs)

  def call(self, output_tensor):
    return tf.map_fn(lambda x: tf.io.encode_base64(tf.io.serialize_tensor(x)),
                    output_tensor,
                    fn_output_signature=tf.string)

  def get_config(self):
    config = super().get_config()
    return config

input_deserializer = DeSerializeInput()
output_deserializer = ReSerializeOutput()

serialized_inputs = {
    b: tf.keras.Input(shape=[], dtype='string', name=b) for b in BANDS
}

updated_model_input = input_deserializer(serialized_inputs)
updated_model = trained_model(updated_model_input)
updated_model = output_deserializer(updated_model)
updated_model = tf.keras.Model(serialized_inputs, updated_model)

In [None]:
updated_model.save(EEIFIED_DIR)

# Deploy the model on Vertex AI

Upload the model artifacts to Vertex AI create a model.  Create an endpoint and deploy the model to the endpoint.

### [Upload the model](https://cloud.google.com/sdk/gcloud/reference/ai/models/upload)
Add an entry to the model registry that points to the location of the saved model and a container image needed to run the model.

In [None]:
!gcloud ai models delete {MODEL_NAME} --project={PROJECT} --region={REGION}

In [None]:
!gcloud ai models upload \
  --artifact-uri={EEIFIED_DIR} \
  --project={PROJECT} \
  --region={REGION} \
  --container-image-uri={CONTAINER_IMAGE} \
  --description={MODEL_NAME} \
  --display-name={MODEL_NAME} \
  --model-id={MODEL_NAME}

### [Create an endpoint](https://cloud.google.com/sdk/gcloud/reference/ai/endpoints/create)

Create an endpoint from which to serve the model.

In [None]:
!gcloud ai endpoints create \
  --display-name={ENDPOINT_NAME} \
  --region={REGION} \
  --project={PROJECT}

### [Deploy the model to the endpoint](https://cloud.google.com/sdk/gcloud/reference/ai/endpoints/deploy-model)

First, look up the endpoint ID, then deploy the model.

In [None]:
ENDPOINT_ID = !gcloud ai endpoints list \
  --project={PROJECT} \
  --region={REGION} \
  --filter=displayName:{ENDPOINT_NAME} \
  --format="value(ENDPOINT_ID.scope())"
ENDPOINT_ID = ENDPOINT_ID[-1]

In [None]:
!gcloud ai endpoints deploy-model {ENDPOINT_ID} \
  --project={PROJECT} \
  --region={REGION} \
  --model={MODEL_NAME} \
  --display-name={MODEL_NAME}

# Connect to the hosted model from Earth Engine

1. Generate the input imagery.  This should be done in exactly the same way as the training data were generated.  See [this example notebook](http://colab.research.google.com/github/google/earthengine-api/blob/master/python/examples/ipynb/TF_demo1_keras.ipynb) for details.
2. Connect to the hosted model.
3. Use the model to make predictions.
4. Display the results.

Note that it may take the model a couple minutes to spin up before it is ready to make predictions.

In [None]:
def get_s2_composite(roi, start, end):
  """Get a cloud-free median composite in the specified ROI and date range."""
  s2c = ee.ImageCollection('COPERNICUS/S2_CLOUD_PROBABILITY')
  s2sr = ee.ImageCollection('COPERNICUS/S2_SR_HARMONIZED')
  s2c = s2c.filterBounds(roi.buffer(100*1000, 1000)).filterDate(start, end)
  s2sr = s2sr.filterBounds(roi.buffer(100*1000, 1000)).filterDate(start, end)

  def index_join(collection_a, collection_b, property_name):
    joined = ee.ImageCollection(
        ee.Join.saveFirst(property_name).apply(
            primary=collection_a,
            secondary=collection_b,
            condition=ee.Filter.equals(
                leftField='system:index',
                rightField='system:index')))
    return joined.map(
        lambda image: image.addBands(ee.Image(image.get(property_name))))

  def mask_image(image):
    prob = image.select('probability')
    return image.select('B.*').divide(10000).updateMask(prob.lt(50))

  with_cloud_probability = index_join(s2sr, s2c, 'cloud_probability')
  masked = ee.ImageCollection(with_cloud_probability.map(mask_image))
  return masked.select(BANDS).median().float().unmask(0)

image = get_s2_composite(ee.Geometry.Point(COORDS), START, END)

# Get a URL to serve image tiles.
mapid = image.getMapId({'bands': ['B4', 'B3', 'B2'], 'min': 0, 'max': 0.3})
# Use folium to visualize the imagery.
map = folium.Map(location=[COORDS[1], COORDS[0]], zoom_start=14)

# Inputs.
folium.TileLayer(
    tiles=mapid['tile_fetcher'].url_format,
    attr=ATTRIBUTION,
    overlay=True,
    name='median composite',
  ).add_to(map)

endpoint_path = (
    'projects/' + PROJECT + '/locations/' + REGION + '/endpoints/' + str(ENDPOINT_ID))

# Connect to the hosted model.
vertex_model = ee.Model.fromVertexAi(**{
  'endpoint': endpoint_path,
  'inputTileSize': [64, 64],
  'inputOverlapSize': [32, 32],
  'proj': ee.Projection('EPSG:4326').atScale(10),
  'fixInputProj': True,
  'outputBands': {'output': {
      'type': ee.PixelType.float(),
      'dimensions': 1
    }
  }
})

predictions = vertex_model.predictImage(image.select(BANDS).float())
labels = predictions.arrayArgmax().arrayGet(0).rename('label')

vis_params = {
  'min': 0,
  'max': len(CLASSIFICATIONS) - 1,
  'palette': list(CLASSIFICATIONS.values()),
  'bands': ['label'],
}

tile_url = labels.getMapId(vis_params)['tile_fetcher'].url_format
folium.TileLayer(
    tiles=tile_url,
    attr=ATTRIBUTION,
    overlay=True,
    name='labels',
  ).add_to(map)

map.add_child(folium.LayerControl())
display(map)