# Presto in EarthEngine

**Authors**: Ivan Zvonkov, Gabriel Tseng

**Description**:
1. Loads default Presto model.
2. Deploys default model to Vertex AI.

Inspired by: https://github.com/google/earthengine-community/blob/master/guides/linked/Earth_Engine_PyTorch_Vertex_AI.ipynb

**Running this demo may incur charges to your Google Cloud Account!**

# Set up

In [None]:
from google.colab import auth

import ee
import google

# REPLACE WITH YOUR CLOUD PROJECT!
PROJECT = 'presto-deployment'

# Authenticate the notebook.
auth.authenticate_user()

# Authenticate to Earth Engine.
credentials, _ = google.auth.default()
ee.Initialize(credentials, project=PROJECT, opt_url='https://earthengine-highvolume.googleapis.com')

# Set the gcloud project for Vertex AI deployment.
!gcloud config set project {PROJECT}

In [None]:
!git clone https://github.com/nasaharvest/presto.git

## 1. Load model

In [None]:
%cd /content/presto

In [None]:
from torch.utils.data import TensorDataset
from torch.utils.data import DataLoader
import torch
import torch.optim as optim
import torch.nn as nn
import numpy as np

from single_file_presto import Presto


device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")

In [None]:
model = Presto.construct()
model.load_state_dict(torch.load("data/default_model.pt", map_location=device))
model.eval();

In [None]:
# Sanity check

# from presto.eval import CropHarvestEval

# togo_eval = CropHarvestEval("Togo", ignore_dynamic_world=False, num_timesteps=12, seed=0)
# results = togo_eval.finetuning_results(model, model_modes=["Regression"])
# results

# batch_size = 8
# X_np, dw_np, latlons_np, y_np = togo_eval.dataset.as_array(num_samples=batch_size)
# month_np = np.array([togo_eval.dataset.start_month] * batch_size)

In [None]:
# Construct input manually
batch_size = 256

X_tensor = torch.zeros([batch_size, 12, 17])
latlons_tensor = torch.zeros([batch_size, 2])

dw_empty = torch.full([batch_size, 12], 9, device=device).long()
month_tensor = torch.full([batch_size], 1, device=device)

# [0   1   2   3   4   5   6   7   8   9    10   11   12    13      14    16     17  ]
# [VV, VH, B2, B3, B4, B5, B6, B7, B8, B8A, B11, B12, temp, precip, elev, slope, NDVI]
mask = torch.zeros(X_tensor.shape, device=device).float()

In [None]:
with torch.no_grad():
    preds = model.encoder(
        x=X_tensor,
        dynamic_world=dw_empty,
        latlons=latlons_tensor,
        mask=mask,
        month=month_tensor
    )
preds

## Deploy to Vertex AI

In [None]:
%cd ..

In [None]:
!pip install torchserve torch-model-archiver -q
!mkdir pytorch_model

In [None]:
from ts.torch_handler.base_handler import BaseHandler

# Make model torchscriptable
example_kwargs = {
    'x': X_tensor,
    'dynamic_world': dw_empty,
    'latlons': latlons_tensor,
    'mask': mask,
    'month': month_tensor
}
sm = torch.jit.trace(model.encoder, example_kwarg_inputs=example_kwargs)

!mkdir -p pytorch_model
sm.save('pytorch_model/model.pt')

In [None]:
jit_model = torch.jit.load('pytorch_model/model.pt')
jit_model(**example_kwargs).shape

In [None]:
%%writefile pytorch_model/custom_handler.py

import logging

import torch
from ts.torch_handler.base_handler import BaseHandler
import numpy as np
import sys

logger = logging.getLogger(__name__)
version = "v27"
batch_size = 256

def printh(text):
    print(f"HANDLER {version}: {text}")

class ClassifierHandler(BaseHandler):

    def inference(self, data):
        printh("inference begin")

        # Data shape: [ num_pixels, composite_bands, 1, 1 ]
        data = data[:, :, 0, 0]
        printh(f"data shape {data.shape}")

        num_bands = 17
        printh(f"num_bands {num_bands}")

        # Subtract first two latlon
        num_timesteps = (data.shape[1] - 2) // num_bands
        printh(f"num_timesteps {num_timesteps}")

        with torch.no_grad():

            batches = torch.split(data, batch_size, dim=0)

            # month: An int or torch.Tensor describing the first month of the instances being passed. If an int, all instances in the batch are assumed to have the same starting month.
            month_tensor = torch.full([batch_size], 3, device=self.device)
            printh(f"month: 3")

            # dynamic_world: torch.Tensor of shape [batch_size, num_timesteps]. If no Dynamic World classes are available, this tensor should be filled with the value DynamicWorld2020_2021.class_amount (i.e. 9), in which case it is ignored.
            dw_empty = torch.full([batch_size, num_timesteps], 9, device=self.device).long()
            printh(f"dw {dw_empty[0]}")

            # mask: An optional torch.Tensor of shape [batch_size, num_timesteps, bands]. mask[i, j, k] == 1 means x[i, j, k] is considered masked. If the mask is None, no values in x are ignored.
            mask = torch.zeros((batch_size, num_timesteps, num_bands), device=self.device).float()
            printh(f"mask sample one timestep: {mask[0, 0]}")

            preds = []
            for batch in batches:

                padding = 0
                if batch.shape[0] < batch_size:
                    padding = batch_size - batch.shape[0]
                    batch = torch.cat([batch, torch.zeros([padding, batch.shape[1]], device=self.device)])

                # x: torch.Tensor of shape [batch_size, num_timesteps, bands] where bands is described by NORMED_BANDS.
                X_tensor = batch[:, 2:]
                printh(f"X {X_tensor.shape}")

                X_tensor_reshaped = X_tensor.reshape(batch_size, num_timesteps, num_bands)
                printh(f"X sample one timestep: {X_tensor_reshaped[0, 0]}")

                # latlons: torch.Tensor of shape [batch_size, 2] describing the latitude and longitude of each input instance.
                latlons_tensor = batch[:, :2]

                # Shapes
                printh("SHAPES")
                printh(f"X {X_tensor_reshaped.shape}")
                printh(f"dw {dw_empty.shape}")
                printh(f"latlons {latlons_tensor.shape}")
                printh(f"mask {mask.shape}")
                printh(f"month {month_tensor.shape}")

                pred = self.model(
                    x=X_tensor_reshaped,
                    dynamic_world=dw_empty,
                    latlons=latlons_tensor,
                    mask=mask,
                    month=month_tensor
                )
                pred_np = np.expand_dims(pred.numpy(), axis=[1,2])
                if padding == 0:
                    preds.append(pred_np[:])
                else:
                    preds.append(pred_np[:-padding])

        [printh(f"{p.shape}") for p in preds]
        preds = np.concatenate(preds)
        printh(f"preds shape {preds.shape}")
        return preds

    def handle(self, data, context):
        self.context = context
        printh(f"handle begin")
        input_tensor = self.preprocess(data)
        printh(f"input_tensor shape {input_tensor.shape}")
        pred_out = self.inference(input_tensor)
        return self.postprocess(pred_out)

In [None]:
import importlib
import pytorch_model
from pytorch_model.custom_handler import ClassifierHandler
importlib.reload(pytorch_model.custom_handler)

from pytorch_model.custom_handler import ClassifierHandler

# Test output
data = torch.zeros([713, 206, 1, 1])
handler = ClassifierHandler()
handler.model = jit_model
preds = handler.handle(data=data, context=None)

In [None]:
!torch-model-archiver -f \
  --model-name model \
  --version 1.0 \
  --serialized-file 'pytorch_model/model.pt' \
  --handler 'pytorch_model/custom_handler.py' \
  --export-path pytorch_model/

In [None]:
version = "v27"
MODEL_DIR = f'gs://presto-models/default_v2025_04_10_{version}'
!gsutil cp -r pytorch_model {MODEL_DIR}

In [None]:
REGION = 'us-central1'
MODEL_NAME = f'model_{version}'
CONTAINER_IMAGE = 'us-docker.pkg.dev/vertex-ai/prediction/pytorch-cpu.2-4:latest'
ENDPOINT_NAME = 'vertex-pytorch-presto-endpoint'

!gcloud ai models upload \
  --artifact-uri={MODEL_DIR} \
  --project={PROJECT} \
  --region={REGION} \
  --container-image-uri={CONTAINER_IMAGE} \
  --description={MODEL_NAME} \
  --display-name={MODEL_NAME} \
  --model-id={MODEL_NAME}

# Create endpoint, if endpoint does not exist
# !gcloud ai endpoints create \
#   --display-name={ENDPOINT_NAME} \
#   --endpoint-id={ENDPOINT_NAME} \
#   --region={REGION} \
#   --project={PROJECT}

!gcloud ai endpoints deploy-model {ENDPOINT_NAME} \
  --project={PROJECT} \
  --region={REGION} \
  --model={MODEL_NAME} \
  --display-name={MODEL_NAME} \
  --machine-type="e2-standard-4"

# 21 mintues when issues with server
# 4 minutes when it's working