In [None]:
import io
import json
import os
from typing import List, Union

import ee
from google.auth.transport.requests import AuthorizedSession
from google.oauth2 import service_account
import numpy as np
import tensorflow as tf

# Set these with your values
# Path to service account key file
os.environ['GA_AUTH_FILE']=''
# Service account address
os.environ['GEE_SERVICE_ACCOUNT']=''

def authenticate(key_file: str = os.environ["GA_AUTH_FILE"]) -> AuthorizedSession:

    gcs_credentials = service_account.Credentials.from_service_account_file(key_file)
    ee_creds = ee.ServiceAccountCredentials(os.environ["GEE_SERVICE_ACCOUNT"], key_file)
    ee.Initialize(ee_creds)
    scoped_credentials = gcs_credentials.with_scopes(
        ["https://www.googleapis.com/auth/cloud-platform"]
    )

    return AuthorizedSession(scoped_credentials)


compute_url = "https://earthengine.googleapis.com/v1beta/projects/earthengine-public/image:computePixels"


def get_asset_url(asset_id):
    name = f"projects/earthengine-public/assets/{asset_id}"
    return f"https://earthengine.googleapis.com/v1beta/{name}"


def get_asset_info(asset_id, session):
    return json.loads(session.get(get_asset_url(asset_id)).content)


def get_chip(
    coords: List,
    image: Union[str, ee.Image],
    scale: float,
    session: AuthorizedSession,
):
    query = {
        "fileFormat": "NPY",
        "grid": {
            "affineTransform": {
                "scaleX": scale,
                "scaleY": scale,
                "translateX": coords[0],
                "translateY": coords[1],
            },
            "dimensions": {"width": 512, "height": 512},
        },
    }

    if isinstance(image, (ee.Image)):
        url = compute_url
        query["expression"] = ee.serializer.encode(image)
    else:
        url = get_asset_url(image) + ":getPixels"

    chip_response = session.post(url, json.dumps(query))

    chip = np.load(io.BytesIO(chip_response.content)).astype("float32")
    # Pulls out nodata values
    return np.where(chip < 0.0, 0.0, chip)


def get_chips(
    pt_tf, feature_image, label_image, scale, session
):
    feature_chip = get_chip(
        pt_tf.numpy().tolist(), feature_image, scale, session
    )

    label_chip = get_chip(
        pt_tf.numpy().tolist(), label_image, scale, session
    )

    return (
        np.expand_dims(np.expand_dims(feature_chip, axis=0), axis=-1),
        np.expand_dims(np.expand_dims(label_chip, axis=0), axis=-1),
    )


def get_points(n=100):
    countries = ee.FeatureCollection("FAO/GAUL_SIMPLIFIED_500m/2015/level0")
    germany = countries.filter(ee.Filter.eq("ADM0_NAME", "Germany"))
    pts = ee.FeatureCollection.randomPoints(region=germany, points=n)
    return tf.convert_to_tensor(pts.geometry().coordinates().getInfo())


In [8]:
from functools import partial

import ee
import tensorflow as tf
from segmentation_models import Unet

# I create a partial so we can use the map function with only a single
# argument. I also wrap in tf.py_function since the data are retrieved
# as numpy arrays and not tensors.
def get_loaded_chips(pt_tf):
    return tf.py_function(
        partial(
            get_chips,
            feature_image=dem_id,
            label_image=slope,
            scale=scale,
            session=session,
        ),
        [pt_tf],
        [tf.float32, tf.float32],
    )


def get_dataset(points):
  return (
      tf.data.Dataset.from_tensor_slices(points)
      .map(get_loaded_chips, num_parallel_calls=tf.data.AUTOTUNE)
      .prefetch(tf.data.AUTOTUNE)
      .cache()
  )

session = authenticate()
# Shuttle Radar Topography Mission Digital Elevation Model
dem_id = "CGIAR/SRTM90_V4"
slope = ee.Terrain.slope(ee.Image(dem_id))

# Retrieve the native scale of the DEM. It is in EPSG:4326,
# so the points are in the correct transformation.
dem_info = get_asset_info(dem_id, session)
scale = dem_info["bands"][0]["grid"]["affineTransform"]["scaleX"]

dataset = get_dataset(get_points(1200))
v_dataset = get_dataset(get_points(400))

# Unet model with resnet34 backbone. Since the feature data only
# has one band, we need to change the input weight and set `encoder_weights` to 
# None. We also change the activation function since we're modeling a 
#quantitative output.

model = Unet(
    "resnet34",
    input_shape=(None, None, 1),
    activation="linear",
    classes=1,
    encoder_weights=None,
)

model.compile("SGD", "MeanSquaredError", ["RootMeanSquaredError"])
model.fit(dataset, batch_size=25, epochs=5, validation_data=v_dataset) 

Epoch 1/5
Epoch 2/5
Epoch 3/5
Epoch 4/5
Epoch 5/5


<keras.callbacks.History at 0x7f5a7b727990>