# 1. Presto to Vertex AI

<a target="_blank" href="https://colab.research.google.com/github/nasaharvest/presto/blob/initial-deploy-code/deploy/1_Presto_to_VertexAI.ipynb">
  <img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/>
</a>

**Authors**: Ivan Zvonkov, Gabriel Tseng, (additional credits: [Earth_Engine_PyTorch_Vertex_AI](https://github.com/google/earthengine-community/blob/master/guides/linked/Earth_Engine_PyTorch_Vertex_AI.ipynb))

**Description**: The notebook Deploys Presto to Vertex AI. This is a prerequisite to generating Presto embeddings on Google Earth Engine using
[ee.Model.fromVertexAi](https://developers.google.com/earth-engine/apidocs/ee-model-fromvertexai).

Once the model is deployed this [GEE script](https://code.earthengine.google.com/df6348b8d47cd751eb5164dccb7b26a9) can be used to generate Presto embeddings.

**Steps**:
1. Set up environment
2. Load default Presto model
3. Transform Presto model into TorchScript
4. Package TorchScript model into TorchServe
5. Deploy and use Vertex AI

    5a. Upload TorchServe model to Vertex AI Model Registry [Free]

    5b. Create a Vertex AI Endpoint [Free]

    5c. Deploy model to endpoint [Cost depends on Minimum Replica Count parameter]

    5d. Generate embeddings in Google Earth Engine [Cost depends on region size]

    5e. Undeploy model from endpoint [Free]

**Cost Breakdown**:

*5a. Upload TorchServe model to Vertex AI Model Registry [Free]*
- Model files are uploaded to Cloud Storage but are lightweight (3.37 Mb total) and thus easily fall into Google Cloud's 5GB/month Storage [Free Tier](https://cloud.google.com/storage/pricing#cloud-storage-always-free)
- There is no cost to storing models in Vertex AI Model Registry ([source](https://cloud.google.com/vertex-ai/pricing#modelregistry))

*5b. Create a Vertex AI Endpoint [Free]*
- There is no cost to creating an endpoint. Costs start when a model is deployed to that endpoint

*5c. Deploy model to endpoint [Cost depends on Minimum Replica Count parameter]*
- The `Minimum Replica Count` represents the minimum amount of compute nodes started when a model is deployed is e2-standard-2 machine (\$0.0771/node hour in us-central-1)
- So as long as the endpoint is active you will be paying \$0.0771/hour even if no predictions are made

*5d. Generate embeddings in Google Earth Engine [Cost depends on region size]*
- Once a model is deployed and `ee.model.fromVertexAi` is used Vertex AI scales the amount of nodes based on amount of data (size of the region)
- Our current embedding generation cost estimates are <strong>\$5.37 - \$10.14 per 1000 km<sup>2</sup> </strong>
- We compute a cost estimate for your ROI in our Google Earth Engine script

*5e. Undeploy model from endpoint [Free]*
- Necessary to stop incurring charges from 5c

## 1. Set up environment

In [None]:
from google.colab import auth

auth.authenticate_user()

In [None]:
PROJECT = '<YOUR CLOUD PROJECT>'
!gcloud config set project {PROJECT}

In [None]:
!git clone https://github.com/nasaharvest/presto.git

## 2. Load default Presto model

In [None]:
# Navigate inside of the repository to import Presto
%cd /content/presto

import torch
from single_file_presto import Presto

device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")

model = Presto.construct()
model.load_state_dict(torch.load("data/default_model.pt", map_location=device))
model.eval();

# Navigate back to main directory
%cd /content

## 3. Transform Presto model into TorchScript
> TorchScript is a way to create serializable and optimizable models from PyTorch code.
https://docs.pytorch.org/docs/stable/jit.html

In [None]:
# Construct input manually
batch_size = 256

X_tensor = torch.zeros([batch_size, 12, 17])
latlons_tensor = torch.zeros([batch_size, 2])

dw_empty = torch.full([batch_size, 12], 9, device=device).long()
month_tensor = torch.full([batch_size], 1, device=device)

# [0   1   2   3   4   5   6   7   8   9    10   11   12    13      14    16     17  ]
# [VV, VH, B2, B3, B4, B5, B6, B7, B8, B8A, B11, B12, temp, precip, elev, slope, NDVI]
mask = torch.zeros(X_tensor.shape, device=device).float()

In [None]:
# Verify forward pass with regular model
with torch.no_grad():
    preds = model.encoder(
        x=X_tensor,
        dynamic_world=dw_empty,
        latlons=latlons_tensor,
        mask=mask,
        month=month_tensor
    )

In [None]:
# Make model torchscriptable
example_kwargs = {
    'x': X_tensor,
    'dynamic_world': dw_empty,
    'latlons': latlons_tensor,
    'mask': mask,
    'month': month_tensor
}
sm = torch.jit.trace(model.encoder, example_kwarg_inputs=example_kwargs)

!mkdir -p pytorch_model
sm.save('pytorch_model/model.pt')

In [None]:
jit_model = torch.jit.load('pytorch_model/model.pt')
jit_model(**example_kwargs).shape

## 4. Package TorchScript model into TorchServe
> TorchServe is a performant, flexible and easy to use tool for serving PyTorch models in production.
https://docs.pytorch.org/serve/

In [None]:
!pip install torchserve torch-model-archiver -q

In [None]:
%%writefile pytorch_model/custom_handler.py
import logging
import torch
from ts.torch_handler.base_handler import BaseHandler
import numpy as np

# UPDATE BASED ON YOUR NEEDS
########################################
VERSION = "v1"
START_MONTH = 3
BATCH_SIZE = 256
########################################

def printh(text):
    # Prepends HANDLER to each print statement to make it easier to find in logs.
    print(f"HANDLER {VERSION}: {text}")

class ClassifierHandler(BaseHandler):

    def inference(self, data):
        printh("Inference begin")

        # Data shape: [ num_pixels, composite_bands, 1, 1 ]
        data = data[:, :, 0, 0]
        printh(f"Data shape {data.shape}")

        num_bands = 17
        printh(f"Num_bands {num_bands}")

        # Subtract first two latlon
        num_timesteps = (data.shape[1] - 2) // num_bands
        printh(f"Num_timesteps {num_timesteps}")

        with torch.no_grad():

            batches = torch.split(data, BATCH_SIZE, dim=0)

            # month: An int or torch.Tensor describing the first month of the instances being passed. If an int, all instances in the batch are assumed to have the same starting month.
            month_tensor = torch.full([BATCH_SIZE], START_MONTH, device=self.device)
            printh(f"Month: {START_MONTH}")

            # dynamic_world: torch.Tensor of shape [BATCH_SIZE, num_timesteps]. If no Dynamic World classes are available, this tensor should be filled with the value DynamicWorld2020_2021.class_amount (i.e. 9), in which case it is ignored.
            dw_empty = torch.full([BATCH_SIZE, num_timesteps], 9, device=self.device).long()
            printh(f"DW {dw_empty[0]}")

            # mask: An optional torch.Tensor of shape [BATCH_SIZE, num_timesteps, bands]. mask[i, j, k] == 1 means x[i, j, k] is considered masked. If the mask is None, no values in x are ignored.
            mask = torch.zeros((BATCH_SIZE, num_timesteps, num_bands), device=self.device).float()
            printh(f"Mask sample one timestep: {mask[0, 0]}")

            preds_list = []
            for batch in batches:
                padding = 0
                if batch.shape[0] < BATCH_SIZE:
                    padding = BATCH_SIZE - batch.shape[0]
                    batch = torch.cat([batch, torch.zeros([padding, batch.shape[1]], device=self.device)])

                # x: torch.Tensor of shape [BATCH_SIZE, num_timesteps, bands] where bands is described by NORMED_BANDS.
                X_tensor = batch[:, 2:]
                printh(f"X {X_tensor.shape}")

                X_tensor_reshaped = X_tensor.reshape(BATCH_SIZE, num_timesteps, num_bands)
                printh(f"X sample one timestep: {X_tensor_reshaped[0, 0]}")

                # latlons: torch.Tensor of shape [BATCH_SIZE, 2] describing the latitude and longitude of each input instance.
                latlons_tensor = batch[:, :2]

                printh("SHAPES")
                printh(f"X {X_tensor_reshaped.shape}")
                printh(f"DW {dw_empty.shape}")
                printh(f"Latlons {latlons_tensor.shape}")
                printh(f"Mask {mask.shape}")
                printh(f"Month {month_tensor.shape}")

                pred = self.model(
                    x=X_tensor_reshaped,
                    dynamic_world=dw_empty,
                    latlons=latlons_tensor,
                    mask=mask,
                    month=month_tensor
                )
                pred_np = np.expand_dims(pred.numpy(), axis=[1,2])
                if padding == 0:
                    preds_list.append(pred_np[:])
                else:
                    preds_list.append(pred_np[:-padding])

        [printh(f"{p.shape}") for p in preds_list]
        preds = np.concatenate(preds_list)
        printh(f"Preds shape {preds.shape}")
        return preds

    def handle(self, data, context):
        self.context = context
        printh(f"Handle begin")
        input_tensor = self.preprocess(data)
        printh(f"Input_tensor shape {input_tensor.shape}")
        pred_out = self.inference(input_tensor)
        printh(f"Inference complete")
        return self.postprocess(pred_out)

In [None]:
import importlib
import pytorch_model.custom_handler

importlib.reload(pytorch_model.custom_handler)

from pytorch_model.custom_handler import ClassifierHandler, VERSION

# Test output
data = torch.zeros([713, 206, 1, 1])
handler = ClassifierHandler()
handler.model = jit_model
preds = handler.handle(data=data, context=None)

In [None]:
!torch-model-archiver -f \
  --model-name model \
  --version 1.0 \
  --serialized-file 'pytorch_model/model.pt' \
  --handler 'pytorch_model/custom_handler.py' \
  --export-path pytorch_model/

## 5. Deploy and use Vertex AI

### 5a. Upload TorchServe model to Vertex AI Model Registry
> The Vertex AI Model Registry is a central repository where you can manage the lifecycle of your ML models.
https://cloud.google.com/vertex-ai/docs/model-registry/introduction

In [None]:
REGION = 'us-central1'
BUCKET_NAME = "<YOUR CLOUD BUCKET>"

In [None]:
# Create bucket to store model artifcats if it doesn't exist
!gcloud storage buckets create gs://{BUCKET_NAME} --location={REGION}

In [None]:
MODEL_DIR = f'gs://{BUCKET_NAME}/{VERSION}'
!gsutil cp -r pytorch_model {MODEL_DIR}

In [None]:
# Can take 2 minutes
MODEL_NAME = f'model_{VERSION}'
CONTAINER_IMAGE = 'us-docker.pkg.dev/vertex-ai/prediction/pytorch-cpu.2-4:latest'

!gcloud ai models upload \
  --artifact-uri={MODEL_DIR} \
  --region={REGION} \
  --container-image-uri={CONTAINER_IMAGE} \
  --description={MODEL_NAME} \
  --display-name={MODEL_NAME} \
  --model-id={MODEL_NAME}

### 5b. Create a Vertex AI Endpoint
> To deploy a model for online prediction, you need an endpoint.
https://cloud.google.com/vertex-ai/docs/predictions/choose-endpoint-type


In [None]:
ENDPOINT_NAME = 'vertex-pytorch-presto-endpoint'

endpoints = !gcloud ai endpoints list --region={REGION} --format='get(DISPLAY_NAME)'

if ENDPOINT_NAME in endpoints:
    print(f"Endpoint: '{ENDPOINT_NAME}' already exists skipping endpoint creation.")
else:
    print(f"Endpoint: '{ENDPOINT_NAME}' does not exist, creating... (~3 minutes)")
    !gcloud ai endpoints create \
    --display-name={ENDPOINT_NAME} \
    --endpoint-id={ENDPOINT_NAME} \
    --region={REGION}

### 5c. Deploy model to endpoint
> Deploying a model associates physical resources with the model so that it can serve online predictions with low latency.
https://cloud.google.com/vertex-ai/docs/general/deployment

In [None]:
# Deploy model to endpoint, this will start an e2-standard-2 machine which costs money
print("Track model deployment progress and prediction logs:")
print(f"https://console.cloud.google.com/vertex-ai/online-prediction/locations/{REGION}/endpoints/{ENDPOINT_NAME}?project={PROJECT}\n")

# If using for large region, set min-replica-count higher to save scaling time
# Can take from 4-27 minutes
# Relevant quota: "Custom model serving CPUs per region"
!gcloud ai endpoints deploy-model {ENDPOINT_NAME} \
    --region={REGION} \
    --model={MODEL_NAME} \
    --display-name={MODEL_NAME} \
    --machine-type="e2-standard-2" \
    --min-replica-count='1' \
    --max-replica-count="100"

### 5d. Generate embeddings in Google Earth Engine


In [None]:
GEE_SCRIPT_URL = "https://code.earthengine.google.com/c239905f788f67ecf0cee42753893d1c"
print(f"Open this script: {GEE_SCRIPT_URL}")
print("Use the below string for the ENDPOINT variable")
print(f"projects/{PROJECT}/locations/{REGION}/endpoints/{ENDPOINT_NAME}")

### 5e. Undeploy model from endpoint

Once predictions are made, you must <strong>undeploy your model</strong> to stop incurring further charges.

This can be done using the below code or by using the Google Cloud console directly.

In [None]:
def get_deployed_model():
    deployed_models = !gcloud ai endpoints describe {ENDPOINT_NAME} --region={REGION} --format 'get(deployedModels)'
    if deployed_models[1] == '':
        print("No models deployed")
    else:
        print(deployed_model_id)
        return eval(deployed_models[1])['id']

deployed_model_id = get_deployed_model()

In [None]:
!gcloud ai endpoints undeploy-model {ENDPOINT_NAME} --region={REGION} --deployed-model-id={deployed_model_id}

In [None]:
get_deployed_model()