<a href="https://colab.research.google.com/github/boothmanrylan/canadaMSSForestDisturbances/blob/main/SpatioTemporalUNet.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

# Setup


In [None]:
from google.colab import auth
auth.authenticate_user()

PROJECT_ID = "api-project-269347469410"
!gcloud config set project {PROJECT_ID}

In [None]:
import os
import math
import json

import tensorflow as tf
from tensorflow.python.tools import saved_model_utils
import numpy as np
import matplotlib.pyplot as plt
from matplotlib.colors import ListedColormap, BoundaryNorm

# Config

In [None]:
NUM_OUTPUTS = 10
EXPORT_HEIGHT = 512
EXPORT_WIDTH = 512
HEIGHT = 256
WIDTH = 256

# TODO: DEM will make an uneven number of bands that doesn't split into current/past as the DEM wont have changed
BANDS = [
    'nir', 'red_edge', 'red', 'green',
    'tca', 'ndvi',
    'historical_nir', 'historical_red_edge', 'historical_red', 'historical_green',
    'historical_tca', 'historical_ndvi'
]
METADATA = ['doy', 'ecozone'] # , 'lat', 'lon']
NUM_INPUTS = len([b for b in BANDS if "historical" not in b])
LABEL_BAND = 'label'

MAX_DOY = 110
NUM_ECOZONES = 10  # there are only seven represented in the sanity test dataset

# Data Config
BUCKET = 'rylan-mssforestdisturbances'
BASE_PATH = f'gs://{BUCKET}/scratch/test_export/ecozone*/'
TEST_PATTERN = os.path.join(BASE_PATH, '*-00000-of-*.tfrecord.gz')
TRAIN_PATTERN = os.path.join(BASE_PATH, '*-000[0-9][1-9]-of*.tfrecord.gz')

BATCH_SIZE = 32
SHUFFLE_BUFFER = 100

SUBSET_SIZE = 100

# Model Config
FILTERS = [32, 64, 128, 256]
KERNELS = [7, 5, 3, 3]
DILATION_RATES = [1, 1, 2, 4]
UPSAMPLE_FILTERS = 3
METADATA_FILTERS = 32
OUTPUT_KERNEL = 3
MODEL_CONFIG = list(zip(FILTERS, KERNELS, DILATION_RATES))

TWO_DOWNSTACKS = True
INCLUDE_HISTORICAL = True
INCLUDE_METADATA = True

if not (TWO_DOWNSTACKS and INCLUDE_HISTORICAL):
    BANDS = [b for b in BANDS if "historical" not in b]

RNG = tf.random.Generator.from_seed(42, alg="philox")

# AI Platform Hosting Config
REGION = "us-central1"
MODEL_DIR = f"gs://{BUCKET}/scratch/models/"
EEIFIED_DIR = f"gs://{BUCKET}/scratch/eeified_models/test_model_hosting/"
MODEL_NAME = "test_model"
VERSION_NAME = "v0"

# Load Data

In [None]:
IMAGE_FEATURES = {
    b: tf.io.FixedLenFeature(
        shape=(EXPORT_HEIGHT, EXPORT_WIDTH),
        dtype=tf.float32
    )
    for b in BANDS
}

LABEL_FEATURES = {
    LABEL_BAND: tf.io.FixedLenFeature(
        shape=(EXPORT_HEIGHT, EXPORT_WIDTH),
        dtype=tf.int64
    )
}

METADATA_FEATURES = {
    m: tf.io.FixedLenFeature(shape=1, dtype=tf.int64)
    for m in METADATA
}


def parse(example):
    x = tf.io.parse_single_example(example, IMAGE_FEATURES)
    x = tf.stack([x[b] for b in BANDS], axis=-1)

    y = tf.io.parse_single_example(example, LABEL_FEATURES)[LABEL_BAND]
    y = tf.one_hot(y, NUM_OUTPUTS)

    metadata = tf.io.parse_single_example(example, METADATA_FEATURES)
    metadata = [metadata[m] for m in METADATA]

    if INCLUDE_METADATA:
        x = (x, *metadata)

    return x, y


def non_overlapping_crop(x, y):
    assert EXPORT_HEIGHT % HEIGHT == 0
    assert EXPORT_WIDTH % WIDTH == 0

    def _crop(tensor):
        """ based on https://stackoverflow.com/a/31530106
        """
        tensor = tf.reshape(
            tensor,
            (EXPORT_HEIGHT // HEIGHT, HEIGHT, EXPORT_WIDTH // WIDTH, WIDTH, -1)
        )
        cropped = tf.experimental.numpy.swapaxes(tensor, 1, 2)

        num_blocks = (EXPORT_HEIGHT // HEIGHT) * (EXPORT_WIDTH // WIDTH)
        cropped = tf.reshape(cropped, (num_blocks, HEIGHT, WIDTH, -1))
        return tf.data.Dataset.from_tensor_slices(cropped)

    if INCLUDE_METADATA:
        metadata = [
            tf.data.Dataset.from_tensor_slices(m).repeat()
            for m in x[1:]
        ]
        x = x[0]

    x = _crop(x)
    y = _crop(y)

    if INCLUDE_METADATA:
        x = tf.data.Dataset.zip((x, *metadata))

    return tf.data.Dataset.zip((x, y))


def crop(x, y, seed):
    if INCLUDE_METADATA:
        metadata = x[1:]
        x = x[0]

    y_shape = tf.shape(y)
    if len(y_shape) == 2:  # add temporary channel dimension
        y = tf.reshape(y, y_shape + (1,))
        num_y_bands = 1
    else:
        num_y_bands = y_shape[-1]

    y_type = y.dtype
    desired_type = x.dtype
    y = tf.cast(y, desired_type)

    num_x_bands = tf.shape(x)[-1]
    xy = tf.concat([x, y], -1)

    target_shape = (HEIGHT, WIDTH, num_x_bands + num_y_bands)
    xy = tf.image.stateless_random_crop(xy, target_shape, seed=seed)

    x = xy[:, :, :-num_y_bands]

    y = tf.squeeze(tf.cast(xy[:, :, -num_y_bands:], y_type))

    if INCLUDE_METADATA:
        x = (x, *metadata)

    return x, y


def crop_wrapper(x, y):
    seed = RNG.make_seeds(2)[0]
    x, y = crop(x, y, seed)
    return x, y


def build_dataset(tfrecord_pattern, train=True):
    tfrecords = tf.data.Dataset.list_files(tfrecord_pattern, shuffle=train)
    dataset = tfrecords.interleave(
        lambda x: tf.data.TFRecordDataset(x, compression_type='GZIP').map(parse, num_parallel_calls=1),
        cycle_length=3 * NUM_ECOZONES,
        block_length=BATCH_SIZE // 4,
        num_parallel_calls=tf.data.AUTOTUNE,
        deterministic=not train,
    )

    dataset = dataset.cache()

    if train:
        dataset = dataset.shuffle(SHUFFLE_BUFFER)
        dataset = dataset.map(crop_wrapper, num_parallel_calls=tf.data.AUTOTUNE)
    else:
        dataset = dataset.flat_map(non_overlapping_crop)

    if train:
        subset = []
        for x, y in dataset.take(math.ceil(SUBSET_SIZE / BATCH_SIZE)):
            if INCLUDE_METADATA:
                subset.append(x[0])
            else:
                subset.append(x)
        subset = tf.concat(subset, axis=0)

        dataset = dataset.repeat()

    dataset = dataset.batch(BATCH_SIZE)
    dataset = dataset.prefetch(tf.data.AUTOTUNE)

    if train:
        return dataset, subset
    return dataset


train_dataset, normalize_subset = build_dataset(
    TRAIN_PATTERN,
    train=True,
)
test_dataset = build_dataset(
    TEST_PATTERN,
    train=False,
)

# Spatial Model


In [None]:
class TemporalFusion(tf.keras.layers.Layer):
    """ Change detection layer.

    Based on Late Fusion from Maretto et al. 2021 10.1109/LGRS.2020.2986407
    """
    def __init__(self, filters, **kwargs):
        super().__init__(**kwargs)
        self.conv = tf.keras.layers.Conv2D(
            filters=filters,
            kernel_size=(1, 1),
            padding="same",
            activation="relu",
        )

    def call(self, input1, input2):
        x = tf.concat([input1, input2], -1)
        x = self.conv(x)
        return x


class DownSample(tf.keras.layers.Layer):
    def __init__(self, filters, kernel_size, dilation_rate, **kwargs):
        super().__init__(**kwargs)
        self.separable_conv2d_1 = tf.keras.layers.SeparableConv2D(
            filters=filters,
            kernel_size=kernel_size,
            dilation_rate=dilation_rate,
            padding="same",
            activation="relu",
        )
        self.batch_norm = tf.keras.layers.BatchNormalization()

    def call(self, x):
        x = self.separable_conv2d_1(x)
        x = self.batch_norm(x)
        return x


class UpSample(tf.keras.layers.Layer):
    def __init__(self, filters, kernel_size, **kwargs):
        super().__init__(**kwargs)
        self.transposed_conv2d_1 = tf.keras.layers.Conv2DTranspose(
            filters=filters,
            kernel_size=kernel_size,
            padding="same",
            activation="relu",
        )
        self.batch_norm = tf.keras.layers.BatchNormalization()

    def call(self, x):
        x = self.transposed_conv2d_1(x)
        x = self.batch_norm(x)
        return x


class MetadataBias(tf.keras.layers.Layer):
    """ Layer to include scalar metadata in a fully convolutional network.

    Based on LSENet from Xie, Guo, and Dong 2022 10.1109/TGRS.2022.3176635

    x += Dense(Concat([Dense(GlobalAvgPool(x)), Embedding(scalars)]))
    """
    def __init__(self, num_outputs, **kwargs):
        super().__init__(**kwargs)

        self.num_outputs = num_outputs
        self.num_inputs = self.num_outputs // 3

        self.doy_embedding = tf.keras.layers.Embedding(
            MAX_DOY,
            self.num_inputs
        )
        self.ecozone_embedding = tf.keras.layers.Embedding(
            NUM_ECOZONES,
            self.num_inputs
        )

        self.pool = tf.keras.layers.GlobalAveragePooling2D()
        self.dense1 = tf.keras.layers.Dense(self.num_inputs)
        self.dense2 = tf.keras.layers.Dense(self.num_outputs)

    def call(self, x, doy, ecozone):
        doy_embedding = self.doy_embedding(doy)[:, 0]
        ecozone_embedding = self.ecozone_embedding(ecozone)[:, 0]

        pooled_x = self.pool(x)
        pooled_x = self.dense1(pooled_x)

        metadata = tf.concat(
            [pooled_x, doy_embedding, ecozone_embedding],
            axis=-1
        )
        metadata = self.dense2(metadata)
        metadata = tf.reshape(metadata, (-1, 1, 1, self.num_outputs))

        return x + metadata

In [None]:
def build_two_downstack_model(subset):
    normalizer = tf.keras.layers.Normalization()
    normalizer.adapt(subset)

    input_layer = tf.keras.layers.Input(shape=(HEIGHT, WIDTH, 2 * NUM_INPUTS))
    x = normalizer(input_layer)

    x1 = x[:, :, :, :NUM_INPUTS]
    x2 = x[:, :, :, NUM_INPUTS:]

    down_stack_1 = [DownSample(*config) for config in MODEL_CONFIG]
    down_stack_2 = [DownSample(*config) for config in MODEL_CONFIG]
    up_stack = [UpSample(f, UPSAMPLE_FILTERS) for f in reversed(FILTERS)]

    skips = []
    for i, (down1, down2) in enumerate(zip(down_stack_1, down_stack_2)):
        x1 = down1(x1)
        x2 = down2(x2)
        x = TemporalFusion(FILTERS[i])(x1, x2)
        skips.append(x)

    if INCLUDE_METADATA:
        doy_input = tf.keras.layers.Input(shape=1, dtype=tf.int64)
        ecozone_input = tf.keras.layers.Input(shape=1, dtype=tf.int64)

        metadata_bias = MetadataBias(FILTERS[-1])
        x = metadata_bias(x, doy_input, ecozone_input)

    skips = reversed(skips[:-1])

    for up, skip in zip(up_stack, skips):
        x = up(x)
        x = tf.keras.layers.Concatenate()([x, skip])

    x = tf.keras.layers.Conv2DTranspose(
        NUM_OUTPUTS,
        kernel_size=OUTPUT_KERNEL,
        padding="same",
        activation="softmax",
    )(x)

    if INCLUDE_METADATA:
        inputs = [input_layer, doy_input, ecozone_input]
    else:
        inputs = input_layer

    model = tf.keras.Model(inputs, x)
    return model


def build_single_downstack_model(subset):
    if INCLUDE_HISTORICAL:
        shape = (HEIGHT, WIDTH, 2 * NUM_INPUTS)
    else:
        shape = (HEIGHT, WIDTH, NUM_INPUTS)

    input_layer = tf.keras.layers.Input(shape=shape)

    image_normalizer = tf.keras.layers.Normalization()
    image_normalizer.adapt(subset)

    x = image_normalizer(input_layer)

    down_stack = [DownSample(*config) for config in MODEL_CONFIG]
    up_stack = [UpSample(f, UPSAMPLE_FILTERS) for f in reversed(FILTERS)]

    skips = []
    for i, down in enumerate(down_stack):
        x = down(x)
        skips.append(x)

    if INCLUDE_METADATA:
        doy_input = tf.keras.layers.Input(shape=1)
        ecozone_input = tf.keras.layers.Input(shape=1)
        lat_input = tf.keras.layers.Input(shape=1)
        lon_input = tf.keras.layers.Input(shape=1)

        metadata_bias = MetadataBias(METADATA_FILTERS, FILTERS[-1])
        x = metadata_bias(x, doy_input, ecozone_input, lat_input, lon_input)

    skips = reversed(skips[:-1])

    for up, skip in zip(up_stack, skips):
        x = up(x)
        x = tf.keras.layers.Concatenate()([x, skip])

    x = tf.keras.layers.Conv2DTranspose(
        NUM_OUTPUTS,
        kernel_size=OUTPUT_KERNEL,
        padding="same",
        activation="softmax",
    )(x)

    if INCLUDE_METADATA:
        inputs = [input_layer, doy_input, ecozone_input, lat_input, lon_input]
    else:
        inputs = input_layer

    model = tf.keras.Model(inputs, x)
    return model


tf.keras.backend.clear_session()
if TWO_DOWNSTACKS:
    model = build_two_downstack_model(normalize_subset)
else:
    model = build_single_downstack_model(normalize_subset)
# tf.keras.utils.plot_model(model, show_shapes=True)

# Temporal Model

In [None]:
class RecurrentBlock(tf.keras.layers.Layer):
    def __init__(self, units, **kwargs):
        super().__init__(**kwargs)
        self.lstm1 = tf.keras.layers.LSTM(
            units,
            return_sequences=True,
        )
        self.lstm2 = tf.keras.layers.LSTM(
            units,
            return_sequences=True,
        )
        self.lstm3 = tf.keras.layers.LSTM(
            units,
            return_sequences=False,
            return_state=True,
        )

    def call(self, x, initial_states=None):
        initial_state = [None] * 3 if initial_state is None
        x, state1 = self.lstm1(x, initial_state=initial_state[0])
        x, state2 = self.lstm2(x, initial_state=initial_state[1])
        x, state3 = self.lstm3(x, initial_state=initial_state[2])
        return x, [state1, state2, state3]

In [None]:
def build_temporal_model(units, num_inputs, num_outputs):
    lookback_input = tf.keras.layers.Input(shape=(None, num_inputs))
    target_input = tf.keras.layers.Input(shape=(None, num_inputs))
    lookahead_input = tf.keras.layers.Input(shape=(None, num_inputs))

    lookback, states = RecurrentBlock(units)(lookback_input)
    target, states = RecurrentBlock(units)target_input, initial_states=states)
    lookahead, _ = RecurrentBlock(units)(lookahead_input, initial_state=states)

    x = tf.concat([lookback, target, lookahead])

    x = tf.Dense(
        num_outputs,
        activation="softmax" if num_outputs > 1 else "sigmoid",
    )(x)

    model = tf.keras.Model(inputs=[lookback_input, target_input, lookahead_input], outputs=x)
    return model

temporal_model = build_temporal_model(64, 16, 3)
tf.keras.utils.plot(temporal_model)

# Train Model

In [None]:
rng = np.random.default_rng()

size = 50
data_A = rng.normal(0, 1, (size, HEIGHT, WIDTH, NUM_INPUTS))
data_B = rng.normal(0, 1, (size, HEIGHT, WIDTH, NUM_INPUTS))

labels = tf.one_hot(rng.integers(0, NUM_OUTPUTS, (size, HEIGHT, WIDTH)), NUM_OUTPUTS)

In [None]:
# checkpoint to save progress during training and for easier loading of the
# model later on, but need to use model.save(...) for EEification
checkpoint = tf.keras.callbacks.ModelCheckpoint(
    filepath=f"{MODEL_DIR}/test/checkpoint",
    save_weights_only=True,
)

model.compile(
    loss=tf.keras.losses.categorical_crossentropy,
    optimizer=tf.keras.optimizers.Adam(learning_rate=1e-4)
)

model.fit(
    train_dataset,
    steps_per_epoch=100,
    epochs=10,
    callbacks=[checkpoint],
)
# model.load_weights(f"{MODEL_DIR}/test/checkpoint")
model.save(f"{MODEL_DIR}/test/full_model/", save_format='tf')

# Visualization


In [None]:
class_colours = ["white", "black", "gold", "darkCyan", "darkOrange", "red",
                 "orchid", "purple", "cornsilk", "dimGrey"]
CLASS_LIST = ["None", "Non-Forest", "Forest", "Water", "Previous Burn",
              "Burn", "Previous Harvest", "Harvest", "Cloud", "Cloud Shadow"]
cmap = ListedColormap(class_colours, NUM_OUTPUTS)
norm = BoundaryNorm(np.arange(NUM_OUTPUTS + 1), NUM_OUTPUTS)

In [None]:
# Run this cell to verify that the cropping does what we expect
def _plot_x(x, axes, i, j):
    x = tf.gather(x, (0, 1, 2), axis=-1)
    std = np.std(x)
    vmin = np.mean(x) - std
    vmax = np.mean(x) + std
    axes[i, j].imshow(x, vmin=vmin, vmax=vmax)

def _plot_y(y, axes, i, j):
    y = np.squeeze(np.argmax(y, axis=-1))
    axes[i, j].imshow(y, cmap=cmap, norm=norm)

def crop_visualizer(pattern, count=10, deterministic_crop=False):
    files = tf.data.Dataset.list_files(pattern, shuffle=False)
    raw_dataset = tf.data.TFRecordDataset(files, compression_type='GZIP')
    dataset = raw_dataset.map(parse)
    dataset = dataset.cache()

    size = 6
    rgb_indices = (0, 1, 2)

    if deterministic_crop:
        cropped_dataset = dataset.flat_map(non_overlapping_crop)
        cropped_dataset = cropped_dataset.take(4 * count)
        fig, axes = plt.subplots(count, 10, figsize=(10 * size, count * size))
    else:
        cropped_dataset = dataset.map(crop_wrapper)
        cropped_dataset = cropped_dataset.take(count)
        fig, axes = plt.subplots(count, 4, figsize=(4 * size, count * size))

    dataset = dataset.take(count)

    if deterministic_crop:
        for i, (x, y) in enumerate(dataset):
            if TWO_DOWNSTACKS:
                x = x[0]
            _plot_x(x, axes, i, 0)
            _plot_y(y, axes, i, 5)

        for i, (x, y) in enumerate(cropped_dataset):
            if TWO_DOWNSTACKS:
                x = x[0]
            _plot_x(x, axes, i // 4, 1 + (i % 4))
            _plot_y(y, axes, i // 4, 6 + (i % 4))
    else:
        for i, (x, y) in enumerate(dataset):
            if TWO_DOWNSTACKS:
                x = x[0]
            _plot_x(x, axes, i, 0)
            _plot_y(y, axes, i, 2)

        for i, (x, y) in enumerate(cropped_dataset):
            if TWO_DOWNSTACKS:
                x = x[0]
            _plot_x(x, axes, i, 1)
            _plot_y(y, axes, i, 3)

# crop_visualizer(TRAIN_PATTERN, deterministic_crop=True)
# crop_visualizer(TRAIN_PATTERN, deterministic_crop=False)

In [None]:
def visualizer(dataset, model=None, count=10):
    rgb_indices = [0, 1, 2]
    historical_rgb_indices = [6, 7, 8]

    data = dataset.unbatch()

    num = 3 if model is None else 4
    size = 10
    fig, axes = plt.subplots(count, num, figsize=(num * size, count * size))

    def plot_row(x, hx, y, model_output, index):
        vmin_x = np.mean(x) - (0.5 * np.std(x))
        vmax_x = np.mean(x) + (0.5 * np.std(x))
        vmin_hx = np.mean(hx) - (0.5 * np.std(hx))
        vmax_hx = np.mean(hx) + (0.5 * np.std(hx))
        axes[index, 0].imshow(hx, vmin=vmin_hx, vmax=vmax_hx)
        axes[index, 1].imshow(x, vmin=vmin_x, vmax=vmax_x)
        y = np.argmax(y, axis=-1)
        axes[index, 2].imshow(y, cmap=cmap, norm=norm)
        if model_output is not None:
            model_output = np.squeeze(np.argmax(model_output, axis=-1))
            axes[index, 3].imshow(model_output, cmap=cmap, norm=norm)

    for i, (x, y) in enumerate(data.take(count)):
        if INCLUDE_METADATA:
            _x = tf.expand_dims(x[0], axis=0)
            metadata = [tf.expand_dims(m, axis=0) for m in x[1:]]
            _x = [_x, *metadata]
        else:
            _x = tf.expand_dims(x, axis=0)

        model_output = None if model is None else model(_x)

        if INCLUDE_METADATA:
            x = x[0]

        x_rgb = tf.gather(x, rgb_indices, axis=-1)
        hx_rgb = tf.gather(x, historical_rgb_indices, axis=-1)

        plot_row(x_rgb, hx_rgb, y, model_output, i)


visualizer(train_dataset, model=None, count=25)

# Assessment

In [None]:
def assess():
    pass

# EEification

See https://github.com/google/earthengine-community/blob/master/guides/linked/Earth_Engine_TensorFlow_AI_Platform.ipynb

In [None]:
import ee
ee.Authenticate()
ee.Initialize()

In [None]:
meta_graph_def = saved_model_utils.get_meta_graph_def(f"{MODEL_DIR}/test/full_model/", 'serve')
inputs = meta_graph_def.signature_def['serving_default'].inputs
outputs = meta_graph_def.signature_def['serving_default'].outputs

input_dims = {k: len(v.tensor_shape.dim) - 1 for k, v in inputs.items()}
array_input = [k for k in input_dims.keys() if input_dims[k] == 3][0]
scalar_inputs = [k for k in input_dims.keys() if input_dims[k] == 1]
scalar_inputs.sort(key=lambda name: int(name.split("_")[-1]))
input_dict = {inputs[k].name: METADATA[i] for i, k in enumerate(scalar_inputs)}
input_dict[inputs[array_input].name] = "array"
input_dict = "'" + json.dumps(input_dict) + "'"

output_name = None
for k, v in outputs.items():
    output_name = v.name
output_dict = "'" + json.dumps({output_name: "class"}) + "'"

print(input_dict)
print(output_dict)

In [None]:
model_path = f"{MODEL_DIR}/test/full_model/"
!earthengine set_project {PROJECT_ID}
!earthengine model prepare \
    --source_dir {model_path} \
    --dest_dir {EEIFIED_DIR} \
    --input {input_dict} \
    --output {output_dict}

In [None]:
# !gcloud ai-platform models create {MODEL_NAME} \
#     --project {PROJECT_ID} \
#     --region {REGION}

!gcloud ai-platform versions create {VERSION_NAME} \
    --project {PROJECT_ID} \
    --region {REGION} \
    --model {MODEL_NAME} \
    --origin {EEIFIED_DIR} \
    --framework "TENSORFLOW" \
    --runtime-version=2.11 \
    --python-version=3.7

# TODO
* set up assessment code
* grid creation
    * simplify grid creation code in data.py
    * move train/test/val split and large disturbance/small disturbance code to data.py
    * ensure that large/small disturbance code doesn't introduce duplicate cells
* __not enough disturbances in exported data__
* Add Digital Elevation Model band back to export
* Add index to distinguish new harvest from old harvest
    * red / ndvi
    * need way to prove/argue that this is a useful spectral index
* Add index to distinguish new burn scar from old burn scar
* More data augmentation
    * random flips
    * random rotations
    * random colour/brightness adjustments
* temporal model
    * write code
    * figure out how to export training data
* Try out larger node size in AI Platform to see if we can use a larger patch size
* Figure out how to run colab with a paid backend
