<a href="https://colab.research.google.com/github/boothmanrylan/canadaMSSForestDisturbances/blob/main/SpatioTemporalUNet.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

# Setup


In [None]:
%load_ext autoreload
%autoreload 2

In [None]:
from google.colab import auth
auth.authenticate_user()

PROJECT_ID = "api-project-269347469410"
!gcloud config set project {PROJECT_ID}

In [None]:
import os
import math
import json

import tensorflow as tf
from tensorflow.python.tools import saved_model_utils
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
from matplotlib.colors import ListedColormap, BoundaryNorm

# Config

In [None]:
NUM_OUTPUTS = 10
EXPORT_HEIGHT = 512
EXPORT_WIDTH = 512
HEIGHT = 128
WIDTH = 128

# TODO: DEM will make an uneven number of bands that doesn't split into current/past as the DEM wont have changed
BANDS = [
    'nir', 'red_edge', 'red', 'green',
    'tca', 'ndvi',
    'historical_nir', 'historical_red_edge', 'historical_red', 'historical_green',
    'historical_tca', 'historical_ndvi'
]
METADATA = ['doy', 'ecozone'] # , 'lat', 'lon']
NUM_INPUTS = len([b for b in BANDS if "historical" not in b])
LABEL_BAND = 'label'

IMAGE_INPUT_LAYER_NAME = 'image'
ECOZONE_INPUT_LAYER_NAME = 'ecozone'
DOY_INPUT_LAYER_NAME = 'doy'

MAX_DOY = 110
NUM_ECOZONES = 10  # there are only seven represented in the sanity test dataset

# Data Config
BUCKET = 'rylan-mssforestdisturbances'
BASE_PATH = f'gs://{BUCKET}/scratch/test_export/ecozone*/'
TEST_PATTERN = os.path.join(BASE_PATH, '*-00000-of-*.tfrecord.gz')
TRAIN_PATTERN = os.path.join(BASE_PATH, '*-000[0-9][1-9]-of*.tfrecord.gz')

BATCH_SIZE = 32
SHUFFLE_BUFFER = 100

SUBSET_SIZE = 100

# Model Config
FILTERS = [32, 64, 128, 256]
KERNELS = [7, 5, 3, 3]
DILATION_RATES = [1, 1, 2, 4]
UPSAMPLE_FILTERS = 3
METADATA_FILTERS = 32
OUTPUT_KERNEL = 3
MODEL_CONFIG = list(zip(FILTERS, KERNELS, DILATION_RATES))

TWO_DOWNSTACKS = True
INCLUDE_HISTORICAL = True
INCLUDE_METADATA = True

DATA_AUGMENTATION = False

if not (TWO_DOWNSTACKS and INCLUDE_HISTORICAL):
    BANDS = [b for b in BANDS if "historical" not in b]

RNG = tf.random.Generator.from_seed(42, alg="philox")

# AI Platform Hosting Config
REGION = "us-central1"
MODEL_DIR = f"gs://{BUCKET}/scratch/models/"
EEIFIED_DIR = f"gs://{BUCKET}/scratch/eeified_models/test_model_hosting/"
MODEL_NAME = "test_model"
ENDPOINT_NAME = "test_endpoint"

# Load Data

In [None]:
IMAGE_FEATURES = {
    b: tf.io.FixedLenFeature(
        shape=(EXPORT_HEIGHT, EXPORT_WIDTH),
        dtype=tf.float32
    )
    for b in BANDS
}

LABEL_FEATURES = {
    LABEL_BAND: tf.io.FixedLenFeature(
        shape=(EXPORT_HEIGHT, EXPORT_WIDTH),
        dtype=tf.int64
    )
}

METADATA_FEATURES = {
    m: tf.io.FixedLenFeature(shape=1, dtype=tf.int64)
    for m in METADATA
}


def parse(example):
    x = tf.io.parse_single_example(example, IMAGE_FEATURES)
    x = tf.stack([x[b] for b in BANDS], axis=-1)

    y = tf.io.parse_single_example(example, LABEL_FEATURES)[LABEL_BAND]
    y = tf.one_hot(y, NUM_OUTPUTS)

    metadata = tf.io.parse_single_example(example, METADATA_FEATURES)
    metadata = [metadata[m] for m in METADATA]

    if INCLUDE_METADATA:
        x = (x, *metadata)

    return x, y


def non_overlapping_crop(x, y):
    assert EXPORT_HEIGHT % HEIGHT == 0
    assert EXPORT_WIDTH % WIDTH == 0

    def _crop(tensor):
        """ based on https://stackoverflow.com/a/31530106
        """
        tensor = tf.reshape(
            tensor,
            (EXPORT_HEIGHT // HEIGHT, HEIGHT, EXPORT_WIDTH // WIDTH, WIDTH, -1)
        )
        cropped = tf.experimental.numpy.swapaxes(tensor, 1, 2)

        num_blocks = (EXPORT_HEIGHT // HEIGHT) * (EXPORT_WIDTH // WIDTH)
        cropped = tf.reshape(cropped, (num_blocks, HEIGHT, WIDTH, -1))
        return tf.data.Dataset.from_tensor_slices(cropped)

    if INCLUDE_METADATA:
        metadata = [
            tf.data.Dataset.from_tensor_slices(m).repeat()
            for m in x[1:]
        ]
        x = x[0]

    x = _crop(x)
    y = _crop(y)

    if INCLUDE_METADATA:
        x = tf.data.Dataset.zip((x, *metadata))

    return tf.data.Dataset.zip((x, y))


def _apply_fn_to_xy(x, y, func):
    if INCLUDE_METADATA:
        metadata = x[1:]
        x = x[0]

    y_shape = tf.shape(y)
    if len(y_shape) == 2:  # add temporary channel dimension
        y = tf.reshape(y, y_shape + (1,))
        num_y_bands = 1
    else:
        num_y_bands = y_shape[-1]

    y_type = y.dtype
    desired_type = x.dtype
    y = tf.cast(y, desired_type)

    xy = tf.concat([x, y], -1)

    xy = func(xy)

    x = xy[:, :, :-num_y_bands]

    y = tf.squeeze(tf.cast(xy[:, :, -num_y_bands:], y_type))

    if INCLUDE_METADATA:
        x = (x, *metadata)

    return x, y


def crop(x, y, seed):
    y_shape = tf.shape(y)
    if len(y_shape) == 2:  # add temporary channel dimension
        y = tf.reshape(y, y_shape + (1,))
        num_y_bands = 1
    else:
        num_y_bands = y_shape[-1]

    if INCLUDE_METADATA:
        num_x_bands = tf.shape(x[0])[-1]
    else:
        num_x_bands = tf.shape(x)[-1]

    target_shape = (HEIGHT, WIDTH, num_x_bands + num_y_bands)

    def func(xy):
        return tf.image.stateless_random_crop(xy, target_shape, seed=seed)

    return _apply_fn_to_xy(x, y, func)


def crop_wrapper(x, y):
    seed = RNG.make_seeds(2)[0]
    x, y = crop(x, y, seed)
    return x, y


AUGMENTER = tf.keras.Sequential([
    tf.keras.layers.RandomFlip("horizontal_and_vertical"),
    tf.keras.layers.RandomRotation(0.2, "reflect"),
])

def data_augmentation(x, y):
    def func(xy):
        return AUGMENTER(xy, training=True)

    return _apply_fn_to_xy(x, y, func)


def build_dataset(tfrecord_pattern, train=True):
    tfrecords = tf.data.Dataset.list_files(tfrecord_pattern, shuffle=train)
    dataset = tfrecords.interleave(
        lambda x: tf.data.TFRecordDataset(x, compression_type='GZIP').map(parse, num_parallel_calls=1),
        cycle_length=3 * NUM_ECOZONES,
        block_length=BATCH_SIZE // 4,
        num_parallel_calls=tf.data.AUTOTUNE,
        deterministic=not train,
    )

    dataset = dataset.cache()

    if train:
        dataset = dataset.shuffle(SHUFFLE_BUFFER)
        dataset = dataset.map(crop_wrapper, num_parallel_calls=tf.data.AUTOTUNE)
    else:
        dataset = dataset.flat_map(non_overlapping_crop)

    if train:
        subset = []
        for x, y in dataset.take(math.ceil(SUBSET_SIZE / BATCH_SIZE)):
            if INCLUDE_METADATA:
                subset.append(x[0])
            else:
                subset.append(x)
        subset = tf.concat(subset, axis=0)

        if DATA_AUGMENTATION:  # do this after creating subset for normalization
            dataset = dataset.map(
                data_augmentation,
                num_parallel_calls=tf.data.AUTOTUNE
            )

        dataset = dataset.repeat()

    dataset = dataset.batch(BATCH_SIZE)
    dataset = dataset.prefetch(tf.data.AUTOTUNE)

    if train:
        return dataset, subset
    return dataset


train_dataset, normalize_subset = build_dataset(
    TRAIN_PATTERN,
    train=True,
)
test_dataset = build_dataset(
    TEST_PATTERN,
    train=False,
)

# Spatial Model


In [None]:
class TemporalFusion(tf.keras.layers.Layer):
    """ Change detection layer.

    Based on Late Fusion from Maretto et al. 2021 10.1109/LGRS.2020.2986407
    """
    def __init__(self, filters, **kwargs):
        super().__init__(**kwargs)
        self.conv = tf.keras.layers.Conv2D(
            filters=filters,
            kernel_size=(1, 1),
            padding="same",
            activation="relu",
        )

    def call(self, input1, input2):
        x = tf.concat([input1, input2], -1)
        x = self.conv(x)
        return x


class DownSample(tf.keras.layers.Layer):
    def __init__(self, filters, kernel_size, dilation_rate, **kwargs):
        super().__init__(**kwargs)
        self.separable_conv2d_1 = tf.keras.layers.SeparableConv2D(
            filters=filters,
            kernel_size=kernel_size,
            dilation_rate=dilation_rate,
            padding="same",
            activation="relu",
        )
        self.batch_norm = tf.keras.layers.BatchNormalization()

    def call(self, x):
        x = self.separable_conv2d_1(x)
        x = self.batch_norm(x)
        return x


class UpSample(tf.keras.layers.Layer):
    def __init__(self, filters, kernel_size, **kwargs):
        super().__init__(**kwargs)
        self.transposed_conv2d_1 = tf.keras.layers.Conv2DTranspose(
            filters=filters,
            kernel_size=kernel_size,
            padding="same",
            activation="relu",
        )
        self.batch_norm = tf.keras.layers.BatchNormalization()

    def call(self, x):
        x = self.transposed_conv2d_1(x)
        x = self.batch_norm(x)
        return x


class MetadataBias(tf.keras.layers.Layer):
    """ Layer to include scalar metadata in a fully convolutional network.

    Based on LSENet from Xie, Guo, and Dong 2022 10.1109/TGRS.2022.3176635

    x += Dense(Concat([Dense(GlobalAvgPool(x)), Embedding(scalars)]))
    """
    def __init__(self, num_outputs, **kwargs):
        super().__init__(**kwargs)

        self.num_outputs = num_outputs
        self.num_inputs = self.num_outputs // 3

        self.doy_embedding = tf.keras.layers.Embedding(
            MAX_DOY,
            self.num_inputs
        )
        self.ecozone_embedding = tf.keras.layers.Embedding(
            NUM_ECOZONES,
            self.num_inputs
        )

        self.pool = tf.keras.layers.GlobalAveragePooling2D()
        self.dense1 = tf.keras.layers.Dense(self.num_inputs)
        self.dense2 = tf.keras.layers.Dense(self.num_outputs)

    def call(self, x, doy, ecozone):
        doy_embedding = self.doy_embedding(doy)[:, 0]
        ecozone_embedding = self.ecozone_embedding(ecozone)[:, 0]

        pooled_x = self.pool(x)
        pooled_x = self.dense1(pooled_x)

        metadata = tf.concat(
            [pooled_x, doy_embedding, ecozone_embedding],
            axis=-1
        )
        metadata = self.dense2(metadata)
        metadata = tf.reshape(metadata, (-1, 1, 1, self.num_outputs))

        return x + metadata

In [None]:
def build_two_downstack_model(subset):
    normalizer = tf.keras.layers.Normalization()
    normalizer.adapt(subset)

    input_layer = tf.keras.layers.Input(
        shape=(HEIGHT, WIDTH, 2 * NUM_INPUTS),
        name=IMAGE_INPUT_LAYER_NAME,
    )
    x = normalizer(input_layer)

    x1 = x[:, :, :, :NUM_INPUTS]
    x2 = x[:, :, :, NUM_INPUTS:]

    down_stack_1 = [DownSample(*config) for config in MODEL_CONFIG]
    down_stack_2 = [DownSample(*config) for config in MODEL_CONFIG]
    up_stack = [UpSample(f, UPSAMPLE_FILTERS) for f in reversed(FILTERS)]

    skips = []
    for i, (down1, down2) in enumerate(zip(down_stack_1, down_stack_2)):
        x1 = down1(x1)
        x2 = down2(x2)
        x = TemporalFusion(FILTERS[i])(x1, x2)
        skips.append(x)

    if INCLUDE_METADATA:
        doy_input = tf.keras.layers.Input(
            shape=1,
            dtype=tf.int64,
            name=DOY_INPUT_LAYER_NAME,
        )
        ecozone_input = tf.keras.layers.Input(
            shape=1,
            dtype=tf.int64,
            name=ECOZONE_INPUT_LAYER_NAME,
        )

        metadata_bias = MetadataBias(FILTERS[-1])
        x = metadata_bias(x, doy_input, ecozone_input)

    skips = reversed(skips[:-1])

    for up, skip in zip(up_stack, skips):
        x = up(x)
        x = tf.keras.layers.Concatenate()([x, skip])

    x = tf.keras.layers.Conv2DTranspose(
        NUM_OUTPUTS,
        kernel_size=OUTPUT_KERNEL,
        padding="same",
        activation="softmax",
    )(x)

    if INCLUDE_METADATA:
        inputs = [input_layer, doy_input, ecozone_input]
    else:
        inputs = input_layer

    model = tf.keras.Model(inputs, x)
    return model


def build_single_downstack_model(subset):
    if INCLUDE_HISTORICAL:
        shape = (HEIGHT, WIDTH, 2 * NUM_INPUTS)
    else:
        shape = (HEIGHT, WIDTH, NUM_INPUTS)

    input_layer = tf.keras.layers.Input(
        shape=shape,
        name=IMAGE_INPUT_LAYER_NAME
    )

    image_normalizer = tf.keras.layers.Normalization()
    image_normalizer.adapt(subset)

    x = image_normalizer(input_layer)

    down_stack = [DownSample(*config) for config in MODEL_CONFIG]
    up_stack = [UpSample(f, UPSAMPLE_FILTERS) for f in reversed(FILTERS)]

    skips = []
    for i, down in enumerate(down_stack):
        x = down(x)
        skips.append(x)

    if INCLUDE_METADATA:
        doy_input = tf.keras.layers.Input(
            shape=1,
            name=DOY_INPUT_LAYER_NAME,
        )
        ecozone_input = tf.keras.layers.Input(
            shape=1,
            name=ECOZONE_INPUT_LAYER_NAME,
        )

        metadata_bias = MetadataBias(METADATA_FILTERS, FILTERS[-1])
        x = metadata_bias(x, doy_input, ecozone_input)

    skips = reversed(skips[:-1])

    for up, skip in zip(up_stack, skips):
        x = up(x)
        x = tf.keras.layers.Concatenate()([x, skip])

    x = tf.keras.layers.Conv2DTranspose(
        NUM_OUTPUTS,
        kernel_size=OUTPUT_KERNEL,
        padding="same",
        activation="softmax",
    )(x)

    if INCLUDE_METADATA:
        inputs = [input_layer, doy_input, ecozone_input]
    else:
        inputs = input_layer

    model = tf.keras.Model(inputs, x)
    return model


tf.keras.backend.clear_session()
if TWO_DOWNSTACKS:
    model = build_two_downstack_model(normalize_subset)
else:
    model = build_single_downstack_model(normalize_subset)
# tf.keras.utils.plot_model(model, show_shapes=True)

# Temporal Model

In [None]:
class RecurrentBlock(tf.keras.layers.Layer):
    def __init__(self, units, **kwargs):
        super().__init__(**kwargs)
        self.lstm1 = tf.keras.layers.LSTM(
            units,
            return_sequences=True,
        )
        self.lstm2 = tf.keras.layers.LSTM(
            units,
            return_sequences=True,
        )
        self.lstm3 = tf.keras.layers.LSTM(
            units,
            return_sequences=False,
            return_state=True,
        )

    def call(self, x, initial_states=None):
        initial_state = [None] * 3 if initial_state is None
        x, state1 = self.lstm1(x, initial_state=initial_state[0])
        x, state2 = self.lstm2(x, initial_state=initial_state[1])
        x, state3 = self.lstm3(x, initial_state=initial_state[2])
        return x, [state1, state2, state3]

In [None]:
def build_temporal_model(units, num_inputs, num_outputs):
    lookback_input = tf.keras.layers.Input(shape=(None, num_inputs))
    target_input = tf.keras.layers.Input(shape=(None, num_inputs))
    lookahead_input = tf.keras.layers.Input(shape=(None, num_inputs))

    lookback, states = RecurrentBlock(units)(lookback_input)
    target, states = RecurrentBlock(units)target_input, initial_states=states)
    lookahead, _ = RecurrentBlock(units)(lookahead_input, initial_state=states)

    x = tf.concat([lookback, target, lookahead])

    x = tf.Dense(
        num_outputs,
        activation="softmax" if num_outputs > 1 else "sigmoid",
    )(x)

    model = tf.keras.Model(inputs=[lookback_input, target_input, lookahead_input], outputs=x)
    return model

temporal_model = build_temporal_model(64, 16, 3)
tf.keras.utils.plot(temporal_model)

# Train Model

In [None]:
rng = np.random.default_rng()

size = 50
data_A = rng.normal(0, 1, (size, HEIGHT, WIDTH, NUM_INPUTS))
data_B = rng.normal(0, 1, (size, HEIGHT, WIDTH, NUM_INPUTS))

labels = tf.one_hot(rng.integers(0, NUM_OUTPUTS, (size, HEIGHT, WIDTH)), NUM_OUTPUTS)

In [None]:
# checkpoint to save progress during training and for easier loading of the
# model later on, but need to use model.save(...) for EEification
checkpoint_path = os.path.join(MODEL_DIR, "test", "checkpoints")
checkpoint = tf.keras.callbacks.ModelCheckpoint(
    filepath=checkpoint_path,
    save_weights_only=True,
)

model.compile(
    loss=tf.keras.losses.categorical_crossentropy,
    optimizer=tf.keras.optimizers.Adam(learning_rate=1e-4)
)

# model.fit(
#     train_dataset,
#     steps_per_epoch=50,
#     epochs=10,
#     callbacks=[checkpoint],
# )
model.load_weights(checkpoint_path)

# Visualization


In [None]:
class_colours = ["white", "black", "gold", "darkCyan", "darkOrange", "red",
                 "orchid", "purple", "cornsilk", "dimGrey"]
CLASS_LIST = ["None", "Non-Forest", "Forest", "Water", "Previous Burn",
              "Burn", "Previous Harvest", "Harvest", "Cloud", "Cloud Shadow"]
cmap = ListedColormap(class_colours, NUM_OUTPUTS)
norm = BoundaryNorm(np.arange(NUM_OUTPUTS + 1), NUM_OUTPUTS)

def _plot_x(x, axes, i, j):
    x = tf.gather(x, (0, 1, 2), axis=-1)
    std = np.std(x)
    vmin = np.mean(x) - std
    vmax = np.mean(x) + std
    axes[i, j].imshow(x, vmin=vmin, vmax=vmax)

def _plot_y(y, axes, i, j):
    y = np.squeeze(np.argmax(y, axis=-1))
    axes[i, j].imshow(y, cmap=cmap, norm=norm)

In [None]:
# Run this cell to verify that the cropping does what we expect
def crop_visualizer(pattern, count=10, deterministic_crop=False):
    files = tf.data.Dataset.list_files(pattern, shuffle=False)
    raw_dataset = tf.data.TFRecordDataset(files, compression_type='GZIP')
    dataset = raw_dataset.map(parse)
    dataset = dataset.cache()

    size = 6
    rgb_indices = (0, 1, 2)

    if deterministic_crop:
        cropped_dataset = dataset.flat_map(non_overlapping_crop)
        cropped_dataset = cropped_dataset.take(4 * count)
        fig, axes = plt.subplots(count, 10, figsize=(10 * size, count * size))
    else:
        cropped_dataset = dataset.map(crop_wrapper)
        cropped_dataset = cropped_dataset.take(count)
        fig, axes = plt.subplots(count, 4, figsize=(4 * size, count * size))

    dataset = dataset.take(count)

    if deterministic_crop:
        for i, (x, y) in enumerate(dataset):
            if INCLUDE_METADATA:
                x = x[0]
            _plot_x(x, axes, i, 0)
            _plot_y(y, axes, i, 5)

        for i, (x, y) in enumerate(cropped_dataset):
            if INCLUDE_METADATA:
                x = x[0]
            _plot_x(x, axes, i // 4, 1 + (i % 4))
            _plot_y(y, axes, i // 4, 6 + (i % 4))
    else:
        for i, (x, y) in enumerate(dataset):
            if INCLUDE_METADATA:
                x = x[0]
            _plot_x(x, axes, i, 0)
            _plot_y(y, axes, i, 2)

        for i, (x, y) in enumerate(cropped_dataset):
            if INCLUDE_METADATA:
                x = x[0]
            _plot_x(x, axes, i, 1)
            _plot_y(y, axes, i, 3)

# crop_visualizer(TRAIN_PATTERN, deterministic_crop=True)
# crop_visualizer(TRAIN_PATTERN, deterministic_crop=False)

In [None]:
# run this cell to verify that data augmentation does what we intend
def data_augmentation_visualizer(pattern, count=10):
    files = tf.data.Dataset.list_files(pattern, shuffle=False)
    raw_dataset = tf.data.TFRecordDataset(files, compression_type='GZIP')
    dataset = raw_dataset.map(parse)
    dataset = dataset.cache()
    dataset = dataset.take(count)

    size = 6
    rgb_indices = (0, 1, 2)

    augmented_dataset = dataset.map(data_augmentation)

    fig, axes = plt.subplots(count, 4, figsize=(4 * size, count * size))

    for i, (x, y) in enumerate(dataset):
        if INCLUDE_METADATA:
            x = x[0]
        _plot_x(x, axes, i, 0)
        _plot_y(y, axes, i, 1)

    for i, (x, y) in enumerate(augmented_dataset):
        if INCLUDE_METADATA:
            x = x[0]
        _plot_x(x, axes, i, 2)
        _plot_y(y, axes, i, 3)

data_augmentation_visualizer(TRAIN_PATTERN, 25)

In [None]:
def visualizer(dataset, model=None, count=10):
    rgb_indices = [0, 1, 2]
    historical_rgb_indices = [6, 7, 8]

    data = dataset.unbatch()

    num = 3 if model is None else 4
    size = 10
    fig, axes = plt.subplots(count, num, figsize=(num * size, count * size))

    def plot_row(x, hx, y, model_output, index):
        vmin_x = np.mean(x) - (0.5 * np.std(x))
        vmax_x = np.mean(x) + (0.5 * np.std(x))
        vmin_hx = np.mean(hx) - (0.5 * np.std(hx))
        vmax_hx = np.mean(hx) + (0.5 * np.std(hx))
        axes[index, 0].imshow(hx, vmin=vmin_hx, vmax=vmax_hx)
        axes[index, 1].imshow(x, vmin=vmin_x, vmax=vmax_x)
        y = np.argmax(y, axis=-1)
        axes[index, 2].imshow(y, cmap=cmap, norm=norm)
        if model_output is not None:
            model_output = np.squeeze(np.argmax(model_output, axis=-1))
            axes[index, 3].imshow(model_output, cmap=cmap, norm=norm)

    for i, (x, y) in enumerate(data.take(count)):
        if INCLUDE_METADATA:
            _x = tf.expand_dims(x[0], axis=0)
            metadata = [tf.expand_dims(m, axis=0) for m in x[1:]]
            _x = [_x, *metadata]
        else:
            _x = tf.expand_dims(x, axis=0)

        model_output = None if model is None else model(_x)

        if INCLUDE_METADATA:
            x = x[0]

        x_rgb = tf.gather(x, rgb_indices, axis=-1)
        hx_rgb = tf.gather(x, historical_rgb_indices, axis=-1)

        plot_row(x_rgb, hx_rgb, y, model_output, i)


visualizer(train_dataset, model=model, count=25)

# Assessment

In [None]:
def build_confusion_matrix(model, dataset):
    complete_confusion_matrix = tf.zeros(
        (NUM_OUTPUTS, NUM_OUTPUTS),
        dtype=tf.int32
    )

    for x, y in dataset.take(10):
        y_prime = model(x)

        current_confusion_matrix = tf.math.confusion_matrix(
            labels=tf.reshape(tf.argmax(y, -1), [-1]),
            predictions=tf.reshape(tf.argmax(y_prime, -1), [-1]),
            num_classes=NUM_OUTPUTS,
        )

        complete_confusion_matrix += current_confusion_matrix

    return complete_confusion_matrix

def label_confusion_matrix(confusion_matrix, class_labels):
    confusion_matrix_df = pd.DataFrame(
        confusion_matrix,
        index=['True ' + label for label in class_labels],
        columns=['Pred ' + label for label in class_labels]
    )

    return confusion_matrix_df

cm = build_confusion_matrix(model, test_dataset)
label_confusion_matrix(cm, CLASS_LIST)

# EEification

See: https://github.com/google/earthengine-community/blob/master/guides/linked/Earth_Engine_TensorFlow_Vertex_AI.ipynb

And also: https://github.com/google/earthengine-community/blob/master/guides/linked/Earth_Engine_TensorFlow_tree_counting_model.ipynb


In [None]:
class Preprocessing(tf.keras.layers.Layer):
    """ Based on:
    https://github.com/google/earthengine-community/blob/master/guides/linked/Earth_Engine_TensorFlow_Vertex_AI.ipynb

    Stacks and reshapes input tensors.
    """
    def __init__(self, **kwargs):
        super().__init__(**kwargs)

    def call(self, features_dict):
        # (None, 1, 1, 1) -> (None, 1, 1, P)
        image = tf.concat([features_dict[b] for b in BANDS], axis=-1, name='image')
        if INCLUDE_METADATA:
            metadata = [features_dict[m] for m in METADATA]
            return (image, *metadata)
        else:
            return image

    def get_config(self):
        config = super().get_config()
        return config

class WrappedModel(tf.keras.Model):
    """ Based on:
    https://github.com/google/earthengine-community/blob/master/guides/linked/Earth_Engine_TensorFlow_Vertex_AI.ipynb

    Wraps a given model in Preprocessing Layer
    """
    def __init__(self, model, **kwargs):
        super().__init__(**kwargs)
        self.preprocessing = Preprocessing()
        self.model = model

    def call(self, features_dict):
        x = self.preprocessing(features_dict)
        return self.model(x)

    def get_config(self):
        config = super().get_config()
        return config

In [None]:
class DeSerializeInput(tf.keras.layers.Layer):
    def __init__(self, **kwargs):
        super().__init__(**kwargs)

    def call(self, inputs_dict):
        serialized_dict = {
            k: tf.map_fn(
                lambda x: tf.io.parse_tensor(x, tf.float32),
                tf.io.decode_base64(v),
                fn_output_signature=tf.float32
            )
            for (k, v) in inputs_dict.items()
            if k in BANDS
        }

        # scalar metadata should be parsed as int64 not float
        for (k, v) in inputs_dict.items():
            if k not in BANDS:
                serialized_dict[k] = tf.map_fn(
                    lambda x: tf.io.parse_tensor(x, tf.int64),
                    tf.io.decode_base64(v),
                    fn_output_signature=tf.int64
                )

        return serialized_dict


    def get_config(self):
        config = super().get_config()
        return config


class ReSerializeOutput(tf.keras.layers.Layer):
    def __init__(self, **kwargs):
        super().__init__(**kwargs)

    def call(self, output_tensor):
        return tf.map_fn(
            lambda x: tf.io.encode_base64(tf.io.serialize_tensor(x)),
            output_tensor,
            fn_output_signature=tf.string
        )

    def get_config(self):
        config = super().get_config()
        return config

input_deserializer = DeSerializeInput()
output_reserializer = ReSerializeOutput()

serialized_inputs = {
    x: tf.keras.Input(shape=[], dtype='string', name=x)
    for x in (BANDS + METADATA if INCLUDE_METADATA else BANDS)
}

model.load_weights(os.path.join(MODEL_DIR, "test", "checkpoints"))

wrapped_model = WrappedModel(model)
updated_model_input = input_deserializer(serialized_inputs)
updated_model = wrapped_model(updated_model_input)
updated_model = output_reserializer(updated_model)
updated_model = tf.keras.Model(serialized_inputs, updated_model)

SAVED_MODEL_PATH = os.path.join(MODEL_DIR, "test", "full_model")

!gsutil rm -rf {SAVED_MODEL_PATH}
updated_model.save(SAVED_MODEL_PATH)

In [None]:
!gcloud ai models delete {MODEL_NAME} --project={PROJECT_ID} --region={REGION}

In [None]:
# upload the model
CONTAINER_IMAGE = 'us-docker.pkg.dev/vertex-ai/prediction/tf2-gpu.2-11:latest'

!gcloud ai models upload \
    --project={PROJECT_ID} \
    --artifact-uri={SAVED_MODEL_PATH} \
    --region={REGION} \
    --container-image-uri={CONTAINER_IMAGE} \
    --description={MODEL_NAME} \
    --display-name={MODEL_NAME} \
    --model-id={MODEL_NAME}

In [None]:
# create endpoint for model
!gcloud ai endpoints create \
    --display-name={ENDPOINT_NAME} \
    --region={REGION} \
    --project={PROJECT_ID}

In [None]:
# deploy the model

# may need to filter, if you have multiple of these
ENDPOINT_ID = !gcloud ai endpoints list \
    --project={PROJECT_ID} \
    --region={REGION} \
    --format="value(ENDPOINT_ID.scope())"
ENDPOINT_ID = ENDPOINT_ID[-1]

!gcloud ai endpoints deploy-model {ENDPOINT_ID} \
    --project={PROJECT_ID} \
    --region={REGION} \
    --model={MODEL_NAME} \
    --machine-type=n1-standard-8 \
    --accelerator=type=nvidia-tesla-t4,count=1 \
    --display-name={MODEL_NAME}

# Verify Model Hosting Was Successful

In [None]:
import ee
ee.Authenticate()
ee.Initialize()

In [None]:
!git clone https://github.com/boothmanrylan/canadaMSSForestDisturbances.git
%cd canadaMSSForestDisturbances

In [None]:
!pip install --quiet msslib
!pip install --quiet geemap

In [None]:
from mss_forest_disturbances import data
import geemap
from msslib import msslib

In [None]:
Map = geemap.Map()
Map

In [None]:
aoi = Map.draw_features[0]
year = 1990

collection = msslib.getCol(
    aoi=aoi.geometry(),
    yearRange=[year, year],
    doyRange=data.DOY_RANGE,
    maxCloudCover=100
)

image = collection.sort('CLOUD_COVER').first()

Map.addLayer(image, msslib.visDn2, "Image")

In [None]:
ecozone = ee.FeatureCollection(data.ECOZONES).filterBounds(aoi.geometry()).first()
ecozone_id = ecozone.getNumber('ECOZONE_ID')
prepared_image, target_label = data.prepare_image_for_export(image)
prepared_image = prepared_image.set('ecozone', ecozone_id)

In [None]:
endpoint_path = os.path.join('projects', PROJECT_ID, 'locations', REGION, 'endpoints', ENDPOINT_ID)
hosted_model = ee.Model.fromVertexAi(
    endpoint=endpoint_path,
    inputTileSize=(HEIGHT, WIDTH),
    inputOverlapSize=(16, 16),
    inputProperties=METADATA,
    proj=data.get_default_projection(),
    fixInputProj=True,
    outputBands={
        'label': {
            'type': ee.PixelType.float(),
            'dimensions': 1
        },
    },
    maxPayloadBytes=3000000,
)

In [None]:
prediction = hosted_model.predictImage(prepared_image)

task = ee.batch.Export.image.toAsset(
    image=prediction,
    description="test_vertex_ai_hosting",
    assetId="projects/api-project-269347469410/assets/rylan-mssforestdisturbances/scratch/test_vertex_ai_hosting",
    pyramidingPolicy={".default": "mode"},
    region=image.geometry(),
    scale=60,
    crs=data.get_default_projection(),
)
task.start()

# TODO
* set up assessment code
* __not enough disturbances in exported data__
* Add Digital Elevation Model band back to export
* Add index to distinguish new harvest from old harvest
    * red / ndvi
    * need way to prove/argue that this is a useful spectral index
* Add index to distinguish new burn scar from old burn scar
* temporal model
    * write code
    * figure out how to export training data
* Figure out how to run colab with a paid backend
* Vertex AI hosted model called through earth engine exporting the result is very slow (24 minutes for one image) Batch export and running everything in google cloud is likely faster, but more expensive and for the next step we need to be able to look at pixels through time which will be more difficult outside of earth engine
